#!/usr/bin/python -Es
#
# Authors: Dan Walsh <dwalsh@redhat.com>
#
# Copyright (C) 2012 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

PROGNAME="sandboxcontainer"
from gi.repository import LibvirtGObject
from gi.repository import LibvirtSandbox
from gi.repository import GLib
import gi
import os, sys, shutil, errno, stat
import rpm
from subprocess import Popen, PIPE, STDOUT
import gettext
import selinux

LibvirtGObject.init_object_check(None)
LibvirtSandbox.init_check(None)

gettext.bindtextdomain(PROGNAME, "/usr/share/locale")
gettext.textdomain(PROGNAME)
try:
    gettext.install(PROGNAME,
                    localedir="/usr/share/locale",
                    unicode=False,
                    codeset = 'utf-8')
except IOError:
    import __builtin__
    __builtin__.__dict__['_'] = unicode

class Container:
    DEFAULT_DIRS  = [ "/etc", "/var" ]
    CONFIG_PATH   = "/etc/libvirt-sandbox/services/%s.sandbox"
    SYSTEM_DIRS   = [ "/tmp", "/etc/pki", "/var/tmp", "/dev/shm", "/var/log", "/var/lib/dhclient", "/home", "/root", "/var/spool" ]
    SELINUX_FILE_TYPE  = "svirt_lxc_file_t"
    SELINUX_TYPE  = "svirt_lxc_net_t"
    SELINUX_LEVEL = "s0"
    DEFAULT_PATH  = "/var/lib/libvirt/filesystems"
    DEFAULT_IMAGE = "/var/lib/libvirt/images/%s.raw"

    def __init__(self, name=None, uri = "lxc:///", path = DEFAULT_PATH):
        self.uri = uri
        self.clone = False
        self.use_image = False
        self.name = name
        self.path = path
        self.service_files = []
        self.conn = None
        self.image = None
        self.mcs = "s0"
        self.dynamic = False
        self.file_type = self.SELINUX_FILE_TYPE
        self.security_type = self.SELINUX_TYPE
        self.security_level = self.SELINUX_LEVEL
        self.size = 10 * MB
        if self.name:
            config = self.CONFIG_PATH % self.name
            try:
                self.config = LibvirtSandbox.Config.load_from_path(config)
                mounts = self.config.get_host_image_mounts()
                if len(mounts):
                    mount = self.config.get_host_image_mounts()[0]
            except gi._glib.GError, e:
                raise OSError(config + ": " + str(e))

    def __follow_units(self):
        unitst=""
        for i in self.service_files:
            unitst += "#ReloadPropagatedFrom=%s\n" % i

        return unitst

    def __create_system_unit(self):

        if self.rpm_name:
            name = "%s@%s.service" % (self.rpm_name, self.name)
        else:
            name = self.name

        self.unitfile = "/etc/systemd/system/" + name
        unit = """
[Unit]
Description=Secure Container %s for %s Server
Requires=libvirtd.service
After=libvirtd.service
%s
[Service]
Type=simple
ExecStart=/usr/sbin/virt-sandbox-service start %s
ExecReload=/usr/sbin/virt-sandbox-service reload %s
ExecStop=/usr/sbin/virt-sandbox-service stop %s

[Install]
WantedBy=multi-user.target
""" % ( self.name , self.executable, self.__follow_units(), self.name, self.name, self.name)
        fd = open(self.unitfile, "w")
        fd.write(unit)
        fd.close()
        sys.stdout.write("Created unit file %s\n" % self.unitfile)

    def __add_dir(self, newd):
        self.all_dirs.append(newd)
        tmp = self.dirs
        for d in tmp:
            if newd.startswith(d):
                return
            if d.startswith(newd):
                self.dirs.remove(d)
                self.dirs.append(newd)
                return
        self.dirs.append(newd)

    def set_clone(self, clone):
        self.clone = clone

    def get_dynamic(self):
        return self.dynamic

    def set_dynamic(self, dynamic):
        self.dynamic = dynamic

    def set_security_type(self, security_type):
        self.security_type = security_type

    def set_security_level(self, security_level):
        self.security_level = security_level

    def set_image(self, size):
        self.use_image = True
        self.size = size * MB

    def set_path(self, path):
        self.path = path

    def set_mcs(self, mcs):
        self.mcs = mcs

    def __extract_rpm(self):
        self.all_dirs = []
        self.dirs = []

        ts = rpm.ts()
        mi = ts.dbMatch(rpm.RPMTAG_BASENAMES, self.executable)
        h = mi.next()
        self.rpm_name=h['name']
        for rec in h.fiFromHeader():
            f = rec[0]
            for b in self.DEFAULT_DIRS:
                if f.startswith(b):
                    if os.path.isdir(f):
                        self.__add_dir(f)

            if f.endswith(".service"):
                self.service_files.append(f)

    def __add_mount(self, source, dest):
        mount = LibvirtSandbox.ConfigMount.new(dest)
        mount.set_root(source)
        self.config.add_host_bind_mount(mount)

    def __gen_filesystems(self):
        if self.use_image:
            self.image = self.DEFAULT_IMAGE % self.name
            mount = LibvirtSandbox.ConfigMount.new(self.dest)
            mount.set_root(self.image)
            self.config.add_host_image_mount(mount)

        for d in self.SYSTEM_DIRS + self.dirs:
            source = "%s%s" % ( self.dest, d)
            self.__add_mount(source, d)

    def __get_init_path(self):
        return "%s/%s/init" % (self.path, self.name)

    def __gen_dirs(self):
        if self.clone:
            for d in self.dirs:
                shutil.copytree(d, "%s%s" % (self.dest, d), symlinks=True)
        else:
            for d in self.all_dirs:
                os.makedirs("%s%s" % (self.dest, d))

        for d in self.SYSTEM_DIRS:
            try:
                os.makedirs("%s%s" % (self.dest, d))
            except OSError, e:
                if not e.errno == errno.EEXIST:
                    raise e

    def __umount(self):
        p = Popen(["/bin/umount", self.dest])
        p.communicate()
        if p.returncode and p.returncode != 0:
            raise OSError("Failed to unmount image %s from %s " %  (self.image, self.dest))

    def __create_image(self):
        fd = open(self.image, "w")
        fd.truncate(self.size)
        fd.close()
        p = Popen(["/sbin/mkfs","-F", "-t", "ext4", self.image],stdout=PIPE, stderr=PIPE)
        p.communicate()
        if p.returncode and p.returncode != 0:
            raise OSError("Failed to build image %s" % self.image )

        p = Popen(["/bin/mount", self.image, self.dest])
        p.communicate()
        if p.returncode and p.returncode != 0:
            raise OSError("Failed to mount image %s on %s " %  (self.image, self.dest))

    def __save(self):
        config = self.CONFIG_PATH % self.name
        if os.path.exists(config):
            os.remove(config)
        self.config.save_to_path(config)
        sys.stdout.write("Created sandbox config %s\n" % config)

    def __get_image_path(self):
            mounts = self.config.get_host_image_mounts()
            if len(mounts) and mount[0].get_target() == self.dest:
                return mount[0].get_root()
            return None

    def delete(self):
        # Stop service if it is running
        try:
            self.stop()
        except:
            pass

        self.dest = "%s/%s" % (self.path, self.name)
        if os.path.exists(self.dest):
            shutil.rmtree(self.dest)
            mount_path = self.__get_image_path()
            if mount_path:
                os.remove(mount_path)

        config = self.CONFIG_PATH % self.name
        if os.path.exists(config):
            os.remove(config)

        unitfile = "/etc/systemd/system/%s.service" % self.name
        if os.path.exists(unitfile):
            os.remove(unitfile)

    def create(self, name, executable):
        self.dest = "%s/%s" % (self.path, name)
        if os.path.exists(self.dest):
            raise OSError("%s already exists" % self.dest)

        self.config = LibvirtSandbox.ConfigService.new(name)
        self.config.set_security_label("system_u:system_r:%s:%s" % (self.security_type, self.security_level))
        self.config.set_security_dynamic(self.dynamic)
        net = LibvirtSandbox.ConfigNetwork().new()
        net.set_dhcp(True)
        self.config.add_network(net)
        self.name = name
        self.executable = executable
        self.config.set_executable(self.executable)
        self.config.set_shell(True)

        self.__extract_rpm()
        self.__gen_filesystems()
        os.mkdir(self.dest)
        sys.stdout.write("Created container dir %s\n" % self.dest)
        if self.image:
            self.__create_image()
            self.__gen_dirs()
            self.__umount()
        else:
            self.__gen_dirs()
            selinux.chcon(self.dest, "system_u:object_r:%s:%s" % (self.file_type, self.security_level), True)

#        resolv_conf = self.dest + "/etc/resolv.conf"
#        mount = LibvirtSandbox.ConfigMount.new("/etc/resolv.conf")
#        mount.set_root(resolv_conf)
#        self.config.add_host_mount(mount)
#        fd = open(resolv_conf,"w")
#        fd.close()

        self.__save()
        self.__create_system_unit()

    def reload(self):
        # Crude way of doing this.
        self.stop()
        self.start()

    def __connect(self):
        if not self.conn:
            self.conn=LibvirtGObject.Connection.new(self.uri)
            self.conn.open(None)

    def start(self):
        def closed(obj, error):
            self.loop.quit()

        self.__connect()
        context = LibvirtSandbox.ContextService.new(self.conn, self.config)
        context.start()
        console = context.get_log_console()
        console.connect("closed", closed)
        console.attach_stderr()
        self.loop = GLib.MainLoop()
        self.loop.run()
        try:
            console.detach()
        except:
            pass
        try:
            context.stop()
        except:
            pass

    def stop(self):
        self.__connect()
        context = LibvirtSandbox.ContextService.new(self.conn, self.config)
        context.attach()
        context.stop()

    def connect(self):
        def closed(obj, error):
            self.loop.quit()

        self.__connect()
        context = LibvirtSandbox.ContextService.new(self.conn, self.config)
        context.attach()
        print dir(context)
        console = context.get_shell_console()
        console.connect("closed", closed)
        console.attach_stdio()
        self.loop = GLib.MainLoop()
        self.loop.run()
        try:
            console.detach()
        except:
            pass

    def get_process_label(self):
        return "system_u:system_r:%s:%s" % (self.process_type, self.mcs)

    def get_file_label(self):
        return "system_u:object_r:%s:%s" % (self.file_type, self.mcs)

MB = int(1000000)

def delete(args):
    container = Container(args.name)
    container.delete()

def create(args):
    container = Container()
    container.set_clone(args.clone)
    container.set_dynamic(args.dynamic)
    container.set_security_type(args.type)
    if args.level:
        container.set_security_level(args.level)

    if args.level:
        container.set_mcs(args.level)

    if args.image:
        container.set_image(args.size)
        if not args.level:
            container.set_mcs(args.dynamic)

    container.create(args.name, args.executable)

def usage(parser, msg):
    parser.print_help()

    sys.stderr.write("\n%s\n" % msg)
    sys.stderr.flush()
    sys.exit(1)

def start(args):
    container = Container(args.name, args.uri)
    container.start()

def reload(args):
    container = Container(args.name, args.uri)
    container.reload()

def stop(args):
    container = Container(args.name, args.uri)
    container.stop()

def connect(args):
    container = Container(args.name, args.uri)
    container.connect()

import argparse
class SizeAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        if namespace.image:
            setattr(namespace, self.dest, int(values))
        else:
            raise ValueError(parser, "%s --size options requires --image option" % parser.prog)

if __name__ == '__main__':
    c = Container()

    parser = argparse.ArgumentParser(description='Secure Container Tool')
    parser.add_argument("-c", "--connect", required=False, dest="uri",  default="lxc:///",
                        help="libvirt connection URI to use (lxc:/// [default] or qemu:///session)")

    subparsers = parser.add_subparsers(help="commands")
    parser_create = subparsers.add_parser("create", help="create a security container")

    parser_create.add_argument("-e", "--executable", required=True, dest="executable",  default=None,
                       help="executable to run within the container")
    parser_create.add_argument("-p", "--path", dest="path",  default=None,
                       help="select path to store container content.  Default: %s" % c.DEFAULT_PATH)
    parser_create.add_argument("-t", "--type", dest="type", default=c.SELINUX_TYPE,
                       help="SELinux type with which to run the container.  Default: %s" % c.SELINUX_TYPE)
    group = parser_create.add_mutually_exclusive_group()
    group.add_argument("-l", "--level", dest="level", default=None,
                       help="MCS Level with which to run the container")
    group.add_argument("-d", "--dynamic", dest="dynamic", default=False, action="store_true",
                       help="use dynamic MCS labeling for the container")
    image = parser_create.add_argument_group("Create On Disk Image File")

    image.add_argument("-i", "--image", dest="image", default=False, action="store_true",
                       help="use an image file for the container content")
    image.add_argument("-s", "--size", dest="size", default=(10 * MB), action=SizeAction,
                       help="create image of this size. Requires --image option. Default 10m.")
    parser_create.add_argument("-n", "--clone", default=False, action="store_true",
                       help="clone content from /etc and /var directories that will be mounted within the container")

    parser_create.set_defaults(func=create)

    parser_connect = subparsers.add_parser("connect", help="Connect to a security container")
    parser_connect.set_defaults(func=connect)

    parser_delete = subparsers.add_parser("delete", help="Delete a security container")
    parser_delete.set_defaults(func=delete)

    parser_stop = subparsers.add_parser("stop", help="Stop a security container")
    parser_stop.set_defaults(func=stop)

    parser_start = subparsers.add_parser("start", help="Start a security container")
    parser_start.set_defaults(func=start)

    parser_reload = subparsers.add_parser("reload", help="Reload a running security container")
    parser_reload.set_defaults(func=reload)

    parser.add_argument("name", help="name of the container")

    try:
        args = parser.parse_args()
    except ValueError, e:
        usage(e[0], e[1])

    try:
        args.func(args)
        sys.exit(0)
    except KeyboardInterrupt, e:
        sys.exit(0)
    except ValueError, e:
        for line in e:
            for l in line:
                print l
