'''
PROJECT:     ReactOS apisets
LICENSE:     MIT (https://spdx.org/licenses/MIT)
PURPOSE:     Create apiset lookup table based on the data files of https://apisets.info
COPYRIGHT:   Copyright 2024 Mark Jansen <mark.jansen@reactos.org>
'''

from pathlib import Path
from dataclasses import dataclass, field
import sys
import json

# These are modules we do not have, so redirect them to ones we do have.
REDIRECT_HOSTS = {
    'kernelbase.dll': 'kernel32.dll',
    'kernel.appcore.dll': 'kernel32.dll',
    'combase.dll': 'ole32.dll',
    'ucrtbase.dll': 'msvcrt.dll',
    'shcore.dll': 'shell32.dll',
    'winmmbase.dll': 'winmm.dll',
    'gdi32full.dll': 'gdi32.dll'
}

OUTPUT_HEADER = """/*
 * PROJECT:     ReactOS apisets
 * LICENSE:     LGPL-2.1-or-later (https://spdx.org/licenses/LGPL-2.1-or-later)
 * PURPOSE:     Autogenerated table of all apisets
 * COPYRIGHT:   Copyright 2024 Mark Jansen <mark.jansen@reactos.org>
 */

#include <ndk/umtypes.h>
#include <ndk/rtlfuncs.h>
#include "apisetsp.h"

const ROS_APISET g_Apisets[] = {
"""

OUTPUT_FOOTER = """};

const LONG g_ApisetsCount = RTL_NUMBER_OF(g_Apisets);
"""

def winver_to_name(version):
    major, minor, build, _ = map(int, version.split('.'))
    if (major, minor) == (6, 1):
        return 'APISET_WIN7'
    if (major, minor) == (6, 2):
        return 'APISET_WIN8'
    if (major, minor) == (6, 3):
        return 'APISET_WIN81'
    if (major, minor) == (10, 0):
        if build < 22000:
            return 'APISET_WIN10'
        return 'APISET_WIN11'
    assert False, (major, minor, build)

@dataclass
class Apiset:
    name: str
    host: str
    versions: list[str] = field(default_factory=list)

    def add_version(self, version):
        if version not in self.versions:
            self.versions.append(version)

    def __str__(self):
        version_str = ' | '.join(self.versions)
        name = self.name
        assert name[-4:].lower() == '.dll'
        name = name[:-4]
        prefix, postfix = '', ''
        host = self.host
        if host == '':
            # Disable forwarders that have an empty host
            prefix = '// '
        else:
            # Check to see if there is any dll we want to swap (kernelbase -> kernel32)
            replace = REDIRECT_HOSTS.get(host.lower(), None)
            if replace:
                postfix = ' // ' + host
                host = replace
        return f'    {prefix}{{ RTL_CONSTANT_STRING(L"{name}"), RTL_CONSTANT_STRING(L"{host}"), {version_str} }},{postfix}'


class ApisetSchema:
    def __init__(self, file):
        self._data = json.load(file.open())
        self.version = winver_to_name(self._data['PE']['ProductVersion'])
        self._arch = self._data['PE']['Machine']

    def apisets(self):
        for elem in self._data['namespaces']:
            name = elem['name']
            host = elem['host']
            yield Apiset(name, host)


class CombinedSchemas:
    def __init__(self):
        self._apisets = {}

    def add(self, schema: ApisetSchema):
        for apiset in schema.apisets():
            lowername = apiset.name.lower()
            if lowername not in self._apisets:
                self._apisets[lowername] = apiset
            else:
                apiset = self._apisets[lowername]
            apiset.add_version(schema.version)

    def generate(self, output):
        for key in sorted(self._apisets):
            apiset = self._apisets[key]
            output.write(f'{apiset}\n'.encode('utf-8'))


def process_apisetschemas(input_dir: Path, output_file):
    schemas = CombinedSchemas()

    for schemafile in input_dir.glob('*.json'):
        schema = ApisetSchema(schemafile)
        # Skip Win11 for now
        if schema.version != 'APISET_WIN11':
            schemas.add(schema)

    output_file.write(OUTPUT_HEADER.encode('utf-8'))
    schemas.generate(output_file)
    output_file.write(OUTPUT_FOOTER.encode('utf-8'))


def usage():
    print('Usage: update.py <apisetschema folder>')
    print('    where <apisetschema folder> is the folder with all apisetschema json files')

def main(args):
    if len(args) < 1:
        return usage()

    apisetschemas = Path(args[0])
    if not apisetschemas.is_dir():
        return usage()

    output = Path(__file__).parent / 'apisets.table.c'

    process_apisetschemas(apisetschemas, output.open('wb'))


if __name__ == '__main__':
    main(sys.argv[1:])