#! /usr/bin/python3 -s
import os
import sys
import syslog

command = os.environ.get('SSH_ORIGINAL_COMMAND')
ssh_connection = os.environ.get('SSH_CONNECTION')
if command is None:
    sys.stderr.write('This command must be run via SSH ForceCommand (see man 5 sshd_config).\n')
    sys.exit(1)

syslog.openlog('nova_migration_wrapper')

def allow_command(user, args):
    syslog.syslog(syslog.LOG_INFO, "Allowing connection='{}' command={} ".format(
        ssh_connection,
        repr(args)
    ))
    os.execlp('sudo', 'sudo', '-u', user, *args)

def deny_command(args):
    syslog.syslog(syslog.LOG_ERR, "Denying connection='{}' command={}".format(
        ssh_connection,
        repr(args)
    ))
    sys.stderr.write('Forbidden\n')
    sys.exit(1)

# Handle libvirt ssh tunnel script snippet
# https://github.com/libvirt/libvirt/blob/f0803dae93d62a4b8a2f67f4873c290a76d978b3/src/rpc/virnetsocket.c#L890
libvirt_sock = '/var/run/libvirt/libvirt-sock'
live_migration_tunnel_cmd = "sh -c 'if 'nc' -q 2>&1 | grep \"requires an argument\" >/dev/null 2>&1; then " \
                                "ARG=-q0;" \
                            "else " \
                                "ARG=;" \
                            "fi;" \
                            "'nc' $ARG -U {}'".format(libvirt_sock)

cold_migration_root = '/var/lib/nova/instances/'
cold_migration_cmds = [
    ['mkdir', '-p'],
    ['rm', '-rf'],
    ['touch'],
    ['rm'],
    ['scp', '-r', '-t'],
    ['scp', '-r', '-f'],
    ['scp', '-t'],
    ['scp', '-f'],
]
rootwrap_args = ['/usr/bin/nova-rootwrap', '/etc/nova/migration/rootwrap.conf']

def validate_cold_migration_cmd(args):
    target_path = os.path.normpath(args[-1])
    cmd = args[:-1]
    return cmd in cold_migration_cmds and target_path.startswith(cold_migration_root)

# Rules
args = command.split(' ')
if command == live_migration_tunnel_cmd:
    args = ['nc', '-U', libvirt_sock]
    allow_command('nova', args)
if validate_cold_migration_cmd(args):
    args = rootwrap_args + args
    allow_command('root', args)
deny_command(args)
