#!/usr/bin/python3.3

# Copyright (C) 2012  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import sys
sys.path.append('/usr/lib64/python3.3/site-packages')
import time
import signal
from optparse import OptionParser
from isc.dns import *
from isc.datasrc import *
import isc.util.process
import isc.log
from isc.log_messages.loadzone_messages import *
from datetime import timedelta

isc.util.process.rename()

# These are needed for logger settings
import bind10_config
import json
from isc.config import module_spec_from_file
from isc.config.ccsession import path_search

isc.log.init("b10-loadzone")
logger = isc.log.Logger("loadzone")

# The default value for the interval of progress report in terms of the
# number of RRs loaded in that interval.  Arbitrary choice, but intended to
# be reasonably small to handle emergency exit.
LOAD_INTERVAL_DEFAULT = 10000

class BadArgument(Exception):
    '''An exception indicating an error in command line argument.

    '''
    pass

class LoadFailure(Exception):
    '''An exception indicating failure in loading operation.

    '''
    pass

def set_cmd_options(parser):
    '''Helper function to set command-line options.

    '''
    parser.add_option("-c", "--datasrc-conf", dest="conf", action="store",
                      help="""configuration of datasrc to load the zone in.
Example: '{"database_file": "/path/to/dbfile/db.sqlite3"}'""",
                      metavar='CONFIG')
    parser.add_option("-d", "--debug", dest="debug_level",
                      type='int', action="store", default=None,
                      help="enable debug logs with the specified level [0-99]")
    parser.add_option("-i", "--report-interval", dest="report_interval",
                      type='int', action="store",
                      default=LOAD_INTERVAL_DEFAULT,
                      help="""report logs progress per specified number of RRs
(specify 0 to suppress report) [default: %default]""")
    parser.add_option("-t", "--datasrc-type", dest="datasrc_type",
                      action="store", default='sqlite3',
                      help="""type of data source (e.g., 'sqlite3')\n
[default: %default]""")
    parser.add_option("-C", "--class", dest="zone_class", action="store",
                      default='IN',
                      help="""RR class of the zone; currently must be 'IN'
[default: %default]""")

class LoadZoneRunner:
    '''Main logic for the loadzone.

    This is implemented as a class mainly for the convenience of tests.

    '''
    def __init__(self, command_args):
        self.__command_args = command_args
        self.__interrupted = False # will be set to True on receiving signal

        # system-wide log configuration.  We need to configure logging this
        # way so that the logging policy applies to underlying libraries, too.
        self.__log_spec = json.dumps(isc.config.module_spec_from_file(
                path_search('logging.spec', bind10_config.PLUGIN_PATHS)).
                                     get_full_spec())
        # "severity" and "debuglevel" are the tunable parameters, which will
        # be set in _config_log().
        self.__log_conf_base = {"loggers":
                                    [{"name": "*",
                                      "output_options":
                                          [{"output": "stderr",
                                            "destination": "console"}]}]}

        # These are essentially private, but defined as "protected" for the
        # convenience of tests inspecting them
        self._loaded_rrs = 0
        self._zone_class = None
        self._zone_name = None
        self._zone_file = None
        self._datasrc_config = None
        self._datasrc_type = None
        self._log_severity = 'INFO'
        self._log_debuglevel = 0
        self._report_interval = LOAD_INTERVAL_DEFAULT
        self._start_time = None
        # This one will be used in (rare) cases where we want to allow tests to
        # fake time.time()
        self._get_time = time.time

        self._config_log()

    def _config_log(self):
        '''Configure logging policy.

        This is essentially private, but defined as "protected" for tests.

        '''
        self.__log_conf_base['loggers'][0]['severity'] = self._log_severity
        self.__log_conf_base['loggers'][0]['debuglevel'] = self._log_debuglevel
        isc.log.log_config_update(json.dumps(self.__log_conf_base),
                                  self.__log_spec)

    def _parse_args(self):
        '''Parse command line options and other arguments.

        This is essentially private, but defined as "protected" for tests.

        '''

        usage_txt = \
            'usage: %prog [options] -c datasrc_config zonename zonefile'
        parser = OptionParser(usage=usage_txt)
        set_cmd_options(parser)
        (options, args) = parser.parse_args(args=self.__command_args)

        # Configure logging policy as early as possible
        if options.debug_level is not None:
            self._log_severity = 'DEBUG'
            # optparse performs type check
            self._log_debuglevel = int(options.debug_level)
            if self._log_debuglevel < 0:
                raise BadArgument(
                    'Invalid debug level (must be non negative): %d' %
                    self._log_debuglevel)
        self._config_log()

        self._datasrc_type = options.datasrc_type
        self._datasrc_config = options.conf
        if options.conf is None:
            self._datasrc_config = self._get_datasrc_config(self._datasrc_type)
        try:
            self._zone_class = RRClass(options.zone_class)
        except isc.dns.InvalidRRClass as ex:
            raise BadArgument('Invalid zone class: ' + str(ex))
        if self._zone_class != RRClass.IN:
            raise BadArgument("RR class is not supported: " +
                              str(self._zone_class))

        self._report_interval = int(options.report_interval)
        if self._report_interval < 0:
            raise BadArgument(
                'Invalid report interval (must be non negative): %d' %
                self._report_interval)

        if len(args) != 2:
            raise BadArgument('Unexpected number of arguments: %d (must be 2)'
                              % (len(args)))
        try:
            self._zone_name = Name(args[0])
        except Exception as ex: # too broad, but there's no better granurality
            raise BadArgument("Invalid zone name '" + args[0] + "': " +
                              str(ex))
        self._zone_file = args[1]

    def _get_datasrc_config(self, datasrc_type):
        ''''Return the default data source configuration of given type.

        Right now, it only supports SQLite3, and hardcodes the syntax
        of the default configuration.  It's a kind of workaround to balance
        convenience of users and minimizing hardcoding of data source
        specific logic in the entire tool.  In future this should be
        more sophisticated.

        This is essentially a private helper method for _parse_arg(),
        but defined as "protected" so tests can use it directly.

        '''
        if datasrc_type != 'sqlite3':
            raise BadArgument('default config is not available for ' +
                              datasrc_type)

        default_db_file = bind10_config.DATA_PATH + '/zone.sqlite3'
        logger.info(LOADZONE_SQLITE3_USING_DEFAULT_CONFIG, default_db_file)
        return '{"database_file": "' + default_db_file + '"}'

    def _report_progress(self, loaded_rrs, progress, dump=True):
        '''Dump the current progress report to stdout.

        This is essentially private, but defined as "protected" for tests.
        Normally dump is True, but tests will set it False to get the
        text to be reported.  Tests may also fake self._get_time (which
        is set to time.time() by default) and self._start_time for control
        time related conditions.

        '''
        elapsed = self._get_time() - self._start_time
        speed = int(loaded_rrs / elapsed) if elapsed > 0 else 0
        etc = None            # calculate estimated time of completion
        if progress != ZoneLoader.PROGRESS_UNKNOWN:
            etc = (1 - progress) * (elapsed / progress)

        # Build report text
        report_txt = '\r%d RRs' % loaded_rrs
        if progress != ZoneLoader.PROGRESS_UNKNOWN:
            report_txt += ' (%.1f%%)' % (progress * 100)
        report_txt += ' in %s, %d RRs/sec' % \
            (str(timedelta(seconds=int(elapsed))), speed)
        if etc is not None:
            report_txt += ', %s ETC' % str(timedelta(seconds=int(etc)))

        # Dump or return the report text.
        if dump:
            sys.stdout.write("\r" + (80 * " "))
            sys.stdout.write(report_txt)
        else:
            return report_txt

    def _do_load(self):
        '''Main part of the load logic.

        This is essentially private, but defined as "protected" for tests.

        '''
        created = False
        try:
            datasrc_client = DataSourceClient(self._datasrc_type,
                                              self._datasrc_config)
            created = datasrc_client.create_zone(self._zone_name)
            if created:
                logger.info(LOADZONE_ZONE_CREATED, self._zone_name,
                            self._zone_class)
            else:
                logger.info(LOADZONE_ZONE_UPDATING, self._zone_name,
                            self._zone_class)
            loader = ZoneLoader(datasrc_client, self._zone_name,
                                self._zone_file)
            self._start_time = time.time()
            if self._report_interval > 0:
                limit = self._report_interval
            else:
                # Even if progress report is suppressed, we still load
                # incrementally so we won't delay catching signals too long.
                limit = LOAD_INTERVAL_DEFAULT
            while (not self.__interrupted and
                   not loader.load_incremental(limit)):
                self._loaded_rrs += self._report_interval
                if self._report_interval > 0:
                    self._report_progress(self._loaded_rrs,
                                          loader.get_progress())
            if self.__interrupted:
                raise LoadFailure('loading interrupted by signal')

            # On successful completion, add final '\n' to the progress
            # report output (on failure don't bother to make it prettier).
            if (self._report_interval > 0 and
                self._loaded_rrs >= self._report_interval):
                sys.stdout.write('\n')
            # record the final count of the loaded RRs for logging
            self._loaded_rrs = loader.get_rr_count()
        except Exception as ex:
            # release any remaining lock held in the loader
            loader = None
            if created:
                datasrc_client.delete_zone(self._zone_name)
                logger.error(LOADZONE_CANCEL_CREATE_ZONE, self._zone_name,
                             self._zone_class)
            raise LoadFailure(str(ex))

    def _set_signal_handlers(self):
        signal.signal(signal.SIGINT, self._interrupt_handler)
        signal.signal(signal.SIGTERM, self._interrupt_handler)

    def _interrupt_handler(self, signal, frame):
        self.__interrupted = True

    def run(self):
        '''Top-level method, simply calling other helpers'''

        try:
            self._set_signal_handlers()
            self._parse_args()
            self._do_load()
            total_elapsed_txt = "%.2f" % (time.time() - self._start_time)
            logger.info(LOADZONE_DONE, self._loaded_rrs, self._zone_name,
                        self._zone_class, total_elapsed_txt)
            return 0
        except BadArgument as ex:
            logger.error(LOADZONE_ARGUMENT_ERROR, ex)
        except LoadFailure as ex:
            logger.error(LOADZONE_LOAD_ERROR, self._zone_name,
                         self._zone_class, ex)
        except Exception as ex:
            logger.error(LOADZONE_UNEXPECTED_FAILURE, ex)
        return 1

if '__main__' == __name__:
    runner = LoadZoneRunner(sys.argv[1:])
    ret = runner.run()
    sys.exit(ret)

## Local Variables:
## mode: python
## End:
