use alloc::boxed::Box; use core::cmp::min; use crate::crypto::cipher::{InboundOpaqueMessage, MessageDecrypter, MessageEncrypter}; use crate::error::Error; use crate::log::trace; use crate::msgs::message::{InboundPlainMessage, OutboundOpaqueMessage, OutboundPlainMessage}; #[derive(PartialEq)] enum DirectionState { /// No keying material. Invalid, /// Keying material present, but not yet in use. Prepared, /// Keying material in use. Active, } /// Record layer that tracks decryption and encryption keys. pub(crate) struct RecordLayer { message_encrypter: Box, message_decrypter: Box, write_seq_max: u64, write_seq: u64, read_seq: u64, has_decrypted: bool, encrypt_state: DirectionState, decrypt_state: DirectionState, // Message encrypted with other keys may be encountered, so failures // should be swallowed by the caller. This struct tracks the amount // of message size this is allowed for. trial_decryption_len: Option, } impl RecordLayer { /// Create new record layer with no keys. pub(crate) fn new() -> Self { Self { message_encrypter: ::invalid(), message_decrypter: ::invalid(), write_seq_max: 0, write_seq: 0, read_seq: 0, has_decrypted: false, encrypt_state: DirectionState::Invalid, decrypt_state: DirectionState::Invalid, trial_decryption_len: None, } } /// Decrypt a TLS message. /// /// `encr` is a decoded message allegedly received from the peer. /// If it can be decrypted, its decryption is returned. Otherwise, /// an error is returned. pub(crate) fn decrypt_incoming<'a>( &mut self, encr: InboundOpaqueMessage<'a>, ) -> Result>, Error> { if self.decrypt_state != DirectionState::Active { return Ok(Some(Decrypted { want_close_before_decrypt: false, plaintext: encr.into_plain_message(), })); } // Set to `true` if the peer appears to getting close to encrypting // too many messages with this key. // // Perhaps if we send an alert well before their counter wraps, a // buggy peer won't make a terrible mistake here? // // Note that there's no reason to refuse to decrypt: the security // failure has already happened. let want_close_before_decrypt = self.read_seq == SEQ_SOFT_LIMIT; let encrypted_len = encr.payload.len(); match self .message_decrypter .decrypt(encr, self.read_seq) { Ok(plaintext) => { self.read_seq += 1; if !self.has_decrypted { self.has_decrypted = true; } Ok(Some(Decrypted { want_close_before_decrypt, plaintext, })) } Err(Error::DecryptError) if self.doing_trial_decryption(encrypted_len) => { trace!("Dropping undecryptable message after aborted early_data"); Ok(None) } Err(err) => Err(err), } } /// Encrypt a TLS message. /// /// `plain` is a TLS message we'd like to send. This function /// panics if the requisite keying material hasn't been established yet. pub(crate) fn encrypt_outgoing( &mut self, plain: OutboundPlainMessage<'_>, ) -> OutboundOpaqueMessage { debug_assert!(self.encrypt_state == DirectionState::Active); assert!(self.next_pre_encrypt_action() != PreEncryptAction::Refuse); let seq = self.write_seq; self.write_seq += 1; self.message_encrypter .encrypt(plain, seq) .unwrap() } /// Prepare to use the given `MessageEncrypter` for future message encryption. /// It is not used until you call `start_encrypting`. pub(crate) fn prepare_message_encrypter( &mut self, cipher: Box, max_messages: u64, ) { self.message_encrypter = cipher; self.write_seq = 0; self.write_seq_max = min(SEQ_SOFT_LIMIT, max_messages); self.encrypt_state = DirectionState::Prepared; } /// Prepare to use the given `MessageDecrypter` for future message decryption. /// It is not used until you call `start_decrypting`. pub(crate) fn prepare_message_decrypter(&mut self, cipher: Box) { self.message_decrypter = cipher; self.read_seq = 0; self.decrypt_state = DirectionState::Prepared; } /// Start using the `MessageEncrypter` previously provided to the previous /// call to `prepare_message_encrypter`. pub(crate) fn start_encrypting(&mut self) { debug_assert!(self.encrypt_state == DirectionState::Prepared); self.encrypt_state = DirectionState::Active; } /// Start using the `MessageDecrypter` previously provided to the previous /// call to `prepare_message_decrypter`. pub(crate) fn start_decrypting(&mut self) { debug_assert!(self.decrypt_state == DirectionState::Prepared); self.decrypt_state = DirectionState::Active; } /// Set and start using the given `MessageEncrypter` for future outgoing /// message encryption. pub(crate) fn set_message_encrypter( &mut self, cipher: Box, max_messages: u64, ) { self.prepare_message_encrypter(cipher, max_messages); self.start_encrypting(); } /// Set and start using the given `MessageDecrypter` for future incoming /// message decryption. pub(crate) fn set_message_decrypter(&mut self, cipher: Box) { self.prepare_message_decrypter(cipher); self.start_decrypting(); self.trial_decryption_len = None; } /// Set and start using the given `MessageDecrypter` for future incoming /// message decryption, and enable "trial decryption" mode for when TLS1.3 /// 0-RTT is attempted but rejected by the server. pub(crate) fn set_message_decrypter_with_trial_decryption( &mut self, cipher: Box, max_length: usize, ) { self.prepare_message_decrypter(cipher); self.start_decrypting(); self.trial_decryption_len = Some(max_length); } pub(crate) fn finish_trial_decryption(&mut self) { self.trial_decryption_len = None; } pub(crate) fn next_pre_encrypt_action(&self) -> PreEncryptAction { self.pre_encrypt_action(0) } /// Return a remedial action when we are near to encrypting too many messages. /// /// `add` is added to the current sequence number. `add` as `0` means /// "the next message processed by `encrypt_outgoing`" pub(crate) fn pre_encrypt_action(&self, add: u64) -> PreEncryptAction { match self.write_seq.saturating_add(add) { v if v == self.write_seq_max => PreEncryptAction::RefreshOrClose, SEQ_HARD_LIMIT.. => PreEncryptAction::Refuse, _ => PreEncryptAction::Nothing, } } pub(crate) fn is_encrypting(&self) -> bool { self.encrypt_state == DirectionState::Active } /// Return true if we have ever decrypted a message. This is used in place /// of checking the read_seq since that will be reset on key updates. pub(crate) fn has_decrypted(&self) -> bool { self.has_decrypted } pub(crate) fn write_seq(&self) -> u64 { self.write_seq } pub(crate) fn read_seq(&self) -> u64 { self.read_seq } pub(crate) fn encrypted_len(&self, payload_len: usize) -> usize { self.message_encrypter .encrypted_payload_len(payload_len) } fn doing_trial_decryption(&mut self, requested: usize) -> bool { match self .trial_decryption_len .and_then(|value| value.checked_sub(requested)) { Some(remaining) => { self.trial_decryption_len = Some(remaining); true } _ => false, } } } /// Result of decryption. #[derive(Debug)] pub(crate) struct Decrypted<'a> { /// Whether the peer appears to be getting close to encrypting too many messages with this key. pub(crate) want_close_before_decrypt: bool, /// The decrypted message. pub(crate) plaintext: InboundPlainMessage<'a>, } #[derive(Debug, Eq, PartialEq)] pub(crate) enum PreEncryptAction { /// No action is needed before calling `encrypt_outgoing` Nothing, /// A `key_update` request should be sent ASAP. /// /// If that is not possible (for example, the connection is TLS1.2), a `close_notify` /// alert should be sent instead. RefreshOrClose, /// Do not call `encrypt_outgoing` further, it will panic rather than /// over-use the key. Refuse, } const SEQ_SOFT_LIMIT: u64 = 0xffff_ffff_ffff_0000u64; const SEQ_HARD_LIMIT: u64 = 0xffff_ffff_ffff_fffeu64; #[cfg(test)] mod tests { use super::*; #[test] fn test_has_decrypted() { use crate::{ContentType, ProtocolVersion}; struct PassThroughDecrypter; impl MessageDecrypter for PassThroughDecrypter { fn decrypt<'a>( &mut self, m: InboundOpaqueMessage<'a>, _: u64, ) -> Result, Error> { Ok(m.into_plain_message()) } } // A record layer starts out invalid, having never decrypted. let mut record_layer = RecordLayer::new(); assert!(matches!( record_layer.decrypt_state, DirectionState::Invalid )); assert_eq!(record_layer.read_seq, 0); assert!(!record_layer.has_decrypted()); // Preparing the record layer should update the decrypt state, but shouldn't affect whether it // has decrypted. record_layer.prepare_message_decrypter(Box::new(PassThroughDecrypter)); assert!(matches!( record_layer.decrypt_state, DirectionState::Prepared )); assert_eq!(record_layer.read_seq, 0); assert!(!record_layer.has_decrypted()); // Starting decryption should update the decrypt state, but not affect whether it has decrypted. record_layer.start_decrypting(); assert!(matches!(record_layer.decrypt_state, DirectionState::Active)); assert_eq!(record_layer.read_seq, 0); assert!(!record_layer.has_decrypted()); // Decrypting a message should update the read_seq and track that we have now performed // a decryption. record_layer .decrypt_incoming(InboundOpaqueMessage::new( ContentType::Handshake, ProtocolVersion::TLSv1_2, &mut [0xC0, 0xFF, 0xEE], )) .unwrap(); assert!(matches!(record_layer.decrypt_state, DirectionState::Active)); assert_eq!(record_layer.read_seq, 1); assert!(record_layer.has_decrypted()); // Resetting the record layer message decrypter (as if a key update occurred) should reset // the read_seq number, but not our knowledge of whether we have decrypted previously. record_layer.set_message_decrypter(Box::new(PassThroughDecrypter)); assert!(matches!(record_layer.decrypt_state, DirectionState::Active)); assert_eq!(record_layer.read_seq, 0); assert!(record_layer.has_decrypted()); } }