diff --git a/lightning/src/blinded_path/message.rs b/lightning/src/blinded_path/message.rs index 164cfcfb1ad..ba121f2a32d 100644 --- a/lightning/src/blinded_path/message.rs +++ b/lightning/src/blinded_path/message.rs @@ -92,6 +92,7 @@ impl BlindedMessagePath { recipient_node_id, context, &blinding_secret, + [41; 32], // TODO: Pass this in ) .map_err(|_| ())?, })) @@ -514,18 +515,19 @@ pub(crate) const MESSAGE_PADDING_ROUND_OFF: usize = 100; pub(super) fn blinded_hops( secp_ctx: &Secp256k1, intermediate_nodes: &[MessageForwardNode], recipient_node_id: PublicKey, context: MessageContext, session_priv: &SecretKey, + local_node_receive_key: [u8; 32], ) -> Result, secp256k1::Error> { let pks = intermediate_nodes .iter() - .map(|node| node.node_id) - .chain(core::iter::once(recipient_node_id)); + .map(|node| (node.node_id, None)) + .chain(core::iter::once((recipient_node_id, Some(local_node_receive_key)))); let is_compact = intermediate_nodes.iter().any(|node| node.short_channel_id.is_some()); let tlvs = pks .clone() .skip(1) // The first node's TLVs contains the next node's pubkey .zip(intermediate_nodes.iter().map(|node| node.short_channel_id)) - .map(|(pubkey, scid)| match scid { + .map(|((pubkey, _), scid)| match scid { Some(scid) => NextMessageHop::ShortChannelId(scid), None => NextMessageHop::NodeId(pubkey), }) diff --git a/lightning/src/blinded_path/payment.rs b/lightning/src/blinded_path/payment.rs index 95ad76c3644..96913ef3c62 100644 --- a/lightning/src/blinded_path/payment.rs +++ b/lightning/src/blinded_path/payment.rs @@ -664,8 +664,10 @@ pub(super) fn blinded_hops( secp_ctx: &Secp256k1, intermediate_nodes: &[PaymentForwardNode], payee_node_id: PublicKey, payee_tlvs: ReceiveTlvs, session_priv: &SecretKey, ) -> Result, secp256k1::Error> { - let pks = - intermediate_nodes.iter().map(|node| node.node_id).chain(core::iter::once(payee_node_id)); + let pks = intermediate_nodes + .iter() + .map(|node| (node.node_id, None)) + .chain(core::iter::once((payee_node_id, None))); let tlvs = intermediate_nodes .iter() .map(|node| BlindedPaymentTlvsRef::Forward(&node.tlvs)) diff --git a/lightning/src/blinded_path/utils.rs b/lightning/src/blinded_path/utils.rs index b17fa01bbcf..05caae688c5 100644 --- a/lightning/src/blinded_path/utils.rs +++ b/lightning/src/blinded_path/utils.rs @@ -17,7 +17,8 @@ use bitcoin::secp256k1::{self, PublicKey, Scalar, Secp256k1, SecretKey}; use super::message::BlindedMessagePath; use super::{BlindedHop, BlindedPath}; -use crate::crypto::streams::ChaChaPolyWriteAdapter; +use crate::crypto::chacha20poly1305rfc::ChaCha20Poly1305RFC; +use crate::crypto::streams::chachapoly_encrypt_with_swapped_aad; use crate::io; use crate::ln::onion_utils; use crate::onion_message::messenger::Destination; @@ -105,7 +106,6 @@ macro_rules! build_keys_helper { }; } -#[inline] pub(crate) fn construct_keys_for_onion_message<'a, T, I, F>( secp_ctx: &Secp256k1, unblinded_path: I, destination: Destination, session_priv: &SecretKey, mut callback: F, @@ -113,9 +113,13 @@ pub(crate) fn construct_keys_for_onion_message<'a, T, I, F>( where T: secp256k1::Signing + secp256k1::Verification, I: Iterator, - F: FnMut(PublicKey, SharedSecret, PublicKey, [u8; 32], Option, Option>), + F: FnMut(SharedSecret, PublicKey, [u8; 32], Option, Option>), { - build_keys_helper!(session_priv, secp_ctx, callback); + let mut callback_wrapper = + |_, ss, pk, encrypted_payload_rho, unblinded_hop_data, encrypted_payload| { + callback(ss, pk, encrypted_payload_rho, unblinded_hop_data, encrypted_payload); + }; + build_keys_helper!(session_priv, secp_ctx, callback_wrapper); for pk in unblinded_path { build_keys_in_loop!(pk, false, None); @@ -133,8 +137,7 @@ where Ok(()) } -#[inline] -pub(super) fn construct_keys_for_blinded_path<'a, T, I, F, H>( +fn construct_keys_for_blinded_path<'a, T, I, F, H>( secp_ctx: &Secp256k1, unblinded_path: I, session_priv: &SecretKey, mut callback: F, ) -> Result<(), secp256k1::Error> where @@ -145,7 +148,8 @@ where { build_keys_helper!(session_priv, secp_ctx, callback); - for pk in unblinded_path { + let mut iter = unblinded_path.peekable(); + while let Some(pk) = iter.next() { build_keys_in_loop!(pk, false, None); } Ok(()) @@ -153,6 +157,7 @@ where struct PublicKeyWithTlvs { pubkey: PublicKey, + hop_recv_key: Option<[u8; 32]>, tlvs: W, } @@ -167,20 +172,26 @@ pub(crate) fn construct_blinded_hops<'a, T, I, W>( ) -> Result, secp256k1::Error> where T: secp256k1::Signing + secp256k1::Verification, - I: Iterator, + I: Iterator), W)>, W: Writeable, { let mut blinded_hops = Vec::with_capacity(unblinded_path.size_hint().0); construct_keys_for_blinded_path( secp_ctx, - unblinded_path.map(|(pubkey, tlvs)| PublicKeyWithTlvs { pubkey, tlvs }), + unblinded_path.map(|((pubkey, hop_recv_key), tlvs)| PublicKeyWithTlvs { + pubkey, + hop_recv_key, + tlvs, + }), session_priv, |blinded_node_id, _, _, encrypted_payload_rho, unblinded_hop_data, _| { + let hop_data = unblinded_hop_data.unwrap(); blinded_hops.push(BlindedHop { blinded_node_id, encrypted_payload: encrypt_payload( - unblinded_hop_data.unwrap().tlvs, + hop_data.tlvs, encrypted_payload_rho, + hop_data.hop_recv_key, ), }); }, @@ -189,9 +200,19 @@ where } /// Encrypt TLV payload to be used as a [`crate::blinded_path::BlindedHop::encrypted_payload`]. -fn encrypt_payload(payload: P, encrypted_tlvs_rho: [u8; 32]) -> Vec { - let write_adapter = ChaChaPolyWriteAdapter::new(encrypted_tlvs_rho, &payload); - write_adapter.encode() +fn encrypt_payload( + payload: P, encrypted_tlvs_rho: [u8; 32], hop_recv_key: Option<[u8; 32]>, +) -> Vec { + let mut payload_data = payload.encode(); + if let Some(hop_recv_key) = hop_recv_key { + chachapoly_encrypt_with_swapped_aad(payload_data, encrypted_tlvs_rho, hop_recv_key) + } else { + let mut chacha = ChaCha20Poly1305RFC::new(&encrypted_tlvs_rho, &[0; 12], &[]); + let mut tag = [0; 16]; + chacha.encrypt_full_message_in_place(&mut payload_data, &mut tag); + payload_data.extend_from_slice(&tag); + payload_data + } } /// A data structure used exclusively to pad blinded path payloads, ensuring they are of diff --git a/lightning/src/crypto/chacha20poly1305rfc.rs b/lightning/src/crypto/chacha20poly1305rfc.rs index f1c261cb1f1..839fad9ce6c 100644 --- a/lightning/src/crypto/chacha20poly1305rfc.rs +++ b/lightning/src/crypto/chacha20poly1305rfc.rs @@ -10,249 +10,148 @@ // This is a port of Andrew Moons poly1305-donna // https://github.com/floodyberry/poly1305-donna -#[cfg(not(fuzzing))] -mod real_chachapoly { - use super::super::chacha20::ChaCha20; - use super::super::fixed_time_eq; - use super::super::poly1305::Poly1305; - - #[derive(Clone, Copy)] - pub struct ChaCha20Poly1305RFC { - cipher: ChaCha20, - mac: Poly1305, - finished: bool, - data_len: usize, - aad_len: u64, - } - - impl ChaCha20Poly1305RFC { - #[inline] - fn pad_mac_16(mac: &mut Poly1305, len: usize) { - if len % 16 != 0 { - mac.input(&[0; 16][0..16 - (len % 16)]); - } - } - pub fn new(key: &[u8], nonce: &[u8], aad: &[u8]) -> ChaCha20Poly1305RFC { - assert!(key.len() == 16 || key.len() == 32); - assert!(nonce.len() == 12); - - // Ehh, I'm too lazy to *also* tweak ChaCha20 to make it RFC-compliant - assert!(nonce[0] == 0 && nonce[1] == 0 && nonce[2] == 0 && nonce[3] == 0); - - let mut cipher = ChaCha20::new(key, &nonce[4..]); - let mut mac_key = [0u8; 64]; - let zero_key = [0u8; 64]; - cipher.process(&zero_key, &mut mac_key); - - let mut mac = Poly1305::new(&mac_key[..32]); - mac.input(aad); - ChaCha20Poly1305RFC::pad_mac_16(&mut mac, aad.len()); - - ChaCha20Poly1305RFC { - cipher, - mac, - finished: false, - data_len: 0, - aad_len: aad.len() as u64, - } - } - - pub fn encrypt(&mut self, input: &[u8], output: &mut [u8], out_tag: &mut [u8]) { - assert!(input.len() == output.len()); - assert!(!self.finished); - self.cipher.process(input, output); - self.data_len += input.len(); - self.mac.input(output); - ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); - self.finished = true; - self.mac.input(&self.aad_len.to_le_bytes()); - self.mac.input(&(self.data_len as u64).to_le_bytes()); - self.mac.raw_result(out_tag); - } - - pub fn encrypt_full_message_in_place( - &mut self, input_output: &mut [u8], out_tag: &mut [u8], - ) { - self.encrypt_in_place(input_output); - self.finish_and_get_tag(out_tag); - } - - // Encrypt `input_output` in-place. To finish and calculate the tag, use `finish_and_get_tag` - // below. - pub(in super::super) fn encrypt_in_place(&mut self, input_output: &mut [u8]) { - debug_assert!(!self.finished); - self.cipher.process_in_place(input_output); - self.data_len += input_output.len(); - self.mac.input(input_output); - } - - // If we were previously encrypting with `encrypt_in_place`, this method can be used to finish - // encrypting and calculate the tag. - pub(in super::super) fn finish_and_get_tag(&mut self, out_tag: &mut [u8]) { - debug_assert!(!self.finished); - ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); - self.finished = true; - self.mac.input(&self.aad_len.to_le_bytes()); - self.mac.input(&(self.data_len as u64).to_le_bytes()); - self.mac.raw_result(out_tag); - } - - /// Decrypt the `input`, checking the given `tag` prior to writing the decrypted contents - /// into `output`. Note that, because `output` is not touched until the `tag` is checked, - /// this decryption is *variable time*. - pub fn variable_time_decrypt( - &mut self, input: &[u8], output: &mut [u8], tag: &[u8], - ) -> Result<(), ()> { - assert!(input.len() == output.len()); - assert!(!self.finished); - - self.finished = true; - - self.mac.input(input); - - self.data_len += input.len(); - ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); - self.mac.input(&self.aad_len.to_le_bytes()); - self.mac.input(&(self.data_len as u64).to_le_bytes()); - - let mut calc_tag = [0u8; 16]; - self.mac.raw_result(&mut calc_tag); - if fixed_time_eq(&calc_tag, tag) { - self.cipher.process(input, output); - Ok(()) - } else { - Err(()) - } - } - - pub fn check_decrypt_in_place( - &mut self, input_output: &mut [u8], tag: &[u8], - ) -> Result<(), ()> { - self.decrypt_in_place(input_output); - if self.finish_and_check_tag(tag) { - Ok(()) - } else { - Err(()) - } - } - - /// Decrypt in place, without checking the tag. Use `finish_and_check_tag` to check it - /// later when decryption finishes. - /// - /// Should never be `pub` because the public API should always enforce tag checking. - pub(in super::super) fn decrypt_in_place(&mut self, input_output: &mut [u8]) { - debug_assert!(!self.finished); - self.mac.input(input_output); - self.data_len += input_output.len(); - self.cipher.process_in_place(input_output); - } - - /// If we were previously decrypting with `just_decrypt_in_place`, this method must be used - /// to check the tag. Returns whether or not the tag is valid. - pub(in super::super) fn finish_and_check_tag(&mut self, tag: &[u8]) -> bool { - debug_assert!(!self.finished); - self.finished = true; - ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); - self.mac.input(&self.aad_len.to_le_bytes()); - self.mac.input(&(self.data_len as u64).to_le_bytes()); +use super::chacha20::ChaCha20; +use super::fixed_time_eq; +use super::poly1305::Poly1305; + +pub struct ChaCha20Poly1305RFC { + cipher: ChaCha20, + mac: Poly1305, + finished: bool, + data_len: usize, + aad_len: u64, +} - let mut calc_tag = [0u8; 16]; - self.mac.raw_result(&mut calc_tag); - if fixed_time_eq(&calc_tag, tag) { - true - } else { - false - } +impl ChaCha20Poly1305RFC { + #[inline] + fn pad_mac_16(mac: &mut Poly1305, len: usize) { + if len % 16 != 0 { + mac.input(&[0; 16][0..16 - (len % 16)]); } } -} -#[cfg(not(fuzzing))] -pub use self::real_chachapoly::ChaCha20Poly1305RFC; - -#[cfg(fuzzing)] -mod fuzzy_chachapoly { - #[derive(Clone, Copy)] - pub struct ChaCha20Poly1305RFC { - tag: [u8; 16], - finished: bool, + pub fn new(key: &[u8], nonce: &[u8], aad: &[u8]) -> ChaCha20Poly1305RFC { + assert!(key.len() == 16 || key.len() == 32); + assert!(nonce.len() == 12); + + // Ehh, I'm too lazy to *also* tweak ChaCha20 to make it RFC-compliant + assert!(nonce[0] == 0 && nonce[1] == 0 && nonce[2] == 0 && nonce[3] == 0); + + let mut cipher = ChaCha20::new(key, &nonce[4..]); + let mut mac_key = [0u8; 64]; + let zero_key = [0u8; 64]; + cipher.process(&zero_key, &mut mac_key); + + #[cfg(not(fuzzing))] + let mut mac = Poly1305::new(&mac_key[..32]); + #[cfg(fuzzing)] + let mut mac = Poly1305::new(&key); + mac.input(aad); + ChaCha20Poly1305RFC::pad_mac_16(&mut mac, aad.len()); + + ChaCha20Poly1305RFC { cipher, mac, finished: false, data_len: 0, aad_len: aad.len() as u64 } } - impl ChaCha20Poly1305RFC { - pub fn new(key: &[u8], nonce: &[u8], _aad: &[u8]) -> ChaCha20Poly1305RFC { - assert!(key.len() == 16 || key.len() == 32); - assert!(nonce.len() == 12); - - // Ehh, I'm too lazy to *also* tweak ChaCha20 to make it RFC-compliant - assert!(nonce[0] == 0 && nonce[1] == 0 && nonce[2] == 0 && nonce[3] == 0); - let mut tag = [0; 16]; - tag.copy_from_slice(&key[0..16]); + pub fn encrypt(&mut self, input: &[u8], output: &mut [u8], out_tag: &mut [u8]) { + assert!(input.len() == output.len()); + assert!(!self.finished); + self.cipher.process(input, output); + self.data_len += input.len(); + self.mac.input(output); + ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); + self.finished = true; + self.mac.input(&self.aad_len.to_le_bytes()); + self.mac.input(&(self.data_len as u64).to_le_bytes()); + out_tag.copy_from_slice(&self.mac.result()); + } - ChaCha20Poly1305RFC { tag, finished: false } - } + pub fn encrypt_full_message_in_place(&mut self, input_output: &mut [u8], out_tag: &mut [u8]) { + self.encrypt_in_place(input_output); + self.finish_and_get_tag(out_tag); + } - pub fn encrypt(&mut self, input: &[u8], output: &mut [u8], out_tag: &mut [u8]) { - assert!(input.len() == output.len()); - assert!(self.finished == false); + // Encrypt `input_output` in-place. To finish and calculate the tag, use `finish_and_get_tag` + // below. + pub(in super::super) fn encrypt_in_place(&mut self, input_output: &mut [u8]) { + debug_assert!(!self.finished); + self.cipher.process_in_place(input_output); + self.data_len += input_output.len(); + self.mac.input(input_output); + } - output.copy_from_slice(&input); - out_tag.copy_from_slice(&self.tag); - self.finished = true; - } + // If we were previously encrypting with `encrypt_in_place`, this method can be used to finish + // encrypting and calculate the tag. + pub(in super::super) fn finish_and_get_tag(&mut self, out_tag: &mut [u8]) { + debug_assert!(!self.finished); + ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); + self.finished = true; + self.mac.input(&self.aad_len.to_le_bytes()); + self.mac.input(&(self.data_len as u64).to_le_bytes()); + out_tag.copy_from_slice(&self.mac.result()); + } - pub fn encrypt_full_message_in_place( - &mut self, input_output: &mut [u8], out_tag: &mut [u8], - ) { - self.encrypt_in_place(input_output); - self.finish_and_get_tag(out_tag); - } + /// Decrypt the `input`, checking the given `tag` prior to writing the decrypted contents + /// into `output`. Note that, because `output` is not touched until the `tag` is checked, + /// this decryption is *variable time*. + pub fn variable_time_decrypt( + &mut self, input: &[u8], output: &mut [u8], tag: &[u8], + ) -> Result<(), ()> { + assert!(input.len() == output.len()); + assert!(!self.finished); - pub(in super::super) fn encrypt_in_place(&mut self, _input_output: &mut [u8]) { - assert!(self.finished == false); - } + self.finished = true; - pub(in super::super) fn finish_and_get_tag(&mut self, out_tag: &mut [u8]) { - assert!(self.finished == false); - out_tag.copy_from_slice(&self.tag); - self.finished = true; - } + self.mac.input(input); - pub fn variable_time_decrypt( - &mut self, input: &[u8], output: &mut [u8], tag: &[u8], - ) -> Result<(), ()> { - assert!(input.len() == output.len()); - assert!(self.finished == false); + self.data_len += input.len(); + ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); + self.mac.input(&self.aad_len.to_le_bytes()); + self.mac.input(&(self.data_len as u64).to_le_bytes()); - if tag[..] != self.tag[..] { - return Err(()); - } - output.copy_from_slice(input); - self.finished = true; + let calc_tag = self.mac.result(); + if fixed_time_eq(&calc_tag, tag) { + self.cipher.process(input, output); Ok(()) + } else { + Err(()) } + } - pub fn check_decrypt_in_place( - &mut self, input_output: &mut [u8], tag: &[u8], - ) -> Result<(), ()> { - self.decrypt_in_place(input_output); - if self.finish_and_check_tag(tag) { - Ok(()) - } else { - Err(()) - } + pub fn check_decrypt_in_place( + &mut self, input_output: &mut [u8], tag: &[u8], + ) -> Result<(), ()> { + self.decrypt_in_place(input_output); + if self.finish_and_check_tag(tag) { + Ok(()) + } else { + Err(()) } + } - pub(in super::super) fn decrypt_in_place(&mut self, _input: &mut [u8]) { - assert!(self.finished == false); - } + /// Decrypt in place, without checking the tag. Use `finish_and_check_tag` to check it + /// later when decryption finishes. + /// + /// Should never be `pub` because the public API should always enforce tag checking. + pub(in super::super) fn decrypt_in_place(&mut self, input_output: &mut [u8]) { + debug_assert!(!self.finished); + self.mac.input(input_output); + self.data_len += input_output.len(); + self.cipher.process_in_place(input_output); + } - pub(in super::super) fn finish_and_check_tag(&mut self, tag: &[u8]) -> bool { - if tag[..] != self.tag[..] { - return false; - } - self.finished = true; + /// If we were previously decrypting with `just_decrypt_in_place`, this method must be used + /// to check the tag. Returns whether or not the tag is valid. + pub(in super::super) fn finish_and_check_tag(&mut self, tag: &[u8]) -> bool { + debug_assert!(!self.finished); + self.finished = true; + ChaCha20Poly1305RFC::pad_mac_16(&mut self.mac, self.data_len); + self.mac.input(&self.aad_len.to_le_bytes()); + self.mac.input(&(self.data_len as u64).to_le_bytes()); + + let calc_tag = self.mac.result(); + if fixed_time_eq(&calc_tag, tag) { true + } else { + false } } } -#[cfg(fuzzing)] -pub use self::fuzzy_chachapoly::ChaCha20Poly1305RFC; diff --git a/lightning/src/crypto/mod.rs b/lightning/src/crypto/mod.rs index 4dc851a3683..478918a49a8 100644 --- a/lightning/src/crypto/mod.rs +++ b/lightning/src/crypto/mod.rs @@ -1,9 +1,14 @@ #[cfg(not(fuzzing))] -use bitcoin::hashes::cmp::fixed_time_eq; +pub(crate) use bitcoin::hashes::cmp::fixed_time_eq; + +#[cfg(fuzzing)] +fn fixed_time_eq(a: &[u8], b: &[u8]) -> bool { + assert_eq!(a.len(), b.len()); + a == b +} pub(crate) mod chacha20; pub(crate) mod chacha20poly1305rfc; -#[cfg(not(fuzzing))] pub(crate) mod poly1305; pub(crate) mod streams; pub(crate) mod utils; diff --git a/lightning/src/crypto/poly1305.rs b/lightning/src/crypto/poly1305.rs index 6ac1c6c9694..c7306863e9e 100644 --- a/lightning/src/crypto/poly1305.rs +++ b/lightning/src/crypto/poly1305.rs @@ -7,385 +7,426 @@ // This is a port of Andrew Moons poly1305-donna // https://github.com/floodyberry/poly1305-donna -use core::cmp::min; - -use crate::prelude::*; - -#[derive(Clone, Copy)] -pub struct Poly1305 { - r: [u32; 5], - h: [u32; 5], - pad: [u32; 4], - leftover: usize, - buffer: [u8; 16], - finalized: bool, -} - -impl Poly1305 { - pub fn new(key: &[u8]) -> Poly1305 { - assert!(key.len() == 32); - let mut poly = Poly1305 { - r: [0u32; 5], - h: [0u32; 5], - pad: [0u32; 4], - leftover: 0, - buffer: [0u8; 16], - finalized: false, - }; - - // r &= 0xffffffc0ffffffc0ffffffc0fffffff - poly.r[0] = (u32::from_le_bytes(key[0..4].try_into().expect("len is 4"))) & 0x3ffffff; - poly.r[1] = (u32::from_le_bytes(key[3..7].try_into().expect("len is 4")) >> 2) & 0x3ffff03; - poly.r[2] = (u32::from_le_bytes(key[6..10].try_into().expect("len is 4")) >> 4) & 0x3ffc0ff; - poly.r[3] = (u32::from_le_bytes(key[9..13].try_into().expect("len is 4")) >> 6) & 0x3f03fff; - poly.r[4] = - (u32::from_le_bytes(key[12..16].try_into().expect("len is 4")) >> 8) & 0x00fffff; - - poly.pad[0] = u32::from_le_bytes(key[16..20].try_into().expect("len is 4")); - poly.pad[1] = u32::from_le_bytes(key[20..24].try_into().expect("len is 4")); - poly.pad[2] = u32::from_le_bytes(key[24..28].try_into().expect("len is 4")); - poly.pad[3] = u32::from_le_bytes(key[28..32].try_into().expect("len is 4")); - - poly +#[cfg(not(fuzzing))] +mod fuzzy_poly1305 { + use core::cmp::min; + + #[derive(Clone, Copy)] + pub struct Poly1305 { + r: [u32; 5], + h: [u32; 5], + pad: [u32; 4], + leftover: usize, + buffer: [u8; 16], + finalized: bool, } - fn block(&mut self, m: &[u8]) { - let hibit: u32 = if self.finalized { 0 } else { 1 << 24 }; - - let r0 = self.r[0]; - let r1 = self.r[1]; - let r2 = self.r[2]; - let r3 = self.r[3]; - let r4 = self.r[4]; - - let s1 = r1 * 5; - let s2 = r2 * 5; - let s3 = r3 * 5; - let s4 = r4 * 5; - - let mut h0 = self.h[0]; - let mut h1 = self.h[1]; - let mut h2 = self.h[2]; - let mut h3 = self.h[3]; - let mut h4 = self.h[4]; - - // h += m - h0 += (u32::from_le_bytes(m[0..4].try_into().expect("len is 4"))) & 0x3ffffff; - h1 += (u32::from_le_bytes(m[3..7].try_into().expect("len is 4")) >> 2) & 0x3ffffff; - h2 += (u32::from_le_bytes(m[6..10].try_into().expect("len is 4")) >> 4) & 0x3ffffff; - h3 += (u32::from_le_bytes(m[9..13].try_into().expect("len is 4")) >> 6) & 0x3ffffff; - h4 += (u32::from_le_bytes(m[12..16].try_into().expect("len is 4")) >> 8) | hibit; - - // h *= r - let d0 = (h0 as u64 * r0 as u64) - + (h1 as u64 * s4 as u64) - + (h2 as u64 * s3 as u64) - + (h3 as u64 * s2 as u64) - + (h4 as u64 * s1 as u64); - let mut d1 = (h0 as u64 * r1 as u64) - + (h1 as u64 * r0 as u64) - + (h2 as u64 * s4 as u64) - + (h3 as u64 * s3 as u64) - + (h4 as u64 * s2 as u64); - let mut d2 = (h0 as u64 * r2 as u64) - + (h1 as u64 * r1 as u64) - + (h2 as u64 * r0 as u64) - + (h3 as u64 * s4 as u64) - + (h4 as u64 * s3 as u64); - let mut d3 = (h0 as u64 * r3 as u64) - + (h1 as u64 * r2 as u64) - + (h2 as u64 * r1 as u64) - + (h3 as u64 * r0 as u64) - + (h4 as u64 * s4 as u64); - let mut d4 = (h0 as u64 * r4 as u64) - + (h1 as u64 * r3 as u64) - + (h2 as u64 * r2 as u64) - + (h3 as u64 * r1 as u64) - + (h4 as u64 * r0 as u64); - - // (partial) h %= p - let mut c: u32; - c = (d0 >> 26) as u32; - h0 = d0 as u32 & 0x3ffffff; - d1 += c as u64; - c = (d1 >> 26) as u32; - h1 = d1 as u32 & 0x3ffffff; - d2 += c as u64; - c = (d2 >> 26) as u32; - h2 = d2 as u32 & 0x3ffffff; - d3 += c as u64; - c = (d3 >> 26) as u32; - h3 = d3 as u32 & 0x3ffffff; - d4 += c as u64; - c = (d4 >> 26) as u32; - h4 = d4 as u32 & 0x3ffffff; - h0 += c * 5; - c = h0 >> 26; - h0 &= 0x3ffffff; - h1 += c; - - self.h[0] = h0; - self.h[1] = h1; - self.h[2] = h2; - self.h[3] = h3; - self.h[4] = h4; - } + impl Poly1305 { + pub fn new(key: &[u8]) -> Poly1305 { + assert!(key.len() == 32); + let mut poly = Poly1305 { + r: [0u32; 5], + h: [0u32; 5], + pad: [0u32; 4], + leftover: 0, + buffer: [0u8; 16], + finalized: false, + }; + + // r &= 0xffffffc0ffffffc0ffffffc0fffffff + poly.r[0] = (u32::from_le_bytes(key[0..4].try_into().expect("len is 4"))) & 0x3ffffff; + poly.r[1] = + (u32::from_le_bytes(key[3..7].try_into().expect("len is 4")) >> 2) & 0x3ffff03; + poly.r[2] = + (u32::from_le_bytes(key[6..10].try_into().expect("len is 4")) >> 4) & 0x3ffc0ff; + poly.r[3] = + (u32::from_le_bytes(key[9..13].try_into().expect("len is 4")) >> 6) & 0x3f03fff; + poly.r[4] = + (u32::from_le_bytes(key[12..16].try_into().expect("len is 4")) >> 8) & 0x00fffff; + + poly.pad[0] = u32::from_le_bytes(key[16..20].try_into().expect("len is 4")); + poly.pad[1] = u32::from_le_bytes(key[20..24].try_into().expect("len is 4")); + poly.pad[2] = u32::from_le_bytes(key[24..28].try_into().expect("len is 4")); + poly.pad[3] = u32::from_le_bytes(key[28..32].try_into().expect("len is 4")); + + poly + } + + fn block(&mut self, m: &[u8]) { + let hibit: u32 = if self.finalized { 0 } else { 1 << 24 }; + + let r0 = self.r[0]; + let r1 = self.r[1]; + let r2 = self.r[2]; + let r3 = self.r[3]; + let r4 = self.r[4]; + + let s1 = r1 * 5; + let s2 = r2 * 5; + let s3 = r3 * 5; + let s4 = r4 * 5; + + let mut h0 = self.h[0]; + let mut h1 = self.h[1]; + let mut h2 = self.h[2]; + let mut h3 = self.h[3]; + let mut h4 = self.h[4]; + + // h += m + h0 += (u32::from_le_bytes(m[0..4].try_into().expect("len is 4"))) & 0x3ffffff; + h1 += (u32::from_le_bytes(m[3..7].try_into().expect("len is 4")) >> 2) & 0x3ffffff; + h2 += (u32::from_le_bytes(m[6..10].try_into().expect("len is 4")) >> 4) & 0x3ffffff; + h3 += (u32::from_le_bytes(m[9..13].try_into().expect("len is 4")) >> 6) & 0x3ffffff; + h4 += (u32::from_le_bytes(m[12..16].try_into().expect("len is 4")) >> 8) | hibit; + + // h *= r + let d0 = (h0 as u64 * r0 as u64) + + (h1 as u64 * s4 as u64) + + (h2 as u64 * s3 as u64) + + (h3 as u64 * s2 as u64) + + (h4 as u64 * s1 as u64); + let mut d1 = (h0 as u64 * r1 as u64) + + (h1 as u64 * r0 as u64) + + (h2 as u64 * s4 as u64) + + (h3 as u64 * s3 as u64) + + (h4 as u64 * s2 as u64); + let mut d2 = (h0 as u64 * r2 as u64) + + (h1 as u64 * r1 as u64) + + (h2 as u64 * r0 as u64) + + (h3 as u64 * s4 as u64) + + (h4 as u64 * s3 as u64); + let mut d3 = (h0 as u64 * r3 as u64) + + (h1 as u64 * r2 as u64) + + (h2 as u64 * r1 as u64) + + (h3 as u64 * r0 as u64) + + (h4 as u64 * s4 as u64); + let mut d4 = (h0 as u64 * r4 as u64) + + (h1 as u64 * r3 as u64) + + (h2 as u64 * r2 as u64) + + (h3 as u64 * r1 as u64) + + (h4 as u64 * r0 as u64); + + // (partial) h %= p + let mut c: u32; + c = (d0 >> 26) as u32; + h0 = d0 as u32 & 0x3ffffff; + d1 += c as u64; + c = (d1 >> 26) as u32; + h1 = d1 as u32 & 0x3ffffff; + d2 += c as u64; + c = (d2 >> 26) as u32; + h2 = d2 as u32 & 0x3ffffff; + d3 += c as u64; + c = (d3 >> 26) as u32; + h3 = d3 as u32 & 0x3ffffff; + d4 += c as u64; + c = (d4 >> 26) as u32; + h4 = d4 as u32 & 0x3ffffff; + h0 += c * 5; + c = h0 >> 26; + h0 &= 0x3ffffff; + h1 += c; + + self.h[0] = h0; + self.h[1] = h1; + self.h[2] = h2; + self.h[3] = h3; + self.h[4] = h4; + } - pub fn finish(&mut self) { - if self.leftover > 0 { - self.buffer[self.leftover] = 1; - for i in self.leftover + 1..16 { - self.buffer[i] = 0; + pub fn finish(&mut self) { + if self.leftover > 0 { + self.buffer[self.leftover] = 1; + for i in self.leftover + 1..16 { + self.buffer[i] = 0; + } + self.finalized = true; + let tmp = self.buffer; + self.block(&tmp); } - self.finalized = true; - let tmp = self.buffer; - self.block(&tmp); + + // fully carry h + let mut h0 = self.h[0]; + let mut h1 = self.h[1]; + let mut h2 = self.h[2]; + let mut h3 = self.h[3]; + let mut h4 = self.h[4]; + + let mut c: u32; + c = h1 >> 26; + h1 &= 0x3ffffff; + h2 += c; + c = h2 >> 26; + h2 &= 0x3ffffff; + h3 += c; + c = h3 >> 26; + h3 &= 0x3ffffff; + h4 += c; + c = h4 >> 26; + h4 &= 0x3ffffff; + h0 += c * 5; + c = h0 >> 26; + h0 &= 0x3ffffff; + h1 += c; + + // compute h + -p + let mut g0 = h0.wrapping_add(5); + c = g0 >> 26; + g0 &= 0x3ffffff; + let mut g1 = h1.wrapping_add(c); + c = g1 >> 26; + g1 &= 0x3ffffff; + let mut g2 = h2.wrapping_add(c); + c = g2 >> 26; + g2 &= 0x3ffffff; + let mut g3 = h3.wrapping_add(c); + c = g3 >> 26; + g3 &= 0x3ffffff; + let mut g4 = h4.wrapping_add(c).wrapping_sub(1 << 26); + + // select h if h < p, or h + -p if h >= p + let mut mask = (g4 >> (32 - 1)).wrapping_sub(1); + g0 &= mask; + g1 &= mask; + g2 &= mask; + g3 &= mask; + g4 &= mask; + mask = !mask; + h0 = (h0 & mask) | g0; + h1 = (h1 & mask) | g1; + h2 = (h2 & mask) | g2; + h3 = (h3 & mask) | g3; + h4 = (h4 & mask) | g4; + + // h = h % (2^128) + h0 = ((h0) | (h1 << 26)) & 0xffffffff; + h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff; + h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff; + h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff; + + // h = mac = (h + pad) % (2^128) + let mut f: u64; + f = h0 as u64 + self.pad[0] as u64; + h0 = f as u32; + f = h1 as u64 + self.pad[1] as u64 + (f >> 32); + h1 = f as u32; + f = h2 as u64 + self.pad[2] as u64 + (f >> 32); + h2 = f as u32; + f = h3 as u64 + self.pad[3] as u64 + (f >> 32); + h3 = f as u32; + + self.h[0] = h0; + self.h[1] = h1; + self.h[2] = h2; + self.h[3] = h3; } - // fully carry h - let mut h0 = self.h[0]; - let mut h1 = self.h[1]; - let mut h2 = self.h[2]; - let mut h3 = self.h[3]; - let mut h4 = self.h[4]; - - let mut c: u32; - c = h1 >> 26; - h1 &= 0x3ffffff; - h2 += c; - c = h2 >> 26; - h2 &= 0x3ffffff; - h3 += c; - c = h3 >> 26; - h3 &= 0x3ffffff; - h4 += c; - c = h4 >> 26; - h4 &= 0x3ffffff; - h0 += c * 5; - c = h0 >> 26; - h0 &= 0x3ffffff; - h1 += c; - - // compute h + -p - let mut g0 = h0.wrapping_add(5); - c = g0 >> 26; - g0 &= 0x3ffffff; - let mut g1 = h1.wrapping_add(c); - c = g1 >> 26; - g1 &= 0x3ffffff; - let mut g2 = h2.wrapping_add(c); - c = g2 >> 26; - g2 &= 0x3ffffff; - let mut g3 = h3.wrapping_add(c); - c = g3 >> 26; - g3 &= 0x3ffffff; - let mut g4 = h4.wrapping_add(c).wrapping_sub(1 << 26); - - // select h if h < p, or h + -p if h >= p - let mut mask = (g4 >> (32 - 1)).wrapping_sub(1); - g0 &= mask; - g1 &= mask; - g2 &= mask; - g3 &= mask; - g4 &= mask; - mask = !mask; - h0 = (h0 & mask) | g0; - h1 = (h1 & mask) | g1; - h2 = (h2 & mask) | g2; - h3 = (h3 & mask) | g3; - h4 = (h4 & mask) | g4; - - // h = h % (2^128) - h0 = ((h0) | (h1 << 26)) & 0xffffffff; - h1 = ((h1 >> 6) | (h2 << 20)) & 0xffffffff; - h2 = ((h2 >> 12) | (h3 << 14)) & 0xffffffff; - h3 = ((h3 >> 18) | (h4 << 8)) & 0xffffffff; - - // h = mac = (h + pad) % (2^128) - let mut f: u64; - f = h0 as u64 + self.pad[0] as u64; - h0 = f as u32; - f = h1 as u64 + self.pad[1] as u64 + (f >> 32); - h1 = f as u32; - f = h2 as u64 + self.pad[2] as u64 + (f >> 32); - h2 = f as u32; - f = h3 as u64 + self.pad[3] as u64 + (f >> 32); - h3 = f as u32; - - self.h[0] = h0; - self.h[1] = h1; - self.h[2] = h2; - self.h[3] = h3; - } + pub fn input(&mut self, data: &[u8]) { + assert!(!self.finalized); + let mut m = data; - pub fn input(&mut self, data: &[u8]) { - assert!(!self.finalized); - let mut m = data; + if self.leftover > 0 { + let want = min(16 - self.leftover, m.len()); + for i in 0..want { + self.buffer[self.leftover + i] = m[i]; + } + m = &m[want..]; + self.leftover += want; - if self.leftover > 0 { - let want = min(16 - self.leftover, m.len()); - for i in 0..want { - self.buffer[self.leftover + i] = m[i]; - } - m = &m[want..]; - self.leftover += want; + if self.leftover < 16 { + return; + } - if self.leftover < 16 { - return; + // self.block(self.buffer[..]); + let tmp = self.buffer; + self.block(&tmp); + + self.leftover = 0; } - // self.block(self.buffer[..]); - let tmp = self.buffer; - self.block(&tmp); + while m.len() >= 16 { + self.block(&m[0..16]); + m = &m[16..]; + } - self.leftover = 0; + for i in 0..m.len() { + self.buffer[i] = m[i]; + } + self.leftover = m.len(); } - while m.len() >= 16 { - self.block(&m[0..16]); - m = &m[16..]; + pub fn result(&mut self) -> [u8; 16] { + if !self.finalized { + self.finish(); + } + let mut output = [0; 16]; + output[0..4].copy_from_slice(&self.h[0].to_le_bytes()); + output[4..8].copy_from_slice(&self.h[1].to_le_bytes()); + output[8..12].copy_from_slice(&self.h[2].to_le_bytes()); + output[12..16].copy_from_slice(&self.h[3].to_le_bytes()); + output } + } - for i in 0..m.len() { - self.buffer[i] = m[i]; + #[cfg(test)] + mod test { + use core::iter::repeat; + + use super::Poly1305; + + fn poly1305(key: &[u8], msg: &[u8], mac: &mut [u8; 16]) { + let mut poly = Poly1305::new(key); + poly.input(msg); + *mac = poly.result(); } - self.leftover = m.len(); - } - pub fn raw_result(&mut self, output: &mut [u8]) { - assert!(output.len() >= 16); - if !self.finalized { - self.finish(); + #[test] + fn test_nacl_vector() { + let key = [ + 0xee, 0xa6, 0xa7, 0x25, 0x1c, 0x1e, 0x72, 0x91, 0x6d, 0x11, 0xc2, 0xcb, 0x21, 0x4d, + 0x3c, 0x25, 0x25, 0x39, 0x12, 0x1d, 0x8e, 0x23, 0x4e, 0x65, 0x2d, 0x65, 0x1f, 0xa4, + 0xc8, 0xcf, 0xf8, 0x80, + ]; + + let msg = [ + 0x8e, 0x99, 0x3b, 0x9f, 0x48, 0x68, 0x12, 0x73, 0xc2, 0x96, 0x50, 0xba, 0x32, 0xfc, + 0x76, 0xce, 0x48, 0x33, 0x2e, 0xa7, 0x16, 0x4d, 0x96, 0xa4, 0x47, 0x6f, 0xb8, 0xc5, + 0x31, 0xa1, 0x18, 0x6a, 0xc0, 0xdf, 0xc1, 0x7c, 0x98, 0xdc, 0xe8, 0x7b, 0x4d, 0xa7, + 0xf0, 0x11, 0xec, 0x48, 0xc9, 0x72, 0x71, 0xd2, 0xc2, 0x0f, 0x9b, 0x92, 0x8f, 0xe2, + 0x27, 0x0d, 0x6f, 0xb8, 0x63, 0xd5, 0x17, 0x38, 0xb4, 0x8e, 0xee, 0xe3, 0x14, 0xa7, + 0xcc, 0x8a, 0xb9, 0x32, 0x16, 0x45, 0x48, 0xe5, 0x26, 0xae, 0x90, 0x22, 0x43, 0x68, + 0x51, 0x7a, 0xcf, 0xea, 0xbd, 0x6b, 0xb3, 0x73, 0x2b, 0xc0, 0xe9, 0xda, 0x99, 0x83, + 0x2b, 0x61, 0xca, 0x01, 0xb6, 0xde, 0x56, 0x24, 0x4a, 0x9e, 0x88, 0xd5, 0xf9, 0xb3, + 0x79, 0x73, 0xf6, 0x22, 0xa4, 0x3d, 0x14, 0xa6, 0x59, 0x9b, 0x1f, 0x65, 0x4c, 0xb4, + 0x5a, 0x74, 0xe3, 0x55, 0xa5, + ]; + + let expected = [ + 0xf3, 0xff, 0xc7, 0x70, 0x3f, 0x94, 0x00, 0xe5, 0x2a, 0x7d, 0xfb, 0x4b, 0x3d, 0x33, + 0x05, 0xd9, + ]; + + let mut mac = [0u8; 16]; + poly1305(&key, &msg, &mut mac); + assert_eq!(&mac[..], &expected[..]); + + let mut poly = Poly1305::new(&key); + poly.input(&msg[0..32]); + poly.input(&msg[32..96]); + poly.input(&msg[96..112]); + poly.input(&msg[112..120]); + poly.input(&msg[120..124]); + poly.input(&msg[124..126]); + poly.input(&msg[126..127]); + poly.input(&msg[127..128]); + poly.input(&msg[128..129]); + poly.input(&msg[129..130]); + poly.input(&msg[130..131]); + let mac = poly.result(); + assert_eq!(&mac[..], &expected[..]); } - output[0..4].copy_from_slice(&self.h[0].to_le_bytes()); - output[4..8].copy_from_slice(&self.h[1].to_le_bytes()); - output[8..12].copy_from_slice(&self.h[2].to_le_bytes()); - output[12..16].copy_from_slice(&self.h[3].to_le_bytes()); - } -} -#[cfg(test)] -mod test { - use core::iter::repeat; + #[test] + fn donna_self_test() { + let wrap_key = [ + 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, + ]; - use super::Poly1305; + let wrap_msg = [ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, + ]; - fn poly1305(key: &[u8], msg: &[u8], mac: &mut [u8]) { - let mut poly = Poly1305::new(key); - poly.input(msg); - poly.raw_result(mac); - } + let wrap_mac = [ + 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ]; - #[test] - fn test_nacl_vector() { - let key = [ - 0xee, 0xa6, 0xa7, 0x25, 0x1c, 0x1e, 0x72, 0x91, 0x6d, 0x11, 0xc2, 0xcb, 0x21, 0x4d, - 0x3c, 0x25, 0x25, 0x39, 0x12, 0x1d, 0x8e, 0x23, 0x4e, 0x65, 0x2d, 0x65, 0x1f, 0xa4, - 0xc8, 0xcf, 0xf8, 0x80, - ]; - - let msg = [ - 0x8e, 0x99, 0x3b, 0x9f, 0x48, 0x68, 0x12, 0x73, 0xc2, 0x96, 0x50, 0xba, 0x32, 0xfc, - 0x76, 0xce, 0x48, 0x33, 0x2e, 0xa7, 0x16, 0x4d, 0x96, 0xa4, 0x47, 0x6f, 0xb8, 0xc5, - 0x31, 0xa1, 0x18, 0x6a, 0xc0, 0xdf, 0xc1, 0x7c, 0x98, 0xdc, 0xe8, 0x7b, 0x4d, 0xa7, - 0xf0, 0x11, 0xec, 0x48, 0xc9, 0x72, 0x71, 0xd2, 0xc2, 0x0f, 0x9b, 0x92, 0x8f, 0xe2, - 0x27, 0x0d, 0x6f, 0xb8, 0x63, 0xd5, 0x17, 0x38, 0xb4, 0x8e, 0xee, 0xe3, 0x14, 0xa7, - 0xcc, 0x8a, 0xb9, 0x32, 0x16, 0x45, 0x48, 0xe5, 0x26, 0xae, 0x90, 0x22, 0x43, 0x68, - 0x51, 0x7a, 0xcf, 0xea, 0xbd, 0x6b, 0xb3, 0x73, 0x2b, 0xc0, 0xe9, 0xda, 0x99, 0x83, - 0x2b, 0x61, 0xca, 0x01, 0xb6, 0xde, 0x56, 0x24, 0x4a, 0x9e, 0x88, 0xd5, 0xf9, 0xb3, - 0x79, 0x73, 0xf6, 0x22, 0xa4, 0x3d, 0x14, 0xa6, 0x59, 0x9b, 0x1f, 0x65, 0x4c, 0xb4, - 0x5a, 0x74, 0xe3, 0x55, 0xa5, - ]; - - let expected = [ - 0xf3, 0xff, 0xc7, 0x70, 0x3f, 0x94, 0x00, 0xe5, 0x2a, 0x7d, 0xfb, 0x4b, 0x3d, 0x33, - 0x05, 0xd9, - ]; - - let mut mac = [0u8; 16]; - poly1305(&key, &msg, &mut mac); - assert_eq!(&mac[..], &expected[..]); - - let mut poly = Poly1305::new(&key); - poly.input(&msg[0..32]); - poly.input(&msg[32..96]); - poly.input(&msg[96..112]); - poly.input(&msg[112..120]); - poly.input(&msg[120..124]); - poly.input(&msg[124..126]); - poly.input(&msg[126..127]); - poly.input(&msg[127..128]); - poly.input(&msg[128..129]); - poly.input(&msg[129..130]); - poly.input(&msg[130..131]); - poly.raw_result(&mut mac); - assert_eq!(&mac[..], &expected[..]); - } + let mut mac = [0u8; 16]; + poly1305(&wrap_key, &wrap_msg, &mut mac); + assert_eq!(&mac[..], &wrap_mac[..]); + + let total_key = [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, + ]; + + let total_mac = [ + 0x64, 0xaf, 0xe2, 0xe8, 0xd6, 0xad, 0x7b, 0xbd, 0xd2, 0x87, 0xf9, 0x7c, 0x44, 0x62, + 0x3d, 0x39, + ]; + + let mut tpoly = Poly1305::new(&total_key); + for i in 0..256 { + let key: Vec = repeat(i as u8).take(32).collect(); + let msg: Vec = repeat(i as u8).take(256).collect(); + let mut mac = [0u8; 16]; + poly1305(&key[..], &msg[0..i], &mut mac); + tpoly.input(&mac); + } + let mac = tpoly.result(); + assert_eq!(&mac[..], &total_mac[..]); + } - #[test] - fn donna_self_test() { - let wrap_key = [ - 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - ]; - - let wrap_msg = [ - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0xff, 0xff, - ]; - - let wrap_mac = [ - 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, - ]; - - let mut mac = [0u8; 16]; - poly1305(&wrap_key, &wrap_msg, &mut mac); - assert_eq!(&mac[..], &wrap_mac[..]); - - let total_key = [ - 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0xff, 0xfe, 0xfd, 0xfc, 0xfb, 0xfa, 0xf9, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, - 0x00, 0x00, 0x00, 0x00, - ]; - - let total_mac = [ - 0x64, 0xaf, 0xe2, 0xe8, 0xd6, 0xad, 0x7b, 0xbd, 0xd2, 0x87, 0xf9, 0x7c, 0x44, 0x62, - 0x3d, 0x39, - ]; - - let mut tpoly = Poly1305::new(&total_key); - for i in 0..256 { - let key: Vec = repeat(i as u8).take(32).collect(); - let msg: Vec = repeat(i as u8).take(256).collect(); + #[test] + fn test_tls_vectors() { + // from http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04 + let key = b"this is 32-byte key for Poly1305"; + let msg = [0u8; 32]; + let expected = [ + 0x49, 0xec, 0x78, 0x09, 0x0e, 0x48, 0x1e, 0xc6, 0xc2, 0x6b, 0x33, 0xb9, 0x1c, 0xcc, + 0x03, 0x07, + ]; let mut mac = [0u8; 16]; - poly1305(&key[..], &msg[0..i], &mut mac); - tpoly.input(&mac); + poly1305(key, &msg, &mut mac); + assert_eq!(&mac[..], &expected[..]); + + let msg = b"Hello world!"; + let expected = [ + 0xa6, 0xf7, 0x45, 0x00, 0x8f, 0x81, 0xc9, 0x16, 0xa2, 0x0d, 0xcc, 0x74, 0xee, 0xf2, + 0xb2, 0xf0, + ]; + poly1305(key, msg, &mut mac); + assert_eq!(&mac[..], &expected[..]); } - tpoly.raw_result(&mut mac); - assert_eq!(&mac[..], &total_mac[..]); } +} +pub use fuzzy_poly1305::*; + +#[cfg(fuzzing)] +mod fuzzy_poly1305 { + #[derive(Clone, Copy)] + pub struct Poly1305 { + tag: [u8; 16], + finalized: bool, + } + + impl Poly1305 { + pub fn new(key: &[u8]) -> Poly1305 { + assert_eq!(key.len(), 32); + let mut poly = Poly1305 { tag: [0; 16], finalized: false }; + poly.tag.copy_from_slice(&key[..16]); - #[test] - fn test_tls_vectors() { - // from http://tools.ietf.org/html/draft-agl-tls-chacha20poly1305-04 - let key = b"this is 32-byte key for Poly1305"; - let msg = [0u8; 32]; - let expected = [ - 0x49, 0xec, 0x78, 0x09, 0x0e, 0x48, 0x1e, 0xc6, 0xc2, 0x6b, 0x33, 0xb9, 0x1c, 0xcc, - 0x03, 0x07, - ]; - let mut mac = [0u8; 16]; - poly1305(key, &msg, &mut mac); - assert_eq!(&mac[..], &expected[..]); - - let msg = b"Hello world!"; - let expected = [ - 0xa6, 0xf7, 0x45, 0x00, 0x8f, 0x81, 0xc9, 0x16, 0xa2, 0x0d, 0xcc, 0x74, 0xee, 0xf2, - 0xb2, 0xf0, - ]; - poly1305(key, msg, &mut mac); - assert_eq!(&mac[..], &expected[..]); + poly + } + + pub fn finish(&mut self) { + self.finalized = true; + } + + pub fn input(&mut self, _data: &[u8]) { + assert!(!self.finalized); + } + + pub fn result(&mut self) -> [u8; 16] { + if !self.finalized { + self.finish(); + } + self.tag + } } } +pub use fuzzy_poly1305::*; diff --git a/lightning/src/crypto/streams.rs b/lightning/src/crypto/streams.rs index 82680131f1d..7f30245325b 100644 --- a/lightning/src/crypto/streams.rs +++ b/lightning/src/crypto/streams.rs @@ -1,5 +1,7 @@ use crate::crypto::chacha20::ChaCha20; use crate::crypto::chacha20poly1305rfc::ChaCha20Poly1305RFC; +use crate::crypto::fixed_time_eq; +use crate::crypto::poly1305::Poly1305; use crate::io::{self, Read, Write}; use crate::ln::msgs::DecodeError; @@ -7,6 +9,8 @@ use crate::util::ser::{ FixedLengthReader, LengthLimitedRead, LengthReadableArgs, Readable, Writeable, Writer, }; +use alloc::vec::Vec; + pub(crate) struct ChaChaReader<'a, R: io::Read> { pub chacha: &'a mut ChaCha20, pub read: R, @@ -49,6 +53,132 @@ impl<'a, T: Writeable> Writeable for ChaChaPolyWriteAdapter<'a, T> { } } +/// Encrypts the provided plaintext with the given key using ChaCha20Poly1305 in the modified +/// with-AAD form used in [`ChaChaDualPolyReadAdapter`]. +pub(crate) fn chachapoly_encrypt_with_swapped_aad( + mut plaintext: Vec, key: [u8; 32], aad: [u8; 32], +) -> Vec { + let mut chacha = ChaCha20::new(&key[..], &[0; 12]); + let mut mac_key = [0u8; 64]; + chacha.process_in_place(&mut mac_key); + + let mut mac = Poly1305::new(&mac_key[..32]); + chacha.process_in_place(&mut plaintext[..]); + mac.input(&plaintext[..]); + + if plaintext.len() % 16 != 0 { + mac.input(&[0; 16][0..16 - (plaintext.len() % 16)]); + } + + mac.input(&aad[..]); + // Note that we don't need to pad the AAD since its a multiple of 16 bytes + + mac.input(&(plaintext.len() as u64).to_le_bytes()); + mac.input(&32u64.to_le_bytes()); + + plaintext.extend_from_slice(&mac.result()); + plaintext +} + +/// Enables the use of the serialization macros for objects that need to be simultaneously decrypted +/// and deserialized. This allows us to avoid an intermediate Vec allocation. +/// +/// This variant of [`ChaChaPolyReadAdapter`] calculates Poly1305 tags twice, once using the given +/// key and once with the given 32-byte AAD appended after the encrypted stream, accepting either +/// being correct as sufficient. +/// +/// Note that we do *not* use the provided AAD as the standard ChaCha20Poly1305 AAD as that would +/// require placing it first and prevent us from avoiding redundant Poly1305 rounds. Instead, the +/// ChaCha20Poly1305 MAC check is tweaked to move the AAD to *after* the the contents being +/// checked, effectively treating the contents as the AAD for the AAD-containing MAC but behaving +/// like classic ChaCha20Poly1305 for the non-AAD-containing MAC. +pub(crate) struct ChaChaDualPolyReadAdapter { + pub readable: R, + pub used_aad: bool, +} + +impl LengthReadableArgs<([u8; 32], [u8; 32])> for ChaChaDualPolyReadAdapter { + // Simultaneously read and decrypt an object from a LengthLimitedRead storing it in + // Self::readable. LengthLimitedRead must be used instead of std::io::Read because we need the + // total length to separate out the tag at the end. + fn read( + r: &mut R, params: ([u8; 32], [u8; 32]), + ) -> Result { + if r.remaining_bytes() < 16 { + return Err(DecodeError::InvalidValue); + } + let (key, aad) = params; + + let mut chacha = ChaCha20::new(&key[..], &[0; 12]); + let mut mac_key = [0u8; 64]; + chacha.process_in_place(&mut mac_key); + + #[cfg(not(fuzzing))] + let mut mac = Poly1305::new(&mac_key[..32]); + #[cfg(fuzzing)] + let mut mac = Poly1305::new(&key); + + let decrypted_len = r.remaining_bytes() - 16; + let s = FixedLengthReader::new(r, decrypted_len); + let mut chacha_stream = + ChaChaDualPolyReader { chacha: &mut chacha, poly: &mut mac, read_len: 0, read: s }; + + let readable: T = Readable::read(&mut chacha_stream)?; + chacha_stream.read.eat_remaining()?; + + let read_len = chacha_stream.read_len; + + if read_len % 16 != 0 { + mac.input(&[0; 16][0..16 - (read_len % 16)]); + } + + let mut mac_aad = mac; + + mac_aad.input(&aad[..]); + // Note that we don't need to pad the AAD since its a multiple of 16 bytes + + // For the AAD-containing MAC, swap the AAD and the read data, effectively. + mac_aad.input(&(read_len as u64).to_le_bytes()); + mac_aad.input(&32u64.to_le_bytes()); + + // For the non-AAD-containing MAC, leave the data and AAD where they belong. + mac.input(&0u64.to_le_bytes()); + mac.input(&(read_len as u64).to_le_bytes()); + + let mut tag = [0 as u8; 16]; + r.read_exact(&mut tag)?; + if fixed_time_eq(&mac.result(), &tag) { + Ok(Self { readable, used_aad: false }) + } else if fixed_time_eq(&mac_aad.result(), &tag) { + Ok(Self { readable, used_aad: true }) + } else { + return Err(DecodeError::InvalidValue); + } + } +} + +struct ChaChaDualPolyReader<'a, R: Read> { + chacha: &'a mut ChaCha20, + poly: &'a mut Poly1305, + read_len: usize, + pub read: R, +} + +impl<'a, R: Read> Read for ChaChaDualPolyReader<'a, R> { + // Decrypt bytes from Self::read into `dest`. + // `ChaCha20Poly1305RFC::finish_and_check_tag` must be called to check the tag after all reads + // complete. + fn read(&mut self, dest: &mut [u8]) -> Result { + let res = self.read.read(dest)?; + if res > 0 { + self.poly.input(&dest[0..res]); + self.chacha.process_in_place(&mut dest[0..res]); + self.read_len += res; + } + Ok(res) + } +} + /// Enables the use of the serialization macros for objects that need to be simultaneously decrypted and /// deserialized. This allows us to avoid an intermediate Vec allocation. pub(crate) struct ChaChaPolyReadAdapter { diff --git a/lightning/src/ln/blinded_payment_tests.rs b/lightning/src/ln/blinded_payment_tests.rs index d65fa53a544..f36172ff777 100644 --- a/lightning/src/ln/blinded_payment_tests.rs +++ b/lightning/src/ln/blinded_payment_tests.rs @@ -1556,17 +1556,23 @@ fn route_blinding_spec_test_vector() { let blinding_override = PublicKey::from_secret_key(&secp_ctx, &dave_eve_session_priv); assert_eq!(blinding_override, pubkey_from_hex("031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f")); // Can't use the public API here as the encrypted payloads contain unknown TLVs. - let path = [(dave_node_id, WithoutLength(&dave_unblinded_tlvs)), (eve_node_id, WithoutLength(&eve_unblinded_tlvs))]; + let path = [ + ((dave_node_id, None), WithoutLength(&dave_unblinded_tlvs)), + ((eve_node_id, None), WithoutLength(&eve_unblinded_tlvs)), + ]; let mut dave_eve_blinded_hops = blinded_path::utils::construct_blinded_hops( - &secp_ctx, path.into_iter(), &dave_eve_session_priv + &secp_ctx, path.into_iter(), &dave_eve_session_priv, ).unwrap(); // Concatenate an additional Bob -> Carol blinded path to the Eve -> Dave blinded path. let bob_carol_session_priv = secret_from_hex("0202020202020202020202020202020202020202020202020202020202020202"); let bob_blinding_point = PublicKey::from_secret_key(&secp_ctx, &bob_carol_session_priv); - let path = [(bob_node_id, WithoutLength(&bob_unblinded_tlvs)), (carol_node_id, WithoutLength(&carol_unblinded_tlvs))]; + let path = [ + ((bob_node_id, None), WithoutLength(&bob_unblinded_tlvs)), + ((carol_node_id, None), WithoutLength(&carol_unblinded_tlvs)), + ]; let bob_carol_blinded_hops = blinded_path::utils::construct_blinded_hops( - &secp_ctx, path.into_iter(), &bob_carol_session_priv + &secp_ctx, path.into_iter(), &bob_carol_session_priv, ).unwrap(); let mut blinded_hops = bob_carol_blinded_hops; @@ -2026,9 +2032,9 @@ fn do_test_trampoline_single_hop_receive(success: bool) { let payee_tlvs = payee_tlvs.authenticate(nonce, &expanded_key); let carol_unblinded_tlvs = payee_tlvs.encode(); - let path = [(carol_node_id, WithoutLength(&carol_unblinded_tlvs))]; + let path = [((carol_node_id, None), WithoutLength(&carol_unblinded_tlvs))]; blinded_path::utils::construct_blinded_hops( - &secp_ctx, path.into_iter(), &carol_alice_trampoline_session_priv + &secp_ctx, path.into_iter(), &carol_alice_trampoline_session_priv, ).unwrap() } else { let payee_tlvs = blinded_path::payment::TrampolineForwardTlvs { @@ -2047,9 +2053,9 @@ fn do_test_trampoline_single_hop_receive(success: bool) { }; let carol_unblinded_tlvs = payee_tlvs.encode(); - let path = [(carol_node_id, WithoutLength(&carol_unblinded_tlvs))]; + let path = [((carol_node_id, None), WithoutLength(&carol_unblinded_tlvs))]; blinded_path::utils::construct_blinded_hops( - &secp_ctx, path.into_iter(), &carol_alice_trampoline_session_priv + &secp_ctx, path.into_iter(), &carol_alice_trampoline_session_priv, ).unwrap() }; @@ -2249,11 +2255,11 @@ fn test_trampoline_unblinded_receive() { }; let carol_unblinded_tlvs = payee_tlvs.encode(); - let path = [(carol_node_id, WithoutLength(&carol_unblinded_tlvs))]; + let path = [((carol_node_id, None), WithoutLength(&carol_unblinded_tlvs))]; let carol_alice_trampoline_session_priv = secret_from_hex("a0f4b8d7b6c2d0ffdfaf718f76e9decaef4d9fb38a8c4addb95c4007cc3eee03"); let carol_blinding_point = PublicKey::from_secret_key(&secp_ctx, &carol_alice_trampoline_session_priv); let carol_blinded_hops = blinded_path::utils::construct_blinded_hops( - &secp_ctx, path.into_iter(), &carol_alice_trampoline_session_priv + &secp_ctx, path.into_iter(), &carol_alice_trampoline_session_priv, ).unwrap(); let route = Route { diff --git a/lightning/src/onion_message/messenger.rs b/lightning/src/onion_message/messenger.rs index fd98f78350e..f8d68a38e56 100644 --- a/lightning/src/onion_message/messenger.rs +++ b/lightning/src/onion_message/messenger.rs @@ -1068,11 +1068,12 @@ where }, } }; + let receiving_context_auth_key = [41; 32]; // TODO: pass this in let next_hop = onion_utils::decode_next_untagged_hop( onion_decode_ss, &msg.onion_routing_packet.hop_data[..], msg.onion_routing_packet.hmac, - (control_tlvs_ss, custom_handler.deref(), logger.deref()), + (control_tlvs_ss, custom_handler.deref(), receiving_context_auth_key, logger.deref()), ); match next_hop { Ok(( @@ -1080,6 +1081,7 @@ where message, control_tlvs: ReceiveControlTlvs::Unblinded(ReceiveTlvs { context }), reply_path, + control_tlvs_authenticated, }, None, )) => match (message, context) { @@ -1108,6 +1110,8 @@ where Ok(PeeledOnion::DNSResolver(msg, None, reply_path)) }, _ => { + // Hide the "`control_tlvs_authenticated` is unused warning". We'll use it here soon + let _ = control_tlvs_authenticated; log_trace!( logger, "Received message was sent on a blinded path with wrong or missing context." @@ -1116,10 +1120,11 @@ where }, }, Ok(( - Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs { - next_hop, - next_blinding_override, - })), + Payload::Forward { + control_tlvs: + ForwardControlTlvs::Unblinded(ForwardTlvs { next_hop, next_blinding_override }), + control_tlvs_authenticated, + }, Some((next_hop_hmac, new_packet_bytes)), )) => { // TODO: we need to check whether `next_hop` is our node, in which case this is a dummy @@ -1127,6 +1132,17 @@ where // unwrapping the onion layers to get to the final payload. Since we don't have the option // of creating blinded paths with dummy hops currently, we should be ok to not handle this // for now. + if control_tlvs_authenticated { + // TODO: When we start adding dummy hops, we should require the use of + // authenticated control TLVs, as it prevents a DoS attack where someone builds a + // blinded path to us which requires we decode hundreds of dummy hops only to find + // that we don't actually have a message inside to read. + // However, we should never accept a `control_tlvs_authenticated` packet which is + // *not* a dummy blinded hop we added, though it shouldn't be possible to reach in + // any case. + log_trace!(logger, "Received an authenticated to-forward onion message"); + return Err(()); + } let packet_pubkey = msg.onion_routing_packet.public_key; let new_pubkey_opt = onion_utils::next_hop_pubkey(&secp_ctx, packet_pubkey, &onion_decode_ss); @@ -2243,19 +2259,17 @@ fn packet_payloads_and_keys< unblinded_path.into_iter(), destination, session_priv, - |_, - onion_packet_ss, - ephemeral_pubkey, - control_tlvs_ss, - unblinded_pk_opt, - enc_payload_opt| { + |onion_packet_ss, ephemeral_pubkey, control_tlvs_ss, unblinded_pk_opt, enc_payload_opt| { if num_unblinded_hops != 0 && unblinded_path_idx < num_unblinded_hops { if let Some(ss) = prev_control_tlvs_ss.take() { payloads.push(( - Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs { - next_hop: NextMessageHop::NodeId(unblinded_pk_opt.unwrap()), - next_blinding_override: None, - })), + Payload::Forward { + control_tlvs: ForwardControlTlvs::Unblinded(ForwardTlvs { + next_hop: NextMessageHop::NodeId(unblinded_pk_opt.unwrap()), + next_blinding_override: None, + }), + control_tlvs_authenticated: false, + }, ss, )); } @@ -2264,17 +2278,23 @@ fn packet_payloads_and_keys< } else if let Some((intro_node_id, blinding_pt)) = intro_node_id_blinding_pt.take() { if let Some(control_tlvs_ss) = prev_control_tlvs_ss.take() { payloads.push(( - Payload::Forward(ForwardControlTlvs::Unblinded(ForwardTlvs { - next_hop: NextMessageHop::NodeId(intro_node_id), - next_blinding_override: Some(blinding_pt), - })), + Payload::Forward { + control_tlvs: ForwardControlTlvs::Unblinded(ForwardTlvs { + next_hop: NextMessageHop::NodeId(intro_node_id), + next_blinding_override: Some(blinding_pt), + }), + control_tlvs_authenticated: false, + }, control_tlvs_ss, )); } } if blinded_path_idx < num_blinded_hops.saturating_sub(1) && enc_payload_opt.is_some() { payloads.push(( - Payload::Forward(ForwardControlTlvs::Blinded(enc_payload_opt.unwrap())), + Payload::Forward { + control_tlvs: ForwardControlTlvs::Blinded(enc_payload_opt.unwrap()), + control_tlvs_authenticated: false, + }, control_tlvs_ss, )); blinded_path_idx += 1; @@ -2299,7 +2319,12 @@ fn packet_payloads_and_keys< if let Some(control_tlvs) = final_control_tlvs { payloads.push(( - Payload::Receive { control_tlvs, reply_path: reply_path.take(), message }, + Payload::Receive { + control_tlvs, + reply_path: reply_path.take(), + message, + control_tlvs_authenticated: false, + }, prev_control_tlvs_ss.unwrap(), )); } else { @@ -2308,6 +2333,7 @@ fn packet_payloads_and_keys< control_tlvs: ReceiveControlTlvs::Unblinded(ReceiveTlvs { context: None }), reply_path: reply_path.take(), message, + control_tlvs_authenticated: false, }, prev_control_tlvs_ss.unwrap(), )); diff --git a/lightning/src/onion_message/packet.rs b/lightning/src/onion_message/packet.rs index 632cbc9c8a3..7d80be6dae8 100644 --- a/lightning/src/onion_message/packet.rs +++ b/lightning/src/onion_message/packet.rs @@ -18,7 +18,7 @@ use super::dns_resolution::DNSResolverMessage; use super::messenger::CustomOnionMessageHandler; use super::offers::OffersMessage; use crate::blinded_path::message::{BlindedMessagePath, ForwardTlvs, NextMessageHop, ReceiveTlvs}; -use crate::crypto::streams::{ChaChaPolyReadAdapter, ChaChaPolyWriteAdapter}; +use crate::crypto::streams::{ChaChaDualPolyReadAdapter, ChaChaPolyWriteAdapter}; use crate::ln::msgs::DecodeError; use crate::ln::onion_utils; use crate::util::logger::Logger; @@ -110,9 +110,25 @@ impl LengthReadable for Packet { /// message content itself, such as an invoice request. pub(super) enum Payload { /// This payload is for an intermediate hop. - Forward(ForwardControlTlvs), + Forward { + /// The [`ReceiveControlTlvs`] were authenticated with the additional key which was + /// provided to [`ReadableArgs::read`]. + /// + /// This should not happen for blinded paths built by any node but us (i.e. messages + /// forwarded through us to other nodes), but can be used to authenticate extra hops added + /// to a blinded path as padding. + control_tlvs_authenticated: bool, + control_tlvs: ForwardControlTlvs, + }, /// This payload is for the final hop. - Receive { control_tlvs: ReceiveControlTlvs, reply_path: Option, message: T }, + Receive { + /// The [`ReceiveControlTlvs`] were authenticated with the additional key which was + /// provided to [`ReadableArgs::read`]. + control_tlvs_authenticated: bool, + control_tlvs: ReceiveControlTlvs, + reply_path: Option, + message: T, + }, } /// The contents of an [`OnionMessage`] as read from the wire. @@ -216,13 +232,17 @@ pub(super) enum ReceiveControlTlvs { impl Writeable for (Payload, [u8; 32]) { fn write(&self, w: &mut W) -> Result<(), io::Error> { match &self.0 { - Payload::Forward(ForwardControlTlvs::Blinded(encrypted_bytes)) => { + Payload::Forward { + control_tlvs: ForwardControlTlvs::Blinded(encrypted_bytes), + control_tlvs_authenticated: _, + } => { _encode_varint_length_prefixed_tlv!(w, { (4, encrypted_bytes, required_vec) }) }, Payload::Receive { control_tlvs: ReceiveControlTlvs::Blinded(encrypted_bytes), reply_path, message, + control_tlvs_authenticated: _, } => { _encode_varint_length_prefixed_tlv!(w, { (2, reply_path, option), @@ -230,7 +250,10 @@ impl Writeable for (Payload, [u8; 32]) { (message.tlv_type(), message, required) }) }, - Payload::Forward(ForwardControlTlvs::Unblinded(control_tlvs)) => { + Payload::Forward { + control_tlvs: ForwardControlTlvs::Unblinded(control_tlvs), + control_tlvs_authenticated: _, + } => { let write_adapter = ChaChaPolyWriteAdapter::new(self.1, &control_tlvs); _encode_varint_length_prefixed_tlv!(w, { (4, write_adapter, required) }) }, @@ -238,6 +261,7 @@ impl Writeable for (Payload, [u8; 32]) { control_tlvs: ReceiveControlTlvs::Unblinded(control_tlvs), reply_path, message, + control_tlvs_authenticated: _, } => { let write_adapter = ChaChaPolyWriteAdapter::new(self.1, &control_tlvs); _encode_varint_length_prefixed_tlv!(w, { @@ -252,22 +276,25 @@ impl Writeable for (Payload, [u8; 32]) { } // Uses the provided secret to simultaneously decode and decrypt the control TLVs and data TLV. -impl ReadableArgs<(SharedSecret, &H, &L)> +impl + ReadableArgs<(SharedSecret, &H, [u8; 32], &L)> for Payload::CustomMessage>> { - fn read(r: &mut R, args: (SharedSecret, &H, &L)) -> Result { - let (encrypted_tlvs_ss, handler, logger) = args; + fn read( + r: &mut R, args: (SharedSecret, &H, [u8; 32], &L), + ) -> Result { + let (encrypted_tlvs_ss, handler, receive_tlvs_key, logger) = args; let v: BigSize = Readable::read(r)?; let mut rd = FixedLengthReader::new(r, v.0); let mut reply_path: Option = None; - let mut read_adapter: Option> = None; + let mut read_adapter: Option> = None; let rho = onion_utils::gen_rho_from_shared_secret(&encrypted_tlvs_ss.secret_bytes()); let mut message_type: Option = None; let mut message = None; decode_tlv_stream_with_custom_tlv_decode!(&mut rd, { (2, reply_path, option), - (4, read_adapter, (option: LengthReadableArgs, rho)), + (4, read_adapter, (option: LengthReadableArgs, (rho, receive_tlvs_key))), }, |msg_type, msg_reader| { if msg_type < 64 { return Ok(false) } // Don't allow reading more than one data TLV from an onion message. @@ -304,17 +331,21 @@ impl ReadableArgs<(Sh match read_adapter { None => return Err(DecodeError::InvalidValue), - Some(ChaChaPolyReadAdapter { readable: ControlTlvs::Forward(tlvs) }) => { - if message_type.is_some() { + Some(ChaChaDualPolyReadAdapter { readable: ControlTlvs::Forward(tlvs), used_aad }) => { + if used_aad || message_type.is_some() { return Err(DecodeError::InvalidValue); } - Ok(Payload::Forward(ForwardControlTlvs::Unblinded(tlvs))) + Ok(Payload::Forward { + control_tlvs: ForwardControlTlvs::Unblinded(tlvs), + control_tlvs_authenticated: used_aad, + }) }, - Some(ChaChaPolyReadAdapter { readable: ControlTlvs::Receive(tlvs) }) => { + Some(ChaChaDualPolyReadAdapter { readable: ControlTlvs::Receive(tlvs), used_aad }) => { Ok(Payload::Receive { control_tlvs: ReceiveControlTlvs::Unblinded(tlvs), reply_path, message: message.ok_or(DecodeError::InvalidValue)?, + control_tlvs_authenticated: used_aad, }) }, }