#!/usr/bin/python
#
# Copyright 2017 Cumulus Networks, Inc. All rights reserved.
# Authors:
#           Roopa Prabhu, roopa@cumulusnetworks.com
#           Julien Fortin, julien@cumulusnetworks.com
#
# ifupdown2 --
#    tool to configure network interfaces
#

import os
import re
import json
import fcntl
import struct
import signal
import socket
import daemon
import select
import datetime
import threading

try:
    import ifupdown2.ifupdown.argv

    from ifupdown2.ifupdown.log import log
    from ifupdown2.ifupdown.main import Ifupdown2
except ImportError:
    import ifupdown.argv

    from ifupdown.log import log
    from ifupdown.main import Ifupdown2


class Daemon:
    shutdown_event = threading.Event()

    def __init__(self):
        self.uds = None
        self.context = None
        self.working_directory = '/var/run/ifupdown2d/'
        self.server_address = '/var/run/ifupdown2d/uds'

        if not os.path.exists(self.working_directory):
            log.info('creating %s' % self.working_directory)
            os.makedirs(self.working_directory, mode=0755)

        if os.path.exists(self.server_address):
            log.info('removing uds %s' % self.server_address)
            os.remove(self.server_address)

        self.context = daemon.DaemonContext(
            working_directory=self.working_directory,
            signal_map={
                signal.SIGINT: self.signal_handler,
                signal.SIGTERM: self.signal_handler,
                signal.SIGQUIT: self.signal_handler,
            },
            umask=0o22
        )

        try:
            self.SO_PEERCRED = socket.SO_PEERCRED
        except AttributeError:
            # powerpc is the only non-generic we care about. alpha, mips,
            # sparc, and parisc also have non-generic values.
            machine = os.uname()[4]
            if re.search(r'^(ppc|powerpc)', machine):
                self.SO_PASSCRED = 20
                self.SO_PEERCRED = 21
            else:
                self.SO_PASSCRED = 16
                self.SO_PEERCRED = 17

        log.info('daemonizing ifupdown2d...')
        self.context.open()

        log.info('preloading all necessary modules')
        self.preload_imports()

        try:
            log.info('opening UNIX socket')
            self.uds = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            fcntl.fcntl(self.uds.fileno(), fcntl.F_SETFD, fcntl.FD_CLOEXEC)
        except Exception as e:
            raise Exception('socket: %s' % str(e))
        try:
            self.uds.bind(self.server_address)
        except Exception as e:
            raise Exception('bind: %s' % str(e))
        try:
            self.uds.setsockopt(socket.SOL_SOCKET, self.SO_PASSCRED, 1)
        except Exception as e:
            raise Exception('setsockopt: %s' % str(e))
        try:
            self.uds.listen(1)
        except Exception as e:
            raise Exception('listen: %s' % str(e))
        os.chmod(self.server_address, 0777)

    def __del__(self):
        if self.context:
            self.context.close()
        if self.uds:
            self.uds.close()

    @staticmethod
    def preload_imports():
        """
            preloading all the necessary modules
            at first will increase performances
        """
        try:
            import io
            import pdb
            import imp
            import sets
            import json
            import glob
            import time
            import copy
            import errno
            import pprint
            import atexit
            import ipaddr
            import cPickle
            import logging
            import argparse
            import StringIO
            import datetime
            import traceback
            import itertools
            import subprocess
            import argcomplete
            import collections
            import ConfigParser
            import pkg_resources

            import ifupdown2.ifupdown.exceptions
            import ifupdown2.ifupdown.graph
            import ifupdown2.ifupdown.iface
            import ifupdown2.ifupdown.iff
            import ifupdown2.ifupdown.ifupdownbase
            import ifupdown2.ifupdown.ifupdownbase
            import ifupdown2.ifupdown.ifupdownconfig
            import ifupdown2.ifupdown.ifupdownflags
            import ifupdown2.ifupdown.ifupdownmain
            import ifupdown2.ifupdown.netlink
            import ifupdown2.ifupdown.networkinterfaces
            import ifupdown2.ifupdown.policymanager
            import ifupdown2.ifupdown.scheduler
            import ifupdown2.ifupdown.statemanager
            import ifupdown2.ifupdown.template
            import ifupdown2.ifupdown.utils

            import ifupdown2.ifupdownaddons.cache
            import ifupdown2.ifupdownaddons.dhclient
            import ifupdown2.ifupdownaddons.mstpctlutil
            import ifupdown2.ifupdownaddons.LinkUtils
            import ifupdown2.ifupdownaddons.modulebase
            import ifupdown2.ifupdownaddons.systemutils
            import ifupdown2.ifupdownaddons.utilsbase
        except ImportError, e:
            raise ImportError('%s - required module not found' % str(e))

    @staticmethod
    def signal_handler(sig, frame):
        log.info('received %s' % 'SIGINT' if sig == signal.SIGINT else 'SIGTERM')
        Daemon.shutdown_event.set()

    @staticmethod
    def user_waiting_for_reply():
        return not log.is_syslog()

    def run(self):
        try:
            while True:
                if Daemon.shutdown_event.is_set():
                    log.info("shutdown signal RXed, breaking out loop")
                    break

                try:
                    (client_socket, client_address) = self.uds.accept()
                except socket.error as e:
                    log.error(str(e))
                    break

                pid = os.fork()
                if pid == 0:
                    exit(self.ifupdown2(client_socket))
                else:
                    log.tx_data(json.dumps({'pid': pid}), socket=client_socket)

                    start = datetime.datetime.now()
                    status = os.WEXITSTATUS(os.waitpid(pid, 0)[1])
                    end = datetime.datetime.now()

                    log.tx_data(json.dumps({'status': status}), socket=client_socket)
                    client_socket.close()

                    log.info('exit status %d - in %ssecs'
                             % (status, (end - start).total_seconds()))

        except Exception as e:
            log.error(e)
        self.uds.close()

    def get_client_uid(self, client_socket):
        creds = client_socket.getsockopt(socket.SOL_SOCKET, self.SO_PEERCRED, struct.calcsize('3i'))
        (pid, uid, gid) = struct.unpack('3i', creds)
        log.debug('client uid %d' % uid)
        return uid

    @staticmethod
    def get_client_request(client_socket):
        """
            This function handles requests of any length.

                if the received json is longer than 65k it will be truncated
                several calls to recv will be needed, we store the data until
                we can decode them with the json library.
        """
        data = []
        while True:
            log.debug('waiting for request on client socket')
            ready = select.select([client_socket], [], [])

            if ready and ready[0] and ready[0][0] == client_socket:
                # data available start reading
                raw_data = client_socket.recv(65536)

                try:
                    return json.loads(raw_data)
                except ValueError:
                    # the json is incomplete
                    data.append(raw_data)

                    if len(data) > 1:
                        try:
                            return json.loads(''.join(data))
                        except ValueError:
                            pass

    def ifupdown2(self, client_socket):
        try:
            fcntl.fcntl(client_socket.fileno(), fcntl.F_SETFD, fcntl.FD_CLOEXEC)

            ifupdown2 = Ifupdown2(daemon=True, uid=self.get_client_uid(client_socket))
            ifupdown2.set_signal_handlers()

            request = self.get_client_request(client_socket)
            log.info('request: %s' % request['argv'])

            ifupdown2.parse_argv(request['argv'])
            # adjust the logger with argv
            ifupdown2.update_logger(socket=client_socket)

            try:
                status = ifupdown2.main(request['stdin'])
            except Exception as e:
                log.error(str(e))
                status = 1

        except ifupdown2.ifupdown.argv.ArgvParseError as e:
            log.update_current_logger(syslog=False, verbose=True, debug=False)
            log.set_socket(client_socket)
            e.log_error()
            status = 1
        except Exception as e:
            log.error(e)
            status = 1

        log.flush()
        log.set_socket(None)
        client_socket.close()
        return status


if __name__ == '__main__':
    try:
        Daemon().run()
    except Exception as e:
        print e
        log.error(str(e))
        import traceback
        log.error(traceback.format_exc())
        exit(1)
