#!/usr/bin/python
#
# Copyright (c) 2010 Red Hat, Inc.
#
# This software is licensed to you under the GNU General Public License,
# version 2 (GPLv2). There is NO WARRANTY for this software, express or
# implied, including the implied warranties of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. You should have received a copy of GPLv2
# along with this software; if not, see
# http://www.gnu.org/licenses/old-licenses/gpl-2.0.txt.
#
# Red Hat trademarks are not licensed under GPLv2. No permission is
# granted to use or replicate Red Hat trademarks that are incorporated
# in this software or its documentation.
#

import sys
import os
import xmlrpclib
import httplib
import getpass
import libxml2
import subprocess
import re
import simplejson as json
import shutil
import logging
import traceback
import base64
from datetime import datetime
from rhsm.connection import UEPConnection

_LIBPATH = "/usr/share/rhsm"
# add to the path if need be
if _LIBPATH not in sys.path:
    sys.path.append(_LIBPATH)

from subscription_manager.i18n import configure_i18n
configure_i18n()

import gettext
_ = gettext.gettext

# quick check to see if you are a super-user.
if os.getuid() != 0:
    print _("Must be root user to execute\n")
    sys.exit(8)

# access the rhn/up2date python libaries and read the up2date config file
_RHNLIBPATH = "/usr/share/rhn"
if _RHNLIBPATH not in sys.path:
    sys.path.append(_RHNLIBPATH)

from up2date_client import up2dateErrors
from up2date_client.rhnChannel import  getChannels
import up2date_client.config

rhncfg = up2date_client.config.initUp2dateConfig()

from optparse import Option, OptionParser
options_table = [
    Option("-f", "--force", action="store_true", default=False,
           help=_("Ignore channels not available on RHSM")),
    Option("-g", "--gui", action="store_true", default=False, dest='gui',
           help=_("Launch the GUI tool to subscribe the system, instead of autosubscribing")),
    Option("-n", "--no-auto", action="store_true", default=False, dest='noauto',
           help=_("Don't execute the autosubscribe option while registering with subscription manager.")),
    Option("-s", "--servicelevel", dest="servicelevel",
           help=_("Service level to subscribe this system to. For no service "
                  "level use --service-level=\"\"")),
]

parser = OptionParser(option_list=options_table, usage=_("%s [OPTIONS]") % os.path.basename(sys.argv[0]))
(options, args) = parser.parse_args()

# access the rhsm python libraries, read rhsm config file and setup logging
_RHSMLIBPATH = "/usr/share/rhsm"
if _RHSMLIBPATH not in sys.path:
    sys.path.append(_RHSMLIBPATH)

try:
    # try importing from RHEL 6 location
    from subscription_manager.certlib import ConsumerIdentity, ProductDirectory
    from subscription_manager import repolib, logutil
except ImportError:
    # old RHEL 5 location
    from certlib import ConsumerIdentity, ProductDirectory
    import repolib
    import logutil

import rhsm.config

logutil.init_logger()
log = logging.getLogger('rhsm-app.' + __name__)
rhsmcfg = rhsm.config.initConfig()

proxyHost = ""
proxyPort = ""
proxyUser = ""
proxyPass = ""

CONNECTION_FAILURE = _("Unable to connect to certificate server.  See /var/log/rhsm/rhsm.log for more details.")


class InvalidChoiceError(Exception):
    pass


class Menu(object):
    def __init__(self, choices, header):
        # choices is a tuple with the first value being the display string
        # and the second value being the value to return.
        self.choices = choices
        self.header = header

    def choose(self):
        while True:
            self.display()
            selection = raw_input("? ").strip()
            try:
                return self.getItem(selection)
            except InvalidChoiceError:
                print _("You have entered an invalid choice.")

    def display(self):
        print self.header
        for index, entry in enumerate(self.choices):
            print "%s. %s" % (index + 1, entry[0])

    def getItem(self, selection):
        try:
            index = int(selection) - 1
            # In case some joker enters zero or a negative number
            if index < 0:
                raise InvalidChoiceError
        except TypeError:
            raise InvalidChoiceError
        except ValueError:
            raise InvalidChoiceError

        try:
            return self.choices[index][1]
        except IndexError:
            raise InvalidChoiceError


class ProxiedTransport(xmlrpclib.Transport):
    def set_proxy(self, proxy, credentials):
        self.proxy = proxy
        self.credentials = credentials

    def make_connection(self, host):
        self.realhost = host
        return httplib.HTTP(self.proxy)

    def send_request(self, connection, handler, request_body):
        connection.putrequest("POST", 'http://%s%s' % (self.realhost, handler))

    def send_host(self, connection, host):
        connection.putheader('Host', self.realhost)
        if self.credentials:
            connection.putheader('Proxy-Authorization', 'Basic ' + self.credentials)


def systemExit(code, msgs=None):
    "Exit with a code and optional message(s). Saved a few lines of code."
    if msgs:
        if type(msgs) not in [type([]), type(())]:
            msgs = (msgs, )
        for msg in msgs:
            sys.stderr.write(str(msg) + '\n')
    sys.exit(code)


def checkOkToProceed(username, password):
    # check if this machine is already registered to Certicate-based RHN
    if ConsumerIdentity.existsAndValid():
        print _("\nThis machine appears to be already registered to Certificate-based RHN.  Exiting.")
        consumer = ConsumerIdentity.read()
        systemExit(1, _("\nPlease visit https://access.redhat.com/management/consumers/%s to view the profile details.") % consumer.getConsumerId())

    # Check to make sure we're hooked up to Hosted.
    serverURL = rhncfg['serverURL']
    if not re.search('xmlrpc\.rhn\..*redhat\.com', serverURL):
        log.info("serverURL = %s" % serverURL)
        systemExit(1, _("This migration script requires the system to be registered to RHN Classic.\nHowever this system appears to be registered to '%s'.\nExiting.") % serverURL.split('/')[2])

    # Check to make sure we can connect to the certificate server.
    cp = UEPConnection(username=username, password=password,
            proxy_hostname=rhsmcfg.get('server', 'proxy_hostname'),
            proxy_port=rhsmcfg.get('server', 'proxy_port'),
            proxy_user=rhsmcfg.get('server', 'proxy_user'),
            proxy_password=rhsmcfg.get('server', 'proxy_password'))

    try:
        cp.getOwnerList(username)
    except Exception:
        log.error(traceback.format_exc())
        systemExit(1, CONNECTION_FAILURE)
    return cp


def checkIsOrgAdmin(sc, sk, username):
    try:
        roles = sc.user.listRoles(sk, username)
    except:
        log.error(traceback.format_exc())
        systemExit(1, _("Problem encountered determining user roles in RHN Classic.  Exiting."))
    if "org_admin" not in roles:
        systemExit(1, _("You must be an org admin to successfully run this script."))


def selectServiceLevel(cp, consumer, servicelevel):
    uuid = consumer.getConsumerId()
    try:
        org_key = cp.getOwner(uuid)['key']
        levels = cp.getServiceLevelList(org_key)
    except:
        log.error(traceback.format_exc())
        systemExit(1, CONNECTION_FAILURE)

    # Create the sla tuple before appending the empty string to the list of
    # valid slas.
    slas = [(sla, sla) for sla in levels]
    # Display an actual message for the empty string level.
    slas.append((_("No service level"), ""))

    # The empty string is a valid level so append it to the list.
    levels.append("")
    if servicelevel is None or servicelevel not in levels:
        if servicelevel is not None and servicelevel not in levels:
            print _("\nService level \"%s\" is not available." % servicelevel)
        menu = Menu(slas, _("Please select a service level agreement for this system."))
        servicelevel = menu.choose()
    return servicelevel


def setServiceLevel(cp, consumer, servicelevel):
    uuid = consumer.getConsumerId()
    if servicelevel is "":
        print _("\nNot subscribing to any service level.")
    else:
        print _("\nSubscribing to service level %s" % servicelevel)

    try:
        cp.updateConsumer(uuid, service_level=servicelevel)
    except:
        log.error(traceback.format_exc())
        systemExit(1, CONNECTION_FAILURE)


def getSubscribedChannelsList():
    try:
        subscribedChannels = map(lambda x: x['label'], getChannels().channels())
    except up2dateErrors.NoChannelsError:
        systemExit(1, _("This system is not associated with any channel."))
    except up2dateErrors.NoSystemIdError:
        systemExit(1, _("Unable to locate SystemId file. Is this system registered?"))
    except:
        log.error(traceback.format_exc())
        systemExit(1, _("Problem encountered getting the list of subscribed channels.  Exiting."))
    return subscribedChannels


def getSystemId():
    systemIdPath = rhncfg["systemIdPath"]
    p = libxml2.parseDoc(file(systemIdPath).read())
    systemId = int(p.xpathEval('string(//member[* = "system_id"]/value/string)').split('-')[1])
    return systemId


def connectToRhn(username, password):
    hostname = rhncfg['serverURL'].split('/')[2]
    server_url = 'https://%s/rpc/api' % (hostname)
    try:
        if rhncfg['enableProxy']:
            pt = ProxiedTransport()
            if rhncfg['enableProxyAuth']:
                credentials = base64.encodestring('%s:%s' % (proxyUser, proxyPass)).strip()
            else:
                credentials = ""

            pt.set_proxy("%s:%s" % (proxyHost, proxyPort), credentials)
            log.info("Using proxy %s:%s for RHN API methods" % (proxyHost, proxyPort))
            sc = xmlrpclib.Server(server_url, transport=pt)
        else:
            sc = xmlrpclib.Server(server_url)

        sk = sc.auth.login(username, password)
        return (sc, sk)
    except:
        log.error(traceback.format_exc())
        systemExit(1, _("Unable to authenticate to RHN Classic.  See /var/log/rhsm/rhsm.log for more details."))


def unRegisterSystemFromRhnClassic(sc, sk):
    #getSystemIdPath
    systemIdPath = rhncfg["systemIdPath"]
    systemId = getSystemId()

    log.info("Deleting system %s from RHN Classic...", systemId)
    result = sc.system.deleteSystems(sk, systemId)
    if result:
        log.info("System %s deleted.  Removing systemid file and disabling rhnplugin.conf", systemId)
        os.remove(systemIdPath)
        disableYumRhnPlugin()
        print _("System successfully unregistered from RHN Classic.")
    else:
        systemExit(1, _("Unable to unregister system from RHN Classic.  Exiting."))


def disableYumRhnPlugin():
    # 'Inspired by' up2date_client/rhnreg.py
    """ disable yum-rhn-plugin by setting enabled=0 in file
        /etc/yum/pluginconf.d/rhnplugin.conf
        Can thrown IOError exception.
    """
    log.info("Disabling rhnplugin.conf")
    YUM_PLUGIN_CONF = '/etc/yum/pluginconf.d/rhnplugin.conf'
    f = open(YUM_PLUGIN_CONF, 'r')
    lines = f.readlines()
    f.close()
    main_section = False
    f = open(YUM_PLUGIN_CONF, 'w')
    for line in lines:
        if re.match("^\[.*]", line):
            if re.match("^\[main]", line):
                main_section = True
            else:
                main_section = False
        if main_section:
            line = re.sub('^(\s*)enabled\s*=.+', r'\1enabled = 0', line)
        f.write(line)
    f.close()


def readChannelCertMapping(mappingfile):
    f = open(mappingfile)
    lines = f.readlines()
    dic_data = {}
    for line in lines:
        if re.match("^[a-zA-Z]", line):
            line = line.replace("\n", "")
            key, val = line.split(": ")
            dic_data[key] = val
    return dic_data


def transferHttpProxySettings():
    # transfer http proxy information from up2date to rhsm.conf
    global proxyHost, proxyPort, proxyUser, proxyPass
    if rhncfg['enableProxy']:
        httpProxy = rhncfg['httpProxy']
        if httpProxy[:7] == "http://":
            httpProxy = httpProxy[7:]
        proxyHost, proxyPort = httpProxy.split(':')

        log.info("Using proxy %s:%s - transferring settings to rhsm.conf" % (proxyHost, proxyPort))
        rhsmcfg.set('server', 'proxy_hostname', proxyHost)
        rhsmcfg.set('server', 'proxy_port', proxyPort)
        if rhncfg['enableProxyAuth']:
            proxyUser = rhncfg['proxyUser']
            proxyPass = rhncfg['proxyPassword']
            rhsmcfg.set('server', 'proxy_user', proxyUser)
            rhsmcfg.set('server', 'proxy_password', proxyPass)
        else:
            rhsmcfg.set('server', 'proxy_user', '')
            rhsmcfg.set('server', 'proxy_password', '')
        rhsmcfg.save()


def runSubscriptionManager(username, password):
    # For registering the machine, use the CLI tool to reuse the username/password (because the GUI will prompt for them again)
    print _("\nAttempting to register system to Certificate-based RHN ...")
    result = subprocess.call(['subscription-manager', 'register', '--username=' + username, '--password=' + password])
    if result != 0:
        systemExit(2, _("\nUnable to register.\nFor further assistance, please contact Red Hat Global Support Services."))
    else:
        consumer = ConsumerIdentity.read()
        print _("System '%s' successfully registered to Certificate-based RHN.\n") % consumer.getConsumerName()

    # For subscribing, use the GUI tool if the DISPLAY environment variable is set and the gui tool exists
    if os.getenv('DISPLAY') and os.path.exists('/usr/bin/subscription-manager-gui') and options.gui:
        print _("Launching the GUI tool to manually subscribe the system ...")
        result = subprocess.call(['subscription-manager-gui'], stderr=open(os.devnull, 'w'))
    else:
        print _("Attempting to auto-subscribe to appropriate subscriptions ...")
        result = subprocess.call(['subscription-manager', 'subscribe', '--auto'])
        if result != 0:
            print _("\nUnable to auto-subscribe.  Do your existing subscriptions match the products installed on this system?")
    print _("\nPlease visit https://access.redhat.com/management/consumers/%s to view the details, and to make changes if necessary.") % consumer.getConsumerId()
    return consumer


def deployProdCertificates(subscribedChannels):
    release = getRelease()
    mappingfile = "/usr/share/rhsm/product/" + release + "/channel-cert-mapping.txt"
    log.info("Using mapping file %s", mappingfile)

    try:
        dic_data = readChannelCertMapping(mappingfile)
    except IOError, e:
        log.exception(e)
        systemExit(1, _("Unable to read mapping file: %s") % mappingfile)

    applicableCerts = []
    validRhsmChannels = []
    invalidRhsmChannels = []
    unrecognizedChannels = []

    for channel in subscribedChannels:
        try:
            if dic_data[channel] != 'none':
                validRhsmChannels.append(channel)
                log.info("mapping found for : %s = %s", channel, dic_data[channel])
                if dic_data[channel] not in applicableCerts:
                    applicableCerts.append(dic_data[channel])
            else:
                invalidRhsmChannels.append(channel)
                log.info("%s None", channel)
        except:
            unrecognizedChannels.append(channel)

    if invalidRhsmChannels:
        print "\n+--------------------------------------------------+"
        print _("Channels not available on RHSM:")
        print "+--------------------------------------------------+"
        for i in invalidRhsmChannels:
            print i

    if unrecognizedChannels:
        print "\n+--------------------------------------------------+",
        print _("\nUnrecognized channels. Channel to product certificate mapping missing for these channels.")
        print "+--------------------------------------------------+"
        for i in unrecognizedChannels:
            print i

    if unrecognizedChannels or invalidRhsmChannels:
        if not options.force:
            print _("\nUse --force to ignore these channels and continue the migration.\n")
            sys.exit(1)

    log.info("certs to be copied: %s", applicableCerts)

    print _("\nList of channels for which certs are being copied")
    for i in validRhsmChannels:
        print i

    release = getRelease()

    # creates the product directory if it doesn't already exist
    productDir = ProductDirectory()
    for cert in applicableCerts:
        sourcepath = "/usr/share/rhsm/product/" + release + "/" + cert
        truncated_cert_name = cert.split('-')[-1]
        destinationpath = str(productDir) + "/" + truncated_cert_name
        log.info("cp %s %s ", sourcepath, destinationpath)
        shutil.copy2(sourcepath, destinationpath)
    print _("\nProduct certificates copied successfully to %s") % str(productDir)


def getRelease():
    f = open('/etc/redhat-release')
    lines = f.readlines()
    f.close()
    release = "RHEL-" + str(lines).split(' ')[6].split('.')[0]
    return release


def enableExtraChannels(subscribedChannels, cp):
    # Check if system was subscribed to extra channels like supplementary, optional, fastrack etc.
    # If so, enable them in the redhat.repo file

    extraChannels = {'supplementary': False, 'productivity': False, 'optional': False}
    for subscribedChannel in subscribedChannels:
        if 'supplementary' in subscribedChannel:
            extraChannels['supplementary'] = True
        elif 'optional' in subscribedChannel:
            extraChannels['optional'] = True
        elif 'productivity' in subscribedChannel:
            extraChannels['productivity'] = True

    if True not in extraChannels.values():
        return

    # create and populate the redhat.repo file
    repolib.RepoLib(uep=cp).update()

    # read in the redhat.repo file
    repofile = repolib.RepoFile()
    repofile.read()

    # enable any extra channels we are using and write out redhat.repo
    try:
        for rhsmChannel in repofile.sections():
            if ((extraChannels['supplementary'] and re.search('supplementary$', rhsmChannel)) or
            (extraChannels['optional']  and re.search('optional-rpms$', rhsmChannel)) or
            (extraChannels['productivity']  and re.search('productivity-rpms$', rhsmChannel))):
                log.info("Enabling extra channel '%s'" % rhsmChannel)
                repofile.set(rhsmChannel, 'enabled', '1')
                repofile.write()
    except:
        print _("\nUnable to enable extra repositories.")
        print _("Please ensure system is subscribed to entitlements, and alter /etc/yum.repos.d/redhat.repo to enable extra repositories.")


def writeMigrationFacts():
    migration_date = datetime.now().isoformat()

    FACT_FILE = "/etc/rhsm/facts/migration.facts"
    if not os.path.exists(FACT_FILE):
        f = open(FACT_FILE, 'w')
        json.dump({"migration.classic_system_id": getSystemId(),
                   "migration.migrated_from": "rhn_hosted_classic",
                   "migration.migration_date": migration_date}, f)
        f.close()


def cleanUp(subscribedChannels):
    """
    Hack to address BZ 786257.
    """
    ch = (
            ('rhel-x86_64-client-supplementary-5', 'rhel-x86_64-client-workstation-5'),
            ('rhel-x86_64-client-5', 'rhel-x86_64-client-workstation-5'),
            ('rhel-i386-client-supplementary-5', 'rhel-i386-client-workstation-5'),
            ('rhel-i386-client-5', 'rhel-i386-client-workstation-5')
         )

    productDir = ProductDirectory()
    for channelPair in ch:
        if channelPair[0] in subscribedChannels and channelPair[1] in subscribedChannels:
            try:
                os.remove(os.path.join(str(productDir), "68.pem"))
                log.info("Removed 68.pem due to existence of both %s and %s" % (channelPair[0], channelPair[1]))
            except OSError, e:
                log.info(e)


def main():
    username = raw_input(_("RHN Username: ")).strip()
    password = getpass.getpass()
    transferHttpProxySettings()
    cp = checkOkToProceed(username, password)

    (sc, sk) = connectToRhn(username, password)
    checkIsOrgAdmin(sc, sk, username)

    # get a list of RHN classic channels this machine is subscribed to
    print _("\nRetrieving existing RHN Classic subscription information ...")
    subscribedChannels = getSubscribedChannelsList()
    print "+----------------------------------+"
    print _("System is currently subscribed to:")
    print "+----------------------------------+"
    for channel in subscribedChannels:
        print channel

    deployProdCertificates(subscribedChannels)

    writeMigrationFacts()
    print _("\nPreparing to unregister system from RHN classic ...")
    unRegisterSystemFromRhnClassic(sc, sk)

    # register the system to Certificate-based RHN and consume a subscription
    if not options.noauto:
        consumer = runSubscriptionManager(username, password)
        servicelevel = selectServiceLevel(cp, consumer, options.servicelevel)
        setServiceLevel(cp, consumer, servicelevel)
    # check if we need to enable to supplementary/optional channels
    enableExtraChannels(subscribedChannels, cp)
    cleanUp(subscribedChannels)


if __name__ == '__main__':
    main()
