Commit fcf4fdb5 authored by jan.koester's avatar jan.koester
Browse files

test

parent 2e98fc6e
Loading
Loading
Loading
Loading
+283 −50
Original line number Diff line number Diff line
@@ -476,8 +476,22 @@ namespace netplus {
                handshakeStarted = true;
                is_tls13 = false;
                chosenSuite = 0;
                if (hasSuite(0x002F))      chosenSuite = 0x002F;
                else if (hasSuite(0x0035)) chosenSuite = 0x0035;

                // Prefer ECDHE-GCM suites (work with both RSA and ECDSA certs)
                bool have_ec_cert = (selected_cert_bundle && selected_cert_bundle->has_ec_key);
                bool have_rsa_cert = (selected_cert_bundle && selected_cert_bundle->rsa_key);

                if (have_ec_cert) {
                    // ECDSA certificate: only ECDHE_ECDSA suites
                    if (hasSuite(0xC02C))      chosenSuite = 0xC02C; // ECDHE_ECDSA_AES_256_GCM_SHA384
                    else if (hasSuite(0xC02B)) chosenSuite = 0xC02B; // ECDHE_ECDSA_AES_128_GCM_SHA256
                } else if (have_rsa_cert) {
                    // RSA certificate: prefer ECDHE_RSA, fall back to static RSA
                    if (hasSuite(0xC030))      chosenSuite = 0xC030; // ECDHE_RSA_AES_256_GCM_SHA384
                    else if (hasSuite(0xC02F)) chosenSuite = 0xC02F; // ECDHE_RSA_AES_128_GCM_SHA256
                    else if (hasSuite(0x002F)) chosenSuite = 0x002F; // RSA_AES_128_CBC_SHA
                    else if (hasSuite(0x0035)) chosenSuite = 0x0035; // RSA_AES_256_CBC_SHA
                }

                if (chosenSuite == 0)
                    throwSSL(NetException::Error, "No supported cipher suite");
@@ -857,6 +871,61 @@ namespace netplus {
                    }

                    sendHandshake(0x0b, certMsg);

                    // --- ServerKeyExchange for ECDHE suites ---
                    bool isECDHE = (chosenSuite == 0xC02B || chosenSuite == 0xC02C ||
                                    chosenSuite == 0xC02F || chosenSuite == 0xC030);
                    if (isECDHE) {
                        // Generate ephemeral ECDHE P-256 key pair
                        gen_tls13_p256_scalar(tls12_ecdhe_priv);
                        netplus::P256Point pub = netplus::scalar_mul_G(tls12_ecdhe_priv);
                        std::vector<uint8_t> pubBytes = netplus::encode_tls_point(pub);

                        // Build ServerKeyExchange params (ECParameters + ECPoint)
                        // ECParameters: curve_type(1) = named_curve(3), named_curve(2) = secp256r1(0x0017)
                        std::vector<uint8_t> params;
                        params.push_back(0x03); // named_curve
                        params.push_back(0x00); params.push_back(0x17); // secp256r1
                        params.push_back((uint8_t)pubBytes.size()); // point length
                        params.insert(params.end(), pubBytes.begin(), pubBytes.end());

                        // Sign: SHA-256(client_random + server_random + params)
                        std::vector<uint8_t> toSign;
                        toSign.insert(toSign.end(), clientRandom.begin(), clientRandom.end());
                        toSign.insert(toSign.end(), serverRandom.begin(), serverRandom.end());
                        toSign.insert(toSign.end(), params.begin(), params.end());

                        std::vector<uint8_t> sig;
                        uint16_t sigAlg;

                        if (chosenSuite == 0xC02B || chosenSuite == 0xC02C) {
                            // ECDHE_ECDSA: sign with ECDSA
                            if (!selected_cert_bundle->has_ec_key)
                                throwSSL(NetException::Error, "ECDHE_ECDSA requires EC key");
                            std::vector<uint8_t> ecKey(selected_cert_bundle->ecPrivateKey,
                                                       selected_cert_bundle->ecPrivateKey + 32);
                            sig = sign_ecdsa_sha256(toSign, ecKey);
                            sigAlg = 0x0403; // ecdsa_secp256r1_sha256
                        } else {
                            // ECDHE_RSA: sign with RSA PKCS#1 v1.5
                            if (!selected_cert_bundle->rsa_key)
                                throwSSL(NetException::Error, "ECDHE_RSA requires RSA key");
                            sig = sign_rsa_sha256_pkcs15(toSign, selected_cert_bundle->rsa_key);
                            sigAlg = 0x0401; // rsa_pkcs1_sha256
                        }

                        // Build ServerKeyExchange body: params + sig_algorithm(2) + sig_len(2) + sig
                        std::vector<uint8_t> skeBody;
                        skeBody.insert(skeBody.end(), params.begin(), params.end());
                        skeBody.push_back((sigAlg >> 8) & 0xFF);
                        skeBody.push_back(sigAlg & 0xFF);
                        skeBody.push_back((uint16_t(sig.size()) >> 8) & 0xFF);
                        skeBody.push_back(uint16_t(sig.size()) & 0xFF);
                        skeBody.insert(skeBody.end(), sig.begin(), sig.end());

                        sendHandshake(0x0c, skeBody); // ServerKeyExchange
                    }

                    sendHandshake(0x0e, {});

                    serverFlightQueued = true;
@@ -938,7 +1007,7 @@ namespace netplus {
                uint16_t finVer = ((uint16_t)finRec[1] << 8) | (uint16_t)finRec[2];
                std::vector<uint8_t> finFrag(finRec.begin() + 5, finRec.end());

                std::vector<uint8_t> finPT = decryptRecordCBC(0x16, finVer, finFrag);
                std::vector<uint8_t> finPT = decryptTLS12Record(0x16, finVer, finFrag);

                if (finPT.size() < 4 + 12 ||
                    finPT[0] != 0x14 || finPT[1] != 0x00 || finPT[2] != 0x00 || finPT[3] != 0x0c) {
@@ -1047,17 +1116,44 @@ namespace netplus {
                if (msg.size() != 4 + hlen)
                    throwSSL(NetException::Error, "ClientKeyExchange length mismatch");

                std::vector<uint8_t> preMaster;
                bool isECDHE = (chosenSuite == 0xC02B || chosenSuite == 0xC02C ||
                                chosenSuite == 0xC02F || chosenSuite == 0xC030);

                if (isECDHE) {
                    // ECDHE CKE: body = point_length(1) + uncompressed_point(65)
                    size_t off = 4;
                    if (hlen < 1)
                        throwSSL(NetException::Error, "ECDHE CKE body too short");
                    uint8_t ptLen = msg[off]; off++;
                    if (ptLen != 65 || off + ptLen > msg.size())
                        throwSSL(NetException::Error, "ECDHE CKE invalid point length");
                    if (msg[off] != 0x04)
                        throwSSL(NetException::Error, "ECDHE CKE point not uncompressed");

                    netplus::P256Point clientPub;
                    if (!netplus::decode_tls_point(clientPub, msg.data() + off, 65))
                        throwSSL(NetException::Error, "ECDHE CKE invalid EC point");
                    if (!netplus::is_on_curve(clientPub))
                        throwSSL(NetException::Error, "ECDHE CKE point not on curve");

                    // Compute shared secret
                    preMaster.resize(32);
                    if (!netplus::ecdh_shared_secret(preMaster.data(), tls12_ecdhe_priv, clientPub))
                        throwSSL(NetException::Error, "ECDHE shared secret computation failed");

                    tls12_is_gcm = true;
                } else {
                    // RSA CKE body: uint16 encLen + ciphertext
                    if (hlen < 2)
                    throwSSL(NetException::Error, "ClientKeyExchange body too short");
                        throwSSL(NetException::Error, "RSA CKE body too short");

                    size_t off = 4;

                    uint16_t encLen = (uint16_t(msg[off]) << 8) | msg[off + 1];
                    off += 2;

                    if (off + encLen > msg.size())
                    throwSSL(NetException::Error, "ClientKeyExchange truncated ciphertext");
                        throwSSL(NetException::Error, "RSA CKE truncated ciphertext");

                    const size_t kBytes = (selected_cert_bundle->rsa_key.n.bitLength() + 7) / 8;
                    if (encLen != kBytes) {
@@ -1066,47 +1162,62 @@ namespace netplus {
                            " kBytes=" + std::to_string(kBytes));
                    }

    #if SSL_DEBUG
    #endif

                    rsa::bigInt cipher = rsa::bigIntFromBytesBE(msg.data() + off, encLen);
                    rsa::bigInt plainBI = selected_cert_bundle->rsa_key.decrypt(cipher);

    #if SSL_DEBUG
    #endif

                    std::vector<uint8_t> pkcs1 = rsa::bigIntToBytesBE(plainBI, kBytes);
                std::vector<uint8_t> preMaster = extractPreMasterFromPkcs1(pkcs1);

    #if SSL_DEBUG
    #endif
                    preMaster = extractPreMasterFromPkcs1(pkcs1);
                }

                // master_secret = PRF(PMS, "master secret", client_random || server_random)
                std::vector<uint8_t> msSeed = clientRandom;
                msSeed.insert(msSeed.end(), serverRandom.begin(), serverRandom.end());
                masterSecret = prf(preMaster, "master secret", msSeed, 48);

    #if SSL_DEBUG
    #endif

                // key_block = PRF(master, "key expansion", server_random || client_random)
                std::vector<uint8_t> kbSeed = serverRandom;
                kbSeed.insert(kbSeed.end(), clientRandom.begin(), clientRandom.end());
                std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, 72);

    #if SSL_DEBUG
    #endif
                if (tls12_is_gcm) {
                    // GCM key block: client_write_key + server_write_key + client_write_IV(4) + server_write_IV(4)
                    size_t keyLen = (chosenSuite == 0xC02C || chosenSuite == 0xC030) ? 32 : 16;
                    size_t kbLen = keyLen + keyLen + 4 + 4;
                    std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, kbLen);

                    size_t k = 0;
                    std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen;
                    std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen;
                    std::memcpy(tls12_client_write_iv, keyBlock.data() + k, 4); k += 4;
                    std::memcpy(tls12_server_write_iv, keyBlock.data() + k, 4);

                    if (keyLen == 32) {
                        aes_recv = std::make_unique<aes256>(clientKey);
                        aes      = std::make_unique<aes256>(serverKey);
                    } else {
                        aes_recv = std::make_unique<aes128>(clientKey);
                        aes      = std::make_unique<aes128>(serverKey);
                    }
                } else {
                    // CBC key block: client_mac + server_mac + client_key + server_key
                    size_t keyLen = (chosenSuite == 0x0035) ? 32 : 16;
                    size_t kbLen = 20 + 20 + keyLen + keyLen;
                    std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, kbLen);

                    size_t k = 0;
                    client_mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20;
                    mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20);        k += 20;

                std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + 16); k += 16;
                std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + 16);
                    std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen;
                    std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen);

                    if (keyLen == 32) {
                        aes_recv = std::make_unique<aes256>(clientKey);
                        aes      = std::make_unique<aes256>(serverKey);
                    } else {
                        aes_recv = std::make_unique<aes128>(clientKey);
                        aes      = std::make_unique<aes128>(serverKey);

                    }
                }

                hs_state = HsState::WAIT_CCS;
                return;
@@ -2238,7 +2349,7 @@ namespace netplus {
                    std::vector<uint8_t> finFrag(rec.begin() + 5, rec.end());

                    // Decrypt the record
                    std::vector<uint8_t> finPT = decryptRecordCBC(0x16, finVer, finFrag);
                    std::vector<uint8_t> finPT = decryptTLS12Record(0x16, finVer, finFrag);

                    if (finPT.size() < 4 + 12 ||
                        finPT[0] != 0x14 || finPT[1] != 0x00 || finPT[2] != 0x00 || finPT[3] != 0x0c) {
@@ -2368,7 +2479,7 @@ namespace netplus {
                if (!aes_recv)
                    throwSSL(NetException::Error, "aes_recv missing");

                std::vector<uint8_t> plain = decryptRecordCBC(outer_type, ver, frag);
                std::vector<uint8_t> plain = decryptTLS12Record(outer_type, ver, frag);

                if (plain.empty())
                    continue;
@@ -3166,6 +3277,8 @@ namespace netplus {
    // sendTLS12Record - Encrypt and send TLS 1.2 record
    // ============================================================================
    void tls::sendTLS12Record(uint8_t type, const std::vector<uint8_t>& plain) {
        if (tls12_is_gcm) { sendTLS12RecordGCM(type, plain); return; }

        const uint16_t ver = 0x0303;
        const size_t macLen = 20;
        const size_t blockSize = 16;
@@ -3236,6 +3349,8 @@ namespace netplus {
    }

    void tls::sendTLS12Record(uint8_t type, const uint8_t* plain_data, size_t plain_len) {
        if (tls12_is_gcm) { sendTLS12RecordGCM(type, plain_data, plain_len); return; }

        const uint16_t ver = 0x0303;
        const size_t macLen = 20;
        const size_t blockSize = 16;
@@ -3304,6 +3419,122 @@ namespace netplus {
        send_seq++;
    }

    // ============================================================================
    // decryptRecordGCM - Decrypt TLS 1.2 GCM record
    // ============================================================================
    std::vector<uint8_t> tls::decryptRecordGCM(
        uint8_t type, uint16_t ver, const std::vector<uint8_t>& payload) 
    {
        if (!aes_recv) {
            NetException e;
            e[NetException::Error] << "tls: decryptRecordGCM called but aes_recv is NULL";
            throw e;
        }

        // GCM record payload: explicit_nonce(8) + ciphertext + tag(16)
        if (payload.size() < 8 + 16)
            throwSSL(NetException::Error, "GCM record too short");

        const uint8_t* explicit_nonce = payload.data();
        size_t ct_len = payload.size() - 8 - 16;
        const uint8_t* ct = payload.data() + 8;
        const uint8_t* tag = payload.data() + 8 + ct_len;

        // Nonce: implicit_iv(4) || explicit_nonce(8) = 12 bytes
        uint8_t nonce[12];
        const uint8_t* recv_iv = is_client ? tls12_server_write_iv : tls12_client_write_iv;
        std::memcpy(nonce, recv_iv, 4);
        std::memcpy(nonce + 4, explicit_nonce, 8);

        // AAD: seq(8) + type(1) + version(2) + plaintext_length(2) = 13 bytes
        uint8_t aad[13];
        seqToBytes(recv_seq, aad);
        aad[8] = type;
        aad[9] = uint8_t(ver >> 8);
        aad[10] = uint8_t(ver & 0xFF);
        aad[11] = uint8_t((ct_len >> 8) & 0xFF);
        aad[12] = uint8_t(ct_len & 0xFF);

        std::vector<uint8_t> plaintext(ct_len);
        uint8_t tag_copy[16];
        std::memcpy(tag_copy, tag, 16);

        if (!aes_recv->aes_gcm_decrypt(nonce, aad, 13, ct, ct_len, tag_copy, plaintext.data()))
            throwSSL(NetException::Error, "GCM tag verification failed");

        recv_seq++;
        return plaintext;
    }

    // ============================================================================
    // decryptTLS12Record - Dispatch to CBC or GCM decrypt
    // ============================================================================
    std::vector<uint8_t> tls::decryptTLS12Record(
        uint8_t type, uint16_t ver, const std::vector<uint8_t>& payload)
    {
        if (tls12_is_gcm)
            return decryptRecordGCM(type, ver, payload);
        return decryptRecordCBC(type, ver, payload);
    }

    // ============================================================================
    // sendTLS12RecordGCM - Encrypt and send TLS 1.2 GCM record (vector overload)
    // ============================================================================
    void tls::sendTLS12RecordGCM(uint8_t type, const std::vector<uint8_t>& plain) {
        sendTLS12RecordGCM(type, plain.data(), plain.size());
    }

    // ============================================================================
    // sendTLS12RecordGCM - Encrypt and send TLS 1.2 GCM record (pointer overload)
    // ============================================================================
    void tls::sendTLS12RecordGCM(uint8_t type, const uint8_t* plain_data, size_t plain_len) {
        const uint16_t ver = 0x0303;

        if (!aes) {
            throwSSL(NetException::Error, "sendTLS12RecordGCM: AES cipher not initialized");
        }

        // Build nonce: implicit_iv(4) || explicit_nonce(8) = 12 bytes
        // Use send_seq as explicit nonce
        uint8_t explicit_nonce[8];
        seqToBytes(send_seq, explicit_nonce);

        uint8_t nonce[12];
        const uint8_t* send_iv = is_client ? tls12_client_write_iv : tls12_server_write_iv;
        std::memcpy(nonce, send_iv, 4);
        std::memcpy(nonce + 4, explicit_nonce, 8);

        // AAD: seq(8) + type(1) + version(2) + plaintext_length(2) = 13 bytes
        uint8_t aad[13];
        seqToBytes(send_seq, aad);
        aad[8] = type;
        aad[9] = uint8_t(ver >> 8);
        aad[10] = uint8_t(ver & 0xFF);
        aad[11] = uint8_t((plain_len >> 8) & 0xFF);
        aad[12] = uint8_t(plain_len & 0xFF);

        // Encrypt
        std::vector<uint8_t> ct(plain_len);
        uint8_t tag[16];
        aes->aes_gcm_encrypt(nonce, aad, 13, plain_data, plain_len, ct.data(), tag);

        // Record payload: explicit_nonce(8) + ciphertext + tag(16)
        size_t payload_len = 8 + plain_len + 16;
        uint8_t recHdr[5];
        recHdr[0] = type;
        recHdr[1] = uint8_t(ver >> 8);
        recHdr[2] = uint8_t(ver & 0xFF);
        recHdr[3] = uint8_t((payload_len >> 8) & 0xFF);
        recHdr[4] = uint8_t(payload_len & 0xFF);

        queueRaw(recHdr, 5);
        queueRaw(explicit_nonce, 8);
        queueRaw(ct.data(), ct.size());
        queueRaw(tag, 16);

        send_seq++;
    }

    // ============================================================================
    // sendTLS13Record - Encrypt and send TLS 1.3 application data record
    // ============================================================================
@@ -3568,7 +3799,7 @@ namespace netplus {

            if (ccs_received && aes_recv) {
                uint16_t ver = (uint16_t(rec[1]) << 8) | uint16_t(rec[2]);
                fragment = decryptRecordCBC(recType, ver, fragment);
                fragment = decryptTLS12Record(recType, ver, fragment);
            }

            rx_handshake_buf.insert(rx_handshake_buf.end(), 
@@ -3613,8 +3844,10 @@ namespace netplus {
            throw e;
        };

        // Sanity: allow only TLS1.2 suites here
        if (chosenSuiteArg != 0x002F && chosenSuiteArg != 0x0035) {
        // Sanity: allow TLS1.2 suites (RSA + ECDHE-GCM)
        if (chosenSuiteArg != 0x002F && chosenSuiteArg != 0x0035 &&
            chosenSuiteArg != 0xC02B && chosenSuiteArg != 0xC02C &&
            chosenSuiteArg != 0xC02F && chosenSuiteArg != 0xC030) {
            throwSSL(netplus::NetException::Error, "Invalid TLS1.2 cipher suite selected");
        }

+29 −0
Original line number Diff line number Diff line
@@ -280,6 +280,18 @@ namespace netplus {
            const std::vector<uint8_t>& frag
        );

        std::vector<uint8_t> decryptRecordGCM(
            uint8_t recType,
            uint16_t ver,
            const std::vector<uint8_t>& frag
        );

        std::vector<uint8_t> decryptTLS12Record(
            uint8_t recType,
            uint16_t ver,
            const std::vector<uint8_t>& frag
        );

        void sendTLS12Record(
            uint8_t recType,
            const std::vector<uint8_t>& content
@@ -289,6 +301,15 @@ namespace netplus {
            const uint8_t* data, size_t len
        );

        void sendTLS12RecordGCM(
            uint8_t recType,
            const std::vector<uint8_t>& content
        );
        void sendTLS12RecordGCM(
            uint8_t recType,
            const uint8_t* data, size_t len
        );

        // --- TLS 1.2 Handshake Operations ---
        bool popHandshakeMsg(
            std::vector<uint8_t>& out,
@@ -453,6 +474,14 @@ namespace netplus {
        std::vector<uint8_t> mac_key;
        std::vector<uint8_t> client_mac_key;

        // --- TLS 1.2 GCM implicit IVs (4 bytes each) ---
        uint8_t tls12_client_write_iv[4] = {0};
        uint8_t tls12_server_write_iv[4] = {0};
        bool tls12_is_gcm = false;

        // --- TLS 1.2 ECDHE state ---
        uint8_t tls12_ecdhe_priv[32] = {0};  // Server ephemeral P-256 private key

        // --- TLS 1.2 Record Layer Encryption/Decryption ---
        std::unique_ptr<netplus::aes> aes = nullptr;
        std::unique_ptr<netplus::aes> aes_recv = nullptr;