#!/usr/bin/python
#
# Copyright (c) 2010 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
#
# Red Hat trademarks are not licensed under GPLv2. No permission is
# granted to use or replicate Red Hat trademarks that are incorporated
# in this software or its documentation.
#

import sys
import os
import xmlrpclib
import httplib
import getpass
import libxml2
import subprocess
import re
import simplejson as json
import shutil
import logging
import traceback
import base64
from datetime import datetime
from rhsm.connection import UEPConnection, RemoteServerException, RestlibException
from M2Crypto.SSL import SSLError

_LIBPATH = "/usr/share/rhsm"
# add to the path if need be
if _LIBPATH not in sys.path:
    sys.path.append(_LIBPATH)

from subscription_manager.utils import parse_server_info, ServerUrlParseError
from subscription_manager.i18n import configure_i18n
configure_i18n()

import gettext
_ = gettext.gettext

# quick check to see if you are a super-user.
if os.getuid() != 0:
    print _("Must be root user to execute\n")
    sys.exit(8)

# access the rhn/up2date python libaries and read the up2date config file
_RHNLIBPATH = "/usr/share/rhn"
if _RHNLIBPATH not in sys.path:
    sys.path.append(_RHNLIBPATH)

from up2date_client import up2dateErrors
from up2date_client.rhnChannel import  getChannels
import up2date_client.config

rhncfg = up2date_client.config.initUp2dateConfig()

from optparse import Option
from subscription_manager.i18n_optparse import OptionParser, WrappedIndentedHelpFormatter

options_table = [
    Option("-f", "--force", action="store_true", default=False,
           help=_("Ignore channels not available on RHSM")),
    Option("-g", "--gui", action="store_true", default=False, dest='gui',
           help=_("Launch the GUI tool to subscribe the system, instead of autosubscribing")),
    Option("-n", "--no-auto", action="store_true", default=False, dest='noauto',
           help=_("Don't execute the autosubscribe option while registering with subscription manager.")),
    Option("-s", "--servicelevel", dest="servicelevel",
           help=_("Service level to subscribe this system to. For no service "
                  "level use --service-level=\"\"")),
    Option("--serverurl", dest='serverurl',
           help=_("Specify the Subscription Management Server to migrate TO.")),
]

parser = OptionParser(usage=_("%prog [OPTIONS]"),
                      option_list=options_table,
                      formatter=WrappedIndentedHelpFormatter())

(options, args) = parser.parse_args()

# access the rhsm python libraries, read rhsm config file and setup logging
_RHSMLIBPATH = "/usr/share/rhsm"
if _RHSMLIBPATH not in sys.path:
    sys.path.append(_RHSMLIBPATH)

from subscription_manager.certlib import ConsumerIdentity, ProductDirectory
from subscription_manager import repolib, logutil

import rhsm.config

logutil.init_logger()
log = logging.getLogger('rhsm-app.' + __name__)
rhsmcfg = rhsm.config.initConfig()

proxyHost = ""
proxyPort = ""
proxyUser = ""
proxyPass = ""

CONNECTION_FAILURE = _(u"Unable to connect to certificate server: %s.  " \
        "See /var/log/rhsm/rhsm.log for more details.")


class UserCredentials(object):
    def __init__(self, username, password):
        self.username = username
        self.password = password


class InvalidChoiceError(Exception):
    pass


class Menu(object):
    def __init__(self, choices, header):
        # choices is a tuple with the first value being the display string
        # and the second value being the value to return.
        self.choices = choices
        self.header = header

    def choose(self):
        while True:
            self.display()
            selection = raw_input("? ").strip()
            try:
                return self.getItem(selection)
            except InvalidChoiceError:
                print _("You have entered an invalid choice.")

    def display(self):
        print self.header
        for index, entry in enumerate(self.choices):
            print "%s. %s" % (index + 1, entry[0])

    def getItem(self, selection):
        try:
            index = int(selection) - 1
            # In case some joker enters zero or a negative number
            if index < 0:
                raise InvalidChoiceError
        except TypeError:
            raise InvalidChoiceError
        except ValueError:
            raise InvalidChoiceError

        try:
            return self.choices[index][1]
        except IndexError:
            raise InvalidChoiceError


class ProxiedTransport(xmlrpclib.Transport):
    def set_proxy(self, proxy, credentials):
        self.proxy = proxy
        self.credentials = credentials

    def make_connection(self, host):
        self.realhost = host
        return httplib.HTTP(self.proxy)

    def send_request(self, connection, handler, request_body):
        connection.putrequest("POST", 'http://%s%s' % (self.realhost, handler))

    def send_host(self, connection, host):
        connection.putheader('Host', self.realhost)
        if self.credentials:
            connection.putheader('Proxy-Authorization', 'Basic ' + self.credentials)


def systemExit(code, msgs=None):
    "Exit with a code and optional message(s). Saved a few lines of code."
    if msgs:
        if type(msgs) not in [type([]), type(())]:
            msgs = (msgs, )
        for msg in msgs:
            sys.stderr.write(unicode(msg).encode('utf-8') + '\n')
    sys.exit(code)


def isHosted():
    hostname = rhsmcfg.get('server', 'hostname')
    if re.search('subscription\.rhn\.(.*\.)*redhat\.com', hostname):
        return True  # re.search doesn't return a boolean
    else:
        return False


def checkOkToProceed(secreds, serverurl):
    # check if this machine is already registered to Certicate-based RHN
    if ConsumerIdentity.existsAndValid():
        print _("\nThis machine appears to be already registered to Red Hat Subscription Management.  Exiting.")
        consumer = ConsumerIdentity.read()
        systemExit(1, _("\nPlease visit https://access.redhat.com/management/consumers/%s to view the profile details.") % consumer.getConsumerId())

    try:
        if serverurl is None:
            hostname = rhsmcfg.get('server', 'hostname')
            port = rhsmcfg.get('server', 'port')
            prefix = rhsmcfg.get('server', 'prefix')
        else:
            (hostname, port, prefix) = parse_server_info(serverurl)
    except ServerUrlParseError, e:
        print _("Error parsing serverurl: %s" % e.msg)
        sys.exit(-1)

    # Check to make sure we can connect to the certificate server.
    cp = UEPConnection(host=hostname,
            ssl_port=int(port),
            handler=prefix,
            username=secreds.username,
            password=secreds.password,
            proxy_hostname=rhsmcfg.get('server', 'proxy_hostname'),
            proxy_port=rhsmcfg.get('server', 'proxy_port'),
            proxy_user=rhsmcfg.get('server', 'proxy_user'),
            proxy_password=rhsmcfg.get('server', 'proxy_password'))

    try:
        cp.getOwnerList(secreds.username)
    except SSLError, e:
        print _("Error: CA certificate for subscription service has not been installed.")
        systemExit(1, CONNECTION_FAILURE % e)
    except Exception, e:
        log.error(e)
        log.error(traceback.format_exc())
        systemExit(1, CONNECTION_FAILURE % e)
    return cp


def getOrg(cp, username):
    try:
        owner_list = cp.getOwnerList(username)
    except Exception, e:
        log.error(e)
        log.error(traceback.format_exc())
        systemExit(1, CONNECTION_FAILURE % e)

    org = None
    if len(owner_list) == 0:
        systemExit(1, _("%s cannot register to any organizations.") % username)

    if len(owner_list) > 1:
        org = raw_input(_("Org: ")).strip()

    if org not in [x['key'] for x in owner_list]:
        systemExit(1, _("No such org: %s") % org)

    return org


def getEnvironment(cp, owner_key):
    environment_list = []
    try:
        if cp.supports_resource('environments'):
            environment_list = cp.getEnvironmentList(owner_key)
    except Exception, e:
        log.error(e)
        log.error(traceback.format_exc())
        systemExit(1, CONNECTION_FAILURE % e)

    environment = None
    if len(environment_list) > 1:
        environment = raw_input(_("Environment") + ": ").strip()
        if environment not in [x['name'] for x in environment_list]:
            systemExit(1, _("No such environment: %s") % environment)

    return environment


def checkIsOrgAdmin(sc, sk, username):
    try:
        roles = sc.user.listRoles(sk, username)
    except:
        log.error(traceback.format_exc())
        systemExit(1, _("Problem encountered determining user roles in RHN Classic.  Exiting."))
    if "org_admin" not in roles:
        systemExit(1, _("You must be an org admin to successfully run this script."))


def selectServiceLevel(cp, consumer, servicelevel):
    not_supported = _("Error: The service-level command is not supported by "
                      "the server.")
    uuid = consumer.getConsumerId()
    try:
        org_key = cp.getOwner(uuid)['key']
        levels = cp.getServiceLevelList(org_key)
    except RemoteServerException, e:
        systemExit(-1, not_supported)
    except RestlibException, e:
            if e.code == 404:
                # no need to die, just skip it
                print not_supported
                return None
            else:
                # server supports it but something went wrong, die.
                raise e

    # Create the sla tuple before appending the empty string to the list of
    # valid slas.
    slas = [(sla, sla) for sla in levels]
    # Display an actual message for the empty string level.
    slas.append((_("No service level preference"), ""))

    # The empty string is a valid level so append it to the list.
    levels.append("")
    if servicelevel is None or \
        servicelevel.upper() not in (level.upper() for level in levels):
        if servicelevel is not None:
            print _("\nService level \"%s\" is not available." % servicelevel)
        menu = Menu(slas, _("Please select a service level agreement for this system."))
        servicelevel = menu.choose()
    return servicelevel


def getSubscribedChannelsList():
    try:
        subscribedChannels = map(lambda x: x['label'], getChannels().channels())
    except up2dateErrors.NoChannelsError:
        systemExit(1, _("This system is not associated with any channel."))
    except up2dateErrors.NoSystemIdError:
        systemExit(1, _("Unable to locate SystemId file. Is this system registered?"))
    except:
        log.error(traceback.format_exc())
        systemExit(1, _("Problem encountered getting the list of subscribed channels.  Exiting."))
    return subscribedChannels


def checkForConflictingChannels(subscribedChannels):
    jbossChannel = False
    for channel in subscribedChannels:
        if channel.startswith("jbappplatform"):
            if jbossChannel:
                systemExit(1, _("You are subscribed to more than one jbappplatform channel."
                                "  This script does not support that configuration.  Exiting."))
            jbossChannel = True


def getSystemId():
    systemIdPath = rhncfg["systemIdPath"]
    p = libxml2.parseDoc(file(systemIdPath).read())
    systemId = int(p.xpathEval('string(//member[* = "system_id"]/value/string)').split('-')[1])
    return systemId


def connectToRhn(credentials):
    hostname = rhncfg['serverURL'].split('/')[2]
    server_url = 'https://%s/rpc/api' % (hostname)
    try:
        if rhncfg['enableProxy']:
            pt = ProxiedTransport()
            if rhncfg['enableProxyAuth']:
                proxy_credentials = base64.encodestring('%s:%s' % (proxyUser, proxyPass)).strip()
            else:
                proxy_credentials = ""

            pt.set_proxy("%s:%s" % (proxyHost, proxyPort), proxy_credentials)
            log.info("Using proxy %s:%s for RHN API methods" % (proxyHost, proxyPort))
            sc = xmlrpclib.Server(server_url, transport=pt)
        else:
            sc = xmlrpclib.Server(server_url)

        sk = sc.auth.login(credentials.username, credentials.password)
        return (sc, sk)
    except:
        log.error(traceback.format_exc())
        systemExit(1, _("Unable to authenticate to RHN Classic.  See /var/log/rhsm/rhsm.log for more details."))


def unRegisterSystemFromRhnClassic(sc, sk):
    #getSystemIdPath
    systemIdPath = rhncfg["systemIdPath"]
    systemId = getSystemId()

    log.info("Deleting system %s from RHN Classic...", systemId)
    result = sc.system.deleteSystems(sk, systemId)
    if result:
        log.info("System %s deleted.  Removing systemid file and disabling rhnplugin.conf", systemId)
        os.remove(systemIdPath)
        disableYumRhnPlugin()
        print _("System successfully unregistered from RHN Classic.")
    else:
        systemExit(1, _("Unable to unregister system from RHN Classic.  Exiting."))


def disableYumRhnPlugin():
    # 'Inspired by' up2date_client/rhnreg.py
    """ disable yum-rhn-plugin by setting enabled=0 in file
        /etc/yum/pluginconf.d/rhnplugin.conf
        Can thrown IOError exception.
    """
    log.info("Disabling rhnplugin.conf")
    YUM_PLUGIN_CONF = '/etc/yum/pluginconf.d/rhnplugin.conf'
    f = open(YUM_PLUGIN_CONF, 'r')
    lines = f.readlines()
    f.close()
    main_section = False
    f = open(YUM_PLUGIN_CONF, 'w')
    for line in lines:
        if re.match("^\[.*]", line):
            if re.match("^\[main]", line):
                main_section = True
            else:
                main_section = False
        if main_section:
            line = re.sub('^(\s*)enabled\s*=.+', r'\1enabled = 0', line)
        f.write(line)
    f.close()


def readChannelCertMapping(mappingfile):
    f = open(mappingfile)
    lines = f.readlines()
    dic_data = {}
    for line in lines:
        if re.match("^[a-zA-Z]", line):
            line = line.replace("\n", "")
            key, val = line.split(": ")
            dic_data[key] = val
    return dic_data


def transferHttpProxySettings():
    # transfer http proxy information from up2date to rhsm.conf
    global proxyHost, proxyPort, proxyUser, proxyPass
    if rhncfg['enableProxy']:
        httpProxy = rhncfg['httpProxy']
        if httpProxy[:7] == "http://":
            httpProxy = httpProxy[7:]
        try:
            proxyHost, proxyPort = httpProxy.split(':')
        except ValueError, e:
            log.exception(e)
            systemExit(1, _("Unable to read RHN proxy settings."))

        log.info("Using proxy %s:%s - transferring settings to rhsm.conf" % (proxyHost, proxyPort))
        rhsmcfg.set('server', 'proxy_hostname', proxyHost)
        rhsmcfg.set('server', 'proxy_port', proxyPort)
        if rhncfg['enableProxyAuth']:
            proxyUser = rhncfg['proxyUser']
            proxyPass = rhncfg['proxyPassword']
            rhsmcfg.set('server', 'proxy_user', proxyUser)
            rhsmcfg.set('server', 'proxy_password', proxyPass)
        else:
            rhsmcfg.set('server', 'proxy_user', '')
            rhsmcfg.set('server', 'proxy_password', '')
        rhsmcfg.save()


def register(credentials, serverurl, org, environment):
    # For registering the machine, use the CLI tool to reuse the username/password (because the GUI will prompt for them again)
    print _("\nAttempting to register system to Red Hat Subscription Management ...")
    cmd = ['subscription-manager', 'register', '--username=' + credentials.username, '--password=' + credentials.password]
    if serverurl:
        # insert just after register (not sure if order matters,
        # but just in case)
        cmd.insert(2, '--serverurl=' + serverurl)

    if org:
        cmd.append('--org=' + org)
    if environment:
        cmd.append('--environment=' + environment)

    result = subprocess.call(cmd)

    if result != 0:
        systemExit(2, _("\nUnable to register.\nFor further assistance, please contact Red Hat Global Support Services."))
    else:
        consumer = ConsumerIdentity.read()
        print _("System '%s' successfully registered to Red Hat Subscription Management.\n") % consumer.getConsumerName()
    return consumer


def subscribe(consumer, servicelevel):
    # For subscribing, use the GUI tool if the DISPLAY environment variable is set and the gui tool exists
    if os.getenv('DISPLAY') and os.path.exists('/usr/bin/subscription-manager-gui') and options.gui:
        print _("Launching the GUI tool to manually subscribe the system ...")
        result = subprocess.call(['subscription-manager-gui'], stderr=open(os.devnull, 'w'))
    else:
        print _("Attempting to auto-subscribe to appropriate subscriptions ...")
        cmd = ['subscription-manager', 'subscribe', '--auto']

        # only add servicelevel if one was passed in
        if servicelevel:
            cmd.append('--servicelevel=' + servicelevel)

        result = subprocess.call(cmd)
        if result != 0:
            print _("\nUnable to auto-subscribe.  Do your existing subscriptions match the products installed on this system?")
    # don't show url for katello/CFSE/SAM
    if isHosted():
        print _("\nPlease visit https://access.redhat.com/management/consumers/%s to view the details, and to make changes if necessary.") % consumer.getConsumerId()


def print_banner(msg):
        print "\n+-----------------------------------------------------+",
        print msg
        print "+-----------------------------------------------------+"


def deployProdCertificates(subscribedChannels):
    release = getRelease()
    mappingfile = "/usr/share/rhsm/product/" + release + "/channel-cert-mapping.txt"
    log.info("Using mapping file %s", mappingfile)

    try:
        dic_data = readChannelCertMapping(mappingfile)
    except IOError, e:
        log.exception(e)
        systemExit(1, _("Unable to read mapping file: %s") % mappingfile)

    applicableCerts = []
    validRhsmChannels = []
    invalidRhsmChannels = []
    unrecognizedChannels = []

    for channel in subscribedChannels:
        try:
            if dic_data[channel] != 'none':
                validRhsmChannels.append(channel)
                log.info("mapping found for : %s = %s", channel, dic_data[channel])
                if dic_data[channel] not in applicableCerts:
                    applicableCerts.append(dic_data[channel])
            else:
                invalidRhsmChannels.append(channel)
                log.info("%s None", channel)
        except:
            unrecognizedChannels.append(channel)

    if invalidRhsmChannels:
        print_banner("\n" + _("Channels not available on RHSM:"))
        for i in invalidRhsmChannels:
            print i

    if unrecognizedChannels:
        print_banner(_("\nNo product certificates are mapped to these RHN Classic channels:"))
        for i in unrecognizedChannels:
            print i

    if unrecognizedChannels or invalidRhsmChannels:
        if not options.force:
            print(_("\nUse --force to ignore these channels and continue the migration.\n"))
            sys.exit(1)

    log.info("certs to be installed: %s", applicableCerts)

    print_banner(_("\nInstalling product certificates for these RHN Classic channels:"))
    for i in validRhsmChannels:
        print i

    release = getRelease()

    # creates the product directory if it doesn't already exist
    productDir = ProductDirectory()
    for cert in applicableCerts:
        sourcepath = "/usr/share/rhsm/product/" + release + "/" + cert
        truncated_cert_name = cert.split('-')[-1]
        destinationpath = str(productDir) + "/" + truncated_cert_name
        log.info("cp %s %s ", sourcepath, destinationpath)
        shutil.copy2(sourcepath, destinationpath)
    print _("\nProduct certificates installed successfully to %s.") % str(productDir)


def getRelease():
    f = open('/etc/redhat-release')
    lines = f.readlines()
    f.close()
    release = "RHEL-" + str(lines).split(' ')[6].split('.')[0]
    return release


def enableExtraChannels(subscribedChannels, cp):
    # Check if system was subscribed to extra channels like supplementary, optional, fastrack etc.
    # If so, enable them in the redhat.repo file

    extraChannels = {'supplementary': False, 'productivity': False, 'optional': False}
    for subscribedChannel in subscribedChannels:
        if 'supplementary' in subscribedChannel:
            extraChannels['supplementary'] = True
        elif 'optional' in subscribedChannel:
            extraChannels['optional'] = True
        elif 'productivity' in subscribedChannel:
            extraChannels['productivity'] = True

    if True not in extraChannels.values():
        return

    # create and populate the redhat.repo file
    repolib.RepoLib(uep=cp).update()

    # read in the redhat.repo file
    repofile = repolib.RepoFile()
    repofile.read()

    # enable any extra channels we are using and write out redhat.repo
    try:
        for rhsmChannel in repofile.sections():
            if ((extraChannels['supplementary'] and re.search('supplementary$', rhsmChannel)) or
            (extraChannels['optional']  and re.search('optional-rpms$', rhsmChannel)) or
            (extraChannels['productivity']  and re.search('productivity-rpms$', rhsmChannel))):
                log.info("Enabling extra channel '%s'" % rhsmChannel)
                repofile.set(rhsmChannel, 'enabled', '1')
                repofile.write()
    except:
        print _("\nUnable to enable extra repositories.")
        print _("Please ensure system has subscriptions attached, and see 'subscription-manager repos --help' to enable additional repositories")


def writeMigrationFacts():
    migration_date = datetime.now().isoformat()

    FACT_FILE = "/etc/rhsm/facts/migration.facts"
    if not os.path.exists(FACT_FILE):
        f = open(FACT_FILE, 'w')
        json.dump({"migration.classic_system_id": getSystemId(),
                   "migration.migrated_from": "rhn_hosted_classic",
                   "migration.migration_date": migration_date}, f)
        f.close()


def cleanUp(subscribedChannels):
    #Hack to address BZ 853233
    productDir = ProductDirectory()
    if os.path.isfile(os.path.join(str(productDir), "68.pem")) and \
        os.path.isfile(os.path.join(str(productDir), "71.pem")):
        try:
            os.remove(os.path.join(str(productDir), "68.pem"))
            log.info("Removed 68.pem due to existence of 71.pem")
        except OSError, e:
            log.info(e)

    #Hack to address double mapping for 180.pem and 17{6|8}.pem
    double_mapped = "rhel-.*?-(client|server)-dts-(5|6)-beta(-debuginfo)?"
    #The (?!-beta) bit is a negative lookahead assertion.  So we won't match
    #if the 5 or 6 is followed by the word "-beta"
    single_mapped = "rhel-.*?-(client|server)-dts-(5|6)(?!-beta)(-debuginfo)?"

    is_double_mapped = [x for x in subscribedChannels if re.match(double_mapped, x)]
    is_single_mapped = [x for x in subscribedChannels if re.match(single_mapped, x)]

    if is_double_mapped and is_single_mapped:
        try:
            os.remove(os.path.join(str(productDir), "180.pem"))
            log.info("Removed 180.pem")
        except OSError, e:
            log.info(e)


def authenticate(prompt):
    username = raw_input(prompt).strip()
    password = getpass.getpass()
    return UserCredentials(username, password)


def validateOptions():
    if options.servicelevel and options.noauto:
        systemExit(1, _("The --servicelevel and --no-auto options cannot be used together."))


def main():
    validateOptions()
    serverurl = None
    if options.serverurl:
        rhncreds = authenticate(_("Red Hat account: "))
        secreds = authenticate(_("System Engine Username: "))
        serverurl = options.serverurl
    else:
        rhncreds = authenticate(_("Red Hat account: "))
        if not isHosted():
            secreds = authenticate(_("System Engine Username: "))
        else:
            secreds = rhncreds  # make them the same

    transferHttpProxySettings()
    cp = checkOkToProceed(secreds, serverurl)

    org = getOrg(cp, secreds.username)
    environment = getEnvironment(cp, org)

    (sc, sk) = connectToRhn(rhncreds)
    checkIsOrgAdmin(sc, sk, rhncreds.username)

    # get a list of RHN classic channels this machine is subscribed to
    print _("\nRetrieving existing RHN Classic subscription information ...")
    subscribedChannels = getSubscribedChannelsList()
    print_banner("\n" + _("System is currently subscribed to these RHN Classic Channels:"))
    for channel in subscribedChannels:
        print channel

    checkForConflictingChannels(subscribedChannels)
    deployProdCertificates(subscribedChannels)
    cleanUp(subscribedChannels)

    writeMigrationFacts()
    print _("\nPreparing to unregister system from RHN Classic ...")
    unRegisterSystemFromRhnClassic(sc, sk)

    # register the system to Certificate-based RHN and consume a subscription
    consumer = register(secreds, serverurl, org, environment)
    if not options.noauto:
        if options.servicelevel:
            servicelevel = selectServiceLevel(cp, consumer, options.servicelevel)
            subscribe(consumer, servicelevel)
        else:
            subscribe(consumer, None)
    # check if we need to enable to supplementary/optional channels
    enableExtraChannels(subscribedChannels, cp)


if __name__ == '__main__':
    main()
