diff --git a/lightning/src/ln/channel.rs b/lightning/src/ln/channel.rs index 930511c0dd4..24b1eafafcd 100644 --- a/lightning/src/ln/channel.rs +++ b/lightning/src/ln/channel.rs @@ -5525,7 +5525,7 @@ mod tests { use bitcoin::hashes::hex::FromHex; use hex; use ln::{PaymentPreimage, PaymentHash}; - use ln::channelmanager::HTLCSource; + use ln::channelmanager::{HTLCSource, MppId}; use ln::channel::{Channel,InboundHTLCOutput,OutboundHTLCOutput,InboundHTLCState,OutboundHTLCState,HTLCOutputInCommitment,HTLCCandidate,HTLCInitiator,TxCreationKeys}; use ln::channel::MAX_FUNDING_SATOSHIS; use ln::features::InitFeatures; @@ -5699,6 +5699,7 @@ mod tests { path: Vec::new(), session_priv: SecretKey::from_slice(&hex::decode("0fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff").unwrap()[..]).unwrap(), first_hop_htlc_msat: 548, + mpp_id: MppId([42; 32]), } }); diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index b772c5e9ac7..f985976a73d 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -173,6 +173,22 @@ struct ClaimableHTLC { onion_payload: OnionPayload, } +/// A payment identifier used to correlate an MPP payment's per-path HTLC sources internally. +#[derive(Hash, Copy, Clone, PartialEq, Eq, Debug)] +pub(crate) struct MppId(pub [u8; 32]); + +impl Writeable for MppId { + fn write(&self, w: &mut W) -> Result<(), io::Error> { + self.0.write(w) + } +} + +impl Readable for MppId { + fn read(r: &mut R) -> Result { + let buf: [u8; 32] = Readable::read(r)?; + Ok(MppId(buf)) + } +} /// Tracks the inbound corresponding to an outbound HTLC #[derive(Clone, PartialEq)] pub(crate) enum HTLCSource { @@ -183,6 +199,7 @@ pub(crate) enum HTLCSource { /// Technically we can recalculate this from the route, but we cache it here to avoid /// doing a double-pass on route when we get a failure back first_hop_htlc_msat: u64, + mpp_id: MppId, }, } #[cfg(test)] @@ -192,6 +209,7 @@ impl HTLCSource { path: Vec::new(), session_priv: SecretKey::from_slice(&[1; 32]).unwrap(), first_hop_htlc_msat: 0, + mpp_id: MppId([2; 32]), } } } @@ -473,8 +491,11 @@ pub struct ChannelManager>, + pending_outbound_payments: Mutex>>, our_network_key: SecretKey, our_network_pubkey: PublicKey, @@ -1138,7 +1159,7 @@ impl ChannelMana pending_msg_events: Vec::new(), }), pending_inbound_payments: Mutex::new(HashMap::new()), - pending_outbound_payments: Mutex::new(HashSet::new()), + pending_outbound_payments: Mutex::new(HashMap::new()), our_network_key: keys_manager.get_node_secret(), our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &keys_manager.get_node_secret()), @@ -1820,7 +1841,7 @@ impl ChannelMana } // Only public for testing, this should otherwise never be called direcly - pub(crate) fn send_payment_along_path(&self, path: &Vec, payment_hash: &PaymentHash, payment_secret: &Option, total_value: u64, cur_height: u32, keysend_preimage: &Option) -> Result<(), APIError> { + pub(crate) fn send_payment_along_path(&self, path: &Vec, payment_hash: &PaymentHash, payment_secret: &Option, total_value: u64, cur_height: u32, mpp_id: MppId, keysend_preimage: &Option) -> Result<(), APIError> { log_trace!(self.logger, "Attempting to send payment for path with next hop {}", path.first().unwrap().short_channel_id); let prng_seed = self.keys_manager.get_secure_random_bytes(); let session_priv_bytes = self.keys_manager.get_secure_random_bytes(); @@ -1835,7 +1856,9 @@ impl ChannelMana let onion_packet = onion_utils::construct_onion_packet(onion_payloads, onion_keys, prng_seed, payment_hash); let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(&self.total_consistency_lock, &self.persistence_notifier); - assert!(self.pending_outbound_payments.lock().unwrap().insert(session_priv_bytes)); + let mut pending_outbounds = self.pending_outbound_payments.lock().unwrap(); + let sessions = pending_outbounds.entry(mpp_id).or_insert(HashSet::new()); + assert!(sessions.insert(session_priv_bytes)); let err: Result<(), _> = loop { let mut channel_lock = self.channel_state.lock().unwrap(); @@ -1857,6 +1880,7 @@ impl ChannelMana path: path.clone(), session_priv: session_priv.clone(), first_hop_htlc_msat: htlc_msat, + mpp_id, }, onion_packet, &self.logger), channel_state, chan) } { Some((update_add, commitment_signed, monitor_update)) => { @@ -1956,6 +1980,7 @@ impl ChannelMana let mut total_value = 0; let our_node_id = self.get_our_node_id(); let mut path_errs = Vec::with_capacity(route.paths.len()); + let mpp_id = MppId(self.keys_manager.get_secure_random_bytes()); 'path_check: for path in route.paths.iter() { if path.len() < 1 || path.len() > 20 { path_errs.push(Err(APIError::RouteError{err: "Path didn't go anywhere/had bogus size"})); @@ -1977,7 +2002,7 @@ impl ChannelMana let cur_height = self.best_block.read().unwrap().height() + 1; let mut results = Vec::new(); for path in route.paths.iter() { - results.push(self.send_payment_along_path(&path, &payment_hash, payment_secret, total_value, cur_height, &keysend_preimage)); + results.push(self.send_payment_along_path(&path, &payment_hash, payment_secret, total_value, cur_height, mpp_id, &keysend_preimage)); } let mut has_ok = false; let mut has_err = false; @@ -2812,23 +2837,28 @@ impl ChannelMana self.fail_htlc_backwards_internal(channel_state, htlc_src, &payment_hash, HTLCFailReason::Reason { failure_code, data: onion_failure_data}); }, - HTLCSource::OutboundRoute { session_priv, .. } => { - if { - let mut session_priv_bytes = [0; 32]; - session_priv_bytes.copy_from_slice(&session_priv[..]); - self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes) - } { - self.pending_events.lock().unwrap().push( - events::Event::PaymentFailed { - payment_hash, - rejected_by_dest: false, - network_update: None, -#[cfg(test)] - error_code: None, -#[cfg(test)] - error_data: None, + HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => { + let mut session_priv_bytes = [0; 32]; + session_priv_bytes.copy_from_slice(&session_priv[..]); + let mut outbounds = self.pending_outbound_payments.lock().unwrap(); + if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(mpp_id) { + if sessions.get_mut().remove(&session_priv_bytes) { + self.pending_events.lock().unwrap().push( + events::Event::PaymentFailed { + payment_hash, + rejected_by_dest: false, + network_update: None, + all_paths_failed: sessions.get().len() == 0, + #[cfg(test)] + error_code: None, + #[cfg(test)] + error_data: None, + } + ); + if sessions.get().len() == 0 { + sessions.remove(); } - ) + } } else { log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0)); } @@ -2853,12 +2883,21 @@ impl ChannelMana // from block_connected which may run during initialization prior to the chain_monitor // being fully configured. See the docs for `ChannelManagerReadArgs` for more. match source { - HTLCSource::OutboundRoute { ref path, session_priv, .. } => { - if { - let mut session_priv_bytes = [0; 32]; - session_priv_bytes.copy_from_slice(&session_priv[..]); - !self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes) - } { + HTLCSource::OutboundRoute { ref path, session_priv, mpp_id, .. } => { + let mut session_priv_bytes = [0; 32]; + session_priv_bytes.copy_from_slice(&session_priv[..]); + let mut outbounds = self.pending_outbound_payments.lock().unwrap(); + let mut all_paths_failed = false; + if let hash_map::Entry::Occupied(mut sessions) = outbounds.entry(mpp_id) { + if !sessions.get_mut().remove(&session_priv_bytes) { + log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0)); + return; + } + if sessions.get().len() == 0 { + all_paths_failed = true; + sessions.remove(); + } + } else { log_trace!(self.logger, "Received duplicative fail for HTLC with payment_hash {}", log_bytes!(payment_hash.0)); return; } @@ -2878,6 +2917,7 @@ impl ChannelMana payment_hash: payment_hash.clone(), rejected_by_dest: !payment_retryable, network_update, + all_paths_failed, #[cfg(test)] error_code: onion_error_code, #[cfg(test)] @@ -2903,6 +2943,7 @@ impl ChannelMana payment_hash: payment_hash.clone(), rejected_by_dest: path.len() == 1, network_update: None, + all_paths_failed, #[cfg(test)] error_code: Some(*failure_code), #[cfg(test)] @@ -3099,17 +3140,18 @@ impl ChannelMana fn claim_funds_internal(&self, mut channel_state_lock: MutexGuard>, source: HTLCSource, payment_preimage: PaymentPreimage, forwarded_htlc_value_msat: Option, from_onchain: bool) { match source { - HTLCSource::OutboundRoute { session_priv, .. } => { + HTLCSource::OutboundRoute { session_priv, mpp_id, .. } => { mem::drop(channel_state_lock); - if { - let mut session_priv_bytes = [0; 32]; - session_priv_bytes.copy_from_slice(&session_priv[..]); - self.pending_outbound_payments.lock().unwrap().remove(&session_priv_bytes) - } { - let mut pending_events = self.pending_events.lock().unwrap(); - pending_events.push(events::Event::PaymentSent { - payment_preimage - }); + let mut session_priv_bytes = [0; 32]; + session_priv_bytes.copy_from_slice(&session_priv[..]); + let mut outbounds = self.pending_outbound_payments.lock().unwrap(); + let found_payment = if let Some(mut sessions) = outbounds.remove(&mpp_id) { + sessions.remove(&session_priv_bytes) + } else { false }; + if found_payment { + self.pending_events.lock().unwrap().push( + events::Event::PaymentSent { payment_preimage } + ); } else { log_trace!(self.logger, "Received duplicative fulfill for HTLC with payment_preimage {}", log_bytes!(payment_preimage.0)); } @@ -4911,14 +4953,60 @@ impl Readable for ClaimableHTLC { } } -impl_writeable_tlv_based_enum!(HTLCSource, - (0, OutboundRoute) => { - (0, session_priv, required), - (2, first_hop_htlc_msat, required), - (4, path, vec_type), - }, ; - (1, PreviousHopData) -); +impl Readable for HTLCSource { + fn read(reader: &mut R) -> Result { + let id: u8 = Readable::read(reader)?; + match id { + 0 => { + let mut session_priv: ::util::ser::OptionDeserWrapper = ::util::ser::OptionDeserWrapper(None); + let mut first_hop_htlc_msat: u64 = 0; + let mut path = Some(Vec::new()); + let mut mpp_id = None; + read_tlv_fields!(reader, { + (0, session_priv, required), + (1, mpp_id, option), + (2, first_hop_htlc_msat, required), + (4, path, vec_type), + }); + if mpp_id.is_none() { + // For backwards compat, if there was no mpp_id written, use the session_priv bytes + // instead. + mpp_id = Some(MppId(*session_priv.0.unwrap().as_ref())); + } + Ok(HTLCSource::OutboundRoute { + session_priv: session_priv.0.unwrap(), + first_hop_htlc_msat: first_hop_htlc_msat, + path: path.unwrap(), + mpp_id: mpp_id.unwrap(), + }) + } + 1 => Ok(HTLCSource::PreviousHopData(Readable::read(reader)?)), + _ => Err(DecodeError::UnknownRequiredFeature), + } + } +} + +impl Writeable for HTLCSource { + fn write(&self, writer: &mut W) -> Result<(), ::io::Error> { + match self { + HTLCSource::OutboundRoute { ref session_priv, ref first_hop_htlc_msat, ref path, mpp_id } => { + 0u8.write(writer)?; + let mpp_id_opt = Some(mpp_id); + write_tlv_fields!(writer, { + (0, session_priv, required), + (1, mpp_id_opt, option), + (2, first_hop_htlc_msat, required), + (4, path, vec_type), + }); + } + HTLCSource::PreviousHopData(ref field) => { + 1u8.write(writer)?; + field.write(writer)?; + } + } + Ok(()) + } +} impl_writeable_tlv_based_enum!(HTLCFailReason, (0, LightningError) => { @@ -5039,12 +5127,21 @@ impl Writeable f } let pending_outbound_payments = self.pending_outbound_payments.lock().unwrap(); - (pending_outbound_payments.len() as u64).write(writer)?; - for session_priv in pending_outbound_payments.iter() { - session_priv.write(writer)?; + // For backwards compat, write the session privs and their total length. + let mut num_pending_outbounds_compat: u64 = 0; + for (_, outbounds) in pending_outbound_payments.iter() { + num_pending_outbounds_compat += outbounds.len() as u64; + } + num_pending_outbounds_compat.write(writer)?; + for (_, outbounds) in pending_outbound_payments.iter() { + for outbound in outbounds.iter() { + outbound.write(writer)?; + } } - write_tlv_fields!(writer, {}); + write_tlv_fields!(writer, { + (1, pending_outbound_payments, required), + }); Ok(()) } @@ -5297,15 +5394,23 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> } } - let pending_outbound_payments_count: u64 = Readable::read(reader)?; - let mut pending_outbound_payments: HashSet<[u8; 32]> = HashSet::with_capacity(cmp::min(pending_outbound_payments_count as usize, MAX_ALLOC_SIZE/32)); - for _ in 0..pending_outbound_payments_count { - if !pending_outbound_payments.insert(Readable::read(reader)?) { - return Err(DecodeError::InvalidValue); - } + let pending_outbound_payments_count_compat: u64 = Readable::read(reader)?; + let mut pending_outbound_payments_compat: HashMap> = + HashMap::with_capacity(cmp::min(pending_outbound_payments_count_compat as usize, MAX_ALLOC_SIZE/32)); + for _ in 0..pending_outbound_payments_count_compat { + let session_priv = Readable::read(reader)?; + if pending_outbound_payments_compat.insert(MppId(session_priv), [session_priv].iter().cloned().collect()).is_some() { + return Err(DecodeError::InvalidValue) + }; } - read_tlv_fields!(reader, {}); + let mut pending_outbound_payments = None; + read_tlv_fields!(reader, { + (1, pending_outbound_payments, option), + }); + if pending_outbound_payments.is_none() { + pending_outbound_payments = Some(pending_outbound_payments_compat); + } let mut secp_ctx = Secp256k1::new(); secp_ctx.seeded_randomize(&args.keys_manager.get_secure_random_bytes()); @@ -5326,7 +5431,7 @@ impl<'a, Signer: Sign, M: Deref, T: Deref, K: Deref, F: Deref, L: Deref> pending_msg_events: Vec::new(), }), pending_inbound_payments: Mutex::new(pending_inbound_payments), - pending_outbound_payments: Mutex::new(pending_outbound_payments), + pending_outbound_payments: Mutex::new(pending_outbound_payments.unwrap()), our_network_key: args.keys_manager.get_node_secret(), our_network_pubkey: PublicKey::from_secret_key(&secp_ctx, &args.keys_manager.get_node_secret()), @@ -5364,7 +5469,7 @@ mod tests { use bitcoin::hashes::sha256::Hash as Sha256; use core::time::Duration; use ln::{PaymentPreimage, PaymentHash, PaymentSecret}; - use ln::channelmanager::PaymentSendFailure; + use ln::channelmanager::{MppId, PaymentSendFailure}; use ln::features::{InitFeatures, InvoiceFeatures}; use ln::functional_test_utils::*; use ln::msgs; @@ -5515,10 +5620,11 @@ mod tests { let net_graph_msg_handler = &nodes[0].net_graph_msg_handler; let route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph, &nodes[1].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &Vec::new(), 100_000, TEST_FINAL_CLTV, &logger).unwrap(); let (payment_preimage, our_payment_hash, payment_secret) = get_payment_preimage_hash!(&nodes[1]); + let mpp_id = MppId([42; 32]); // Use the utility function send_payment_along_path to send the payment with MPP data which // indicates there are more HTLCs coming. let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match. - nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, &None).unwrap(); + nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, mpp_id, &None).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -5548,7 +5654,7 @@ mod tests { expect_payment_failed!(nodes[0], our_payment_hash, true); // Send the second half of the original MPP payment. - nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, &None).unwrap(); + nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200_000, cur_height, mpp_id, &None).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -5586,7 +5692,8 @@ mod tests { nodes[0].node.handle_revoke_and_ack(&nodes[1].node.get_our_node_id(), &bs_third_raa); check_added_monitors!(nodes[0], 1); - // There's an existing bug that generates a PaymentSent event for each MPP path, so handle that here. + // Note that successful MPP payments will generate 1 event upon the first path's success. No + // further events will be generated for subsequence path successes. let events = nodes[0].node.get_and_clear_pending_events(); match events[0] { Event::PaymentSent { payment_preimage: ref preimage } => { @@ -5594,12 +5701,6 @@ mod tests { }, _ => panic!("Unexpected event"), } - match events[1] { - Event::PaymentSent { payment_preimage: ref preimage } => { - assert_eq!(payment_preimage, *preimage); - }, - _ => panic!("Unexpected event"), - } } #[test] diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 6824fb043a9..dd1fedd9733 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -1041,7 +1041,7 @@ macro_rules! expect_payment_failed_with_update { let events = $node.node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); match events[0] { - Event::PaymentFailed { ref payment_hash, rejected_by_dest, ref network_update, ref error_code, ref error_data } => { + Event::PaymentFailed { ref payment_hash, rejected_by_dest, ref network_update, ref error_code, ref error_data, .. } => { assert_eq!(*payment_hash, $expected_payment_hash, "unexpected payment_hash"); assert_eq!(rejected_by_dest, $rejected_by_dest, "unexpected rejected_by_dest value"); assert!(error_code.is_some(), "expected error_code.is_some() = true"); @@ -1070,7 +1070,7 @@ macro_rules! expect_payment_failed { let events = $node.node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); match events[0] { - Event::PaymentFailed { ref payment_hash, rejected_by_dest, network_update: _, ref error_code, ref error_data } => { + Event::PaymentFailed { ref payment_hash, rejected_by_dest, network_update: _, ref error_code, ref error_data, .. } => { assert_eq!(*payment_hash, $expected_payment_hash, "unexpected payment_hash"); assert_eq!(rejected_by_dest, $rejected_by_dest, "unexpected rejected_by_dest value"); assert!(error_code.is_some(), "expected error_code.is_some() = true"); @@ -1242,9 +1242,11 @@ pub fn claim_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, exp if !skip_last { last_update_fulfill_dance!(origin_node, expected_route.first().unwrap()); - expect_payment_sent!(origin_node, our_payment_preimage); } } + if !skip_last { + expect_payment_sent!(origin_node, our_payment_preimage); + } } pub fn claim_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], our_payment_preimage: PaymentPreimage) { @@ -1287,77 +1289,97 @@ pub fn send_payment<'a, 'b, 'c>(origin: &Node<'a, 'b, 'c>, expected_route: &[&No claim_payment(&origin, expected_route, our_payment_preimage); } -pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], skip_last: bool, our_payment_hash: PaymentHash) { - assert!(expected_route.last().unwrap().node.fail_htlc_backwards(&our_payment_hash)); - expect_pending_htlcs_forwardable!(expected_route.last().unwrap()); - check_added_monitors!(expected_route.last().unwrap(), 1); +pub fn fail_payment_along_route<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_paths_slice: &[&[&Node<'a, 'b, 'c>]], skip_last: bool, our_payment_hash: PaymentHash) { + let mut expected_paths: Vec<_> = expected_paths_slice.iter().collect(); + for path in expected_paths.iter() { + assert_eq!(path.last().unwrap().node.get_our_node_id(), expected_paths[0].last().unwrap().node.get_our_node_id()); + } + assert!(expected_paths[0].last().unwrap().node.fail_htlc_backwards(&our_payment_hash)); + expect_pending_htlcs_forwardable!(expected_paths[0].last().unwrap()); + check_added_monitors!(expected_paths[0].last().unwrap(), expected_paths.len()); - let mut next_msgs: Option<(msgs::UpdateFailHTLC, msgs::CommitmentSigned)> = None; - macro_rules! update_fail_dance { - ($node: expr, $prev_node: expr, $last_node: expr) => { - { - $node.node.handle_update_fail_htlc(&$prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0); - commitment_signed_dance!($node, $prev_node, next_msgs.as_ref().unwrap().1, !$last_node); - if skip_last && $last_node { - expect_pending_htlcs_forwardable!($node); + let mut per_path_msgs: Vec<((msgs::UpdateFailHTLC, msgs::CommitmentSigned), PublicKey)> = Vec::with_capacity(expected_paths.len()); + let events = expected_paths[0].last().unwrap().node.get_and_clear_pending_msg_events(); + assert_eq!(events.len(), expected_paths.len()); + for ev in events.iter() { + let (update_fail, commitment_signed, node_id) = match ev { + &MessageSendEvent::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref update_fee, ref commitment_signed } } => { + assert!(update_add_htlcs.is_empty()); + assert!(update_fulfill_htlcs.is_empty()); + assert_eq!(update_fail_htlcs.len(), 1); + assert!(update_fail_malformed_htlcs.is_empty()); + assert!(update_fee.is_none()); + (update_fail_htlcs[0].clone(), commitment_signed.clone(), node_id.clone()) + }, + _ => panic!("Unexpected event"), + }; + per_path_msgs.push(((update_fail, commitment_signed), node_id)); + } + per_path_msgs.sort_unstable_by(|(_, node_id_a), (_, node_id_b)| node_id_a.cmp(node_id_b)); + expected_paths.sort_unstable_by(|path_a, path_b| path_a[path_a.len() - 2].node.get_our_node_id().cmp(&path_b[path_b.len() - 2].node.get_our_node_id())); + + for (i, (expected_route, (path_msgs, next_hop))) in expected_paths.iter().zip(per_path_msgs.drain(..)).enumerate() { + let mut next_msgs = Some(path_msgs); + let mut expected_next_node = next_hop; + let mut prev_node = expected_route.last().unwrap(); + + for (idx, node) in expected_route.iter().rev().enumerate().skip(1) { + assert_eq!(expected_next_node, node.node.get_our_node_id()); + let update_next_node = !skip_last || idx != expected_route.len() - 1; + if next_msgs.is_some() { + node.node.handle_update_fail_htlc(&prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0); + commitment_signed_dance!(node, prev_node, next_msgs.as_ref().unwrap().1, update_next_node); + if !update_next_node { + expect_pending_htlcs_forwardable!(node); } } - } - } + let events = node.node.get_and_clear_pending_msg_events(); + if update_next_node { + assert_eq!(events.len(), 1); + match events[0] { + MessageSendEvent::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref update_fee, ref commitment_signed } } => { + assert!(update_add_htlcs.is_empty()); + assert!(update_fulfill_htlcs.is_empty()); + assert_eq!(update_fail_htlcs.len(), 1); + assert!(update_fail_malformed_htlcs.is_empty()); + assert!(update_fee.is_none()); + expected_next_node = node_id.clone(); + next_msgs = Some((update_fail_htlcs[0].clone(), commitment_signed.clone())); + }, + _ => panic!("Unexpected event"), + } + } else { + assert!(events.is_empty()); + } + if !skip_last && idx == expected_route.len() - 1 { + assert_eq!(expected_next_node, origin_node.node.get_our_node_id()); + } - let mut expected_next_node = expected_route.last().unwrap().node.get_our_node_id(); - let mut prev_node = expected_route.last().unwrap(); - for (idx, node) in expected_route.iter().rev().enumerate() { - assert_eq!(expected_next_node, node.node.get_our_node_id()); - if next_msgs.is_some() { - // We may be the "last node" for the purpose of the commitment dance if we're - // skipping the last node (implying it is disconnected) and we're the - // second-to-last node! - update_fail_dance!(node, prev_node, skip_last && idx == expected_route.len() - 1); + prev_node = node; } - let events = node.node.get_and_clear_pending_msg_events(); - if !skip_last || idx != expected_route.len() - 1 { + if !skip_last { + let prev_node = expected_route.first().unwrap(); + origin_node.node.handle_update_fail_htlc(&prev_node.node.get_our_node_id(), &next_msgs.as_ref().unwrap().0); + check_added_monitors!(origin_node, 0); + assert!(origin_node.node.get_and_clear_pending_msg_events().is_empty()); + commitment_signed_dance!(origin_node, prev_node, next_msgs.as_ref().unwrap().1, false); + let events = origin_node.node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); match events[0] { - MessageSendEvent::UpdateHTLCs { ref node_id, updates: msgs::CommitmentUpdate { ref update_add_htlcs, ref update_fulfill_htlcs, ref update_fail_htlcs, ref update_fail_malformed_htlcs, ref update_fee, ref commitment_signed } } => { - assert!(update_add_htlcs.is_empty()); - assert!(update_fulfill_htlcs.is_empty()); - assert_eq!(update_fail_htlcs.len(), 1); - assert!(update_fail_malformed_htlcs.is_empty()); - assert!(update_fee.is_none()); - expected_next_node = node_id.clone(); - next_msgs = Some((update_fail_htlcs[0].clone(), commitment_signed.clone())); + Event::PaymentFailed { payment_hash, rejected_by_dest, all_paths_failed, .. } => { + assert_eq!(payment_hash, our_payment_hash); + assert!(rejected_by_dest); + assert_eq!(all_paths_failed, i == expected_paths.len() - 1); }, _ => panic!("Unexpected event"), } - } else { - assert!(events.is_empty()); - } - if !skip_last && idx == expected_route.len() - 1 { - assert_eq!(expected_next_node, origin_node.node.get_our_node_id()); - } - - prev_node = node; - } - - if !skip_last { - update_fail_dance!(origin_node, expected_route.first().unwrap(), true); - - let events = origin_node.node.get_and_clear_pending_events(); - assert_eq!(events.len(), 1); - match events[0] { - Event::PaymentFailed { payment_hash, rejected_by_dest, .. } => { - assert_eq!(payment_hash, our_payment_hash); - assert!(rejected_by_dest); - }, - _ => panic!("Unexpected event"), } } } -pub fn fail_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_route: &[&Node<'a, 'b, 'c>], our_payment_hash: PaymentHash) { - fail_payment_along_route(origin_node, expected_route, false, our_payment_hash); +pub fn fail_payment<'a, 'b, 'c>(origin_node: &Node<'a, 'b, 'c>, expected_path: &[&Node<'a, 'b, 'c>], our_payment_hash: PaymentHash) { + fail_payment_along_route(origin_node, &[&expected_path[..]], false, our_payment_hash); } pub fn create_chanmon_cfgs(node_count: usize) -> Vec { diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index e1a57cd7cc1..85eef456c21 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -19,7 +19,7 @@ use chain::transaction::OutPoint; use chain::keysinterface::BaseSign; use ln::{PaymentPreimage, PaymentSecret, PaymentHash}; use ln::channel::{COMMITMENT_TX_BASE_WEIGHT, COMMITMENT_TX_WEIGHT_PER_HTLC}; -use ln::channelmanager::{ChannelManager, ChannelManagerReadArgs, RAACommitmentOrder, PaymentSendFailure, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA}; +use ln::channelmanager::{ChannelManager, ChannelManagerReadArgs, MppId, RAACommitmentOrder, PaymentSendFailure, BREAKDOWN_TIMEOUT, MIN_CLTV_EXPIRY_DELTA}; use ln::channel::{Channel, ChannelError}; use ln::{chan_utils, onion_utils}; use ln::chan_utils::HTLC_SUCCESS_TX_WEIGHT; @@ -3308,7 +3308,7 @@ fn test_simple_peer_disconnect() { nodes[1].node.peer_disconnected(&nodes[0].node.get_our_node_id(), false); claim_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], true, payment_preimage_3); - fail_payment_along_route(&nodes[0], &[&nodes[1], &nodes[2]], true, payment_hash_5); + fail_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[2]]], true, payment_hash_5); reconnect_nodes(&nodes[0], &nodes[1], (false, false), (0, 0), (0, 0), (0, 0), (1, 0), (1, 0), (false, false)); { @@ -3886,7 +3886,8 @@ fn do_test_htlc_timeout(send_partial_mpp: bool) { // Use the utility function send_payment_along_path to send the payment with MPP data which // indicates there are more HTLCs coming. let cur_height = CHAN_CONFIRM_DEPTH + 1; // route_payment calls send_payment, which adds 1 to the current height. So we do the same here to match. - nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200000, cur_height, &None).unwrap(); + let mpp_id = MppId([42; 32]); + nodes[0].node.send_payment_along_path(&route.paths[0], &our_payment_hash, &Some(payment_secret), 200000, cur_height, mpp_id, &None).unwrap(); check_added_monitors!(nodes[0], 1); let mut events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); @@ -4083,6 +4084,34 @@ fn test_no_txn_manager_serialize_deserialize() { send_payment(&nodes[0], &[&nodes[1]], 1000000); } +#[test] +fn mpp_failure() { + let chanmon_cfgs = create_chanmon_cfgs(4); + let node_cfgs = create_node_cfgs(4, &chanmon_cfgs); + let node_chanmgrs = create_node_chanmgrs(4, &node_cfgs, &[None, None, None, None]); + let nodes = create_network(4, &node_cfgs, &node_chanmgrs); + + let chan_1_id = create_announced_chan_between_nodes(&nodes, 0, 1, InitFeatures::known(), InitFeatures::known()).0.contents.short_channel_id; + let chan_2_id = create_announced_chan_between_nodes(&nodes, 0, 2, InitFeatures::known(), InitFeatures::known()).0.contents.short_channel_id; + let chan_3_id = create_announced_chan_between_nodes(&nodes, 1, 3, InitFeatures::known(), InitFeatures::known()).0.contents.short_channel_id; + let chan_4_id = create_announced_chan_between_nodes(&nodes, 2, 3, InitFeatures::known(), InitFeatures::known()).0.contents.short_channel_id; + let logger = test_utils::TestLogger::new(); + + let (_, payment_hash, payment_secret) = get_payment_preimage_hash!(&nodes[3]); + let net_graph_msg_handler = &nodes[0].net_graph_msg_handler; + let mut route = get_route(&nodes[0].node.get_our_node_id(), &net_graph_msg_handler.network_graph, &nodes[3].node.get_our_node_id(), Some(InvoiceFeatures::known()), None, &[], 100000, TEST_FINAL_CLTV, &logger).unwrap(); + let path = route.paths[0].clone(); + route.paths.push(path); + route.paths[0][0].pubkey = nodes[1].node.get_our_node_id(); + route.paths[0][0].short_channel_id = chan_1_id; + route.paths[0][1].short_channel_id = chan_3_id; + route.paths[1][0].pubkey = nodes[2].node.get_our_node_id(); + route.paths[1][0].short_channel_id = chan_2_id; + route.paths[1][1].short_channel_id = chan_4_id; + send_along_route_with_secret(&nodes[0], route, &[&[&nodes[1], &nodes[3]], &[&nodes[2], &nodes[3]]], 200_000, payment_hash, payment_secret); + fail_payment_along_route(&nodes[0], &[&[&nodes[1], &nodes[3]], &[&nodes[2], &nodes[3]]], false, payment_hash); +} + #[test] fn test_dup_htlc_onchain_fails_on_reload() { // When a Channel is closed, any outbound HTLCs which were relayed through it are simply @@ -5913,9 +5942,10 @@ fn test_fail_holding_cell_htlc_upon_free() { let events = nodes[0].node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); match &events[0] { - &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref error_code, ref error_data } => { + &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref error_code, ref error_data, ref all_paths_failed } => { assert_eq!(our_payment_hash.clone(), *payment_hash); assert_eq!(*rejected_by_dest, false); + assert_eq!(*all_paths_failed, true); assert_eq!(*network_update, None); assert_eq!(*error_code, None); assert_eq!(*error_data, None); @@ -5999,9 +6029,10 @@ fn test_free_and_fail_holding_cell_htlcs() { let events = nodes[0].node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); match &events[0] { - &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref error_code, ref error_data } => { + &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref error_code, ref error_data, ref all_paths_failed } => { assert_eq!(payment_hash_2.clone(), *payment_hash); assert_eq!(*rejected_by_dest, false); + assert_eq!(*all_paths_failed, true); assert_eq!(*network_update, None); assert_eq!(*error_code, None); assert_eq!(*error_data, None); diff --git a/lightning/src/ln/onion_route_tests.rs b/lightning/src/ln/onion_route_tests.rs index bcc05fdb030..70d5a0c0bdf 100644 --- a/lightning/src/ln/onion_route_tests.rs +++ b/lightning/src/ln/onion_route_tests.rs @@ -163,8 +163,9 @@ fn run_onion_failure_test_with_fail_intercept(_name: &str, test_case: let events = nodes[0].node.get_and_clear_pending_events(); assert_eq!(events.len(), 1); - if let &Event::PaymentFailed { payment_hash:_, ref rejected_by_dest, ref network_update, ref error_code, error_data: _ } = &events[0] { + if let &Event::PaymentFailed { payment_hash:_, ref rejected_by_dest, ref network_update, ref error_code, error_data: _, ref all_paths_failed } = &events[0] { assert_eq!(*rejected_by_dest, !expected_retryable); + assert_eq!(*all_paths_failed, true); assert_eq!(*error_code, expected_error_code); if expected_channel_update.is_some() { match network_update { diff --git a/lightning/src/ln/onion_utils.rs b/lightning/src/ln/onion_utils.rs index f6c62cb83ca..ee3ed96b5ef 100644 --- a/lightning/src/ln/onion_utils.rs +++ b/lightning/src/ln/onion_utils.rs @@ -332,7 +332,7 @@ pub(super) fn build_first_hop_failure_packet(shared_secret: &[u8], failure_type: /// Returns update, a boolean indicating that the payment itself failed, and the error code. #[inline] pub(super) fn process_onion_failure(secp_ctx: &Secp256k1, logger: &L, htlc_source: &HTLCSource, mut packet_decrypted: Vec) -> (Option, bool, Option, Option>) where L::Target: Logger { - if let &HTLCSource::OutboundRoute { ref path, ref session_priv, ref first_hop_htlc_msat } = htlc_source { + if let &HTLCSource::OutboundRoute { ref path, ref session_priv, ref first_hop_htlc_msat, .. } = htlc_source { let mut res = None; let mut htlc_msat = *first_hop_htlc_msat; let mut error_code_ret = None; diff --git a/lightning/src/routing/network_graph.rs b/lightning/src/routing/network_graph.rs index ca656039d12..16ff80189e9 100644 --- a/lightning/src/routing/network_graph.rs +++ b/lightning/src/routing/network_graph.rs @@ -1728,6 +1728,7 @@ mod tests { net_graph_msg_handler.handle_event(&Event::PaymentFailed { payment_hash: PaymentHash([0; 32]), rejected_by_dest: false, + all_paths_failed: true, network_update: Some(NetworkUpdate::ChannelUpdateMessage { msg: valid_channel_update, }), @@ -1750,6 +1751,7 @@ mod tests { net_graph_msg_handler.handle_event(&Event::PaymentFailed { payment_hash: PaymentHash([0; 32]), rejected_by_dest: false, + all_paths_failed: true, network_update: Some(NetworkUpdate::ChannelClosed { short_channel_id, is_permanent: false, @@ -1771,6 +1773,7 @@ mod tests { net_graph_msg_handler.handle_event(&Event::PaymentFailed { payment_hash: PaymentHash([0; 32]), rejected_by_dest: false, + all_paths_failed: true, network_update: Some(NetworkUpdate::ChannelClosed { short_channel_id, is_permanent: true, diff --git a/lightning/src/util/events.rs b/lightning/src/util/events.rs index d63dd88b761..df4d9037307 100644 --- a/lightning/src/util/events.rs +++ b/lightning/src/util/events.rs @@ -112,8 +112,11 @@ pub enum Event { /// payment is to pay an invoice or to send a spontaneous payment. purpose: PaymentPurpose, }, - /// Indicates an outbound payment we made succeeded (ie it made it all the way to its target + /// Indicates an outbound payment we made succeeded (i.e. it made it all the way to its target /// and we got back the payment preimage for it). + /// + /// Note for MPP payments: in rare cases, this event may be preceded by a `PaymentFailed` event. + /// In this situation, you SHOULD treat this payment as having succeeded. PaymentSent { /// The preimage to the hash given to ChannelManager::send_payment. /// Note that this serves as a payment receipt, if you wish to have such a thing, you must @@ -138,6 +141,10 @@ pub enum Event { /// [`NetworkGraph`]: crate::routing::network_graph::NetworkGraph /// [`NetGraphMsgHandler`]: crate::routing::network_graph::NetGraphMsgHandler network_update: Option, + /// For both single-path and multi-path payments, this is set if all paths of the payment have + /// failed. This will be set to false if (1) this is an MPP payment and (2) other parts of the + /// larger MPP payment were still in flight when this event was generated. + all_paths_failed: bool, #[cfg(test)] error_code: Option, #[cfg(test)] @@ -221,7 +228,7 @@ impl Writeable for Event { (0, payment_preimage, required), }); }, - &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, ref network_update, + &Event::PaymentFailed { ref payment_hash, ref rejected_by_dest, ref network_update, ref all_paths_failed, #[cfg(test)] ref error_code, #[cfg(test)] @@ -236,6 +243,7 @@ impl Writeable for Event { (0, payment_hash, required), (1, network_update, option), (2, rejected_by_dest, required), + (3, all_paths_failed, required), }); }, &Event::PendingHTLCsForwardable { time_forwardable: _ } => { @@ -319,15 +327,18 @@ impl MaybeReadable for Event { let mut payment_hash = PaymentHash([0; 32]); let mut rejected_by_dest = false; let mut network_update = None; + let mut all_paths_failed = Some(true); read_tlv_fields!(reader, { (0, payment_hash, required), (1, network_update, ignorable), (2, rejected_by_dest, required), + (3, all_paths_failed, option), }); Ok(Some(Event::PaymentFailed { payment_hash, rejected_by_dest, network_update, + all_paths_failed: all_paths_failed.unwrap(), #[cfg(test)] error_code, #[cfg(test)] diff --git a/lightning/src/util/ser.rs b/lightning/src/util/ser.rs index 0b5036bd749..c76b701817b 100644 --- a/lightning/src/util/ser.rs +++ b/lightning/src/util/ser.rs @@ -528,6 +528,36 @@ impl Readable for HashMap } } +// HashSet +impl Writeable for HashSet +where T: Writeable + Eq + Hash +{ + #[inline] + fn write(&self, w: &mut W) -> Result<(), io::Error> { + (self.len() as u16).write(w)?; + for item in self.iter() { + item.write(w)?; + } + Ok(()) + } +} + +impl Readable for HashSet +where T: Readable + Eq + Hash +{ + #[inline] + fn read(r: &mut R) -> Result { + let len: u16 = Readable::read(r)?; + let mut ret = HashSet::with_capacity(len as usize); + for _ in 0..len { + if !ret.insert(T::read(r)?) { + return Err(DecodeError::InvalidValue) + } + } + Ok(ret) + } +} + // Vectors impl Writeable for Vec { #[inline]