/*
 * PE Fixup Utility
 * Copyright (C) 2005 Filip Navara
 * Copyright (C) 2020 Mark Jansen
 *
 * The purpose of this utility is fix PE binaries generated by binutils and
 * to manipulate flags that can't be set by binutils.
 *
 * Currently one features is implemented:
 *
 * - Updating the PE header to use a LOAD_CONFIG,
 *   when the struct is exported with the name '_load_config_used'
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

// host_includes
#include <typedefs.h>
#include <pecoff.h>
#include "../../dll/win32/dbghelp/compat.h"

static const char* g_ApplicationName;
static const char* g_Target;

enum fixup_mode
{
    MODE_LOADCONFIG,
    MODE_KERNELDRIVER,
    MODE_WDMDRIVER,
    MODE_KERNELDLL,
    MODE_KERNEL
};

void *rva_to_ptr(unsigned char *buffer, PIMAGE_NT_HEADERS nt_header, DWORD rva)
{
    unsigned int i;
    PIMAGE_SECTION_HEADER section_header = IMAGE_FIRST_SECTION(nt_header);

    for (i = 0; i < nt_header->FileHeader.NumberOfSections; i++, section_header++)
    {
        if (rva >= section_header->VirtualAddress &&
            rva < section_header->VirtualAddress + section_header->Misc.VirtualSize)
        {
            return buffer + rva - section_header->VirtualAddress + section_header->PointerToRawData;
        }
    }

    return NULL;
}

static void error(const char* message, ...)
{
    va_list args;

    fprintf(stderr, "%s ERROR: '%s': ", g_ApplicationName, g_Target);

    va_start(args, message);
    fprintf(stderr, message, args);
    va_end(args);
}

static void fix_checksum(unsigned char *buffer, long len, PIMAGE_NT_HEADERS nt_header)
{
    unsigned int checksum = 0;
    long n;

    nt_header->OptionalHeader.CheckSum = 0;

    for (n = 0; n < len; n += 2)
    {
        checksum += *(unsigned short *)(buffer + n);
        checksum = (checksum + (checksum >> 16)) & 0xffff;
    }

    checksum += len;
    nt_header->OptionalHeader.CheckSum = checksum;
}

static int add_loadconfig(unsigned char *buffer, PIMAGE_NT_HEADERS nt_header)
{
    PIMAGE_DATA_DIRECTORY export_dir;
    PIMAGE_EXPORT_DIRECTORY export_directory;

    export_dir = &nt_header->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_EXPORT];
    if (export_dir->Size != 0)
    {
        export_directory = rva_to_ptr(buffer, nt_header, export_dir->VirtualAddress);
        if (export_directory != NULL)
        {
            DWORD *name_ptr, *function_ptr, n;
            WORD *ordinal_ptr;

            name_ptr = rva_to_ptr(buffer, nt_header, export_directory->AddressOfNames);
            ordinal_ptr = rva_to_ptr(buffer, nt_header, export_directory->AddressOfNameOrdinals);
            function_ptr = rva_to_ptr(buffer, nt_header, export_directory->AddressOfFunctions);

            for (n = 0; n < export_directory->NumberOfNames; n++)
            {
                const char* name = rva_to_ptr(buffer, nt_header, name_ptr[n]);
                if (!strcmp(name, "_load_config_used"))
                {
                    PIMAGE_DATA_DIRECTORY load_config_dir;
                    DWORD load_config_rva = function_ptr[ordinal_ptr[n]];
                    DWORD* load_config_ptr = rva_to_ptr(buffer, nt_header, load_config_rva);

                    /* Update the DataDirectory pointer / size
                       The first entry of the LOAD_CONFIG struct is the size, use that as DataDirectory.Size */
                    load_config_dir = &nt_header->OptionalHeader.DataDirectory[IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG];
                    load_config_dir->VirtualAddress = load_config_rva;
                    load_config_dir->Size = *load_config_ptr;

                    return 0;
                }
            }

            error("Export '_load_config_used' not found\n");
        }
        else
        {
            error("Invalid rva for export directory\n");
        }
    }
    else
    {
        error("No export directory\n");
    }

    return 1;
}

static int driver_fixup(int mode, unsigned char *buffer, PIMAGE_NT_HEADERS nt_header)
{
    /* GNU LD just doesn't know what a driver is, and has notably no idea of paged vs non-paged sections */
    for (unsigned i = 0; i < nt_header->FileHeader.NumberOfSections; i++)
    {
        PIMAGE_SECTION_HEADER Section = IMAGE_FIRST_SECTION(nt_header) + i;

        /* LD puts alignment crap that nobody asked for */
        Section->Characteristics &= ~IMAGE_SCN_ALIGN_MASK;

        /* LD overdoes it and puts the initialized flag everywhere */
        if (Section->Characteristics & IMAGE_SCN_CNT_CODE)
            Section->Characteristics &= ~IMAGE_SCN_CNT_INITIALIZED_DATA;

        /* For some reason, .rsrc is made writable by windres */
        if (strncasecmp((char*)Section->Name, ".rsrc", 5) == 0)
        {
            Section->Characteristics &= ~IMAGE_SCN_MEM_WRITE;
            continue;
        }

        /* Known sections which can be discarded */
        if (strncasecmp((char*)Section->Name, "INIT", 4) == 0)
        {
            Section->Characteristics |= IMAGE_SCN_MEM_DISCARDABLE;
            continue;
        }

        /* Known sections which can be paged */
        if ((strncasecmp((char*)Section->Name, "PAGE", 4) == 0)
            || (strncasecmp((char*)Section->Name, ".rsrc", 5) == 0)
            || (strncasecmp((char*)Section->Name, ".edata", 6) == 0)
            || (strncasecmp((char*)Section->Name, ".reloc", 6) == 0))
        {
            continue;
        }

        /* If it's discardable, don't set the flag */
        if (Section->Characteristics & IMAGE_SCN_MEM_DISCARDABLE)
            continue;

        Section->Characteristics |= IMAGE_SCN_MEM_NOT_PAGED;
    }

    return 0;
}

static
void
print_usage(void)
{
    printf("Usage: %s <mode> <filename>\n", g_ApplicationName);
    printf("Where <mode> is on of the following:\n");
    printf("  --loadconfig          Fix the LOAD_CONFIG directory entry\n");
    printf("  --kernelmodedriver    Fix code and data sections for driver images\n");
    printf("  --wdmdriver           Fix code and data sections for WDM drivers\n");
    printf("  --kerneldll           Fix code and data sections for Kernel-Mode DLLs\n");
    printf("  --kernel              Fix code and data sections for kernels\n");
}

int main(int argc, char **argv)
{
    FILE* file;
    long len;
    unsigned char *buffer;
    PIMAGE_DOS_HEADER dos_header;
    int result = 1;
    enum fixup_mode mode;

    g_ApplicationName = argv[0];

    if (argc != 3)
    {
        print_usage();
        return 1;
    }

    if (strcmp(argv[1], "--loadconfig") == 0)
    {
        mode = MODE_LOADCONFIG;
    }
    else if (strcmp(argv[1], "--kernelmodedriver") == 0)
    {
        mode = MODE_KERNELDRIVER;
    }
    else if (strcmp(argv[1], "--wdmdriver") == 0)
    {
        mode = MODE_WDMDRIVER;
    }
    else if (strcmp(argv[1], "--kerneldll") == 0)
    {
        mode = MODE_KERNELDLL;
    }
    else if (strcmp(argv[1], "--kernel") == 0)
    {
        mode = MODE_KERNEL;
    }
    else
    {
        print_usage();
        return 1;
    }

    g_Target = argv[2];

    /* Read the whole file to memory. */
    file = fopen(g_Target, "r+b");
    if (!file)
    {
        fprintf(stderr, "%s ERROR: Can't open '%s'.\n", g_ApplicationName, g_Target);
        return 1;
    }

    fseek(file, 0, SEEK_END);
    len = ftell(file);
    if (len < sizeof(IMAGE_DOS_HEADER))
    {
        fclose(file);
        error("Image size too small to be a PE image\n");
        return 1;
    }

    /* Add one byte extra for the case where the input file size is odd.
       We rely on this in our crc calculation */
    buffer = calloc(len + 1, 1);
    if (buffer == NULL)
    {
        fclose(file);
        error("Not enough memory available: (Needed %u bytes).\n", len + 1);
        return 1;
    }

    /* Read the whole input file into a buffer */
    fseek(file, 0, SEEK_SET);
    fread(buffer, 1, len, file);

    /* Check the headers and save pointers to them. */
    dos_header = (PIMAGE_DOS_HEADER)buffer;
    if (dos_header->e_magic == IMAGE_DOS_SIGNATURE)
    {
        PIMAGE_NT_HEADERS nt_header;

        nt_header = (PIMAGE_NT_HEADERS)(buffer + dos_header->e_lfanew);

        if (nt_header->Signature == IMAGE_NT_SIGNATURE)
        {
            if (mode == MODE_LOADCONFIG)
                result = add_loadconfig(buffer, nt_header);
            else
                result = driver_fixup(mode, buffer, nt_header);

            if (!result)
            {
                /* Success. Fix checksum and write to file */
                fix_checksum(buffer, len, nt_header);

                /* We could 'optimize by only writing the changed parts, but keep it simple for now */
                fseek(file, 0, SEEK_SET);
                fwrite(buffer, 1, len, file);
            }
        }
        else
        {
            error("Invalid PE signature: %x\n", nt_header->Signature);
        }
    }
    else
    {
        error("Invalid DOS signature: %x\n", dos_header->e_magic);
    }

    free(buffer);
    fclose(file);

    return result;
}