#!/usr/bin/env python
'Submit a batch job to a queue system'
##   Copyright (C) 2010 University of Texas at Austin
##  
##   This program is free software; you can redistribute it and/or modify
##   it under the terms of the GNU General Public License as published by
##   the Free Software Foundation; either version 2 of the License, or
##   (at your option) any later version.
##  
##   This program is distributed in the hope that it will be useful,
##   but WITHOUT ANY WARRANTY; without even the implied warranty of
##   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
##   GNU General Public License for more details.
##  
##   You should have received a copy of the GNU General Public License
##   along with this program; if not, write to the Free Software
##   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

from __future__ import print_function, division, absolute_import
import os, sys, time

required_pars=dict(
    batch=   'pbs, sge, or slurm', 
    queue=   'type of the queue',
    mail=    'email for the user',
    acct=    'account name',
    wall=    'wall time',
    nodes=   'number of nodes',
    ppn=     'processes per node',
    np=      'number of tasks',
    exe=     'command to run',
    )
optional_pars=dict(
    modules= 'modules to load',    
    path=    'directory path',
    job=     'script name',
)

if len(sys.argv) < 2:
    # selfdoc
    print('''
    Usage: %s [batchfile=parfile] [key=value] 
    
    The parfile can contain Python-syntax key=value entries
    for the following keys:''' % sys.argv[0])
    for key in list(required_pars.keys()):
        print('    %s= %s# %s' % (key,' ' * (7-len(key)), required_pars[key]))
    print('''

    The following keys are optional:''')
    for key in list(optional_pars.keys()):
        print('    %s= %s# %s' % (key,' ' * (7-len(key)), optional_pars[key]))
    sys.exit(1)

# Read the default parameter file from ~/.batch
parfile = os.path.join(os.environ.get('HOME'),'.batch')

if os.path.isfile(parfile):
    # get keys from the parameter file
    
    try:
        pfile = open(parfile)
        pars = pfile.read()
        pfile.close()
    except:
        print('''
        Could not read batchfile "%s".
        ''' % parfile)
        sys.exit(2)

    exec(pars, locals())

# Get a parameter file from the command line
parfile = None
for arg in sys.argv[1:]:
    if 'batchfile=' == arg[:10]:
        parfile = arg[10:]
        sys.argv.remove(arg)
        
if parfile:
    # get keys from the parameter file
    
    try:
        pfile = open(parfile)
        pars = pfile.read()
        pfile.close()
    except:
        print('''
        Could not read batchfile "%s".
        ''' % parfile)
        sys.exit(2)

    exec(pars, locals())

cwd = os.getcwd()
dirname = os.path.split(os.path.dirname(cwd))[1]

for key in list(required_pars.keys())+list(optional_pars.keys()):
    # look for the key on the command line
    klen=len(key)+1
    for arg in sys.argv[1:]:
        if arg[:klen] == key+'=':
            locals()[key] = arg[klen:]
            sys.argv.remove(arg)

for key in list(required_pars.keys()):
    # check if the required key is specified
    if key not in locals():
        print('''
        %s: need %s= parameter.
        ''' % (sys.argv[0],key)) 
        sys.exit(3)

for key in list(optional_pars.keys()):
    # check if the required key is specified
    if key not in locals():
        # set default values
        if key=='modules':
            modules = ''
        elif key=='path':
            path = cwd
        elif key=='job':
            job = '-'.join([dirname,time.strftime('%Y-%m-%d@%H-%M-%S')])

#exe = ' '.join(sys.argv[1:])

# Write a batch script file
jobid = os.path.join(cwd,'jobid')

# figure out the number of nodes
nodes = int(nodes)
np    = int(np)
ppn   = int(ppn)

if nodes > np/ppn:
    nodes = max(1,np/ppn)

np = nodes*ppn

for module in modules.split(','):
    if module:
        exe = 'module load %s\n' % module + exe
    
if batch == 'pbs':
    submit = '/usr/bin/qsub'
    jobfile = job + '.pbs'
    f = open(jobfile,'w')
    f.write('#PBS -N ' + job + '\n')
    f.write('#PBS -A ' + acct + '\n')
    f.write('#PBS -V\n')
    f.write('#PBS -l walltime=' + wall + '\n')
    f.write('#PBS -l nodes=' + str(nodes) + ':ppn=' + str(ppn) + '\n')
    f.write('#PBS -q '+queue + '\n')
    f.write('#PBS -m abe\n')
    f.write('#PBS -M '+ mail + '\n')
    f.write('#PBS -o '+job+'.%j.out\n')
    f.write('#PBS -e '+job+'.%j.err\n')
    f.write('cd '+ path +'\n')
    f.write(exe + '\n')
    f.write('echo $PBS_JOBID > ' + jobid + '\n')
    f.close()
elif batch == 'slurm':
    submit = '/usr/bin/sbatch'
    jobfile = job + '.bat'
    f = open(jobfile,'w')
    f.write('#!/bin/bash\n')
    f.write('#SBATCH -J ' + job + '\n')
    f.write('#SBATCH -A ' + acct + '\n')
    f.write('#SBATCH -t '+wall + '\n')
    f.write('#SBATCH -n ' + str(np) + '\n')
    f.write('#SBATCH -N ' + str(nodes) + '\n')
    f.write('#SBATCH -p '+queue + '\n')
    f.write('#SBATCH --mail-type=begin\n')
    f.write('#SBATCH --mail-type=end\n')
    f.write('#SBATCH --mail-user='+mail + '\n')
#    f.write('#SBATCH --uid=' + os.getenv('USER') + '\n')
    f.write('#SBATCH -o '+job+'.%j.out\n')
    f.write('#SBATCH -e '+job+'.%j.err\n')
    f.write('cd '+path+'\n')
    f.write(exe + '\n')
    f.write('echo $SLURM_JOBID > ' + jobid + '\n')
    f.close()
elif batch == 'sge':
    submit = '/usr/bin/qsub'
    jobfile = job + '.sge'
    f = open(jobfile,'w')
    f.write('#$ -N ' + job + '\n')
    f.write('#$ -A ' + acct + '\n')
    f.write('#$ -V\n')
    f.write('#$ -l h_rt=' + wall + '\n')
    f.write('#$ -pe %dway %d\n' % (ppn,nodes))
    f.write('#$ -q '+queue + '\n')
    f.write('#$ -m abe\n')
    f.write('#$ -M '+ mail + '\n')
    f.write('#$ -o '+job+'.%j.out\n')
    f.write('#$ -e '+job+'.%j.err\n')
    f.write('cd '+ path +'\n')
    f.write(exe + '\n')
    f.write('echo $JOB_ID > ' + jobid + '\n')
    f.close()
else:
    print('''
    Unknown batch system "%s".
    ''' % batch)
    sys.exit(4)

submit = ' '.join([submit,jobfile])
print(submit)

if os.path.isfile(jobid):
    os.unlink(jobid)
    
if os.system(submit):
    print('Failed to submit a job.')
    sys.exit(5)

while not os.path.isfile(jobid):
     time.sleep(5)

sys.exit(0)
