#!/usr/bin/python3.3

# Copyright (C) 2010, 2011, 2012  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

"""
Statistics daemon in BIND 10

"""
import sys; sys.path.append ('/usr/lib/python3.3/site-packages')
import os
from time import time, strftime, gmtime
from optparse import OptionParser, OptionValueError
import errno
import select

import isc
import isc.util.process
import isc.log
from isc.log_messages.stats_messages import *

isc.log.init("b10-stats", buffer=True)
logger = isc.log.Logger("stats")

# Some constants for debug levels.
DBG_STATS_MESSAGING = logger.DBGLVL_COMMAND

# This is for boot_time of Stats
_BASETIME = gmtime()

# for setproctitle
isc.util.process.rename()

# If B10_FROM_SOURCE is set in the environment, we use data files
# from a directory relative to that, otherwise we use the ones
# installed on the system
if "B10_FROM_SOURCE" in os.environ:
    SPECFILE_LOCATION = os.environ["B10_FROM_SOURCE"] + os.sep + \
        "src" + os.sep + "bin" + os.sep + "stats" + os.sep + "stats.spec"
else:
    PREFIX = "/usr"
    DATAROOTDIR = "${prefix}/share"
    SPECFILE_LOCATION = "/usr/share" + os.sep + "bind10" + os.sep + "stats.spec"
    SPECFILE_LOCATION = SPECFILE_LOCATION.replace("${datarootdir}", DATAROOTDIR)\
        .replace("${prefix}", PREFIX)

def get_timestamp():
    """
    get current timestamp
    """
    return time()

def get_datetime(gmt=None):
    """
    get current datetime
    """
    if not gmt: gmt = gmtime()
    return strftime("%Y-%m-%dT%H:%M:%SZ", gmt)

def get_spec_defaults(spec):
    """
    extracts the default values of the items from spec specified in
    arg, and returns the dict-type variable which is a set of the item
    names and the default values
    """
    if type(spec) is not list: return {}
    def _get_spec_defaults(spec):
        item_type = spec['item_type']
        if item_type == "integer":
            return int(spec.get('item_default', 0))
        elif item_type == "real":
            return float(spec.get('item_default', 0.0))
        elif item_type == "boolean":
            return bool(spec.get('item_default', False))
        elif item_type == "string":
            return str(spec.get('item_default', ""))
        elif item_type == "list":
            return spec.get(
                    "item_default",
                    [ _get_spec_defaults(spec["list_item_spec"]) ])
        elif item_type == "map":
            return spec.get(
                    "item_default",
                    dict([ (s["item_name"], _get_spec_defaults(s)) for s in spec["map_item_spec"] ]) )
        elif item_type == "named_set":
            # in a named_set type, it returns {} as a default value
            return spec.get("item_default", {})
        else:
            return spec.get("item_default", None)
    return dict([ (s['item_name'], _get_spec_defaults(s)) for s in spec ])

def _accum(a, b):
    """If the first arg is dict or list type, two values
    would be merged and accumlated. This is for internal use."""

    # If both of args are dict or list type, two
    # values are merged.
    if type(a) is dict and type(b) is dict:
        return dict([ (k, _accum(v, b[k])) \
                          if k in b else (k, v) \
                          for (k, v) in a.items() ] \
                        + [ (k, v) \
                                for (k, v) in b.items() \
                                if k not in a ])
    elif type(a) is list and type(b) is list:
        return [ _accum(a[i], b[i]) \
                     if len(b) > i else a[i] \
                     for i in range(len(a)) ] \
                     + [ b[i] \
                             for i in range(len(b)) \
                             if len(a) <= i ]
    # If both of args are integer or float type, two
    # values are added.
    elif (type(a) is int and type(b) is int) \
            or (type(a) is float or type(b) is float):
        return a + b

    # If both of args are string type,
    # values are compared and bigger one is returned.
    elif type(a) is str and type(b) is str:
        if a < b: return b
        return a

    # If the first arg is None type, the second value is returned.
    elif a is None:
        return b

    # Nothing matches above, the first arg is returned
    return a

def merge_oldnew(old, new):
    """
    Merges two arguments.  If old data contains the corresponding name
    against new data, the value of the name is replaced with the
    corresponding value in new data. Otherwise, the new date are added
    at same name or same id. Both old data and new data should be same
    data type. This method returns the merged result.
    """
    # If the first arg is dict or list type, two values
    # would be merged
    if type(old) is dict and type(new) is dict:
        return dict([ (k, merge_oldnew(old[k], v)) \
                          if k in old else (k, v) \
                          for (k, v) in new.items() ] \
                        + [ (k, v) \
                                for (k, v) in old.items() \
                                if k not in new ])
    elif type(old) is list and type(new) is list:
        return [ merge_oldnew(old[i], new[i]) \
                     if len(old) > i else new[i] \
                     for i in range(len(new)) ] \
                     + [ old[i] \
                             for i in range(len(old)) \
                             if len(new) <= i ]
    else:
        return new

class Callback():
    """
    A Callback handler class
    """
    def __init__(self, command=None, args=(), kwargs={}):
        self.command = command
        self.args = args
        self.kwargs = kwargs

    def __call__(self, *args, **kwargs):
        if not args: args = self.args
        if not kwargs: kwargs = self.kwargs
        if self.command: return self.command(*args, **kwargs)

class StatsError(Exception):
    """Exception class for Stats class"""
    pass

class Stats:
    """
    Main class of stats module
    """
    def __init__(self):
        self.running = False
        # create ModuleCCSession object
        self.mccs = isc.config.ModuleCCSession(SPECFILE_LOCATION,
                                               self.config_handler,
                                               self.command_handler)
        self.cc_session = self.mccs._session
        # get module spec
        self.module_name = self.mccs.get_module_spec().get_module_name()
        self.modules = {}
        self.statistics_data = {}
        # statistics data by each mid
        self.statistics_data_bymid = {}
        # get commands spec
        self.commands_spec = self.mccs.get_module_spec().get_commands_spec()
        # add event handler related command_handler of ModuleCCSession
        self.callbacks = {}
        for cmd in self.commands_spec:
            # add prefix "command_"
            name = "command_" + cmd["command_name"]
            try:
                callback = getattr(self, name)
                kwargs = get_spec_defaults(cmd["command_args"])
                self.callbacks[name] = Callback(command=callback, kwargs=kwargs)
            except AttributeError:
                raise StatsError(STATS_UNKNOWN_COMMAND_IN_SPEC, cmd["command_name"])
        self.config = {}
        self.mccs.start()
        # setup my config
        self.config = dict([
                (itm['item_name'], self.mccs.get_value(itm['item_name'])[0])
                for itm in self.mccs.get_module_spec().get_config_spec()
                ])
        # set a absolute timestamp polling at next time
        self.next_polltime = get_timestamp() + self.get_interval()
        # initialized Statistics data
        self.update_modules()
        if self.update_statistics_data(
            self.module_name,
            self.cc_session.lname,
            {'lname': self.cc_session.lname,
             'boot_time': get_datetime(_BASETIME),
             'last_update_time': get_datetime()}):
            logger.warn(STATS_RECEIVED_INVALID_STATISTICS_DATA,
                        self.module_name)
        # define the variable of the last time of polling
        self._lasttime_poll = 0.0

    def get_interval(self):
        """return the current value of 'poll-interval'"""
        return self.config['poll-interval']

    def do_polling(self):
        """Polls modules for statistics data. Return nothing. First
           search multiple instances of same module. Second requests
           each module to invoke 'getstats'. Finally updates internal
           statistics data every time it gets from each instance."""

        # It counts the number of instances of same module by
        # examining the third value from the array result of
        # 'show_processes' of Init
        seq = self.cc_session.group_sendmsg(
            isc.config.ccsession.create_command("show_processes"),
            'Init')
        (answer, env) = self.cc_session.group_recvmsg(False, seq)
        modules = []
        if answer:
            (rcode, value) = isc.config.ccsession.parse_answer(answer)
            if rcode == 0 and type(value) is list:
                # NOTE: For example, the "show_processes" command
                # of Init is assumed to return the response in this
                # format:
                #  [
                #  ...
                #    [
                #      20061,
                #      "b10-auth",
                #      "Auth"
                #    ],
                #    [
                #      20103,
                #      "b10-auth-2",
                #      "Auth"
                #    ]
                #  ...
                #  ]
                # If multiple instances of the same module are
                # running, the address names of them, which are at the
                # third element, must be also same. Thus, the value of
                # the third element of each outer element is read here
                # for counting multiple instances.  This is a
                # workaround for counting the instances. This should
                # be fixed in another proper way in the future
                # release.
                modules = [ v[2] if type(v) is list and len(v) > 2 \
                                else None for v in value ]
        # start requesting each module to collect statistics data
        sequences = []
        for (module_name, data) in self.get_statistics_data().items():
            # skip if module_name is 'Stats'
            if module_name == self.module_name:
                continue
            logger.debug(DBG_STATS_MESSAGING, STATS_SEND_STATISTICS_REQUEST,
                         module_name)
            cmd = isc.config.ccsession.create_command(
                "getstats", None) # no argument
            seq = self.cc_session.group_sendmsg(cmd, module_name)
            sequences.append((module_name, seq))
            cnt = modules.count(module_name)
            if cnt > 1:
                sequences = sequences + [ (module_name, seq) \
                                              for i in range(cnt-1) ]
        # start receiving statistics data
        _statistics_data = []
        while len(sequences) > 0:
            try:
                (module_name, seq) = sequences.pop(0)
                answer, env = self.cc_session.group_recvmsg(
                    False, seq)
                if answer:
                    rcode, args = isc.config.ccsession.parse_answer(
                        answer)
                    if rcode == 0:
                        _statistics_data.append(
                            (module_name, env['from'], args))
            # skip this module if SessionTimeout raised
            except isc.cc.session.SessionTimeout:
                pass

        # update statistics data
        self.update_modules()
        while len(_statistics_data) > 0:
            (_module_name, _lname, _args) = _statistics_data.pop(0)
            if self.update_statistics_data(_module_name, _lname, _args):
                logger.warn(
                STATS_RECEIVED_INVALID_STATISTICS_DATA,
                _module_name)
            else:
                if self.update_statistics_data(
                    self.module_name,
                    self.cc_session.lname,
                    {'last_update_time': get_datetime()}):
                    logger.warn(
                        STATS_RECEIVED_INVALID_STATISTICS_DATA,
                        self.module_name)
        # if successfully done, set the last time of polling
        self._lasttime_poll = get_timestamp()

    def start(self):
        """
        Start stats module
        """
        logger.info(STATS_STARTING)

        def _check_command(nonblock=False):
            """check invoked command by waiting for 'poll-interval'
            seconds"""
            # backup original timeout
            orig_timeout = self.cc_session.get_timeout()
            # set cc-session timeout to half of a second(500ms)
            self.cc_session.set_timeout(500)
            try:
                answer, env = self.cc_session.group_recvmsg(nonblock)
                self.mccs.check_command_without_recvmsg(answer, env)
            except isc.cc.session.SessionTimeout:
                pass # waited for poll-interval seconds
            # restore timeout
            self.cc_session.set_timeout(orig_timeout)

        try:
            self.running = True
            while self.running:
                _check_command()
                now = get_timestamp()
                intval = self.get_interval()
                if intval > 0 and now >= self.next_polltime:
                    # decide the next polling timestamp
                    self.next_polltime += intval
                    # adjust next time
                    if self.next_polltime < now:
                        self.next_polltime = now
                    self.do_polling()
        finally:
            self.mccs.send_stopping()

    def config_handler(self, new_config):
        """
        handle a configure from the cc channel
        """
        logger.debug(DBG_STATS_MESSAGING, STATS_RECEIVED_NEW_CONFIG,
                     new_config)
        errors = []
        if not self.mccs.get_module_spec().\
                validate_config(False, new_config, errors):
                return isc.config.ccsession.create_answer(
                    1, ", ".join(errors))

        if 'poll-interval' in new_config \
                and new_config['poll-interval'] < 0:
            return isc.config.ccsession.create_answer(
                1, "Negative integer ignored")

        self.config.update(new_config)
        if 'poll-interval' in self.config:
            # update next polling timestamp
            self.next_polltime = get_timestamp() + self.get_interval()
        return isc.config.create_answer(0)

    def command_handler(self, command, kwargs):
        """
        handle commands from the cc channel
        """
        name = 'command_' + command
        if name in self.callbacks:
            callback = self.callbacks[name]
            if kwargs:
                return callback(**kwargs)
            else:
                return callback()
        else:
            logger.error(STATS_RECEIVED_UNKNOWN_COMMAND, command)
            return isc.config.create_answer(1, "Unknown command: '"+str(command)+"'")

    def update_modules(self):
        """
        updates information of each module. This method gets each
        module's information from the config manager and sets it into
        self.modules. If its getting from the config manager fails, it
        raises StatsError.
        """
        modules = {}
        seq = self.cc_session.group_sendmsg(
            isc.config.ccsession.create_command(
                isc.config.ccsession.COMMAND_GET_STATISTICS_SPEC),
            'ConfigManager')
        (answer, env) = self.cc_session.group_recvmsg(False, seq)
        if answer:
            (rcode, value) = isc.config.ccsession.parse_answer(answer)
            if rcode == 0:
                for mod in value:
                    spec = { "module_name" : mod }
                    if value[mod] and type(value[mod]) is list:
                        spec["statistics"] = value[mod]
                    modules[mod] = isc.config.module_spec.ModuleSpec(spec)
            else:
                raise StatsError("Updating module spec fails: " + str(value))
        modules[self.module_name] = self.mccs.get_module_spec()
        self.modules = modules

    def get_statistics_data(self, owner=None, name=None):
        """
        returns statistics data which stats module has of each
        module. If it can't find specified statistics data, it raises
        StatsError.
        """
        self.update_modules()
        if owner and name:
            try:
                return {owner:{name:self.statistics_data[owner][name]}}
            except KeyError:
                pass
        elif owner:
            try:
                return {owner: self.statistics_data[owner]}
            except KeyError:
                pass
        elif name:
            pass
        else:
            return self.statistics_data
        raise StatsError("No statistics data found: "
                         + "owner: " + str(owner) + ", "
                         + "name: " + str(name))

    def update_statistics_data(self, owner=None, mid=None, data=None):
        """
        change statistics data of specified module into specified
        data. It updates information of each module first, and it
        updates statistics data. If specified data is invalid for
        statistics spec of specified owner, it returns a list of error
        messages. If there is no error or if neither owner nor data is
        specified in args, it returns None. The 'mid' argument is an identifier of
        the sender module in order for stats to identify which
        instance sends statistics data in the situation that multiple
        instances are working.
        """
        # Note:
        # The fix of #1751 is for multiple instances working. It is
        # assumed here that they send different statistics data with
        # each sender module id (mid). Stats should save their statistics data by
        # mid. The statistics data, which is the existing variable, is
        # preserved by accumlating from statistics data by the mid. This
        # is an ad-hoc fix because administrators can not see
        # statistics by each instance via bindctl or HTTP/XML. These
        # interfaces aren't changed in this fix.

        def _accum_bymodule(statistics_data_bymid):
            """This is an internal method for the superordinate
            method. It accumulates statistics data of each module id
            by module. It returns a accumulated result."""
            # FIXME: A issue might happen when consolidating
            # statistics of the multiple instances. If they have
            # different statistics data which are not for adding each
            # other, this might happen: If these are integer or float,
            # these are added each other. If these are string , these
            # are compared and consolidated into bigger one.  If one
            # of them is None type , these might be consolidated
            # into not None-type one. Otherwise these are overwritten
            # into one of them.
            ret = {}
            for data in statistics_data_bymid.values():
                ret.update(_accum(data, ret))
            return ret

        # Firstly, it gets default statistics data in each spec file.
        statistics_data = {}
        for (name, module) in self.modules.items():
            value = get_spec_defaults(module.get_statistics_spec())
            if module.validate_statistics(True, value):
                statistics_data[name] = value
        self.statistics_data = statistics_data

        # If the "owner" and "data" arguments in this function are
        # specified, then the variable of statistics data of each module id
        # would be updated.
        errors = []
        if owner and data:
            _data = self.statistics_data_bymid.copy()
            try:
                for (_key, _val) in data.items():
                    if self.modules[owner].validate_statistics(
                        False, {_key: _val}, errors):
                        if owner not in _data:
                            _data[owner] = { mid: { _key: _val } }
                        elif mid not in _data[owner]:
                            _data[owner][mid] = { _key: _val }
                        else:
                            # merge recursively old value and new
                            # value each other
                            _data[owner][mid] = \
                                merge_oldnew(_data[owner][mid],
                                             {_key: _val})
                        continue
                    # the key string might be a "xx/yy/zz[0]"
                    # type. try it.
                    if _key.find('/') >= 0 or \
                            isc.cc.data.identifier_has_list_index(_key):
                        # remove the last error
                        if errors: errors.pop()
                        # try updata and check validation in adavance
                        __data = _data.copy()
                        if owner not in _data:
                            __data[owner] = {}
                        if mid not in _data[owner]:
                            __data[owner][mid] = {}
                        # use the isc.cc.data.set method
                        try:
                            isc.cc.data.set(__data[owner][mid],
                                            _key, _val)
                            if self.modules[owner].validate_statistics(
                                False, __data[owner][mid], errors):
                                _data = __data
                        except Exception as e:
                            errors.append(
                                "%s: %s" % (e.__class__.__name__, e))
            except KeyError:
                errors.append("unknown module name: " + str(owner))
            if not errors:
                self.statistics_data_bymid = _data

        # Just consolidate statistics data of each module without
        # removing that of modules which have been already dead
        mlist = [ k for k in self.statistics_data_bymid.keys() ]
        for m in mlist:
            if self.statistics_data_bymid[m]:
                if m in self.statistics_data:
                    # propagate the default values by times of
                    # instances
                    _len = len(self.statistics_data_bymid[m])
                    for i in range(0, _len - 1):
                        self.statistics_data[m] = _accum(
                            self.statistics_data[m],
                            self.statistics_data[m])
                    # replace the default values with summaries of the
                    # collected values of each module. But the default
                    # values which are not included in collected
                    # values are not replaced.
                    self.statistics_data[m] = merge_oldnew(
                        self.statistics_data[m],
                        _accum_bymodule(
                            self.statistics_data_bymid[m]))

        if errors: return errors

    def command_status(self):
        """
        handle status command
        """
        logger.debug(DBG_STATS_MESSAGING, STATS_RECEIVED_STATUS_COMMAND)
        return isc.config.create_answer(
            0, "Stats is up. (PID " + str(os.getpid()) + ")")

    def command_shutdown(self, pid=None):
        """
        handle shutdown command

        The pid argument is ignored, it is here to match the signature.
        """
        logger.info(STATS_RECEIVED_SHUTDOWN_COMMAND)
        self.running = False
        return isc.config.create_answer(0)

    def command_show(self, owner=None, name=None):
        """
        handle show command
        """
        # decide if polling should be done based on the the last time of
        # polling. If more than one second has passed since the last
        # request to each module, the stats module requests each module
        # statistics data and then shows the latest result. Otherwise,
        # the stats module just shows statistics data which it has.
        if get_timestamp() - self._lasttime_poll > 1.0:
            self.do_polling()
        if owner or name:
            logger.debug(DBG_STATS_MESSAGING,
                         STATS_RECEIVED_SHOW_NAME_COMMAND,
                         str(owner)+", "+str(name))
        else:
            logger.debug(DBG_STATS_MESSAGING,
                         STATS_RECEIVED_SHOW_ALL_COMMAND)
        errors = self.update_statistics_data(
            self.module_name,
            self.cc_session.lname,
            {'timestamp': get_timestamp(),
             'report_time': get_datetime()}
            )
        if errors:
            raise StatsError("stats spec file is incorrect: "
                             + ", ".join(errors))
        try:
            return isc.config.create_answer(
                0, self.get_statistics_data(owner, name))
        except StatsError:
            return isc.config.create_answer(
                1, "specified arguments are incorrect: " \
                    + "owner: " + str(owner) + ", name: " + str(name))

    def command_showschema(self, owner=None, name=None):
        """
        handle show command
        """
        if owner or name:
            logger.debug(DBG_STATS_MESSAGING,
                         STATS_RECEIVED_SHOWSCHEMA_NAME_COMMAND,
                         str(owner)+", "+str(name))
        else:
            logger.debug(DBG_STATS_MESSAGING,
                         STATS_RECEIVED_SHOWSCHEMA_ALL_COMMAND)
        self.update_modules()
        schema = {}
        schema_byname = {}
        for mod in self.modules:
            spec = self.modules[mod].get_statistics_spec()
            schema_byname[mod] = {}
            if spec:
                schema[mod] = spec
                for item in spec:
                    schema_byname[mod][item['item_name']] = item
        if owner:
            try:
                if name:
                    return isc.config.create_answer(0, {owner:[schema_byname[owner][name]]})
                else:
                    return isc.config.create_answer(0, {owner:schema[owner]})
            except KeyError:
                pass
        else:
            if name:
                return isc.config.create_answer(1, "module name is not specified")
            else:
                return isc.config.create_answer(0, schema)
        return isc.config.create_answer(
                1, "specified arguments are incorrect: " \
                    + "owner: " + str(owner) + ", name: " + str(name))

if __name__ == "__main__":
    try:
        parser = OptionParser()
        parser.add_option(
            "-v", "--verbose", dest="verbose", action="store_true",
            help="enable maximum debug logging")
        (options, args) = parser.parse_args()
        if options.verbose:
            isc.log.init("b10-stats", "DEBUG", 99, buffer=True)
        stats = Stats()
        stats.start()
    except OptionValueError as ove:
        logger.fatal(STATS_BAD_OPTION_VALUE, ove)
        sys.exit(1)
    except isc.cc.session.SessionError as se:
        logger.fatal(STATS_CC_SESSION_ERROR, se)
        sys.exit(1)
    except StatsError as se:
        logger.fatal(STATS_START_ERROR, se)
        sys.exit(1)
    except KeyboardInterrupt as kie:
        logger.info(STATS_STOPPED_BY_KEYBOARD)

    logger.info(STATS_EXITING)
