Loading src/crypto/tls.cpp +283 −50 Original line number Diff line number Diff line Loading @@ -476,8 +476,22 @@ namespace netplus { handshakeStarted = true; is_tls13 = false; chosenSuite = 0; if (hasSuite(0x002F)) chosenSuite = 0x002F; else if (hasSuite(0x0035)) chosenSuite = 0x0035; // Prefer ECDHE-GCM suites (work with both RSA and ECDSA certs) bool have_ec_cert = (selected_cert_bundle && selected_cert_bundle->has_ec_key); bool have_rsa_cert = (selected_cert_bundle && selected_cert_bundle->rsa_key); if (have_ec_cert) { // ECDSA certificate: only ECDHE_ECDSA suites if (hasSuite(0xC02C)) chosenSuite = 0xC02C; // ECDHE_ECDSA_AES_256_GCM_SHA384 else if (hasSuite(0xC02B)) chosenSuite = 0xC02B; // ECDHE_ECDSA_AES_128_GCM_SHA256 } else if (have_rsa_cert) { // RSA certificate: prefer ECDHE_RSA, fall back to static RSA if (hasSuite(0xC030)) chosenSuite = 0xC030; // ECDHE_RSA_AES_256_GCM_SHA384 else if (hasSuite(0xC02F)) chosenSuite = 0xC02F; // ECDHE_RSA_AES_128_GCM_SHA256 else if (hasSuite(0x002F)) chosenSuite = 0x002F; // RSA_AES_128_CBC_SHA else if (hasSuite(0x0035)) chosenSuite = 0x0035; // RSA_AES_256_CBC_SHA } if (chosenSuite == 0) throwSSL(NetException::Error, "No supported cipher suite"); Loading Loading @@ -857,6 +871,61 @@ namespace netplus { } sendHandshake(0x0b, certMsg); // --- ServerKeyExchange for ECDHE suites --- bool isECDHE = (chosenSuite == 0xC02B || chosenSuite == 0xC02C || chosenSuite == 0xC02F || chosenSuite == 0xC030); if (isECDHE) { // Generate ephemeral ECDHE P-256 key pair gen_tls13_p256_scalar(tls12_ecdhe_priv); netplus::P256Point pub = netplus::scalar_mul_G(tls12_ecdhe_priv); std::vector<uint8_t> pubBytes = netplus::encode_tls_point(pub); // Build ServerKeyExchange params (ECParameters + ECPoint) // ECParameters: curve_type(1) = named_curve(3), named_curve(2) = secp256r1(0x0017) std::vector<uint8_t> params; params.push_back(0x03); // named_curve params.push_back(0x00); params.push_back(0x17); // secp256r1 params.push_back((uint8_t)pubBytes.size()); // point length params.insert(params.end(), pubBytes.begin(), pubBytes.end()); // Sign: SHA-256(client_random + server_random + params) std::vector<uint8_t> toSign; toSign.insert(toSign.end(), clientRandom.begin(), clientRandom.end()); toSign.insert(toSign.end(), serverRandom.begin(), serverRandom.end()); toSign.insert(toSign.end(), params.begin(), params.end()); std::vector<uint8_t> sig; uint16_t sigAlg; if (chosenSuite == 0xC02B || chosenSuite == 0xC02C) { // ECDHE_ECDSA: sign with ECDSA if (!selected_cert_bundle->has_ec_key) throwSSL(NetException::Error, "ECDHE_ECDSA requires EC key"); std::vector<uint8_t> ecKey(selected_cert_bundle->ecPrivateKey, selected_cert_bundle->ecPrivateKey + 32); sig = sign_ecdsa_sha256(toSign, ecKey); sigAlg = 0x0403; // ecdsa_secp256r1_sha256 } else { // ECDHE_RSA: sign with RSA PKCS#1 v1.5 if (!selected_cert_bundle->rsa_key) throwSSL(NetException::Error, "ECDHE_RSA requires RSA key"); sig = sign_rsa_sha256_pkcs15(toSign, selected_cert_bundle->rsa_key); sigAlg = 0x0401; // rsa_pkcs1_sha256 } // Build ServerKeyExchange body: params + sig_algorithm(2) + sig_len(2) + sig std::vector<uint8_t> skeBody; skeBody.insert(skeBody.end(), params.begin(), params.end()); skeBody.push_back((sigAlg >> 8) & 0xFF); skeBody.push_back(sigAlg & 0xFF); skeBody.push_back((uint16_t(sig.size()) >> 8) & 0xFF); skeBody.push_back(uint16_t(sig.size()) & 0xFF); skeBody.insert(skeBody.end(), sig.begin(), sig.end()); sendHandshake(0x0c, skeBody); // ServerKeyExchange } sendHandshake(0x0e, {}); serverFlightQueued = true; Loading Loading @@ -938,7 +1007,7 @@ namespace netplus { uint16_t finVer = ((uint16_t)finRec[1] << 8) | (uint16_t)finRec[2]; std::vector<uint8_t> finFrag(finRec.begin() + 5, finRec.end()); std::vector<uint8_t> finPT = decryptRecordCBC(0x16, finVer, finFrag); std::vector<uint8_t> finPT = decryptTLS12Record(0x16, finVer, finFrag); if (finPT.size() < 4 + 12 || finPT[0] != 0x14 || finPT[1] != 0x00 || finPT[2] != 0x00 || finPT[3] != 0x0c) { Loading Loading @@ -1047,17 +1116,44 @@ namespace netplus { if (msg.size() != 4 + hlen) throwSSL(NetException::Error, "ClientKeyExchange length mismatch"); std::vector<uint8_t> preMaster; bool isECDHE = (chosenSuite == 0xC02B || chosenSuite == 0xC02C || chosenSuite == 0xC02F || chosenSuite == 0xC030); if (isECDHE) { // ECDHE CKE: body = point_length(1) + uncompressed_point(65) size_t off = 4; if (hlen < 1) throwSSL(NetException::Error, "ECDHE CKE body too short"); uint8_t ptLen = msg[off]; off++; if (ptLen != 65 || off + ptLen > msg.size()) throwSSL(NetException::Error, "ECDHE CKE invalid point length"); if (msg[off] != 0x04) throwSSL(NetException::Error, "ECDHE CKE point not uncompressed"); netplus::P256Point clientPub; if (!netplus::decode_tls_point(clientPub, msg.data() + off, 65)) throwSSL(NetException::Error, "ECDHE CKE invalid EC point"); if (!netplus::is_on_curve(clientPub)) throwSSL(NetException::Error, "ECDHE CKE point not on curve"); // Compute shared secret preMaster.resize(32); if (!netplus::ecdh_shared_secret(preMaster.data(), tls12_ecdhe_priv, clientPub)) throwSSL(NetException::Error, "ECDHE shared secret computation failed"); tls12_is_gcm = true; } else { // RSA CKE body: uint16 encLen + ciphertext if (hlen < 2) throwSSL(NetException::Error, "ClientKeyExchange body too short"); throwSSL(NetException::Error, "RSA CKE body too short"); size_t off = 4; uint16_t encLen = (uint16_t(msg[off]) << 8) | msg[off + 1]; off += 2; if (off + encLen > msg.size()) throwSSL(NetException::Error, "ClientKeyExchange truncated ciphertext"); throwSSL(NetException::Error, "RSA CKE truncated ciphertext"); const size_t kBytes = (selected_cert_bundle->rsa_key.n.bitLength() + 7) / 8; if (encLen != kBytes) { Loading @@ -1066,47 +1162,62 @@ namespace netplus { " kBytes=" + std::to_string(kBytes)); } #if SSL_DEBUG #endif rsa::bigInt cipher = rsa::bigIntFromBytesBE(msg.data() + off, encLen); rsa::bigInt plainBI = selected_cert_bundle->rsa_key.decrypt(cipher); #if SSL_DEBUG #endif std::vector<uint8_t> pkcs1 = rsa::bigIntToBytesBE(plainBI, kBytes); std::vector<uint8_t> preMaster = extractPreMasterFromPkcs1(pkcs1); #if SSL_DEBUG #endif preMaster = extractPreMasterFromPkcs1(pkcs1); } // master_secret = PRF(PMS, "master secret", client_random || server_random) std::vector<uint8_t> msSeed = clientRandom; msSeed.insert(msSeed.end(), serverRandom.begin(), serverRandom.end()); masterSecret = prf(preMaster, "master secret", msSeed, 48); #if SSL_DEBUG #endif // key_block = PRF(master, "key expansion", server_random || client_random) std::vector<uint8_t> kbSeed = serverRandom; kbSeed.insert(kbSeed.end(), clientRandom.begin(), clientRandom.end()); std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, 72); #if SSL_DEBUG #endif if (tls12_is_gcm) { // GCM key block: client_write_key + server_write_key + client_write_IV(4) + server_write_IV(4) size_t keyLen = (chosenSuite == 0xC02C || chosenSuite == 0xC030) ? 32 : 16; size_t kbLen = keyLen + keyLen + 4 + 4; std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, kbLen); size_t k = 0; std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen; std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen; std::memcpy(tls12_client_write_iv, keyBlock.data() + k, 4); k += 4; std::memcpy(tls12_server_write_iv, keyBlock.data() + k, 4); if (keyLen == 32) { aes_recv = std::make_unique<aes256>(clientKey); aes = std::make_unique<aes256>(serverKey); } else { aes_recv = std::make_unique<aes128>(clientKey); aes = std::make_unique<aes128>(serverKey); } } else { // CBC key block: client_mac + server_mac + client_key + server_key size_t keyLen = (chosenSuite == 0x0035) ? 32 : 16; size_t kbLen = 20 + 20 + keyLen + keyLen; std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, kbLen); size_t k = 0; client_mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20; mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20; std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + 16); k += 16; std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + 16); std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen; std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); if (keyLen == 32) { aes_recv = std::make_unique<aes256>(clientKey); aes = std::make_unique<aes256>(serverKey); } else { aes_recv = std::make_unique<aes128>(clientKey); aes = std::make_unique<aes128>(serverKey); } } hs_state = HsState::WAIT_CCS; return; Loading Loading @@ -2238,7 +2349,7 @@ namespace netplus { std::vector<uint8_t> finFrag(rec.begin() + 5, rec.end()); // Decrypt the record std::vector<uint8_t> finPT = decryptRecordCBC(0x16, finVer, finFrag); std::vector<uint8_t> finPT = decryptTLS12Record(0x16, finVer, finFrag); if (finPT.size() < 4 + 12 || finPT[0] != 0x14 || finPT[1] != 0x00 || finPT[2] != 0x00 || finPT[3] != 0x0c) { Loading Loading @@ -2368,7 +2479,7 @@ namespace netplus { if (!aes_recv) throwSSL(NetException::Error, "aes_recv missing"); std::vector<uint8_t> plain = decryptRecordCBC(outer_type, ver, frag); std::vector<uint8_t> plain = decryptTLS12Record(outer_type, ver, frag); if (plain.empty()) continue; Loading Loading @@ -3166,6 +3277,8 @@ namespace netplus { // sendTLS12Record - Encrypt and send TLS 1.2 record // ============================================================================ void tls::sendTLS12Record(uint8_t type, const std::vector<uint8_t>& plain) { if (tls12_is_gcm) { sendTLS12RecordGCM(type, plain); return; } const uint16_t ver = 0x0303; const size_t macLen = 20; const size_t blockSize = 16; Loading Loading @@ -3236,6 +3349,8 @@ namespace netplus { } void tls::sendTLS12Record(uint8_t type, const uint8_t* plain_data, size_t plain_len) { if (tls12_is_gcm) { sendTLS12RecordGCM(type, plain_data, plain_len); return; } const uint16_t ver = 0x0303; const size_t macLen = 20; const size_t blockSize = 16; Loading Loading @@ -3304,6 +3419,122 @@ namespace netplus { send_seq++; } // ============================================================================ // decryptRecordGCM - Decrypt TLS 1.2 GCM record // ============================================================================ std::vector<uint8_t> tls::decryptRecordGCM( uint8_t type, uint16_t ver, const std::vector<uint8_t>& payload) { if (!aes_recv) { NetException e; e[NetException::Error] << "tls: decryptRecordGCM called but aes_recv is NULL"; throw e; } // GCM record payload: explicit_nonce(8) + ciphertext + tag(16) if (payload.size() < 8 + 16) throwSSL(NetException::Error, "GCM record too short"); const uint8_t* explicit_nonce = payload.data(); size_t ct_len = payload.size() - 8 - 16; const uint8_t* ct = payload.data() + 8; const uint8_t* tag = payload.data() + 8 + ct_len; // Nonce: implicit_iv(4) || explicit_nonce(8) = 12 bytes uint8_t nonce[12]; const uint8_t* recv_iv = is_client ? tls12_server_write_iv : tls12_client_write_iv; std::memcpy(nonce, recv_iv, 4); std::memcpy(nonce + 4, explicit_nonce, 8); // AAD: seq(8) + type(1) + version(2) + plaintext_length(2) = 13 bytes uint8_t aad[13]; seqToBytes(recv_seq, aad); aad[8] = type; aad[9] = uint8_t(ver >> 8); aad[10] = uint8_t(ver & 0xFF); aad[11] = uint8_t((ct_len >> 8) & 0xFF); aad[12] = uint8_t(ct_len & 0xFF); std::vector<uint8_t> plaintext(ct_len); uint8_t tag_copy[16]; std::memcpy(tag_copy, tag, 16); if (!aes_recv->aes_gcm_decrypt(nonce, aad, 13, ct, ct_len, tag_copy, plaintext.data())) throwSSL(NetException::Error, "GCM tag verification failed"); recv_seq++; return plaintext; } // ============================================================================ // decryptTLS12Record - Dispatch to CBC or GCM decrypt // ============================================================================ std::vector<uint8_t> tls::decryptTLS12Record( uint8_t type, uint16_t ver, const std::vector<uint8_t>& payload) { if (tls12_is_gcm) return decryptRecordGCM(type, ver, payload); return decryptRecordCBC(type, ver, payload); } // ============================================================================ // sendTLS12RecordGCM - Encrypt and send TLS 1.2 GCM record (vector overload) // ============================================================================ void tls::sendTLS12RecordGCM(uint8_t type, const std::vector<uint8_t>& plain) { sendTLS12RecordGCM(type, plain.data(), plain.size()); } // ============================================================================ // sendTLS12RecordGCM - Encrypt and send TLS 1.2 GCM record (pointer overload) // ============================================================================ void tls::sendTLS12RecordGCM(uint8_t type, const uint8_t* plain_data, size_t plain_len) { const uint16_t ver = 0x0303; if (!aes) { throwSSL(NetException::Error, "sendTLS12RecordGCM: AES cipher not initialized"); } // Build nonce: implicit_iv(4) || explicit_nonce(8) = 12 bytes // Use send_seq as explicit nonce uint8_t explicit_nonce[8]; seqToBytes(send_seq, explicit_nonce); uint8_t nonce[12]; const uint8_t* send_iv = is_client ? tls12_client_write_iv : tls12_server_write_iv; std::memcpy(nonce, send_iv, 4); std::memcpy(nonce + 4, explicit_nonce, 8); // AAD: seq(8) + type(1) + version(2) + plaintext_length(2) = 13 bytes uint8_t aad[13]; seqToBytes(send_seq, aad); aad[8] = type; aad[9] = uint8_t(ver >> 8); aad[10] = uint8_t(ver & 0xFF); aad[11] = uint8_t((plain_len >> 8) & 0xFF); aad[12] = uint8_t(plain_len & 0xFF); // Encrypt std::vector<uint8_t> ct(plain_len); uint8_t tag[16]; aes->aes_gcm_encrypt(nonce, aad, 13, plain_data, plain_len, ct.data(), tag); // Record payload: explicit_nonce(8) + ciphertext + tag(16) size_t payload_len = 8 + plain_len + 16; uint8_t recHdr[5]; recHdr[0] = type; recHdr[1] = uint8_t(ver >> 8); recHdr[2] = uint8_t(ver & 0xFF); recHdr[3] = uint8_t((payload_len >> 8) & 0xFF); recHdr[4] = uint8_t(payload_len & 0xFF); queueRaw(recHdr, 5); queueRaw(explicit_nonce, 8); queueRaw(ct.data(), ct.size()); queueRaw(tag, 16); send_seq++; } // ============================================================================ // sendTLS13Record - Encrypt and send TLS 1.3 application data record // ============================================================================ Loading Loading @@ -3568,7 +3799,7 @@ namespace netplus { if (ccs_received && aes_recv) { uint16_t ver = (uint16_t(rec[1]) << 8) | uint16_t(rec[2]); fragment = decryptRecordCBC(recType, ver, fragment); fragment = decryptTLS12Record(recType, ver, fragment); } rx_handshake_buf.insert(rx_handshake_buf.end(), Loading Loading @@ -3613,8 +3844,10 @@ namespace netplus { throw e; }; // Sanity: allow only TLS1.2 suites here if (chosenSuiteArg != 0x002F && chosenSuiteArg != 0x0035) { // Sanity: allow TLS1.2 suites (RSA + ECDHE-GCM) if (chosenSuiteArg != 0x002F && chosenSuiteArg != 0x0035 && chosenSuiteArg != 0xC02B && chosenSuiteArg != 0xC02C && chosenSuiteArg != 0xC02F && chosenSuiteArg != 0xC030) { throwSSL(netplus::NetException::Error, "Invalid TLS1.2 cipher suite selected"); } Loading src/crypto/tls.h +29 −0 Original line number Diff line number Diff line Loading @@ -280,6 +280,18 @@ namespace netplus { const std::vector<uint8_t>& frag ); std::vector<uint8_t> decryptRecordGCM( uint8_t recType, uint16_t ver, const std::vector<uint8_t>& frag ); std::vector<uint8_t> decryptTLS12Record( uint8_t recType, uint16_t ver, const std::vector<uint8_t>& frag ); void sendTLS12Record( uint8_t recType, const std::vector<uint8_t>& content Loading @@ -289,6 +301,15 @@ namespace netplus { const uint8_t* data, size_t len ); void sendTLS12RecordGCM( uint8_t recType, const std::vector<uint8_t>& content ); void sendTLS12RecordGCM( uint8_t recType, const uint8_t* data, size_t len ); // --- TLS 1.2 Handshake Operations --- bool popHandshakeMsg( std::vector<uint8_t>& out, Loading Loading @@ -453,6 +474,14 @@ namespace netplus { std::vector<uint8_t> mac_key; std::vector<uint8_t> client_mac_key; // --- TLS 1.2 GCM implicit IVs (4 bytes each) --- uint8_t tls12_client_write_iv[4] = {0}; uint8_t tls12_server_write_iv[4] = {0}; bool tls12_is_gcm = false; // --- TLS 1.2 ECDHE state --- uint8_t tls12_ecdhe_priv[32] = {0}; // Server ephemeral P-256 private key // --- TLS 1.2 Record Layer Encryption/Decryption --- std::unique_ptr<netplus::aes> aes = nullptr; std::unique_ptr<netplus::aes> aes_recv = nullptr; Loading Loading
src/crypto/tls.cpp +283 −50 Original line number Diff line number Diff line Loading @@ -476,8 +476,22 @@ namespace netplus { handshakeStarted = true; is_tls13 = false; chosenSuite = 0; if (hasSuite(0x002F)) chosenSuite = 0x002F; else if (hasSuite(0x0035)) chosenSuite = 0x0035; // Prefer ECDHE-GCM suites (work with both RSA and ECDSA certs) bool have_ec_cert = (selected_cert_bundle && selected_cert_bundle->has_ec_key); bool have_rsa_cert = (selected_cert_bundle && selected_cert_bundle->rsa_key); if (have_ec_cert) { // ECDSA certificate: only ECDHE_ECDSA suites if (hasSuite(0xC02C)) chosenSuite = 0xC02C; // ECDHE_ECDSA_AES_256_GCM_SHA384 else if (hasSuite(0xC02B)) chosenSuite = 0xC02B; // ECDHE_ECDSA_AES_128_GCM_SHA256 } else if (have_rsa_cert) { // RSA certificate: prefer ECDHE_RSA, fall back to static RSA if (hasSuite(0xC030)) chosenSuite = 0xC030; // ECDHE_RSA_AES_256_GCM_SHA384 else if (hasSuite(0xC02F)) chosenSuite = 0xC02F; // ECDHE_RSA_AES_128_GCM_SHA256 else if (hasSuite(0x002F)) chosenSuite = 0x002F; // RSA_AES_128_CBC_SHA else if (hasSuite(0x0035)) chosenSuite = 0x0035; // RSA_AES_256_CBC_SHA } if (chosenSuite == 0) throwSSL(NetException::Error, "No supported cipher suite"); Loading Loading @@ -857,6 +871,61 @@ namespace netplus { } sendHandshake(0x0b, certMsg); // --- ServerKeyExchange for ECDHE suites --- bool isECDHE = (chosenSuite == 0xC02B || chosenSuite == 0xC02C || chosenSuite == 0xC02F || chosenSuite == 0xC030); if (isECDHE) { // Generate ephemeral ECDHE P-256 key pair gen_tls13_p256_scalar(tls12_ecdhe_priv); netplus::P256Point pub = netplus::scalar_mul_G(tls12_ecdhe_priv); std::vector<uint8_t> pubBytes = netplus::encode_tls_point(pub); // Build ServerKeyExchange params (ECParameters + ECPoint) // ECParameters: curve_type(1) = named_curve(3), named_curve(2) = secp256r1(0x0017) std::vector<uint8_t> params; params.push_back(0x03); // named_curve params.push_back(0x00); params.push_back(0x17); // secp256r1 params.push_back((uint8_t)pubBytes.size()); // point length params.insert(params.end(), pubBytes.begin(), pubBytes.end()); // Sign: SHA-256(client_random + server_random + params) std::vector<uint8_t> toSign; toSign.insert(toSign.end(), clientRandom.begin(), clientRandom.end()); toSign.insert(toSign.end(), serverRandom.begin(), serverRandom.end()); toSign.insert(toSign.end(), params.begin(), params.end()); std::vector<uint8_t> sig; uint16_t sigAlg; if (chosenSuite == 0xC02B || chosenSuite == 0xC02C) { // ECDHE_ECDSA: sign with ECDSA if (!selected_cert_bundle->has_ec_key) throwSSL(NetException::Error, "ECDHE_ECDSA requires EC key"); std::vector<uint8_t> ecKey(selected_cert_bundle->ecPrivateKey, selected_cert_bundle->ecPrivateKey + 32); sig = sign_ecdsa_sha256(toSign, ecKey); sigAlg = 0x0403; // ecdsa_secp256r1_sha256 } else { // ECDHE_RSA: sign with RSA PKCS#1 v1.5 if (!selected_cert_bundle->rsa_key) throwSSL(NetException::Error, "ECDHE_RSA requires RSA key"); sig = sign_rsa_sha256_pkcs15(toSign, selected_cert_bundle->rsa_key); sigAlg = 0x0401; // rsa_pkcs1_sha256 } // Build ServerKeyExchange body: params + sig_algorithm(2) + sig_len(2) + sig std::vector<uint8_t> skeBody; skeBody.insert(skeBody.end(), params.begin(), params.end()); skeBody.push_back((sigAlg >> 8) & 0xFF); skeBody.push_back(sigAlg & 0xFF); skeBody.push_back((uint16_t(sig.size()) >> 8) & 0xFF); skeBody.push_back(uint16_t(sig.size()) & 0xFF); skeBody.insert(skeBody.end(), sig.begin(), sig.end()); sendHandshake(0x0c, skeBody); // ServerKeyExchange } sendHandshake(0x0e, {}); serverFlightQueued = true; Loading Loading @@ -938,7 +1007,7 @@ namespace netplus { uint16_t finVer = ((uint16_t)finRec[1] << 8) | (uint16_t)finRec[2]; std::vector<uint8_t> finFrag(finRec.begin() + 5, finRec.end()); std::vector<uint8_t> finPT = decryptRecordCBC(0x16, finVer, finFrag); std::vector<uint8_t> finPT = decryptTLS12Record(0x16, finVer, finFrag); if (finPT.size() < 4 + 12 || finPT[0] != 0x14 || finPT[1] != 0x00 || finPT[2] != 0x00 || finPT[3] != 0x0c) { Loading Loading @@ -1047,17 +1116,44 @@ namespace netplus { if (msg.size() != 4 + hlen) throwSSL(NetException::Error, "ClientKeyExchange length mismatch"); std::vector<uint8_t> preMaster; bool isECDHE = (chosenSuite == 0xC02B || chosenSuite == 0xC02C || chosenSuite == 0xC02F || chosenSuite == 0xC030); if (isECDHE) { // ECDHE CKE: body = point_length(1) + uncompressed_point(65) size_t off = 4; if (hlen < 1) throwSSL(NetException::Error, "ECDHE CKE body too short"); uint8_t ptLen = msg[off]; off++; if (ptLen != 65 || off + ptLen > msg.size()) throwSSL(NetException::Error, "ECDHE CKE invalid point length"); if (msg[off] != 0x04) throwSSL(NetException::Error, "ECDHE CKE point not uncompressed"); netplus::P256Point clientPub; if (!netplus::decode_tls_point(clientPub, msg.data() + off, 65)) throwSSL(NetException::Error, "ECDHE CKE invalid EC point"); if (!netplus::is_on_curve(clientPub)) throwSSL(NetException::Error, "ECDHE CKE point not on curve"); // Compute shared secret preMaster.resize(32); if (!netplus::ecdh_shared_secret(preMaster.data(), tls12_ecdhe_priv, clientPub)) throwSSL(NetException::Error, "ECDHE shared secret computation failed"); tls12_is_gcm = true; } else { // RSA CKE body: uint16 encLen + ciphertext if (hlen < 2) throwSSL(NetException::Error, "ClientKeyExchange body too short"); throwSSL(NetException::Error, "RSA CKE body too short"); size_t off = 4; uint16_t encLen = (uint16_t(msg[off]) << 8) | msg[off + 1]; off += 2; if (off + encLen > msg.size()) throwSSL(NetException::Error, "ClientKeyExchange truncated ciphertext"); throwSSL(NetException::Error, "RSA CKE truncated ciphertext"); const size_t kBytes = (selected_cert_bundle->rsa_key.n.bitLength() + 7) / 8; if (encLen != kBytes) { Loading @@ -1066,47 +1162,62 @@ namespace netplus { " kBytes=" + std::to_string(kBytes)); } #if SSL_DEBUG #endif rsa::bigInt cipher = rsa::bigIntFromBytesBE(msg.data() + off, encLen); rsa::bigInt plainBI = selected_cert_bundle->rsa_key.decrypt(cipher); #if SSL_DEBUG #endif std::vector<uint8_t> pkcs1 = rsa::bigIntToBytesBE(plainBI, kBytes); std::vector<uint8_t> preMaster = extractPreMasterFromPkcs1(pkcs1); #if SSL_DEBUG #endif preMaster = extractPreMasterFromPkcs1(pkcs1); } // master_secret = PRF(PMS, "master secret", client_random || server_random) std::vector<uint8_t> msSeed = clientRandom; msSeed.insert(msSeed.end(), serverRandom.begin(), serverRandom.end()); masterSecret = prf(preMaster, "master secret", msSeed, 48); #if SSL_DEBUG #endif // key_block = PRF(master, "key expansion", server_random || client_random) std::vector<uint8_t> kbSeed = serverRandom; kbSeed.insert(kbSeed.end(), clientRandom.begin(), clientRandom.end()); std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, 72); #if SSL_DEBUG #endif if (tls12_is_gcm) { // GCM key block: client_write_key + server_write_key + client_write_IV(4) + server_write_IV(4) size_t keyLen = (chosenSuite == 0xC02C || chosenSuite == 0xC030) ? 32 : 16; size_t kbLen = keyLen + keyLen + 4 + 4; std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, kbLen); size_t k = 0; std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen; std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen; std::memcpy(tls12_client_write_iv, keyBlock.data() + k, 4); k += 4; std::memcpy(tls12_server_write_iv, keyBlock.data() + k, 4); if (keyLen == 32) { aes_recv = std::make_unique<aes256>(clientKey); aes = std::make_unique<aes256>(serverKey); } else { aes_recv = std::make_unique<aes128>(clientKey); aes = std::make_unique<aes128>(serverKey); } } else { // CBC key block: client_mac + server_mac + client_key + server_key size_t keyLen = (chosenSuite == 0x0035) ? 32 : 16; size_t kbLen = 20 + 20 + keyLen + keyLen; std::vector<uint8_t> keyBlock = prf(masterSecret, "key expansion", kbSeed, kbLen); size_t k = 0; client_mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20; mac_key.assign(keyBlock.begin() + k, keyBlock.begin() + k + 20); k += 20; std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + 16); k += 16; std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + 16); std::vector<uint8_t> clientKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); k += keyLen; std::vector<uint8_t> serverKey(keyBlock.begin() + k, keyBlock.begin() + k + keyLen); if (keyLen == 32) { aes_recv = std::make_unique<aes256>(clientKey); aes = std::make_unique<aes256>(serverKey); } else { aes_recv = std::make_unique<aes128>(clientKey); aes = std::make_unique<aes128>(serverKey); } } hs_state = HsState::WAIT_CCS; return; Loading Loading @@ -2238,7 +2349,7 @@ namespace netplus { std::vector<uint8_t> finFrag(rec.begin() + 5, rec.end()); // Decrypt the record std::vector<uint8_t> finPT = decryptRecordCBC(0x16, finVer, finFrag); std::vector<uint8_t> finPT = decryptTLS12Record(0x16, finVer, finFrag); if (finPT.size() < 4 + 12 || finPT[0] != 0x14 || finPT[1] != 0x00 || finPT[2] != 0x00 || finPT[3] != 0x0c) { Loading Loading @@ -2368,7 +2479,7 @@ namespace netplus { if (!aes_recv) throwSSL(NetException::Error, "aes_recv missing"); std::vector<uint8_t> plain = decryptRecordCBC(outer_type, ver, frag); std::vector<uint8_t> plain = decryptTLS12Record(outer_type, ver, frag); if (plain.empty()) continue; Loading Loading @@ -3166,6 +3277,8 @@ namespace netplus { // sendTLS12Record - Encrypt and send TLS 1.2 record // ============================================================================ void tls::sendTLS12Record(uint8_t type, const std::vector<uint8_t>& plain) { if (tls12_is_gcm) { sendTLS12RecordGCM(type, plain); return; } const uint16_t ver = 0x0303; const size_t macLen = 20; const size_t blockSize = 16; Loading Loading @@ -3236,6 +3349,8 @@ namespace netplus { } void tls::sendTLS12Record(uint8_t type, const uint8_t* plain_data, size_t plain_len) { if (tls12_is_gcm) { sendTLS12RecordGCM(type, plain_data, plain_len); return; } const uint16_t ver = 0x0303; const size_t macLen = 20; const size_t blockSize = 16; Loading Loading @@ -3304,6 +3419,122 @@ namespace netplus { send_seq++; } // ============================================================================ // decryptRecordGCM - Decrypt TLS 1.2 GCM record // ============================================================================ std::vector<uint8_t> tls::decryptRecordGCM( uint8_t type, uint16_t ver, const std::vector<uint8_t>& payload) { if (!aes_recv) { NetException e; e[NetException::Error] << "tls: decryptRecordGCM called but aes_recv is NULL"; throw e; } // GCM record payload: explicit_nonce(8) + ciphertext + tag(16) if (payload.size() < 8 + 16) throwSSL(NetException::Error, "GCM record too short"); const uint8_t* explicit_nonce = payload.data(); size_t ct_len = payload.size() - 8 - 16; const uint8_t* ct = payload.data() + 8; const uint8_t* tag = payload.data() + 8 + ct_len; // Nonce: implicit_iv(4) || explicit_nonce(8) = 12 bytes uint8_t nonce[12]; const uint8_t* recv_iv = is_client ? tls12_server_write_iv : tls12_client_write_iv; std::memcpy(nonce, recv_iv, 4); std::memcpy(nonce + 4, explicit_nonce, 8); // AAD: seq(8) + type(1) + version(2) + plaintext_length(2) = 13 bytes uint8_t aad[13]; seqToBytes(recv_seq, aad); aad[8] = type; aad[9] = uint8_t(ver >> 8); aad[10] = uint8_t(ver & 0xFF); aad[11] = uint8_t((ct_len >> 8) & 0xFF); aad[12] = uint8_t(ct_len & 0xFF); std::vector<uint8_t> plaintext(ct_len); uint8_t tag_copy[16]; std::memcpy(tag_copy, tag, 16); if (!aes_recv->aes_gcm_decrypt(nonce, aad, 13, ct, ct_len, tag_copy, plaintext.data())) throwSSL(NetException::Error, "GCM tag verification failed"); recv_seq++; return plaintext; } // ============================================================================ // decryptTLS12Record - Dispatch to CBC or GCM decrypt // ============================================================================ std::vector<uint8_t> tls::decryptTLS12Record( uint8_t type, uint16_t ver, const std::vector<uint8_t>& payload) { if (tls12_is_gcm) return decryptRecordGCM(type, ver, payload); return decryptRecordCBC(type, ver, payload); } // ============================================================================ // sendTLS12RecordGCM - Encrypt and send TLS 1.2 GCM record (vector overload) // ============================================================================ void tls::sendTLS12RecordGCM(uint8_t type, const std::vector<uint8_t>& plain) { sendTLS12RecordGCM(type, plain.data(), plain.size()); } // ============================================================================ // sendTLS12RecordGCM - Encrypt and send TLS 1.2 GCM record (pointer overload) // ============================================================================ void tls::sendTLS12RecordGCM(uint8_t type, const uint8_t* plain_data, size_t plain_len) { const uint16_t ver = 0x0303; if (!aes) { throwSSL(NetException::Error, "sendTLS12RecordGCM: AES cipher not initialized"); } // Build nonce: implicit_iv(4) || explicit_nonce(8) = 12 bytes // Use send_seq as explicit nonce uint8_t explicit_nonce[8]; seqToBytes(send_seq, explicit_nonce); uint8_t nonce[12]; const uint8_t* send_iv = is_client ? tls12_client_write_iv : tls12_server_write_iv; std::memcpy(nonce, send_iv, 4); std::memcpy(nonce + 4, explicit_nonce, 8); // AAD: seq(8) + type(1) + version(2) + plaintext_length(2) = 13 bytes uint8_t aad[13]; seqToBytes(send_seq, aad); aad[8] = type; aad[9] = uint8_t(ver >> 8); aad[10] = uint8_t(ver & 0xFF); aad[11] = uint8_t((plain_len >> 8) & 0xFF); aad[12] = uint8_t(plain_len & 0xFF); // Encrypt std::vector<uint8_t> ct(plain_len); uint8_t tag[16]; aes->aes_gcm_encrypt(nonce, aad, 13, plain_data, plain_len, ct.data(), tag); // Record payload: explicit_nonce(8) + ciphertext + tag(16) size_t payload_len = 8 + plain_len + 16; uint8_t recHdr[5]; recHdr[0] = type; recHdr[1] = uint8_t(ver >> 8); recHdr[2] = uint8_t(ver & 0xFF); recHdr[3] = uint8_t((payload_len >> 8) & 0xFF); recHdr[4] = uint8_t(payload_len & 0xFF); queueRaw(recHdr, 5); queueRaw(explicit_nonce, 8); queueRaw(ct.data(), ct.size()); queueRaw(tag, 16); send_seq++; } // ============================================================================ // sendTLS13Record - Encrypt and send TLS 1.3 application data record // ============================================================================ Loading Loading @@ -3568,7 +3799,7 @@ namespace netplus { if (ccs_received && aes_recv) { uint16_t ver = (uint16_t(rec[1]) << 8) | uint16_t(rec[2]); fragment = decryptRecordCBC(recType, ver, fragment); fragment = decryptTLS12Record(recType, ver, fragment); } rx_handshake_buf.insert(rx_handshake_buf.end(), Loading Loading @@ -3613,8 +3844,10 @@ namespace netplus { throw e; }; // Sanity: allow only TLS1.2 suites here if (chosenSuiteArg != 0x002F && chosenSuiteArg != 0x0035) { // Sanity: allow TLS1.2 suites (RSA + ECDHE-GCM) if (chosenSuiteArg != 0x002F && chosenSuiteArg != 0x0035 && chosenSuiteArg != 0xC02B && chosenSuiteArg != 0xC02C && chosenSuiteArg != 0xC02F && chosenSuiteArg != 0xC030) { throwSSL(netplus::NetException::Error, "Invalid TLS1.2 cipher suite selected"); } Loading
src/crypto/tls.h +29 −0 Original line number Diff line number Diff line Loading @@ -280,6 +280,18 @@ namespace netplus { const std::vector<uint8_t>& frag ); std::vector<uint8_t> decryptRecordGCM( uint8_t recType, uint16_t ver, const std::vector<uint8_t>& frag ); std::vector<uint8_t> decryptTLS12Record( uint8_t recType, uint16_t ver, const std::vector<uint8_t>& frag ); void sendTLS12Record( uint8_t recType, const std::vector<uint8_t>& content Loading @@ -289,6 +301,15 @@ namespace netplus { const uint8_t* data, size_t len ); void sendTLS12RecordGCM( uint8_t recType, const std::vector<uint8_t>& content ); void sendTLS12RecordGCM( uint8_t recType, const uint8_t* data, size_t len ); // --- TLS 1.2 Handshake Operations --- bool popHandshakeMsg( std::vector<uint8_t>& out, Loading Loading @@ -453,6 +474,14 @@ namespace netplus { std::vector<uint8_t> mac_key; std::vector<uint8_t> client_mac_key; // --- TLS 1.2 GCM implicit IVs (4 bytes each) --- uint8_t tls12_client_write_iv[4] = {0}; uint8_t tls12_server_write_iv[4] = {0}; bool tls12_is_gcm = false; // --- TLS 1.2 ECDHE state --- uint8_t tls12_ecdhe_priv[32] = {0}; // Server ephemeral P-256 private key // --- TLS 1.2 Record Layer Encryption/Decryption --- std::unique_ptr<netplus::aes> aes = nullptr; std::unique_ptr<netplus::aes> aes_recv = nullptr; Loading