#!/usr/bin/python -Es
#
# Authors: Dan Walsh <dwalsh@redhat.com>
#
# Copyright (C) 2012 Red Hat, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
#

PROGNAME="virt-sandbox-service"

from gi.repository import LibvirtGObject
from gi.repository import LibvirtSandbox
from gi.repository import GLib
import gi
import os, sys, shutil, errno, stat
import rpm
from subprocess import Popen, PIPE, STDOUT
import gettext
import selinux

LibvirtGObject.init_object_check(None)
LibvirtSandbox.init_check(None)

gettext.bindtextdomain(PROGNAME, "/usr/share/locale")
gettext.textdomain(PROGNAME)
try:
    gettext.install(PROGNAME,
		    localedir="/usr/share/locale",
		    unicode=False,
		    codeset = 'utf-8')
except IOError:
    import __builtin__
    __builtin__.__dict__['_'] = unicode

class Container:
    DEFAULT_DIRS  = [ "/etc", "/var" ]
    CONFIG_PATH   = "/etc/libvirt-sandbox/services/"
    NETWORK_CONFIG = "/etc/sysconfig/network"
    SYSINIT_WANTS_PATH = "/usr/lib/systemd/system/sysinit.target.wants"
    SOCKET_WANTS_PATH = "/usr/lib/systemd/system/sockets.target.wants"
    BIND_SYSTEM_DIRS   = [ "/etc/pki", "/var/tmp", "/var/log", "/var/lib/dhclient", "/var/lib/dbus", "/home", "/root", "/var/spool", "/var/cache", "/etc/systemd/system", "/usr/lib/systemd/system/basic.target.wants", "/usr/lib/systemd/system/local-fs.target.wants", SYSINIT_WANTS_PATH, SOCKET_WANTS_PATH ]
    BIND_SYSTEM_FILES   = [ "/etc/machine-id", "/etc/adjtime", "/etc/fstab", NETWORK_CONFIG ]
    LOCAL_LINK_FILES   = { SYSINIT_WANTS_PATH : [ "systemd-tmpfiles-setup.service" ] , SOCKET_WANTS_PATH : [ "dbus.socket", "systemd-journald.socket", "systemd-shutdownd.socket" ] }

    SELINUX_FILE_TYPE  = "svirt_lxc_file_t"
    DEFAULT_PATH  = "/var/lib/libvirt/filesystems"
    DEFAULT_IMAGE = "/var/lib/libvirt/images/%s.raw"

    def __init__(self, name=None, uri = "lxc:///", path = DEFAULT_PATH, config=None, create=False):
	self.uri = uri
	self.clone = False
	self.use_image = False
	self.path = path
	self.conn = None
	self.image = None
	self.size = 10 * MB
	self.config = None
	self.file_type = self.SELINUX_FILE_TYPE
	self.unit_file_list = []

	if name:
	    config = self.get_config_path(name)

	if create:
	    if config and os.path.exists(config):
		raise ValueError(["Container already exists"])

	    self.config = LibvirtSandbox.ConfigService.new(name)
	    return

	if config:
	    try:
		self.config = LibvirtSandbox.Config.load_from_path(config)
		self.unitfile = self.__gen_unit_file_name()
	    except gi._glib.GError, e:
		raise OSError(config + ": " + str(e))

    def __follow_units(self):
	unitst=""
	for i in self.unit_file_list:
	    unitst += "#ReloadPropagatedFrom=%s\n" % i

	return unitst

    def __gen_unit_file_name(self):
	if len(self.unit_file_list) != 1:
	    return "/etc/systemd/system/%s.service" % self.config.get_name()
	return "/etc/systemd/system/%s@%s.service" %  (self.unit_file_list[0].split(".")[0], self.config.get_name())

    def add_unit_file(self, unit_file):
	if unit_file not in self.unit_file_list:
	    self.unit_file_list.append(unit_file)

    def delete_unit_file(self, unit_file):
	self.unit_file_list.delete(unit_file)

    def get_unit_file(self, unit_file):
	return self.unit_file_list

    def set_unit_file_list(self, unit_file_list):
	self.unit_file_list = unit_file_list

    def __create_system_unit(self):
	self.unitfile = self.__gen_unit_file_name()
	name = self.config.get_name()
	unit = r"""
[Unit]
Description=Secure Sandbox Container %(NAME)s
Requires=libvirtd.service
After=libvirtd.service
%(FOLLOW)s
[Service]
Type=simple
ExecStart=/usr/bin/virt-sandbox-service start %(NAME)s
ExecReload=/usr/bin/virt-sandbox-service reload %(NAME)s
ExecStop=/usr/bin/virt-sandbox-service stop %(NAME)s

[Install]
WantedBy=multi-user.target
""" % { 'NAME':name, 'FOLLOW':self.__follow_units() }
	fd = open(self.unitfile, "w")
	fd.write(unit)
	fd.close()
	sys.stdout.write("Created unit file %s\n" % self.unitfile)

    def __add_dir(self, newd):
	self.all_dirs.append(newd)
	tmp = self.dirs
	for d in tmp:
	    if newd.startswith(d):
		return
	    if d.startswith(newd):
		self.dirs.remove(d)
		self.dirs.append(newd)
		return
	self.dirs.append(newd)

    def get_config_path(self, name = None):
	if not name:
	    name = self.get_name()
	return self.CONFIG_PATH + name + ".sandbox"

    def get_name(self):
	if self.config:
	    return self.config.get_name()
	raise ValueError(["Name not configured"])

    def set_clone(self, clone):
	self.clone = clone

    def get_dynamic(self):
	return self.config.get_security_dynamic()

    def set_dynamic(self, val):
	return self.config.set_security_dynamic(val)

    def get_security_type(self):
	if self.config:
	    con = self.config.get_security_label().split(':')
	    return con[2]
	else:
	    return "svirt_lxc_net_t"

    def get_security_level(self):
	if self.config:
	    con = self.config.get_security_label().split(':')
	    return ":".join(con[3:])
	else:
	    return "s0"

    def set_security_type(self, security_type):
	label = "system_u:system_r:%s:%s" % (security_type, self.get_security_level())
	try:
	    selinux.security_check_context(label)
	    self.config.set_security_label(label)
	except OSError, e:
	    raise OSError("Invalid Security Type %s: %s " %  (security_type, e))

    def set_security_level(self, security_level):
	label = "system_u:system_r:%s:%s" % (self.get_security_type(), security_level)
	try:
	    selinux.security_check_context(label)
	    self.config.set_security_label(label)
	except OSError, e:
	    raise OSError("Invalid Security Level %s: %s " %  (security_level, e))

    def set_image(self, size):
	self.use_image = True
	self.size = size * MB

    def set_path(self, path):
	self.path = path

    def set_name(self, name):
	if self.config:
	    raise ValueError(["Cannot modify Name"])
	self.config = LibvirtSandbox.ConfigService.new(name)

    def __extract_rpms(self):
	self.all_dirs = []
	self.dirs = []

	for u in self.unit_file_list:
	    self.__extract_rpm_for_unit(u)

    def __extract_rpm_for_unit(self, unitfile):
	try:
	    ts = rpm.ts()
	    mi = ts.dbMatch(rpm.RPMTAG_BASENAMES, "/lib/systemd/system/" + unitfile)
	    h = mi.next()
	    for rec in h.fiFromHeader():
		f = rec[0]
		for b in self.DEFAULT_DIRS:
		    # /var/run is a symbolic link to /run
		    if f.startswith(b) and not f.startswith("/var/run/"):
			if os.path.isdir(f):
			    self.__add_dir(f)
	except StopIteration:
	    pass

    def __add_mount(self, source, dest):
	mount = LibvirtSandbox.ConfigMountHostBind.new(source, dest)
	self.config.add_mount(mount)

    def __gen_filesystems(self):
	if self.use_image:
	    self.image = self.DEFAULT_IMAGE % self.get_name()
	    mount = LibvirtSandbox.ConfigMountHostImage.new(self.image, self.dest)
	    self.config.add_mount(mount)

	# 10 MB /run
	mount = LibvirtSandbox.ConfigMountRam.new("/run", 10 * 1024 * 1024);
	self.config.add_mount(mount)

	# Link /var/run to /run
	mount = LibvirtSandbox.ConfigMountGuestBind.new("/run", "/var/run")
	self.config.add_mount(mount)

	# 100 MB /tmp
	mount = LibvirtSandbox.ConfigMountRam.new("/tmp", 100 * 1024 * 1024);
	self.config.add_mount(mount)

	# 100 MB /tmp
	mount = LibvirtSandbox.ConfigMountRam.new("/dev/shm", 100 * 1024 * 1024);
	self.config.add_mount(mount)

	for d in self.BIND_SYSTEM_DIRS:
	    source = "%s%s" % ( self.dest, d)
	    self.__add_mount(source, d)

	for f in self.BIND_SYSTEM_FILES:
	    source = "%s%s" % ( self.dest, f)
	    self.__add_mount(source, f)

	for d in self.dirs:
	    found = False
	    # Dont add dirs whos parents are in SYSTEM_DIRS
	    for s in self.BIND_SYSTEM_DIRS:
		if d.startswith(s):
		    found = True
		    break
	    if not found:
		source = "%s%s" % ( self.dest, d)
		self.__add_mount(source, d)

    def __get_init_path(self):
	return "%s/%s/init" % (self.path, self.get_name())

    def __create_container_unit(self, src, dest, unit):
	    fd = open(dest + "/" + unit, "w")
	    fd.write(""".include %s
[Service]
PrivateTmp=false
PrivateNetwork=false
""" % src )
	    fd.close()

    def __gen_content(self):
	if self.clone:
	    for d in self.dirs:
		shutil.copytree(d, "%s%s" % (self.dest, d), symlinks=True)
	else:
	    for d in self.all_dirs:
		os.makedirs("%s%s" % (self.dest, d))

	for d in self.BIND_SYSTEM_DIRS:
	    try:
		os.makedirs("%s%s" % (self.dest, d))
	    except OSError, e:
		if not e.errno == errno.EEXIST:
		    raise e

	for f in self.BIND_SYSTEM_FILES:
	    try:
		try:
		    os.makedirs("%s%s" % (self.dest, os.path.dirname(f)))
		except OSError, e:
		    if not e.errno == errno.EEXIST:
			raise e
		fd=open("%s%s" % (self.dest, f), "w")
		fd.close()
	    except OSError, e:
		if not e.errno == errno.EEXIST:
		    raise e

	fd=open(self.dest + self.NETWORK_CONFIG, "w")
	fd.write("HOSTNAME=%s\n" % self.config.get_name() )
	fd.close()

	for k in self.LOCAL_LINK_FILES:
	    for d in self.LOCAL_LINK_FILES[k]:
		src = "../%s" % ( d)
		dest = "%s%s/%s" % ( self.dest, k, d)
		os.symlink(src,dest)

	unitdir = self.dest + "/etc/systemd/system"
	try:
	    os.makedirs(unitdir)
	except OSError, e:
	    if not e.errno == errno.EEXIST:
		raise e

	tgtdir = unitdir + "/sandbox.target.wants"
	try:
	    os.makedirs(tgtdir)
	except OSError, e:
	    if not e.errno == errno.EEXIST:
		raise e

	for i in self.unit_file_list:
	    src = "/etc/systemd/system/" + i
	    if not os.path.exists(src):
		src = "/lib/systemd/system/" + i
	    if not os.path.exists(src):
		raise OSError("Requested unit %s does not exist" % i)
	    self.__create_container_unit(src, unitdir, i)
	    os.symlink("../" + i, tgtdir + "/" + i)

	tgtfile = unitdir + "/sandbox.target"
	try:
	    fd = open(tgtfile, "w")
	    fd.write("[Unit]\n")
	    fd.write("Description=Sandbox target\n")
	    fd.close()
	except OSError, e:
	    if not e.errno == errno.EEXIST:
		raise e


    def __umount(self):
	p = Popen(["/bin/umount", self.dest])
	p.communicate()
	if p.returncode and p.returncode != 0:
	    raise OSError("Failed to unmount image %s from %s " %  (self.image, self.dest))

    def __create_image(self):
	fd = open(self.image, "w")
	fd.truncate(self.size)
	fd.close()
	p = Popen(["/sbin/mkfs","-F", "-t", "ext4", self.image],stdout=PIPE, stderr=PIPE)
	p.communicate()
	if p.returncode and p.returncode != 0:
	    raise OSError("Failed to build image %s" % self.image )

	p = Popen(["/bin/mount", self.image, self.dest])
	p.communicate()
	if p.returncode and p.returncode != 0:
	    raise OSError("Failed to mount image %s on %s " %  (self.image, self.dest))

    def save_config(self):
	config = self.get_config_path()
	if os.path.exists(config):
	    os.remove(config)
	self.config.save_to_path(config)
	sys.stdout.write("Created sandbox config %s\n" % config)

    def __get_image_path(self):
	    mounts = self.config.get_mounts()
	    for m in mounts:
		if type(m) != LibvirtSandbox.ConfigMountHostImage:
		    continue

		if m.get_target() == self.dest:
		    return m.get_source()
	    return None

    def delete(self):
	# Stop service if it is running
	try:
	    self.stop()
	except:
	    pass

	self.dest = "%s/%s" % (self.path, self.get_name())
	if os.path.exists(self.dest):
	    shutil.rmtree(self.dest)
	    mount_path = self.__get_image_path()
	    if mount_path:
		os.remove(mount_path)

	config = self.get_config_path()
	if os.path.exists(config):
	    os.remove(config)

	if os.path.exists(self.unitfile):
	    os.remove(self.unitfile)

    def create(self):
	self.dest = "%s/%s" % (self.path, self.get_name())
	if os.path.exists(self.dest):
	    raise OSError("%s already exists" % self.dest)

	net = LibvirtSandbox.ConfigNetwork().new()
	net.set_dhcp(True)
	self.config.add_network(net)
	self.config.set_shell(True)
	self.config.set_boot_target("sandbox.target")
	self.__extract_rpms()
	self.__gen_filesystems()
	os.mkdir(self.dest)
	sys.stdout.write("Created sandbox container dir %s\n" % self.dest)
	if self.image:
	    self.__create_image()
	    self.__gen_content()
	    self.__umount()
	else:
	    self.__gen_content()
	    selinux.chcon(self.dest, "system_u:object_r:%s:%s" % (self.file_type, self.get_security_level()), True)

	self.save_config()
	self.__create_system_unit()

    def reload(self):
	# Crude way of doing this.
	self.stop()
	self.start()
	# self.execute("systemctl reload %s.service" % self.get_name())

    def __connect(self):
	if not self.conn:
	    self.conn=LibvirtGObject.Connection.new(self.uri)
	    self.conn.open(None)

    def __disconnect(self):
	if self.conn:
	    self.conn.close()
	    self.conn = None

    def running(self):
	try:
	    self.__connect()
	    context = LibvirtSandbox.ContextService.new(self.conn, self.config)
	    context.attach()
	    self.__disconnect()
	    return 1
	except gi._glib.GError, e:
	    return 0

    def start(self):
	def closed(obj, error):
	    self.loop.quit()

	try:
	    self.__connect()
	    context = LibvirtSandbox.ContextService.new(self.conn, self.config)
	    context.start()
	    console = context.get_log_console()
	    console.connect("closed", closed)
	    console.attach_stderr()
	    self.loop = GLib.MainLoop()
	    self.loop.run()
	    try:
		console.detach()
	    except:
		pass
	    try:
		context.stop()
	    except:
		pass
	except gi._glib.GError, e:
	    raise OSError(str(e))

    def stop(self):
	try:
	    self.__connect()
	    context = LibvirtSandbox.ContextService.new(self.conn, self.config)
	    context.attach()
	    context.stop()
	except gi._glib.GError, e:
	    raise OSError(str(e))

    def connect(self):
	def closed(obj, error):
	    self.loop.quit()

	try:
	    self.__connect()
	    context = LibvirtSandbox.ContextService.new(self.conn, self.config)
	    context.attach()
	    console = context.get_shell_console()
	    console.connect("closed", closed)
	    console.attach_stdio()
	    self.loop = GLib.MainLoop()
	    self.loop.run()
	    try:
		console.detach()
	    except:
		pass
	except gi._glib.GError, e:
	    raise OSError(str(e))

    def execute(self, command):
	def closed(obj, error):
	    self.loop.quit()

	try:
	    self.__connect()
	    context = LibvirtSandbox.ContextService.new(self.conn, self.config)
	    context.attach()
	    console = context.get_shell_console()
	    console.connect("closed", closed)
	    console.attach_stdio()
	    print "not currently implemented"
	    console.detach()
	    return
	    self.loop = GLib.MainLoop()
	    self.loop.run()
	    try:
		console.detach()
	    except:
		pass
	except gi._glib.GError, e:
	    raise OSError(str(e))

MB = int(1000000)

def delete(args):
    container = Container(args.name)
    container.delete()

def create(args):
    container = Container(name = args.name, create = True)
    container.set_clone(args.clone)
    container.set_dynamic(args.dynamic)
    container.set_security_type(args.type)
    container.set_unit_file_list(args.unitfiles)
    if args.level:
	container.set_security_level(args.level)

    if args.imagesize:
	container.set_image(args.imagesize)

    container.create()

def usage(parser, msg):
    parser.print_help()

    sys.stderr.write("\n%s\n" % msg)
    sys.stderr.flush()
    sys.exit(1)

def sandbox_list(args):
    import glob
    container = Container()
    g = glob.glob(container.CONFIG_PATH + "*.sandbox")
    for gc in g:
	try:
	    c = Container(config = gc)
	    if args.running:
		if c.running():
		    print c.get_name()
	    else:
		print c.get_name()

	except OSError, e:
	    print "Invalid container %s: %s", (gc, e)


def sandbox_reload(args):
    container = Container(args.name, args.uri)
    container.reload()

def start(args):
    os.execl("/usr/libexec/virt-sandbox-service-util", "virt-sandbox-service-util","-s", args.name)
#    container = Container(args.name, args.uri)
#    container.start()

def stop(args):
    os.execl("/usr/libexec/virt-sandbox-service-util", "virt-sandbox-service-util","-S", args.name)
#    container = Container(args.name, args.uri)
#    container.stop()

def connect(args):
    print """\
Connected to %s.
Escape character is '^]'.
""" % ( args.name )
    os.execl("/usr/libexec/virt-sandbox-service-util", "virt-sandbox-service-util","-a", args.name)
#    container = Container(args.name, args.uri)
#    container.connect()

#
# Search Path for command to execute within the container.
#
def fullpath(cmd):
    for i in [ "/", "./", "../" ]:
	if cmd.startswith(i):
	    return cmd
    for i in  os.environ["PATH"].split(':'):
	f = "%s/%s" % (i, cmd)
	if os.access(f, os.X_OK):
	    return f
    return cmd

def execute(args):
    myargs = args.command.split()
    myargs[0] = fullpath(myargs[0])
    cmd = "%s" % " ".join(myargs)
    myexec = [ "virt-sandbox-service-util", "-e", cmd, "-p", args.pid, args.name ]
    os.execv("/usr/libexec/virt-sandbox-service-util", myexec)

def clone(args):
    container = Container(args.source, args.uri)
    if args.level:
	container.set_security_level(args.level)
    container.set_name(args.dest)
    container.save_config()

import argparse
class SizeAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
	setattr(namespace, self.dest, int(values))

if __name__ == '__main__':
    c = Container()

    parser = argparse.ArgumentParser(description='Sandbox Container Tool')
    parser.add_argument("-c", "--connect", required=False, dest="uri",  default="lxc:///",
			help=_("libvirt connection URI to use (lxc:/// [default] or qemu:///session)"))

    subparsers = parser.add_subparsers(help=_("commands"))
    parser_create = subparsers.add_parser("create",
					  help=_("create a sandbox container"))

    parser_create.add_argument("-u", "--unitfile", required=True,
			       dest="unitfiles",  action='append',
			       help=_("Systemd Unit file to run within the sandbox container"))
    parser_create.add_argument("-p", "--path", dest="path",  default=None,
			       help=_("select path to store sandbox content.  Default: %s") % c.DEFAULT_PATH)
    parser_create.add_argument("-t", "--type", dest="type",
			       default=c.get_security_type(),
			       help=_("SELinux type with which to run the sandbox.  Default: %s") % c.get_security_type())
    group = parser_create.add_mutually_exclusive_group()
    group.add_argument("-l", "--level", dest="level", default=None,
		       help=_("MCS Level with which to run the sandbox"))
    group.add_argument("-d", "--dynamic", dest="dynamic", default=False,
		       action="store_true",
		       help=_("use dynamic MCS labeling for the sandbox"))
    image = parser_create.add_argument_group("Create On Disk Image File")

    image.add_argument("-i", "--imagesize", dest="imagesize", default = None,
		       action=SizeAction,
		       help=_("create image of this many megabytes."))
    parser_create.add_argument("-C", "--clone", default=False,
			       action="store_true",
			       help=_("clone content from /etc and /var directories that will be mounted within the sandbox"))

    parser_create.set_defaults(func=create)
    parser_create.add_argument("name",
			       help=_("name of the sandbox container"))

    parser_connect = subparsers.add_parser("connect",
					   help=_("Connect to a sandbox container"))
    parser_connect.set_defaults(func=connect)
    parser_connect.add_argument("name",
				help=_("name of the sandbox container"))

    parser_clone = subparsers.add_parser("clone",
					 help=_("Clone an existing sandbox container"))
    parser_clone.set_defaults(func=clone)
    parser_clone.add_argument("-l", "--level", dest="level", default=None,
			      help=_("MCS Level with which to run the sandbox"))
    parser_clone.add_argument("-s", "--source", required=True,
			      dest="source", default=None,
			      help=_("source sandbox container name"))
    parser_clone.add_argument("dest",
			      help=("dest name of the new sandbox container"))

    parser_execute = subparsers.add_parser("execute",
					   help=("Execute a command within a sandbox container"))
    parser_execute.set_defaults(func=execute)
    parser_execute.add_argument("-C", "--command", required=True,
				dest="command", default=None,
				help=_("command to execute within the sandbox"))
    # This option should be removed as soon as we can query libvirt for the
    # pid of libvirt_lxc or systemd associated with the container
    parser_execute.add_argument("-p", "--pid", required=True,
				dest="pid", default=None,
				help=_("pid file of systemd"))
    parser_execute.add_argument("name",
				help=_("name of the sandbox container"))

    parser_delete = subparsers.add_parser("delete",
					  help=_("Delete a sandbox container"))
    parser_delete.set_defaults(func=delete)
    parser_delete.add_argument("name",
			       help=_("name of the sandbox container"))

    parser_stop = subparsers.add_parser("stop",
					help=_("Stop a sandbox container"))
    parser_stop.set_defaults(func=stop)
    parser_stop.add_argument("name",
			     help=_("name of the sandbox container"))

    parser_start = subparsers.add_parser("start",
					 help=_("Start a sandbox container"))
    parser_start.set_defaults(func=start)
    parser_start.add_argument("name",
			      help=_("name of the sandbox container"))

    parser_reload = subparsers.add_parser("reload",
					  help=_("Reload a running sandbox container"))
    parser_reload.set_defaults(func=sandbox_reload)
    parser_reload.add_argument("name",
			       help=_("name of the sandbox container"))

    parser_list = subparsers.add_parser("list",
					help=_("List all sandbox container"))
    parser_list.set_defaults(func=sandbox_list)
    parser_list.add_argument("-r", "--running", dest="running",
			     default=False, action="store_true",
			     help=_("list only running sandbox containers"))

    try:
	args = parser.parse_args()
	args.func(args)
	sys.exit(0)
    except KeyboardInterrupt, e:
	sys.exit(0)
    except ValueError, e:
	for line in e:
	    for l in line:
		sys.stderr.write("%s: %s\n" % (sys.argv[0], l))
	sys.stderr.flush()
	sys.exit(1)
    except OSError, e:
	sys.stderr.write("%s: %s\n" % (sys.argv[0], e))
	sys.stderr.flush()
	sys.exit(1)
