+/*
+ * hl_wait_multi_cs_completion_fini - return completion structure and set as
+ * unused
+ *
+ * @mcs_compl: pointer to the completion structure
+ */
+static void hl_wait_multi_cs_completion_fini(
+ struct multi_cs_completion *mcs_compl)
+{
+ /*
+ * free completion structure, do it under lock to be in-sync with the
+ * thread that signals completion
+ */
+ spin_lock(&mcs_compl->lock);
+ mcs_compl->used = 0;
+ spin_unlock(&mcs_compl->lock);
+}
+
+/*
+ * hl_wait_multi_cs_completion - wait for first CS to complete
+ *
+ * @mcs_data: multi-CS internal data
+ *
+ * @return 0 on success, otherwise non 0 error code
+ */
+static int hl_wait_multi_cs_completion(struct multi_cs_data *mcs_data)
+{
+ struct hl_device *hdev = mcs_data->ctx->hdev;
+ struct multi_cs_completion *mcs_compl;
+ long completion_rc;
+
+ mcs_compl = hl_wait_multi_cs_completion_init(hdev,
+ mcs_data->stream_master_qid_map);
+ if (IS_ERR(mcs_compl))
+ return PTR_ERR(mcs_compl);
+
+ completion_rc = wait_for_completion_interruptible_timeout(
+ &mcs_compl->completion,
+ usecs_to_jiffies(mcs_data->timeout_us));
+
+ /* update timestamp */
+ if (completion_rc > 0)
+ mcs_data->timestamp = mcs_compl->timestamp;
+
+ hl_wait_multi_cs_completion_fini(mcs_compl);
+
+ mcs_data->wait_status = completion_rc;
+
+ return 0;
+}
+
+/*
+ * hl_multi_cs_completion_init - init array of multi-CS completion structures
+ *
+ * @hdev: pointer to habanalabs device structure
+ */
+void hl_multi_cs_completion_init(struct hl_device *hdev)
+{
+ struct multi_cs_completion *mcs_cmpl;
+ int i;
+
+ for (i = 0; i < MULTI_CS_MAX_USER_CTX; i++) {
+ mcs_cmpl = &hdev->multi_cs_completion[i];
+ mcs_cmpl->used = 0;
+ spin_lock_init(&mcs_cmpl->lock);
+ init_completion(&mcs_cmpl->completion);
+ }
+}
+
+/*
+ * hl_multi_cs_wait_ioctl - implementation of the multi-CS wait ioctl
+ *
+ * @hpriv: pointer to the private data of the fd
+ * @data: pointer to multi-CS wait ioctl in/out args
+ *
+ */
+static int hl_multi_cs_wait_ioctl(struct hl_fpriv *hpriv, void *data)
+{
+ struct hl_device *hdev = hpriv->hdev;
+ struct multi_cs_data mcs_data = {0};
+ union hl_wait_cs_args *args = data;
+ struct hl_ctx *ctx = hpriv->ctx;
+ struct hl_fence **fence_arr;
+ void __user *seq_arr;
+ u32 size_to_copy;
+ u64 *cs_seq_arr;
+ u8 seq_arr_len;
+ int rc;
+
+ if (!hdev->supports_wait_for_multi_cs) {
+ dev_err(hdev->dev, "Wait for multi CS is not supported\n");
+ return -EPERM;
+ }
+
+ seq_arr_len = args->in.seq_arr_len;
+
+ if (seq_arr_len > HL_WAIT_MULTI_CS_LIST_MAX_LEN) {
+ dev_err(hdev->dev, "Can wait only up to %d CSs, input sequence is of length %u\n",
+ HL_WAIT_MULTI_CS_LIST_MAX_LEN, seq_arr_len);
+ return -EINVAL;
+ }
+
+ /* allocate memory for sequence array */
+ cs_seq_arr =
+ kmalloc_array(seq_arr_len, sizeof(*cs_seq_arr), GFP_KERNEL);
+ if (!cs_seq_arr)
+ return -ENOMEM;
+
+ /* copy CS sequence array from user */
+ seq_arr = (void __user *) (uintptr_t) args->in.seq;
+ size_to_copy = seq_arr_len * sizeof(*cs_seq_arr);
+ if (copy_from_user(cs_seq_arr, seq_arr, size_to_copy)) {
+ dev_err(hdev->dev, "Failed to copy multi-cs sequence array from user\n");
+ rc = -EFAULT;
+ goto free_seq_arr;
+ }
+
+ /* allocate array for the fences */
+ fence_arr = kmalloc_array(seq_arr_len, sizeof(*fence_arr), GFP_KERNEL);
+ if (!fence_arr) {
+ rc = -ENOMEM;
+ goto free_seq_arr;
+ }
+
+ /* initialize the multi-CS internal data */
+ mcs_data.ctx = ctx;
+ mcs_data.seq_arr = cs_seq_arr;
+ mcs_data.fence_arr = fence_arr;
+ mcs_data.arr_len = seq_arr_len;
+
+ hl_ctx_get(hdev, ctx);
+
+ /* poll all CS fences, extract timestamp */
+ mcs_data.update_ts = true;
+ rc = hl_cs_poll_fences(&mcs_data);
+ /*
+ * skip wait for CS completion when one of the below is true:
+ * - an error on the poll function
+ * - one or more CS in the list completed
+ * - the user called ioctl with timeout 0
+ */
+ if (rc || mcs_data.completion_bitmap || !args->in.timeout_us)
+ goto put_ctx;
+
+ /* wait (with timeout) for the first CS to be completed */
+ mcs_data.timeout_us = args->in.timeout_us;
+ rc = hl_wait_multi_cs_completion(&mcs_data);
+ if (rc)
+ goto put_ctx;
+
+ if (mcs_data.wait_status > 0) {
+ /*
+ * poll fences once again to update the CS map.
+ * no timestamp should be updated this time.
+ */
+ mcs_data.update_ts = false;
+ rc = hl_cs_poll_fences(&mcs_data);
+
+ /*
+ * if hl_wait_multi_cs_completion returned before timeout (i.e.
+ * it got a completion) we expect to see at least one CS
+ * completed after the poll function.
+ */
+ if (!mcs_data.completion_bitmap) {
+ dev_err(hdev->dev, "Multi-CS got completion on wait but no CS completed\n");
+ rc = -EFAULT;
+ }