#!/usr/bin/env python

# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as
# published by the Free Software Foundation; either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
# Copyright 2012, Andy Grover <agrover@redhat.com>
#
# A server that exposes a network interface for the LIO
# kernel target.

import os
import sys
import contextlib
import setproctitle
from rtslib import (Target, TPG, NodeACL, FabricModule, BlockStorageObject,
                    NetworkPortal, LUN, MappedLUN, RTSLibError, RTSLibNotInCFS)
import lvm
import json
from BaseHTTPServer import BaseHTTPRequestHandler, HTTPServer
from SocketServer import ThreadingMixIn
import socket
from threading import Lock
import yaml
import time
from targetcli import UIRoot
from configshell import ConfigShell
import ssl

setproctitle.setproctitle("targetd")

config_path = "/etc/target/targetd.yaml"

default_config = dict(
    pool_name = "vg-targetd",
    user = "admin",
    # security: no default password
    target_name = "iqn.2003-01.org.linux-iscsi.%s:targetd" % socket.gethostname(),
    ssl = False,
    ssl_cert = "/etc/target/targetd_cert.pem",
    ssl_key = "/etc/target/targetd_key.pem",
)

config = {}
if os.path.isfile(config_path):
    config = yaml.load(open(config_path).read())
    if config == None:
        config = {}

for key, value in default_config.iteritems():
    if key not in config:
        config[key] = value

if os.getuid() != 0:
    print "targetd must run as root."
    sys.exit(-1)

if not config.get('password', None):
    print "password not set in %s, aborting" % config_path
    sys.exit(-1)

# fail early if can't access vg
test_vg = lvm.vgOpen(config['pool_name'], "w")
test_vg.close()

#
# Helper function to check/close vg for us.
#
@contextlib.contextmanager
def vgopen(pool_name):
    pool_check(pool_name)
    with contextlib.closing(lvm.vgOpen(pool_name, "w")) as vg:
        yield vg

def pool_check(pool_name):
    '''
    pool_name *cannot* be trusted, funcs taking a pool param must call
    this or vgopen() to ensure passed-in pool name is one targetd has
    been configured to use.
    '''
    if pool_name != config['pool_name']:
        raise IOError("Invalid pool name")

def volumes(req, pool):
    output = []
    with vgopen(pool) as vg:
        for lv in vg.listLVs():
            output.append(dict(name=lv.getName(), size=lv.getSize(),
                               uuid=lv.getUuid()))
    return output

def create(req, pool, name, size):
    with vgopen(pool) as vg:
        lv = vg.createLvLinear(name, int(size))

def destroy(req, pool, name):
    try:
        fm = FabricModule('iscsi')
        t = Target(fm, config['target_name'], mode='lookup')
        tpg = TPG(t, 1, mode='lookup')

        so_name = "%s:%s" % (pool, name)
        if so_name in (lun.storage_object.name for lun in tpg.luns):
            raise ValueError("Volume '%s' cannot be removed while exported" % name)
    except RTSLibNotInCFS:
        pass

    with vgopen(pool) as vg:
        vg.lvFromName(name).remove()

def copy(req, pool, vol_orig, vol_new, timeout=10):
    """
    Create a new volume that is a copy of an existing one.
    If this operation takes longer than the timeout, it will return
    an async completion and report actual status via async_complete().
    """
    with vgopen(pool) as vg:
        copy_size = vg.lvFromName(vol_orig).getSize()

    create(req, pool, vol_new, copy_size)

    try:
        src_path = "/dev/%s/%s" % (pool, vol_orig)
        dst_path = "/dev/%s/%s" % (pool, vol_new)

        start_time = time.clock()
        with open(src_path, 'rb') as fsrc:
            with open(dst_path, 'wb') as fdst:
                copied = 0
                while copied != copy_size:
                    buf = fsrc.read(1024*1024)
                    if not buf:
                        break
                    fdst.write(buf)
                    copied += len(buf)
                    if time.clock() > (start_time + timeout):
                        req.async_completion()
                        async_status(req, 0, int((float(copied)/copy_size)*100))
        complete_if_async(req, 0)

    except:
        destroy(req, vol_new)
        raise

def export_list(req):
    try:
        fm = FabricModule('iscsi')
        t = Target(fm, config['target_name'], mode='lookup')
        tpg = TPG(t, 1, mode='lookup')
    except RTSLibNotInCFS:
        return []

    exports = []
    for na in tpg.node_acls:
        for mlun in na.mapped_luns:
            mlun_vg, mlun_name = mlun.tpg_lun.storage_object.udev_path.split("/")[2:]
            with vgopen(mlun_vg) as vg:
                lv = vg.lvFromName(mlun_name)
                exports.append(dict(initiator_wwn=na.node_wwn, lun=mlun.mapped_lun,
                                    vol_name=mlun_name, pool=mlun_vg,
                                    vol_uuid=lv.getUuid(), vol_size=lv.getSize()))
    return exports

#
# HACK: call targetcli saveconfig method to save state
#
def _exports_save_config():
    root = UIRoot(ConfigShell(), as_root=True)
    root.ui_command_saveconfig()

def export_create(req, pool, vol, initiator_wwn, lun):

    # get wwn of volume so LIO can export as vpd83 info
    with vgopen(pool) as vg:
        vol_serial = vg.lvFromName(vol).getUuid()

    # only add new SO if it doesn't exist
    # so.name concats pool & vol names separated by ':'
    so_name = "%s:%s" % (pool,vol)
    try:
        so = BlockStorageObject(so_name)
    except RTSLibError:
        so = BlockStorageObject(so_name, dev="/dev/%s/%s" % (pool, vol))
        so.wwn = vol_serial

    fm = FabricModule('iscsi')
    t = Target(fm, config['target_name'])
    tpg = TPG(t, 1)
    tpg.enable = True
    tpg.set_attribute("authentication", 0)
    np = NetworkPortal(tpg, "0.0.0.0")
    na = NodeACL(tpg, initiator_wwn)

    # only add tpg lun if it doesn't exist
    for tmp_lun in tpg.luns:
        if tmp_lun.storage_object.name == so.name \
                and tmp_lun.storage_object.plugin == 'block':
            tpg_lun = tmp_lun
            break
    else:
        tpg_lun = LUN(tpg, storage_object=so)

    # only add mapped lun if it doesn't exist
    for tmp_mlun in tpg_lun.mapped_luns:
        if tmp_mlun.mapped_lun == lun:
            mapped_lun = tmp_mlun
            break
    else:
        mapped_lun = MappedLUN(na, lun, tpg_lun)

    _exports_save_config()

def export_destroy(req, pool, vol, initiator_wwn):
    pool_check(pool)
    fm = FabricModule('iscsi')
    t = Target(fm, config['target_name'])
    tpg = TPG(t, 1)
    na = NodeACL(tpg, initiator_wwn)

    for mlun in na.mapped_luns:
        # all SOs are Block so we can access udev_path safely
        mlun_vg, mlun_name = mlun.tpg_lun.storage_object.udev_path.split("/")[2:]

        if mlun_vg == pool and mlun_name == vol:
            tpg_lun = mlun.tpg_lun
            mlun.delete()
            # be tidy and delete unused tpg lun mappings?
            if not len(list(tpg_lun.mapped_luns)):
                so = tpg_lun.storage_object
                tpg_lun.delete()
                so.delete()
            break
    else:
        raise LookupError("Volume '%s' not found in %s exports" %
                          (vol, initiator_wwn))

    # Clean up tree if branch has no leaf
    if not len(list(na.mapped_luns)):
        na.delete()
        if not len(list(tpg.node_acls)):
            tpg.delete()
            if not len(list(t.tpgs)):
                t.delete()

    _exports_save_config()

def pools(req):
    with vgopen(config['pool_name']) as vg:
        # only support 1 vg for now
        return [dict(name=vg.getName(), size=vg.getSize(), free_size=vg.getFreeSize())]

def async_list(req):
    '''
    Return a list of ongoing processes. Processes that have terminated with an
    error are returned once and then delisted.
    '''
    with long_op_status_lock:
        status_dict = long_op_status.copy()
        for key, item in long_op_status:
            if item[0] > 0:
                del long_op_status[key]
    return status_dict


async_id_lock = Lock()
async_id = 100

def new_async_id():
    global async_id
    with async_id_lock:
        new_id = async_id
        async_id += 1
    return new_id

# Long-running threads update their progress here
long_op_status_lock = Lock()
# async_id -> (code, pct_complete)
long_op_status = dict()

def async_status(req, code, pct_complete=None):
    '''
    update a global array with status of ongoing ops.
    code: 0 if ok, or -err
    pct_complete: percent complete, integer 0-100
    '''
    with long_op_status_lock:
        long_op_status[req.async_id] = (code, pct_complete)

def complete_if_async(req, code):
    '''
    Ongoing op is done, remove status if succeeded
    '''
    if req.async_id:
        with long_op_status_lock:
            if not code:
                del long_op_status[req.async_id]


mapping = dict(
    vol_list=volumes,
    vol_create=create,
    vol_destroy=destroy,
    vol_copy=copy,
    export_list=export_list,
    export_create=export_create,
    export_destroy=export_destroy,
    pool_list=pools,
    async_list=async_list,
    )


class TargetHandler(BaseHTTPRequestHandler):

    def log_request(self, code='-', size='-'):
        # override base class - don't log good requests
        pass

    def do_POST(self):

        self.async_id = None

        # get basic auth string, strip "Basic "
        try:
            auth64 = self.headers.getheader("Authorization")[6:]
            in_user, in_pass = auth64.decode('base64').split(":")
        except:
            self.send_error(400)
            return

        if in_user != config['user'] or in_pass != config['password']:
            self.send_error(401)
            return

        if not self.path == "/targetrpc":
            self.send_error(404)
            return

        try:
            error = (-1, "jsonrpc error")
            self.id = None
            try:
                content_len = int(self.headers.getheader('content-length'))
                req = json.loads(self.rfile.read(content_len))
            except ValueError:
                # see http://www.jsonrpc.org/specification for errcodes
                errcode = (-32700, "parse error")
                raise

            self.send_response(200)
            self.send_header("Content-type", "application/json")
            self.end_headers()

            try:
                version = req['jsonrpc']
                if version != "2.0":
                    raise ValueError
                method = req['method']
                self.id = int(req['id'])
                params = req.get('params', None)
            except (KeyError, ValueError):
                error = (-32600, "not a valid jsonrpc-2.0 request")
                raise

            try:
                if params:
                    result = mapping[method](self, **params)
                else:
                    result = mapping[method](self)
            except KeyError:
                error = (-32601, "method %s not found" % method)
                raise
            except TypeError:
                error = (-32602, "invalid method parameter(s)")
                raise
            except Exception, e:
                error = (-1, "%s: %s" % (type(e).__name__, e))
                raise

            rpcdata = json.dumps(dict(result=result, id=self.id))

        except:
            rpcdata = json.dumps(dict(error=dict(code=error[0], message=error[1]), id=self.id))
        finally:
            if not self.async_id:
                self.wfile.write(rpcdata)

    def async_completion(self):
        if not self.async_id:
            self.async_id = new_async_id()
            rpcdata = json.dumps(dict(error=dict(code=self.async_id, message="Async Operation"), id=self.id))
            self.wfile.write(rpcdata)
            # wfile is buffered, need to do this to flush the response
            self.connection.shutdown(socket.SHUT_WR)

class ThreadedHTTPServer(ThreadingMixIn, HTTPServer, object):
    """Handle requests in a separate thread."""

class TLSThreadedHTTPServer(ThreadedHTTPServer):
    """Also use TLS to encrypt the connection"""

    def finish_request(self, sock, addr):
        sockssl = ssl.wrap_socket(
            sock, server_side=True,
            keyfile=config["ssl_key"],
            certfile=config["ssl_cert"],
            ciphers="HIGH:-aNULL:-eNULL:-PSK",
            suppress_ragged_eofs=True)
        return self.RequestHandlerClass(sockssl, addr, self)


if config['ssl']:
    server_class = TLSThreadedHTTPServer
    note = "(TLS yes)"
else:
    server_class = ThreadedHTTPServer
    note = "(TLS no)"

try:
    server = server_class(('', 18700), TargetHandler)
    print "started server", note
    server.serve_forever()
except KeyboardInterrupt:
    print "SIGINT received, shutting down"
    server.socket.close()
