https://github.com/JuliaParallel/MPI.jl
Raw File
Tip revision: 69ce0bf5c15e4a29cb4a4c3df3e8d02f3e9ba0ad authored by Simon Byrne on 23 January 2023, 18:49:26 UTC
Tag v0.20.8 (#712)
Tip revision: 69ce0bf
test_shared_win.jl
using MPI: MPI, Comm, Win
using Test

"""
Create a shared array, allocated by process with rank `owner_rank` on the
node_comm provided (i.e. when `MPI.Comm_rank(node_comm) == owner_rank`). Assumes all
processes on the node_comm are on the same node, or, more precisely that they
can create/access a shared mem block between them.

usage:
nrows, ncols = 100, 11
const arr = mpi_shared_array(MPI.COMM_WORLD, Int, (nrows, nworkers_node), owner_rank=0)
"""
function mpi_shared_array(node_comm::Comm, ::Type{T}, sz::Tuple{Vararg{Int}}; owner_rank=0) where T
    node_rank = MPI.Comm_rank(node_comm)
    len_to_alloc = MPI.Comm_rank(node_comm) == owner_rank ? prod(sz) : 0
    if node_rank == owner_rank
        win, array = MPI.Win_allocate_shared(Array{T}, sz, node_comm)
    else
        win, array = MPI.Win_allocate_shared(Array{T}, 0, node_comm)
    end
    if node_rank != owner_rank
        array = MPI.Win_shared_query(Array{T}, sz, win, owner_rank)
    end
    win, array
end

function main()
    MPI.Init()

    global_comm = MPI.COMM_WORLD

    if MPI.Comm_size(global_comm) > 1

        group_comm_id = 1
        node_rank = MPI.Comm_rank(global_comm) # do this differently in real code
        node_comm = MPI.Comm_split(global_comm, group_comm_id, node_rank)
        owner_rank = 1

        win, shared_arr =
            mpi_shared_array(node_comm, Float32, (100, 2); owner_rank=owner_rank)

        if node_rank == 0
            shared_arr[:, 1] .= 1:100
        elseif node_rank == 1
            shared_arr[:, 2] .= 901:1000
        end

        MPI.Barrier(node_comm) # finish writing before reading
        # check you can see those values on all processes
        @test all(shared_arr[:, 1] .== 1:100)
        @test all(shared_arr[:, 2] .== 901:1000)
        if node_rank <= 1
            len, elsize_bytes, baseptr = MPI.Win_shared_query(Ptr{Float32}, win, owner_rank)
            @test elsize_bytes == sizeof(Float32)
            @test len == sizeof(shared_arr)
            @test baseptr == pointer(shared_arr)
        end
        MPI.free(win)

    end

    MPI.Finalize()
end

# run with `mpirun -np 3 julia --project test_shared_win.jl`
main()
GC.gc()

@test MPI.Finalized()
back to top