Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 84 additions & 16 deletions crates/attested-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -663,20 +663,13 @@ impl ServerCertVerifier for AttestedCertificateVerifier {
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
if let Some(server_inner) = &self.server_inner {
match server_inner.verify_server_cert(
server_inner.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
) {
Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)) => {
// handle self-signed certs differently
Self::verify_server_cert_constraints(end_entity, server_name, now)?;
}
Err(err) => return Err(err),
Ok(_) => {}
}
)?;
} else {
Self::verify_server_cert_constraints(end_entity, server_name, now)?;
}
Expand Down Expand Up @@ -741,13 +734,7 @@ impl ClientCertVerifier for AttestedCertificateVerifier {
now: UnixTime,
) -> Result<ClientCertVerified, rustls::Error> {
if let Some(client_inner) = &self.client_inner {
match client_inner.verify_client_cert(end_entity, intermediates, now) {
Err(rustls::Error::InvalidCertificate(rustls::CertificateError::UnknownIssuer)) => {
Self::verify_cert_time_validity(end_entity, now)?;
}
Err(err) => return Err(err),
Ok(_) => {}
}
client_inner.verify_client_cert(end_entity, intermediates, now)?;
} else {
Self::verify_cert_time_validity(end_entity, now)?;
}
Expand Down Expand Up @@ -853,6 +840,20 @@ mod tests {
)
}

/// Test helper to verify a client certificate
fn verify_client_cert_direct(
verifier: &AttestedCertificateVerifier,
end_entity: &CertificateDer<'_>,
now: UnixTime,
) -> Result<ClientCertVerified, rustls::Error> {
rustls::server::danger::ClientCertVerifier::verify_client_cert(
verifier,
end_entity,
&[],
now,
)
}

#[tokio::test(flavor = "multi_thread")]
async fn certificate_resolver_creates_initial_certificate() {
let provider: Arc<CryptoProvider> = aws_lc_rs::default_provider().into();
Expand Down Expand Up @@ -1190,6 +1191,73 @@ mod tests {
assert_eq!(result.unwrap_err(), Error::InvalidCertificate(CertificateError::BadEncoding));
}

#[tokio::test(flavor = "multi_thread")]
async fn private_ca_verifier_rejects_untrusted_self_signed_attested_server_cert() {
let provider: Arc<CryptoProvider> = aws_lc_rs::default_provider().into();
let ca = test_ca();
let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap();
let resolver = AttestedCertificateResolver::new_with_provider(
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
None,
"foo".to_string(),
vec![],
provider.clone(),
Duration::from_secs(4),
)
.await
.unwrap();
let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone();

let mut roots = RootCertStore::empty();
roots.add(ca_cert).unwrap();
let verifier = AttestedCertificateVerifier::new_with_provider(
Some(roots),
AttestationVerifier::mock(),
provider,
)
.unwrap();

let result = verify_server_cert_direct(
&verifier,
&cert,
&ServerName::try_from("foo").unwrap(),
UnixTime::now(),
);

assert_eq!(result.unwrap_err(), Error::InvalidCertificate(CertificateError::UnknownIssuer));
}

#[tokio::test(flavor = "multi_thread")]
async fn private_ca_verifier_rejects_untrusted_self_signed_attested_client_cert() {
let provider: Arc<CryptoProvider> = aws_lc_rs::default_provider().into();
let ca = test_ca();
let ca_cert = CertificateDer::from_pem_slice(ca.pem_cert.as_bytes()).unwrap();
let resolver = AttestedCertificateResolver::new_with_provider(
AttestationGenerator::new(AttestationType::DcapTdx, None).unwrap(),
None,
"client".to_string(),
vec![],
provider.clone(),
Duration::from_secs(4),
)
.await
.unwrap();
let cert = resolver.state.certificate.read().unwrap().first().unwrap().clone();

let mut roots = RootCertStore::empty();
roots.add(ca_cert).unwrap();
let verifier = AttestedCertificateVerifier::new_with_provider(
Some(roots),
AttestationVerifier::mock(),
provider,
)
.unwrap();

let result = verify_client_cert_direct(&verifier, &cert, UnixTime::now());

assert_eq!(result.unwrap_err(), Error::InvalidCertificate(CertificateError::UnknownIssuer));
}

#[tokio::test(flavor = "multi_thread")]
async fn self_signed_attested_certificate_with_wrong_name_is_rejected() {
let provider: Arc<CryptoProvider> = aws_lc_rs::default_provider().into();
Expand Down
Loading