Commit 80f7bb8f authored by jan.koester's avatar jan.koester
Browse files

test

parent 36761cfe
Loading
Loading
Loading
Loading
+135 −9
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@

#include "tls.h"
#include "sha.h"
#include "base64.h"
#include "rsa_pss_sha256.h"
#include "ecc_p256.h"
#include "ecc_u256.h"
@@ -4020,10 +4021,16 @@ namespace netplus {
            try {
                selected_cert_bundle->rsa_key = netplus::rsa(selected_cert_bundle->privateKeyDer);
            } catch (const std::exception& e) {
                // Warning: Failed to load RSA key
                // Not RSA — check for EC key in bundle
            }
        }

        // Load EC key from bundle if available
        if (selected_cert_bundle && selected_cert_bundle->has_ec_key && !has_ec_key) {
            std::memcpy(ec_priv, selected_cert_bundle->ecPrivateKey, 32);
            has_ec_key = true;
        }
        
        if (selected_cert_bundle && selected_cert_bundle->rsa_key) {
            // For TLS 1.3, we MUST use RSA-PSS-RSAE-SHA256 (0x0804)
            sig = sign_rsa_pss_sha256(toSign, selected_cert_bundle->rsa_key);
@@ -4392,6 +4399,129 @@ namespace netplus {
        tls13_send_handshake(0x14, verify_data, handshake_keys);
    }

    // Helper: extract EC P-256 private scalar (32 bytes) from DER key data.
    // Handles both PKCS#8 (-----BEGIN PRIVATE KEY-----) and SEC1
    // (-----BEGIN EC PRIVATE KEY-----) formats.
    static bool extractECPrivateKey(const std::vector<uint8_t>& der,
                                     uint8_t outKey[32]) {
        if (der.empty() || der[0] != 0x30) return false;

        auto readLen = [](const uint8_t* d, size_t max, size_t& len, size_t& hdr) -> bool {
            if (max < 1) return false;
            len = d[0]; hdr = 1;
            if (len & 0x80) {
                int n = len & 0x7f;
                if (n <= 0 || n > 4 || max < 1u + n) return false;
                len = 0;
                for (int i = 0; i < n; i++) len = (len << 8) | d[1 + i];
                hdr = 1 + n;
            }
            return true;
        };

        const uint8_t* p = der.data();
        size_t sz = der.size();
        size_t pos = 1; // skip outer SEQUENCE tag

        size_t outerLen, hdr;
        if (!readLen(p + pos, sz - pos, outerLen, hdr)) return false;
        pos += hdr;

        // PKCS#8: version INTEGER, then algorithm SEQUENCE
        if (p[pos] == 0x02) { // INTEGER (version)
            // Skip version
            pos++;
            size_t vlen; if (!readLen(p + pos, sz - pos, vlen, hdr)) return false;
            pos += hdr + vlen;

            // Algorithm SEQUENCE
            if (pos >= sz || p[pos] != 0x30) return false;
            pos++;
            size_t algLen; if (!readLen(p + pos, sz - pos, algLen, hdr)) return false;
            pos += hdr;

            // Check for EC OID (1.2.840.10045.2.1) = 06 07 2A 86 48 CE 3D 02 01
            static const uint8_t EC_OID[] = {0x06, 0x07, 0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01};
            if (pos + sizeof(EC_OID) > sz) return false;
            if (std::memcmp(p + pos, EC_OID, sizeof(EC_OID)) != 0) return false;
            pos += algLen; // skip entire algorithm sequence content

            // OCTET STRING wrapping the inner EC key
            if (pos >= sz || p[pos] != 0x04) return false;
            pos++;
            size_t octLen; if (!readLen(p + pos, sz - pos, octLen, hdr)) return false;
            pos += hdr;

            // Now at inner ECPrivateKey SEQUENCE (SEC1 format)
            if (pos >= sz || p[pos] != 0x30) return false;
        }

        // SEC1 ECPrivateKey: SEQUENCE { INTEGER 1, OCTET STRING(32), ... }
        if (p[pos] != 0x30) return false;
        pos++;
        size_t innerLen; if (!readLen(p + pos, sz - pos, innerLen, hdr)) return false;
        pos += hdr;

        // version INTEGER = 1
        if (pos >= sz || p[pos] != 0x02) return false;
        pos++;
        size_t vlen2; if (!readLen(p + pos, sz - pos, vlen2, hdr)) return false;
        pos += hdr + vlen2;

        // OCTET STRING containing the 32-byte private scalar
        if (pos >= sz || p[pos] != 0x04) return false;
        pos++;
        size_t keyLen; if (!readLen(p + pos, sz - pos, keyLen, hdr)) return false;
        pos += hdr;

        if (keyLen != 32 || pos + 32 > sz) return false;
        std::memcpy(outKey, p + pos, 32);
        return true;
    }

    // Helper: try RSA first, fall back to EC P-256 key extraction.
    static void loadPrivateKey(const std::vector<uint8_t>& keyData,
                                tls::CertificateBundle& bundle) {
        bundle.privateKeyDer = keyData;
        try {
            bundle.rsa_key = netplus::rsa(keyData);
        } catch (...) {
            // Not RSA — try EC P-256
            // If keyData is PEM, pemKeyToDer already produced DER inside rsa ctor;
            // but we need DER for EC extraction too.  Re-decode if PEM.
            std::vector<uint8_t> der;
            if (keyData.size() > 11 &&
                std::memcmp(keyData.data(), "-----BEGIN ", 11) == 0) {
                // Reuse the same PEM-stripping logic as rsa.cpp's pemKeyToDer
                std::string pem(keyData.begin(), keyData.end());
                auto b = pem.find("-----BEGIN ");
                auto bEnd = pem.find("-----", b + 11);
                if (bEnd != std::string::npos) {
                    bEnd += 5;
                    auto e = pem.find("-----END ", bEnd);
                    if (e != std::string::npos) {
                        std::string body = pem.substr(bEnd, e - bEnd);
                        std::string b64;
                        b64.reserve(body.size());
                        for (unsigned char c : body) {
                            if (std::isalnum(c) || c == '+' || c == '/' || c == '=')
                                b64.push_back(static_cast<char>(c));
                        }
                        netplus::base64 dec;
                        dec << b64;
                        std::vector<char> decoded;
                        dec >> decoded;
                        der.assign(decoded.begin(), decoded.end());
                    }
                }
            } else {
                der = keyData;
            }
            if (extractECPrivateKey(der, bundle.ecPrivateKey))
                bundle.has_ec_key = true;
        }
    }

    // ---- CertificateBundle::loadFromFile ----
    // Auto-detect format:
    //   Inline PEM  → certPath/keyPath start with "-----BEGIN"
@@ -4435,8 +4565,7 @@ namespace netplus {
            // Key: use keyPath if given, otherwise look for key in certPath
            const std::string& keySource = keyPath.empty() ? certPath : keyPath;
            std::vector<uint8_t> keyData(keySource.begin(), keySource.end());
            privateKeyDer = keyData;
            rsa_key = netplus::rsa(keyData);  // auto-detects PEM vs DER
            loadPrivateKey(keyData, *this);
            return true;
        }

@@ -4464,8 +4593,7 @@ namespace netplus {
            chain = std::move(p12.chainDer);

            if (!p12.keyDer.empty()) {
                privateKeyDer = p12.keyDer;
                rsa_key = netplus::rsa(p12.keyDer);
                loadPrivateKey(p12.keyDer, *this);
            }
            return true;
        }
@@ -4488,8 +4616,7 @@ namespace netplus {
        if (kp.size() > 11 &&
            kp.compare(0, 11, "-----BEGIN ") == 0) {
            std::vector<uint8_t> keyData(kp.begin(), kp.end());
            privateKeyDer = keyData;
            rsa_key = netplus::rsa(keyData);
            loadPrivateKey(keyData, *this);
            return true;
        }

@@ -4504,8 +4631,7 @@ namespace netplus {
        if (!keyFile.read(reinterpret_cast<char*>(keyData.data()), ksize))
            return false;

        privateKeyDer = keyData;
        rsa_key = netplus::rsa(keyData);  // auto-detects PEM vs DER
        loadPrivateKey(keyData, *this);
        return true;
    }

+2 −0
Original line number Diff line number Diff line
@@ -113,6 +113,8 @@ namespace netplus {
            netplus::x509cert    cert;
            std::vector<uint8_t> privateKeyDer;
            netplus::rsa         rsa_key;  // Pre-loaded RSA private key
            uint8_t              ecPrivateKey[32] = {0};  // EC P-256 private scalar
            bool                 has_ec_key = false;
            std::vector<std::vector<uint8_t>> chain;  // CA chain certs in DER (per-SNI)

            // Load cert+key from file(s) or inline PEM data.