# vim: set ft=c:

MPIX_Stream_create:
    .desc: Create a new stream object
    info: INFO, [info argument]
    stream: STREAM, direction=out, [stream object created]

MPIX_Stream_free:
    .desc: Free a stream object
    stream: STREAM, direction=inout, [stream object]

MPIX_Stream_comm_create:
    .desc: Create a new communicator with local stream attached
    comm: COMMUNICATOR, [communicator]
    stream: STREAM, [stream object]
    newcomm: COMMUNICATOR, direction=out, [new stream-associated communicator]

MPIX_Stream_comm_create_multiplex:
    .desc: Create a new communicator with multiple local streams attached
    comm: COMMUNICATOR, [communicator]
    count: ARRAY_LENGTH_NNI, [list length]
    array_of_streams: STREAM, length=count, [stream object array]
    newcomm: COMMUNICATOR, direction=out, [new stream-associated communicator]

MPIX_Comm_get_stream:
    .desc: Get the stream object that is attached to the communicator
    comm: COMMUNICATOR, [communicator]
    idx: INDEX
    stream: STREAM, direction=out, [stream object]

MPIX_Stream_progress:
    stream: STREAM, [stream object]
    .desc: Invoke progress on the given stream
    .impl: mpid

MPIX_Start_progress_thread:
    .desc: Start a progress thread that will poll progress on the given stream
    stream: STREAM, [stream object]

MPIX_Stop_progress_thread:
    .desc: Stop the progress thread that polls progress on the given stream
    stream: STREAM, [stream object]

MPIX_Stream_send:
    .desc: Send message from a specific local stream to a specific destination stream
    buf: BUFFER, constant=True, [initial address of send buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
    datatype: DATATYPE, [datatype of each send buffer element]
    dest: RANK, [rank of destination]
    tag: TAG, [message tag]
    comm: COMMUNICATOR
    source_stream_index: INDEX
    dest_stream_index: INDEX
{
    MPIR_Request *request_ptr = NULL;

    int attr;
    mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, comm_ptr->rank, dest,
                                          source_stream_index, dest_stream_index, &attr);
    MPIR_ERR_CHECK(mpi_errno);

    mpi_errno = MPID_Send(buf, count, datatype, dest, tag, comm_ptr, attr, &request_ptr);
    MPIR_ERR_CHECK(mpi_errno);

    if (request_ptr == NULL) {
        goto fn_exit;
    }

    mpi_errno = MPIR_Wait(request_ptr, MPI_STATUS_IGNORE);
    MPIR_Request_free(request_ptr);
    MPIR_ERR_CHECK(mpi_errno);
}

MPIX_Stream_isend:
    .desc: Start a nonblocking send from a specific local stream to a specific remote stream
    buf: BUFFER, asynchronous=True, constant=True, [initial address of send buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
    datatype: DATATYPE, [datatype of each send buffer element]
    dest: RANK, [rank of destination]
    tag: TAG, [message tag]
    comm: COMMUNICATOR
    source_stream_index: INDEX
    dest_stream_index: INDEX
    request: REQUEST, direction=out
{
    MPIR_Request *request_ptr = NULL;

    int attr;
    mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, comm_ptr->rank, dest,
                                          source_stream_index, dest_stream_index, &attr);
    MPIR_ERR_CHECK(mpi_errno);

    mpi_errno = MPID_Isend(buf, count, datatype, dest, tag, comm_ptr,
                           attr, &request_ptr);
    MPIR_ERR_CHECK(mpi_errno);

    MPII_SENDQ_REMEMBER(request_ptr, dest, tag, comm_ptr->context_id, buf, count);

    /* return the handle of the request to the user */
    /* MPIU_OBJ_HANDLE_PUBLISH is unnecessary for isend, lower-level access is
     * responsible for its own consistency, while upper-level field access is
     * controlled by the completion counter */
    *request = request_ptr->handle;
}

MPIX_Stream_recv:
    .desc: Receive a message from a specific source stream to a specific local stream
    buf: BUFFER, direction=out, [initial address of receive buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
    datatype: DATATYPE, [datatype of each receive buffer element]
    source: RANK, [rank of source or MPI_ANY_SOURCE]
    tag: TAG, [message tag or MPI_ANY_TAG]
    comm: COMMUNICATOR
    source_stream_index: INDEX
    dest_stream_index: INDEX
    status: STATUS, direction=out
{
    MPIR_Request *request_ptr = NULL;

    int attr;
    mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, source, comm_ptr->rank,
                                          source_stream_index, dest_stream_index, &attr);
    MPIR_ERR_CHECK(mpi_errno);

    mpi_errno = MPID_Recv(buf, count, datatype, source, tag, comm_ptr,
                          attr, status, &request_ptr);
    MPIR_ERR_CHECK(mpi_errno);

    if (request_ptr == NULL) {
        goto fn_exit;
    }

    mpi_errno = MPIR_Wait(request_ptr, MPI_STATUS_IGNORE);
    MPIR_Request_free(request_ptr);
    MPIR_ERR_CHECK(mpi_errno);
}

MPIX_Stream_irecv:
    .desc: Start a nonblocking receive from a specific source stream to a specific local stream
    buf: BUFFER, direction=out, asynchronous=True, suppress=f08_intent, [initial address of receive buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
    datatype: DATATYPE, [datatype of each receive buffer element]
    source: RANK, [rank of source or MPI_ANY_SOURCE]
    tag: TAG, [message tag or MPI_ANY_TAG]
    comm: COMMUNICATOR
    source_stream_index: INDEX
    dest_stream_index: INDEX
    request: REQUEST, direction=out
{
    MPIR_Request *request_ptr = NULL;

    int attr;
    mpi_errno = MPIR_Stream_comm_set_attr(comm_ptr, source, comm_ptr->rank,
                                          source_stream_index, dest_stream_index, &attr);
    MPIR_ERR_CHECK(mpi_errno);

    mpi_errno = MPID_Irecv(buf, count, datatype, source, tag, comm_ptr,
                           attr, &request_ptr);

    *request = request_ptr->handle;
    MPIR_ERR_CHECK(mpi_errno);
}

MPIX_Send_enqueue:
    .desc: Enqueue a send operation to a GPU stream that is associated with the local stream
    buf: BUFFER, constant=True, [initial address of send buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
    datatype: DATATYPE, [datatype of each send buffer element]
    dest: RANK, [rank of destination]
    tag: TAG, [message tag]
    comm: COMMUNICATOR
    .impl: mpid
    .decl: MPIR_Send_enqueue_impl

MPIX_Recv_enqueue:
    .desc: Enqueue a receive operation to a GPU stream that is associated with the local stream
    buf: BUFFER, direction=out, [initial address of receive buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
    datatype: DATATYPE, [datatype of each receive buffer element]
    source: RANK, [rank of source or MPI_ANY_SOURCE]
    tag: TAG, [message tag or MPI_ANY_TAG]
    comm: COMMUNICATOR
    status: STATUS, direction=out
    .impl: mpid
    .decl: MPIR_Recv_enqueue_impl

MPIX_Isend_enqueue:
    .desc: Enqueue a nonblocking send operation to a GPU stream that is associated with the local stream
    buf: BUFFER, constant=True, [initial address of send buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
    datatype: DATATYPE, [datatype of each send buffer element]
    dest: RANK, [rank of destination]
    tag: TAG, [message tag]
    comm: COMMUNICATOR
    request: REQUEST, direction=out
    .impl: mpid
    .decl: MPIR_Isend_enqueue_impl

MPIX_Irecv_enqueue:
    .desc: Enqueue a nonblocking receive operation to a GPU stream that is associated with the local stream
    buf: BUFFER, direction=out, [initial address of receive buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in receive buffer]
    datatype: DATATYPE, [datatype of each receive buffer element]
    source: RANK, [rank of source or MPI_ANY_SOURCE]
    tag: TAG, [message tag or MPI_ANY_TAG]
    comm: COMMUNICATOR
    request: REQUEST, direction=out
    .impl: mpid
    .decl: MPIR_Irecv_enqueue_impl

MPIX_Wait_enqueue:
    .desc: Enqueue a wait operation to a GPU stream that is associated with the local stream
    request: REQUEST, direction=inout, [request]
    status: STATUS, direction=out
    .impl: mpid
    .decl: MPIR_Wait_enqueue_impl

MPIX_Waitall_enqueue:
    .desc: Enqueue a waitall operation to a GPU stream that is associated with the local stream
    count: ARRAY_LENGTH_NNI, [lists length]
    array_of_requests: REQUEST, direction=inout, length=count, [array of requests]
    array_of_statuses: STATUS, direction=out, length=*, pointer=False, [array of status objects]
    .impl: mpid
    .decl: MPIR_Waitall_enqueue_impl
{ -- error_check -- array_of_statuses
    if (count > 0) {
        MPIR_ERRTEST_ARGNULL(array_of_statuses, "array_of_statuses", mpi_errno);
    }
}

MPIX_Allreduce_enqueue:
    .desc: Enqueue an allreduce operation to a GPU stream that is associated with the local stream
    sendbuf: BUFFER, constant=True, [starting address of send buffer]
    recvbuf: BUFFER, direction=out, [starting address of receive buffer]
    count: POLYXFER_NUM_ELEM_NNI, [number of elements in send buffer]
    datatype: DATATYPE, [data type of elements of send buffer]
    op: OPERATION, [operation]
    comm: COMMUNICATOR, [communicator]
