#!/usr/bin/python3
# SPDX-License-Identifier: AGPL-3.0-or-later

import argparse
import importlib
import inspect
import json
import logging
import os
import sys
import traceback
import types
import typing

import plinth.log
from plinth import cfg, module_loader

EXIT_SYNTAX = 10
EXIT_PERM = 20

logger = logging.getLogger(__name__)


def main():
    """Parse arguments."""
    plinth.log.action_init()

    parser = argparse.ArgumentParser()
    parser.add_argument('module', help='Module to trigger action in')
    parser.add_argument('action', help='Action to trigger in module')
    parser.add_argument('--write-fd', type=int, default=1,
                        help='File descriptor to write output to')
    parser.add_argument('--no-args', default=False, action='store_true',
                        help='Do not read arguments from stdin')
    args = parser.parse_args()

    try:
        try:
            arguments = {'args': [], 'kwargs': {}}
            if not args.no_args:
                input_ = sys.stdin.read()
                if input_:
                    arguments = json.loads(input_)
        except json.JSONDecodeError as exception:
            raise SyntaxError('Arguments on stdin not JSON.') from exception

        return_value = _call(args.module, args.action, arguments)
        with os.fdopen(args.write_fd, 'w') as write_file_handle:
            write_file_handle.write(json.dumps(return_value))
    except PermissionError as exception:
        logger.error(exception.args[0])
        sys.exit(EXIT_PERM)
    except SyntaxError as exception:
        logger.error(exception.args[0])
        sys.exit(EXIT_SYNTAX)
    except TypeError as exception:
        logger.error(exception.args[0])
        sys.exit(EXIT_SYNTAX)
    except Exception as exception:
        logger.exception(exception)
        sys.exit(1)


def _call(module_name, action_name, arguments):
    """Import the module and run action as superuser"""
    if '.' in module_name:
        raise SyntaxError('Invalid module name')

    cfg.read()
    if module_name == 'plinth':
        import_path = 'plinth'
    else:
        import_path = module_loader.get_module_import_path(module_name)

    try:
        module = importlib.import_module(import_path + '.privileged')
    except ModuleNotFoundError as exception:
        raise SyntaxError('Specified module not found') from exception

    try:
        action = getattr(module, action_name)
    except AttributeError as exception:
        raise SyntaxError('Specified action not found') from exception

    if not getattr(action, '_privileged', None):
        raise SyntaxError('Specified action is not privileged action')

    func = getattr(action, '__wrapped__')

    _assert_valid_arguments(func, arguments)

    try:
        return_values = func(*arguments['args'], **arguments['kwargs'])
        return_value = {'result': 'success', 'return': return_values}
    except Exception as exception:
        logger.exception('Error executing action: %s', exception)
        return_value = {
            'result': 'exception',
            'exception': {
                'module': type(exception).__module__,
                'name': type(exception).__name__,
                'args': exception.args,
                'traceback': traceback.format_tb(exception.__traceback__)
            }
        }

    return return_value


def _assert_valid_arguments(func, arguments):
    """Check the names, types and completeness of the arguments passed."""
    # Check if arguments match types
    if not isinstance(arguments, dict):
        raise SyntaxError('Invalid arguments format')

    if 'args' not in arguments or 'kwargs' not in arguments:
        raise SyntaxError('Invalid arguments format')

    args = arguments['args']
    kwargs = arguments['kwargs']
    if not isinstance(args, list) or not isinstance(kwargs, dict):
        raise SyntaxError('Invalid arguments format')

    argspec = inspect.getfullargspec(func)
    if len(args) + len(kwargs) > len(argspec.args):
        raise SyntaxError('Too many arguments')

    no_defaults = len(argspec.args)
    if argspec.defaults:
        no_defaults -= len(argspec.defaults)

    for key in argspec.args[len(args):no_defaults]:
        if key not in kwargs:
            raise SyntaxError(f'Argument not provided: {key}')

    for key, value in kwargs.items():
        if key not in argspec.args:
            raise SyntaxError(f'Unknown argument: {key}')

        if argspec.args.index(key) < len(args):
            raise SyntaxError(f'Duplicate argument: {key}')

        _assert_valid_type(f'arg {key}', value, argspec.annotations[key])

    for index, arg in enumerate(args):
        annotation = argspec.annotations[argspec.args[index]]
        _assert_valid_type(f'arg #{index}', arg, annotation)


def _assert_valid_type(arg_name, value, annotation):
    """Assert that the type of argument value matches the annotation."""
    if annotation == typing.Any:
        return

    NoneType = type(None)
    if annotation == NoneType:
        if value is not None:
            raise TypeError('Expected None for {arg_name}')

        return

    basic_types = {bool, int, str, float}
    if annotation in basic_types:
        if not isinstance(value, annotation):
            raise TypeError(
                f'Expected type {annotation.__name__} for {arg_name}')

        return

    # 'int | str' or 'typing.Union[int, str]'
    if (isinstance(annotation, types.UnionType)
            or getattr(annotation, '__origin__', None) == typing.Union):
        for arg in annotation.__args__:
            try:
                _assert_valid_type(arg_name, value, arg)
                return
            except TypeError:
                pass

        raise TypeError(f'Expected one of unioned types for {arg_name}')

    # 'list[int]' or 'typing.List[int]'
    if getattr(annotation, '__origin__', None) == list:
        if not isinstance(value, list):
            raise TypeError(f'Expected type list for {arg_name}')

        for index, inner_item in enumerate(value):
            _assert_valid_type(f'{arg_name}[{index}]', inner_item,
                               annotation.__args__[0])

        return

    # 'list[dict]' or 'typing.List[dict]'
    if getattr(annotation, '__origin__', None) == dict:
        if not isinstance(value, dict):
            raise TypeError(f'Expected type dict for {arg_name}')

        for inner_key, inner_value in value.items():
            _assert_valid_type(f'{arg_name}[{inner_key}]', inner_key,
                               annotation.__args__[0])
            _assert_valid_type(f'{arg_name}[{inner_value}]', inner_value,
                               annotation.__args__[1])

        return

    raise TypeError('Unsupported annotation type')


if __name__ == '__main__':
    main()
