#!/usr/bin/python -tt

import sys, os
import syslog
import pwd
import syslog
import glob
import base64
import commands
import re
import selinux

commands_map = {
    "git-receive-pack": "/usr/bin/git-receive-pack",
    "git-upload-pack": "/usr/bin/git-upload-pack",
    "snapshot": "/bin/bash",
    "restore": "/bin/bash",
    "tail": "/usr/bin/tail",
    "rhcsh": "/bin/bash",
    "true": "/bin/true",
    "java": "/bin/bash",
    "scp": "/bin/bash",
    "cd": "/bin/bash",
    "set": "/bin/bash",
    "mkdir": "/bin/bash",
    "test": "/bin/bash",
    "rsync": "/bin/bash",
    "ctl_all": "/bin/bash",
    "deploy.sh": "/bin/bash",
    "rhc-list-ports": "/bin/bash",
    "post_deploy.sh": "/bin/bash"
}

#
# Read in uservars variables.
#
def _set_env_uservars(uservars_dir):
    for env in os.listdir(uservars_dir):
        fp = open(os.path.join(uservars_dir, env), 'r')
        env_var = fp.readlines()[0].strip().strip('\'"')
        fp.close()
        os.putenv(env, env_var)
    pass

#
# Read in environment variables
#
def read_env_vars():
#    os.putenv
    envdir = os.path.expanduser('~/.env/')
    for env in os.listdir(envdir):
        if 'USER_VARS' == env:
           continue
        elif os.path.isdir(envdir + env):
           if '.uservars' == env:
              _set_env_uservars(envdir + env)
           continue

        fp = open(os.path.expanduser('~/.env/') + env, 'r')
        env_var = fp.readlines()[0].strip().split('=')[1].strip('\'"')
        fp.close()
        os.putenv(env, env_var)

def read_config():
  config = {}
  f = open('/etc/stickshift/stickshift-node.conf','r')
  data = f.read()
  f.close()
  lines = data.split("\n")
  for line in lines:
    if line != "":
      split_line = line.split("=")
      value = re.sub(r"[ \t]*#[^\n]*", "", split_line[1]).strip('\'"')     #bash comments
      config[split_line[0]] = value
  return config

if __name__ == '__main__':
    read_env_vars()
    config = read_config()

    orig_cmd = os.environ.get('SSH_ORIGINAL_COMMAND', "rhcsh")
    syslog.syslog(orig_cmd)
    allargs = orig_cmd.split()
    try:
        basecmd = os.path.basename(allargs[0])
        cmd = commands_map[basecmd]
    except:
        syslog.syslog("Invalid command %s" % orig_cmd)
        sys.stderr.write("Invalid command %s" % orig_cmd)
        sys.exit(2)
    if basecmd in ('snapshot',):
        # This gets called with "snapshot"
        allargs = ['snapshot.sh']
    if basecmd in ('restore',):
        # This gets called with "restore <INCLUDE_GIT>"
        include_git = False
        if len(allargs) > 1 and allargs[1] == 'INCLUDE_GIT':
            include_git = True

        allargs = ['restore.sh'] 
        if include_git:
            allargs.append('INCLUDE_GIT')
    elif basecmd in ('rhcsh',):
        os.environ["PS1"] = "rhcsh> "
        if len(allargs) < 2:
            allargs = ['--init-file', '/usr/bin/rhcsh', '-i']
        else:
            str = ' '.join(allargs[1:])
            allargs = ['--init-file', '/usr/bin/rhcsh', '-c', str]
    elif basecmd in ('ctl_all',):
        allargs = ['-c', '. /usr/bin/rhcsh > /dev/null ; ctl_all %s' % allargs[-1]]
    elif basecmd in ('java','set','scp', 'cd', 'test', 'mkdir', 'rsync', 'deploy.sh', 'post_deploy.sh', 'rhc-list-ports'):
        str = ' '.join(allargs)
        allargs = ['-c', str]
    elif basecmd in ('tail',):
        files = []
        
        files_start_index = 1
        args = []
        add_follow = True
        if allargs[1] == '--opts':
            files_start_index = 3
            args_str = base64.standard_b64decode(allargs[2])
            args = args_str.split()
            for arg in args:
                if arg.startswith(('..', '/')):
                    print "All paths must be relative: " + arg
                    sys.exit(88)
                elif arg == '-f' or arg == '-F' or arg.startswith('--follow'):
                    add_follow = False

        for glob_list in allargs[files_start_index:]:
            for f in glob.glob(glob_list):
                try:
                    if os.path.islink(f) and os.path.lexists(f):
                        files.append(f)
                    else:
                        files.append(f)
                except OSError, e:
                    print "Error: %s" % e.strerror
                    sys.exit(91)
        if len(files) == 0:
            print "Could not find any files matching glob"
            sys.exit(32)
        allargs = []
        allargs.extend(args)
        if add_follow:
            allargs.append('-f')
        allargs.extend(files)
    elif basecmd in ('git-receive-pack', 'git-upload-pack'):
        # git repositories need to be parsed specially
        thearg = ' '.join(allargs[1:])
        if thearg[0] == "'" and thearg[-1] == "'":
            thearg = thearg.replace("'","")
        thearg = thearg.replace("\\'", "")
        thearg = thearg.replace("//", "/")

        # replace leading tilde (~) with user's home path
        realpath = os.path.expanduser(thearg)
        if not realpath.startswith(config['GEAR_BASE_DIR']):
            syslog.syslog("Invalid repository: not in stickshift_root - %s: (%s)" %
                          (thearg, realpath))
            print "Invalid repository %s: not in application root" % thearg
            sys.exit(3)

        if not os.path.isdir(realpath):
            syslog.syslog("Invalid repository %s (%s)" % 
                          (thearg, realpath))
            print "Invalid repository %s: not a directory" % thearg
            sys.exit(3)


        allargs = [thearg]

    runcon = '/usr/bin/runcon'

    target_context = 'system_u:system_r:stickshift_t:s0'
    actual_context = selinux.getcon()[1]
    # Confirm nothing fishy is going on
    if target_context != actual_context:
        os.execv(runcon, [runcon, target_context, cmd] + allargs)
        sys.exit(1)
    else:
        os.execv(cmd, [cmd] + allargs)
        sys.exit(1)
