diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index fa8a0b2163d..d26baba8151 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -2020,9 +2020,9 @@ where /// /// See `ChannelManager` struct-level documentation for lock order requirements. #[cfg(not(any(test, feature = "_test_utils")))] - per_peer_state: FairRwLock>>>, + per_peer_state: FairRwLock>>>, #[cfg(any(test, feature = "_test_utils"))] - pub(super) per_peer_state: FairRwLock>>>, + pub(super) per_peer_state: FairRwLock>>>, /// The set of events which we need to give to the user to handle. In some cases an event may /// require some further action after the user handles it (currently only blocking a monitor @@ -2700,8 +2700,8 @@ macro_rules! handle_error { if let Some(msg_event) = msg_event { let per_peer_state = $self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(&$counterparty_node_id) { - let mut peer_state = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&$counterparty_node_id) { + let mut peer_state = peer_state_rwlock.write().unwrap(); peer_state.pending_msg_events.push(msg_event); } } @@ -2944,8 +2944,8 @@ macro_rules! handle_monitor_update_completion { let per_peer_state = $self.per_peer_state.read().unwrap(); let mut batch_funding_tx = None; for (channel_id, counterparty_node_id, _) in removed_batch_state { - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state = peer_state_rwlock.write().unwrap(); if let Some(ChannelPhase::Funded(chan)) = peer_state.channel_by_id.get_mut(&channel_id) { batch_funding_tx = batch_funding_tx.or_else(|| chan.context.unbroadcasted_funding()); chan.set_batch_ready(); @@ -3250,10 +3250,10 @@ where let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(&their_network_key) + let peer_state_rwlock = per_peer_state.get(&their_network_key) .ok_or_else(|| APIError::APIMisuseError{ err: format!("Not connected to node: {}", their_network_key) })?; - let mut peer_state = peer_state_mutex.lock().unwrap(); + let mut peer_state = peer_state_rwlock.write().unwrap(); if let Some(temporary_channel_id) = temporary_channel_id { if peer_state.channel_by_id.contains_key(&temporary_channel_id) { @@ -3308,9 +3308,9 @@ where { let best_block_height = self.best_block.read().unwrap().height; let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; res.extend(peer_state.channel_by_id.iter() .filter_map(|(chan_id, phase)| match phase { // Only `Channels` in the `ChannelPhase::Funded` phase can be considered funded. @@ -3341,9 +3341,9 @@ where { let best_block_height = self.best_block.read().unwrap().height; let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; for context in peer_state.channel_by_id.iter().map(|(_, phase)| phase.context()) { let details = ChannelDetails::from_channel_context(context, best_block_height, peer_state.latest_features.clone(), &self.fee_estimator); @@ -3372,9 +3372,9 @@ where let best_block_height = self.best_block.read().unwrap().height; let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + if let Some(peer_state_rwlock) = per_peer_state.get(counterparty_node_id) { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; let features = &peer_state.latest_features; let context_to_details = |context| { ChannelDetails::from_channel_context(context, best_block_height, features.clone(), &self.fee_estimator) @@ -3433,10 +3433,10 @@ where { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(channel_id.clone()) { @@ -3586,8 +3586,8 @@ where let per_peer_state = self.per_peer_state.read().unwrap(); let mut has_uncompleted_channel = None; for (channel_id, counterparty_node_id, state) in affected_channels { - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state = peer_state_rwlock.write().unwrap(); if let Some(mut chan) = peer_state.channel_by_id.remove(&channel_id) { update_maps_on_chan_removal!(self, &chan.context()); shutdown_results.push(chan.context_mut().force_shutdown(false, ClosureReason::FundingBatchClosure)); @@ -3628,10 +3628,10 @@ where fn force_close_channel_with_peer(&self, channel_id: &ChannelId, peer_node_id: &PublicKey, peer_msg: Option<&String>, broadcast: bool) -> Result { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(peer_node_id) + let peer_state_rwlock = per_peer_state.get(peer_node_id) .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", peer_node_id) })?; let (update_opt, counterparty_node_id) = { - let mut peer_state = peer_state_mutex.lock().unwrap(); + let mut peer_state = peer_state_rwlock.write().unwrap(); let closure_reason = if let Some(peer_msg) = peer_msg { ClosureReason::CounterpartyForceClosed { peer_msg: UntrustedString(peer_msg.to_string()) } } else { @@ -3687,8 +3687,8 @@ where match self.force_close_channel_with_peer(channel_id, counterparty_node_id, None, broadcast) { Ok(counterparty_node_id) => { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state = peer_state_rwlock.write().unwrap(); peer_state.pending_msg_events.push( events::MessageSendEvent::HandleError { node_id: counterparty_node_id, @@ -3794,11 +3794,11 @@ where Some((cp_id, id)) => (cp_id, id), }; let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id); - if peer_state_mutex_opt.is_none() { + let peer_state_rwlock_opt = per_peer_state.get(&counterparty_node_id); + if peer_state_rwlock_opt.is_none() { return None; } - let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock_opt.unwrap().write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.get_mut(&channel_id).and_then( |chan_phase| if let ChannelPhase::Funded(chan) = chan_phase { Some(chan) } else { None } @@ -4109,9 +4109,9 @@ where payment_hash, path.hops.first().unwrap().short_channel_id); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(&counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(&counterparty_node_id) .ok_or_else(|| APIError::ChannelUnavailable{err: "No peer matching the path's first hop found!".to_owned() })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; if let hash_map::Entry::Occupied(mut chan_phase_entry) = peer_state.channel_by_id.entry(id) { match chan_phase_entry.get_mut() { @@ -4473,10 +4473,10 @@ where mut find_funding_output: FundingOutput, ) -> Result<(), APIError> { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let funding_txo; let (mut chan, msg_opt) = match peer_state.channel_by_id.remove(temporary_channel_id) { @@ -4705,7 +4705,7 @@ where let per_peer_state = self.per_peer_state.read().unwrap(); for (channel_id, counterparty_node_id) in channels_to_remove { per_peer_state.get(&counterparty_node_id) - .map(|peer_state_mutex| peer_state_mutex.lock().unwrap()) + .map(|peer_state_rwlock| peer_state_rwlock.write().unwrap()) .and_then(|mut peer_state| peer_state.channel_by_id.remove(&channel_id)) .map(|mut chan| { update_maps_on_chan_removal!(self, &chan.context()); @@ -4755,9 +4755,9 @@ where let _persistence_guard = PersistenceNotifierGuard::notify_on_drop(self); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; for channel_id in channel_ids { @@ -4857,10 +4857,10 @@ where let next_hop_scid = { let peer_state_lock = self.per_peer_state.read().unwrap(); - let peer_state_mutex = peer_state_lock.get(&next_node_id) + let peer_state_rwlock = peer_state_lock.get(&next_node_id) .ok_or_else(|| APIError::ChannelUnavailable { err: format!("Can't find a peer matching the passed counterparty node_id {}", next_node_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; match peer_state.channel_by_id.get(next_hop_channel_id) { Some(ChannelPhase::Funded(chan)) => { if !chan.context.is_usable() { @@ -5226,12 +5226,12 @@ where }; forwarding_counterparty = Some(counterparty_node_id); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id); - if peer_state_mutex_opt.is_none() { + let peer_state_rwlock_opt = per_peer_state.get(&counterparty_node_id); + if peer_state_rwlock_opt.is_none() { forwarding_channel_not_found!(); continue; } - let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock_opt.unwrap().write().unwrap(); let peer_state = &mut *peer_state_lock; if let Some(ChannelPhase::Funded(ref mut chan)) = peer_state.channel_by_id.get_mut(&forward_chan_id) { let logger = WithChannelContext::from(&self.logger, &chan.context); @@ -5625,8 +5625,8 @@ where let mut updated_chan = false; { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(channel_id) { hash_map::Entry::Occupied(mut chan_phase) => { @@ -5649,8 +5649,8 @@ where }, BackgroundEvent::MonitorUpdatesComplete { counterparty_node_id, channel_id } => { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; if let Some(ChannelPhase::Funded(chan)) = peer_state.channel_by_id.get_mut(&channel_id) { handle_monitor_update_completion!(self, peer_state_lock, peer_state, per_peer_state, chan); @@ -5709,8 +5709,8 @@ where let anchor_feerate = self.fee_estimator.bounded_sat_per_1000_weight(ConfirmationTarget::AnchorChannelFee); let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; for (chan_id, chan) in peer_state.channel_by_id.iter_mut().filter_map( |(chan_id, phase)| if let ChannelPhase::Funded(chan) = phase { Some((chan_id, chan)) } else { None } @@ -5793,8 +5793,8 @@ where { let per_peer_state = self.per_peer_state.read().unwrap(); - for (counterparty_node_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (counterparty_node_id, peer_state_rwlock) in per_peer_state.iter() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let pending_msg_events = &mut peer_state.pending_msg_events; let counterparty_node_id = *counterparty_node_id; @@ -5932,7 +5932,7 @@ where // Remove the entry if the peer is still disconnected and we still // have no channels to the peer. let remove_entry = { - let peer_state = entry.get().lock().unwrap(); + let peer_state = entry.get().write().unwrap(); peer_state.ok_to_remove(true) }; if remove_entry { @@ -6121,8 +6121,8 @@ where ) { let (failure_code, onion_failure_data) = { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(channel_id) { hash_map::Entry::Occupied(chan_phase_entry) => { @@ -6427,7 +6427,7 @@ where let peer_state_opt = counterparty_node_id_opt.as_ref().map( |counterparty_node_id| per_peer_state.get(counterparty_node_id) - .map(|peer_mutex| peer_mutex.lock().unwrap()) + .map(|peer_mutex| peer_mutex.write().unwrap()) ).unwrap_or(None); if peer_state_opt.is_some() { @@ -6484,8 +6484,8 @@ where "Duplicate claims should always free another channel immediately"); return Ok(()); }; - if let Some(peer_state_mtx) = per_peer_state.get(&node_id) { - let mut peer_state = peer_state_mtx.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&node_id) { + let mut peer_state = peer_state_rwlock.write().unwrap(); if let Some(blockers) = peer_state .actions_blocking_raa_monitor_updates .get_mut(&channel_id) @@ -6843,9 +6843,9 @@ where }; let per_peer_state = self.per_peer_state.read().unwrap(); let mut peer_state_lock; - let peer_state_mutex_opt = per_peer_state.get(&counterparty_node_id); - if peer_state_mutex_opt.is_none() { return } - peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let peer_state_rwlock_opt = per_peer_state.get(&counterparty_node_id); + if peer_state_rwlock_opt.is_none() { return } + peer_state_lock = peer_state_rwlock_opt.unwrap().write().unwrap(); let peer_state = &mut *peer_state_lock; let channel = if let Some(ChannelPhase::Funded(chan)) = peer_state.channel_by_id.get_mut(channel_id) { @@ -6923,14 +6923,14 @@ where let peers_without_funded_channels = self.peers_without_funded_channels(|peer| { peer.total_channel_count() > 0 }); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { let err_str = format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id); log_error!(logger, "{}", err_str); APIError::ChannelUnavailable { err: err_str } })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let is_only_peer_channel = peer_state.total_channel_count() == 1; @@ -7028,7 +7028,7 @@ where { let peer_state_lock = self.per_peer_state.read().unwrap(); for (_, peer_mtx) in peer_state_lock.iter() { - let peer = peer_mtx.lock().unwrap(); + let peer = peer_mtx.write().unwrap(); if !maybe_count_peer(&*peer) { continue; } let num_unfunded_channels = Self::unfunded_channel_count(&peer, best_block_height); if num_unfunded_channels == peer.total_channel_count() { @@ -7104,14 +7104,14 @@ where self.peers_without_funded_channels(|node| node.total_channel_count() > 0); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close( format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.common_fields.temporary_channel_id.clone()) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; // If this peer already has some channels, a new channel won't increase our number of peers @@ -7205,12 +7205,12 @@ where // likely to be lost on restart! let (value, output_script, user_id) = { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.common_fields.temporary_channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.common_fields.temporary_channel_id) { hash_map::Entry::Occupied(mut phase) => { @@ -7242,13 +7242,13 @@ where let best_block = *self.best_block.read().unwrap(); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.temporary_channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let (mut chan, funding_msg_opt, monitor) = match peer_state.channel_by_id.remove(&msg.temporary_channel_id) { @@ -7338,13 +7338,13 @@ where fn internal_funding_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::FundingSigned) -> Result<(), MsgHandleErrInternal> { let best_block = *self.best_block.read().unwrap(); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(chan_phase_entry) => { @@ -7399,12 +7399,12 @@ where // Note that the ChannelManager is NOT re-persisted on disk after this (unless we error // closing a channel), so any changes are likely to be lost on restart! let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7455,12 +7455,12 @@ where let mut finish_shutdown = None; { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; if let hash_map::Entry::Occupied(mut chan_phase_entry) = peer_state.channel_by_id.entry(msg.channel_id.clone()) { let phase = chan_phase_entry.get_mut(); @@ -7527,13 +7527,13 @@ where fn internal_closing_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::ClosingSigned) -> Result<(), MsgHandleErrInternal> { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; let (tx, chan_option, shutdown_result) = { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id.clone()) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7597,12 +7597,12 @@ where let decoded_hop_res = self.decode_update_add_htlc_onion(msg, counterparty_node_id); let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7669,12 +7669,12 @@ where let next_user_channel_id; let (htlc_source, forwarded_htlc_value, skimmed_fee_msat) = { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7718,12 +7718,12 @@ where // Note that the ChannelManager is NOT re-persisted on disk after this (unless we error // closing a channel), so any changes are likely to be lost on restart! let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7743,12 +7743,12 @@ where // Note that the ChannelManager is NOT re-persisted on disk after this (unless we error // closing a channel), so any changes are likely to be lost on restart! let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7770,12 +7770,12 @@ where fn internal_commitment_signed(&self, counterparty_node_id: &PublicKey, msg: &msgs::CommitmentSigned) -> Result<(), MsgHandleErrInternal> { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7944,8 +7944,8 @@ where counterparty_node_id: PublicKey, channel_id: ChannelId ) -> bool { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mtx) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state_lck = peer_state_mtx.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state_lck = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lck; if let Some(chan) = peer_state.channel_by_id.get(&channel_id) { @@ -7963,7 +7963,7 @@ where .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) - }).map(|mtx| mtx.lock().unwrap())?; + }).map(|mtx| mtx.write().unwrap())?; let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -7998,12 +7998,12 @@ where fn internal_update_fee(&self, counterparty_node_id: &PublicKey, msg: &msgs::UpdateFee) -> Result<(), MsgHandleErrInternal> { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -8022,12 +8022,12 @@ where fn internal_announcement_signatures(&self, counterparty_node_id: &PublicKey, msg: &msgs::AnnouncementSignatures) -> Result<(), MsgHandleErrInternal> { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close(format!("Can't find a peer matching the passed counterparty node_id {}", counterparty_node_id), msg.channel_id) })?; - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -8065,11 +8065,11 @@ where } }; let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(&chan_counterparty_node_id); - if peer_state_mutex_opt.is_none() { + let peer_state_rwlock_opt = per_peer_state.get(&chan_counterparty_node_id); + if peer_state_rwlock_opt.is_none() { return Ok(NotifyOption::SkipPersistNoEvents) } - let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock_opt.unwrap().write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(chan_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -8111,7 +8111,7 @@ where let need_lnd_workaround = { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex = per_peer_state.get(counterparty_node_id) + let peer_state_rwlock = per_peer_state.get(counterparty_node_id) .ok_or_else(|| { debug_assert!(false); MsgHandleErrInternal::send_err_msg_no_close( @@ -8120,7 +8120,7 @@ where ) })?; let logger = WithContext::from(&self.logger, Some(*counterparty_node_id), Some(msg.channel_id)); - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.entry(msg.channel_id) { hash_map::Entry::Occupied(mut chan_phase_entry) => { @@ -8242,8 +8242,8 @@ where }; if let Some(counterparty_node_id) = counterparty_node_id_opt { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let pending_msg_events = &mut peer_state.pending_msg_events; if let hash_map::Entry::Occupied(chan_phase_entry) = peer_state.channel_by_id.entry(channel_id) { @@ -8307,9 +8307,9 @@ where // manage to go through all our peers without finding a single channel to update. 'peer_loop: loop { let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { 'chan_loop: loop { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state: &mut PeerState<_> = &mut *peer_state_lock; for (channel_id, chan) in peer_state.channel_by_id.iter_mut().filter_map( |(chan_id, phase)| if let ChannelPhase::Funded(chan) = phase { Some((chan_id, chan)) } else { None } @@ -8389,16 +8389,16 @@ where let per_peer_state = self.per_peer_state.read().unwrap(); if let Some((counterparty_node_id, channel_id)) = channel_opt { - if let Some(peer_state_mutex) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; if let Some(chan) = peer_state.channel_by_id.get_mut(&channel_id) { unblock_chan(chan, &mut peer_state.pending_msg_events); } } } else { - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; for (_, chan) in peer_state.channel_by_id.iter_mut() { unblock_chan(chan, &mut peer_state.pending_msg_events); @@ -8417,8 +8417,8 @@ where { let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let pending_msg_events = &mut peer_state.pending_msg_events; peer_state.channel_by_id.retain(|channel_id, phase| { @@ -8962,7 +8962,7 @@ where let peers = self.per_peer_state.read().unwrap() .iter() - .filter(|(_, peer)| peer.lock().unwrap().latest_features.supports_onion_messages()) + .filter(|(_, peer)| peer.write().unwrap().latest_features.supports_onion_messages()) .map(|(node_id, _)| *node_id) .collect::>(); @@ -9045,9 +9045,9 @@ where let mut inflight_htlcs = InFlightHtlcs::new(); let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; for chan in peer_state.channel_by_id.values().filter_map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ) { @@ -9105,8 +9105,8 @@ where ); loop { let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mtx) = per_peer_state.get(&counterparty_node_id) { - let mut peer_state_lck = peer_state_mtx.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(&counterparty_node_id) { + let mut peer_state_lck = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lck; if let Some(blocker) = completed_blocker.take() { // Only do this on the first iteration of the loop. @@ -9225,8 +9225,8 @@ where let mut is_any_peer_connected = false; let mut pending_events = Vec::new(); let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; if peer_state.pending_msg_events.len() > 0 { pending_events.append(&mut peer_state.pending_msg_events); @@ -9387,9 +9387,9 @@ where fn get_relevant_txids(&self) -> Vec<(Txid, u32, Option)> { let mut res = Vec::with_capacity(self.short_to_chan_info.read().unwrap().len()); - for (_cp_id, peer_state_mutex) in self.per_peer_state.read().unwrap().iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + for (_cp_id, peer_state_rwlock) in self.per_peer_state.read().unwrap().iter() { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; for chan in peer_state.channel_by_id.values().filter_map(|phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None }) { let txid_opt = chan.context.get_funding_txo(); let height_opt = chan.context.get_funding_tx_confirmation_height(); @@ -9440,8 +9440,8 @@ where let mut timed_out_htlcs = Vec::new(); { let per_peer_state = self.per_peer_state.read().unwrap(); - for (_cp_id, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (_cp_id, peer_state_rwlock) in per_peer_state.iter() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let pending_msg_events = &mut peer_state.pending_msg_events; @@ -9902,8 +9902,8 @@ where "Marking channels with {} disconnected and generating channel_updates.", log_pubkey!(counterparty_node_id) ); - if let Some(peer_state_mutex) = per_peer_state.get(counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let pending_msg_events = &mut peer_state.pending_msg_events; peer_state.channel_by_id.retain(|_, phase| { @@ -10035,7 +10035,7 @@ where res = Err(()); return NotifyOption::SkipPersistNoEvents; } - e.insert(Mutex::new(PeerState { + e.insert(FairRwLock::new(PeerState { channel_by_id: new_hash_map(), inbound_channel_request_by_id: new_hash_map(), latest_features: init_msg.features.clone(), @@ -10047,7 +10047,7 @@ where })); }, hash_map::Entry::Occupied(e) => { - let mut peer_state = e.get().lock().unwrap(); + let mut peer_state = e.get().write().unwrap(); peer_state.latest_features = init_msg.features.clone(); let best_block_height = self.best_block.read().unwrap().height; @@ -10068,8 +10068,8 @@ where log_debug!(logger, "Generating channel_reestablish events for {}", log_pubkey!(counterparty_node_id)); let per_peer_state = self.per_peer_state.read().unwrap(); - if let Some(peer_state_mutex) = per_peer_state.get(counterparty_node_id) { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + if let Some(peer_state_rwlock) = per_peer_state.get(counterparty_node_id) { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; let pending_msg_events = &mut peer_state.pending_msg_events; @@ -10141,9 +10141,9 @@ where self, || -> NotifyOption { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(counterparty_node_id); - if peer_state_mutex_opt.is_none() { return NotifyOption::SkipPersistNoEvents; } - let mut peer_state = peer_state_mutex_opt.unwrap().lock().unwrap(); + let peer_state_rwlock_opt = per_peer_state.get(counterparty_node_id); + if peer_state_rwlock_opt.is_none() { return NotifyOption::SkipPersistNoEvents; } + let mut peer_state = peer_state_rwlock_opt.unwrap().write().unwrap(); if let Some(ChannelPhase::Funded(chan)) = peer_state.channel_by_id.get(&msg.channel_id) { if let Some(msg) = chan.get_outbound_shutdown() { peer_state.pending_msg_events.push(events::MessageSendEvent::SendShutdown { @@ -10179,9 +10179,9 @@ where if msg.channel_id.is_zero() { let channel_ids: Vec = { let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(counterparty_node_id); - if peer_state_mutex_opt.is_none() { return; } - let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let peer_state_rwlock_opt = per_peer_state.get(counterparty_node_id); + if peer_state_rwlock_opt.is_none() { return; } + let mut peer_state_lock = peer_state_rwlock_opt.unwrap().write().unwrap(); let peer_state = &mut *peer_state_lock; // Note that we don't bother generating any events for pre-accept channels - // they're not considered "channels" yet from the PoV of our events interface. @@ -10196,9 +10196,9 @@ where { // First check if we can advance the channel type and try again. let per_peer_state = self.per_peer_state.read().unwrap(); - let peer_state_mutex_opt = per_peer_state.get(counterparty_node_id); - if peer_state_mutex_opt.is_none() { return; } - let mut peer_state_lock = peer_state_mutex_opt.unwrap().lock().unwrap(); + let peer_state_rwlock_opt = per_peer_state.get(counterparty_node_id); + if peer_state_rwlock_opt.is_none() { return; } + let mut peer_state_lock = peer_state_rwlock_opt.unwrap().write().unwrap(); let peer_state = &mut *peer_state_lock; match peer_state.channel_by_id.get_mut(&msg.channel_id) { Some(ChannelPhase::UnfundedOutboundV1(ref mut chan)) => { @@ -11019,9 +11019,9 @@ where { let per_peer_state = self.per_peer_state.read().unwrap(); let mut number_of_funded_channels = 0; - for (_, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + for (_, peer_state_rwlock) in per_peer_state.iter() { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; if !peer_state.ok_to_remove(false) { serializable_peer_count += 1; } @@ -11033,9 +11033,9 @@ where (number_of_funded_channels as u64).write(writer)?; - for (_, peer_state_mutex) in per_peer_state.iter() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); - let peer_state = &mut *peer_state_lock; + for (_, peer_state_rwlock) in per_peer_state.iter() { + let peer_state_lock = peer_state_rwlock.read().unwrap(); + let peer_state = &*peer_state_lock; for channel in peer_state.channel_by_id.iter().filter_map( |(_, phase)| if let ChannelPhase::Funded(channel) = phase { if channel.context.is_funding_broadcast() { Some(channel) } else { None } @@ -11085,11 +11085,11 @@ where let mut monitor_update_blocked_actions_per_peer = None; let mut peer_states = Vec::new(); - for (_, peer_state_mutex) in per_peer_state.iter() { + for (_, peer_state_rwlock) in per_peer_state.iter() { // Because we're holding the owning `per_peer_state` write lock here there's no chance // of a lockorder violation deadlock - no other thread can be holding any // per_peer_state lock at all. - peer_states.push(peer_state_mutex.unsafe_well_ordered_double_lock_self()); + peer_states.push(peer_state_rwlock.unsafe_well_ordered_double_lock_self()); } (serializable_peer_count).write(writer)?; @@ -11611,13 +11611,13 @@ where }; let peer_count: u64 = Readable::read(reader)?; - let mut per_peer_state = hash_map_with_capacity(cmp::min(peer_count as usize, MAX_ALLOC_SIZE/mem::size_of::<(PublicKey, Mutex>)>())); + let mut per_peer_state = hash_map_with_capacity(cmp::min(peer_count as usize, MAX_ALLOC_SIZE/mem::size_of::<(PublicKey, FairRwLock>)>())); for _ in 0..peer_count { let peer_pubkey = Readable::read(reader)?; let peer_chans = funded_peer_channels.remove(&peer_pubkey).unwrap_or(new_hash_map()); let mut peer_state = peer_state_from_chans(peer_chans); peer_state.latest_features = Readable::read(reader)?; - per_peer_state.insert(peer_pubkey, Mutex::new(peer_state)); + per_peer_state.insert(peer_pubkey, FairRwLock::new(peer_state)); } let event_count: u64 = Readable::read(reader)?; @@ -11776,8 +11776,8 @@ where } } } - for (counterparty_id, peer_state_mtx) in per_peer_state.iter_mut() { - let mut peer_state_lock = peer_state_mtx.lock().unwrap(); + for (counterparty_id, peer_state_rwlock) in per_peer_state.iter_mut() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; for phase in peer_state.channel_by_id.values() { if let ChannelPhase::Funded(chan) = phase { @@ -11825,10 +11825,10 @@ where // Now that we've removed all the in-flight monitor updates for channels that are // still open, we need to replay any monitor updates that are for closed channels, // creating the neccessary peer_state entries as we go. - let peer_state_mutex = per_peer_state.entry(counterparty_id).or_insert_with(|| { - Mutex::new(peer_state_from_chans(new_hash_map())) + let peer_state_rwlock = per_peer_state.entry(counterparty_id).or_insert_with(|| { + FairRwLock::new(peer_state_from_chans(new_hash_map())) }); - let mut peer_state = peer_state_mutex.lock().unwrap(); + let mut peer_state = peer_state_rwlock.write().unwrap(); handle_in_flight_updates!(counterparty_id, chan_in_flight_updates, funding_txo, monitor, peer_state, logger, "closed "); } else { @@ -12098,8 +12098,8 @@ where } let mut outbound_scid_aliases = new_hash_set(); - for (_peer_node_id, peer_state_mutex) in per_peer_state.iter_mut() { - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + for (_peer_node_id, peer_state_rwlock) in per_peer_state.iter_mut() { + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; for (chan_id, phase) in peer_state.channel_by_id.iter_mut() { if let ChannelPhase::Funded(chan) = phase { @@ -12169,8 +12169,8 @@ where // restart. let previous_channel_id = claimable_htlc.prev_hop.channel_id; if let Some(peer_node_id) = outpoint_to_peer.get(&claimable_htlc.prev_hop.outpoint) { - let peer_state_mutex = per_peer_state.get(peer_node_id).unwrap(); - let mut peer_state_lock = peer_state_mutex.lock().unwrap(); + let peer_state_rwlock = per_peer_state.get(peer_node_id).unwrap(); + let mut peer_state_lock = peer_state_rwlock.write().unwrap(); let peer_state = &mut *peer_state_lock; if let Some(ChannelPhase::Funded(channel)) = peer_state.channel_by_id.get_mut(&previous_channel_id) { let logger = WithChannelContext::from(&args.logger, &channel.context); @@ -12206,7 +12206,7 @@ where log_trace!(logger, "Holding the next revoke_and_ack from {} until the preimage is durably persisted in the inbound edge's ChannelMonitor", blocked_channel_id); - blocked_peer_state.lock().unwrap().actions_blocking_raa_monitor_updates + blocked_peer_state.write().unwrap().actions_blocking_raa_monitor_updates .entry(*blocked_channel_id) .or_insert_with(Vec::new).push(blocking_action.clone()); } else { @@ -12222,7 +12222,7 @@ where } } } - peer_state.lock().unwrap().monitor_update_blocked_actions = monitor_update_blocked_actions; + peer_state.write().unwrap().monitor_update_blocked_actions = monitor_update_blocked_actions; } else { log_error!(WithContext::from(&args.logger, Some(node_id), None), "Got blocked actions without a per-peer-state for {}", node_id); return Err(DecodeError::InvalidValue); diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 3a506b57fe2..e057ba1a2bf 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -491,7 +491,7 @@ impl<'a, 'b, 'c> Node<'a, 'b, 'c> { log_debug!(self.logger, "Setting channel signer for {} as available={}", chan_id, available); let per_peer_state = self.node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(peer_id).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(peer_id).unwrap().write().unwrap(); let mut channel_keys_id = None; if let Some(chan) = chan_lock.channel_by_id.get(chan_id).map(|phase| phase.context()) { @@ -930,7 +930,7 @@ macro_rules! get_channel_ref { ($node: expr, $counterparty_node: expr, $per_peer_state_lock: ident, $peer_state_lock: ident, $channel_id: expr) => { { $per_peer_state_lock = $node.node.per_peer_state.read().unwrap(); - $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap(); + $peer_state_lock = $per_peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().write().unwrap(); $peer_state_lock.channel_by_id.get_mut(&$channel_id).unwrap() } } @@ -1556,8 +1556,8 @@ macro_rules! check_warn_msg { /// Checks if at least one peer is connected. fn is_any_peer_connected(node: &Node) -> bool { let peer_state = node.node.per_peer_state.read().unwrap(); - for (_, peer_mutex) in peer_state.iter() { - let peer = peer_mutex.lock().unwrap(); + for (_, peer_rwlock) in peer_state.iter() { + let peer = peer_rwlock.read().unwrap(); if peer.is_connected { return true; } } false @@ -2018,8 +2018,8 @@ pub fn do_commitment_signed_dance(node_a: &Node<'_, '_, '_>, node_b: &Node<'_, ' let node_a_per_peer_state = node_a.node.per_peer_state.read().unwrap(); let mut number_of_msg_events = 0; - for (cp_id, peer_state_mutex) in node_a_per_peer_state.iter() { - let peer_state = peer_state_mutex.lock().unwrap(); + for (cp_id, peer_state) in node_a_per_peer_state.iter() { + let peer_state = peer_state.write().unwrap(); let cp_pending_msg_events = &peer_state.pending_msg_events; number_of_msg_events += cp_pending_msg_events.len(); if cp_pending_msg_events.len() == 1 { @@ -2853,7 +2853,7 @@ pub fn pass_claimed_payment_along_route<'a, 'b, 'c, 'd>(args: ClaimAlongRouteArg let (base_fee, prop_fee) = { let per_peer_state = $node.node.per_peer_state.read().unwrap(); let peer_state = per_peer_state.get(&$prev_node.node.get_our_node_id()) - .unwrap().lock().unwrap(); + .unwrap().write().unwrap(); let channel = peer_state.channel_by_id.get(&next_msgs.as_ref().unwrap().0.channel_id).unwrap(); if let Some(prev_config) = channel.context().prev_config() { (prev_config.forwarding_fee_base_msat as u64, @@ -3458,7 +3458,7 @@ pub fn get_announce_close_broadcast_events<'a, 'b, 'c>(nodes: &Vec {{ let peer_state_lock = $node.node.per_peer_state.read().unwrap(); - let chan_lock = peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = peer_state_lock.get(&$counterparty_node.node.get_our_node_id()).unwrap().write().unwrap(); let chan = chan_lock.channel_by_id.get(&$channel_id).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index 5ea3e6372c0..4acb975a6a7 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -699,7 +699,7 @@ fn test_update_fee_that_funder_cannot_afford() { // needed to sign the new commitment tx and (2) sign the new commitment tx. let (local_revocation_basepoint, local_htlc_basepoint, local_funding) = { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let local_chan = chan_lock.channel_by_id.get(&chan.2).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); @@ -710,7 +710,7 @@ fn test_update_fee_that_funder_cannot_afford() { }; let (remote_delayed_payment_basepoint, remote_htlc_basepoint,remote_point, remote_funding) = { let per_peer_state = nodes[1].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[0].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[0].node.get_our_node_id()).unwrap().write().unwrap(); let remote_chan = chan_lock.channel_by_id.get(&chan.2).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); @@ -727,7 +727,7 @@ fn test_update_fee_that_funder_cannot_afford() { let res = { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let local_chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let local_chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let local_chan = local_chan_lock.channel_by_id.get(&chan.2).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); @@ -1429,7 +1429,7 @@ fn test_fee_spike_violation_fails_htlc() { // needed to sign the new commitment tx and (2) sign the new commitment tx. let (local_revocation_basepoint, local_htlc_basepoint, local_secret, next_local_point, local_funding) = { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let local_chan = chan_lock.channel_by_id.get(&chan.2).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); @@ -1445,7 +1445,7 @@ fn test_fee_spike_violation_fails_htlc() { }; let (remote_delayed_payment_basepoint, remote_htlc_basepoint, remote_point, remote_funding) = { let per_peer_state = nodes[1].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[0].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[0].node.get_our_node_id()).unwrap().write().unwrap(); let remote_chan = chan_lock.channel_by_id.get(&chan.2).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); @@ -1476,7 +1476,7 @@ fn test_fee_spike_violation_fails_htlc() { let res = { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let local_chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let local_chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let local_chan = local_chan_lock.channel_by_id.get(&chan.2).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap(); @@ -3236,7 +3236,7 @@ fn do_test_commitment_revoked_fail_backward_exhaustive(deliver_bs_raa: bool, use // The dust limit applied to HTLC outputs considers the fee of the HTLC transaction as // well, so HTLCs at exactly the dust limit will not be included in commitment txn. nodes[2].node.per_peer_state.read().unwrap().get(&nodes[1].node.get_our_node_id()) - .unwrap().lock().unwrap().channel_by_id.get(&chan_2.2).unwrap().context().holder_dust_limit_satoshis * 1000 + .unwrap().write().unwrap().channel_by_id.get(&chan_2.2).unwrap().context().holder_dust_limit_satoshis * 1000 } else { 3000000 }; let (_, first_payment_hash, ..) = route_payment(&nodes[0], &[&nodes[1], &nodes[2]], value); @@ -5209,7 +5209,7 @@ fn do_test_fail_backwards_unrevoked_remote_announce(deliver_last_raa: bool, anno assert_eq!(get_local_commitment_txn!(nodes[3], chan_2_3.2)[0].output.len(), 2); let ds_dust_limit = nodes[3].node.per_peer_state.read().unwrap().get(&nodes[2].node.get_our_node_id()) - .unwrap().lock().unwrap().channel_by_id.get(&chan_2_3.2).unwrap().context().holder_dust_limit_satoshis; + .unwrap().write().unwrap().channel_by_id.get(&chan_2_3.2).unwrap().context().holder_dust_limit_satoshis; // 0th HTLC: let (_, payment_hash_1, ..) = route_payment(&nodes[0], &[&nodes[2], &nodes[3], &nodes[4]], ds_dust_limit*1000); // not added < dust limit + HTLC tx fee // 1st HTLC: @@ -6344,7 +6344,7 @@ fn test_update_add_htlc_bolt2_sender_exceed_max_htlc_num_and_htlc_id_increment() let mut nodes = create_network(2, &node_cfgs, &node_chanmgrs); let chan = create_announced_chan_between_nodes_with_value(&nodes, 0, 1, 1000000, 0); let max_accepted_htlcs = nodes[1].node.per_peer_state.read().unwrap().get(&nodes[0].node.get_our_node_id()) - .unwrap().lock().unwrap().channel_by_id.get(&chan.2).unwrap().context().counterparty_max_accepted_htlcs as u64; + .unwrap().write().unwrap().channel_by_id.get(&chan.2).unwrap().context().counterparty_max_accepted_htlcs as u64; // Fetch a route in advance as we will be unable to once we're unable to send. let (route, our_payment_hash, _, our_payment_secret) = get_route_and_payment_hash!(nodes[0], nodes[1], 100000); @@ -6415,7 +6415,7 @@ fn test_update_add_htlc_bolt2_receiver_check_amount_received_more_than_min() { let htlc_minimum_msat: u64; { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let channel = chan_lock.channel_by_id.get(&chan.2).unwrap(); htlc_minimum_msat = channel.context().get_holder_htlc_minimum_msat(); } @@ -7021,7 +7021,7 @@ fn do_test_failure_delay_dust_htlc_local_commitment(announce_latest: bool) { let chan =create_announced_chan_between_nodes(&nodes, 0, 1); let bs_dust_limit = nodes[1].node.per_peer_state.read().unwrap().get(&nodes[0].node.get_our_node_id()) - .unwrap().lock().unwrap().channel_by_id.get(&chan.2).unwrap().context().holder_dust_limit_satoshis; + .unwrap().write().unwrap().channel_by_id.get(&chan.2).unwrap().context().holder_dust_limit_satoshis; // We route 2 dust-HTLCs between A and B let (_, payment_hash_1, ..) = route_payment(&nodes[0], &[&nodes[1]], bs_dust_limit*1000); @@ -7114,7 +7114,7 @@ fn do_test_sweep_outbound_htlc_failure_update(revoked: bool, local: bool) { let chan = create_announced_chan_between_nodes(&nodes, 0, 1); let bs_dust_limit = nodes[1].node.per_peer_state.read().unwrap().get(&nodes[0].node.get_our_node_id()) - .unwrap().lock().unwrap().channel_by_id.get(&chan.2).unwrap().context().holder_dust_limit_satoshis; + .unwrap().write().unwrap().channel_by_id.get(&chan.2).unwrap().context().holder_dust_limit_satoshis; let (_payment_preimage_1, dust_hash, ..) = route_payment(&nodes[0], &[&nodes[1]], bs_dust_limit*1000); let (_payment_preimage_2, non_dust_hash, ..) = route_payment(&nodes[0], &[&nodes[1]], 1000000); @@ -7796,7 +7796,7 @@ fn test_counterparty_raa_skip_no_crash() { let next_per_commitment_point; { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let mut guard = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let mut guard = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let keys = guard.channel_by_id.get_mut(&channel_id).map( |phase| if let ChannelPhase::Funded(chan) = phase { Some(chan) } else { None } ).flatten().unwrap().get_signer(); @@ -9227,7 +9227,7 @@ fn test_duplicate_chan_id() { let funding_created = { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let mut a_peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let mut a_peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); // Once we call `get_funding_created` the channel has a duplicate channel_id as // another channel in the ChannelManager - an invalid state. Thus, we'd panic later when we // try to create another channel. Instead, we drop the channel entirely here (leaving the @@ -9942,7 +9942,7 @@ fn do_test_max_dust_htlc_exposure(dust_outbound_balance: bool, exposure_breach_e let (dust_buffer_feerate, max_dust_htlc_exposure_msat) = { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); let chan = chan_lock.channel_by_id.get(&channel_id).unwrap(); (chan.context().get_dust_buffer_feerate(None) as u64, chan.context().get_max_dust_htlc_exposure_msat(&LowerBoundedFeeEstimator(nodes[0].fee_estimator))) @@ -10440,7 +10440,7 @@ fn test_remove_expired_outbound_unfunded_channels() { // Asserts the outbound channel has been removed from a nodes[0]'s peer state map. let check_outbound_channel_existence = |should_exist: bool| { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); assert_eq!(chan_lock.channel_by_id.contains_key(&temp_channel_id), should_exist); }; @@ -10491,7 +10491,7 @@ fn test_remove_expired_inbound_unfunded_channels() { // Asserts the inbound channel has been removed from a nodes[1]'s peer state map. let check_inbound_channel_existence = |should_exist: bool| { let per_peer_state = nodes[1].node.per_peer_state.read().unwrap(); - let chan_lock = per_peer_state.get(&nodes[0].node.get_our_node_id()).unwrap().lock().unwrap(); + let chan_lock = per_peer_state.get(&nodes[0].node.get_our_node_id()).unwrap().write().unwrap(); assert_eq!(chan_lock.channel_by_id.contains_key(&temp_channel_id), should_exist); }; diff --git a/lightning/src/ln/onion_route_tests.rs b/lightning/src/ln/onion_route_tests.rs index aeb175bc626..f7f7c55d0f6 100644 --- a/lightning/src/ln/onion_route_tests.rs +++ b/lightning/src/ln/onion_route_tests.rs @@ -515,7 +515,7 @@ fn test_onion_failure() { let short_channel_id = channels[1].0.contents.short_channel_id; let amt_to_forward = nodes[1].node.per_peer_state.read().unwrap().get(&nodes[2].node.get_our_node_id()) - .unwrap().lock().unwrap().channel_by_id.get(&channels[1].2).unwrap() + .unwrap().write().unwrap().channel_by_id.get(&channels[1].2).unwrap() .context().get_counterparty_htlc_minimum_msat() - 1; let mut bogus_route = route.clone(); let route_len = bogus_route.paths[0].hops.len(); diff --git a/lightning/src/ln/payment_tests.rs b/lightning/src/ln/payment_tests.rs index a75120797ca..988cd117dd2 100644 --- a/lightning/src/ln/payment_tests.rs +++ b/lightning/src/ln/payment_tests.rs @@ -796,7 +796,7 @@ fn do_retry_with_no_persist(confirm_before_reload: bool) { { let per_peer_state = nodes[1].node.per_peer_state.read().unwrap(); let mut peer_state = per_peer_state.get(&nodes[2].node.get_our_node_id()) - .unwrap().lock().unwrap(); + .unwrap().write().unwrap(); let mut channel = peer_state.channel_by_id.get_mut(&chan_id_2).unwrap(); let mut new_config = channel.context().config(); new_config.forwarding_fee_base_msat += 100_000; diff --git a/lightning/src/ln/reorg_tests.rs b/lightning/src/ln/reorg_tests.rs index 62c82b01f59..a11ac7cf5bb 100644 --- a/lightning/src/ln/reorg_tests.rs +++ b/lightning/src/ln/reorg_tests.rs @@ -258,7 +258,7 @@ fn do_test_unconf_chan(reload_node: bool, reorg_after_reload: bool, use_funding_ { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); assert_eq!(peer_state.channel_by_id.len(), 1); assert_eq!(nodes[0].node.short_to_chan_info.read().unwrap().len(), 2); } @@ -294,7 +294,7 @@ fn do_test_unconf_chan(reload_node: bool, reorg_after_reload: bool, use_funding_ { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); assert_eq!(peer_state.channel_by_id.len(), 0); assert_eq!(nodes[0].node.short_to_chan_info.read().unwrap().len(), 0); } @@ -340,7 +340,7 @@ fn do_test_unconf_chan(reload_node: bool, reorg_after_reload: bool, use_funding_ { let per_peer_state = nodes[0].node.per_peer_state.read().unwrap(); - let peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().lock().unwrap(); + let peer_state = per_peer_state.get(&nodes[1].node.get_our_node_id()).unwrap().write().unwrap(); assert_eq!(peer_state.channel_by_id.len(), 0); assert_eq!(nodes[0].node.short_to_chan_info.read().unwrap().len(), 0); }