#!/usr/bin/env bash

#  ------------------------------------------------------------------------------------------
#  Copyright (c) Microsoft Corporation. All rights reserved.
#  Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
#  ------------------------------------------------------------------------------------------

export JULIA_LOAD_PATH=$JULIA_LOAD_PATH:$AZ_BATCH_TASK_WORKING_DIR
export PATH=$PATH:$JULIA_DEPOT_PATH/bin

# Execute julia runtime
if [[ $MPI_RUN == 0 ]]; then

    if [[ $INTER_NODE_CONNECTION == 0 ]]; then

        # Serial execution: one task per node
        echo "Serial execution on single node."
        julia $AZ_BATCH_TASK_WORKING_DIR/batch_runtime.jl

        else

        # Parallel julia session on multi nodes w/o MPI
        echo "Run parallel Julia session for each task with $NUM_NODES_PER_TASK workers."
        julia --procs $NUM_NODES_PER_TASK $AZ_BATCH_TASK_SHARED_DIR/batch_runtime.jl
    fi
    
    else

    # Total number of MPI ranks
    NUM_RANKS=$(( NUM_NODES_PER_TASK*NUM_PROCS_PER_NODE ))

    if [[ $INTER_NODE_CONNECTION == 0 ]]; then

        # Parallel execution with mpi (single node)
        echo "MPI execution on single node with $NUM_RANKS MPI ranks."
        mpiexecjl -n $NUM_RANKS julia $AZ_BATCH_TASK_WORKING_DIR/batch_runtime.jl

        else

        # Parallel execution with mpi (multi node)
        echo "MPI execution on $NUM_NODES_PER_TASK nodes with $NUM_RANKS total MPI ranks."
        mpiexecjl -n $NUM_RANKS -host $AZ_BATCH_HOST_LIST julia $AZ_BATCH_TASK_SHARED_DIR/batch_runtime.jl
    fi
fi

