//
// This file is part of libdebug Python library (https://github.com/libdebug/libdebug).
// Copyright (c) 2024 Roberto Alessandro Bertolini. All rights reserved.
// Licensed under the MIT license. See LICENSE file in the project root for details.
//

// The purpose of this script is to autogenerate the layout of the xsave area
// for the current CPU and dump it to a generated header file.

#include <cpuid.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <signal.h>
#include <sys/ptrace.h>
#include <sys/types.h>
#include <sys/uio.h>
#include <sys/wait.h>
#include <unistd.h>

#define NT_X86_XSTATE 0x202

/* The extended state feature IDs in the state component bitmap.  */
#define X86_XSTATE_X87_ID	0
#define X86_XSTATE_SSE_ID	1
#define X86_XSTATE_AVX_ID	2
#define X86_XSTATE_BNDREGS_ID	3
#define X86_XSTATE_BNDCFG_ID	4
#define X86_XSTATE_K_ID		5
#define X86_XSTATE_ZMM_H_ID	6
#define X86_XSTATE_ZMM_ID	7
#define X86_XSTATE_PKRU_ID	9
#define X86_XSTATE_TILECFG_ID	17
#define X86_XSTATE_TILEDATA_ID	18
#define X86_XSTATE_APX_F_ID	19

int has_xsave()
{
    uint32_t eax, ebx, ecx, edx;

    __cpuid(0x0d, eax, ebx, ecx, edx);

    return eax & 0x1;
}

int xsave_element_offset(int element)
{
    uint32_t eax, ebx, ecx, edx;

    __cpuid_count(0xd, element, eax, ebx, ecx, edx);

    return ebx;
}

int xsave_element_size(int element)
{
    uint32_t eax, ebx, ecx, edx;

    __cpuid_count(0xd, element, eax, ebx, ecx, edx);

    return eax;
}

int xsave_area_size()
{
    uint32_t eax, ebx, ecx, edx;

    __cpuid_count(0xd, 0x0, eax, ebx, ecx, edx);

    return ebx;
}

void dump_file_header()
{
    puts("//\n"
"// This file is part of libdebug Python library (https://github.com/libdebug/libdebug).\n"
"// It was autogenerated by libdebug/ptrace/native/xsave/autogenerate_xsave_layout.c.\n"
"// Licensed under the MIT license. See LICENSE file in the project root for details.\n"
"//\n"
"\n"
"// This source file contains the layout of the xsave area for the current CPU.\n"
"// Along with the necessary nanobind bindings to access it.\n"
"\n"
"#pragma once\n"
"\n"
"#include <nanobind/nanobind.h>\n"
"\n"
"namespace nb = nanobind;\n"
"\n");
}

int main(int argc, char *argv[])
{
    int pid = fork();

    int has_avx = 0, has_avx512 = 0;
    int has_xsave = 0;

    if (!pid) {
        if (ptrace(PTRACE_TRACEME, 0, 0, 0) == -1) {
            fprintf(stderr, "Failed to trace me\n");
            return 1;
        }

        raise(SIGSTOP);
    } else {
        if (waitpid(pid, NULL, 0) == -1) {
            fprintf(stderr, "Failed to wait for child\n");
            return 1;
        }
    }

    // dump a maximum size struct
    int *xsave_struct = malloc(4088);
    if (xsave_struct == NULL) {
        fprintf(stderr, "Failed to allocate memory\n");
        return 1;
    }

    struct iovec iov = {
        .iov_base = xsave_struct,
        .iov_len = 4088
    };

    // get the xsave area
    if (ptrace(PTRACE_GETREGSET, pid, NT_X86_XSTATE, &iov) == -1) {
        fprintf(stderr, "Failed to get xsave area\n");

        // this probably means that the CPU (or kernel) doesn't support xsave
        // we can still get the fp regs through GETFPREGS
        has_avx = has_avx512 = has_xsave = 0;
        goto no_xsave;
    } else {
        has_xsave = 1;
    }

    // kill the child
    kill(pid, SIGKILL);

    // wait for the child to die
    if (waitpid(pid, NULL, 0) == -1) {
        fprintf(stderr, "Failed to wait for child\n");
        return 1;
    }

    // get xcr0
    int xcr0 = xsave_struct[464 / 4];

    dump_file_header();

    printf("// Detected XSAVE feature max = %d\n", xcr0);

    puts("#pragma pack(push, 1)");
    puts("struct PtraceFPRegsStruct");
    puts("{");
    puts("    unsigned long type;");
    puts("    bool dirty;");
    puts("    bool fresh;");
    puts("    unsigned char bool_padding[6];");
    puts("    unsigned char padding0[32];");
    puts("    std::array<Reg128, 8> mmx;");
    puts("    std::array<Reg128, 16> xmm0;");
    puts("    unsigned char padding1[96];");

    int current_size = 512;

    // if we have AVX
    if (xcr0 & (1 << X86_XSTATE_AVX_ID)) {
        int avx_offset = xsave_element_offset(X86_XSTATE_AVX_ID);
        int avx_size = xsave_element_size(X86_XSTATE_AVX_ID);
        if (avx_offset < current_size) {
            fprintf(stderr, "AVX offset is less than current size\n");
            return 1;
        } else if (avx_offset > current_size) {
            printf("    unsigned char padding2[%d];\n", avx_offset - current_size);
        }
        puts("    std::array<Reg128, 16> ymm0;");

        has_avx = 1;

        current_size = avx_offset + avx_size;
    }

    // if we have MPX
    if (xcr0 & (1 << X86_XSTATE_BNDREGS_ID)) {
        int mpx_offset = xsave_element_offset(X86_XSTATE_BNDREGS_ID);
        int mpx_size = xsave_element_size(X86_XSTATE_BNDREGS_ID);
        if (mpx_offset < current_size) {
            fprintf(stderr, "MPX offset is less than current size\n");
            return 1;
        } else if (mpx_offset > current_size) {
            printf("    unsigned char padding3[%d];\n", mpx_offset - current_size);
        }
        puts("    std::array<Reg128, 4> bndregs;");

        current_size = mpx_offset + mpx_size;
    }

    // if we have MPX
    if (xcr0 & (1 << X86_XSTATE_BNDCFG_ID)) {
        int mpx_offset = xsave_element_offset(X86_XSTATE_BNDCFG_ID);
        int mpx_size = xsave_element_size(X86_XSTATE_BNDCFG_ID);
        if (mpx_offset < current_size) {
            fprintf(stderr, "MPX offset is less than current size\n");
            return 1;
        } else if (mpx_offset > current_size) {
            printf("    unsigned char padding3[%d];\n", mpx_offset - current_size);
        }
        puts("    Reg128 bndcfg;");
        puts("    unsigned char padding4[48];");

        current_size = mpx_offset + mpx_size;
    }

    // if we have AVX-512
    if (xcr0 & (1 << X86_XSTATE_K_ID)) {
        int avx512_offset = xsave_element_offset(X86_XSTATE_K_ID);
        int avx512_size = xsave_element_size(X86_XSTATE_K_ID);
        if (avx512_offset < current_size) {
            fprintf(stderr, "AVX-512 offset is less than current size\n");
            return 1;
        } else if (avx512_offset > current_size) {
            printf("    unsigned char padding4[%d];\n", avx512_offset - current_size);
        }
        puts("    unsigned long long kmask[8];");

        current_size = avx512_offset + avx512_size;
    }

    // if we have AVX-512
    if (xcr0 & (1 << X86_XSTATE_ZMM_H_ID)) {
        int avx512_offset = xsave_element_offset(X86_XSTATE_ZMM_H_ID);
        int avx512_size = xsave_element_size(X86_XSTATE_ZMM_H_ID);
        if (avx512_offset < current_size) {
            fprintf(stderr, "AVX-512 offset is less than current size\n");
            return 1;
        } else if (avx512_offset > current_size) {
            printf("    unsigned char padding4[%d];\n", avx512_offset - current_size);
        }
        puts("    std::array<Reg256, 16> zmm0;");

        current_size = avx512_offset + avx512_size;
    }

    // if we have AVX-512
    if (xcr0 & (1 << X86_XSTATE_ZMM_ID)) {
        int avx512_offset = xsave_element_offset(X86_XSTATE_ZMM_ID);
        int avx512_size = xsave_element_size(X86_XSTATE_ZMM_ID);
        if (avx512_offset < current_size) {
            fprintf(stderr, "AVX-512 offset is less than current size\n");
            return 1;
        } else if (avx512_offset > current_size) {
            printf("    unsigned char padding5[%d];\n", avx512_offset - current_size);
        }
        puts("    std::array<Reg512, 16> zmm1;");

        has_avx512 = 1;

        current_size = avx512_offset + avx512_size;
    }

    // If we have PKRU
    if (xcr0 & (1 << X86_XSTATE_PKRU_ID)) {
        int pkru_offset = xsave_element_offset(X86_XSTATE_PKRU_ID);
        int pkru_size = xsave_element_size(X86_XSTATE_PKRU_ID);
        if (pkru_offset < current_size) {
            fprintf(stderr, "PKRU offset is less than current size\n");
            return 1;
        } else if (pkru_offset > current_size) {
            printf("    unsigned char padding6[%d];\n", pkru_offset - current_size);
        }
        puts("    unsigned int pkru;");
        puts("    unsigned char padding7[60];");

        current_size = pkru_offset + pkru_size;
    }

no_xsave:
    puts("};");
    puts("#pragma pack(pop)");

    puts("");

    printf("// Size of struct fp_regs_struct = %d\n", current_size);
    printf("// Expected size of struct fp_regs_struct = %d\n", xsave_area_size());

    puts("");

    if (has_xsave) {
        puts("#define HAS_XSAVE 1");
    } else {
        puts("#define HAS_XSAVE 0");
    }

    if (!has_avx && !has_avx512) {
        puts("#define FPREGS_TYPE 0");
    } else if (has_avx && !has_avx512) {
        puts("#define FPREGS_TYPE 1");
    } else if (has_avx && has_avx512) {
        puts("#define FPREGS_TYPE 2");
    } else {
        printf("Bad state detected!\n");
        return 1;
    }

    puts("");

    // This guard is needed to avoid redefinition of the function
    // when including the header file in multiple source files
    puts("#ifdef DECLARE_NANOBIND");
    puts("");

    // Now we need to dump the nanobind function that will register the class definition
    puts("void init_fpregs_struct(nanobind::module_ &m)");
    puts("{");
    puts("    nb::class_<PtraceFPRegsStruct>(m, \"PtraceFPRegsStruct\")");
    puts("        .def_ro(\"type\", &PtraceFPRegsStruct::type)");
    puts("        .def_rw(\"dirty\", &PtraceFPRegsStruct::dirty)");
    puts("        .def_rw(\"fresh\", &PtraceFPRegsStruct::fresh)");

    if (!has_avx && !has_avx512) {
        puts("        .def_ro(\"mmx\", &PtraceFPRegsStruct::mmx);");
    } else if (has_avx && !has_avx512) {
        puts("        .def_ro(\"mmx\", &PtraceFPRegsStruct::mmx)");
        puts("        .def_ro(\"xmm0\", &PtraceFPRegsStruct::xmm0)");
        puts("        .def_ro(\"ymm0\", &PtraceFPRegsStruct::ymm0);");
    } else if (has_avx && has_avx512) {
        puts("        .def_ro(\"mmx\", &PtraceFPRegsStruct::mmx)");
        puts("        .def_ro(\"xmm0\", &PtraceFPRegsStruct::xmm0)");
        puts("        .def_ro(\"ymm0\", &PtraceFPRegsStruct::ymm0)");
        puts("        .def_ro(\"zmm0\", &PtraceFPRegsStruct::zmm0)");
        puts("        .def_ro(\"zmm1\", &PtraceFPRegsStruct::zmm1);");
    } else {
        printf("Bad state detected!\n");
        return 1;
    }

    puts("}");
    puts("");

    puts("#endif");
    puts("");

    return 0;
}
