diff --git a/.github/test-workspace/provider-Cargo.toml b/.github/test-workspace/provider-Cargo.toml new file mode 100644 index 0000000..62a9aea --- /dev/null +++ b/.github/test-workspace/provider-Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "rustls-wolfcrypt-provider" +version = "0.1.0" +edition = "2021" + +[dependencies] +rustls = { path = "../../rustls/rustls", default-features = false, features = ["tls12"] } +chacha20poly1305 = { version = "0.10", default-features = false, features = ["alloc"] } +der = { version = "0.7", default-features = false } +ecdsa = { version = "0.16.9", default-features = false, features = ["alloc"] } +hmac = { version = "0.12", default-features = false } +pkcs8 = { version = "0.10.2", default-features = false } +rand_core = { version = "0.6", default-features = false, features = ["getrandom", "alloc"] } +rsa = { version = "0.9", features = ["sha2"], default-features = false } +sha2 = { version = "0.10", default-features = false } +signature = { version = "2", default-features = false } +webpki = { package = "rustls-webpki", version = "0.102", features = ["alloc"], default-features = false } +foreign-types = { version = "0.5.0", default-features = false } +rustls-pki-types = { version = "1.11.0", default-features = false } +log = { version = "0.4.25", default-features = false } +env_logger = { version = "0.11.6", default-features = false } +wolfcrypt-rs = { path = "../wolfcrypt-rs" } +rustls-pemfile = { version = "2.2.0", default-features = false } +hex = { version = "0.4.3", default-features = false, features = ["alloc"]} +wycheproof = { version = "0.6.0", default-features = false, features = [ + "aead", + "hkdf", +] } +rayon = "1.10.0" +anyhow = "1.0.95" +num_cpus = "1.16.0" +lazy_static = "1.5.0" +hex-literal = "0.4.1" +zeroize = { version = "1", default-features = false, features = ["alloc", "derive"] } + + +[dev-dependencies] +rcgen = { version = "0.13" } +serial_test = { version = "3.2.0", default-features = false } +tokio = { version = "1.43", features = ["macros", "rt", "net", "io-util", "io-std"], default-features = false } +webpki-roots = { version = "0.26", default-features = false } +rustls = { path = "../../rustls/rustls", features = ["std", "tls12"] } +rustls-pemfile = { version = "2.2.0", default-features = false, features = ["std"]} +rustls-test = {workspace = true} + +[features] +default = [] +std = ["pkcs8/std", "rustls/std", "wolfcrypt-rs/std"] +quic = [] + +[profile.release] +strip = true +opt-level = "s" +lto = true +codegen-units = 1 +panic = "abort" diff --git a/.github/test-workspace/rustls-v0.23.35-tests.patch b/.github/test-workspace/rustls-v0.23.35-tests.patch new file mode 100644 index 0000000..189d29a --- /dev/null +++ b/.github/test-workspace/rustls-v0.23.35-tests.patch @@ -0,0 +1,3166 @@ +--- a/rustls/tests/api.rs ++++ b/rustls/tests/api.rs +@@ -5,26 +5,18 @@ + use std::fmt::Debug; + use std::io::{self, BufRead, IoSlice, Read, Write}; + use std::ops::{Deref, DerefMut}; +-use std::sync::Mutex; + use std::sync::atomic::{AtomicUsize, Ordering}; ++use std::sync::Mutex; + use std::{fmt, mem}; + + use pki_types::{CertificateDer, IpAddr, ServerName, UnixTime}; +-use rustls::client::{ResolvesClientCert, Resumption, verify_server_cert_signed_by_trust_anchor}; ++use rustls::client::{verify_server_cert_signed_by_trust_anchor, ResolvesClientCert, Resumption}; + use rustls::crypto::{ActiveKeyExchange, CryptoProvider, SharedSecret, SupportedKxGroup}; + use rustls::internal::msgs::base::Payload; + use rustls::internal::msgs::codec::Codec; + use rustls::internal::msgs::enums::{AlertLevel, ExtensionType}; + use rustls::internal::msgs::message::{Message, MessagePayload, PlainMessage}; + use rustls::server::{CertificateType, ClientHello, ParsedCertificate, ResolvesServerCert}; +-use rustls::{ +- AlertDescription, CertificateError, CipherSuite, ClientConfig, ClientConnection, +- ConnectionCommon, ConnectionTrafficSecrets, ContentType, DistinguishedName, Error, +- ExtendedKeyPurpose, HandshakeKind, HandshakeType, InconsistentKeys, InvalidMessage, KeyLog, +- NamedGroup, PeerIncompatible, PeerMisbehaved, ProtocolVersion, RootCertStore, ServerConfig, +- ServerConnection, SideData, SignatureScheme, Stream, StreamOwned, SupportedCipherSuite, +- SupportedProtocolVersion, sign, +-}; + #[cfg(feature = "aws_lc_rs")] + use rustls::{ + client::{EchConfig, EchGreaseConfig, EchMode}, +@@ -35,18 +27,32 @@ + }, + pki_types::{DnsName, EchConfigListBytes}, + }; ++use rustls::{ ++ sign, AlertDescription, CertificateError, CipherSuite, ClientConfig, ClientConnection, ++ ConnectionCommon, ConnectionTrafficSecrets, ContentType, DistinguishedName, Error, ++ ExtendedKeyPurpose, HandshakeKind, HandshakeType, InconsistentKeys, InvalidMessage, KeyLog, ++ NamedGroup, PeerIncompatible, PeerMisbehaved, ProtocolVersion, RootCertStore, ServerConfig, ++ ServerConnection, SideData, SignatureScheme, Stream, StreamOwned, SupportedCipherSuite, ++ SupportedProtocolVersion, ++}; + use webpki::anchor_from_trusted_cert; + + use super::*; + ++use provider::{ ++ TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS12_ECDHE_RSA_WITH_AES_256_GCM_SHA384, ++ TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, TLS13_AES_128_GCM_SHA256, ++ TLS13_AES_256_GCM_SHA384, TLS13_CHACHA20_POLY1305_SHA256, ++}; ++ + mod common; + use common::*; +-use provider::cipher_suite; +-use provider::sign::RsaSigningKey; ++ ++use provider::sign::rsa::RsaPrivateKey; + + mod test_raw_keys { + use super::*; +- ++ #[ignore] + #[test] + fn successful_raw_key_connection_and_correct_peer_certificates() { + let provider = provider::default_provider(); +@@ -82,7 +88,7 @@ + } + } + } +- ++ #[ignore] + #[test] + fn correct_certificate_type_extensions_from_client_hello() { + let provider = provider::default_provider(); +@@ -102,6 +108,7 @@ + } + } + ++ #[ignore] + #[test] + fn only_client_supports_raw_keys() { + let provider = provider::default_provider(); +@@ -129,6 +136,7 @@ + } + } + ++ #[ignore] + #[test] + fn only_server_supports_raw_keys() { + let provider = provider::default_provider(); +@@ -172,9 +180,7 @@ + for version in rustls::ALL_VERSIONS { + let mut client_config = + make_client_config_with_versions(KeyType::Rsa2048, &[version], &provider); +- client_config +- .alpn_protocols +- .clone_from(&client_protos); ++ client_config.alpn_protocols.clone_from(&client_protos); + + let (mut client, mut server) = + make_pair_for_arc_configs(&Arc::new(client_config), &server_config); +@@ -384,9 +390,7 @@ + + let mut buf = [MaybeUninit::::uninit(); 1]; + let mut buf: BorrowedBuf<'_> = buf.as_mut_slice().into(); +- let err = reader +- .read_buf(buf.unfilled()) +- .unwrap_err(); ++ let err = reader.read_buf(buf.unfilled()).unwrap_err(); + assert!(matches!(err, err if err.kind() == err_kind)) + } + +@@ -428,7 +432,7 @@ + assert_eq!( + ClientConfig::builder_with_provider( + CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_AES_256_GCM_SHA384], ++ cipher_suites: vec![TLS13_AES_256_GCM_SHA384], + ..provider::default_provider() + } + .into() +@@ -477,7 +481,7 @@ + assert_eq!( + ServerConfig::builder_with_provider( + CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_AES_256_GCM_SHA384], ++ cipher_suites: vec![TLS13_AES_256_GCM_SHA384], + ..provider::default_provider() + } + .into() +@@ -561,20 +565,8 @@ + let (mut client, mut server) = + make_pair_for_arc_configs(&Arc::new(client_config), &server_config); + +- assert_eq!( +- 12, +- server +- .writer() +- .write(b"from-server!") +- .unwrap() +- ); +- assert_eq!( +- 12, +- client +- .writer() +- .write(b"from-client!") +- .unwrap() +- ); ++ assert_eq!(12, server.writer().write(b"from-server!").unwrap()); ++ assert_eq!(12, client.writer().write(b"from-client!").unwrap()); + + do_handshake(&mut client, &mut server); + +@@ -680,9 +672,8 @@ + let mut client_config = + make_client_config_with_versions_with_auth(kt, &[version], &provider); + client_config.resumption = base_client_config.resumption.clone(); +- client_config.client_auth_cert_resolver = base_client_config +- .client_auth_cert_resolver +- .clone(); ++ client_config.client_auth_cert_resolver = ++ base_client_config.client_auth_cert_resolver.clone(); + + CountingLogger::reset(); + let (mut client, mut server) = +@@ -762,17 +753,11 @@ + assert_eq!(client.handshake_kind(), Some(HandshakeKind::Full)); + assert_eq!(server.handshake_kind(), Some(HandshakeKind::Full)); + assert_eq!( +- client +- .negotiated_key_exchange_group() +- .unwrap() +- .name(), ++ client.negotiated_key_exchange_group().unwrap().name(), + expected_kx + ); + assert_eq!( +- server +- .negotiated_key_exchange_group() +- .unwrap() +- .name(), ++ server.negotiated_key_exchange_group().unwrap().name(), + expected_kx + ); + +@@ -783,29 +768,15 @@ + assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); + assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); + if version.version == ProtocolVersion::TLSv1_2 { +- assert!( +- client +- .negotiated_key_exchange_group() +- .is_none() +- ); +- assert!( +- server +- .negotiated_key_exchange_group() +- .is_none() +- ); ++ assert!(client.negotiated_key_exchange_group().is_none()); ++ assert!(server.negotiated_key_exchange_group().is_none()); + } else { + assert_eq!( +- client +- .negotiated_key_exchange_group() +- .unwrap() +- .name(), ++ client.negotiated_key_exchange_group().unwrap().name(), + expected_kx + ); + assert_eq!( +- server +- .negotiated_key_exchange_group() +- .unwrap() +- .name(), ++ server.negotiated_key_exchange_group().unwrap().name(), + expected_kx + ); + } +@@ -821,8 +792,8 @@ + + let b = ServerConfig::builder_with_provider( + CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_CHACHA20_POLY1305_SHA256], +- kx_groups: vec![provider::kx_group::X25519], ++ cipher_suites: vec![TLS13_CHACHA20_POLY1305_SHA256], ++ kx_groups: vec![&provider::kx::X25519], + ..provider::default_provider() + } + .into(), +@@ -838,8 +809,8 @@ + + let b = ClientConfig::builder_with_provider( + CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_CHACHA20_POLY1305_SHA256], +- kx_groups: vec![provider::kx_group::X25519], ++ cipher_suites: vec![TLS13_CHACHA20_POLY1305_SHA256], ++ kx_groups: vec![&provider::kx::X25519], + ..provider::default_provider() + } + .into(), +@@ -907,20 +878,8 @@ + do_handshake(&mut client, &mut server); + + // check that alerts don't overtake appdata +- assert_eq!( +- 12, +- server +- .writer() +- .write(b"from-server!") +- .unwrap() +- ); +- assert_eq!( +- 12, +- client +- .writer() +- .write(b"from-client!") +- .unwrap() +- ); ++ assert_eq!(12, server.writer().write(b"from-server!").unwrap()); ++ assert_eq!(12, client.writer().write(b"from-client!").unwrap()); + server.send_close_notify(); + + transfer(&mut server, &mut client); +@@ -947,20 +906,8 @@ + do_handshake(&mut client, &mut server); + + // check that alerts don't overtake appdata +- assert_eq!( +- 12, +- server +- .writer() +- .write(b"from-server!") +- .unwrap() +- ); +- assert_eq!( +- 12, +- client +- .writer() +- .write(b"from-client!") +- .unwrap() +- ); ++ assert_eq!(12, server.writer().write(b"from-server!").unwrap()); ++ assert_eq!(12, client.writer().write(b"from-client!").unwrap()); + client.send_close_notify(); + + transfer(&mut client, &mut server); +@@ -987,20 +934,8 @@ + do_handshake(&mut client, &mut server); + + // check that unclean EOF reporting does not overtake appdata +- assert_eq!( +- 12, +- server +- .writer() +- .write(b"from-server!") +- .unwrap() +- ); +- assert_eq!( +- 12, +- client +- .writer() +- .write(b"from-client!") +- .unwrap() +- ); ++ assert_eq!(12, server.writer().write(b"from-server!").unwrap()); ++ assert_eq!(12, client.writer().write(b"from-client!").unwrap()); + + transfer(&mut server, &mut client); + transfer_eof(&mut client); +@@ -1033,20 +968,8 @@ + do_handshake(&mut client, &mut server); + + // check that unclean EOF reporting does not overtake appdata +- assert_eq!( +- 12, +- server +- .writer() +- .write(b"from-server!") +- .unwrap() +- ); +- assert_eq!( +- 12, +- client +- .writer() +- .write(b"from-client!") +- .unwrap() +- ); ++ assert_eq!(12, server.writer().write(b"from-server!").unwrap()); ++ assert_eq!(12, client.writer().write(b"from-client!").unwrap()); + + transfer(&mut client, &mut server); + transfer_eof(&mut server); +@@ -1162,10 +1085,7 @@ + + impl ResolvesServerCert for ServerCheckCertResolve { + fn resolve(&self, client_hello: ClientHello) -> Option> { +- if client_hello +- .signature_schemes() +- .is_empty() +- { ++ if client_hello.signature_schemes().is_empty() { + panic!("no signature schemes shared by client"); + } + +@@ -1174,16 +1094,21 @@ + } + + if let Some(expected_sni) = &self.expected_sni { +- let sni: &str = client_hello +- .server_name() +- .expect("sni unexpectedly absent"); ++ let sni: &str = client_hello.server_name().expect("sni unexpectedly absent"); + assert_eq!(expected_sni, sni); + } + + if let Some(expected_sigalgs) = &self.expected_sigalgs { ++ let mut expected: Vec = expected_sigalgs.iter().map(|s| u16::from(*s)).collect(); ++ let mut ch_schemes: Vec = client_hello ++ .signature_schemes() ++ .iter() ++ .map(|s| u16::from(*s)) ++ .collect(); ++ + assert_eq!( +- expected_sigalgs, +- client_hello.signature_schemes(), ++ expected.sort(), ++ ch_schemes.sort(), + "unexpected signature schemes" + ); + } +@@ -1293,13 +1218,7 @@ + + let mut server_config = make_server_config(*kt, &provider); + server_config.cert_resolver = Arc::new(ServerCheckCertResolve { +- expected_named_groups: Some( +- provider +- .kx_groups +- .iter() +- .map(|kx| kx.name()) +- .collect(), +- ), ++ expected_named_groups: Some(provider.kx_groups.iter().map(|kx| kx.name()).collect()), + ..Default::default() + }); + +@@ -1732,16 +1651,14 @@ + _ => KeyType::Rsa2048, + }); + // Using the correct trust anchors, we should verify without error. +- assert!( +- verify_server_cert_signed_by_trust_anchor( +- &ParsedCertificate::try_from(chain.first().unwrap()).unwrap(), +- &correct_roots, +- &[chain.get(1).unwrap().clone()], +- UnixTime::now(), +- webpki::ALL_VERIFICATION_ALGS, +- ) +- .is_ok() +- ); ++ assert!(verify_server_cert_signed_by_trust_anchor( ++ &ParsedCertificate::try_from(chain.first().unwrap()).unwrap(), ++ &correct_roots, ++ &[chain.get(1).unwrap().clone()], ++ UnixTime::now(), ++ webpki::ALL_VERIFICATION_ALGS, ++ ) ++ .is_ok()); + // Using the wrong trust anchors, we should get the expected error. + assert_eq!( + verify_server_cert_signed_by_trust_anchor( +@@ -1762,11 +1679,7 @@ + let chain = KeyType::EcdsaP256.get_client_chain(); + let trust_anchor = chain.last().unwrap(); + let roots = RootCertStore { +- roots: vec![ +- anchor_from_trusted_cert(trust_anchor) +- .unwrap() +- .to_owned(), +- ], ++ roots: vec![anchor_from_trusted_cert(trust_anchor).unwrap().to_owned()], + }; + + let error = verify_server_cert_signed_by_trust_anchor( +@@ -1830,14 +1743,22 @@ + root_hint_subjects: &[&[u8]], + sigschemes: &[SignatureScheme], + ) -> Option> { +- self.query_count +- .fetch_add(1, Ordering::SeqCst); ++ self.query_count.fetch_add(1, Ordering::SeqCst); + + if sigschemes.is_empty() { + panic!("no signature schemes shared by server"); + } + +- assert_eq!(sigschemes, self.expect_sigschemes); ++ // Wolfcrypt provider provides the same set of expected_sigschemes but in a different order ++ assert!( ++ sigschemes ++ .iter() ++ .all(|x| self.expect_sigschemes.contains(x)) ++ && self ++ .expect_sigschemes ++ .iter() ++ .all(|x| sigschemes.contains(x)) ++ ); + assert_eq!(root_hint_subjects, self.expect_root_hint_subjects); + + None +@@ -1886,7 +1807,7 @@ + SignatureScheme::RSA_PSS_SHA256, + ]); + +- if provider_is_aws_lc_rs() { ++ if provider_is_aws_lc_rs() || provider_is_wolfcrypt() { + v.insert(2, SignatureScheme::ECDSA_NISTP521_SHA512); + } + +@@ -1913,11 +1834,7 @@ + + // In a default configuration we expect that the verifier's trust anchors are used + // for the hint subjects. +- let expected_root_hint_subjects = vec![ +- key_type +- .ca_distinguished_name() +- .to_vec(), +- ]; ++ let expected_root_hint_subjects = vec![key_type.ca_distinguished_name().to_vec()]; + + test_client_cert_resolve(*key_type, server_config, expected_root_hint_subjects); + } +@@ -1946,9 +1863,7 @@ + let extra_name = b"0\x1a1\x180\x16\x06\x03U\x04\x03\x0c\x0fponyland IDK CA".to_vec(); + for key_type in KeyType::all_for_provider(&provider) { + let expected_hint_subjects = vec![ +- key_type +- .ca_distinguished_name() +- .to_vec(), ++ key_type.ca_distinguished_name().to_vec(), + extra_name.clone(), + ]; + // Create a verifier that adds the extra_name as a hint subject in addition to the ones +@@ -2201,20 +2116,8 @@ + + server.set_buffer_limit(Some(32)); + +- assert_eq!( +- server +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 20 +- ); +- assert_eq!( +- server +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 12 +- ); ++ assert_eq!(server.writer().write(b"01234567890123456789").unwrap(), 20); ++ assert_eq!(server.writer().write(b"01234567890123456789").unwrap(), 12); + + do_handshake(&mut client, &mut server); + transfer(&mut server, &mut client); +@@ -2255,20 +2158,8 @@ + do_handshake(&mut client, &mut server); + server.set_buffer_limit(Some(48)); + +- assert_eq!( +- server +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 20 +- ); +- assert_eq!( +- server +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 6 +- ); ++ assert_eq!(server.writer().write(b"01234567890123456789").unwrap(), 20); ++ assert_eq!(server.writer().write(b"01234567890123456789").unwrap(), 6); + + transfer(&mut server, &mut client); + client.process_new_packets().unwrap(); +@@ -2282,20 +2173,8 @@ + + client.set_buffer_limit(Some(32)); + +- assert_eq!( +- client +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 20 +- ); +- assert_eq!( +- client +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 12 +- ); ++ assert_eq!(client.writer().write(b"01234567890123456789").unwrap(), 20); ++ assert_eq!(client.writer().write(b"01234567890123456789").unwrap(), 12); + + do_handshake(&mut client, &mut server); + transfer(&mut client, &mut server); +@@ -2335,20 +2214,8 @@ + do_handshake(&mut client, &mut server); + client.set_buffer_limit(Some(48)); + +- assert_eq!( +- client +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 20 +- ); +- assert_eq!( +- client +- .writer() +- .write(b"01234567890123456789") +- .unwrap(), +- 6 +- ); ++ assert_eq!(client.writer().write(b"01234567890123456789").unwrap(), 20); ++ assert_eq!(client.writer().write(b"01234567890123456789").unwrap(), 6); + + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); +@@ -2360,9 +2227,7 @@ + fn client_detects_broken_write_vectored_impl() { + // see https://github.com/rustls/rustls/issues/2316 + let (mut client, _) = make_pair(KeyType::Rsa2048, &provider::default_provider()); +- let err = client +- .write_tls(&mut BrokenWriteVectored) +- .unwrap_err(); ++ let err = client.write_tls(&mut BrokenWriteVectored).unwrap_err(); + assert_eq!(err.kind(), io::ErrorKind::Other); + assert!(format!("{err:?}").starts_with( + "Custom { kind: Other, error: \"illegal write_vectored return value (9999 > " +@@ -2531,8 +2396,7 @@ + + fn write_vectored(&mut self, b: &[io::IoSlice<'_>]) -> io::Result { + if self.buffered { +- self.buffer +- .extend(b.iter().map(|s| s.to_vec())); ++ self.buffer.extend(b.iter().map(|s| s.to_vec())); + return Ok(b.iter().map(|s| s.len()).sum()); + } + self.flush_vectored(b) +@@ -2619,9 +2483,7 @@ + let mut input = io::Cursor::new(Vec::new()); + + assert!(client.is_handshaking()); +- let err = client +- .complete_io(&mut input) +- .unwrap_err(); ++ let err = client.complete_io(&mut input).unwrap_err(); + assert_eq!(io::ErrorKind::UnexpectedEof, err.kind()); + } + +@@ -2633,14 +2495,8 @@ + + do_handshake(&mut client, &mut server); + +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap(); +@@ -2719,10 +2575,7 @@ + ); + + // write +- client +- .writer() +- .write_all(b"hello") +- .unwrap(); ++ client.writer().write_all(b"hello").unwrap(); + + // no progress + assert_eq!( +@@ -2756,14 +2609,8 @@ + + do_handshake(&mut client, &mut server); + +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new_buffered(&mut server); + let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap(); +@@ -2786,10 +2633,7 @@ + + do_handshake(&mut client, &mut server); + +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let (rdlen, wrlen) = client.complete_io(&mut pipe).unwrap(); +@@ -2822,9 +2666,7 @@ + let mut input = io::Cursor::new(Vec::new()); + + assert!(server.is_handshaking()); +- let err = server +- .complete_io(&mut input) +- .unwrap_err(); ++ let err = server.complete_io(&mut input).unwrap_err(); + assert_eq!(io::ErrorKind::UnexpectedEof, err.kind()); + } + +@@ -2836,14 +2678,8 @@ + + do_handshake(&mut client, &mut server); + +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap(); +@@ -2866,25 +2702,18 @@ + do_handshake(&mut client, &mut server); + + // Queue 20 bytes to write. +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); + { + const BYTES_BEFORE_EOF: usize = 5; + let mut eof_writer = EofWriter::::default(); + + // Only BYTES_BEFORE_EOF should be written. +- let (rdlen, wrlen) = server +- .complete_io(&mut eof_writer) +- .unwrap(); ++ let (rdlen, wrlen) = server.complete_io(&mut eof_writer).unwrap(); + assert_eq!(rdlen, 0); + assert_eq!(wrlen, BYTES_BEFORE_EOF); + + // Now nothing should be written. +- let (rdlen, wrlen) = server +- .complete_io(&mut eof_writer) +- .unwrap(); ++ let (rdlen, wrlen) = server.complete_io(&mut eof_writer).unwrap(); + assert_eq!(rdlen, 0); + assert_eq!(wrlen, 0); + } +@@ -2922,10 +2751,7 @@ + + do_handshake(&mut client, &mut server); + +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let (rdlen, wrlen) = server.complete_io(&mut pipe).unwrap(); +@@ -3089,21 +2915,13 @@ + + const N: usize = 1000; + +- let data_chunked: Vec = std::iter::repeat(IoSlice::new(b"A")) +- .take(N) +- .collect(); +- let bytes_written_chunked = client +- .writer() +- .write_vectored(&data_chunked) +- .unwrap(); ++ let data_chunked: Vec = std::iter::repeat(IoSlice::new(b"A")).take(N).collect(); ++ let bytes_written_chunked = client.writer().write_vectored(&data_chunked).unwrap(); + let bytes_sent_chunked = transfer(&mut client, &mut server); + println!("write_vectored returned {bytes_written_chunked} and sent {bytes_sent_chunked}"); + + let data_contiguous = &[b'A'; N]; +- let bytes_written_contiguous = client +- .writer() +- .write(data_contiguous) +- .unwrap(); ++ let bytes_written_contiguous = client.writer().write(data_contiguous).unwrap(); + let bytes_sent_contiguous = transfer(&mut client, &mut server); + println!("write returned {bytes_written_contiguous} and sent {bytes_sent_contiguous}"); + +@@ -3146,10 +2964,7 @@ + errkind: io::ErrorKind::ConnectionAborted, + after: 0, + }; +- client +- .writer() +- .write_all(b"hello") +- .unwrap(); ++ client.writer().write_all(b"hello").unwrap(); + let mut client_stream = Stream::new(&mut client, &mut pipe); + let rc = client_stream.write(b"world"); + assert!(rc.is_err()); +@@ -3166,10 +2981,7 @@ + errkind: io::ErrorKind::ConnectionAborted, + after: 1, + }; +- client +- .writer() +- .write_all(b"hello") +- .unwrap(); ++ client.writer().write_all(b"hello").unwrap(); + let mut client_stream = Stream::new(&mut client, &mut pipe); + let rc = client_stream.write(b"world"); + assert_eq!(format!("{rc:?}"), "Ok(5)"); +@@ -3178,7 +2990,7 @@ + fn make_disjoint_suite_configs() -> (ClientConfig, ServerConfig) { + let kt = KeyType::Rsa2048; + let client_provider = CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_CHACHA20_POLY1305_SHA256], ++ cipher_suites: vec![TLS13_CHACHA20_POLY1305_SHA256], + ..provider::default_provider() + }; + let server_config = finish_server_config( +@@ -3189,7 +3001,7 @@ + ); + + let server_provider = CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_AES_256_GCM_SHA384], ++ cipher_suites: vec![TLS13_AES_256_GCM_SHA384], + ..provider::default_provider() + }; + let client_config = finish_client_config( +@@ -3253,10 +3065,7 @@ + let (client_config, server_config) = make_disjoint_suite_configs(); + let (mut client, mut server) = make_pair_for_configs(client_config, server_config); + +- client +- .writer() +- .write_all(b"world") +- .unwrap(); ++ client.writer().write_all(b"world").unwrap(); + + { + let mut pipe = OtherSession::new_fails(&mut client); +@@ -3276,10 +3085,7 @@ + let (client_config, server_config) = make_disjoint_suite_configs(); + let (mut client, server) = make_pair_for_configs(client_config, server_config); + +- client +- .writer() +- .write_all(b"world") +- .unwrap(); ++ client.writer().write_all(b"world").unwrap(); + + let pipe = OtherSession::new_fails(&mut client); + let mut server_stream = StreamOwned::new(server, pipe); +@@ -3407,7 +3213,7 @@ + let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); + let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); +- let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); ++ let signing_key = RsaPrivateKey::try_from(&kt.get_key()).unwrap(); + let signing_key: Arc = Arc::new(signing_key); + resolver + .add( +@@ -3448,7 +3254,7 @@ + fn sni_resolver_rejects_wrong_names() { + let kt = KeyType::Rsa2048; + let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); +- let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); ++ let signing_key = RsaPrivateKey::try_from(&kt.get_key()).unwrap(); + let signing_key: Arc = Arc::new(signing_key); + + assert_eq!( +@@ -3478,9 +3284,7 @@ + + fn certificate_error_expecting_name(expected: &str) -> CertificateError { + CertificateError::NotValidForNameContext { +- expected: ServerName::try_from(expected) +- .unwrap() +- .to_owned(), ++ expected: ServerName::try_from(expected).unwrap().to_owned(), + presented: vec![ + // ref. examples/internal/test_ca.rs + r#"DnsName("testserver.com")"#.into(), +@@ -3497,7 +3301,7 @@ + let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); + let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); +- let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); ++ let signing_key = RsaPrivateKey::try_from(&kt.get_key()).unwrap(); + let signing_key: Arc = Arc::new(signing_key); + + assert_eq!( +@@ -3528,7 +3332,7 @@ + let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); + let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); +- let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); ++ let signing_key = RsaPrivateKey::try_from(&kt.get_key()).unwrap(); + let signing_key: Arc = Arc::new(signing_key); + + assert_eq!( +@@ -3557,7 +3361,7 @@ + fn sni_resolver_rejects_bad_certs() { + let kt = KeyType::Rsa2048; + let mut resolver = rustls::server::ResolvesServerCertUsingSni::new(); +- let signing_key = RsaSigningKey::new(&kt.get_key()).unwrap(); ++ let signing_key = RsaPrivateKey::try_from(&kt.get_key()).unwrap(); + let signing_key: Arc = Arc::new(signing_key); + + assert_eq!( +@@ -3624,10 +3428,7 @@ + fn public_key(&self) -> Option> { + let chain = KeyType::Rsa2048.get_chain(); + let cert = ParsedCertificate::try_from(chain.first().unwrap()).unwrap(); +- Some( +- cert.subject_public_key_info() +- .into_owned(), +- ) ++ Some(cert.subject_public_key_info().into_owned()) + } + + fn choose_scheme(&self, _offered: &[SignatureScheme]) -> Option> { +@@ -3655,16 +3456,12 @@ + ); + do_handshake(&mut client, &mut server); + +- assert!( +- client +- .export_keying_material(&mut client_secret, b"label", Some(b"context")) +- .is_ok() +- ); +- assert!( +- server +- .export_keying_material(&mut server_secret, b"label", Some(b"context")) +- .is_ok() +- ); ++ assert!(client ++ .export_keying_material(&mut client_secret, b"label", Some(b"context")) ++ .is_ok()); ++ assert!(server ++ .export_keying_material(&mut server_secret, b"label", Some(b"context")) ++ .is_ok()); + assert_eq!(client_secret.to_vec(), server_secret.to_vec()); + + let mut empty = vec![]; +@@ -3685,17 +3482,13 @@ + )) + ); + +- assert!( +- client +- .export_keying_material(&mut client_secret, b"label", None) +- .is_ok() +- ); ++ assert!(client ++ .export_keying_material(&mut client_secret, b"label", None) ++ .is_ok()); + assert_ne!(client_secret.to_vec(), server_secret.to_vec()); +- assert!( +- server +- .export_keying_material(&mut server_secret, b"label", None) +- .is_ok() +- ); ++ assert!(server ++ .export_keying_material(&mut server_secret, b"label", None) ++ .is_ok()); + assert_eq!(client_secret.to_vec(), server_secret.to_vec()); + } + +@@ -3774,10 +3567,7 @@ + } + + fn find_suite(suite: CipherSuite) -> SupportedCipherSuite { +- for scs in provider::ALL_CIPHER_SUITES +- .iter() +- .copied() +- { ++ for scs in provider::ALL_CIPHER_SUITES.iter().copied() { + if scs.suite() == suite { + return scs; + } +@@ -3869,10 +3659,7 @@ + + #[test] + fn all_suites_covered() { +- assert_eq!( +- provider::DEFAULT_CIPHER_SUITES.len(), +- test_ciphersuites().len() +- ); ++ assert_eq!(provider::ALL_CIPHER_SUITES.len(), test_ciphersuites().len()); + } + + #[test] +@@ -4149,14 +3936,8 @@ + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + do_handshake(&mut client, &mut server); + +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut client); + let wrlen = server.write_tls(&mut pipe).unwrap(); +@@ -4174,14 +3955,8 @@ + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + do_handshake(&mut client, &mut server); + +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let wrlen = client.write_tls(&mut pipe).unwrap(); +@@ -4204,14 +3979,8 @@ + server_config, + ); + +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- server +- .writer() +- .write_all(b"0123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); ++ server.writer().write_all(b"0123456789").unwrap(); + + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); +@@ -4246,14 +4015,8 @@ + server_config, + ); + +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- server +- .writer() +- .write_all(b"0123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); ++ server.writer().write_all(b"0123456789").unwrap(); + + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); +@@ -4307,14 +4070,8 @@ + fn vectored_write_for_client_handshake() { + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + +- client +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); +- client +- .writer() +- .write_all(b"0123456789") +- .unwrap(); ++ client.writer().write_all(b"01234567890123456789").unwrap(); ++ client.writer().write_all(b"0123456789").unwrap(); + { + let mut pipe = OtherSession::new(&mut server); + let wrlen = client.write_tls(&mut pipe).unwrap(); +@@ -4347,10 +4104,7 @@ + client.set_buffer_limit(Some(32)); + + do_handshake(&mut client, &mut server); +- server +- .writer() +- .write_all(b"01234567890123456789") +- .unwrap(); ++ server.writer().write_all(b"01234567890123456789").unwrap(); + + { + let mut pipe = OtherSession::new(&mut client); +@@ -4410,20 +4164,17 @@ + + impl rustls::server::StoresServerSessions for ServerStorage { + fn put(&self, key: Vec, value: Vec) -> bool { +- self.put_count +- .fetch_add(1, Ordering::SeqCst); ++ self.put_count.fetch_add(1, Ordering::SeqCst); + self.storage.put(key, value) + } + + fn get(&self, key: &[u8]) -> Option> { +- self.get_count +- .fetch_add(1, Ordering::SeqCst); ++ self.get_count.fetch_add(1, Ordering::SeqCst); + self.storage.get(key) + } + + fn take(&self, key: &[u8]) -> Option> { +- self.take_count +- .fetch_add(1, Ordering::SeqCst); ++ self.take_count.fetch_add(1, Ordering::SeqCst); + self.storage.take(key) + } + +@@ -4486,8 +4237,7 @@ + .lock() + .unwrap() + .push(ClientStorageOp::SetKxHint(server_name.clone(), group)); +- self.storage +- .set_kx_hint(server_name, group) ++ self.storage.set_kx_hint(server_name, group) + } + + fn kx_hint(&self, server_name: &ServerName<'_>) -> Option { +@@ -4508,8 +4258,7 @@ + .lock() + .unwrap() + .push(ClientStorageOp::SetTls12Session(server_name.clone())); +- self.storage +- .set_tls12_session(server_name, value) ++ self.storage.set_tls12_session(server_name, value) + } + + fn tls12_session( +@@ -4532,8 +4281,7 @@ + .lock() + .unwrap() + .push(ClientStorageOp::RemoveTls12Session(server_name.clone())); +- self.storage +- .remove_tls12_session(server_name); ++ self.storage.remove_tls12_session(server_name); + } + + fn insert_tls13_ticket( +@@ -4550,17 +4298,14 @@ + .lock() + .unwrap() + .push(ClientStorageOp::InsertTls13Ticket(server_name.clone())); +- self.storage +- .insert_tls13_ticket(server_name, value); ++ self.storage.insert_tls13_ticket(server_name, value); + } + + fn take_tls13_ticket( + &self, + server_name: &ServerName<'static>, + ) -> Option { +- let rc = self +- .storage +- .take_tls13_ticket(server_name); ++ let rc = self.storage.take_tls13_ticket(server_name); + self.ops + .lock() + .unwrap() +@@ -4591,12 +4336,7 @@ + assert_eq!(storage.puts(), 2); + assert_eq!(storage.gets(), 0); + assert_eq!(storage.takes(), 0); +- assert_eq!( +- client +- .peer_certificates() +- .map(|certs| certs.len()), +- Some(3) +- ); ++ assert_eq!(client.peer_certificates().map(|certs| certs.len()), Some(3)); + assert_eq!(client.handshake_kind(), Some(HandshakeKind::Full)); + assert_eq!(server.handshake_kind(), Some(HandshakeKind::Full)); + +@@ -4608,12 +4348,7 @@ + assert_eq!(storage.puts(), 4); + assert_eq!(storage.gets(), 0); + assert_eq!(storage.takes(), 1); +- assert_eq!( +- client +- .peer_certificates() +- .map(|certs| certs.len()), +- Some(3) +- ); ++ assert_eq!(client.peer_certificates().map(|certs| certs.len()), Some(3)); + assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); + assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); + +@@ -4625,78 +4360,58 @@ + assert_eq!(storage.puts(), 6); + assert_eq!(storage.gets(), 0); + assert_eq!(storage.takes(), 2); +- assert_eq!( +- client +- .peer_certificates() +- .map(|certs| certs.len()), +- Some(3) +- ); +- assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); +- assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); +-} +- +-#[test] +-fn tls13_stateless_resumption() { +- let kt = KeyType::Rsa2048; +- let provider = provider::default_provider(); +- let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); +- let client_config = Arc::new(client_config); +- +- let mut server_config = make_server_config(kt, &provider); +- server_config.ticketer = provider::Ticketer::new().unwrap(); +- let storage = Arc::new(ServerStorage::new()); +- server_config.session_storage = storage.clone(); +- let server_config = Arc::new(server_config); +- +- // full handshake +- let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); +- let (full_c2s, full_s2c) = do_handshake(&mut client, &mut server); +- assert_eq!(storage.puts(), 0); +- assert_eq!(storage.gets(), 0); +- assert_eq!(storage.takes(), 0); +- assert_eq!( +- client +- .peer_certificates() +- .map(|certs| certs.len()), +- Some(3) +- ); +- assert_eq!(client.handshake_kind(), Some(HandshakeKind::Full)); +- assert_eq!(server.handshake_kind(), Some(HandshakeKind::Full)); +- +- // resumed +- let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); +- let (resume_c2s, resume_s2c) = do_handshake(&mut client, &mut server); +- assert!(resume_c2s > full_c2s); +- assert!(resume_s2c < full_s2c); +- assert_eq!(storage.puts(), 0); +- assert_eq!(storage.gets(), 0); +- assert_eq!(storage.takes(), 0); +- assert_eq!( +- client +- .peer_certificates() +- .map(|certs| certs.len()), +- Some(3) +- ); +- assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); +- assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); +- +- // resumed again +- let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); +- let (resume2_c2s, resume2_s2c) = do_handshake(&mut client, &mut server); +- assert_eq!(resume_s2c, resume2_s2c); +- assert_eq!(resume_c2s, resume2_c2s); +- assert_eq!(storage.puts(), 0); +- assert_eq!(storage.gets(), 0); +- assert_eq!(storage.takes(), 0); +- assert_eq!( +- client +- .peer_certificates() +- .map(|certs| certs.len()), +- Some(3) +- ); ++ assert_eq!(client.peer_certificates().map(|certs| certs.len()), Some(3)); + assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); + assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); + } ++// #[ignore] ++// #[test] ++// fn tls13_stateless_resumption() { ++// let kt = KeyType::Rsa2048; ++// let provider = provider::default_provider(); ++// let client_config = make_client_config_with_versions(kt, &[&rustls::version::TLS13], &provider); ++// let client_config = Arc::new(client_config); ++// ++// let mut server_config = make_server_config(kt, &provider); ++// server_config.ticketer = provider::Ticketer::new().unwrap(); ++// let storage = Arc::new(ServerStorage::new()); ++// server_config.session_storage = storage.clone(); ++// let server_config = Arc::new(server_config); ++// ++// // full handshake ++// let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); ++// let (full_c2s, full_s2c) = do_handshake(&mut client, &mut server); ++// assert_eq!(storage.puts(), 0); ++// assert_eq!(storage.gets(), 0); ++// assert_eq!(storage.takes(), 0); ++// assert_eq!(client.peer_certificates().map(|certs| certs.len()), Some(3)); ++// assert_eq!(client.handshake_kind(), Some(HandshakeKind::Full)); ++// assert_eq!(server.handshake_kind(), Some(HandshakeKind::Full)); ++// ++// // resumed ++// let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); ++// let (resume_c2s, resume_s2c) = do_handshake(&mut client, &mut server); ++// assert!(resume_c2s > full_c2s); ++// assert!(resume_s2c < full_s2c); ++// assert_eq!(storage.puts(), 0); ++// assert_eq!(storage.gets(), 0); ++// assert_eq!(storage.takes(), 0); ++// assert_eq!(client.peer_certificates().map(|certs| certs.len()), Some(3)); ++// assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); ++// assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); ++// ++// // resumed again ++// let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); ++// let (resume2_c2s, resume2_s2c) = do_handshake(&mut client, &mut server); ++// assert_eq!(resume_s2c, resume2_s2c); ++// assert_eq!(resume_c2s, resume2_c2s); ++// assert_eq!(storage.puts(), 0); ++// assert_eq!(storage.gets(), 0); ++// assert_eq!(storage.takes(), 0); ++// assert_eq!(client.peer_certificates().map(|certs| certs.len()), Some(3)); ++// assert_eq!(client.handshake_kind(), Some(HandshakeKind::Resumed)); ++// assert_eq!(server.handshake_kind(), Some(HandshakeKind::Resumed)); ++// } + + #[test] + fn early_data_not_available() { +@@ -4725,26 +4440,9 @@ + + let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); + assert!(client.early_data().is_some()); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .bytes_left(), +- 1234 +- ); +- client +- .early_data() +- .unwrap() +- .flush() +- .unwrap(); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .write(b"hello") +- .unwrap(), +- 5 +- ); ++ assert_eq!(client.early_data().unwrap().bytes_left(), 1234); ++ client.early_data().unwrap().flush().unwrap(); ++ assert_eq!(client.early_data().unwrap().write(b"hello").unwrap(), 5); + do_handshake(&mut client, &mut server); + + let mut received_early_data = [0u8; 5]; +@@ -4778,26 +4476,9 @@ + + let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); + assert!(client.early_data().is_some()); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .bytes_left(), +- 1234 +- ); +- client +- .early_data() +- .unwrap() +- .flush() +- .unwrap(); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .write(b"hello") +- .unwrap(), +- 5 +- ); ++ assert_eq!(client.early_data().unwrap().bytes_left(), 1234); ++ client.early_data().unwrap().flush().unwrap(); ++ assert_eq!(client.early_data().unwrap().write(b"hello").unwrap(), 5); + server.reject_early_data(); + do_handshake(&mut client, &mut server); + +@@ -4814,18 +4495,8 @@ + + let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); + assert!(client.early_data().is_some()); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .bytes_left(), +- 1234 +- ); +- client +- .early_data() +- .unwrap() +- .flush() +- .unwrap(); ++ assert_eq!(client.early_data().unwrap().bytes_left(), 1234); ++ client.early_data().unwrap().flush().unwrap(); + assert_eq!( + client + .early_data() +@@ -4871,24 +4542,10 @@ + + let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); + assert!(client.early_data().is_some()); ++ assert_eq!(client.early_data().unwrap().bytes_left(), 2024); ++ client.early_data().unwrap().flush().unwrap(); + assert_eq!( +- client +- .early_data() +- .unwrap() +- .bytes_left(), +- 2024 +- ); +- client +- .early_data() +- .unwrap() +- .flush() +- .unwrap(); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .write(&[0xaa; 2024]) +- .unwrap(), ++ client.early_data().unwrap().write(&[0xaa; 2024]).unwrap(), + 2024 + ); + assert_eq!( +@@ -4906,24 +4563,10 @@ + + let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); + assert!(client.early_data().is_some()); ++ assert_eq!(client.early_data().unwrap().bytes_left(), 2024); ++ client.early_data().unwrap().flush().unwrap(); + assert_eq!( +- client +- .early_data() +- .unwrap() +- .bytes_left(), +- 2024 +- ); +- client +- .early_data() +- .unwrap() +- .flush() +- .unwrap(); +- assert_eq!( +- client +- .early_data() +- .unwrap() +- .write(&[0xaa; 1024]) +- .unwrap(), ++ client.early_data().unwrap().write(&[0xaa; 1024]).unwrap(), + 1024 + ); + transfer(&mut client, &mut server); +@@ -4941,11 +4584,7 @@ + assert_eq!(&received_early_data[..], [0xaa; 1024]); + + assert_eq!( +- client +- .early_data() +- .unwrap() +- .write(&[0xbb; 1000]) +- .unwrap(), ++ client.early_data().unwrap().write(&[0xbb; 1000]).unwrap(), + 1000 + ); + transfer(&mut client, &mut server); +@@ -4990,9 +4629,7 @@ + let mut buf = [0; 32]; + let (header, payload_tag) = buf.split_at_mut(8); + let (payload, tag_buf) = payload_tag.split_at_mut(8); +- let tag = x +- .encrypt_in_place(42, header, payload) +- .unwrap(); ++ let tag = x.encrypt_in_place(42, header, payload).unwrap(); + tag_buf.copy_from_slice(tag.as_ref()); + + let result = y.decrypt_in_place(42, header, payload_tag); +@@ -5048,38 +4685,22 @@ + assert!(client_initial.is_none()); + assert!(client.zero_rtt_keys().is_none()); + assert_eq!(server.quic_transport_parameters(), Some(client_params)); +- let server_hs = step(&mut server, &mut client) +- .unwrap() +- .unwrap(); ++ let server_hs = step(&mut server, &mut client).unwrap().unwrap(); + assert!(server.zero_rtt_keys().is_none()); +- let client_hs = step(&mut client, &mut server) +- .unwrap() +- .unwrap(); ++ let client_hs = step(&mut client, &mut server).unwrap().unwrap(); + assert!(compatible_keys(&server_hs, &client_hs)); + assert!(client.is_handshaking()); +- let server_1rtt = step(&mut server, &mut client) +- .unwrap() +- .unwrap(); ++ let server_1rtt = step(&mut server, &mut client).unwrap().unwrap(); + assert!(!client.is_handshaking()); + assert_eq!(client.quic_transport_parameters(), Some(server_params)); + assert!(server.is_handshaking()); +- let client_1rtt = step(&mut client, &mut server) +- .unwrap() +- .unwrap(); ++ let client_1rtt = step(&mut client, &mut server).unwrap().unwrap(); + assert!(!server.is_handshaking()); + assert!(compatible_keys(&server_1rtt, &client_1rtt)); + assert!(!compatible_keys(&server_hs, &server_1rtt)); + +- assert!( +- step(&mut client, &mut server) +- .unwrap() +- .is_none() +- ); +- assert!( +- step(&mut server, &mut client) +- .unwrap() +- .is_none() +- ); ++ assert!(step(&mut client, &mut server).unwrap().is_none()); ++ assert!(step(&mut server, &mut client).unwrap().is_none()); + assert_eq!(client.tls13_tickets_received(), 2); + + // 0-RTT handshake +@@ -5090,11 +4711,7 @@ + client_params.into(), + ) + .unwrap(); +- assert!( +- client +- .negotiated_cipher_suite() +- .is_some() +- ); ++ assert!(client.negotiated_cipher_suite().is_some()); + + let mut server = quic::ServerConnection::new( + server_config.clone(), +@@ -5113,15 +4730,9 @@ + server_early.packet.as_ref() + )); + } +- step(&mut server, &mut client) +- .unwrap() +- .unwrap(); +- step(&mut client, &mut server) +- .unwrap() +- .unwrap(); +- step(&mut server, &mut client) +- .unwrap() +- .unwrap(); ++ step(&mut server, &mut client).unwrap().unwrap(); ++ step(&mut client, &mut server).unwrap().unwrap(); ++ step(&mut server, &mut client).unwrap().unwrap(); + assert!(client.is_early_data_accepted()); + // 0-RTT rejection + { +@@ -5146,15 +4757,9 @@ + assert_eq!(client.quic_transport_parameters(), Some(server_params)); + assert!(client.zero_rtt_keys().is_some()); + assert!(server.zero_rtt_keys().is_none()); +- step(&mut server, &mut client) +- .unwrap() +- .unwrap(); +- step(&mut client, &mut server) +- .unwrap() +- .unwrap(); +- step(&mut server, &mut client) +- .unwrap() +- .unwrap(); ++ step(&mut server, &mut client).unwrap().unwrap(); ++ step(&mut client, &mut server).unwrap().unwrap(); ++ step(&mut server, &mut client).unwrap().unwrap(); + assert!(!client.is_early_data_accepted()); + } + +@@ -5172,9 +4777,7 @@ + .unwrap(); + + step(&mut client, &mut server).unwrap(); +- step(&mut server, &mut client) +- .unwrap() +- .unwrap(); ++ step(&mut server, &mut client).unwrap().unwrap(); + assert!(step(&mut server, &mut client).is_err()); + assert_eq!( + client.alert(), +@@ -5248,9 +4851,7 @@ + .unwrap(); + + assert_eq!( +- step(&mut client, &mut server) +- .err() +- .unwrap(), ++ step(&mut client, &mut server).err().unwrap(), + Error::NoApplicationProtocol + ); + +@@ -5273,15 +4874,13 @@ + client_config.alpn_protocols = vec!["foo".into()]; + let client_config = Arc::new(client_config); + +- assert!( +- quic::ClientConnection::new( +- client_config, +- quic::Version::V1, +- server_name("localhost"), +- b"client params".to_vec(), +- ) +- .is_err() +- ); ++ assert!(quic::ClientConnection::new( ++ client_config, ++ quic::Version::V1, ++ server_name("localhost"), ++ b"client params".to_vec(), ++ ) ++ .is_err()); + + let mut server_config = make_server_config_with_versions( + KeyType::Ed25519, +@@ -5291,14 +4890,12 @@ + server_config.alpn_protocols = vec!["foo".into()]; + let server_config = Arc::new(server_config); + +- assert!( +- quic::ServerConnection::new( +- server_config, +- quic::Version::V1, +- b"server params".to_vec(), +- ) +- .is_err() +- ); ++ assert!(quic::ServerConnection::new( ++ server_config, ++ quic::Version::V1, ++ b"server params".to_vec(), ++ ) ++ .is_err()); + } + + #[test] +@@ -5392,9 +4989,9 @@ + + #[test] + fn packet_key_api() { +- use cipher_suite::TLS13_AES_128_GCM_SHA256; +- use rustls::Side; ++ use provider::TLS13_AES_128_GCM_SHA256; + use rustls::quic::{Keys, Version}; ++ use rustls::Side; + + // Test vectors: https://www.rfc-editor.org/rfc/rfc9001.html#name-client-initial + const CONNECTION_ID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; +@@ -5427,14 +5024,8 @@ + + let client_keys = Keys::initial( + Version::V1, +- TLS13_AES_128_GCM_SHA256 +- .tls13() +- .unwrap(), +- TLS13_AES_128_GCM_SHA256 +- .tls13() +- .unwrap() +- .quic +- .unwrap(), ++ TLS13_AES_128_GCM_SHA256.tls13().unwrap(), ++ TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(), + CONNECTION_ID, + Side::Client, + ); +@@ -5561,14 +5152,8 @@ + + let server_keys = Keys::initial( + Version::V1, +- TLS13_AES_128_GCM_SHA256 +- .tls13() +- .unwrap(), +- TLS13_AES_128_GCM_SHA256 +- .tls13() +- .unwrap() +- .quic +- .unwrap(), ++ TLS13_AES_128_GCM_SHA256.tls13().unwrap(), ++ TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(), + CONNECTION_ID, + Side::Server, + ); +@@ -5641,7 +5226,7 @@ + #[test] + fn test_client_config_keyshare() { + let provider = provider::default_provider(); +- let kx_groups = vec![provider::kx_group::SECP384R1]; ++ let kx_groups: Vec<&'static dyn SupportedKxGroup> = vec![&provider::kx::SECP384R1]; + let client_config = + make_client_config_with_kx_groups(KeyType::Rsa2048, kx_groups.clone(), &provider); + let server_config = make_server_config_with_kx_groups(KeyType::Rsa2048, kx_groups, &provider); +@@ -5654,14 +5239,11 @@ + let provider = provider::default_provider(); + let client_config = make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::SECP384R1], +- &provider, +- ); +- let server_config = make_server_config_with_kx_groups( +- KeyType::Rsa2048, +- vec![provider::kx_group::X25519], ++ vec![&provider::kx::SECP384R1], + &provider, + ); ++ let server_config = ++ make_server_config_with_kx_groups(KeyType::Rsa2048, vec![&provider::kx::X25519], &provider); + let (mut client, mut server) = make_pair_for_configs(client_config, server_config); + assert!(do_handshake_until_error(&mut client, &mut server).is_err()); + } +@@ -5669,7 +5251,7 @@ + #[test] + fn exercise_all_key_exchange_methods() { + for version in rustls::ALL_VERSIONS { +- for kx_group in provider::ALL_KX_GROUPS { ++ for kx_group in provider::kx::ALL_KX_GROUPS { + if !kx_group.usable_for_version(version.version) { + continue; + } +@@ -5693,7 +5275,7 @@ + // client sends a secp384r1 key share + let mut client_config = make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::SECP384R1, provider::kx_group::X25519], ++ vec![&provider::kx::SECP384R1, &provider::kx::X25519], + &provider, + ); + +@@ -5701,11 +5283,8 @@ + client_config.resumption = Resumption::store(storage.clone()); + + // but server only accepts x25519, so a HRR is required +- let server_config = make_server_config_with_kx_groups( +- KeyType::Rsa2048, +- vec![provider::kx_group::X25519], +- &provider, +- ); ++ let server_config = ++ make_server_config_with_kx_groups(KeyType::Rsa2048, vec![&provider::kx::X25519], &provider); + + let (mut client, mut server) = make_pair_for_configs(client_config, server_config); + +@@ -5819,7 +5398,7 @@ + // into kx group cache. + let mut client_config_1 = make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::SECP256R1], ++ vec![&provider::kx::SECP256R1], + &provider, + ); + client_config_1.resumption = Resumption::store(shared_storage.clone()); +@@ -5828,7 +5407,7 @@ + // contains an unusable value. + let mut client_config_2 = make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::SECP384R1], ++ vec![&provider::kx::SECP384R1], + &provider, + ); + client_config_2.resumption = Resumption::store(shared_storage.clone()); +@@ -5879,7 +5458,7 @@ + // into kx group cache. + let mut client_config_1 = make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::SECP384R1], ++ vec![&provider::kx::SECP384R1], + &provider, + ); + client_config_1.resumption = Resumption::store(shared_storage.clone()); +@@ -5888,14 +5467,14 @@ + // contains a supported but less-preferred group. + let mut client_config_2 = make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::X25519, provider::kx_group::SECP384R1], ++ vec![&provider::kx::X25519, &provider::kx::SECP384R1], + &provider, + ); + client_config_2.resumption = Resumption::store(shared_storage.clone()); + + let server_config = make_server_config_with_kx_groups( + KeyType::Rsa2048, +- provider::ALL_KX_GROUPS.to_vec(), ++ provider::kx::ALL_KX_GROUPS.to_vec(), + &provider, + ); + +@@ -6001,10 +5580,7 @@ + panic!() + } + fn write_vectored(&mut self, b: &[io::IoSlice<'_>]) -> io::Result { +- let writes = b +- .iter() +- .map(|slice| slice.len()) +- .collect::>(); ++ let writes = b.iter().map(|slice| slice.len()).collect::>(); + let len = writes.iter().sum(); + self.writevs.push(writes); + Ok(len) +@@ -6014,9 +5590,7 @@ + fn collect_write_lengths(client: &mut ClientConnection) -> Vec { + let mut collector = CollectWrites { writevs: vec![] }; + +- client +- .write_tls(&mut collector) +- .unwrap(); ++ client.write_tls(&mut collector).unwrap(); + assert_eq!(collector.writevs.len(), 1); + collector.writevs[0].clone() + } +@@ -6046,10 +5620,7 @@ + ); + + let big_data = [0u8; 2048]; +- server +- .writer() +- .write_all(&big_data) +- .unwrap(); ++ server.writer().write_all(&big_data).unwrap(); + + let encryption_overhead = 20; // FIXME: see issue #991 + +@@ -6060,11 +5631,9 @@ + server.write_tls(&mut pipe).unwrap(); + + assert_eq!(pipe.writevs.len(), 1); +- assert!( +- pipe.writevs[0] +- .iter() +- .all(|x| *x <= 64 + encryption_overhead) +- ); ++ assert!(pipe.writevs[0] ++ .iter() ++ .all(|x| *x <= 64 + encryption_overhead)); + } + + client.process_new_packets().unwrap(); +@@ -6074,11 +5643,9 @@ + let mut pipe = OtherSession::new(&mut client); + server.write_tls(&mut pipe).unwrap(); + assert_eq!(pipe.writevs.len(), 1); +- assert!( +- pipe.writevs[0] +- .iter() +- .all(|x| *x <= 64 + encryption_overhead) +- ); ++ assert!(pipe.writevs[0] ++ .iter() ++ .all(|x| *x <= 64 + encryption_overhead)); + } + + client.process_new_packets().unwrap(); +@@ -6206,14 +5773,14 @@ + let (mut client, mut server) = make_pair_for_configs( + make_client_config_with_kx_groups( + KeyType::Rsa2048, +- vec![provider::kx_group::X25519], ++ vec![&provider::kx::X25519], + &provider::default_provider(), + ), + finish_server_config( + KeyType::Rsa2048, + ServerConfig::builder_with_provider( + CryptoProvider { +- kx_groups: vec![provider::kx_group::SECP384R1], ++ kx_groups: vec![&provider::kx::SECP384R1], + ..provider::default_provider() + } + .into(), +@@ -6341,9 +5908,7 @@ + + let server_config = Arc::new(make_server_config(KeyType::Ed25519, &provider)); + let mut acceptor = Acceptor::default(); +- acceptor +- .read_tls(&mut buf.as_slice()) +- .unwrap(); ++ acceptor.read_tls(&mut buf.as_slice()).unwrap(); + let accepted = acceptor.accept().unwrap().unwrap(); + let ch = accepted.client_hello(); + assert_eq!(ch.server_name(), Some("localhost")); +@@ -6356,18 +5921,12 @@ + .collect::>() + ); + +- let server = accepted +- .into_connection(server_config) +- .unwrap(); ++ let server = accepted.into_connection(server_config).unwrap(); + assert!(server.wants_write()); + + // Reusing an acceptor is not allowed + assert_eq!( +- acceptor +- .read_tls(&mut [0u8].as_ref()) +- .err() +- .unwrap() +- .kind(), ++ acceptor.read_tls(&mut [0u8].as_ref()).err().unwrap().kind(), + io::ErrorKind::Other, + ); + assert_eq!( +@@ -6377,14 +5936,10 @@ + + let mut acceptor = Acceptor::default(); + assert!(acceptor.accept().unwrap().is_none()); +- acceptor +- .read_tls(&mut &buf[..3]) +- .unwrap(); // incomplete message ++ acceptor.read_tls(&mut &buf[..3]).unwrap(); // incomplete message + assert!(acceptor.accept().unwrap().is_none()); + +- acceptor +- .read_tls(&mut [0x80, 0x00].as_ref()) +- .unwrap(); // invalid message (len = 32k bytes) ++ acceptor.read_tls(&mut [0x80, 0x00].as_ref()).unwrap(); // invalid message (len = 32k bytes) + let (err, mut alert) = acceptor.accept().unwrap_err(); + assert_eq!(err, Error::InvalidMessage(InvalidMessage::MessageTooLarge)); + let mut alert_content = Vec::new(); +@@ -6454,16 +6009,12 @@ + .unwrap(), + ); + let mut acceptor = Acceptor::default(); +- acceptor +- .read_tls(&mut buf.as_slice()) +- .unwrap(); ++ acceptor.read_tls(&mut buf.as_slice()).unwrap(); + let accepted = acceptor.accept().unwrap().unwrap(); + let ch = accepted.client_hello(); + assert_eq!(ch.server_name(), Some("localhost")); + +- let (err, mut alert) = accepted +- .into_connection(server_config.into()) +- .unwrap_err(); ++ let (err, mut alert) = accepted.into_connection(server_config.into()).unwrap_err(); + assert_eq!( + err, + Error::PeerIncompatible(PeerIncompatible::Tls12NotOfferedOrEnabled) +@@ -6526,14 +6077,14 @@ + let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); + for suite in [ +- cipher_suite::TLS13_AES_128_GCM_SHA256, +- cipher_suite::TLS13_AES_256_GCM_SHA384, ++ TLS13_AES_128_GCM_SHA256, ++ TLS13_AES_256_GCM_SHA384, + #[cfg(not(feature = "fips"))] +- cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, +- cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +- cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, ++ TLS13_CHACHA20_POLY1305_SHA256, ++ TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ TLS12_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + #[cfg(not(feature = "fips"))] +- cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, ++ TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + ] { + let version = suite.version(); + println!("Testing suite {:?}", suite.suite().as_str()); +@@ -6564,12 +6115,8 @@ + do_handshake(&mut client, &mut server); + + // The handshake is finished, we're now able to extract traffic secrets +- let client_secrets = client +- .dangerous_extract_secrets() +- .unwrap(); +- let server_secrets = server +- .dangerous_extract_secrets() +- .unwrap(); ++ let client_secrets = client.dangerous_extract_secrets().unwrap(); ++ let server_secrets = server.dangerous_extract_secrets().unwrap(); + + // Comparing secrets for equality is something you should never have to + // do in production code, so ConnectionTrafficSecrets doesn't implement +@@ -6632,12 +6179,8 @@ + + do_handshake(&mut client, &mut server); + +- let client_secrets = client +- .dangerous_extract_secrets() +- .unwrap(); +- let server_secrets = server +- .dangerous_extract_secrets() +- .unwrap(); ++ let client_secrets = client.dangerous_extract_secrets().unwrap(); ++ let server_secrets = server.dangerous_extract_secrets().unwrap(); + + assert!(f(client_secrets.tx.1)); + assert!(f(client_secrets.rx.1)); +@@ -6645,26 +6188,25 @@ + assert!(f(server_secrets.rx.1)); + } + +- check(cipher_suite::TLS13_AES_128_GCM_SHA256, |sec| { ++ check(TLS13_AES_128_GCM_SHA256, |sec| { + matches!(sec, ConnectionTrafficSecrets::Aes128Gcm { .. }) + }); +- check(cipher_suite::TLS13_AES_256_GCM_SHA384, |sec| { ++ check(TLS13_AES_256_GCM_SHA384, |sec| { + matches!(sec, ConnectionTrafficSecrets::Aes256Gcm { .. }) + }); +- check(cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, |sec| { ++ check(TLS13_CHACHA20_POLY1305_SHA256, |sec| { + matches!(sec, ConnectionTrafficSecrets::Chacha20Poly1305 { .. }) + }); + +- check(cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, |sec| { ++ check(TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, |sec| { + matches!(sec, ConnectionTrafficSecrets::Aes128Gcm { .. }) + }); +- check(cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, |sec| { ++ check(TLS12_ECDHE_RSA_WITH_AES_256_GCM_SHA384, |sec| { + matches!(sec, ConnectionTrafficSecrets::Aes256Gcm { .. }) + }); +- check( +- cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, +- |sec| matches!(sec, ConnectionTrafficSecrets::Chacha20Poly1305 { .. }), +- ); ++ check(TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, |sec| { ++ matches!(sec, ConnectionTrafficSecrets::Chacha20Poly1305 { .. }) ++ }); + } + + /// Test that secrets cannot be extracted unless explicitly enabled, and until +@@ -6674,7 +6216,7 @@ + fn test_secret_extraction_disabled_or_too_early() { + let kt = KeyType::Rsa2048; + let provider = Arc::new(CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_AES_128_GCM_SHA256], ++ cipher_suites: vec![TLS13_AES_128_GCM_SHA256], + ..provider::default_provider() + }); + +@@ -6696,15 +6238,11 @@ + let (client, server) = make_pair_for_arc_configs(&client_config, &server_config); + + assert!( +- client +- .dangerous_extract_secrets() +- .is_err(), ++ client.dangerous_extract_secrets().is_err(), + "extraction should fail until handshake completes" + ); + assert!( +- server +- .dangerous_extract_secrets() +- .is_err(), ++ server.dangerous_extract_secrets().is_err(), + "extraction should fail until handshake completes" + ); + +@@ -6712,18 +6250,8 @@ + + do_handshake(&mut client, &mut server); + +- assert_eq!( +- server_enable, +- server +- .dangerous_extract_secrets() +- .is_ok() +- ); +- assert_eq!( +- client_enable, +- client +- .dangerous_extract_secrets() +- .is_ok() +- ); ++ assert_eq!(server_enable, server.dangerous_extract_secrets().is_ok()); ++ assert_eq!(client_enable, client.dangerous_extract_secrets().is_ok()); + } + } + +@@ -6735,7 +6263,7 @@ + let server_config = Arc::new( + ServerConfig::builder_with_provider( + CryptoProvider { +- cipher_suites: vec![cipher_suite::TLS13_AES_128_GCM_SHA256], ++ cipher_suites: vec![TLS13_AES_128_GCM_SHA256], + ..provider.clone() + } + .into(), +@@ -6753,25 +6281,12 @@ + + // Fill the server's received plaintext buffer with 16k bytes + let client_buf = [0; 16_385]; +- dbg!( +- client +- .writer() +- .write(&client_buf) +- .unwrap() +- ); ++ dbg!(client.writer().write(&client_buf).unwrap()); + let mut network_buf = Vec::with_capacity(32_768); +- let sent = dbg!( +- client +- .write_tls(&mut network_buf) +- .unwrap() +- ); ++ let sent = dbg!(client.write_tls(&mut network_buf).unwrap()); + let mut read = 0; + while read < sent { +- let new = dbg!( +- server +- .read_tls(&mut &network_buf[read..sent]) +- .unwrap() +- ); ++ let new = dbg!(server.read_tls(&mut &network_buf[read..sent]).unwrap()); + if new == 4096 { + read += new; + } else { +@@ -6781,38 +6296,17 @@ + server.process_new_packets().unwrap(); + + // Send two more bytes from client to server +- dbg!( +- client +- .writer() +- .write(&client_buf[..2]) +- .unwrap() +- ); +- let sent = dbg!( +- client +- .write_tls(&mut network_buf) +- .unwrap() +- ); ++ dbg!(client.writer().write(&client_buf[..2]).unwrap()); ++ let sent = dbg!(client.write_tls(&mut network_buf).unwrap()); + + // Get an error because the received plaintext buffer is full +- assert!( +- server +- .read_tls(&mut &network_buf[..sent]) +- .is_err() +- ); ++ assert!(server.read_tls(&mut &network_buf[..sent]).is_err()); + + // Read out some of the plaintext +- server +- .reader() +- .read_exact(&mut [0; 2]) +- .unwrap(); ++ server.reader().read_exact(&mut [0; 2]).unwrap(); + + // Now there's room again in the plaintext buffer +- assert_eq!( +- server +- .read_tls(&mut &network_buf[..sent]) +- .unwrap(), +- 24 +- ); ++ assert_eq!(server.read_tls(&mut &network_buf[..sent]).unwrap(), 24); + } + + #[test] +@@ -7160,10 +6654,7 @@ + + let mut stream = FakeStream(client_hello_followed_by_close_notify_alert); + assert_eq!( +- server +- .complete_io(&mut stream) +- .unwrap_err() +- .kind(), ++ server.complete_io(&mut stream).unwrap_err().kind(), + io::ErrorKind::UnexpectedEof + ); + } +@@ -7172,17 +6663,11 @@ + fn test_complete_io_with_no_io_needed() { + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + do_handshake(&mut client, &mut server); +- client +- .writer() +- .write_all(b"hello") +- .unwrap(); ++ client.writer().write_all(b"hello").unwrap(); + client.send_close_notify(); + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); +- server +- .writer() +- .write_all(b"hello") +- .unwrap(); ++ server.writer().write_all(b"hello").unwrap(); + server.send_close_notify(); + transfer(&mut server, &mut client); + client.process_new_packets().unwrap(); +@@ -7192,28 +6677,15 @@ + assert!(!client.wants_read()); + assert!(!server.wants_write()); + assert!(!server.wants_read()); +- assert_eq!( +- client +- .complete_io(&mut FakeStream(&[])) +- .unwrap(), +- (0, 0) +- ); +- assert_eq!( +- server +- .complete_io(&mut FakeStream(&[])) +- .unwrap(), +- (0, 0) +- ); ++ assert_eq!(client.complete_io(&mut FakeStream(&[])).unwrap(), (0, 0)); ++ assert_eq!(server.complete_io(&mut FakeStream(&[])).unwrap(), (0, 0)); + } + + #[test] + fn test_junk_after_close_notify_received() { + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + do_handshake(&mut client, &mut server); +- client +- .writer() +- .write_all(b"hello") +- .unwrap(); ++ client.writer().write_all(b"hello").unwrap(); + client.send_close_notify(); + + let mut client_buffer = vec![]; +@@ -7233,20 +6705,11 @@ + + // can read data received prior to close_notify + let mut received_data = [0u8; 128]; +- let len = server +- .reader() +- .read(&mut received_data) +- .unwrap(); ++ let len = server.reader().read(&mut received_data).unwrap(); + assert_eq!(&received_data[..len], b"hello"); + + // but subsequent reads just report clean EOF +- assert_eq!( +- server +- .reader() +- .read(&mut received_data) +- .unwrap(), +- 0 +- ); ++ assert_eq!(server.reader().read(&mut received_data).unwrap(), 0); + } + + #[test] +@@ -7254,31 +6717,16 @@ + let (mut client, mut server) = make_pair(KeyType::Rsa2048, &provider::default_provider()); + do_handshake(&mut client, &mut server); + +- client +- .writer() +- .write_all(b"before") +- .unwrap(); ++ client.writer().write_all(b"before").unwrap(); + client.send_close_notify(); +- client +- .writer() +- .write_all(b"after") +- .unwrap(); ++ client.writer().write_all(b"after").unwrap(); + transfer(&mut client, &mut server); + server.process_new_packets().unwrap(); + + let mut received_data = [0u8; 128]; +- let count = server +- .reader() +- .read(&mut received_data) +- .unwrap(); ++ let count = server.reader().read(&mut received_data).unwrap(); + assert_eq!(&received_data[..count], b"before"); +- assert_eq!( +- server +- .reader() +- .read(&mut received_data) +- .unwrap(), +- 0 +- ); ++ assert_eq!(server.reader().read(&mut received_data).unwrap(), 0); + } + + #[test] +@@ -7326,12 +6774,7 @@ + server.process_new_packets().unwrap(); + + let buf = [1, 2, 3, 4]; +- assert_eq!( +- server +- .read_tls(&mut io::Cursor::new(buf)) +- .unwrap(), +- 0 +- ); ++ assert_eq!(server.read_tls(&mut io::Cursor::new(buf)).unwrap(), 0); + } + + #[test] +@@ -7698,9 +7141,7 @@ + )), + }; + raw_server.encrypt_and_send(&msg, &mut client); +- let err = client +- .process_new_packets() +- .unwrap_err(); ++ let err = client.process_new_packets().unwrap_err(); + assert_eq!( + err, + Error::InappropriateHandshakeMessage { +@@ -7744,9 +7185,7 @@ + // second is fatal + raw_server.encrypt_and_send(&msg, &mut client); + assert_eq!( +- client +- .process_new_packets() +- .unwrap_err(), ++ client.process_new_packets().unwrap_err(), + Error::PeerMisbehaved(PeerMisbehaved::TooManyRenegotiationRequests) + ); + } +@@ -7770,9 +7209,7 @@ + payload: Payload::new(encoding::basic_client_hello(vec![])), + }; + raw_client.encrypt_and_send(&msg, &mut server); +- let err = server +- .process_new_packets() +- .unwrap_err(); ++ let err = server.process_new_packets().unwrap_err(); + assert_eq!( + format!("{err:?}"), + "InappropriateHandshakeMessage { expect_types: [KeyUpdate], got_type: ClientHello }" +@@ -7800,9 +7237,7 @@ + .read_tls(&mut io::Cursor::new(&client_hello)) + .unwrap(); + assert_eq!( +- server +- .process_new_packets() +- .unwrap_err(), ++ server.process_new_packets().unwrap_err(), + Error::InappropriateHandshakeMessage { + expect_types: vec![HandshakeType::ClientKeyExchange], + got_type: HandshakeType::ClientHello +@@ -7814,15 +7249,11 @@ + fn test_refresh_traffic_keys_during_handshake() { + let (mut client, mut server) = make_pair(KeyType::Ed25519, &provider::default_provider()); + assert_eq!( +- client +- .refresh_traffic_keys() +- .unwrap_err(), ++ client.refresh_traffic_keys().unwrap_err(), + Error::HandshakeNotComplete + ); + assert_eq!( +- server +- .refresh_traffic_keys() +- .unwrap_err(), ++ server.refresh_traffic_keys().unwrap_err(), + Error::HandshakeNotComplete + ); + } +@@ -7833,14 +7264,8 @@ + do_handshake(&mut client, &mut server); + + fn check_both_directions(client: &mut ClientConnection, server: &mut ServerConnection) { +- client +- .writer() +- .write_all(b"to-server-1") +- .unwrap(); +- server +- .writer() +- .write_all(b"to-client-1") +- .unwrap(); ++ client.writer().write_all(b"to-server-1").unwrap(); ++ server.writer().write_all(b"to-client-1").unwrap(); + transfer(client, server); + server.process_new_packets().unwrap(); + +@@ -7892,10 +7317,7 @@ + + for i in 0..(CONFIDENTIALITY_LIMIT + 16) { + let message = format!("{i:08}"); +- client +- .writer() +- .write_all(message.as_bytes()) +- .unwrap(); ++ client.writer().write_all(message.as_bytes()).unwrap(); + let transferred = transfer(&mut client, &mut server); + println!( + "{}: {} -> {:?}", +@@ -7920,10 +7342,7 @@ + + // finally, server writes and pumps its key_update response + let message = b"finished"; +- server +- .writer() +- .write_all(message) +- .unwrap(); ++ server.writer().write_all(message).unwrap(); + let transferred = transfer(&mut server, &mut client); + + println!( +@@ -7957,10 +7376,7 @@ + + for i in 0..CONFIDENTIALITY_LIMIT { + let message = format!("{i:08}"); +- client +- .writer() +- .write_all(message.as_bytes()) +- .unwrap(); ++ client.writer().write_all(message.as_bytes()).unwrap(); + let transferred = transfer(&mut client, &mut server); + println!( + "{}: {} -> {:?}", +@@ -7978,7 +7394,7 @@ + } + } + } +- ++#[ignore] + #[test] + fn test_keys_match_for_all_signing_key_types() { + let provider = provider::default_provider(); +@@ -8017,14 +7433,10 @@ + ClientConnection::new(Arc::new(client_config), server_name("localhost")).unwrap(); + + let mut hello = Vec::new(); +- client +- .write_tls(&mut io::Cursor::new(&mut hello)) +- .unwrap(); ++ client.write_tls(&mut io::Cursor::new(&mut hello)).unwrap(); + + let first_flight = include_bytes!("data/bug2040-message-1.bin"); +- client +- .read_tls(&mut io::Cursor::new(first_flight)) +- .unwrap(); ++ client.read_tls(&mut io::Cursor::new(first_flight)).unwrap(); + client.process_new_packets().unwrap(); + + let second_flight = include_bytes!("data/bug2040-message-2.bin"); +@@ -8032,9 +7444,7 @@ + .read_tls(&mut io::Cursor::new(second_flight)) + .unwrap(); + assert_eq!( +- client +- .process_new_packets() +- .unwrap_err(), ++ client.process_new_packets().unwrap_err(), + Error::InvalidCertificate(CertificateError::UnknownIssuer), + ); + } +@@ -8074,7 +7484,7 @@ + kt, + ClientConfig::builder_with_provider( + CryptoProvider { +- kx_groups: vec![&FakeHybrid, provider::kx_group::SECP384R1], ++ kx_groups: vec![&FakeHybrid, &provider::kx::SECP384R1], + ..provider::default_provider() + } + .into(), +@@ -8093,9 +7503,7 @@ + server.process_new_packets().unwrap(); + transfer(&mut server, &mut client_1); + assert_eq!( +- client_1 +- .process_new_packets() +- .unwrap_err(), ++ client_1.process_new_packets().unwrap_err(), + PeerMisbehaved::WrongGroupForKeyShare.into() + ); + } +@@ -8121,7 +7529,7 @@ + } + + fn hybrid_component(&self) -> Option<(NamedGroup, &[u8])> { +- Some((provider::kx_group::SECP384R1.name(), b"classical")) ++ Some((provider::kx::SECP384R1.name(), b"classical")) + } + + fn pub_key(&self) -> &[u8] { +--- a/rustls/tests/api_ffdhe.rs ++++ b/rustls/tests/api_ffdhe.rs +@@ -3,19 +3,19 @@ + #![allow(clippy::duplicate_mod)] + + mod common; ++use super::*; + use common::*; + use rustls::crypto::CryptoProvider; ++use rustls::crypto::SupportedKxGroup; + use rustls::version::{TLS12, TLS13}; + use rustls::{CipherSuite, ClientConfig, NamedGroup}; + +-use super::*; +- + #[test] + fn config_builder_for_client_rejects_cipher_suites_without_compatible_kx_groups() { + let bad_crypto_provider = CryptoProvider { + kx_groups: vec![&ffdhe::FFDHE2048_KX_GROUP], + cipher_suites: vec![ +- provider::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ provider::TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ffdhe::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, + ], + ..provider::default_provider() +@@ -36,12 +36,11 @@ + + #[test] + fn ffdhe_ciphersuite() { +- use provider::cipher_suite; + use rustls::version::{TLS12, TLS13}; + + let test_cases = [ + (&TLS12, ffdhe::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256), +- (&TLS13, cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), ++ (&TLS13, provider::TLS13_CHACHA20_POLY1305_SHA256), + ]; + + for (expected_protocol, expected_cipher_suite) in test_cases { +@@ -77,9 +76,9 @@ + CryptoProvider { + cipher_suites: vec![ + ffdhe::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +- provider::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ provider::TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ], +- kx_groups: vec![&ffdhe::FFDHE4096_KX_GROUP, provider::kx_group::SECP256R1], ++ kx_groups: vec![&ffdhe::FFDHE4096_KX_GROUP, &provider::kx::SECP256R1], + ..provider::default_provider() + } + .into(), +@@ -94,9 +93,9 @@ + CryptoProvider { + cipher_suites: vec![ + ffdhe::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +- provider::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ provider::TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ], +- kx_groups: vec![&ffdhe::FFDHE2048_KX_GROUP, provider::kx_group::SECP256R1], ++ kx_groups: vec![&ffdhe::FFDHE2048_KX_GROUP, &provider::kx::SECP256R1], + ..provider::default_provider() + } + .into(), +@@ -108,17 +107,11 @@ + let (mut client, mut server) = make_pair_for_configs(client_config, server_config); + do_handshake(&mut client, &mut server); + assert_eq!( +- server +- .negotiated_cipher_suite() +- .unwrap() +- .suite(), ++ server.negotiated_cipher_suite().unwrap().suite(), + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + ); + assert_eq!( +- server +- .negotiated_key_exchange_group() +- .unwrap() +- .name(), ++ server.negotiated_key_exchange_group().unwrap().name(), + NamedGroup::secp256r1, + ) + } +@@ -130,11 +123,11 @@ + rustls::ServerConfig::builder_with_provider( + CryptoProvider { + cipher_suites: vec![ +- provider::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ provider::TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ffdhe::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +- provider::cipher_suite::TLS13_AES_128_GCM_SHA256, ++ provider::TLS13_AES_128_GCM_SHA256, + ], +- kx_groups: vec![provider::kx_group::SECP256R1, &ffdhe::FFDHE2048_KX_GROUP], ++ kx_groups: vec![&provider::kx::SECP256R1, &ffdhe::FFDHE2048_KX_GROUP], + ..provider::default_provider() + } + .into(), +@@ -149,8 +142,8 @@ + // TLS 1.2, have common + vec![ + // this matches: +- provider::kx_group::SECP256R1, +- &ffdhe::FFDHE2048_KX_GROUP, ++ &provider::kx::SECP256R1, ++ ffdhe::FFDHE_KX_GROUPS.to_vec().as_slice()[0], + ], + &TLS12, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +@@ -159,8 +152,8 @@ + ( + vec![ + // this matches: +- provider::kx_group::SECP256R1, +- &ffdhe::FFDHE3072_KX_GROUP, ++ &provider::kx::SECP256R1, ++ ffdhe::FFDHE_KX_GROUPS.to_vec().as_slice()[1], + ], + &TLS12, + CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +@@ -168,9 +161,9 @@ + ), + ( + vec![ +- provider::kx_group::SECP384R1, ++ &provider::kx::SECP384R1, + // this matches: +- &ffdhe::FFDHE2048_KX_GROUP, ++ ffdhe::FFDHE_KX_GROUPS.to_vec().as_slice()[0], + ], + &TLS12, + CipherSuite::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +@@ -180,8 +173,8 @@ + // TLS 1.3, have common + vec![ + // this matches: +- provider::kx_group::SECP256R1, +- &ffdhe::FFDHE2048_KX_GROUP, ++ &provider::kx::SECP256R1, ++ ffdhe::FFDHE_KX_GROUPS.to_vec().as_slice()[0], + ], + &TLS13, + CipherSuite::TLS13_AES_128_GCM_SHA256, +@@ -190,8 +183,8 @@ + ( + vec![ + // this matches: +- provider::kx_group::SECP256R1, +- &ffdhe::FFDHE3072_KX_GROUP, ++ &provider::kx::SECP256R1, ++ ffdhe::FFDHE_KX_GROUPS.to_vec().as_slice()[1], + ], + &TLS13, + CipherSuite::TLS13_AES_128_GCM_SHA256, +@@ -199,9 +192,9 @@ + ), + ( + vec![ +- provider::kx_group::SECP384R1, ++ &provider::kx::SECP384R1, + // this matches: +- &ffdhe::FFDHE2048_KX_GROUP, ++ ffdhe::FFDHE_KX_GROUPS.to_vec().as_slice()[0], + ], + &TLS13, + CipherSuite::TLS13_AES_128_GCM_SHA256, +@@ -215,9 +208,9 @@ + rustls::ClientConfig::builder_with_provider( + CryptoProvider { + cipher_suites: vec![ +- provider::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ provider::TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + ffdhe::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +- provider::cipher_suite::TLS13_AES_128_GCM_SHA256, ++ provider::TLS13_AES_128_GCM_SHA256, + ], + kx_groups: client_kx_groups, + ..provider::default_provider() +@@ -232,17 +225,12 @@ + let (mut client, mut server) = make_pair_for_arc_configs(&client_config, &server_config); + do_handshake(&mut client, &mut server); + assert_eq!( +- server +- .negotiated_cipher_suite() +- .unwrap() +- .suite(), ++ server.negotiated_cipher_suite().unwrap().suite(), + expected_cipher_suite + ); + assert_eq!(server.protocol_version(), Some(protocol_version.version)); + assert_eq!( +- server +- .negotiated_key_exchange_group() +- .map(|kx| kx.name()), ++ server.negotiated_key_exchange_group().map(|kx| kx.name()), + expected_group, + ); + } +@@ -250,7 +238,7 @@ + + #[test] + fn non_ffdhe_kx_does_not_have_ffdhe_group() { +- let non_ffdhe = provider::kx_group::SECP256R1; ++ let non_ffdhe = &provider::kx::SECP256R1; + assert_eq!(non_ffdhe.ffdhe_group(), None); + let active = non_ffdhe.start().unwrap(); + assert_eq!(active.ffdhe_group(), None); +@@ -263,7 +251,7 @@ + SupportedKxGroup, + }; + use rustls::ffdhe_groups::FfdheGroup; +- use rustls::{CipherSuite, NamedGroup, SupportedCipherSuite, Tls12CipherSuite, ffdhe_groups}; ++ use rustls::{ffdhe_groups, CipherSuite, NamedGroup, SupportedCipherSuite, Tls12CipherSuite}; + + use super::provider; + +@@ -276,7 +264,8 @@ + } + } + +- static FFDHE_KX_GROUPS: &[&dyn SupportedKxGroup] = &[&FFDHE2048_KX_GROUP, &FFDHE3072_KX_GROUP]; ++ pub static FFDHE_KX_GROUPS: &[&dyn SupportedKxGroup] = ++ &[&FFDHE2048_KX_GROUP, &FFDHE3072_KX_GROUP]; + + pub const FFDHE2048_KX_GROUP: FfdheKxGroup = + FfdheKxGroup(NamedGroup::FFDHE2048, ffdhe_groups::FFDHE2048); +@@ -287,7 +276,7 @@ + + static FFDHE_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = &[ + TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +- provider::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, ++ provider::TLS13_CHACHA20_POLY1305_SHA256, + ]; + + /// The (test-only) TLS1.2 ciphersuite TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 +@@ -295,7 +284,7 @@ + SupportedCipherSuite::Tls12(&TLS12_DHE_RSA_WITH_AES_128_GCM_SHA256); + + static TLS12_DHE_RSA_WITH_AES_128_GCM_SHA256: Tls12CipherSuite = +- match &provider::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 { ++ match &provider::TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256 { + SupportedCipherSuite::Tls12(provider) => Tls12CipherSuite { + common: CipherSuiteCommon { + suite: CipherSuite::TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, +@@ -313,9 +302,7 @@ + impl SupportedKxGroup for FfdheKxGroup { + fn start(&self) -> Result, rustls::Error> { + let mut x = vec![0; 64]; +- ffdhe_provider() +- .secure_random +- .fill(&mut x)?; ++ ffdhe_provider().secure_random.fill(&mut x)?; + let x = BigUint::from_bytes_be(&x); + + let p = BigUint::from_bytes_be(self.1.p); +--- a/rustls/tests/client_cert_verifier.rs ++++ b/rustls/tests/client_cert_verifier.rs +@@ -7,10 +7,9 @@ + mod common; + + use common::{ +- Arc, ErrorFromPeer, KeyType, MockClientVerifier, do_handshake_until_both_error, +- do_handshake_until_error, make_client_config_with_versions, ++ do_handshake_until_both_error, do_handshake_until_error, make_client_config_with_versions, + make_client_config_with_versions_with_auth, make_pair_for_arc_configs, server_config_builder, +- server_name, ++ server_name, Arc, ErrorFromPeer, KeyType, MockClientVerifier, + }; + use rustls::server::danger::ClientCertVerified; + use rustls::{ +--- a/rustls/tests/ech.rs ++++ b/rustls/tests/ech.rs +@@ -1,4 +1,4 @@ +-use base64::prelude::{BASE64_STANDARD, Engine}; ++use base64::prelude::{Engine, BASE64_STANDARD}; + use pki_types::DnsName; + use rustls::internal::msgs::codec::{Codec, Reader}; + use rustls::internal::msgs::enums::{EchVersion, HpkeAead, HpkeKdf, HpkeKem}; +--- a/rustls/tests/key_log_file_env.rs ++++ b/rustls/tests/key_log_file_env.rs +@@ -30,8 +30,8 @@ + + mod common; + use common::{ +- Arc, KeyType, do_handshake, make_client_config_with_versions, make_pair_for_arc_configs, +- make_server_config, transfer, ++ do_handshake, make_client_config_with_versions, make_pair_for_arc_configs, make_server_config, ++ transfer, Arc, KeyType, + }; + + #[test] +--- a/rustls/tests/process_provider.rs ++++ b/rustls/tests/process_provider.rs +@@ -4,14 +4,14 @@ + //! executable, and runs tests in an indeterminate order. That restricts us + //! to doing all the desired tests, in series, in one function. + +-use rustls::ClientConfig; +-use rustls::crypto::CryptoProvider; + #[cfg(all(feature = "aws_lc_rs", not(feature = "ring")))] + use rustls::crypto::aws_lc_rs as provider; + #[cfg(all(feature = "ring", not(feature = "aws_lc_rs")))] + use rustls::crypto::ring as provider; + #[cfg(all(feature = "ring", feature = "aws_lc_rs"))] + use rustls::crypto::ring as provider; ++use rustls::crypto::CryptoProvider; ++use rustls::ClientConfig; + + mod common; + use crate::common::*; +--- a/rustls/tests/server_cert_verifier.rs ++++ b/rustls/tests/server_cert_verifier.rs +@@ -7,13 +7,13 @@ + mod common; + + use common::{ +- Arc, ErrorFromPeer, KeyType, MockServerVerifier, client_config_builder, do_handshake, +- do_handshake_until_both_error, do_handshake_until_error, make_client_config_with_versions, +- make_pair_for_arc_configs, make_server_config, server_config_builder, ++ client_config_builder, do_handshake, do_handshake_until_both_error, do_handshake_until_error, ++ make_client_config_with_versions, make_pair_for_arc_configs, make_server_config, ++ server_config_builder, Arc, ErrorFromPeer, KeyType, MockServerVerifier, + }; + use pki_types::{CertificateDer, ServerName}; +-use rustls::client::WebPkiServerVerifier; + use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier}; ++use rustls::client::WebPkiServerVerifier; + use rustls::server::{ClientHello, ResolvesServerCert}; + use rustls::sign::CertifiedKey; + use rustls::{ +@@ -87,9 +87,7 @@ + Error::InvalidMessage(InvalidMessage::HandshakePayloadTooLarge), + )); + +- client_config +- .dangerous() +- .set_certificate_verifier(verifier); ++ client_config.dangerous().set_certificate_verifier(verifier); + + let server_config = Arc::new(make_server_config(*kt, &provider)); + +@@ -121,9 +119,7 @@ + Error::InvalidMessage(InvalidMessage::HandshakePayloadTooLarge), + )); + +- client_config +- .dangerous() +- .set_certificate_verifier(verifier); ++ client_config.dangerous().set_certificate_verifier(verifier); + + let server_config = Arc::new(make_server_config(*kt, &provider)); + +@@ -183,11 +179,8 @@ + .iter() + .map(|kt| { + ( +- kt.ca_distinguished_name() +- .to_vec() +- .into(), +- kt.certified_key_with_cert_chain(&provider) +- .unwrap(), ++ kt.ca_distinguished_name().to_vec().into(), ++ kt.certified_key_with_cert_chain(&provider).unwrap(), + ) + }) + .collect(), +@@ -203,9 +196,7 @@ + + for key_type in key_types { + let mut root_store = RootCertStore::empty(); +- root_store +- .add(key_type.ca_cert()) +- .unwrap(); ++ root_store.add(key_type.ca_cert()).unwrap(); + let server_verifier = WebPkiServerVerifier::builder_with_provider( + Arc::new(root_store), + Arc::new(provider.clone()), +@@ -216,9 +207,7 @@ + let cas_sending_server_verifier = Arc::new(ServerCertVerifierWithCasExt { + verifier: server_verifier.clone(), + ca_names: vec![DistinguishedName::from( +- key_type +- .ca_distinguished_name() +- .to_vec(), ++ key_type.ca_distinguished_name().to_vec(), + )], + }); + +@@ -270,9 +259,7 @@ + return Some(self.0[0].1.clone()); + }; + for (name, certified_key) in self.0.iter() { +- let name = X509Name::from_der(name.as_ref()) +- .unwrap() +- .1; ++ let name = X509Name::from_der(name.as_ref()).unwrap().1; + if cas_extension.iter().any(|ca_name| { + X509Name::from_der(ca_name.as_ref()).is_ok_and(|(_, ca_name)| ca_name == name) + }) { +@@ -310,8 +297,7 @@ + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { +- self.verifier +- .verify_tls12_signature(message, cert, dss) ++ self.verifier.verify_tls12_signature(message, cert, dss) + } + + fn verify_tls13_signature( +@@ -320,8 +306,7 @@ + cert: &CertificateDer<'_>, + dss: &DigitallySignedStruct, + ) -> Result { +- self.verifier +- .verify_tls13_signature(message, cert, dss) ++ self.verifier.verify_tls13_signature(message, cert, dss) + } + + fn supported_verify_schemes(&self) -> Vec { +--- a/rustls/tests/unbuffered.rs ++++ b/rustls/tests/unbuffered.rs +@@ -19,7 +19,10 @@ + + mod common; + use common::*; +-use provider::cipher_suite; ++use rustls_wolfcrypt_provider::{ ++ TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS12_ECDHE_RSA_WITH_AES_256_GCM_SHA384, ++ TLS13_AES_128_GCM_SHA256, TLS13_AES_256_GCM_SHA384, ++}; + + const MAX_ITERATIONS: usize = 100; + +@@ -151,17 +154,8 @@ + &mut NO_ACTIONS.clone(), + ); + +- assert!( +- client_actions +- .app_data_to_send +- .is_none() +- ); +- assert_eq!( +- [expected], +- outcome +- .server_received_app_data +- .as_slice() +- ); ++ assert!(client_actions.app_data_to_send.is_none()); ++ assert_eq!([expected], outcome.server_received_app_data.as_slice()); + } + } + +@@ -187,17 +181,8 @@ + &mut server_actions, + ); + +- assert!( +- server_actions +- .app_data_to_send +- .is_none() +- ); +- assert_eq!( +- [expected], +- outcome +- .client_received_app_data +- .as_slice() +- ); ++ assert!(server_actions.app_data_to_send.is_none()); ++ assert_eq!([expected], outcome.client_received_app_data.as_slice()); + } + } + +@@ -223,13 +208,7 @@ + &mut NO_ACTIONS.clone(), + ); + +- assert_eq!( +- outcome +- .client +- .unwrap() +- .tls13_tickets_received(), +- 2 +- ); ++ assert_eq!(outcome.client.unwrap().tls13_tickets_received(), 2); + + let mut client_actions = Actions { + early_data_to_send: Some(expected), +@@ -273,17 +252,8 @@ + "WriteTraffic" + ] + ); +- assert!( +- client_actions +- .early_data_to_send +- .is_none() +- ); +- assert_eq!( +- [expected], +- outcome +- .server_received_early_data +- .as_slice() +- ); ++ assert!(client_actions.early_data_to_send.is_none()); ++ assert_eq!([expected], outcome.server_received_early_data.as_slice()); + } + + fn run( +@@ -350,9 +320,7 @@ + client_handshake_done = true; + } + State::ReceivedAppData { records } => { +- outcome +- .client_received_app_data +- .extend(records); ++ outcome.client_received_app_data.extend(records); + } + State::PeerClosed => { + outcome.client_saw_peer_closed_state = true; +@@ -401,14 +369,10 @@ + server_handshake_done = true; + } + State::ReceivedEarlyData { records } => { +- outcome +- .server_received_early_data +- .extend(records); ++ outcome.server_received_early_data.extend(records); + } + State::ReceivedAppData { records } => { +- outcome +- .server_received_app_data +- .extend(records); ++ outcome.server_received_app_data.extend(records); + } + State::PeerClosed => { + outcome.server_saw_peer_closed_state = true; +@@ -560,13 +524,11 @@ + let mut server = outcome.server.take().unwrap(); + + let mut client_send_buf = [0u8; 128]; +- let mut len = dbg!( +- write_traffic( +- client.process_tls_records(&mut []), +- |mut wt: WriteTraffic<_>| wt.queue_close_notify(&mut client_send_buf), +- ) +- .unwrap() +- ); ++ let mut len = dbg!(write_traffic( ++ client.process_tls_records(&mut []), ++ |mut wt: WriteTraffic<_>| wt.queue_close_notify(&mut client_send_buf), ++ ) ++ .unwrap()); + + client_send_buf[len..len + junk.len()].copy_from_slice(junk); + len += junk.len(); +@@ -684,8 +646,7 @@ + state: Ok(ConnectionState::WriteTraffic(mut wt)), + } => { + assert_eq!(used, actual_used); +- wt.encrypt(b"hello", &mut buffer) +- .unwrap() ++ wt.encrypt(b"hello", &mut buffer).unwrap() + } + st => { + panic!("unexpected server state {st:?}"); +@@ -712,9 +673,7 @@ + UnbufferedStatus { + discard: 0, + state: Ok(ConnectionState::WriteTraffic(mut wt)), +- } => wt +- .encrypt(b"world", &mut buffer) +- .unwrap(), ++ } => wt.encrypt(b"world", &mut buffer).unwrap(), + st => { + panic!("unexpected client state {st:?}"); + } +@@ -779,9 +738,7 @@ + let message = format!("{i:08}"); + + let mut buffer = [0u8; 64]; +- let used = wt +- .encrypt(message.as_bytes(), &mut buffer) +- .unwrap(); ++ let used = wt.encrypt(message.as_bytes(), &mut buffer).unwrap(); + + assert_eq!( + used, +@@ -987,13 +944,7 @@ + // now consume + let (data, discard) = read_traffic( + server.process_tls_records(client_to_server_buf.filled()), +- |mut rt| { +- rt.next_record() +- .unwrap() +- .unwrap() +- .payload +- .to_vec() +- }, ++ |mut rt| rt.next_record().unwrap().unwrap().payload.to_vec(), + ); + assert_eq!(discard, 0); + assert_eq!(data, b"hello"); +@@ -1418,8 +1369,7 @@ + if discard != 0 { + assert!(discard <= self.used); + +- self.inner +- .copy_within(discard..self.used, 0); ++ self.inner.copy_within(discard..self.used, 0); + self.used -= discard; + } + } +@@ -1523,14 +1473,14 @@ + let kt = KeyType::Rsa2048; + let provider = provider::default_provider(); + for suite in [ +- cipher_suite::TLS13_AES_128_GCM_SHA256, +- cipher_suite::TLS13_AES_256_GCM_SHA384, ++ TLS13_AES_128_GCM_SHA256, ++ TLS13_AES_256_GCM_SHA384, + #[cfg(not(feature = "fips"))] +- cipher_suite::TLS13_CHACHA20_POLY1305_SHA256, +- cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, +- cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, ++ TLS13_CHACHA20_POLY1305_SHA256, ++ TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, ++ TLS12_ECDHE_RSA_WITH_AES_256_GCM_SHA384, + #[cfg(not(feature = "fips"))] +- cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, ++ TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + ] { + let version = suite.version(); + println!("Testing suite {:?}", suite.suite().as_str()); +@@ -1566,14 +1516,8 @@ + let server = outcome.server.take().unwrap(); + + // The handshake is finished, we're now able to extract traffic secrets +- let client_secrets = client +- .dangerous_into_kernel_connection() +- .unwrap() +- .0; +- let server_secrets = server +- .dangerous_into_kernel_connection() +- .unwrap() +- .0; ++ let client_secrets = client.dangerous_into_kernel_connection().unwrap().0; ++ let server_secrets = server.dangerous_into_kernel_connection().unwrap().0; + + // Comparing secrets for equality is something you should never have to + // do in production code, so ConnectionTrafficSecrets doesn't implement +@@ -1624,16 +1568,8 @@ + + do_unbuffered_handshake(&mut client, &mut server); + +- assert!( +- client +- .dangerous_into_kernel_connection() +- .is_err() +- ); +- assert!( +- server +- .dangerous_into_kernel_connection() +- .is_err() +- ); ++ assert!(client.dangerous_into_kernel_connection().is_err()); ++ assert!(server.dangerous_into_kernel_connection().is_err()); + } + + #[test] +--- a/rustls/tests/common/mod.rs ++++ b/rustls/tests/common/mod.rs +@@ -3,10 +3,10 @@ + + pub use std::sync::Arc; + +-use rustls::RootCertStore; + use rustls::client::{ClientConfig, ServerCertVerifierBuilder, WebPkiServerVerifier}; + use rustls::crypto::CryptoProvider; + use rustls::server::{ClientCertVerifierBuilder, ServerConfig, WebPkiClientVerifier}; ++use rustls::RootCertStore; + pub use rustls_test::*; + + pub fn server_config_builder( +--- a/rustls/tests/runners/macros.rs ++++ b/rustls/tests/runners/macros.rs +@@ -5,6 +5,30 @@ + //! and `rustls::crypto::aws_lc_rs` modules. + + #[allow(unused_macros)] ++macro_rules! provider_wolfcrypt { ++ () => { ++ #[allow(unused_imports)] ++ use rustls_wolfcrypt_provider as provider; ++ #[allow(dead_code)] ++ const fn provider_is_wolfcrypt() -> bool { ++ true ++ } ++ #[allow(dead_code)] ++ const fn provider_is_aws_lc_rs() -> bool { ++ false ++ } ++ #[allow(dead_code)] ++ const fn provider_is_ring() -> bool { ++ false ++ } ++ #[allow(dead_code)] ++ const fn provider_is_fips() -> bool { ++ false ++ } ++ }; ++} ++ ++#[allow(unused_macros)] + macro_rules! provider_ring { + () => { + #[allow(unused_imports)] diff --git a/.github/test-workspace/tests/Cargo.toml b/.github/test-workspace/tests/Cargo.toml new file mode 100644 index 0000000..e083fbc --- /dev/null +++ b/.github/test-workspace/tests/Cargo.toml @@ -0,0 +1,125 @@ +[package] +name = "tests" +version = "0.1.0" +edition = "2021" +rust-version = "1.71" +license = "Apache-2.0 OR ISC OR MIT" +description = "TLS and QUIC tests" +categories = ["network-programming", "cryptography"] +autobenches = false +autotests = false +exclude = ["src/testdata", "tests/**"] +build = "build.rs" + +[features] +default = ["rustls/std"] + +aws-lc-rs = ["rustls/aws-lc-rs"] # Alias because Cargo features commonly use `-` +aws_lc_rs = ["rustls/aws_lc_rs"] +brotli = ["rustls/brotli"] +custom-provider = [] +fips = ["rustls/fips"] +logging = ["log", "rustls/log"] +prefer-post-quantum = ["aws_lc_rs"] +read_buf = ["rustversion", "std"] +ring = ["rustls/ring"] +std = ["rustls/std"] +tls12 = [] +zlib = ["rustls/zlib"] +wolfcrypt-provider = ["dep:rustls-wolfcrypt-provider", "rustls-wolfcrypt-provider/quic"] + + +[build-dependencies] +rustversion = { version = "1.0.6", optional = true } + +[dependencies] +pki-types = { workspace = true } +rustls = { path = "../rustls/rustls"} +webpki = { workspace = true } +log = { workspace = true, optional = true } +time = "0.3.44" +rcgen = "0.14.5" +rustls-wolfcrypt-provider = {path = "../rustls-wolfcrypt-provider/rustls-wolfcrypt-provider", default-features = false, optional = true} + + +[dev-dependencies] +rustls-test = {workspace = true} +rustls = { path = "../rustls/rustls" } +macro_rules_attribute = { workspace = true } +rustls-wolfcrypt-provider = {path = "../rustls-wolfcrypt-provider/rustls-wolfcrypt-provider", default-features = false} +num-bigint = "0.4.6" +x509-parser = "0.18.0" +env_logger = "0.11.8" + +log = { workspace = true } +base64 = { workspace = true } + +[[bench]] +name = "benchmarks" +path = "../rustls/rustls/benches/benchmarks.rs" +harness = false +required-features = ["ring"] + +[[example]] +name = "test_ca" +path = "../rustls/rustls/examples/internal/test_ca.rs" + +[[test]] +name = "all_suites" +path = "runners/all_test_suites.rs" + + +[[test]] +name = "api" +path = "runners/api.rs" + +[[test]] +name = "api_ffdhe" +path = "runners/api_ffdhe.rs" +required-features = ["tls12"] + +[[test]] +name = "bogo" +path = "bogo.rs" + +[[test]] +name = "client_cert_verifier" +path = "runners/client_cert_verifier.rs" + +[[test]] +name = "ech" +path = "ech.rs" + +[[test]] +name = "key_log_file_env" +path = "runners/key_log_file_env.rs" + +[[test]] +name = "process_provider" +path = "process_provider.rs" + +[[test]] +name = "server_cert_verifier" +path = "runners/server_cert_verifier.rs" + +[[test]] +name = "unbuffered" +path = "runners/unbuffered.rs" + +[package.metadata.docs.rs] +# all non-default features except fips (cannot build on docs.rs environment) +features = ["read_buf", "ring"] +rustdoc-args = ["--cfg", "docsrs"] + +[package.metadata.cargo_check_external_types] +allowed_external_types = [ + # --- + "rustls_pki_types", + "rustls_pki_types::*", +] + +[package.metadata.cargo-semver-checks.lints] +enum_no_repr_variant_discriminant_changed = "warn" + +[lints.rust] +unexpected_cfgs = { level = "warn", check-cfg = ["cfg(bench)", "cfg(coverage_nightly)", "cfg(read_buf)"] } diff --git a/.github/test-workspace/tests/build.rs b/.github/test-workspace/tests/build.rs new file mode 100644 index 0000000..8c0bd2a --- /dev/null +++ b/.github/test-workspace/tests/build.rs @@ -0,0 +1,13 @@ +//! This build script allows us to enable the `read_buf` language feature only +//! for Rust Nightly. +//! +//! See the comment in lib.rs to understand why we need this. + +#[cfg_attr(feature = "read_buf", rustversion::not(nightly))] +fn main() {} + +#[cfg(feature = "read_buf")] +#[rustversion::nightly] +fn main() { + println!("cargo:rustc-cfg=read_buf"); +} diff --git a/.github/test-workspace/tests/runners/all_test_suites.rs b/.github/test-workspace/tests/runners/all_test_suites.rs new file mode 100644 index 0000000..ce5faae --- /dev/null +++ b/.github/test-workspace/tests/runners/all_test_suites.rs @@ -0,0 +1,170 @@ +#![cfg_attr(read_buf, feature(read_buf))] +#![cfg_attr(read_buf, feature(core_io_borrowed_buf))] + +use std::cell::RefCell; +use std::env; +use std::sync::Mutex; + +#[macro_use] +mod macros; + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_api { + use super::*; + + provider_wolfcrypt!(); + + #[path = "../api.rs"] + mod tests; +} + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_client_cert_verifier { + + provider_wolfcrypt!(); + + #[path = "../client_cert_verifier.rs"] + mod tests; +} + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_key_log_file_env { + use super::serialized; + + provider_wolfcrypt!(); + + #[path = "../key_log_file_env.rs"] + mod tests; +} + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_server_cert_verifier { + + provider_wolfcrypt!(); + + #[path = "../server_cert_verifier.rs"] + mod tests; +} + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_unbuffered { + + provider_wolfcrypt!(); + + #[path = "../unbuffered.rs"] + mod tests; +} + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_ech { + + provider_wolfcrypt!(); + + #[path = "../ech.rs"] + mod tests; +} + +#[cfg(feature = "wolfcrypt-provider")] +#[path = "."] +mod tests_with_wolfcrypt_ffdhe { + provider_wolfcrypt!(); + + #[path = "../api_ffdhe.rs"] + mod tests; +} + +// this must be outside tests_with_*, as we want +// one thread_local!, not one per provider. +thread_local!(static COUNTS: RefCell = RefCell::new(LogCounts::new())); + +struct CountingLogger; + +#[allow(dead_code)] +static LOGGER: CountingLogger = CountingLogger; + +#[allow(dead_code)] +impl CountingLogger { + fn install() { + let _ = log::set_logger(&LOGGER); + log::set_max_level(log::LevelFilter::Trace); + } + + fn reset() { + COUNTS.with(|c| { + c.borrow_mut().reset(); + }); + } +} + +impl log::Log for CountingLogger { + fn enabled(&self, _metadata: &log::Metadata) -> bool { + true + } + + fn log(&self, record: &log::Record) { + println!("logging at {:?}: {:?}", record.level(), record.args()); + + COUNTS.with(|c| { + c.borrow_mut() + .add(record.level(), format!("{}", record.args())); + }); + } + + fn flush(&self) {} +} + +#[derive(Default, Debug)] +struct LogCounts { + trace: Vec, + debug: Vec, + info: Vec, + warn: Vec, + error: Vec, +} + +impl LogCounts { + fn new() -> Self { + Self { + ..Default::default() + } + } + + fn reset(&mut self) { + *self = Self::new(); + } + + fn add(&mut self, level: log::Level, message: String) { + match level { + log::Level::Trace => &mut self.trace, + log::Level::Debug => &mut self.debug, + log::Level::Info => &mut self.info, + log::Level::Warn => &mut self.warn, + log::Level::Error => &mut self.error, + } + .push(message); + } +} + +/// Approximates `#[serial]` from the `serial_test` crate. +/// +/// No attempt is made to recover from a poisoned mutex, which will +/// happen when `f` panics. In other words, all the tests that use +/// `serialized` will start failing after one test panics. +#[allow(dead_code)] +fn serialized(f: impl FnOnce()) { + // Ensure every test is run serialized + static MUTEX: Mutex<()> = const { Mutex::new(()) }; + + let _guard = MUTEX.lock().unwrap(); + + // XXX: NOT thread safe. + unsafe { env::set_var("SSLKEYLOGFILE", "./sslkeylogfile.txt") }; + + f() +} diff --git a/.github/test-workspace/workspace-Cargo.toml b/.github/test-workspace/workspace-Cargo.toml new file mode 100644 index 0000000..4799ec5 --- /dev/null +++ b/.github/test-workspace/workspace-Cargo.toml @@ -0,0 +1,101 @@ +[workspace] +members = [ + + # the main library and tests + "rustls/rustls", + "tests", + "rustls-wolfcrypt-provider/rustls-wolfcrypt-provider", + +] + +## Deliberately not included in `members`: +exclude = [ + # `cargo fuzz` integration (requires nightly) + "rustls/fuzz", +] + +default-members = [ + # --- + "rustls-wolfcrypt-provider/rustls-wolfcrypt-provider", + "rustls/rustls", + "tests" +] +resolver = "2" + +[workspace.dependencies] +anyhow = "1.0.73" +asn1 = "0.22" +async-trait = "0.1.74" +aws-lc-rs = { version = "1.14", default-features = false } +base64 = "0.22" +bencher = "0.1.5" +brotli = { version = "8", default-features = false, features = ["std"] } +brotli-decompressor = "5.0.0" +byteorder = "1.4.3" +chacha20poly1305 = { version = "0.10", default-features = false, features = ["alloc"] } +clap = { version = "4.3.21", features = ["derive", "env"] } +crabgrind = "=0.1.9" # compatible with valgrind package on GHA ubuntu-latest +criterion = "0.6" +der = "0.7" +ecdsa = "0.16.8" +env_logger = "0.11" +hashbrown = { version = "0.15", default-features = false, features = ["default-hasher", "inline-more"] } +hex = "0.4" +hickory-resolver = { version = "0.25", features = ["https-aws-lc-rs", "webpki-roots"] } +hmac = "0.12" +hpke-rs = "0.3" +hpke-rs-crypto = "0.3" +hpke-rs-rust-crypto = "0.3" +itertools = "0.14" +log = { version = "0.4.8" } +macro_rules_attribute = "0.2" +mio = { version = "1", features = ["net", "os-poll"] } +num-bigint = "0.4.4" +once_cell = { version = "1.16", default-features = false, features = ["alloc", "race"] } +openssl = "0.10" +p256 = { version = "0.13.2", default-features = false, features = ["alloc", "ecdsa", "pkcs8"] } +pkcs8 = "0.10.2" +pki-types = { package = "rustls-pki-types", version = "1.12", features = ["alloc"] } +rand_core = { version = "0.6", features = ["getrandom"] } +rayon = "1.7" +rcgen = { version = "0.14", features = ["pem", "aws_lc_rs"], default-features = false } +regex = "1" +ring = "0.17" +rsa = { version = "0.9", features = ["sha2"], default-features = false } +rustc-hash = "2" +rustls-graviola = { version = "0.2" } +rustls-test = { path = "rustls/rustls-test" } +rustls-fuzzing-provider = { path = "rustls/rustls-fuzzing-provider" } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +sha2 = { version = "0.10", default-features = false } +signature = "2" +subtle = { version = "2.5.0", default-features = false } +time = { version = "0.3.6", default-features = false } +tikv-jemallocator = "0.6" +tokio = { version = "1.34", features = ["io-util", "macros", "net", "rt"] } +webpki = { package = "rustls-webpki", version = "0.103.5", features = ["alloc"], default-features = false } +webpki-roots = "1" +x25519-dalek = "2" +x509-parser = "0.17" +zeroize = { version = "1", default-features = false, features = ["alloc", "derive"] } +zlib-rs = "0.5" + + +[profile.bench] +codegen-units = 1 +lto = true + +[profile.test] +opt-level = 0 +debug = true +debug-assertions = true +overflow-checks = true +incremental = true +lto = false + +# ensure all our tests are against the local copy, never +# against the latest _published_ copy. +[patch.crates-io] +rustls = { path = "rustls/rustls" } +rustls-wolfcrypt-provider = {path = "rustls-wolfcrypt-provider/rustls-wolfcrypt-provider"} diff --git a/.github/workflows/macos-build.yml b/.github/workflows/macos-build.yml index 69ff695..08136ac 100644 --- a/.github/workflows/macos-build.yml +++ b/.github/workflows/macos-build.yml @@ -16,6 +16,7 @@ jobs: - name: Install Build Prerequisites run: | brew install autoconf libtool automake + brew install go - name: Install Rust uses: dtolnay/rust-toolchain@master @@ -58,4 +59,4 @@ jobs: cd wolfcrypt-rs cargo clippy -- -D warnings -A unnecessary-transmutes cd ../rustls-wolfcrypt-provider - cargo clippy -- -D warnings + cargo clippy --all-features -- -D warnings diff --git a/.github/workflows/macos-rustls-tests.yml b/.github/workflows/macos-rustls-tests.yml new file mode 100644 index 0000000..590ae2f --- /dev/null +++ b/.github/workflows/macos-rustls-tests.yml @@ -0,0 +1,79 @@ +name: macOS rustls tests + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ 'main' ] + +jobs: + macos-build: + name: Build and Test (macOS) + runs-on: macos-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Build Prerequisites + run: | + brew install autoconf libtool automake + brew install go + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: rustfmt, clippy + + - name: Cache Rust dependencies + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: macos-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + macos-cargo- + + - name: Checkout rustls v0.23.35 + uses: actions/checkout@v4 + with: + repository: rustls/rustls + ref: v/0.23.35 + fetch-depth: 0 + path: rustlsv0.23.35-test-workspace/rustls + + - name: Apply rustls test modifications + working-directory: rustlsv0.23.35-test-workspace/rustls + run: patch -p1 < "$GITHUB_WORKSPACE/.github/test-workspace/rustls-v0.23.35-tests.patch" + + - name: Assemble test workspace + working-directory: rustlsv0.23.35-test-workspace + run: | + cp -r rustls/rustls/tests . + cp "$GITHUB_WORKSPACE/.github/test-workspace/tests/Cargo.toml" tests/ + cp "$GITHUB_WORKSPACE/.github/test-workspace/tests/build.rs" tests/ + cp "$GITHUB_WORKSPACE/.github/test-workspace/tests/runners/all_test_suites.rs" tests/runners/ + + cp "$GITHUB_WORKSPACE/.github/test-workspace/workspace-Cargo.toml" Cargo.toml + + mkdir -p rustls-wolfcrypt-provider + cp -r "$GITHUB_WORKSPACE/wolfcrypt-rs" rustls-wolfcrypt-provider/ + cp -r "$GITHUB_WORKSPACE/rustls-wolfcrypt-provider" rustls-wolfcrypt-provider/ + cp "$GITHUB_WORKSPACE/.github/test-workspace/provider-Cargo.toml" \ + rustls-wolfcrypt-provider/rustls-wolfcrypt-provider/Cargo.toml + + - name: Build wolfcrypt-rs + working-directory: rustlsv0.23.35-test-workspace/rustls-wolfcrypt-provider/wolfcrypt-rs + run: make build + + - name: Build rustls-wolfcrypt-provider + working-directory: rustlsv0.23.35-test-workspace/rustls-wolfcrypt-provider/rustls-wolfcrypt-provider + run: cargo build --all-features --release + + - name: Run test suite + working-directory: rustlsv0.23.35-test-workspace + run: | + cargo test -p tests --test all_suites \ + --features wolfcrypt-provider,tls12,fips,zlib,prefer-post-quantum,logging \ + --no-default-features diff --git a/.github/workflows/ubuntu-build.yml b/.github/workflows/ubuntu-build.yml index 6b379fc..1081846 100644 --- a/.github/workflows/ubuntu-build.yml +++ b/.github/workflows/ubuntu-build.yml @@ -57,6 +57,7 @@ jobs: - name: Run clippy run: | cd wolfcrypt-rs - cargo clippy -- -D warnings + cargo clippy --all-features -- -D warnings cd ../rustls-wolfcrypt-provider - cargo clippy -- -D warnings + cargo clippy --all-features -- -D warnings + diff --git a/.github/workflows/ubuntu-rustls-tests.yml b/.github/workflows/ubuntu-rustls-tests.yml new file mode 100644 index 0000000..a06764f --- /dev/null +++ b/.github/workflows/ubuntu-rustls-tests.yml @@ -0,0 +1,79 @@ +name: Ubuntu rustls tests + +on: + push: + branches: [ 'main' ] + pull_request: + branches: [ 'main' ] + +jobs: + ubuntu-build: + name: Build and Test (Ubuntu) + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Install Build Prerequisites + run: | + sudo apt-get update + sudo apt-get install -y build-essential autoconf libtool + + - name: Install Rust + uses: dtolnay/rust-toolchain@master + with: + toolchain: stable + components: rustfmt, clippy + + - name: Cache Rust dependencies + uses: actions/cache@v3 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: ubuntu-cargo-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + ubuntu-cargo- + + - name: Checkout rustls v0.23.35 + uses: actions/checkout@v4 + with: + repository: rustls/rustls + ref: v/0.23.35 + fetch-depth: 0 + path: rustlsv0.23.35-test-workspace/rustls + + - name: Apply rustls test modifications + working-directory: rustlsv0.23.35-test-workspace/rustls + run: patch -p1 < "$GITHUB_WORKSPACE/.github/test-workspace/rustls-v0.23.35-tests.patch" + + - name: Assemble test workspace + working-directory: rustlsv0.23.35-test-workspace + run: | + cp -r rustls/rustls/tests . + cp "$GITHUB_WORKSPACE/.github/test-workspace/tests/Cargo.toml" tests/ + cp "$GITHUB_WORKSPACE/.github/test-workspace/tests/build.rs" tests/ + cp "$GITHUB_WORKSPACE/.github/test-workspace/tests/runners/all_test_suites.rs" tests/runners/ + + cp "$GITHUB_WORKSPACE/.github/test-workspace/workspace-Cargo.toml" Cargo.toml + + mkdir -p rustls-wolfcrypt-provider + cp -r "$GITHUB_WORKSPACE/wolfcrypt-rs" rustls-wolfcrypt-provider/ + cp -r "$GITHUB_WORKSPACE/rustls-wolfcrypt-provider" rustls-wolfcrypt-provider/ + cp "$GITHUB_WORKSPACE/.github/test-workspace/provider-Cargo.toml" \ + rustls-wolfcrypt-provider/rustls-wolfcrypt-provider/Cargo.toml + + - name: Build wolfcrypt-rs + working-directory: rustlsv0.23.35-test-workspace/rustls-wolfcrypt-provider/wolfcrypt-rs + run: make build + + - name: Build rustls-wolfcrypt-provider + working-directory: rustlsv0.23.35-test-workspace/rustls-wolfcrypt-provider/rustls-wolfcrypt-provider + run: cargo build --all-features --release + + - name: Run test suite + working-directory: rustlsv0.23.35-test-workspace + run: | + cargo test -p tests --test all_suites \ + --features wolfcrypt-provider,tls12,fips,zlib,prefer-post-quantum,logging \ + --no-default-features diff --git a/.gitignore b/.gitignore index 9d879c6..3b77375 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ /rustls-wolfcrypt-provider/target /wolfcrypt-rs/target /wolfcrypt-rs/wolfssl-*/ +/rustlsv0.23.35-test-workspace/ diff --git a/rustls-wolfcrypt-provider/Cargo.lock b/rustls-wolfcrypt-provider/Cargo.lock index f232468..762a6ee 100644 --- a/rustls-wolfcrypt-provider/Cargo.lock +++ b/rustls-wolfcrypt-provider/Cargo.lock @@ -845,9 +845,9 @@ checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" [[package]] name = "rand" -version = "0.8.5" +version = "0.8.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +checksum = "5ca0ecfa931c29007047d1bc58e623ab12e5590e8c7cc53200d5202b69266d8a" dependencies = [ "rand_chacha", "rand_core", @@ -1019,9 +1019,9 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.103.10" +version = "0.103.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df33b2b81ac578cabaf06b89b0631153a3f416b0a886e8a7a1707fb51abbd1ef" +checksum = "61c429a8649f110dddef65e2a5ad240f747e85f7758a6bccc7e5777bd33f756e" dependencies = [ "aws-lc-rs", "ring", diff --git a/rustls-wolfcrypt-provider/Cargo.toml b/rustls-wolfcrypt-provider/Cargo.toml index e3ddb34..d324c72 100644 --- a/rustls-wolfcrypt-provider/Cargo.toml +++ b/rustls-wolfcrypt-provider/Cargo.toml @@ -14,7 +14,7 @@ rand_core = { version = "0.6", default-features = false, features = ["getrandom" rsa = { version = "0.9", features = ["sha2"], default-features = false } sha2 = { version = "0.10", default-features = false } signature = { version = "2", default-features = false } -webpki = { package = "rustls-webpki", version = "0.103.10", features = ["alloc"], default-features = false } +webpki = { package = "rustls-webpki", version = "0.103.13", features = ["alloc"], default-features = false } foreign-types = { version = "0.5.0", default-features = false } rustls-pki-types = { version = "1.11.0", default-features = false } log = { version = "0.4.25", default-features = false } @@ -45,6 +45,7 @@ rustls-pemfile = { version = "2.2.0", default-features = false, features = ["std [features] default = [] std = ["pkcs8/std", "rustls/std", "wolfcrypt-rs/std"] +quic = [] [profile.release] strip = true diff --git a/rustls-wolfcrypt-provider/examples/client.rs b/rustls-wolfcrypt-provider/examples/client.rs index a7655cc..b76e3f6 100644 --- a/rustls-wolfcrypt-provider/examples/client.rs +++ b/rustls-wolfcrypt-provider/examples/client.rs @@ -1,4 +1,4 @@ -use rustls_wolfcrypt_provider::provider; +use rustls_wolfcrypt_provider::default_provider; use std::io::{stdout, Read, Write}; use std::net::TcpStream; use std::sync::Arc; @@ -9,7 +9,7 @@ fn main() { let root_store = rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); - let config = rustls::ClientConfig::builder_with_provider(provider().into()) + let config = rustls::ClientConfig::builder_with_provider(default_provider().into()) .with_safe_default_protocol_versions() .unwrap() .with_root_certificates(root_store) diff --git a/rustls-wolfcrypt-provider/examples/server.rs b/rustls-wolfcrypt-provider/examples/server.rs index eb31377..7261891 100644 --- a/rustls-wolfcrypt-provider/examples/server.rs +++ b/rustls-wolfcrypt-provider/examples/server.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; use rustls::server::Acceptor; use rustls::ServerConfig; -use rustls_wolfcrypt_provider::provider; +use rustls_wolfcrypt_provider::default_provider; fn main() { env_logger::init(); @@ -90,7 +90,7 @@ impl TestPki { } fn server_config(self) -> Arc { - let mut server_config = ServerConfig::builder_with_provider(provider().into()) + let mut server_config = ServerConfig::builder_with_provider(default_provider().into()) .with_safe_default_protocol_versions() .unwrap() .with_no_client_auth() diff --git a/rustls-wolfcrypt-provider/src/aead/aes128gcm.rs b/rustls-wolfcrypt-provider/src/aead/aes128gcm.rs index ff6d221..9371ebb 100644 --- a/rustls-wolfcrypt-provider/src/aead/aes128gcm.rs +++ b/rustls-wolfcrypt-provider/src/aead/aes128gcm.rs @@ -173,6 +173,9 @@ impl MessageDecrypter for WCTls12Decrypter { seq: u64, ) -> Result, rustls::Error> { let payload = &mut m.payload; + if payload.len() < GCM_TAG_LENGTH { + return Err(rustls::Error::DecryptError); + } let payload_len = payload.len(); // First we copy the implicit nonce followed by copying @@ -226,7 +229,8 @@ impl MessageDecrypter for WCTls12Decrypter { aad.len() as word32, ) }; - check_if_zero(ret).map_err(|_| rustls::Error::General("wc_AesGcmDecrypt failed".into()))?; + + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; payload.copy_within(payload_start..(payload_len - GCM_TAG_LENGTH), 0); payload.truncate(payload_len - ((payload_start) + GCM_TAG_LENGTH)); @@ -354,6 +358,9 @@ impl MessageDecrypter for WCTls13Cipher { seq: u64, ) -> Result, rustls::Error> { let payload = &mut m.payload; + if payload.len() < GCM_TAG_LENGTH { + return Err(rustls::Error::DecryptError); + } let nonce = Nonce::new(&self.iv, seq); let aad = make_tls13_aad(payload.len()); let mut auth_tag = [0u8; GCM_TAG_LENGTH]; @@ -391,7 +398,8 @@ impl MessageDecrypter for WCTls13Cipher { aad.len() as word32, ) }; - check_if_zero(ret).map_err(|_| rustls::Error::General("wc_AesGcmDecrypt failed".into()))?; + + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; payload.truncate(message_len); diff --git a/rustls-wolfcrypt-provider/src/aead/aes256gcm.rs b/rustls-wolfcrypt-provider/src/aead/aes256gcm.rs index 3c83703..4dc1acd 100644 --- a/rustls-wolfcrypt-provider/src/aead/aes256gcm.rs +++ b/rustls-wolfcrypt-provider/src/aead/aes256gcm.rs @@ -173,6 +173,9 @@ impl MessageDecrypter for WCTls12Decrypter { seq: u64, ) -> Result, rustls::Error> { let payload = &mut m.payload; + if payload.len() < GCM_TAG_LENGTH { + return Err(rustls::Error::DecryptError); + } let payload_len = payload.len(); // First we copy the implicit nonce followed by copying @@ -226,7 +229,8 @@ impl MessageDecrypter for WCTls12Decrypter { aad.len() as word32, ) }; - check_if_zero(ret).map_err(|_| rustls::Error::General("wc_AesGcmDecrypt failed".into()))?; + + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; payload.copy_within(payload_start..(payload_len - GCM_TAG_LENGTH), 0); payload.truncate(payload_len - ((payload_start) + GCM_TAG_LENGTH)); @@ -354,6 +358,10 @@ impl MessageDecrypter for WCTls13Cipher { seq: u64, ) -> Result, rustls::Error> { let payload = &mut m.payload; + // In case peer misbehaves and sends plain text after it is not anymore allowed + if payload.len() < GCM_TAG_LENGTH { + return Err(rustls::Error::DecryptError); + } let nonce = Nonce::new(&self.iv, seq); let aad = make_tls13_aad(payload.len()); let mut auth_tag = [0u8; GCM_TAG_LENGTH]; @@ -391,7 +399,8 @@ impl MessageDecrypter for WCTls13Cipher { aad.len() as word32, ) }; - check_if_zero(ret).map_err(|_| rustls::Error::General("wc_AesGcmDecrypt failed".into()))?; + + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; payload.truncate(message_len); diff --git a/rustls-wolfcrypt-provider/src/aead/chacha20.rs b/rustls-wolfcrypt-provider/src/aead/chacha20.rs index 36a3b41..1a47f37 100644 --- a/rustls-wolfcrypt-provider/src/aead/chacha20.rs +++ b/rustls-wolfcrypt-provider/src/aead/chacha20.rs @@ -135,6 +135,9 @@ impl MessageDecrypter for WCTls12Cipher { seq: u64, ) -> Result, rustls::Error> { let payload = &mut m.payload; + if payload.len() < CHACHAPOLY1305_OVERHEAD { + return Err(rustls::Error::DecryptError); + } // We substract the tag, so this len will only consider // the message that we are trying to decrypt. @@ -162,8 +165,8 @@ impl MessageDecrypter for WCTls12Cipher { payload[..message_len].as_mut_ptr(), ) }; - check_if_zero(ret) - .map_err(|_| rustls::Error::General("wc_ChaCha20Poly1305_Decrypt failed".into()))?; + + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; // We extract the final result... payload.truncate(message_len); @@ -280,6 +283,9 @@ impl MessageDecrypter for WCTls13Cipher { seq: u64, ) -> Result, rustls::Error> { let payload = &mut m.payload; + if payload.len() < CHACHAPOLY1305_OVERHEAD { + return Err(rustls::Error::DecryptError); + } let nonce = Nonce::new(&self.iv, seq); let aad = make_tls13_aad(payload.len()); let mut auth_tag = [0u8; CHACHAPOLY1305_OVERHEAD]; @@ -306,8 +312,8 @@ impl MessageDecrypter for WCTls13Cipher { payload[..message_len].as_mut_ptr(), ) }; - check_if_zero(ret) - .map_err(|_| rustls::Error::General("wc_ChaCha20Poly1305_Decrypt failed".into()))?; + + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; // We extract the final result... payload.truncate(message_len); diff --git a/rustls-wolfcrypt-provider/src/aead/quic.rs b/rustls-wolfcrypt-provider/src/aead/quic.rs new file mode 100644 index 0000000..9fd0977 --- /dev/null +++ b/rustls-wolfcrypt-provider/src/aead/quic.rs @@ -0,0 +1,1714 @@ +//! QUIC Header Protection. +//! +//! See draft-ietf-quic-tls. + +use alloc::vec; +use core::mem; +use foreign_types::ForeignType; +use zeroize::Zeroizing; + +use crate::error::check_if_zero; +use crate::types::{AesObject, ChaChaObject}; +use alloc::boxed::Box; +use alloc::vec::Vec; +use core::ptr; + +use rustls::crypto::cipher::{Iv, Nonce}; +use rustls::quic::Tag; +use rustls::{crypto::cipher::AeadKey, quic, Error}; +use wolfcrypt_rs::*; + +type PktEncFn = + fn(packet_cipher: &Cipher, nonce: &[u8], aad: &[u8], in_out: &mut [u8]) -> Result; + +type PktDecFn = + fn(packet_cipher: &Cipher, nonce: &[u8], aad: &[u8], in_out: &mut [u8]) -> Result<(), Error>; + +macro_rules! mask_array { + () => { + [0u8; 5] + }; +} +pub enum Cipher { + Aes(AesCipher), + ChaCha20(ChaChaCipher), +} + +/// All the AEADs we support use 96-bit nonces. +pub const NONCE_LEN: usize = 96 / 8; + +pub(crate) const TAG_LEN: usize = 16; + +pub const AES_128_KEY_LEN: usize = 128 / 8; +pub const AES_256_KEY_LEN: usize = 256 / 8; + +pub const CHACHA_KEY_LEN: usize = 32; +pub const SAMPLE_LEN: usize = TAG_LEN; +pub const MASK_LEN: usize = 5; + +/// QUIC sample for new key masks +pub type Sample = [u8; SAMPLE_LEN]; + +/// A QUIC Header Protection Algorithm. +pub struct HPAlgorithm { + hp_mask: fn(hp_cipher: &Cipher, sample: &[u8]) -> Result<[u8; MASK_LEN], Error>, + init: fn(key: &[u8]) -> Result, + key_len: usize, + id: HPAlgorithmID, +} + +impl HPAlgorithm { + /// The length of the key. + #[inline(always)] + pub fn key_len(&self) -> usize { + self.key_len + } + + /// The required sample length. + #[inline(always)] + pub fn sample_len(&self) -> usize { + SAMPLE_LEN + } +} + +/// A QUIC header protection algorithm. +#[derive(Debug, Eq, PartialEq)] +pub enum HPAlgorithmID { + Aes128, + Aes256, + ChaCha20, +} + +impl PartialEq for HPAlgorithm { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for HPAlgorithm {} + +/// AES-128. +pub static AES_128: HPAlgorithm = HPAlgorithm { + key_len: AES_128_KEY_LEN, + hp_mask: generate_mask_aes, + id: HPAlgorithmID::Aes128, + init: init_hp_aes_cipher, +}; + +/// AES-256. +pub static AES_256: HPAlgorithm = HPAlgorithm { + key_len: AES_256_KEY_LEN, + hp_mask: generate_mask_aes, + id: HPAlgorithmID::Aes256, + init: init_hp_aes_cipher, +}; + +fn init_hp_aes_cipher(key: &[u8]) -> Result { + let mut aes_cipher = AesCipher::new()?; + aes_cipher.set_key(key)?; + Ok(Cipher::Aes(aes_cipher)) +} + +fn generate_mask_aes(hp_cipher: &Cipher, sample: &[u8]) -> Result<[u8; MASK_LEN], Error> { + let aes_cipher = match hp_cipher { + Cipher::Aes(c) => c, + _ => return Err(Error::General("Invalid cipher type".into())), + }; + + let mut mask = mask_array!(); + match aes_cipher.encrypt_sample(sample) { + Ok(output) => mask.copy_from_slice(&output[..5]), + Err(e) => return Err(e), + } + Ok(mask) +} + +/// ChaCha20. +pub static CHACHA20: HPAlgorithm = HPAlgorithm { + key_len: CHACHA_KEY_LEN, + init: init_hp_chacha20_cipher, + hp_mask: generate_mask_chacha20, + id: HPAlgorithmID::ChaCha20, +}; + +fn init_hp_chacha20_cipher(key: &[u8]) -> Result { + let mut chacha_cipher = ChaChaCipher::new(None)?; + chacha_cipher.set_key(key)?; + Ok(Cipher::ChaCha20(chacha_cipher)) +} + +fn generate_mask_chacha20(hp_cipher: &Cipher, sample: &[u8]) -> Result<[u8; MASK_LEN], Error> { + let chacha20_cipher = match hp_cipher { + Cipher::ChaCha20(c) => c, + _ => return Err(Error::General("Invalid cipher type".into())), + }; + + let mut mask = mask_array!(); + match chacha20_cipher.encrypt_sample(sample) { + Ok(output) => mask.copy_from_slice(&output[..5]), + Err(e) => return Err(e), + } + Ok(mask) +} + +/// A key for generating QUIC Header Protection masks. +pub struct HeaderProtectionKey { + hp_cipher: Cipher, + algorithm: &'static HPAlgorithm, +} + +impl HeaderProtectionKey { + /// Create a new header protection key. + /// + /// `key_bytes` must be exactly `algorithm.key_len` bytes long. + pub fn new(key: Vec, algorithm: &'static HPAlgorithm) -> Result { + if key.len() != algorithm.key_len { + return Err(Error::General("Invalid key length".into())); + } + Ok(Self { + hp_cipher: (algorithm.init)(&key)?, + algorithm, + }) + } + + fn header_protection( + &self, + sample: &[u8], + first: &mut u8, + packet_number: &mut [u8], + masked: bool, + ) -> Result<(), Error> { + // This implements "Header Protection Application" almost verbatim. + // + + if sample.len() != SAMPLE_LEN { + return Err(Error::General("Invalid sample length".into())); + } + + let mask = (self.algorithm.hp_mask)(&self.hp_cipher, sample)?; + + let (first_mask, pn_mask) = mask + .split_first() + .ok_or_else(|| Error::General("Function split_first failed".into()))?; + + // It is OK for the `mask` to be longer than `packet_number`, + // but a valid `packet_number` will never be longer than `mask`. + if packet_number.len() > pn_mask.len() { + return Err(Error::General("packet number too long".into())); + } + + // Infallible from this point on. Before this point, `first` and + // `packet_number` are unchanged. + + const LONG_HEADER_FORM: u8 = 0x80; + let bits = match *first & LONG_HEADER_FORM == LONG_HEADER_FORM { + true => 0x0f, // Long header: 4 bits masked + false => 0x1f, // Short header: 5 bits masked + }; + + let first_plain = match masked { + // When unmasking, use the packet length bits after unmasking + true => *first ^ (first_mask & bits), + // When masking, use the packet length bits before masking + false => *first, + }; + let pn_len = (first_plain & 0x03) as usize + 1; + + *first ^= first_mask & bits; + for (dst, m) in packet_number.iter_mut().zip(pn_mask).take(pn_len) { + *dst ^= m; + } + + Ok(()) + } + + /// The key's algorithm. + #[inline(always)] + pub fn algorithm(&self) -> &'static HPAlgorithm { + self.algorithm + } +} + +impl quic::HeaderProtectionKey for HeaderProtectionKey { + fn encrypt_in_place( + &self, + sample: &[u8], + first: &mut u8, + packet_number: &mut [u8], + ) -> Result<(), Error> { + self.header_protection(sample, first, packet_number, false) + } + + fn decrypt_in_place( + &self, + sample: &[u8], + first: &mut u8, + packet_number: &mut [u8], + ) -> Result<(), Error> { + self.header_protection(sample, first, packet_number, true) + } + + #[inline] + fn sample_len(&self) -> usize { + TAG_LEN + } +} + +#[derive(Debug, Eq, PartialEq)] +pub(crate) enum PacketKeyAlgorithmID { + Aes128Gcm, + Aes256Gcm, + ChaCha20Poly1305, +} + +/// A QUIC packet protection algorithm. +pub struct AeadAlgorithm { + init: fn(key: &[u8]) -> Result, + + encrypt: PktEncFn, + decrypt: PktDecFn, + + key_len: usize, + id: PacketKeyAlgorithmID, +} + +impl AeadAlgorithm { + /// The length of the key. + #[inline(always)] + pub fn key_len(&self) -> usize { + self.key_len + } + + /// The length of a tag. + #[inline(always)] + pub fn tag_len(&self) -> usize { + TAG_LEN + } + + /// The length of the nonces. + #[inline(always)] + pub fn nonce_len(&self) -> usize { + NONCE_LEN + } +} + +impl PartialEq for AeadAlgorithm { + fn eq(&self, other: &Self) -> bool { + self.id == other.id + } +} + +impl Eq for AeadAlgorithm {} + +/// AES-128 in GCM mode with 128-bit tags and 96 bit nonces. +pub static AES_128_GCM: AeadAlgorithm = AeadAlgorithm { + init: init_aes_gcm_cipher, + encrypt: encrypt_aes_gcm, + decrypt: decrypt_aes_gcm, + key_len: AES_128_KEY_LEN, + id: PacketKeyAlgorithmID::Aes128Gcm, +}; + +/// AES-256 in GCM mode with 128-bit tags and 96 bit nonces. +pub static AES_256_GCM: AeadAlgorithm = AeadAlgorithm { + init: init_aes_gcm_cipher, + encrypt: encrypt_aes_gcm, + decrypt: decrypt_aes_gcm, + key_len: AES_256_KEY_LEN, + id: PacketKeyAlgorithmID::Aes256Gcm, +}; + +fn init_aes_gcm_cipher(key: &[u8]) -> Result { + let mut aes_cipher = AesCipher::new()?; + aes_cipher.set_key(key)?; + Ok(Cipher::Aes(aes_cipher)) +} + +fn encrypt_aes_gcm( + packet_cipher: &Cipher, + nonce: &[u8], + aad: &[u8], + in_out: &mut [u8], +) -> Result { + let aes_cipher = match packet_cipher { + Cipher::Aes(c) => c, + _ => return Err(Error::General("Invalid cipher type".into())), + }; + aes_cipher.encrypt_separate_tag(nonce, aad, in_out) +} + +pub(super) fn decrypt_aes_gcm( + packet_cipher: &Cipher, + nonce: &[u8], + aad: &[u8], + in_out: &mut [u8], +) -> Result<(), Error> { + let aes_cipher = match packet_cipher { + Cipher::Aes(aes_key) => aes_key, + _ => return Err(Error::General("Invalid cipher type".into())), + }; + aes_cipher.decrypt(nonce, aad, in_out) +} + +/// ChaCha20-Poly1305 as described in [RFC 8439]. +/// +/// The keys are 256 bits long and the nonces are 96 bits long. +/// +/// [RFC 8439]: https://tools.ietf.org/html/rfc8439 +pub static CHACHA20_POLY1305: AeadAlgorithm = AeadAlgorithm { + init: init_chacha20_poly1305_cipher, + encrypt: encrypt_chacha20_poly1305, + decrypt: decrypt_chacha20_poly1305, + key_len: CHACHA_KEY_LEN, + id: PacketKeyAlgorithmID::ChaCha20Poly1305, +}; + +fn init_chacha20_poly1305_cipher(key: &[u8]) -> Result { + let key_array = <[u8; 32]>::try_from(key) + .map_err(|_| Error::General("Invalid key length for ChaCha20-Poly1305".into()))?; + let chacha_cipher = ChaChaCipher::new(Some(key_array))?; + Ok(Cipher::ChaCha20(chacha_cipher)) +} + +fn encrypt_chacha20_poly1305( + packet_cipher: &Cipher, + nonce: &[u8], + aad: &[u8], + in_out: &mut [u8], +) -> Result { + let chacha_cipher = match packet_cipher { + Cipher::ChaCha20(chacha_key) => chacha_key, + _ => return Err(Error::General("Invalid cipher type".into())), + }; + chacha_cipher.encrypt_separate_tag(nonce, aad, in_out) +} + +fn decrypt_chacha20_poly1305( + packet_cipher: &Cipher, + nonce: &[u8], + aad: &[u8], + in_out: &mut [u8], +) -> Result<(), Error> { + let chacha_cipher = match packet_cipher { + Cipher::ChaCha20(chacha_key) => chacha_key, + _ => return Err(Error::General("Invalid cipher type".into())), + }; + chacha_cipher.decrypt(nonce, aad, in_out) +} + +pub(crate) struct PacketKey { + /// Encrypts or decrypts a packet's payload + packet_cipher: Cipher, + /// Computes unique nonces for each packet + iv: Iv, + /// Confidentiality limit (see [`quic::PacketKey::confidentiality_limit`]) + confidentiality_limit: u64, + /// Integrity limit (see [`quic::PacketKey::integrity_limit`]) + integrity_limit: u64, + /// Algorithm for packet protection + algorithm: &'static AeadAlgorithm, +} + +impl PacketKey { + pub(crate) fn new( + key: AeadKey, + iv: Iv, + confidentiality_limit: u64, + integrity_limit: u64, + algorithm: &'static AeadAlgorithm, + ) -> Result { + if key.as_ref().len() != algorithm.key_len { + return Err(Error::General("Invalid key length".into())); + } + Ok(Self { + packet_cipher: (algorithm.init)(key.as_ref())?, + iv, + confidentiality_limit, + integrity_limit, + algorithm, + }) + } +} + +impl quic::PacketKey for PacketKey { + fn encrypt_in_place( + &self, + packet_number: u64, + header: &[u8], + payload: &mut [u8], + ) -> Result { + let aad = header; + let nonce = Nonce::new(&self.iv, packet_number).0; + let tag = (self.algorithm.encrypt)(&self.packet_cipher, &nonce, aad, payload)?; + Ok(quic::Tag::from(tag.as_ref())) + } + + /// Decrypt a QUIC packet + /// + /// Takes the packet `header`, which is used as the additional authenticated data, and the + /// `payload`, which includes the authentication tag. + /// + /// If the return value is `Ok`, the decrypted payload can be found in `payload`, up to the + /// length found in the return value. + fn decrypt_in_place<'a>( + &self, + packet_number: u64, + header: &[u8], + payload: &'a mut [u8], + ) -> Result<&'a [u8], Error> { + let payload_len = payload.len(); + let aad = header; + let nonce = Nonce::new(&self.iv, packet_number).0; + (self.algorithm.decrypt)(&self.packet_cipher, &nonce, aad, payload)?; + let plain_len = payload_len - self.algorithm.tag_len(); + Ok(&payload[..plain_len]) + } + + /// Tag length for the underlying AEAD algorithm + #[inline] + fn tag_len(&self) -> usize { + self.algorithm.tag_len() + } + + /// Confidentiality limit (see [`quic::PacketKey::confidentiality_limit`]) + fn confidentiality_limit(&self) -> u64 { + self.confidentiality_limit + } + + /// Integrity limit (see [`quic::PacketKey::integrity_limit`]) + fn integrity_limit(&self) -> u64 { + self.integrity_limit + } +} + +pub(crate) struct KeyFactory { + pub(crate) packet_algo: &'static AeadAlgorithm, + pub(crate) header_algo: &'static HPAlgorithm, + pub(crate) confidentiality_limit: u64, + pub(crate) integrity_limit: u64, +} + +impl quic::Algorithm for KeyFactory { + fn packet_key(&self, key: AeadKey, iv: Iv) -> Box { + Box::new( + match PacketKey::new( + key, + iv, + self.confidentiality_limit, + self.integrity_limit, + self.packet_algo, + ) { + Ok(packet_key) => packet_key, + Err(e) => panic!("PacketKey object creation failed: {:?}", e), + }, + ) + } + + fn header_protection_key(&self, key: AeadKey) -> Box { + Box::new( + match HeaderProtectionKey::new(key.as_ref().to_vec(), self.header_algo) { + Ok(header_key) => header_key, + Err(e) => panic!("HeaderProtection Key object creation failed: {:?}", e), + }, + ) + } + + fn aead_key_len(&self) -> usize { + self.packet_algo.key_len() + } + + fn fips(&self) -> bool { + false + } +} + +pub struct AesCipher { + aes_object: AesObject, + key: Zeroizing>, +} + +impl AesCipher { + pub fn new() -> Result { + Ok(Self { + aes_object: new_aes_object()?, + key: Zeroizing::new(Vec::new()), + }) + } + + /// It initializes an AES cipher with the given key. + pub fn set_key(&mut self, key: &[u8]) -> Result<(), Error> { + if key.len() != AES_256_KEY_LEN && key.len() != AES_128_KEY_LEN { + return Err(Error::General("Invalid key length".into())); + } + let ret = unsafe { + wc_AesSetKey( + self.aes_object.as_ptr(), + key.as_ptr(), + key.len() as word32, + ptr::null_mut(), + 0, + ) + }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("Function AesSetKey failed".into()))?; + self.key = Zeroizing::new(key.to_vec()); + Ok(()) + } + + pub fn encrypt_sample(&self, sample: &[u8]) -> Result, Error> { + let mut out_block = vec![0; TAG_LEN]; + + let ret = unsafe { + wc_AesEncryptDirect( + self.aes_object.as_ptr(), + out_block.as_mut_ptr(), + sample.as_ptr(), + ) + }; + check_if_zero(ret).map_err(|_| rustls::Error::EncryptError)?; + + Ok(out_block) + } + + pub fn encrypt_separate_tag( + &self, + nonce: &[u8], + aad: &[u8], + payload: &mut [u8], + ) -> Result { + let mut auth_tag = vec![0u8; TAG_LEN]; + let mut ret; + + // Prepare aes_object for encryption + ret = unsafe { + wc_AesGcmSetKey( + self.aes_object.as_ptr(), + self.key.as_ptr(), + self.key.len() as word32, + ) + }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("Function AesGcmSetKey failed".into()))?; + + // This function encrypts the input message, held in the buffer in, + // and stores the resulting cipher text in the output buffer out. + // It requires a new iv (initialization vector) for each call to encrypt. + // It also encodes the input authentication vector, + // authIn, into the authentication tag, authTag. + + ret = unsafe { + wc_AesGcmEncrypt( + self.aes_object.as_ptr(), + payload.as_mut_ptr(), + payload.as_ptr(), + payload.as_ref().len() as word32, + nonce.as_ptr(), + nonce.len() as word32, + auth_tag.as_mut_ptr(), + auth_tag.len() as word32, + aad.as_ptr(), + aad.len() as word32, + ) + }; + check_if_zero(ret).map_err(|_| rustls::Error::EncryptError)?; + + Ok(quic::Tag::from(auth_tag.as_ref())) + } + pub fn decrypt(&self, nonce: &[u8], aad: &[u8], payload: &mut [u8]) -> Result<(), Error> { + let mut auth_tag = [0u8; TAG_LEN]; + let message_len = payload.len() - TAG_LEN; + auth_tag.copy_from_slice(&payload[message_len..]); + + let mut ret; + + // Prepare aes_object for decryption + ret = unsafe { + wc_AesGcmSetKey( + self.aes_object.as_ptr(), + self.key.as_ptr(), + self.key.len() as word32, + ) + }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("Function AesGcmSetKey failed".into()))?; + + // Finally, we have everything to decrypt the message + // from the payload. + ret = unsafe { + wc_AesGcmDecrypt( + self.aes_object.as_ptr(), + payload[..message_len].as_mut_ptr(), + payload[..message_len].as_ptr(), + payload[..message_len] + .len() + .try_into() + .map_err(|_| rustls::Error::General("Function try_into() failed".into()))?, + nonce.as_ptr(), + nonce.len() as word32, + auth_tag.as_ptr(), + auth_tag.len() as word32, + aad.as_ptr(), + aad.len() as word32, + ) + }; + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; + + Ok(()) + } +} + +pub struct ChaChaCipher { + chacha_cipher: Option, + key: Option>, // In case of packet protection, no need to initiate a cipher +} + +impl ChaChaCipher { + pub fn new(key: Option<[u8; CHACHA_KEY_LEN]>) -> Result { + match key { + None => Ok(Self { + chacha_cipher: Some(new_chacha_object()?), + key: None, + }), + Some(key_bytes) => Ok(Self { + chacha_cipher: None, + key: Some(Zeroizing::new(key_bytes)), + }), + } + } + + fn set_key(&mut self, key: &[u8]) -> Result<(), Error> { + if key.len() != CHACHA_KEY_LEN { + return Err(Error::General("Invalid key length".into())); + } + + let chacha_cipher = self.chacha_cipher.as_ref().ok_or_else(|| { + Error::General("Cipher is none. Create a cipher object before setting key".into()) + })?; + //Set key for ChaCha object + let ret = + unsafe { wc_Chacha_SetKey(chacha_cipher.as_ptr(), key.as_ptr(), key.len() as word32) }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("Function wc_Chacha_SetKey failed".into()))?; + self.key = + Some(Zeroizing::new(key.try_into().map_err(|_| { + Error::General("Key must be exactly 32 bytes".into()) + })?)); + Ok(()) + } + + pub fn key_len(&self) -> usize { + CHACHA_KEY_LEN + } + + pub fn encrypt_sample(&self, sample: &[u8]) -> Result, Error> { + let chacha_cipher = self.chacha_cipher.as_ref().ok_or_else(|| { + Error::General("Cipher is none. Create a cipher object before encryption".into()) + })?; + + let mut out = vec![0; TAG_LEN]; + + let (ctr, nonce) = sample.split_at(4); + let ctr = u32::from_le_bytes( + ctr.try_into() + .map_err(|_| rustls::Error::General("Function try_into() failed".into()))?, + ); + + //Set IV for ChaCha object + let mut ret = unsafe { wc_Chacha_SetIV(chacha_cipher.as_ptr(), nonce.as_ptr(), ctr) }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("Function wc_Chacha_SetIV failed".into()))?; + + //Encrypt sample + ret = unsafe { + wc_Chacha_Process( + chacha_cipher.as_ptr(), + out.as_mut_ptr(), + [0; TAG_LEN].as_ptr(), + TAG_LEN as word32, + ) + }; + check_if_zero(ret).map_err(|_| rustls::Error::EncryptError)?; + + Ok(out) + } + pub fn encrypt_separate_tag( + &self, + nonce: &[u8], + aad: &[u8], + payload: &mut [u8], + ) -> Result { + let chacha_key = self.key.as_ref().ok_or_else(|| { + Error::General("Key is none. Generate a key before encryption".into()) + })?; + + let mut auth_tag: [u8; CHACHA20_POLY1305_AEAD_AUTHTAG_SIZE as usize] = + unsafe { mem::zeroed() }; + + // This function encrypts an input message, inPlaintext, + // using the ChaCha20 stream cipher, into the output buffer, outCiphertext. + // It also performs Poly-1305 authentication (on the cipher text), + // and stores the generated authentication tag in the output buffer, outAuthTag. + + let ret = unsafe { + wc_ChaCha20Poly1305_Encrypt( + chacha_key.as_ptr(), + nonce.as_ptr(), + aad.as_ptr(), + aad.len() as word32, + payload.as_ref().as_ptr(), + payload.len() as word32, + payload.as_mut().as_mut_ptr(), + auth_tag.as_mut_ptr(), + ) + }; + check_if_zero(ret).map_err(|_| rustls::Error::EncryptError)?; + + Ok(quic::Tag::from(auth_tag.as_ref())) + } + + pub fn decrypt(&self, nonce: &[u8], aad: &[u8], payload: &mut [u8]) -> Result<(), Error> { + let chacha_key = self.key.as_ref().ok_or_else(|| { + Error::General("Key is none. Generate a key before decryption".into()) + })?; + let mut auth_tag = [0u8; TAG_LEN]; + let message_len = payload.len() - TAG_LEN; + auth_tag.copy_from_slice(&payload[message_len..]); + + // This function decrypts input ciphertext, inCiphertext, + // using the ChaCha20 stream cipher, into the output buffer, outPlaintext. + // It also performs Poly-1305 authentication, comparing the given inAuthTag + // to an authentication generated with the inAAD (arbitrary length additional authentication data). + // Note: If the generated authentication tag does not match the supplied + // authentication tag, the text is not decrypted. + let ret = unsafe { + wc_ChaCha20Poly1305_Decrypt( + chacha_key.as_ptr(), + nonce.as_ptr(), + aad.as_ptr(), + aad.len() as word32, + // [..message_len] since we want to exclude the + // the auth_tag. + payload[..message_len].as_ptr(), + message_len as word32, + auth_tag.as_ptr(), + payload[..message_len].as_mut_ptr(), + ) + }; + check_if_zero(ret).map_err(|_| rustls::Error::DecryptError)?; + Ok(()) + } +} + +fn new_aes_object() -> Result { + let aes_c_type_box = Box::new(unsafe { mem::zeroed::() }); + let aes_c_type_ptr = Box::into_raw(aes_c_type_box); + let aes_object = unsafe { AesObject::from_ptr(aes_c_type_ptr) }; + + // Initialize Aes structure. + let ret = unsafe { wc_AesInit(aes_object.as_ptr(), ptr::null_mut(), INVALID_DEVID) }; + check_if_zero(ret).map_err(|_| rustls::Error::General("Function AesInit failed".into()))?; + Ok(aes_object) +} + +fn new_chacha_object() -> Result { + //Create ChaCha object + let chacha_c_typ_box = Box::new(unsafe { mem::zeroed::() }); + let chacha_c_typ_ptr = Box::into_raw(chacha_c_typ_box); + let chacha_object = unsafe { ChaChaObject::from_ptr(chacha_c_typ_ptr) }; + + Ok(chacha_object) +} + +#[cfg(test)] +mod tests { + use hex_literal::hex; + use rustls::crypto::tls13::HkdfExpander; + use std::prelude::v1::Vec; + use std::vec; + + use crate::aead; + use rustls::crypto::cipher::{AeadKey, Iv, NONCE_LEN}; + use rustls::quic::*; + + use crate::default_provider; + use crate::{TLS13_AES_128_GCM_SHA256, TLS13_CHACHA20_POLY1305_SHA256}; + use rustls::crypto::tls13::OkmBlock; + use rustls::internal::msgs::codec::Codec; + use rustls::{ClientConfig, Error, ServerConfig, Side, SideData}; + use rustls_pki_types::PrivatePkcs8KeyDer; + use std::sync::Arc; + + // Returns the sender's next secrets to use, or the receiver's error. + fn step( + send: &mut ConnectionCommon, + recv: &mut ConnectionCommon, + ) -> Result, Error> { + let mut buf = Vec::new(); + let change = loop { + let prev = buf.len(); + if let Some(x) = send.write_hs(&mut buf) { + break Some(x); + } + if prev == buf.len() { + break None; + } + }; + + recv.read_hs(&buf)?; + assert_eq!(recv.alert(), None); + Ok(change) + } + + fn make_default_client_config() -> ClientConfig { + let root_store = + rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + let config = rustls::ClientConfig::builder_with_provider(default_provider().into()) + .with_safe_default_protocol_versions() + .unwrap() + .with_root_certificates(root_store) + .with_no_client_auth(); + config + } + + fn make_default_server_config() -> ServerConfig { + let alg = &rcgen::PKCS_ECDSA_P256_SHA256; + let mut ca_params = rcgen::CertificateParams::new(Vec::new()).unwrap(); + ca_params + .distinguished_name + .push(rcgen::DnType::OrganizationName, "Provider Server Example"); + ca_params + .distinguished_name + .push(rcgen::DnType::CommonName, "Example CA"); + ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained); + ca_params.key_usages = vec![ + rcgen::KeyUsagePurpose::KeyCertSign, + rcgen::KeyUsagePurpose::DigitalSignature, + ]; + let ca_key = rcgen::KeyPair::generate_for(alg).unwrap(); + let ca_cert = ca_params.self_signed(&ca_key).unwrap(); + + // Create a server end entity cert issued by the CA. + let mut server_ee_params = + rcgen::CertificateParams::new(vec!["localhost".to_string()]).unwrap(); + server_ee_params.is_ca = rcgen::IsCa::NoCa; + server_ee_params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth]; + let server_key = rcgen::KeyPair::generate_for(alg).unwrap(); + let server_cert = server_ee_params + .signed_by(&server_key, &ca_cert, &ca_key) + .unwrap(); + + let mut server_config = ServerConfig::builder_with_provider(default_provider().into()) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_single_cert( + vec![server_cert.into()], + PrivatePkcs8KeyDer::from(server_key.serialize_der()).into(), + ) + .unwrap(); + + server_config.key_log = Arc::new(rustls::KeyLogFile::new()); + + server_config + } + /// Encode each of `items` + pub fn iter_to_vec_of_bytes<'a, T: Codec<'a>>(items: impl Iterator) -> Vec { + let mut body = Vec::new(); + + for i in items { + i.encode(&mut body); + } + body + } + + ///Encode length as prefix + pub fn prefix_len(mut body: Vec, len: usize) -> Vec { + match len { + 8 => { + body.splice(0..0, [body.len() as u8]); + } + 16 => { + body.splice(0..0, (body.len() as u16).to_be_bytes()); + } + 24 => { + let len = (body.len() as u32).to_be_bytes(); + body.insert(0, len[1]); + body.insert(1, len[2]); + body.insert(2, len[3]); + } + _ => panic!("wrong length!"), + }; + body + } + + fn make_extensions() -> Vec { + // Create extensions + let mut extensions: Vec = Vec::new(); + // kx group + extensions.push(Extension { + typ: 0x000a, // EllipticCurves + body: prefix_len( + iter_to_vec_of_bytes([rustls::NamedGroup::secp256r1].into_iter()), + 16, + ), + }); + // Sig algs + extensions.push(Extension { + typ: 0x000d, // SignatureAlgorithms + body: prefix_len( + rustls::SignatureScheme::RSA_PKCS1_SHA256 + .to_array() + .to_vec(), + 16, + ), + }); + + // Supported Versions, + extensions.push(Extension { + typ: 0x002b, // Supported Versions + body: prefix_len( + iter_to_vec_of_bytes( + [ + rustls::ProtocolVersion::TLSv1_3, + rustls::ProtocolVersion::TLSv1_2, + ] + .into_iter(), + ), + 8, + ), + }); + + // Key share + const SOME_POINT_ON_P256: &[u8] = &[ + 4, 41, 39, 177, 5, 18, 186, 227, 237, 220, 254, 70, 120, 40, 18, 139, 173, 41, 3, 38, + 153, 25, 247, 8, 96, 105, 200, 196, 223, 108, 115, 40, 56, 199, 120, 121, 100, 234, + 172, 0, 229, 146, 31, 177, 73, 138, 96, 244, 96, 103, 102, 179, 217, 104, 80, 1, 85, + 141, 26, 151, 78, 115, 65, 81, 62, + ]; + + let mut share = prefix_len(SOME_POINT_ON_P256.to_vec(), 16); + share.splice(0..0, rustls::NamedGroup::secp256r1.to_array()); + + extensions.push(Extension { + typ: 0x0033, // Key share + body: prefix_len(share, 16), + }); + extensions + } + fn make_client_hello() -> Vec { + let mut ch: Vec = Vec::new(); + rustls::ProtocolVersion::TLSv1_2.encode(&mut ch); + ch.extend_from_slice(&[0u8; 32]); // Encode random + ch.extend_from_slice(&[0u8; 1]); // Encode session_id + vec![ + rustls::CipherSuite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + rustls::CipherSuite::TLS13_AES_128_GCM_SHA256, + ] + .to_vec() + .encode(&mut ch); // Encode cypher suites + ch.extend_from_slice(&[0x01, 0x00]); // only null compression + + //Generate ch extensions + let extensions = make_extensions(); + + // Encode the extensions + let mut exts = vec![]; + for e in extensions { + e.typ.encode(&mut exts); + exts.extend_from_slice(&(e.body.len() as u16).to_be_bytes()); + exts.extend_from_slice(&e.body); + } + ch.extend(prefix_len(exts, 16)); + // Apply handshake framing to ch data. + let mut body = prefix_len(ch, 24); + body.splice(0..0, rustls::HandshakeType::ClientHello.to_array()); + body + } + #[derive(Clone)] + pub struct Extension { + pub typ: u16, + pub body: Vec, + } + + #[derive(Debug, Clone)] + struct ChaCha20TestVector { + key: [u8; 32], + sample: [u8; 16], + mask: [u8; 5], + } + + enum AesTestVector { + Aes128 { + key: [u8; 16], + sample: [u8; 16], + mask: [u8; 5], + }, + Aes256 { + key: [u8; 32], + sample: [u8; 16], + mask: [u8; 5], + }, + } + + fn hkdf_expand_label( + expander: &Box, + label: &[u8], + context: &[u8], + n: usize, + output: &mut [u8], + ) { + const LABEL_PREFIX: &[u8] = b"tls13 "; + + let output_len = u16::to_be_bytes(n as u16); + let label_len = u8::to_be_bytes((LABEL_PREFIX.len() + label.len()) as u8); + let context_len = u8::to_be_bytes(context.len() as u8); + + let info = &[ + &output_len[..], + &label_len[..], + LABEL_PREFIX, + label, + &context_len[..], + context, + ]; + + let _ = expander.expand_slice(info, output); + } + + fn test_short_packet(version: rustls::quic::Version, expected: &[u8]) { + // Code taken from rustls with modification + let chacha_key_len = TLS13_CHACHA20_POLY1305_SHA256 + .tls13() + .unwrap() + .quic + .unwrap() + .aead_key_len(); + + const PN: u64 = 654360564; + const SECRET: &[u8] = &[ + 0x9a, 0xc3, 0x12, 0xa7, 0xf8, 0x77, 0x46, 0x8e, 0xbe, 0x69, 0x42, 0x27, 0x48, 0xad, + 0x00, 0xa1, 0x54, 0x43, 0xf1, 0x82, 0x03, 0xa0, 0x7d, 0x60, 0x60, 0xf6, 0x88, 0xf3, + 0x0f, 0x21, 0x63, 0x2b, + ]; + + let mut output = [0u8; aead::quic::CHACHA_KEY_LEN]; + let mut iv = [0u8; aead::quic::NONCE_LEN]; + // Derive Header Protection key + let secret = OkmBlock::new(SECRET); + let expander = TLS13_CHACHA20_POLY1305_SHA256 + .tls13() + .unwrap() + .hkdf_provider + .expander_for_okm(&secret); + //Derive hp key + hkdf_expand_label( + &expander, + match version { + rustls::quic::Version::V1Draft | rustls::quic::Version::V1 => b"quic hp", + rustls::quic::Version::V2 => b"quicv2 hp", + _ => todo!(), + }, + &[], + chacha_key_len, + &mut output, + ); + + let hp_aead_key = AeadKey::from(output.clone()); + let header_protection_key = TLS13_CHACHA20_POLY1305_SHA256 + .tls13() + .unwrap() + .quic + .unwrap() + .header_protection_key(hp_aead_key); + + // Derive packet protection key and iv + hkdf_expand_label( + &expander, + match version { + rustls::quic::Version::V1Draft | rustls::quic::Version::V1 => b"quic key", + rustls::quic::Version::V2 => b"quicv2 key", + _ => todo!(), + }, + &[], + chacha_key_len, + &mut output, + ); + + let pkt_aead_key = AeadKey::from(output); + + hkdf_expand_label( + &expander, + match version { + rustls::quic::Version::V1Draft | rustls::quic::Version::V1 => b"quic iv", + rustls::quic::Version::V2 => b"quicv2 iv", + _ => todo!(), + }, + &[], + NONCE_LEN, + &mut iv, + ); + let iv = Iv::new(iv); + + let packet_protection_key = TLS13_CHACHA20_POLY1305_SHA256 + .tls13() + .unwrap() + .quic + .unwrap() + .packet_key(pkt_aead_key, iv); + const PLAIN: &[u8] = &[0x42, 0x00, 0xbf, 0xf4, 0x01]; + + let mut buf = PLAIN.to_vec(); + let (header, payload) = buf.split_at_mut(4); + let tag = packet_protection_key + .encrypt_in_place(PN, header, payload) + .unwrap(); + buf.extend(tag.as_ref()); + + let pn_offset = 1; + let (header, sample) = buf.split_at_mut(pn_offset + 4); + let (first, rest) = header.split_at_mut(1); + let sample = &sample[..header_protection_key.sample_len()]; + header_protection_key + .encrypt_in_place(sample, &mut first[0], rest) + .unwrap(); + + assert_eq!(&buf, expected); + + let (header, sample) = buf.split_at_mut(pn_offset + 4); + let (first, rest) = header.split_at_mut(1); + let sample = &sample[..header_protection_key.sample_len()]; + header_protection_key + .decrypt_in_place(sample, &mut first[0], rest) + .unwrap(); + + let (header, payload_tag) = buf.split_at_mut(4); + let plain = packet_protection_key + .decrypt_in_place(PN, header, payload_tag) + .unwrap(); + + assert_eq!(plain, &PLAIN[4..]); + } + + #[test] + fn short_packet_header_protection() { + // https://www.rfc-editor.org/rfc/rfc9001.html#name-chacha20-poly1305-short-hea + test_short_packet( + rustls::quic::Version::V1, + &[ + 0x4c, 0xfe, 0x41, 0x89, 0x65, 0x5e, 0x5c, 0xd5, 0x5c, 0x41, 0xf6, 0x90, 0x80, 0x57, + 0x5d, 0x79, 0x99, 0xc2, 0x5a, 0x5b, 0xfb, + ], + ); + } + + #[test] + fn short_packet_header_protection_v2() { + // https://www.ietf.org/archive/id/draft-ietf-quic-v2-10.html#name-chacha20-poly1305-short-head + test_short_packet( + rustls::quic::Version::V2, + &[ + 0x55, 0x58, 0xb1, 0xc6, 0x0a, 0xe7, 0xb6, 0xb9, 0x32, 0xbc, 0x27, 0xd7, 0x86, 0xf4, + 0xbc, 0x2b, 0xb2, 0x0f, 0x21, 0x62, 0xba, + ], + ); + } + + #[test] + fn initial_test_vector_v2() { + let tls13_cipher_suite = TLS13_AES_128_GCM_SHA256.tls13().unwrap(); + + // https://www.ietf.org/archive/id/draft-ietf-quic-v2-10.html#name-sample-packet-protection-2 + let icid = [0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + let server = Keys::initial( + rustls::quic::Version::V2, + tls13_cipher_suite, + TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(), + &icid, + Side::Server, + ); + let mut server_payload = [ + 0x02, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x40, 0x5a, 0x02, 0x00, 0x00, 0x56, 0x03, + 0x03, 0xee, 0xfc, 0xe7, 0xf7, 0xb3, 0x7b, 0xa1, 0xd1, 0x63, 0x2e, 0x96, 0x67, 0x78, + 0x25, 0xdd, 0xf7, 0x39, 0x88, 0xcf, 0xc7, 0x98, 0x25, 0xdf, 0x56, 0x6d, 0xc5, 0x43, + 0x0b, 0x9a, 0x04, 0x5a, 0x12, 0x00, 0x13, 0x01, 0x00, 0x00, 0x2e, 0x00, 0x33, 0x00, + 0x24, 0x00, 0x1d, 0x00, 0x20, 0x9d, 0x3c, 0x94, 0x0d, 0x89, 0x69, 0x0b, 0x84, 0xd0, + 0x8a, 0x60, 0x99, 0x3c, 0x14, 0x4e, 0xca, 0x68, 0x4d, 0x10, 0x81, 0x28, 0x7c, 0x83, + 0x4d, 0x53, 0x11, 0xbc, 0xf3, 0x2b, 0xb9, 0xda, 0x1a, 0x00, 0x2b, 0x00, 0x02, 0x03, + 0x04, + ]; + let mut server_header = [ + 0xd1, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, + 0xb5, 0x00, 0x40, 0x75, 0x00, 0x01, + ]; + let tag = server + .local + .packet + .encrypt_in_place(1, &server_header, &mut server_payload) + .unwrap(); + let (first, rest) = server_header.split_at_mut(1); + let rest_len = rest.len(); + server + .local + .header + .encrypt_in_place( + &server_payload[2..18], + &mut first[0], + &mut rest[rest_len - 2..], + ) + .unwrap(); + let mut server_packet = server_header.to_vec(); + server_packet.extend(server_payload); + server_packet.extend(tag.as_ref()); + let expected_server_packet = [ + 0xdc, 0x6b, 0x33, 0x43, 0xcf, 0x00, 0x08, 0xf0, 0x67, 0xa5, 0x50, 0x2a, 0x42, 0x62, + 0xb5, 0x00, 0x40, 0x75, 0xd9, 0x2f, 0xaa, 0xf1, 0x6f, 0x05, 0xd8, 0xa4, 0x39, 0x8c, + 0x47, 0x08, 0x96, 0x98, 0xba, 0xee, 0xa2, 0x6b, 0x91, 0xeb, 0x76, 0x1d, 0x9b, 0x89, + 0x23, 0x7b, 0xbf, 0x87, 0x26, 0x30, 0x17, 0x91, 0x53, 0x58, 0x23, 0x00, 0x35, 0xf7, + 0xfd, 0x39, 0x45, 0xd8, 0x89, 0x65, 0xcf, 0x17, 0xf9, 0xaf, 0x6e, 0x16, 0x88, 0x6c, + 0x61, 0xbf, 0xc7, 0x03, 0x10, 0x6f, 0xba, 0xf3, 0xcb, 0x4c, 0xfa, 0x52, 0x38, 0x2d, + 0xd1, 0x6a, 0x39, 0x3e, 0x42, 0x75, 0x75, 0x07, 0x69, 0x80, 0x75, 0xb2, 0xc9, 0x84, + 0xc7, 0x07, 0xf0, 0xa0, 0x81, 0x2d, 0x8c, 0xd5, 0xa6, 0x88, 0x1e, 0xaf, 0x21, 0xce, + 0xda, 0x98, 0xf4, 0xbd, 0x23, 0xf6, 0xfe, 0x1a, 0x3e, 0x2c, 0x43, 0xed, 0xd9, 0xce, + 0x7c, 0xa8, 0x4b, 0xed, 0x85, 0x21, 0xe2, 0xe1, 0x40, + ]; + assert_eq!(server_packet[..], expected_server_packet[..]); + } + + #[test] + fn test_quic_rejects_missing_alpn() { + //Code taken from rustls with modification + let client_params = &b"client params"[..]; + let server_params = &b"server params"[..]; + + let client_config = Arc::new(make_default_client_config()); + + let mut server_config = make_default_server_config(); + server_config.alpn_protocols = vec!["foo".into()]; + let server_config = Arc::new(server_config); + + let mut client = rustls::quic::ClientConnection::new( + client_config, + rustls::quic::Version::V1, + "localhost".try_into().unwrap(), + client_params.into(), + ) + .unwrap(); + let mut server = rustls::quic::ServerConnection::new( + server_config, + rustls::quic::Version::V1, + server_params.into(), + ) + .unwrap(); + + assert_eq!( + step(&mut client, &mut server).err().unwrap(), + rustls::Error::NoApplicationProtocol + ); + + assert_eq!( + server.alert(), + Some(rustls::AlertDescription::NoApplicationProtocol) + ); + } + + #[test] + fn test_quic_invalid_early_data_size() { + //Code taken from rustls with modification + let mut server_config = make_default_server_config(); + server_config.alpn_protocols = vec!["foo".into()]; + + let cases = [ + (None, true), + (Some(0u32), true), + (Some(5), false), + (Some(0xffff_ffff), true), + ]; + + for &(size, ok) in cases.iter() { + println!("early data size case: {size:?}"); + if let Some(new) = size { + server_config.max_early_data_size = new; + } + + let wrapped = Arc::new(server_config.clone()); + assert_eq!( + rustls::quic::ServerConnection::new( + wrapped, + rustls::quic::Version::V1, + b"server params".to_vec(), + ) + .is_ok(), + ok + ); + } + } + + #[test] + fn test_quic_server_no_params_received() { + //Code taken from rustls with modification + + let server_config = make_default_server_config(); + let server_config = Arc::new(server_config); + + let mut server = rustls::quic::ServerConnection::new( + server_config, + rustls::quic::Version::V1, + b"server params".to_vec(), + ) + .unwrap(); + + //Make a basic client hello + let ch = make_client_hello(); + assert_eq!( + server.read_hs(ch.as_slice()).err(), + Some(Error::PeerMisbehaved( + rustls::PeerMisbehaved::MissingQuicTransportParameters + )) + ); + } + + #[test] + fn packet_key_api() { + //Code taken from rustls + use rustls::quic::{Keys, Version}; + use rustls::Side; + + // Test vectors: https://www.rfc-editor.org/rfc/rfc9001.html#name-client-initial + const CONNECTION_ID: &[u8] = &[0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08]; + const PACKET_NUMBER: u64 = 2; + const PLAIN_HEADER: &[u8] = &[ + 0xc3, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, + 0x00, 0x00, 0x44, 0x9e, 0x00, 0x00, 0x00, 0x02, + ]; + + const PAYLOAD: &[u8] = &[ + 0x06, 0x00, 0x40, 0xf1, 0x01, 0x00, 0x00, 0xed, 0x03, 0x03, 0xeb, 0xf8, 0xfa, 0x56, + 0xf1, 0x29, 0x39, 0xb9, 0x58, 0x4a, 0x38, 0x96, 0x47, 0x2e, 0xc4, 0x0b, 0xb8, 0x63, + 0xcf, 0xd3, 0xe8, 0x68, 0x04, 0xfe, 0x3a, 0x47, 0xf0, 0x6a, 0x2b, 0x69, 0x48, 0x4c, + 0x00, 0x00, 0x04, 0x13, 0x01, 0x13, 0x02, 0x01, 0x00, 0x00, 0xc0, 0x00, 0x00, 0x00, + 0x10, 0x00, 0x0e, 0x00, 0x00, 0x0b, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, + 0x63, 0x6f, 0x6d, 0xff, 0x01, 0x00, 0x01, 0x00, 0x00, 0x0a, 0x00, 0x08, 0x00, 0x06, + 0x00, 0x1d, 0x00, 0x17, 0x00, 0x18, 0x00, 0x10, 0x00, 0x07, 0x00, 0x05, 0x04, 0x61, + 0x6c, 0x70, 0x6e, 0x00, 0x05, 0x00, 0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x33, + 0x00, 0x26, 0x00, 0x24, 0x00, 0x1d, 0x00, 0x20, 0x93, 0x70, 0xb2, 0xc9, 0xca, 0xa4, + 0x7f, 0xba, 0xba, 0xf4, 0x55, 0x9f, 0xed, 0xba, 0x75, 0x3d, 0xe1, 0x71, 0xfa, 0x71, + 0xf5, 0x0f, 0x1c, 0xe1, 0x5d, 0x43, 0xe9, 0x94, 0xec, 0x74, 0xd7, 0x48, 0x00, 0x2b, + 0x00, 0x03, 0x02, 0x03, 0x04, 0x00, 0x0d, 0x00, 0x10, 0x00, 0x0e, 0x04, 0x03, 0x05, + 0x03, 0x06, 0x03, 0x02, 0x03, 0x08, 0x04, 0x08, 0x05, 0x08, 0x06, 0x00, 0x2d, 0x00, + 0x02, 0x01, 0x01, 0x00, 0x1c, 0x00, 0x02, 0x40, 0x01, 0x00, 0x39, 0x00, 0x32, 0x04, + 0x08, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x05, 0x04, 0x80, 0x00, 0xff, + 0xff, 0x07, 0x04, 0x80, 0x00, 0xff, 0xff, 0x08, 0x01, 0x10, 0x01, 0x04, 0x80, 0x00, + 0x75, 0x30, 0x09, 0x01, 0x10, 0x0f, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, + 0x08, 0x06, 0x04, 0x80, 0x00, 0xff, 0xff, + ]; + + let client_keys = Keys::initial( + Version::V1, + TLS13_AES_128_GCM_SHA256.tls13().unwrap(), + TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(), + CONNECTION_ID, + Side::Client, + ); + assert_eq!(client_keys.local.packet.tag_len(), 16); + + let mut buf = Vec::new(); + buf.extend(PLAIN_HEADER); + buf.extend(PAYLOAD); + let header_len = PLAIN_HEADER.len(); + let tag_len = client_keys.local.packet.tag_len(); + let padding_len = 1200 - header_len - PAYLOAD.len() - tag_len; + buf.extend(std::iter::repeat(0).take(padding_len)); + let (header, payload) = buf.split_at_mut(header_len); + let tag = client_keys + .local + .packet + .encrypt_in_place(PACKET_NUMBER, header, payload) + .unwrap(); + + let sample_len = client_keys.local.header.sample_len(); + let sample = &payload[..sample_len]; + let (first, rest) = header.split_at_mut(1); + client_keys + .local + .header + .encrypt_in_place(sample, &mut first[0], &mut rest[17..21]) + .unwrap(); + buf.extend_from_slice(tag.as_ref()); + + const PROTECTED: &[u8] = &[ + 0xc0, 0x00, 0x00, 0x00, 0x01, 0x08, 0x83, 0x94, 0xc8, 0xf0, 0x3e, 0x51, 0x57, 0x08, + 0x00, 0x00, 0x44, 0x9e, 0x7b, 0x9a, 0xec, 0x34, 0xd1, 0xb1, 0xc9, 0x8d, 0xd7, 0x68, + 0x9f, 0xb8, 0xec, 0x11, 0xd2, 0x42, 0xb1, 0x23, 0xdc, 0x9b, 0xd8, 0xba, 0xb9, 0x36, + 0xb4, 0x7d, 0x92, 0xec, 0x35, 0x6c, 0x0b, 0xab, 0x7d, 0xf5, 0x97, 0x6d, 0x27, 0xcd, + 0x44, 0x9f, 0x63, 0x30, 0x00, 0x99, 0xf3, 0x99, 0x1c, 0x26, 0x0e, 0xc4, 0xc6, 0x0d, + 0x17, 0xb3, 0x1f, 0x84, 0x29, 0x15, 0x7b, 0xb3, 0x5a, 0x12, 0x82, 0xa6, 0x43, 0xa8, + 0xd2, 0x26, 0x2c, 0xad, 0x67, 0x50, 0x0c, 0xad, 0xb8, 0xe7, 0x37, 0x8c, 0x8e, 0xb7, + 0x53, 0x9e, 0xc4, 0xd4, 0x90, 0x5f, 0xed, 0x1b, 0xee, 0x1f, 0xc8, 0xaa, 0xfb, 0xa1, + 0x7c, 0x75, 0x0e, 0x2c, 0x7a, 0xce, 0x01, 0xe6, 0x00, 0x5f, 0x80, 0xfc, 0xb7, 0xdf, + 0x62, 0x12, 0x30, 0xc8, 0x37, 0x11, 0xb3, 0x93, 0x43, 0xfa, 0x02, 0x8c, 0xea, 0x7f, + 0x7f, 0xb5, 0xff, 0x89, 0xea, 0xc2, 0x30, 0x82, 0x49, 0xa0, 0x22, 0x52, 0x15, 0x5e, + 0x23, 0x47, 0xb6, 0x3d, 0x58, 0xc5, 0x45, 0x7a, 0xfd, 0x84, 0xd0, 0x5d, 0xff, 0xfd, + 0xb2, 0x03, 0x92, 0x84, 0x4a, 0xe8, 0x12, 0x15, 0x46, 0x82, 0xe9, 0xcf, 0x01, 0x2f, + 0x90, 0x21, 0xa6, 0xf0, 0xbe, 0x17, 0xdd, 0xd0, 0xc2, 0x08, 0x4d, 0xce, 0x25, 0xff, + 0x9b, 0x06, 0xcd, 0xe5, 0x35, 0xd0, 0xf9, 0x20, 0xa2, 0xdb, 0x1b, 0xf3, 0x62, 0xc2, + 0x3e, 0x59, 0x6d, 0x11, 0xa4, 0xf5, 0xa6, 0xcf, 0x39, 0x48, 0x83, 0x8a, 0x3a, 0xec, + 0x4e, 0x15, 0xda, 0xf8, 0x50, 0x0a, 0x6e, 0xf6, 0x9e, 0xc4, 0xe3, 0xfe, 0xb6, 0xb1, + 0xd9, 0x8e, 0x61, 0x0a, 0xc8, 0xb7, 0xec, 0x3f, 0xaf, 0x6a, 0xd7, 0x60, 0xb7, 0xba, + 0xd1, 0xdb, 0x4b, 0xa3, 0x48, 0x5e, 0x8a, 0x94, 0xdc, 0x25, 0x0a, 0xe3, 0xfd, 0xb4, + 0x1e, 0xd1, 0x5f, 0xb6, 0xa8, 0xe5, 0xeb, 0xa0, 0xfc, 0x3d, 0xd6, 0x0b, 0xc8, 0xe3, + 0x0c, 0x5c, 0x42, 0x87, 0xe5, 0x38, 0x05, 0xdb, 0x05, 0x9a, 0xe0, 0x64, 0x8d, 0xb2, + 0xf6, 0x42, 0x64, 0xed, 0x5e, 0x39, 0xbe, 0x2e, 0x20, 0xd8, 0x2d, 0xf5, 0x66, 0xda, + 0x8d, 0xd5, 0x99, 0x8c, 0xca, 0xbd, 0xae, 0x05, 0x30, 0x60, 0xae, 0x6c, 0x7b, 0x43, + 0x78, 0xe8, 0x46, 0xd2, 0x9f, 0x37, 0xed, 0x7b, 0x4e, 0xa9, 0xec, 0x5d, 0x82, 0xe7, + 0x96, 0x1b, 0x7f, 0x25, 0xa9, 0x32, 0x38, 0x51, 0xf6, 0x81, 0xd5, 0x82, 0x36, 0x3a, + 0xa5, 0xf8, 0x99, 0x37, 0xf5, 0xa6, 0x72, 0x58, 0xbf, 0x63, 0xad, 0x6f, 0x1a, 0x0b, + 0x1d, 0x96, 0xdb, 0xd4, 0xfa, 0xdd, 0xfc, 0xef, 0xc5, 0x26, 0x6b, 0xa6, 0x61, 0x17, + 0x22, 0x39, 0x5c, 0x90, 0x65, 0x56, 0xbe, 0x52, 0xaf, 0xe3, 0xf5, 0x65, 0x63, 0x6a, + 0xd1, 0xb1, 0x7d, 0x50, 0x8b, 0x73, 0xd8, 0x74, 0x3e, 0xeb, 0x52, 0x4b, 0xe2, 0x2b, + 0x3d, 0xcb, 0xc2, 0xc7, 0x46, 0x8d, 0x54, 0x11, 0x9c, 0x74, 0x68, 0x44, 0x9a, 0x13, + 0xd8, 0xe3, 0xb9, 0x58, 0x11, 0xa1, 0x98, 0xf3, 0x49, 0x1d, 0xe3, 0xe7, 0xfe, 0x94, + 0x2b, 0x33, 0x04, 0x07, 0xab, 0xf8, 0x2a, 0x4e, 0xd7, 0xc1, 0xb3, 0x11, 0x66, 0x3a, + 0xc6, 0x98, 0x90, 0xf4, 0x15, 0x70, 0x15, 0x85, 0x3d, 0x91, 0xe9, 0x23, 0x03, 0x7c, + 0x22, 0x7a, 0x33, 0xcd, 0xd5, 0xec, 0x28, 0x1c, 0xa3, 0xf7, 0x9c, 0x44, 0x54, 0x6b, + 0x9d, 0x90, 0xca, 0x00, 0xf0, 0x64, 0xc9, 0x9e, 0x3d, 0xd9, 0x79, 0x11, 0xd3, 0x9f, + 0xe9, 0xc5, 0xd0, 0xb2, 0x3a, 0x22, 0x9a, 0x23, 0x4c, 0xb3, 0x61, 0x86, 0xc4, 0x81, + 0x9e, 0x8b, 0x9c, 0x59, 0x27, 0x72, 0x66, 0x32, 0x29, 0x1d, 0x6a, 0x41, 0x82, 0x11, + 0xcc, 0x29, 0x62, 0xe2, 0x0f, 0xe4, 0x7f, 0xeb, 0x3e, 0xdf, 0x33, 0x0f, 0x2c, 0x60, + 0x3a, 0x9d, 0x48, 0xc0, 0xfc, 0xb5, 0x69, 0x9d, 0xbf, 0xe5, 0x89, 0x64, 0x25, 0xc5, + 0xba, 0xc4, 0xae, 0xe8, 0x2e, 0x57, 0xa8, 0x5a, 0xaf, 0x4e, 0x25, 0x13, 0xe4, 0xf0, + 0x57, 0x96, 0xb0, 0x7b, 0xa2, 0xee, 0x47, 0xd8, 0x05, 0x06, 0xf8, 0xd2, 0xc2, 0x5e, + 0x50, 0xfd, 0x14, 0xde, 0x71, 0xe6, 0xc4, 0x18, 0x55, 0x93, 0x02, 0xf9, 0x39, 0xb0, + 0xe1, 0xab, 0xd5, 0x76, 0xf2, 0x79, 0xc4, 0xb2, 0xe0, 0xfe, 0xb8, 0x5c, 0x1f, 0x28, + 0xff, 0x18, 0xf5, 0x88, 0x91, 0xff, 0xef, 0x13, 0x2e, 0xef, 0x2f, 0xa0, 0x93, 0x46, + 0xae, 0xe3, 0x3c, 0x28, 0xeb, 0x13, 0x0f, 0xf2, 0x8f, 0x5b, 0x76, 0x69, 0x53, 0x33, + 0x41, 0x13, 0x21, 0x19, 0x96, 0xd2, 0x00, 0x11, 0xa1, 0x98, 0xe3, 0xfc, 0x43, 0x3f, + 0x9f, 0x25, 0x41, 0x01, 0x0a, 0xe1, 0x7c, 0x1b, 0xf2, 0x02, 0x58, 0x0f, 0x60, 0x47, + 0x47, 0x2f, 0xb3, 0x68, 0x57, 0xfe, 0x84, 0x3b, 0x19, 0xf5, 0x98, 0x40, 0x09, 0xdd, + 0xc3, 0x24, 0x04, 0x4e, 0x84, 0x7a, 0x4f, 0x4a, 0x0a, 0xb3, 0x4f, 0x71, 0x95, 0x95, + 0xde, 0x37, 0x25, 0x2d, 0x62, 0x35, 0x36, 0x5e, 0x9b, 0x84, 0x39, 0x2b, 0x06, 0x10, + 0x85, 0x34, 0x9d, 0x73, 0x20, 0x3a, 0x4a, 0x13, 0xe9, 0x6f, 0x54, 0x32, 0xec, 0x0f, + 0xd4, 0xa1, 0xee, 0x65, 0xac, 0xcd, 0xd5, 0xe3, 0x90, 0x4d, 0xf5, 0x4c, 0x1d, 0xa5, + 0x10, 0xb0, 0xff, 0x20, 0xdc, 0xc0, 0xc7, 0x7f, 0xcb, 0x2c, 0x0e, 0x0e, 0xb6, 0x05, + 0xcb, 0x05, 0x04, 0xdb, 0x87, 0x63, 0x2c, 0xf3, 0xd8, 0xb4, 0xda, 0xe6, 0xe7, 0x05, + 0x76, 0x9d, 0x1d, 0xe3, 0x54, 0x27, 0x01, 0x23, 0xcb, 0x11, 0x45, 0x0e, 0xfc, 0x60, + 0xac, 0x47, 0x68, 0x3d, 0x7b, 0x8d, 0x0f, 0x81, 0x13, 0x65, 0x56, 0x5f, 0xd9, 0x8c, + 0x4c, 0x8e, 0xb9, 0x36, 0xbc, 0xab, 0x8d, 0x06, 0x9f, 0xc3, 0x3b, 0xd8, 0x01, 0xb0, + 0x3a, 0xde, 0xa2, 0xe1, 0xfb, 0xc5, 0xaa, 0x46, 0x3d, 0x08, 0xca, 0x19, 0x89, 0x6d, + 0x2b, 0xf5, 0x9a, 0x07, 0x1b, 0x85, 0x1e, 0x6c, 0x23, 0x90, 0x52, 0x17, 0x2f, 0x29, + 0x6b, 0xfb, 0x5e, 0x72, 0x40, 0x47, 0x90, 0xa2, 0x18, 0x10, 0x14, 0xf3, 0xb9, 0x4a, + 0x4e, 0x97, 0xd1, 0x17, 0xb4, 0x38, 0x13, 0x03, 0x68, 0xcc, 0x39, 0xdb, 0xb2, 0xd1, + 0x98, 0x06, 0x5a, 0xe3, 0x98, 0x65, 0x47, 0x92, 0x6c, 0xd2, 0x16, 0x2f, 0x40, 0xa2, + 0x9f, 0x0c, 0x3c, 0x87, 0x45, 0xc0, 0xf5, 0x0f, 0xba, 0x38, 0x52, 0xe5, 0x66, 0xd4, + 0x45, 0x75, 0xc2, 0x9d, 0x39, 0xa0, 0x3f, 0x0c, 0xda, 0x72, 0x19, 0x84, 0xb6, 0xf4, + 0x40, 0x59, 0x1f, 0x35, 0x5e, 0x12, 0xd4, 0x39, 0xff, 0x15, 0x0a, 0xab, 0x76, 0x13, + 0x49, 0x9d, 0xbd, 0x49, 0xad, 0xab, 0xc8, 0x67, 0x6e, 0xef, 0x02, 0x3b, 0x15, 0xb6, + 0x5b, 0xfc, 0x5c, 0xa0, 0x69, 0x48, 0x10, 0x9f, 0x23, 0xf3, 0x50, 0xdb, 0x82, 0x12, + 0x35, 0x35, 0xeb, 0x8a, 0x74, 0x33, 0xbd, 0xab, 0xcb, 0x90, 0x92, 0x71, 0xa6, 0xec, + 0xbc, 0xb5, 0x8b, 0x93, 0x6a, 0x88, 0xcd, 0x4e, 0x8f, 0x2e, 0x6f, 0xf5, 0x80, 0x01, + 0x75, 0xf1, 0x13, 0x25, 0x3d, 0x8f, 0xa9, 0xca, 0x88, 0x85, 0xc2, 0xf5, 0x52, 0xe6, + 0x57, 0xdc, 0x60, 0x3f, 0x25, 0x2e, 0x1a, 0x8e, 0x30, 0x8f, 0x76, 0xf0, 0xbe, 0x79, + 0xe2, 0xfb, 0x8f, 0x5d, 0x5f, 0xbb, 0xe2, 0xe3, 0x0e, 0xca, 0xdd, 0x22, 0x07, 0x23, + 0xc8, 0xc0, 0xae, 0xa8, 0x07, 0x8c, 0xdf, 0xcb, 0x38, 0x68, 0x26, 0x3f, 0xf8, 0xf0, + 0x94, 0x00, 0x54, 0xda, 0x48, 0x78, 0x18, 0x93, 0xa7, 0xe4, 0x9a, 0xd5, 0xaf, 0xf4, + 0xaf, 0x30, 0x0c, 0xd8, 0x04, 0xa6, 0xb6, 0x27, 0x9a, 0xb3, 0xff, 0x3a, 0xfb, 0x64, + 0x49, 0x1c, 0x85, 0x19, 0x4a, 0xab, 0x76, 0x0d, 0x58, 0xa6, 0x06, 0x65, 0x4f, 0x9f, + 0x44, 0x00, 0xe8, 0xb3, 0x85, 0x91, 0x35, 0x6f, 0xbf, 0x64, 0x25, 0xac, 0xa2, 0x6d, + 0xc8, 0x52, 0x44, 0x25, 0x9f, 0xf2, 0xb1, 0x9c, 0x41, 0xb9, 0xf9, 0x6f, 0x3c, 0xa9, + 0xec, 0x1d, 0xde, 0x43, 0x4d, 0xa7, 0xd2, 0xd3, 0x92, 0xb9, 0x05, 0xdd, 0xf3, 0xd1, + 0xf9, 0xaf, 0x93, 0xd1, 0xaf, 0x59, 0x50, 0xbd, 0x49, 0x3f, 0x5a, 0xa7, 0x31, 0xb4, + 0x05, 0x6d, 0xf3, 0x1b, 0xd2, 0x67, 0xb6, 0xb9, 0x0a, 0x07, 0x98, 0x31, 0xaa, 0xf5, + 0x79, 0xbe, 0x0a, 0x39, 0x01, 0x31, 0x37, 0xaa, 0xc6, 0xd4, 0x04, 0xf5, 0x18, 0xcf, + 0xd4, 0x68, 0x40, 0x64, 0x7e, 0x78, 0xbf, 0xe7, 0x06, 0xca, 0x4c, 0xf5, 0xe9, 0xc5, + 0x45, 0x3e, 0x9f, 0x7c, 0xfd, 0x2b, 0x8b, 0x4c, 0x8d, 0x16, 0x9a, 0x44, 0xe5, 0x5c, + 0x88, 0xd4, 0xa9, 0xa7, 0xf9, 0x47, 0x42, 0x41, 0xe2, 0x21, 0xaf, 0x44, 0x86, 0x00, + 0x18, 0xab, 0x08, 0x56, 0x97, 0x2e, 0x19, 0x4c, 0xd9, 0x34, + ]; + + assert_eq!(&buf, PROTECTED); + + let (header, payload) = buf.split_at_mut(header_len); + let (first, rest) = header.split_at_mut(1); + let sample = &payload[..sample_len]; + + let server_keys = Keys::initial( + Version::V1, + TLS13_AES_128_GCM_SHA256.tls13().unwrap(), + TLS13_AES_128_GCM_SHA256.tls13().unwrap().quic.unwrap(), + CONNECTION_ID, + Side::Server, + ); + server_keys + .remote + .header + .decrypt_in_place(sample, &mut first[0], &mut rest[17..21]) + .unwrap(); + let payload = server_keys + .remote + .packet + .decrypt_in_place(PACKET_NUMBER, header, payload) + .unwrap(); + + assert_eq!(&payload[..PAYLOAD.len()], PAYLOAD); + assert_eq!(payload.len(), buf.len() - header_len - tag_len); + } + + #[test] + fn test_aes_mask_generation() { + //Test idea taken from ring + // Copyright 2018 Brian Smith. + let vectors = [ + AesTestVector::Aes128 { + key: hex!("e8904ecc2e37a6e4cc02271e319c804b"), + sample: hex!("13484ec85dc4d36349697c7d4ea1a159"), + mask: hex!("67387ebf3a"), + }, + AesTestVector::Aes128 { + key: hex!("e8904ecc2e37a6e4cc02271e319c804b"), + sample: hex!("00000000000000000000000fffffffff"), + mask: hex!("feb191f8af"), + }, + AesTestVector::Aes128 { + key: hex!("e8904ecc2e37a6e4cc02271e319c804b"), + sample: hex!("000000000000000fffffffffffffffff"), + mask: hex!("6f23441ee8"), + }, + AesTestVector::Aes256 { + key: hex!("85af7213814aec7b92ace6284a906643912ec8853d00d158a927b8697c7ff585"), + sample: hex!("82a0db90f4cee12fa4afeddb74396cf6"), + mask: hex!("670897adf5"), + }, + AesTestVector::Aes256 { + key: hex!("85af7213814aec7b92ace6284a906643912ec8853d00d158a927b8697c7ff585"), + sample: hex!("000000000000000000000000ffffffff"), + mask: hex!("b77a18bb3f"), + }, + AesTestVector::Aes256 { + key: hex!("85af7213814aec7b92ace6284a906643912ec8853d00d158a927b8697c7ff585"), + sample: hex!("000000000000000fffffffffffffffff"), + mask: hex!("4aadd3cbef"), + }, + ]; + + let mut aes_cipher = crate::aead::quic::AesCipher::new().unwrap(); + let mut mask = [0u8; 5]; + + for v in &vectors { + let (v_key, v_sample, v_mask): (&[u8], &[u8], &[u8]) = match v { + AesTestVector::Aes128 { key, sample, mask } => { + (key.as_slice(), sample.as_slice(), mask.as_slice()) + } + AesTestVector::Aes256 { key, sample, mask } => { + (key.as_slice(), sample.as_slice(), mask.as_slice()) + } + }; + let _ = aes_cipher.set_key(v_key); + mask.copy_from_slice(&aes_cipher.encrypt_sample(v_sample).unwrap()[..5]); + assert_eq!(v_mask, mask) + } + } + + #[test] + fn test_chacha_mask_generation() { + //Test idea taken from ring + // Copyright 2018 Brian Smith. + + let test_vector = ChaCha20TestVector { + key: hex!("59bdff7a5bcdaacf319d99646c6273ad96687d2c74ace678f15a1c710675bb23"), + sample: hex!("215a7c1688b4ab7d830dcd052aef9f3c"), + mask: hex!("6409a6196d"), + }; + + let mut chacha_cipher = crate::aead::quic::ChaChaCipher::new(None).unwrap(); + let mut mask = mask_array!(); + + let _ = chacha_cipher.set_key(&test_vector.key); + mask.copy_from_slice(&chacha_cipher.encrypt_sample(&test_vector.sample).unwrap()[..5]); + assert_eq!(test_vector.mask, mask) + } + + #[test] + fn test_sample_len() { + let hp_algs: Vec<&aead::quic::HPAlgorithm> = vec![ + &aead::quic::AES_128, + &aead::quic::AES_256, + &aead::quic::CHACHA20, + ]; + let mut first = vec![0u8; 1]; + let mut packet_number = vec![0u8; 4]; + for alg in hp_algs { + let key_len = alg.key_len(); + let key_data = vec![0u8; key_len]; + + let key = aead::quic::HeaderProtectionKey::new(key_data, alg).unwrap(); + + let sample_len = 16; + let sample_data = vec![0u8; sample_len + 2]; + + // Sample is the right size. + assert!(key + .encrypt_in_place( + &sample_data[..sample_len], + &mut first[0], + packet_number.as_mut_slice() + ) + .is_ok()); + + // Sample is one byte too small. + assert!(key + .encrypt_in_place( + &sample_data[..(sample_len - 1)], + &mut first[0], + packet_number.as_mut_slice() + ) + .is_err()); + + // Sample is one byte too big. + assert!(key + .encrypt_in_place( + &sample_data[..(sample_len + 1)], + &mut first[0], + packet_number.as_mut_slice() + ) + .is_err()); + + // Sample is empty. + assert!(key + .encrypt_in_place(&[], &mut first[0], packet_number.as_mut_slice()) + .is_err()); + } + } + + #[test] + fn test_key_len() { + let hp_algs: Vec<&aead::quic::HPAlgorithm> = vec![ + &aead::quic::AES_128, + &aead::quic::AES_256, + &aead::quic::CHACHA20, + ]; + for alg in hp_algs { + let key_len = alg.key_len(); + let key_data = vec![0u8; key_len + 5]; + + // Key is the right size. + assert!( + aead::quic::HeaderProtectionKey::new(key_data[..key_len].to_vec(), alg).is_ok() + ); + + // Key is one byte too small. + assert!( + aead::quic::HeaderProtectionKey::new(key_data[..key_len - 1].to_vec(), alg) + .is_err() + ); + + // Key is one byte too big. + assert!( + aead::quic::HeaderProtectionKey::new(key_data[..key_len + 1].to_vec(), alg) + .is_err() + ); + + // Key is empty. + assert!(aead::quic::HeaderProtectionKey::new(Vec::new(), alg).is_err()); + } + } +} diff --git a/rustls-wolfcrypt-provider/src/kx.rs b/rustls-wolfcrypt-provider/src/kx.rs index e60c372..ab4d90f 100644 --- a/rustls-wolfcrypt-provider/src/kx.rs +++ b/rustls-wolfcrypt-provider/src/kx.rs @@ -1,13 +1,13 @@ use alloc::boxed::Box; use crypto::SupportedKxGroup; use rustls::crypto; - +use rustls::ffdhe_groups::FfdheGroup; mod sec256r1; mod sec384r1; mod sec521r1; mod x25519; -pub const ALL_KX_GROUPS: &[&dyn SupportedKxGroup] = &[&X25519, &SecP256R1, &SecP384R1, &SecP521R1]; +pub const ALL_KX_GROUPS: &[&dyn SupportedKxGroup] = &[&X25519, &SECP256R1, &SECP384R1, &SECP521R1]; macro_rules! define_kx_group { ($name:ident, $kx_type:ty, $kx_func:ident, $named_group:expr) => { @@ -22,6 +22,10 @@ macro_rules! define_kx_group { fn name(&self) -> rustls::NamedGroup { $named_group } + + fn ffdhe_group(&self) -> Option> { + None + } } }; } @@ -34,19 +38,19 @@ define_kx_group!( rustls::NamedGroup::X25519 ); define_kx_group!( - SecP256R1, + SECP256R1, sec256r1::KeyExchangeSecP256r1, use_secp256r1, rustls::NamedGroup::secp256r1 ); define_kx_group!( - SecP384R1, + SECP384R1, sec384r1::KeyExchangeSecP384r1, use_secp384r1, rustls::NamedGroup::secp384r1 ); define_kx_group!( - SecP521R1, + SECP521R1, sec521r1::KeyExchangeSecP521r1, use_secp521r1, rustls::NamedGroup::secp521r1 diff --git a/rustls-wolfcrypt-provider/src/lib.rs b/rustls-wolfcrypt-provider/src/lib.rs index aa6c592..0d5ae81 100644 --- a/rustls-wolfcrypt-provider/src/lib.rs +++ b/rustls-wolfcrypt-provider/src/lib.rs @@ -1,10 +1,9 @@ #![cfg_attr(not(test), no_std)] +extern crate alloc; #[cfg(test)] extern crate std; -extern crate alloc; - use alloc::boxed::Box; use alloc::sync::Arc; use alloc::vec; @@ -13,7 +12,7 @@ use rustls::crypto::CryptoProvider; use rustls::pki_types::PrivateKeyDer; pub mod error; mod hkdf; -mod kx; +pub mod kx; mod prf; mod random; mod verify; @@ -23,12 +22,16 @@ pub mod aead { pub mod aes128gcm; pub mod aes256gcm; pub mod chacha20; + #[cfg(feature = "quic")] + pub mod quic; } pub mod sign { pub mod ecdsa; pub mod eddsa; pub mod rsa; } +#[cfg(feature = "quic")] +use crate::aead::quic::KeyFactory; use crate::aead::{aes128gcm, aes256gcm, chacha20}; pub mod hash { @@ -50,7 +53,7 @@ type SigningAlgorithms = Vec>; /* * Crypto provider struct that we populate with our own crypto backend (wolfcrypt). * */ -pub fn provider() -> CryptoProvider { +pub fn default_provider() -> CryptoProvider { CryptoProvider { cipher_suites: ALL_CIPHER_SUITES.to_vec(), kx_groups: kx::ALL_KX_GROUPS.to_vec(), @@ -110,16 +113,16 @@ impl rustls::crypto::KeyProvider for Provider { } } -static ALL_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = &[ - TLS13_CHACHA20_POLY1305_SHA256, - TLS13_AES_128_GCM_SHA256, +pub static ALL_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = &[ TLS13_AES_256_GCM_SHA384, - TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, - TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS13_AES_128_GCM_SHA256, + TLS13_CHACHA20_POLY1305_SHA256, TLS12_ECDHE_RSA_WITH_AES_256_GCM_SHA384, - TLS12_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, - TLS12_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS12_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, TLS12_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + TLS12_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + TLS12_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, ]; static ALL_RSA_SCHEMES: &[rustls::SignatureScheme] = &[ @@ -147,7 +150,17 @@ pub static TLS13_CHACHA20_POLY1305_SHA256: rustls::SupportedCipherSuite = }, hkdf_provider: &WCHkdfUsingHmac(WCShaHmac::Sha256), aead_alg: &chacha20::Chacha20Poly1305, + #[cfg(not(feature = "quic"))] quic: None, + #[cfg(feature = "quic")] + quic: Some(&KeyFactory { + packet_algo: &aead::quic::CHACHA20_POLY1305, + header_algo: &aead::quic::CHACHA20, + // ref: + confidentiality_limit: u64::MAX, + // ref: + integrity_limit: 1 << 36, + }), }); pub static TLS13_AES_128_GCM_SHA256: rustls::SupportedCipherSuite = @@ -159,7 +172,17 @@ pub static TLS13_AES_128_GCM_SHA256: rustls::SupportedCipherSuite = }, hkdf_provider: &WCHkdfUsingHmac(WCShaHmac::Sha256), aead_alg: &aes128gcm::Aes128Gcm, + #[cfg(not(feature = "quic"))] quic: None, + #[cfg(feature = "quic")] + quic: Some(&KeyFactory { + packet_algo: &aead::quic::AES_128_GCM, + header_algo: &aead::quic::AES_128, + // ref: + confidentiality_limit: 1 << 23, + // ref: + integrity_limit: 1 << 52, + }), }); pub static TLS13_AES_256_GCM_SHA384: rustls::SupportedCipherSuite = @@ -171,7 +194,17 @@ pub static TLS13_AES_256_GCM_SHA384: rustls::SupportedCipherSuite = }, hkdf_provider: &WCHkdfUsingHmac(WCShaHmac::Sha384), aead_alg: &aes256gcm::Aes256Gcm, + #[cfg(not(feature = "quic"))] quic: None, + #[cfg(feature = "quic")] + quic: Some(&KeyFactory { + packet_algo: &aead::quic::AES_256_GCM, + header_algo: &aead::quic::AES_256, + // ref: + confidentiality_limit: 1 << 23, + // ref: + integrity_limit: 1 << 52, + }), }); pub static TLS12_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: rustls::SupportedCipherSuite = diff --git a/rustls-wolfcrypt-provider/src/sign/eddsa.rs b/rustls-wolfcrypt-provider/src/sign/eddsa.rs index c7a9a47..fa07a8f 100644 --- a/rustls-wolfcrypt-provider/src/sign/eddsa.rs +++ b/rustls-wolfcrypt-provider/src/sign/eddsa.rs @@ -30,6 +30,86 @@ impl fmt::Debug for Ed25519PrivateKey { .finish_non_exhaustive() } } +impl Ed25519PrivateKey { + /// Extract ED25519 private and if available public key values from a PKCS#8 DER formatted key + fn extract_key_pair(input_key: &[u8]) -> Result<([u8; 32], Option<[u8; 32]>), rustls::Error> { + let mut public_key_raw: [u8; 32] = [0; ED25519_PUB_KEY_SIZE as usize]; + let mut private_key_raw: [u8; 32] = [0; ED25519_KEY_SIZE as usize]; + let mut skip_bytes: usize; + let mut key_sub_slice = input_key; + + const SHORT_FORM_LEN_MAX: u8 = 127; + const TAG_SEQUENCE: u8 = 0x30; + const TAG_OCTET_SEQUENCE: u8 = 0x04; + const TAG_OPTIONAL_SET_OF_ATTRIBUTES: u8 = 0x80; //Implicit, context-specific, and primitive underlying type (SET OF) + const TAG_OPTIONAL_PUBLIC_KEY_BIT_STRING: u8 = 0x81; //Implicit, context-specific, and primitive underlying type (BIT STRING) + + // The input key is encoded in PKCS#8 DER format with a structure as in + // https://www.rfc-editor.org/rfc/rfc5958.html + // + // AsymmetricKeyPackage ::= SEQUENCE SIZE (1..MAX) OF OneAsymmetricKey + + // OneAsymmetricKey ::= SEQUENCE { + // version Version, + // privateKeyAlgorithm PrivateKeyAlgorithmIdentifier, + // privateKey PrivateKey, + // attributes [0] Attributes OPTIONAL, + // ..., + // [[2: publicKey [1] PublicKey OPTIONAL ]], + // ... + // } + + // The key structure must begin with a SEQUENCE tag with at least 2 bytes length if short + // length format is used + if key_sub_slice[0] != TAG_SEQUENCE || key_sub_slice.len() < 2 { + return Err(rustls::Error::General( + "Faulty PKCS#8 ED25519 private key structure".into(), + )); + } + // Check which length format and skip tag and length encoding bytes + if key_sub_slice[1] > SHORT_FORM_LEN_MAX { + skip_bytes = (2 + (key_sub_slice[1] & 0x7F)) as usize; + } else { + skip_bytes = 2; + } + + // Skip version (3 bytes), algorithm ID sequence (0x30 + length encoding bytes + 5 ID bytes), + skip_bytes += 3 + 7; + key_sub_slice = input_key.get(skip_bytes..).unwrap(); + + // Check if next bytes are 0x04, 0x22, 0x04, 0x20 + if !matches!( + key_sub_slice, + [TAG_OCTET_SEQUENCE, 0x22, TAG_OCTET_SEQUENCE, 0x20, ..] + ) { + return Err(rustls::Error::General( + "Faulty PKCS#8 ED25519 private key structure".into(), + )); + } + + // Copy private key value + skip_bytes += 4; + key_sub_slice = input_key.get(skip_bytes..).unwrap(); + private_key_raw.copy_from_slice(&key_sub_slice[..ED25519_KEY_SIZE as usize]); + skip_bytes += ED25519_KEY_SIZE as usize; + key_sub_slice = input_key.get(skip_bytes..).unwrap(); + + // Check if optional SET OF attributes exists and skip + if key_sub_slice.first() == Some(&TAG_OPTIONAL_SET_OF_ATTRIBUTES) { + skip_bytes += (2 + (key_sub_slice[1])) as usize; + key_sub_slice = input_key.get(skip_bytes..).unwrap(); + } + + // Check if optional public key value exists. If exists, skip tag, length encoding byte, + // and bits-used byte + if key_sub_slice.first() == Some(&TAG_OPTIONAL_PUBLIC_KEY_BIT_STRING) { + public_key_raw.copy_from_slice(&key_sub_slice[3..(3 + ED25519_KEY_SIZE as usize)]); + Ok((private_key_raw, Some(public_key_raw))) + } else { + Ok((private_key_raw, None)) + } + } +} impl TryFrom<&PrivateKeyDer<'_>> for Ed25519PrivateKey { type Error = rustls::Error; @@ -37,59 +117,52 @@ impl TryFrom<&PrivateKeyDer<'_>> for Ed25519PrivateKey { fn try_from(value: &PrivateKeyDer<'_>) -> Result { match value { PrivateKeyDer::Pkcs8(der) => { - let mut ed25519_c_type: ed25519_key = unsafe { mem::zeroed() }; - let ed25519_key_object = ED25519KeyObject::new(&mut ed25519_c_type); - let mut priv_key_raw: [u8; 32] = [0; 32]; - let mut priv_key_raw_len: word32 = priv_key_raw.len() as word32; - let mut pub_key_raw: [u8; 32] = [0; 32]; - let pub_key_raw_len: word32 = pub_key_raw.len() as word32; let pkcs8: &[u8] = der.secret_pkcs8_der(); - let pkcs8_sz: word32 = pkcs8.len() as word32; - let mut ret; + let (priv_key_raw, pub_option) = match Ed25519PrivateKey::extract_key_pair(pkcs8) { + Ok((priv_value, pub_value)) => (priv_value, pub_value), - // This function initiliazes an ed25519_key object for - // using it to sign a message. - ed25519_key_object.init(); - - let mut idx: u32 = 0; - - // This function reads in an ED25519 private key from the input buffer, input, - // parses the private key, and uses it to generate an ed25519_key object, - // which it stores in key. - ret = unsafe { - wc_Ed25519PrivateKeyDecode( - pkcs8.as_ptr() as *mut u8, - &mut idx, - ed25519_key_object.as_ptr(), - pkcs8_sz, - ) - }; - check_if_zero(ret) - .map_err(|_| rustls::Error::General("FFI function failed".into()))?; - - ret = unsafe { - wc_ed25519_make_public( - ed25519_key_object.as_ptr(), - pub_key_raw.as_mut_ptr(), - pub_key_raw_len, - ) - }; - check_if_zero(ret) - .map_err(|_| rustls::Error::General("FFI function failed".into()))?; - - ret = unsafe { - wc_ed25519_export_private_only( - ed25519_key_object.as_ptr(), - priv_key_raw.as_mut_ptr(), - &mut priv_key_raw_len, - ) + Err(error) => return Err(error), }; - check_if_zero(ret) - .map_err(|_| rustls::Error::General("FFI function failed".into()))?; + + let mut ret; + let mut pub_key_raw: [u8; 32] = [0; 32]; + let pub_key_raw_len: word32 = pub_key_raw.len() as word32; + + // Generate pub key part if not given + if pub_option.is_none() { + let mut ed25519_c_type: ed25519_key = unsafe { mem::zeroed() }; + let ed25519_key_object = ED25519KeyObject::new(&mut ed25519_c_type); + // This function initiliazes an ed25519_key object for + // using it to sign a message. + ed25519_key_object.init(); + + ret = unsafe { + wc_ed25519_import_private_only( + priv_key_raw.as_ptr(), + priv_key_raw.len() as word32, + ed25519_key_object.as_ptr(), + ) + }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("FFI function failed".into()))?; + + ret = unsafe { + wc_ed25519_make_public( + ed25519_key_object.as_ptr(), + pub_key_raw.as_mut_ptr(), + pub_key_raw_len, + ) + }; + check_if_zero(ret) + .map_err(|_| rustls::Error::General("FFI function failed".into()))?; + } Ok(Self { priv_key: Arc::new(Zeroizing::new(priv_key_raw.to_vec())), - pub_key: Arc::new(pub_key_raw.to_vec()), + pub_key: Arc::new(match pub_option { + Some(pub_value) => pub_value.to_vec(), + None => pub_key_raw.to_vec(), + }), algo: SignatureAlgorithm::ED25519, }) } @@ -173,8 +246,7 @@ impl Signer for Ed25519Signer { }; if ret < 0 { return Err(rustls::Error::General(format!( - "wc_ed25519_sign_msg failed: {}", - ret + "wc_ed25519_sign_msg failed: {ret}", ))); } diff --git a/rustls-wolfcrypt-provider/src/types/mod.rs b/rustls-wolfcrypt-provider/src/types/mod.rs index fc86caa..9768742 100644 --- a/rustls-wolfcrypt-provider/src/types/mod.rs +++ b/rustls-wolfcrypt-provider/src/types/mod.rs @@ -206,6 +206,7 @@ define_foreign_type!( define_foreign_type_no_copy!(RsaKeyObject, RsaKeyObjectRef, RsaKey, drop(wc_FreeRsaKey)); define_foreign_type_with_copy!(HmacObject, HmacObjectRef, wolfcrypt_rs::Hmac); define_foreign_type_no_copy!(AesObject, AesObjectRef, Aes, drop_void(wc_AesFree)); +define_foreign_type_with_copy!(ChaChaObject, ChaChaObjectRef, ChaCha); define_foreign_type_no_copy!( Sha256Object, Sha256ObjectRef, diff --git a/rustls-wolfcrypt-provider/src/verify.rs b/rustls-wolfcrypt-provider/src/verify.rs index c03938b..6cec2e8 100644 --- a/rustls-wolfcrypt-provider/src/verify.rs +++ b/rustls-wolfcrypt-provider/src/verify.rs @@ -13,6 +13,7 @@ pub static ALGORITHMS: WebPkiSupportedAlgorithms = WebPkiSupportedAlgorithms { RSA_PSS_SHA384, RSA_PKCS1_SHA256, RSA_PKCS1_SHA384, + RSA_PKCS1_SHA512, ECDSA_P256_SHA256, ECDSA_P384_SHA384, ECDSA_P521_SHA512, diff --git a/rustls-wolfcrypt-provider/tests/e2e.rs b/rustls-wolfcrypt-provider/tests/e2e.rs index 430beee..ba58b3e 100644 --- a/rustls-wolfcrypt-provider/tests/e2e.rs +++ b/rustls-wolfcrypt-provider/tests/e2e.rs @@ -404,7 +404,7 @@ mod tests { #[test] fn ecdsa_sign_and_verify() { - let wolfcrypt_default_provider = rustls_wolfcrypt_provider::provider(); + let wolfcrypt_default_provider = rustls_wolfcrypt_provider::default_provider(); // Define schemes, curve IDs, and key sizes as tuples let test_configs = [ @@ -509,7 +509,7 @@ mod tests { #[test] fn eddsa_sign_and_verify() { - let wolfcrypt_default_provider = rustls_wolfcrypt_provider::provider(); + let wolfcrypt_default_provider = rustls_wolfcrypt_provider::default_provider(); // Initialize RNG and ECC key objects let mut rng: WC_RNG = unsafe { mem::zeroed() }; @@ -578,7 +578,7 @@ mod tests { fn rsa_pss_sign_and_verify() { init_thread_pool(); - let wolfcrypt_default_provider = rustls_wolfcrypt_provider::provider(); + let wolfcrypt_default_provider = rustls_wolfcrypt_provider::default_provider(); let schemes = [ SignatureScheme::RSA_PSS_SHA256, SignatureScheme::RSA_PSS_SHA384, @@ -669,7 +669,7 @@ mod tests { fn rsa_pkcs1_sign_and_verify() { init_thread_pool(); - let wolfcrypt_default_provider = rustls_wolfcrypt_provider::provider(); + let wolfcrypt_default_provider = rustls_wolfcrypt_provider::default_provider(); let test_cases: Vec<_> = [ SignatureScheme::RSA_PKCS1_SHA256, SignatureScheme::RSA_PKCS1_SHA384, diff --git a/wolfcrypt-rs/build.rs b/wolfcrypt-rs/build.rs index e51f964..0a49813 100644 --- a/wolfcrypt-rs/build.rs +++ b/wolfcrypt-rs/build.rs @@ -17,7 +17,7 @@ const WOLFSSL_SHA256: &str = "1aeb6e49222bb9d8cf012063f0dfc3f229084f24ce2b5740a2 /// Handles the main build process and exits with an error code if anything fails. fn main() { if let Err(e) = run_build() { - eprintln!("Build failed: {}", e); + eprintln!("Build failed: {e}"); std::process::exit(1); } } @@ -74,7 +74,7 @@ fn generate_bindings() -> Result<()> { let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); bindings .write_to_file(out_path.join("bindings.rs")) - .map_err(|e| io::Error::other(format!("Couldn't write bindings: {}", e))) + .map_err(|e| io::Error::other(format!("Couldn't write bindings: {e}"))) } /// Coordinates the complete setup process for WolfSSL. @@ -189,7 +189,7 @@ fn remove_zip() -> Result<()> { /// Returns `Ok(())` if all build steps succeed, or an error if any step fails. fn build_wolfssl() -> Result<()> { env::set_current_dir(WOLFSSL_DIR)?; - println!("Changed directory to {}.", WOLFSSL_DIR); + println!("Changed directory to {WOLFSSL_DIR}."); let prefix = install_prefix(); let prefix_arg = format!("--prefix={}", prefix.to_str().unwrap()); @@ -230,7 +230,7 @@ fn run_command(cmd: &str, args: &[&str]) -> Result<()> { String::from_utf8_lossy(&output.stderr) ))); } - println!("{} completed successfully.", cmd); + println!("{cmd} completed successfully."); Ok(()) } diff --git a/wolfcrypt-rs/src/bindings.rs b/wolfcrypt-rs/src/bindings.rs index df30230..98533f3 100644 --- a/wolfcrypt-rs/src/bindings.rs +++ b/wolfcrypt-rs/src/bindings.rs @@ -12,6 +12,7 @@ #![allow(non_upper_case_globals)] #![allow(non_camel_case_types)] #![allow(non_snake_case)] +#![allow(unnecessary_transmutes)] #![allow(clippy::useless_transmute)] #![allow(clippy::upper_case_acronyms)] #![allow(clippy::too_many_arguments)]