#include <regex.h>
#include <string.h>
#include <dirent.h>
#include <err.h>
#include <stdio.h>
#include <inttypes.h>
#include <fcntl.h>
#include <unistd.h>
#include <sys/socket.h>
#include <netinet/in.h>

#include <mbedtls/pk.h>
#include <mbedtls/entropy.h>
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/gcm.h"

const unsigned short PORT = 58132;

static const unsigned char privkey[] =
"-----BEGIN RSA PRIVATE KEY-----\n"
"MIIEpAIBAAKCAQEAwWlNmLHOOzZpdrfp+EAANwqab0FhQwCyZ/u+ySBW5XxPf6mo\n"
"bySvtJrLdsWzdwnup/UfwZiEhJk/4wpD4Qf/2+syuJi3Rf7L+Jfh//Qf9uXAS80+\n"
"LYad7dW0c1r5nt+F9Can5fBn7futnd8n672T+y8QpHRwX9GtaILvYQe5GQac8cfq\n"
"2rGUd5iYj5KSdcaIZnJ4YgnjLHg2PMbtEJwqcV+2oAkcOPzTAJoNE7XjLZTwXmLl\n"
"FgL/2cN4uJZBDnwv3RZSAhpdYF4KOJmE2GFs2jdvRUrYO7WSl8fM16vRH4vz5MNN\n"
"caprg2MlXheVTPQa+WMdcz7dyQx8s9kRVPPUSwIDAQABAoIBAH1KD8A4flYRO2Ry\n"
"YxgzrW/6aGxlt/HFg8ykYcS8NE5Yps8WQkwtQb0HAYKhM06LmpQm0DmC6WVUOPSE\n"
"c9BUdEQsKiE2nJK1KcCR8w7xP7uavWTdQcgQCkJFS63mYwmt1oKAgAcOIuUhQiig\n"
"pKWrmy7+IBPIcftAQssO9q6uaBNy+ONu6KU/FYd4UAoEt07MzuIAl5rybROOWCrA\n"
"cjuOKK830Q2Mi2ESwwlO0+3Vyz9VhSha5wW2WwFv9zx8YQblaVXxfXC6O5XcXb0j\n"
"O2rTpgHMmOVik3Zrg3XoRXsNZCvFQqbIevwGhRNEFQTnzakh/5g0VDa54RAK/APC\n"
"t3ABEAECgYEA+rnqDqmXQc/xMCLOhJduxRFvZbT6EbUVJG5+l6XtATTa4LWzqx1B\n"
"QDXS/Dixxc9FA1vSH2rAW//6wbr8KXihSeI7YxgIfWyrSIxbS3Dd1qwddvniYIpx\n"
"ms9vYpQKAewv2p310nf7fyuURES8YikhpSuff3DhzXBEi7s3Y7gIIEsCgYEAxXrD\n"
"6xjgLSgbbdyqxKeNk8ADMifbj8ZoiNIHkJ6sShFaTEbyHweOJ4RY7OmTBUDhgzUY\n"
"1oynztilADFaq8KhsiMqzI3DgZ3/2OElGYReEAbXljvgudL3tmiBWc5j9S7XNLBe\n"
"u9f3WYAYDu7/BVqmQWa/QtD9JquoZ1xgV2D6nAECgYEA29w9t9/VSJvM9xX+nNyi\n"
"AON6GOjrRK3TPWA7WEXjH+S2bshHJi0ANAs+2Xfpw/kunnRdPLmCtuowfMO4LbGf\n"
"VcexpgLEJyAszvBteikeDwpcyCD19wxP9J4kIYCJiggQKpfLoWUfP/P6DydrPnSt\n"
"EUbAlaNqDpl9Mj7YonQVhCMCgYEAv0c9M5+hrDuX7d76zYaZvI4UymUO54E/yZ7e\n"
"UvdOXGPYed+SL/oKeD5aQAeyLzl79bHdgBs3g0QW9kvXzly0cC5eC0oZH5hhs7nI\n"
"TKII1i86bLtM3dD5vQYWnF0sNtWK/+8Bo6L5ZAiNxRE7lP0L4ndaNKbnPaixcoRo\n"
"kNpPhAECgYAg7jmNlw+7VVurzR36LKKE+d6veraF5ltpJiboDb3j38RGe3982LLq\n"
"WaBKm1gkHfXgBjkNzS4r2kyRijw0GQ9JgmWooR7BXFH30HkPNl4gFTSsrOG5zGLi\n"
"0aexkDpXQuKsgBzqU0Wn94GZMM+RhuOYec/56JFT+8U5Tcntb26wwA==\n"
"-----END RSA PRIVATE KEY-----\n";

static const unsigned char pp[] = "IJUHZGFDXTZKHJKHGFDHZLUÖDRTFGHHJGHH";

// Close files after 20 minutes
#define ACCESS_TIMEOUT (20*60)

/*
   1 byte type: 0 allocate session, 1 log to session, rest: discard
   8 bytes session id
   type 0:
      256 bytes rsa 2048 data yielding 32 byte AES session key
   type 1:
      16 bytes iv
      16 bytes tag
      rest cipher text

   File name format: 2020-10-13-23-42-05-SSSSSSSSSSSSSSSS-KKKKKKKKKKKKKKKKKKKKKKKKKKKKKKKK.log

*/

enum { SESSION_ID_LENGTH = 8, AES_KEY_LENGTH = 16, GCM_IV_LENGTH = 16, GCM_TAG_LENGTH = 16 };
typedef struct {
    uint8_t session_id[SESSION_ID_LENGTH];
    uint8_t key[AES_KEY_LENGTH];
    int     fd; // or -1 if not existing
    time_t  last_access;
    char    file_name[80];
} SESSION;
static SESSION *g_sessions = 0;
static size_t g_session_count = 0;

int session_compare(const void *a, const void *b) { return memcmp(a, b, 8); }

enum { SIDOFFS = 20, KEYOFFS = 37 };

static uint8_t hex2nyble(char c)
{
    return (c>='0'&&c<='9') ? (uint8_t)(c-'0')
        : (c>='a'&&c<='f') ? (uint8_t)(c-'a'+10)
        : (c>='A'&&c<='F') ? (uint8_t)(c-'A'+10)
        : 0;
}

static void import_sessions(const char *dirname) {
    DIR * dirp = opendir(dirname);
    if (!dirp)
        errx(-1, "Fatal: Can't open dir %s\n", dirname);

    free(g_sessions);
    g_session_count = 0;

    size_t allocated = 1024;
    g_sessions = malloc(allocated * sizeof(SESSION));
    if (!g_sessions)
        errx(-1, "Fatal: Out of memory");

    regex_t regex;
    if (regcomp(&regex, "^[[:digit:]]{4}-[[:digit:]][[:digit:]]-[[:digit:]][[:digit:]]-[[:digit:]][[:digit:]]-[[:digit:]][[:digit:]]-"
                         "[[:digit:]][[:digit:]]-[[:xdigit:]]{16}-[[:xdigit:]]{32}.log$", REG_EXTENDED))
        errx(-1, "Fatal: Can't compile re");

    struct dirent * entry;
    while ((entry = readdir(dirp)) != NULL) {
        // We expect a very specific format
        if (entry->d_type != DT_REG || entry->d_namlen != 73) {
            // fprintf(stderr, "Skipping wrong length file: %*s\n", entry->d_namlen, entry->d_name);
            continue;
        }

        SESSION *ns = g_sessions + g_session_count;

        memcpy(ns->file_name, entry->d_name, entry->d_namlen);
        ns->file_name[entry->d_namlen] = 0;
        ns->fd = -1;
        ns->last_access = 0;

        if (regexec(&regex, ns->file_name, (size_t) 0, NULL, 0)) {
            fprintf(stderr, "Skipping non-re-matching file: %s\n", ns->file_name);
            continue;
        }

        if (sscanf(ns->file_name + SIDOFFS, "%" SCNx64, (uint64_t*)&ns->session_id) != 1) {
            fprintf(stderr, "Skipping non-parsable file: %s\n", ns->file_name);
            continue;
        }
        const char * hexkey = ns->file_name + KEYOFFS;
        for (int i=0; i<16; ++i)
            ns->key[i] = (hex2nyble(hexkey[2*i]) << 4 ) | hex2nyble(hexkey[2*i+1]);

        // Need more memory?
        if (++g_session_count > allocated) {
            allocated += 1024;
            g_sessions = realloc(g_sessions, allocated * sizeof(SESSION));
            if (!g_sessions)
                errx(-1, "Fatal: Out of memory");
        }
    }
    closedir(dirp);
    regfree(&regex);

    qsort(g_sessions, g_session_count, sizeof(SESSION), session_compare);
}

static void add_session(uint8_t *session_id, uint8_t aes_key[AES_KEY_LENGTH]) {
    // We don't overwrite existing sessions
    if (bsearch(session_id, g_sessions, g_session_count, sizeof(SESSION), session_compare))
        return;

    SESSION * sessions_realloc = realloc(g_sessions, sizeof(SESSION) * (1 + g_session_count));

    if (!sessions_realloc)
        errx(-1, "Fatal: Out of memory");

    g_sessions = sessions_realloc;
    SESSION * new_session = &g_sessions[g_session_count];
    g_session_count++;

    memcpy(new_session->session_id, session_id, SESSION_ID_LENGTH);
    memcpy(new_session->key, aes_key, AES_KEY_LENGTH);
    new_session->fd = -1;
    new_session->last_access = 0;

    time_t t = time(NULL);
    struct tm * jetzt = localtime(&t);
    char tprefix[32], hexkey[2*AES_KEY_LENGTH + 1];
    size_t nlen = strftime(tprefix, sizeof(tprefix), "%F-%H-%M-%S", jetzt);
    for (int i=0; i<AES_KEY_LENGTH; ++i)
        sprintf(hexkey + 2 * i, "%02x", aes_key[i]);
    int off = snprintf(new_session->file_name, sizeof(new_session->file_name),
                        "%s-%" PRIx64 "-%s.log", tprefix, *(uint64_t*)session_id, hexkey);
    int fd = open("new_session->file_name", O_RDWR|O_CREAT, 0755);
    close(fd);

    mergesort(g_sessions, g_session_count, sizeof(SESSION), session_compare);
}

static time_t now() {
    struct timespec tp;
    clock_gettime(CLOCK_MONOTONIC, &tp);
    return tp.tv_sec;
}

static void log_to_session(const uint8_t *packet, size_t len) {
    // First check if the packet holds enough space for session id, iv and at least one gcm block
    if (len < SESSION_ID_LENGTH + GCM_IV_LENGTH + GCM_TAG_LENGTH) {
        fprintf(stderr, "Error: Short packet, size %zd\n", len);
        return;
    }

    const uint8_t *session_id = packet;
    const uint8_t *iv = packet + SESSION_ID_LENGTH;
    const uint8_t *tag = packet + SESSION_ID_LENGTH + GCM_IV_LENGTH;
    packet += SESSION_ID_LENGTH + GCM_IV_LENGTH + GCM_TAG_LENGTH;
    len -= SESSION_ID_LENGTH + GCM_IV_LENGTH + GCM_TAG_LENGTH;

    SESSION * s = bsearch(session_id, g_sessions, g_session_count, sizeof(SESSION), session_compare);
    if (!s) {
        fprintf(stderr, "Error: Can't log to unknown session 0x%" PRIx64 "\n", *(uint64_t*)session_id);
        return;
    }

    // Create output file if it doesn't exist
    if (s->fd < 0)
        s->fd = open(s->file_name, O_WRONLY | O_APPEND | O_CREAT, 0755);
    if (s->fd < 0) {
        fprintf(stderr, "Error: Can't create file %s for session 0x%" PRIx64 "\n", s->file_name, *(uint64_t*)packet);
        return;
    }

    // Now decrypt
    mbedtls_gcm_context ctx;
    mbedtls_gcm_init(&ctx);
    mbedtls_gcm_setkey(&ctx, MBEDTLS_CIPHER_ID_AES, s->key, 8 * AES_KEY_LENGTH);

    // Prepare output scratch space
    uint8_t *output = alloca(len);

    // This should fail on invalid input sizes
    switch(mbedtls_gcm_auth_decrypt(&ctx, len, iv, GCM_IV_LENGTH, session_id, SESSION_ID_LENGTH, tag, GCM_TAG_LENGTH, packet, output))
    {
        case 0:
            write(s->fd, output, len);
            s->last_access = now();
            break;
        case MBEDTLS_ERR_GCM_BAD_INPUT:
            fprintf(stderr, "Error: Invalid log data\n");
            break;
        case MBEDTLS_ERR_GCM_AUTH_FAILED :
            fprintf(stderr, "Error: Can't decrypt\n");
            break;
        default:
            fprintf(stderr, "Error: Unknown gcm error\n");
    }

    mbedtls_gcm_free(&ctx);
}

void close_files() {
    time_t jetzt = now();
    for (int i=0; i<g_session_count; ++i)
        if (g_sessions[i].fd >= 0 &&
            g_sessions[i].last_access != 0 &&
            jetzt - g_sessions[i].last_access > ACCESS_TIMEOUT) {
            close(g_sessions[i].fd);
            g_sessions[i].fd = -1;
            }
}

int main() {
    mbedtls_ctr_drbg_context ctr_drbg;
    mbedtls_entropy_context entropy;
    mbedtls_pk_context pk;
    int ret = 0;
    unsigned char result[256];
    unsigned char input[256];
    size_t inputlen = 0;

    chdir("sessions");

    mbedtls_pk_init( &pk );
    mbedtls_entropy_init( &entropy );
    mbedtls_ctr_drbg_init( &ctr_drbg );
    mbedtls_ctr_drbg_seed(&ctr_drbg, mbedtls_entropy_func, &entropy, pp, sizeof(pp));

    if ((ret = mbedtls_pk_parse_key(&pk, privkey, sizeof(privkey), NULL, 0) ) != 0 )
        errx(-1, "Fatal: mbedtls_pk_parse_key returned -0x%04x\n", -ret );

    int sock = socket(AF_INET, SOCK_DGRAM, 0);
    if (sock < 0)
        errx(-1, "Fatal: Can not make socket\n");

    struct sockaddr_in servaddr, peer;
    servaddr.sin_family = AF_INET;
    servaddr.sin_addr.s_addr = INADDR_ANY;
    servaddr.sin_port = htons(PORT);
    if (bind(sock, (const struct sockaddr *)&servaddr, sizeof(servaddr)) < 0)
        errx(-1, "Fatal: Can't bind to port %d\n", PORT);

    import_sessions(".");

    time_t last_close_check = now();
    while(1) {
        uint8_t packet[16*1024];
        uint8_t rsa_plain_text[AES_KEY_LENGTH];
        size_t olen;
        socklen_t peer_len = sizeof(peer);

        ssize_t len = recvfrom(sock, packet, sizeof(packet), MSG_WAITALL, (struct sockaddr *) &peer, &peer_len);
        if (len <= 0)
            errx(-1, "Fatal: recvfrom yields %zd\n", len);
        switch(packet[0]) {
        case 0: // Session setup
            ret = mbedtls_pk_decrypt(&pk, packet + 1 + SESSION_ID_LENGTH, len - 1 - SESSION_ID_LENGTH,
                                     rsa_plain_text, &olen, sizeof(rsa_plain_text), mbedtls_ctr_drbg_random, &ctr_drbg);
            if (ret < 0)
                fprintf(stderr, "Error: Failed to decrypt key (error %d) for session 0x%" PRIx64 "\n", -ret, *(uint64_t *)(packet + 1));
            else
                add_session(packet + 1, rsa_plain_text);
            break;
        case 1:
            log_to_session(packet + 1, len - 1);
            break;
        default:
            break;
        }
        time_t jetzt = now();
        if (jetzt > last_close_check + 60) {
            close_files();
            last_close_check = jetzt;
        }
    }
}