#!/usr/bin/python3
#
#   Copyright 2021 Canonical Ltd.
#   Author: Alberto Milone <alberto.milone@canonical.com>
#

import json
import re
import logging
import sys
import argparse


class DeviceInfo(object):

    def __init__(self, device):
        self._vendor = '10DE'
        self._device = device
        self._devid = (device.get('devid').replace('0x', '')
                       if device.get('devid') else '*')
        self._subvendorid = (device.get('subvendorid').replace('0x', '')
                             if device.get('subvendorid') else '*')
        self._subdevid = (device.get('subdevid').replace('0x', '')
                          if device.get('subdevid') else '*')

    def get_vendor(self):
        return self._vendor

    def get_devid(self):
        return self._devid

    def get_subvendorid(self):
        return self._subvendorid

    def get_subdevid(self):
        return self._subdevid

    def get_features(self):
        return self._device.get('features')

    def get_legacy_series(self):
        return self._device.get('legacybranch')

    def get_device_name(self):
        return self._device.get('name')

    def supports_runtimepm(self):
        return ('runtimepm' in self.get_features())

    def supports_kernelopen(self):
        return ('kernelopen' in self.get_features())

    def is_tesla(self):
        return ('tesla' in self.get_features())


def _get_devices_from_data(data):
    devices = {}
    match = None
    for chip in data["chips"]:
        devices[chip.get('devid')] = chip

    return devices

def _supports_runtimepm(device):
    features = device.get('features')
    if features:
        return ('runtimepm' in features)
    return None

def _is_legacy(device):
    legacy = device.get('legacybranch')
    return legacy != None

def _supports_kernelopen(device):
    features = device.get('features')
    if features:
        return ('kernelopen' in features)
    return None

def _get_supported_devices(data):
    supported_devices = []
    devices = _get_devices_from_data(data)
    for device in devices:
        if not _is_legacy(devices.get(device)):
            if open_module and not _supports_kernelopen(devices.get(device)):
                continue
            supported_devices.append(DeviceInfo(devices.get(device)))
    return supported_devices

def print_aliases(data, module, package):
    devices = _get_supported_devices(data)
    template = 'alias pci:v0000%sd0000%ssv%ssd%sbc03sc*i* %s %s%s'
    line = '# List generated by nvidia_supported. Do not edit manually.\n'

    it = 0
    for device in devices:
        # Remove sv/sd matching for now, since the subsystem vendor/device
        # IDs in supported-gpus.json are the Nvidia IDs.
        # (The finding is based on the contents of nvidia-driver-470)
        #
        # subvendor_id = '0000%s' % (device.get_subvendorid()
        #                            if device.get_subvendorid() != '*'
        #                            else '*')
        # subdev_id = '0000%s' % (device.get_subdevid()
        #                         if device.get_subdevid() != '*'
        #                         else '*')
        subvendor_id = '*'
        subdev_id = '*'
        line += template % (device.get_vendor(), device.get_devid(),
                            subvendor_id,
                            subdev_id,
                            module,
                            package,
                            '\n' if it < (len(devices) - 1) else '')
        it +=1
    print(line)

def main(file_stream, module, package):
    logger = logging.basicConfig(level=logging.DEBUG, stream=sys.stdout)
    data = json.load(file_stream)

    print_aliases(data, module, package)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('module_name', help='The module name')
    parser.add_argument('package_name', help='The package name')
    parser.add_argument('source', help='The .json source file to parse')
    parser.add_argument('-o', '--open', action='store_true',
                        help="The open kernel module")
    args = parser.parse_args()
    global open_module
    open_module = True if args.open else False

    with open(args.source) as f:
        main(f, args.module_name, args.package_name)
        f.close()
