diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index 47e17918978..68067cfb16d 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -44,13 +44,13 @@ pub struct Connection { } impl Connection { fn schedule_read(peer_manager: Arc, Arc>>, us: Arc>, reader: futures::stream::SplitStream>) { - let us_ref = us.clone(); - let us_close_ref = us.clone(); - let peer_manager_ref = peer_manager.clone(); + let connection = us.clone(); + let connection_close = us.clone(); + let peer_manager_close = peer_manager.clone(); tokio::spawn(reader.for_each(move |b| { let pending_read = b.to_vec(); { - let mut lock = us_ref.lock().unwrap(); + let mut lock = connection.lock().unwrap(); assert!(lock.pending_read.is_empty()); if lock.read_paused { lock.pending_read = pending_read; @@ -60,22 +60,22 @@ impl Connection { } } //TODO: There's a race where we don't meet the requirements of disconnect_socket if its - //called right here, after we release the us_ref lock in the scope above, but before we + //called right here, after we release the connection lock in the scope above, but before we //call read_event! - match peer_manager.read_event(&mut SocketDescriptor::new(us_ref.clone(), peer_manager.clone()), pending_read) { + match peer_manager.read_event(&mut SocketDescriptor::new(connection.clone(), peer_manager.clone()), pending_read) { Ok(pause_read) => { if pause_read { - let mut lock = us_ref.lock().unwrap(); + let mut lock = connection.lock().unwrap(); lock.read_paused = true; } }, Err(e) => { - us_ref.lock().unwrap().need_disconnect = false; + connection.lock().unwrap().need_disconnect = false; return future::Either::B(future::result(Err(std::io::Error::new(std::io::ErrorKind::InvalidData, e)))); } } - if let Err(e) = us_ref.lock().unwrap().event_notify.try_send(()) { + if let Err(e) = connection.lock().unwrap().event_notify.try_send(()) { // Ignore full errors as we just need them to poll after this point, so if the user // hasn't received the last send yet, it doesn't matter. assert!(e.is_full()); @@ -83,8 +83,8 @@ impl Connection { future::Either::B(future::result(Ok(()))) }).then(move |_| { - if us_close_ref.lock().unwrap().need_disconnect { - peer_manager_ref.disconnect_event(&SocketDescriptor::new(us_close_ref, peer_manager_ref.clone())); + if connection_close.lock().unwrap().need_disconnect { + peer_manager_close.disconnect_event(&SocketDescriptor::new(connection_close, peer_manager_close.clone())); println!("Peer disconnected!"); } else { println!("We disconnected peer!"); @@ -101,9 +101,9 @@ impl Connection { })).then(|_| { future::result(Ok(())) })); - let us = Arc::new(Mutex::new(Self { writer: Some(send_sink), event_notify, pending_read: Vec::new(), read_blocker: None, read_paused: false, need_disconnect: true, id: ID_COUNTER.fetch_add(1, Ordering::AcqRel) })); + let connection = Arc::new(Mutex::new(Self { writer: Some(send_sink), event_notify, pending_read: Vec::new(), read_blocker: None, read_paused: false, need_disconnect: true, id: ID_COUNTER.fetch_add(1, Ordering::AcqRel) })); - (reader, us) + (reader, connection) } /// Process incoming messages and feed outgoing messages on the provided socket generated by @@ -112,10 +112,10 @@ impl Connection { /// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on /// ChannelManager and ChannelMonitor objects. pub fn setup_inbound(peer_manager: Arc, Arc>>, event_notify: mpsc::Sender<()>, stream: TcpStream) { - let (reader, us) = Self::new(event_notify, stream); + let (reader, connection) = Self::new(event_notify, stream); - if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(us.clone(), peer_manager.clone())) { - Self::schedule_read(peer_manager, us, reader); + if let Ok(_) = peer_manager.new_inbound_connection(SocketDescriptor::new(connection.clone(), peer_manager.clone())) { + Self::schedule_read(peer_manager, connection, reader); } } @@ -126,11 +126,11 @@ impl Connection { /// You should poll the Receive end of event_notify and call get_and_clear_pending_events() on /// ChannelManager and ChannelMonitor objects. pub fn setup_outbound(peer_manager: Arc, Arc>>, event_notify: mpsc::Sender<()>, their_node_id: PublicKey, stream: TcpStream) { - let (reader, us) = Self::new(event_notify, stream); + let (reader, connection) = Self::new(event_notify, stream); - if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(us.clone(), peer_manager.clone())) { - if SocketDescriptor::new(us.clone(), peer_manager.clone()).send_data(&initial_send, true) == initial_send.len() { - Self::schedule_read(peer_manager, us, reader); + if let Ok(initial_send) = peer_manager.new_outbound_connection(their_node_id, SocketDescriptor::new(connection.clone(), peer_manager.clone())) { + if SocketDescriptor::new(connection.clone(), peer_manager.clone()).send_data(&initial_send, true) == initial_send.len() { + Self::schedule_read(peer_manager, connection, reader); } else { println!("Failed to write first full message to socket!"); } @@ -172,16 +172,16 @@ impl SocketDescriptor { impl peer_handler::SocketDescriptor for SocketDescriptor { fn send_data(&mut self, data: &[u8], resume_read: bool) -> usize { macro_rules! schedule_read { - ($us_ref: expr) => { + ($descriptor: expr) => { tokio::spawn(future::lazy(move || -> Result<(), ()> { let mut read_data = Vec::new(); { - let mut us = $us_ref.conn.lock().unwrap(); - mem::swap(&mut read_data, &mut us.pending_read); + let mut connection = $descriptor.conn.lock().unwrap(); + mem::swap(&mut read_data, &mut connection.pending_read); } if !read_data.is_empty() { - let mut us_clone = $us_ref.clone(); - match $us_ref.peer_manager.read_event(&mut us_clone, read_data) { + //let mut us_clone = $descriptor.clone(); + match $descriptor.peer_manager.read_event(&mut $descriptor.clone(), read_data) { Ok(pause_read) => { if pause_read { return Ok(()); } }, @@ -191,12 +191,12 @@ impl peer_handler::SocketDescriptor for SocketDescri } } } - let mut us = $us_ref.conn.lock().unwrap(); - if let Some(sender) = us.read_blocker.take() { + let mut connection = $descriptor.conn.lock().unwrap(); + if let Some(sender) = connection.read_blocker.take() { sender.send(Ok(())).unwrap(); } - us.read_paused = false; - if let Err(e) = us.event_notify.try_send(()) { + connection.read_paused = false; + if let Err(e) = connection.event_notify.try_send(()) { // Ignore full errors as we just need them to poll after this point, so if the user // hasn't received the last send yet, it doesn't matter. assert!(e.is_full()); @@ -206,20 +206,20 @@ impl peer_handler::SocketDescriptor for SocketDescri } } - let mut us = self.conn.lock().unwrap(); + let mut connection = self.conn.lock().unwrap(); if resume_read { - let us_ref = self.clone(); - schedule_read!(us_ref); + let descriptor = self.clone(); + schedule_read!(descriptor); } if data.is_empty() { return 0; } - if us.writer.is_none() { - us.read_paused = true; + if connection.writer.is_none() { + connection.read_paused = true; return 0; } let mut bytes = bytes::BytesMut::with_capacity(data.len()); bytes.put(data); - let write_res = us.writer.as_mut().unwrap().start_send(bytes.freeze()); + let write_res = connection.writer.as_mut().unwrap().start_send(bytes.freeze()); match write_res { Ok(res) => { match res { @@ -227,15 +227,15 @@ impl peer_handler::SocketDescriptor for SocketDescri data.len() }, AsyncSink::NotReady(_) => { - us.read_paused = true; - let us_ref = self.clone(); - tokio::spawn(us.writer.take().unwrap().flush().then(move |writer_res| -> Result<(), ()> { + connection.read_paused = true; + let descriptor = self.clone(); + tokio::spawn(connection.writer.take().unwrap().flush().then(move |writer_res| -> Result<(), ()> { if let Ok(writer) = writer_res { { - let mut us = us_ref.conn.lock().unwrap(); - us.writer = Some(writer); + let mut connection = descriptor.conn.lock().unwrap(); + connection.writer = Some(writer); } - schedule_read!(us_ref); + schedule_read!(descriptor); } // we'll fire the disconnect event on the socket reader end Ok(()) })); @@ -251,9 +251,9 @@ impl peer_handler::SocketDescriptor for SocketDescri } fn disconnect_socket(&mut self) { - let mut us = self.conn.lock().unwrap(); - us.need_disconnect = true; - us.read_paused = true; + let mut connection = self.conn.lock().unwrap(); + connection.need_disconnect = true; + connection.read_paused = true; } } impl Clone for SocketDescriptor { diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 5838b782f4d..7cebc133ebf 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -18,7 +18,7 @@ use util::byte_utils; use util::events::{MessageSendEvent, MessageSendEventsProvider}; use util::logger::Logger; -use std::collections::{HashMap,hash_map,HashSet,LinkedList}; +use std::collections::{HashMap,hash_map,LinkedList}; use std::sync::{Arc, Mutex}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::{cmp,error,hash,fmt}; @@ -120,6 +120,10 @@ struct Peer { sync_status: InitSyncTracker, awaiting_pong: bool, + + /// Indicates do_read_event() pushed a message into pending_outbound_buffer but didn't call + /// do_attempt_write_data() to avoid reentrancy. Cleared in process_events(). + needing_send: bool, } impl Peer { @@ -140,9 +144,6 @@ impl Peer { struct PeerHolder { peers: HashMap, - /// Added to by do_read_event for cases where we pushed a message onto the send buffer but - /// didn't call do_attempt_write_data to avoid reentrancy. Cleared in process_events() - peers_needing_send: HashSet, /// Only add to this set when noise completes: node_id_to_descriptor: HashMap, } @@ -228,7 +229,6 @@ impl PeerManager where message_handler: message_handler, peers: Mutex::new(PeerHolder { peers: HashMap::new(), - peers_needing_send: HashSet::new(), node_id_to_descriptor: HashMap::new() }), our_node_secret: our_node_secret, @@ -299,6 +299,7 @@ impl PeerManager where sync_status: InitSyncTracker::NoSyncRequested, awaiting_pong: false, + needing_send: false, }).is_some() { panic!("PeerManager driver duplicated descriptors!"); }; @@ -336,6 +337,7 @@ impl PeerManager where sync_status: InitSyncTracker::NoSyncRequested, awaiting_pong: false, + needing_send: false, }).is_some() { panic!("PeerManager driver duplicated descriptors!"); }; @@ -485,7 +487,7 @@ impl PeerManager where { log_trace!(self, "Encoding and sending message of type {} to {}", $msg_code, log_pubkey!(peer.their_node_id.unwrap())); peer.pending_outbound_buffer.push_back(peer.channel_encryptor.encrypt_message(&encode_msg!($msg, $msg_code)[..])); - peers.peers_needing_send.insert(peer_descriptor.clone()); + peer.needing_send = true; } } } @@ -644,7 +646,7 @@ impl PeerManager where if msg.features.initial_routing_sync() { peer.sync_status = InitSyncTracker::ChannelsSyncing(0); - peers.peers_needing_send.insert(peer_descriptor.clone()); + peer.needing_send = true; } if !peer.outbound { @@ -1029,7 +1031,6 @@ impl PeerManager where match *action { msgs::ErrorAction::DisconnectPeer { ref msg } => { if let Some(mut descriptor) = peers.node_id_to_descriptor.remove(node_id) { - peers.peers_needing_send.remove(&descriptor); if let Some(mut peer) = peers.peers.remove(&descriptor) { if let Some(ref msg) = *msg { log_trace!(self, "Handling DisconnectPeer HandleError event in peer_handler for node {} with message {}", @@ -1063,11 +1064,10 @@ impl PeerManager where } } - for mut descriptor in peers.peers_needing_send.drain() { - match peers.peers.get_mut(&descriptor) { - Some(peer) => self.do_attempt_write_data(&mut descriptor, peer), - None => panic!("Inconsistent peers set state!"), - } + let peers_needing_send = peers.peers.iter_mut().filter(|(_, peer)| peer.needing_send); + for (descriptor, peer) in peers_needing_send { + peer.needing_send = false; + self.do_attempt_write_data(&mut descriptor.clone(), peer) } } } @@ -1084,9 +1084,7 @@ impl PeerManager where fn disconnect_event_internal(&self, descriptor: &Descriptor, no_connection_possible: bool) { let mut peers = self.peers.lock().unwrap(); - peers.peers_needing_send.remove(descriptor); - let peer_option = peers.peers.remove(descriptor); - match peer_option { + match peers.peers.remove(descriptor) { None => panic!("Descriptor for disconnect_event is not already known to PeerManager"), Some(peer) => { match peer.their_node_id { @@ -1108,13 +1106,11 @@ impl PeerManager where let mut peers_lock = self.peers.lock().unwrap(); { let peers = &mut *peers_lock; - let peers_needing_send = &mut peers.peers_needing_send; let node_id_to_descriptor = &mut peers.node_id_to_descriptor; let peers = &mut peers.peers; peers.retain(|descriptor, peer| { if peer.awaiting_pong == true { - peers_needing_send.remove(descriptor); match peer.their_node_id { Some(node_id) => { node_id_to_descriptor.remove(&node_id);