Commit ff2980d3 authored by jan.koester's avatar jan.koester
Browse files

test



Co-authored-by: default avatarCopilot <copilot@github.com>
parent 21f176ca
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -27,6 +27,9 @@ public:
    virtual std::vector<uint8_t> encrypt(const std::vector<uint8_t>& plaintext) = 0;
    virtual std::vector<uint8_t> decrypt(const std::vector<uint8_t>& ciphertext) = 0;

    // Single-block ECB encrypt (no allocation) — for header protection
    virtual void encrypt_ecb(const uint8_t in[16], uint8_t out[16]) = 0;

    virtual std::vector<uint8_t> encryptCBC(const std::vector<uint8_t>& plaintext,
                                            const std::vector<uint8_t>& iv) = 0;
    virtual std::vector<uint8_t> decryptCBC(const std::vector<uint8_t>& ciphertext,
@@ -160,6 +163,7 @@ public:

    std::vector<uint8_t> encrypt(const std::vector<uint8_t>& plaintext) override;
    std::vector<uint8_t> decrypt(const std::vector<uint8_t>& ciphertext) override;
    void encrypt_ecb(const uint8_t in[16], uint8_t out[16]) override { encrypt_block(in, out); }

    std::vector<uint8_t> encryptCBC(const std::vector<uint8_t>& plaintext,
                                    const std::vector<uint8_t>& iv) override;
@@ -281,6 +285,7 @@ public:

    std::vector<uint8_t> encrypt(const std::vector<uint8_t>& plaintext) override;
    std::vector<uint8_t> decrypt(const std::vector<uint8_t>& ciphertext) override;
    void encrypt_ecb(const uint8_t in[16], uint8_t out[16]) override { encrypt_block(in, out); }

    std::vector<uint8_t> encryptCBC(const std::vector<uint8_t>& plaintext,
                                    const std::vector<uint8_t>& iv) override;
+373 −153

File changed.

Preview size limit exceeded, changes collapsed.

+14 −1
Original line number Diff line number Diff line
@@ -704,6 +704,11 @@ namespace netplus {
		// Send packet to peer (uses sendto for child connections)
		ssize_t sendPacket(const uint8_t* data, size_t len);

		// Batched packet sending: accumulate packets, then flush via udp class
		void    batchPacket(const uint8_t* data, size_t len);
		ssize_t flushBatch();
		static constexpr size_t BATCH_MAX = 64;

		// Retry token validation
		bool validateRetryToken(const std::vector<uint8_t>& token);

@@ -760,6 +765,8 @@ namespace netplus {
		uint64_t _handshake_pn_recv = 0;
		uint64_t _app_pn_send = 0;
		uint64_t _app_pn_recv = 0;
		uint32_t _app_unacked_count = 0;  // Delayed ACK counter
		bool     _app_ack_pending = false; // Deferred ACK flag

		// Received packet-number ranges (inclusive) for proper ACK encoding
		// Each pair is [lo, hi] representing contiguous received PNs.
@@ -868,7 +875,7 @@ namespace netplus {

		// Transport parameters
		uint64_t _max_idle_timeout = 30000;    // 30s
		uint64_t _max_udp_payload = 1350;      // Safe post-handshake default (Ethernet-safe)
		uint64_t _max_udp_payload = 1472;      // Default: Ethernet-safe (1500 - IP/UDP headers)
		uint64_t _active_connection_id_limit = 2;

		// Receive/send queues
@@ -958,6 +965,12 @@ namespace netplus {
		// Persistent receive buffer for pumpNetwork() to avoid 65KB alloc/free per call
		buffer _pump_buf{65535};

		// Batch send buffer for flushBatch()
		struct BatchEntry {
			std::vector<uint8_t> data;
		};
		std::vector<BatchEntry> _send_batch;

		// Mutex for thread safety (recursive: musl returns EDEADLK on
		// double-lock by same thread, which happens when processFrame
		// callbacks re-enter locking functions like sendStreamData)
+7 −0
Original line number Diff line number Diff line
@@ -122,3 +122,10 @@ endif()
add_executable(aes_cts aes_cts.cpp)
target_link_libraries(aes_cts netplus-static)
add_test(NAME aes_cts_test COMMAND aes_cts)

add_executable(benchmark_quic benchmark_quic.cpp)
if(WIN32)
    target_link_libraries(benchmark_quic netplus-static ws2_32)
else()
    target_link_libraries(benchmark_quic netplus-static)
endif()
+521 −0
Original line number Diff line number Diff line
#include <iostream>
#include <iomanip>
#include <string>
#include <cstring>
#include <vector>
#include <thread>
#include <atomic>
#include <chrono>
#include <numeric>
#include <poll.h>

#include "connection.h"
#include "eventapi.h"
#include "socket.h"
#include "exception.h"

#include "https_certs.h"
#include "https_ca_cert.h"

using namespace netplus;

static std::atomic<bool> g_server_ready(false);
static std::atomic<bool> g_tls_server_ready(false);
static std::atomic<bool> g_shutdown(false);

// ============================================================================
// Echo server: receives stream data and sends it back (echo)
// ============================================================================

class EchoServer : public event {
public:
    EchoServer(std::vector<netplus::socket*> socks, int timeout = 500)
        : event(socks, timeout) {}

    void RequestEvent(con& curcon, const int tid, ULONG_PTR args) override {
        netplus::quic* q = dynamic_cast<netplus::quic*>(curcon.slots[0].csock.get());
        if (!q) return;

        // Echo back whatever was received
        if (!curcon.RecvData.empty()) {
            curcon.SendData.append(curcon.RecvData.data(), curcon.RecvData.size());
            curcon.RecvData.clear();
        }
    }

    void ResponseEvent(con& curcon, const int tid, ULONG_PTR args) override {}

    void ConnectEvent(con& curcon, const int tid, ULONG_PTR args) override {}

    void DisconnectEvent(con& curcon, const int tid, ULONG_PTR args) override {}

    void CreateConnection(std::shared_ptr<con>& res) override {
        res = std::make_shared<con>(this);
    }
};

// ============================================================================
// Server thread
// ============================================================================

static void run_server(std::map<std::string, netplus::ssl::CertificateBundle>& certs) {
    try {
        quic serverSock(certs, "127.0.0.1", 9443, 64, -1);

        // Set a simple echo callback: when stream data arrives, queue it for echo
        serverSock.setStreamCallback([](netplus::socket* sock, uint64_t stream_id,
                                        const std::vector<uint8_t>& data, bool fin) {
            netplus::quic* q = dynamic_cast<netplus::quic*>(sock);
            if (!q || data.empty()) return;
            q->sendStreamData(stream_id, data, fin);
        });

        EchoServer srv({&serverSock});
        g_server_ready.store(true);

        srv.runEventloop();
    } catch (std::exception& e) {
        std::cerr << "[Server] Error: " << e.what() << std::endl;
        g_server_ready.store(true); // unblock client even on failure
    }
}

// ============================================================================
// Timing helpers
// ============================================================================

using hrc = std::chrono::high_resolution_clock;

static double elapsed_ms(hrc::time_point start, hrc::time_point end) {
    return std::chrono::duration<double, std::milli>(end - start).count();
}

static void print_separator() {
    std::cout << std::string(78, '-') << std::endl;
}

// ============================================================================
// TLS Echo server: TCP+TLS echo for comparison
// ============================================================================

class TlsEchoServer : public event {
public:
    TlsEchoServer(std::vector<netplus::socket*> socks, int timeout = 500)
        : event(socks, timeout) {}

    void RequestEvent(con& curcon, const int tid, ULONG_PTR args) override {
        if (!curcon.RecvData.empty()) {
            curcon.SendData.append(curcon.RecvData.data(), curcon.RecvData.size());
            curcon.RecvData.clear();
        }
    }

    void ResponseEvent(con& curcon, const int tid, ULONG_PTR args) override {}
    void ConnectEvent(con& curcon, const int tid, ULONG_PTR args) override {}
    void DisconnectEvent(con& curcon, const int tid, ULONG_PTR args) override {}

    void CreateConnection(std::shared_ptr<con>& res) override {
        res = std::make_shared<con>(this);
    }
};

static void run_tls_server(std::map<std::string, netplus::ssl::CertificateBundle>& certs) {
    try {
        ssl serverSock(certs, "127.0.0.1", 9444, 64, -1);
        TlsEchoServer srv({&serverSock});
        g_tls_server_ready.store(true);
        srv.runEventloop();
    } catch (std::exception& e) {
        std::cerr << "[TLS Server] Error: " << e.what() << std::endl;
        g_tls_server_ready.store(true);
    }
}

// ============================================================================
// TLS Benchmark: Handshake latency
// ============================================================================

static double benchmark_tls_handshake(int iterations,
                                       std::map<std::string, netplus::ssl::CertificateBundle>& certs) {
    std::vector<double> times;
    times.reserve(iterations);

    for (int i = 0; i < iterations; i++) {
        auto t0 = hrc::now();
        ssl client(certs);
        client.connect("127.0.0.1", 9444);
        auto t1 = hrc::now();

        times.push_back(elapsed_ms(t0, t1));
        client.close();

        std::this_thread::sleep_for(std::chrono::milliseconds(50));
    }

    double avg = std::accumulate(times.begin(), times.end(), 0.0) / times.size();
    double min_t = *std::min_element(times.begin(), times.end());
    double max_t = *std::max_element(times.begin(), times.end());

    std::cout << std::fixed << std::setprecision(2);
    std::cout << "Handshake     | " << std::setw(4) << iterations << " iterations"
              << " | Avg: " << std::setw(8) << avg << " ms"
              << " | Min: " << std::setw(8) << min_t << " ms"
              << " | Max: " << std::setw(8) << max_t << " ms"
              << std::endl;
    return avg;
}

// ============================================================================
// TLS Benchmark: Transfer throughput
// ============================================================================

static void benchmark_tls_transfer(size_t payload_size, int iterations,
                                    std::map<std::string, netplus::ssl::CertificateBundle>& certs) {
    ssl client(certs);
    client.connect("127.0.0.1", 9444);  // blocking handshake

    std::vector<char> payload(payload_size, 0xBB);
    std::vector<char> recv_buf(65536);

    socketwait waiter;

    // Warmup (still blocking)
    {
        size_t warmup_size = std::min(payload_size, size_t(4096));
        buffer snd(payload.data(), warmup_size);
        client.sendData(snd);
        waiter.waitRead(client, 2000);
        buffer rcv(recv_buf.data(), recv_buf.size());
        try { client.recvData(rcv); } catch (...) {}
    }

    // Switch to non-blocking for the transfer loop
    client.setNonBlock();

    size_t total_sent = 0;
    size_t total_recv = 0;
    size_t target_bytes = size_t(payload_size) * iterations;
    int send_idx = 0;
    size_t send_off = 0;  // offset within current payload iteration

    struct pollfd pfd;
    pfd.fd = client.fd();

    auto t0 = hrc::now();

    while (total_recv < target_bytes) {
        // Poll for readable or writable (or both)
        pfd.events = POLLIN;
        if (send_idx < iterations) pfd.events |= POLLOUT;
        pfd.revents = 0;

        int pr = poll(&pfd, 1, (send_idx >= iterations) ? 500 : 10);
        if (pr < 0) break;

        // Try to send (writable or just attempt it)
        if ((pfd.revents & POLLOUT) && send_idx < iterations) {
            size_t chunk = std::min(payload_size - send_off, size_t(16384));
            buffer snd(payload.data() + send_off, chunk);
            try {
                size_t sent = client.sendData(snd);
                total_sent += sent;
                send_off += sent;
                if (send_off >= payload_size) {
                    send_off = 0;
                    send_idx++;
                }
            } catch (NetException& e) {
                // Note = would block, just try again later
                if (e.getErrorType() != NetException::Note) throw;
            }
        }

        // Try to receive (readable)
        if (pfd.revents & POLLIN) {
            // Drain all available data
            bool more = true;
            while (more) {
                buffer rcv(recv_buf.data(), recv_buf.size());
                try {
                    size_t got = client.recvData(rcv);
                    total_recv += got;
                } catch (NetException& e) {
                    more = false;
                    if (e.getErrorType() != NetException::Note) throw;
                }
            }
        }

        if (elapsed_ms(t0, hrc::now()) > 10000.0) break;
    }
    auto t1 = hrc::now();

    double total_ms = elapsed_ms(t0, t1);
    double send_tp = (double(total_sent) / (1024.0 * 1024.0)) / (total_ms / 1000.0);
    double recv_tp = (double(total_recv) / (1024.0 * 1024.0)) / (total_ms / 1000.0);

    std::cout << std::fixed << std::setprecision(2);
    std::cout << "Transfer      | "
              << std::setw(7) << payload_size << " B x " << std::setw(4) << iterations
              << " | Sent: " << std::setw(6) << (total_sent / 1024) << " KB"
              << " | Recv: " << std::setw(6) << (total_recv / 1024) << " KB"
              << " | " << std::setw(8) << total_ms << " ms"
              << std::endl;
    std::cout << "  Throughput  |"
              << " Send: " << std::setw(8) << send_tp << " MB/s"
              << " | Recv: " << std::setw(8) << recv_tp << " MB/s"
              << std::endl;

    client.close();
}

// ============================================================================
// Benchmark: QUIC Connection + Handshake latency
// ============================================================================

static double benchmark_handshake(int iterations) {
    std::vector<double> times;
    times.reserve(iterations);

    for (int i = 0; i < iterations; i++) {
        auto t0 = hrc::now();
        quic client;
        client.connect("127.0.0.1", 9443);
        auto t1 = hrc::now();

        times.push_back(elapsed_ms(t0, t1));
        client.close();

        // Small delay between connections
        std::this_thread::sleep_for(std::chrono::milliseconds(50));
    }

    double avg = std::accumulate(times.begin(), times.end(), 0.0) / times.size();
    double min_t = *std::min_element(times.begin(), times.end());
    double max_t = *std::max_element(times.begin(), times.end());

    std::cout << std::fixed << std::setprecision(2);
    std::cout << "Handshake     | " << std::setw(4) << iterations << " iterations"
              << " | Avg: " << std::setw(8) << avg << " ms"
              << " | Min: " << std::setw(8) << min_t << " ms"
              << " | Max: " << std::setw(8) << max_t << " ms"
              << std::endl;
    return avg;
}

// ============================================================================
// Benchmark: Stream data transfer throughput
// ============================================================================

static void benchmark_transfer(size_t payload_size, int iterations) {
    quic client;
    client.connect("127.0.0.1", 9443);

    uint64_t stream_id = client.openStream(true);

    // Prepare payload
    std::vector<uint8_t> payload(payload_size, 0xBB);
    std::vector<uint8_t> recv_buf(payload_size + 1024);

    // Warmup: send one block and wait for echo
    client.sendStreamData(stream_id, payload.data(), payload.size(), false);
    {
        auto deadline = hrc::now() + std::chrono::seconds(2);
        size_t warmup_recv = 0;
        while (warmup_recv < payload_size && hrc::now() < deadline) {
            client.pumpNetwork();
            if (client.hasStreamData(stream_id)) {
                warmup_recv += client.recvStreamData(stream_id, recv_buf.data(), recv_buf.size());
            } else {
                std::this_thread::sleep_for(std::chrono::microseconds(100));
            }
        }
    }

    // Benchmark: interleave send and receive to avoid flow-control stalls
    size_t total_sent = 0;
    size_t total_recv = 0;
    size_t target_bytes = size_t(payload_size) * iterations;

    auto t0 = hrc::now();
    int send_idx = 0;

    while (total_recv < target_bytes) {
        // Send next chunk if we haven't sent everything yet
        if (send_idx < iterations) {
            bool fin = (send_idx == iterations - 1);
            size_t sent = client.sendStreamData(stream_id, payload.data(), payload.size(), fin);
            total_sent += sent;
            send_idx++;
        }

        // Pump and receive — drain all available data
        client.pumpNetwork(MSG_DONTWAIT);
        while (client.hasStreamData(stream_id)) {
            size_t got = client.recvStreamData(stream_id, recv_buf.data(), recv_buf.size());
            if (got == 0) break;
            total_recv += got;
        }

        // Timeout safety
        if (elapsed_ms(t0, hrc::now()) > 10000.0) break;
    }
    auto t1 = hrc::now();

    double total_ms = elapsed_ms(t0, t1);
    double send_tp = (double(total_sent) / (1024.0 * 1024.0)) / (total_ms / 1000.0);
    double recv_tp = (double(total_recv) / (1024.0 * 1024.0)) / (total_ms / 1000.0);

    std::cout << std::fixed << std::setprecision(2);
    std::cout << "Transfer      | "
              << std::setw(7) << payload_size << " B x " << std::setw(4) << iterations
              << " | Sent: " << std::setw(6) << (total_sent / 1024) << " KB"
              << " | Recv: " << std::setw(6) << (total_recv / 1024) << " KB"
              << " | " << std::setw(8) << total_ms << " ms"
              << std::endl;
    std::cout << "  Throughput  |"
              << " Send: " << std::setw(8) << send_tp << " MB/s"
              << " | Recv: " << std::setw(8) << recv_tp << " MB/s"
              << std::endl;

    client.close();
}

// ============================================================================
// Benchmark: Multiple stream creation
// ============================================================================

static void benchmark_stream_creation(int count) {
    quic client;
    client.connect("127.0.0.1", 9443);

    auto t0 = hrc::now();
    std::vector<uint64_t> streams;
    streams.reserve(count);
    for (int i = 0; i < count; i++) {
        streams.push_back(client.openStream(true));
    }
    auto t1 = hrc::now();

    double total_ms = elapsed_ms(t0, t1);
    double per_stream_us = (total_ms * 1000.0) / count;

    std::cout << std::fixed << std::setprecision(2);
    std::cout << "Stream open   | " << std::setw(4) << count << " streams"
              << "      | Total: " << std::setw(8) << total_ms << " ms"
              << " | Per stream: " << std::setw(8) << per_stream_us << " us"
              << std::endl;

    // Clean up
    for (auto sid : streams) {
        client.closeStream(sid);
    }
    client.close();
}

// ============================================================================
// Main
// ============================================================================

int main() {
    try {
        // Load certificates
        x509cert cert;
        if (!cert.loadFromBuffer(test_cert_der)) {
            std::cerr << "Failed to load certificate" << std::endl;
            return 1;
        }

        std::map<std::string, netplus::ssl::CertificateBundle> certs;
        netplus::ssl::CertificateBundle bundle;
        bundle.cert = cert;
        bundle.privateKeyDer = std::vector<uint8_t>(test_key_der.begin(), test_key_der.end());
        bundle.rsa_key = netplus::rsa(bundle.privateKeyDer);
        bundle.chain.push_back(std::vector<uint8_t>(MKCERT_ROOT_CA_DER,
            MKCERT_ROOT_CA_DER + MKCERT_ROOT_CA_DER_LEN));

        certs["localhost"] = bundle;
        certs["127.0.0.1"] = bundle;

        // Start QUIC server
        std::thread server_thread(run_server, std::ref(certs));
        server_thread.detach();
        while (!g_server_ready.load())
            std::this_thread::sleep_for(std::chrono::milliseconds(10));

        // Start TLS server
        std::thread tls_thread(run_tls_server, std::ref(certs));
        tls_thread.detach();
        while (!g_tls_server_ready.load())
            std::this_thread::sleep_for(std::chrono::milliseconds(10));

        // Extra settle time
        std::this_thread::sleep_for(std::chrono::milliseconds(100));

        std::cout << "QUIC vs TCP+TLS Performance Comparison" << std::endl;
        std::cout << "(loopback echo, 127.0.0.1)" << std::endl;
        print_separator();

        // ================================================================
        // 1. Handshake latency comparison
        // ================================================================
        std::cout << "\n[Handshake Latency - QUIC (UDP+TLS 1.3)]\n";
        print_separator();
        double quic_hs = benchmark_handshake(10);
        print_separator();

        std::cout << "\n[Handshake Latency - TCP+TLS 1.3]\n";
        print_separator();
        double tls_hs = benchmark_tls_handshake(10, certs);
        print_separator();

        std::cout << "\n  => QUIC/TLS handshake ratio: " << std::fixed << std::setprecision(2)
                  << (quic_hs / tls_hs) << "x\n";

        // ================================================================
        // 2. Stream creation (QUIC only — TCP has no equivalent)
        // ================================================================
        std::cout << "\n[Stream Creation - QUIC only]\n";
        print_separator();
        benchmark_stream_creation(100);
        benchmark_stream_creation(1000);
        print_separator();

        // ================================================================
        // 3. Transfer throughput comparison
        // ================================================================
        struct TransferTest { size_t size; int iters; };
        TransferTest tests[] = {
            {1024, 100}, {16384, 100}, {65536, 50}, {262144, 20}
        };

        std::cout << "\n[Transfer Throughput - QUIC (UDP+TLS 1.3)]\n";
        print_separator();
        for (auto& t : tests) {
            benchmark_transfer(t.size, t.iters);
            std::cout << std::endl;
        }
        print_separator();

        std::cout << "\n[Transfer Throughput - TCP+TLS 1.3]\n";
        print_separator();
        for (auto& t : tests) {
            benchmark_tls_transfer(t.size, t.iters, certs);
            std::cout << std::endl;
        }
        print_separator();

        g_shutdown.store(true);
        std::this_thread::sleep_for(std::chrono::milliseconds(200));

        std::cout << "\nAll benchmarks complete." << std::endl;
        return 0;

    } catch (netplus::NetException& e) {
        std::cerr << "NetException: " << e.what() << std::endl;
        return 1;
    } catch (std::exception& e) {
        std::cerr << "Exception: " << e.what() << std::endl;
        return 1;
    }
}
Loading