#!/usr/bin/python
#
# (c) Copyright 2015 Hewlett Packard Enterprise Development LP.
#
"""
------------.-- - -- --------- --- -- ----- ---------

--- ---- -- -- ------ --- ------- -- --- ---- ----.
"""

import io
import pycurl
import os
import json
import sys
import traceback
import optparse
import re
import time
import subprocess


#
class ExitCode:
    em_flash_success         =   0 #
    no_em_found              = 101 #
    user_terminated          = 120 #
    incorrect_parameters     = 140 #
    login_failed             = 160 #
    em_not_active            = 162 #
    em_update_timeout        = 163 #
    em_not_resp_after_reboot = 164 #
    em_internal_error        = 165 #
    internal_error           = 166 #
    no_local_privileges      = 169 #
    update_in_progress       = 170 #
    invalid_package          = 171 #
    em_update_interrupted    = 172 #
    standby_em_failed        = 180 #
    em_staged_fwver_conflict = 181 #

    @staticmethod
    def contains(val):
        """
        ----- -- - ----- ------ -- --- -- --- ---- -----.
        ------ ---- -- -----.
        """

        for attr_name in ExitCode.__dict__:
            attr = getattr(ExitCode, attr_name)
            if (type(attr) is int and attr == val):
                return True
        return False


#
ris_rc_success = 200
ris_rc_created = 201
ris_rc_accepted = 202
ris_rc_bad_request = 400
ris_rc_unauthorized = 401
ris_rc_forbidden = 403
ris_rc_not_found = 404
ris_rc_resource_conflict = 409
ris_rc_server_error = 500

#
em_err_invalid_fw_pkg = 'InvalidFirmwarePackage'
em_err_update_interrupted = 'UpdateInterrupted'
em_err_standby_failed = 'StandbyUpdateFailed'
em_err_msgid_unsupported_options = 'em.UnsupportedDataInRequest'
em_err_staged_fwver_conflict = 'em.StagedFwVersionsConflict'

#
EM_FWPKGUTIL = "/usr/bin/em-fwpkgutil"

OPT_STAGE_FW = "--stage-firmware"
OPT_ACTIVATE_STAGED_FW = "--activate-staged-firmware"
OPT_CLEAR_STAGED_FW = "--clear-staged-firmware"
OPT_SYNC_STAGED_FW = "--sync-staged-firmware"

#
ACTION_ACTIVATE_STAGED_FW = "ActivateStagedFirmware"
ACTION_CLEAR_STAGED_FW = "ClearStagedFirmware"
ACTION_SYNC_STAGED_FW = "SyncStagedFirmware"

#
actions_dict = {
    #
    OPT_ACTIVATE_STAGED_FW :    ACTION_ACTIVATE_STAGED_FW,
    OPT_CLEAR_STAGED_FW:        ACTION_CLEAR_STAGED_FW,
    OPT_SYNC_STAGED_FW:         ACTION_SYNC_STAGED_FW
}

#
class MyBytesIO(io.BytesIO):
    pass


class EmCurlFactory():
    """ ------- ----- --- ------ ---- ------- --- - ----- --. """

    def __init__(self, ip_addr, use_http, user, password, updating_peer_em):
        """
        --------- --- ------- ---- ---- -- ---- -- ---- ---- ------- --- --- --.
        ---- ---- ---- -- ----- ---- --- ----, ------- --'-- ---- -- --- -----
        ----- -- --------- --------- -- ----.
        -- -- --- -- ------- -- ------ --------- --- ------ ---- --- --- ----,
        --'-- -------- ---- -- --- ------- ------- --- ------- --- --- ----.
        """

        self.ip_addr = ip_addr
        self.user = user
        self.password = password
        self.logged_in = False
        self.use_peer_login = updating_peer_em

        self.use_https = (not use_http or updating_peer_em)

        print("Getting RIS root..")
        curl = self.make_curl_obj("/rest/v1")
        (rc, data) = ris_perform(curl)
        if rc != ris_rc_success:
            if ip_addr[0] == "[" and ip_addr[-1] == "]":
                print("will retry without square brackets..")
                self.ip_addr = ip_addr[1:-1]
            else:
                print("will retry with square brackets..")
                self.ip_addr = "[" + ip_addr + "]"
            curl = self.make_curl_obj("/rest/v1")
            (rc, data) = ris_perform(curl)
            if rc != ris_rc_success:
                print("ERROR: Target is not an EM or is not responsive")
                exit(ExitCode.no_em_found)

        self.resources = {}
        hp_links = data["Oem"]["Hp"]["links"]
        for key in hp_links:
            self.resources[key] = hp_links[key]["href"].encode('ascii')
        #
        #
        if "links" in data:
            links = data["links"]
            for key in links:
                self.resources[key] = links[key]["href"].encode('ascii')

        self.login()
        if not updating_peer_em:
            self._replace_floating_ip_addr()


    def _replace_floating_ip_addr(self):
        """
        ---- ---- -- --- --- ---------- -- -------, --- --- ------ -------- --
        -------.
        """

        print("")
        print("Checking the EM static IP address..")

        curl = self.make_curl_obj("UpdateService")
        (rc, data) = ris_perform(curl)
        if rc != ris_rc_success:
            print("ERROR: Failed to get the UpdateService RIS resource.")
            exit(ExitCode.em_internal_error)

        em_ip_addr = str(data["EmIpAddress"])
        em_ip_addr = re.sub(r'[\da-fA-F:]+', em_ip_addr, self.ip_addr, 1)

        if self.ip_addr.lower() != em_ip_addr.lower():
            #
            #
            #
            #
            #
            #
            #
            #
            print("")
            print("Will switch from floating to static IP address.")
            print("The IP address we were given:", self.ip_addr)
            print("The IP address we will use instead:", em_ip_addr)
            print("Logging out and creating a new session..")
            self.logout()
            self.ip_addr = em_ip_addr
            self.login()

    def login(self):
        """ --- -- -- --- -- -- ----- - --- -------. """

        self.logged_in = False
        if not self.use_https:
            return

        if self.use_peer_login:

            print("")
            print("Creating session with peer EM..")

            #
            #
            #

            #
            attempts = 0
            tryAgain = True
            while tryAgain:
                try:
                    output = subprocess.check_output(["create_peer_em_ris_session", "1800"])
                    session_id, token = re.findall('\S+', output)
                    #
                    tryAgain = False
                except:
                    attempts += 1
                    if attempts < 24:
                        print("Session creation attempt #%d failed... waiting for retry" % attempts)
                        time.sleep(5)
                    else:
                        print("ERROR: unable to create session with the standby EM")
                        exit(ExitCode.login_failed)

            self.auth_token = token
            self.session_uri = self.resources["Sessions"] + "/" + session_id

        else:

            print("")
            print("Creating session..")

            curl = self.make_curl_obj("Sessions", http_header=['Content-Type: application/json'])
            curl.setopt(curl.POST, 1)
            data = json.dumps({"UserName":self.user, "Password":self.password})
            curl.setopt(curl.POSTFIELDS, data)
            curl.setopt(curl.HEADER, 1)

            #
            #
            #

            attempts = 0
            while attempts < 24:
                (rc, resp) = ris_perform(curl, jsonify_resp=False)
                if rc == ris_rc_created:
                    break
                #
                attempts += 1
                time.sleep(5)

            if rc != ris_rc_created:
                print("ERROR: Failed to create session with the given credentials")
                exit(ExitCode.login_failed)

            m = re.search(r"X-Auth-Token: (\S+)", resp)
            if m is None:
                print("ERROR: Failed to parse auth token from login response")
                exit(ExitCode.login_failed)
            self.auth_token = m.group(1)

            m = re.search(r"Location: (\S+)", resp)
            if m is None:
                print("ERROR: Failed to parse session location from login response")
                exit(ExitCode.login_failed)
            self.session_uri = m.group(1)

        self.logged_in = True


    def logout(self):
        """ --- --- -- --- --- ------- --'-- -------. """

        if not self.logged_in:
            return

        print("")
        print("Logging out..")

        try:
            curl = self.make_curl_obj(self.session_uri)
            curl.setopt(curl.CUSTOMREQUEST, 'DELETE')

            (rc, resp) = ris_perform(curl, jsonify_resp=False)
            if rc != ris_rc_success:
                print("ERROR: Failed to delete session")
            else:
                self.logged_in = False
        except:
            print("ERROR: Failed to delete session due to exception..")
            traceback.print_exc(file=sys.stdout)


    def __del__(self):
        """ ---------- -- ----- -- --- -------. """
        self.logout()


    def make_curl_obj(self, resource, timeout=60, http_header=None):
        """
        ------ - ---- ------ -------- -- --- ----- --- --------.
        --- --- -------- ---- -- - ---- ---- -- --- ---- -- --- -- --- -----
        ----- -- --- --- ----.
        """

        #
        #
        #
        if http_header is None:
            http_header = []

        if resource[0] == "/":
            #
            pass
        elif resource not in self.resources:
            print("ERROR: Unknown resource: ", resource)
            exit(ExitCode.internal_error)
        else:
            resource = self.resources[resource]

        curl = pycurl.Curl()
        curl.setopt(curl.TIMEOUT, timeout)

        #
        try:
            curl.setopt(curl.SSLVERSION, curl.SSLVERSION_TLSv1)
        except AttributeError as e:
            print("Error while setting curl.SSLVERSION_TLSv1:", e)
        except pycurl.error as (errno, errstr):
            print("curl error:", errno, errstr)

        #
        #
        #
        #
        curl.setopt(curl.SSL_VERIFYPEER, False)
        curl.setopt(curl.SSL_VERIFYHOST, False)
        if self.logged_in:
            http_header.append('X-Auth-Token: ' + self.auth_token)
        if len(http_header) > 0:
            curl.setopt(curl.HTTPHEADER, http_header)
        if self.use_https:
            curl.setopt(curl.URL, "https://" + self.ip_addr + resource)
        else:
            curl.setopt(curl.URL, "http://" + self.ip_addr + resource)

        return curl


def get_cmd_line_args(argv):
    """-----, ------, & ------ --- ---- ----."""

    #
    class MyOptionParser(optparse.OptionParser):
        def error(self, msg):
            self.print_usage(sys.stderr)
            self.exit(ExitCode.incorrect_parameters, "%s: error: %s\n" % (self.get_prog_name(), msg))

    #
    fw_and_staging_opt = "{ -f <fw_package> [" + OPT_STAGE_FW + "] }"
    other_json_opt = "{ " + OPT_ACTIVATE_STAGED_FW + " | " + OPT_CLEAR_STAGED_FW + " | " + OPT_SYNC_STAGED_FW + " }"
    usage = "usage: %prog -e <ip_addr> " + fw_and_staging_opt + " | " + other_json_opt + " -u <username> [-p <password>]"
    parser = MyOptionParser(usage=usage)
    parser.add_option("-f", "--file", dest="fw_pkg", help="EM firmware package filename", metavar="FILE")
    parser.add_option("-e", "--em", dest="em_ip_addr", help="Targeted EM IP address", metavar="IP_ADDR")
    parser.add_option("-u", "--user", dest="user", help="Username for EM", metavar="STRING")
    parser.add_option("-p", "--pass", dest="password", help="Password for EM", metavar="STRING")
    parser.add_option("-H", "--http", dest="use_http", help="Use unautheticated HTTP (developer builds only)", action="store_true")
    parser.add_option("-s", "--allow-standby", dest="allow_standby", help="Allow targeting the standby EM only", action="store_true")
    parser.add_option("--updating-peer-em", dest="updating_peer_em", help="I am an EM updating my peer. Don't bother creating session.", action="store_true")

    #
    parser.add_option("-S", OPT_STAGE_FW, dest="stage_fw", help="Only stage firmware into the EMs, do not update them.", action="store_true")
    parser.add_option("-A", OPT_ACTIVATE_STAGED_FW, dest="activate_fw", help="Only activate firmware that has been staged on the EMs.", action="store_true")
    parser.add_option("-C", OPT_CLEAR_STAGED_FW, dest="clear_fw", help="Clear firmware that has been staged on the EMs.", action="store_true")
    parser.add_option("-Y", OPT_SYNC_STAGED_FW, dest="sync_fw", help="Synchronize firmware from the Current partition onto the Staged partition of the EMs.", action="store_true")

    #
    (opts, args) = parser.parse_args(argv)
    if not opts.em_ip_addr:
        parser.error('EM IP address not given')
    if opts.activate_fw or opts.clear_fw or opts.sync_fw:
        if opts.fw_pkg:
            parser.error('Specifying a Firmware package not valid for the operation')
    elif not opts.fw_pkg:
        parser.error('Firmware package not given')
    elif not os.path.exists(opts.fw_pkg):
        parser.error("Firmware package doesn't exist")
    if not opts.use_http:
        if not opts.user:
            parser.error('Username not given')
        if not opts.password:
            import getpass
            opts.password = getpass.getpass()
    if opts.updating_peer_em:
        opts.allow_standby = True

    #
    #
    #
    #
    compat_matrix = {
    #
        ('stage_fw', 'activate_fw')         : False,
        ('stage_fw', 'clear_fw')            : False,
        ('stage_fw', 'sync_fw')             : False,
        ('activate_fw', 'clear_fw')         : False,
        ('activate_fw', 'sync_fw')          : False,
        ('clear_fw', 'sync_fw')             : False
    }

    #
    arg_dict = vars(opts)
    arg_list = [i for i in arg_dict if arg_dict[i] == True]
    #

    #
    if len(arg_list) > 1:
        for i, p1 in enumerate(arg_list):
            for p2 in arg_list[i+1:]:
                compat = False if ((p1, p2) in compat_matrix and compat_matrix[(p1, p2)] == False) else True
                if compat == False:
                    break
            if compat == False:
                break;
    else:
        compat = True

    if compat == False:
         parser.error('The specified arguments are incompatible. See usage.')
    return opts


def verify_em_active(curl_factory):
    """
    ---- ---- -- --- --------- -- ------ --. ---- -- -------.
    """

    print("")
    print("Verifying EM role..")

    curl = curl_factory.make_curl_obj("UpdateService")
    (rc, data) = ris_perform(curl)
    if rc != ris_rc_success:
        print("ERROR: Failed to get UpdateService RIS resource")
        exit(ExitCode.internal_error)
    if not data["EmIsActive"]:
        print("ERROR: Target EM is not active. Use '--updating-standby' to override.")
        exit(ExitCode.em_not_active)


def pkg_size(fw_pkg, updating_peer_em):
    """ ------ ------- ----. """
    if not updating_peer_em:
        return os.path.getsize(fw_pkg)
    else:
        cmd = [EM_FWPKGUTIL, "rawcatsize", fw_pkg]
        return int(subprocess.check_output(cmd))


def open_pkg(fw_pkg, updating_peer_em):
    """ ------ ---- ------ -- ---- -------- ------- --------. """
    if updating_peer_em:
        cmd = [EM_FWPKGUTIL, "rawcat", fw_pkg, "/dev/stdout"]
        proc = subprocess.Popen(cmd, stdout=subprocess.PIPE)
        return proc.stdout
    else:
        try:
            return open(fw_pkg, 'rb')
        except IOError:
            print("ERROR: Unable to open", fw_pkg)
            exit(ExitCode.no_local_privileges)


def check_print_extended_err_response(data, msgid):
    """
    ------ -- ----- -- ---- -- --- ----. -- -------, ---- ------ -----------
    ------ --- ------- ----. ----, ------- -----
    """
    if (data and "Name" in data and data["Name"] == "Extended Error"):
        #
        #
        #
        msg_list = [msg for msg in data["Messages"] if msgid in msg["MessageID"]] or [{}]
        msg_info = msg_list[0]
        if bool(msg_info) != False:
            #
            print("ERROR: %s %s" %(msg_info["Message"], msg_info["Resolution"]))
            return True
        else:
            return False
    else:
        return False


def post_update(curl_factory, fw_pkg, updating_peer_em, op = None):
    """
    ---- -- ------ ------- -- -- -- --- ---.
    ----- -- ----- ---- ----------- --------.
    """

    attempts = 1
    while True:

        print("")
        print("Posting update request..")

        if (op == None or op == OPT_STAGE_FW): 
            octet_blob = True
        else:
            octet_blob = False

        #
        if not octet_blob:
            if op == OPT_ACTIVATE_STAGED_FW:
                curl = curl_factory.make_curl_obj("UpdateService", timeout=300, http_header=['Content-Type: application/json'])
            else:
                curl = curl_factory.make_curl_obj("UpdateService", timeout=180, http_header=['Content-Type: application/json'])
            curl.setopt(curl.POST, 1)
            data = json.dumps({"Action":actions_dict[op]})
            curl.setopt(curl.POSTFIELDS, data)
            print data
        else :
            if op == OPT_STAGE_FW:
                curl = curl_factory.make_curl_obj("UpdateService", timeout=180, http_header=['X-Stage-Only:1', 'Content-Type: application/octet-stream'])
            else:
                curl = curl_factory.make_curl_obj("UpdateService", timeout=300, http_header=['Content-Type: application/octet-stream'])
            curl.setopt(curl.POST, 1)
            curl.setopt(curl.POSTFIELDSIZE, pkg_size(fw_pkg, updating_peer_em))
            fin = open_pkg(fw_pkg, updating_peer_em)
            curl.setopt(curl.READFUNCTION, fin.read)

        #
        (rc, data) = ris_perform(curl)
        if (rc == ris_rc_accepted):
            return
        elif (rc == ris_rc_resource_conflict):
            #
            #
            if check_print_extended_err_response(data, em_err_staged_fwver_conflict):
                exit(ExitCode.em_staged_fwver_conflict)
            elif check_print_extended_err_response(data, em_err_standby_failed):
                #
                exit(ExitCode.standby_em_failed)
            else:
                print("ERROR: Another update is already in progress")
                exit(ExitCode.update_in_progress)
        elif (rc == ris_rc_bad_request):
            if (data and "ResultMessage" in data and em_err_invalid_fw_pkg in data["ResultMessage"]):
                print("ERROR: Invalid firmware package")
                exit(ExitCode.invalid_package)
            elif check_print_extended_err_response(data, em_err_msgid_unsupported_options):
                exit(ExitCode.incorrect_parameters)
            else: 
                print("ERROR: Update request failed")
                exit(ExitCode.em_internal_error)
        elif (rc == ris_rc_server_error):
            print("ERROR: Update request failed")
            exitCode = ExitCode.em_internal_error
        else:
            print("ERROR: Unexpected response to update request")
            exitCode = ExitCode.internal_error

        #
        if attempts < 3:
            attempts += 1
            print("")
            print("Waiting to retry...")
            time.sleep(5)
        else:
            exit(exitCode)

def ris_perform(curl, jsonify_resp=True):
    """
    ------- --- --------- ---- -------- ---- ------.
    ------- (--- -------- ----, --- ---- -------- ----).
    """

    #
    fout = MyBytesIO()
    curl.setopt(curl.WRITEFUNCTION, fout.write)

    #
    try:
        curl.perform()
    except pycurl.error as (errno, errstr):
        print("curl error:", errno, errstr)
        return (None, None)

    #
    rc = curl.getinfo(curl.RESPONSE_CODE)
    print("response code:", rc)
    if (rc == ris_rc_accepted):
        #
        return (rc, None)

    #
    resp_buf = fout.getvalue()

    if not jsonify_resp:
        print(resp_buf)
        return (rc, resp_buf)

    #
    json_data = None
    try:
        json_data = json.loads(resp_buf)
        print("data:", json_data)
    except ValueError as error:
        print(resp_buf)
        print(error)

    return (rc, json_data)


def poll_em(curl_factory, url, max_wait_time, sleep_time, break_condition):
    """
    ---- --, -- ------- --- -- ------ ---, --- - ----- --------- -- -- ---.
    ------- (------- -- ----- --------- --- ---, --- -------- ----, ---- ----).
    """

    #
    #
    #
    #
    #
    #
    #
    #
    time_remaining = max_wait_time
    while time_remaining > 0:
        time_remaining -= sleep_time
        time.sleep(sleep_time)
        curl = curl_factory.make_curl_obj(url)
        (rc, data) = ris_perform(curl)
        if break_condition(rc, data):
            return (True, rc, data)
    return (False, None, None)


def wait_for_update_complete(curl_factory, op = None):
    """
    ---- --- ------ -- --------.
    ----- ---- ----------- -------- -- -----.
    """

    print("")
    print("Waiting for update to complete..")
    def break_condition(rc, data): return ((data and data["Successful"] is not None) or rc == ris_rc_unauthorized)
    (completed, rc, data) = poll_em(curl_factory, "UpdateService", 900, 15, break_condition)
    if not completed:
        print("ERROR: Timed out waiting for update to complete.")
        exit(ExitCode.em_update_timeout)
    elif data and em_err_update_interrupted in data["ResultMessage"]:
        print("ERROR: Update interrupted")
        exit(ExitCode.em_update_interrupted)
    elif data and em_err_standby_failed in data["ResultMessage"]:
        print("ERROR: Failed to update the standby EM, update aborted.")
        exit(ExitCode.standby_em_failed)
    elif data and not data["Successful"]:
        print("ERROR: Update failed")
        exit(ExitCode.em_internal_error)

    #
    if (op == None or op == OPT_ACTIVATE_STAGED_FW):
        #
        #
        #
        if data and data["RebootPending"]:
            wait_for_em_shutdown(curl_factory)

        verify_update_status_after_reboot(curl_factory)


def wait_for_em_shutdown(curl_factory):
    """
    ---- --- -- -- ---- ---- --- ------.
    ----- ---- ----------- -------- -- -----.
    """

    #
    print("")
    print("Waiting for EM to shutdown for reboot..")
    def break_condition(rc, data): return (rc is None)
    (completed, rc, data) = poll_em(curl_factory, "UpdateService", 600, 5, break_condition)
    if not completed:
        print("ERROR: Did not see EM go down for reboot")
        exit(ExitCode.em_internal_error)


def verify_update_status_after_reboot(curl_factory):
    """
    ----- -- --- --------, --- ---- -- --- ------ ------ ------ ------- -------.
    """
    wait_for_ris_root(curl_factory)
    wait_for_login(curl_factory)
    wait_for_final_update_status(curl_factory)


def wait_for_ris_root(curl_factory):
    print("")
    print("Waiting for RIS root to be available..")

    curl_factory.logged_in = False
    def break_condition(rc, data): return (rc == ris_rc_success)
    (completed, rc, data) = poll_em(curl_factory, "/rest/v1", 600, 5, break_condition)
    if not completed:
        print("ERROR: EM or RIS not available after reboot")
        exit(ExitCode.em_not_resp_after_reboot)


def wait_for_login(curl_factory):
    #
    #
    time_remaining = 150 
    while True:
        try:
            curl_factory.login()
            break
        except:
            #
            if time_remaining <= 0:
                print("ERROR: Unable to login after reboot")
                exit(ExitCode.em_internal_error)
            else:
                print("Unable to login yet, will retry...")
        time.sleep(5)
        time_remaining -= 5


def wait_for_final_update_status(curl_factory):
    print("")
    print("Check final update status..")

    def break_condition(rc, data): return (rc is not None)
    (completed, rc, data) = poll_em(curl_factory, "UpdateService", 120, 5, break_condition)
    if not completed or data is None:
        print("ERROR: Failed to get update status after EM reboot")
        exit(ExitCode.em_internal_error)
    if em_err_update_interrupted in data["ResultMessage"]:
        print("ERROR: Update interrupted")
        exit(ExitCode.em_update_interrupted)
    elif not data["Successful"]:
        print("ERROR: Update failed")
        exit(ExitCode.em_internal_error)


def main():
    try:
        opts = get_cmd_line_args(sys.argv)
        if opts.stage_fw:
            op = OPT_STAGE_FW
        elif opts.activate_fw:
            op = OPT_ACTIVATE_STAGED_FW
        elif opts.clear_fw:
            op = OPT_CLEAR_STAGED_FW
        elif opts.sync_fw:
            op = OPT_SYNC_STAGED_FW
        else:
            op = None
        curl_factory = EmCurlFactory(opts.em_ip_addr, opts.use_http, opts.user, opts.password, opts.updating_peer_em)
        if not opts.allow_standby:
            verify_em_active(curl_factory)
        post_update(curl_factory, opts.fw_pkg, opts.updating_peer_em, op)
        wait_for_update_complete(curl_factory, op)
        print("")
        print("Success!")
        exit(ExitCode.em_flash_success)
    except KeyboardInterrupt:
        print("Terminated by user")
        exit(ExitCode.user_terminated)
    except SystemExit as e:
        if (not ExitCode.contains(e.code)):
            print("ERROR: Unknown exit code attempted:", e.code)
            e.code = ExitCode.internal_error
        raise e
    except:
        print("ERROR: Unhandled exception")
        traceback.print_exc(file=sys.stdout)
        exit(ExitCode.internal_error)


if __name__ == "__main__":
    main()
