From f2e3fe2ce0517b4ed573efd47e58cb808d1927c6 Mon Sep 17 00:00:00 2001 From: Baojun Wang Date: Fri, 20 May 2022 15:29:45 -0400 Subject: [PATCH 1/2] Add Seek instance for io::Take --- library/std/src/io/mod.rs | 49 ++++++++++++++++++++++++++++++++++++- library/std/src/io/tests.rs | 32 ++++++++++++++++++++++++ 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/library/std/src/io/mod.rs b/library/std/src/io/mod.rs index 94812e3fe3b2c..753b182df878c 100644 --- a/library/std/src/io/mod.rs +++ b/library/std/src/io/mod.rs @@ -988,7 +988,7 @@ pub trait Read { where Self: Sized, { - Take { inner: self, limit } + Take { inner: self, limit, cursor: 0 } } } @@ -2408,6 +2408,7 @@ impl SizeHint for Chain { pub struct Take { inner: T, limit: u64, + cursor: u64, } impl Take { @@ -2559,6 +2560,7 @@ impl Read for Take { let max = cmp::min(buf.len() as u64, self.limit) as usize; let n = self.inner.read(&mut buf[..max])?; + self.cursor += n as u64; self.limit -= n as u64; Ok(n) } @@ -2601,10 +2603,12 @@ impl Read for Take { buf.add_filled(filled); + self.cursor += filled as u64; self.limit -= filled as u64; } else { self.inner.read_buf(buf)?; + self.cursor += buf.filled_len().saturating_sub(prev_filled) as u64; //inner may unfill self.limit -= buf.filled_len().saturating_sub(prev_filled) as u64; } @@ -2623,17 +2627,60 @@ impl BufRead for Take { let buf = self.inner.fill_buf()?; let cap = cmp::min(buf.len() as u64, self.limit) as usize; + self.cursor = cap as u64; Ok(&buf[..cap]) } fn consume(&mut self, amt: usize) { // Don't let callers reset the limit by passing an overlarge value let amt = cmp::min(amt as u64, self.limit) as usize; + self.cursor += amt as u64; self.limit -= amt as u64; self.inner.consume(amt); } } +impl Seek for Take { + fn seek(&mut self, pos: SeekFrom) -> Result { + let stream_end = self.cursor + self.limit; + let position = match pos { + SeekFrom::Start(k) => Some(cmp::min(k, stream_end)), + SeekFrom::Current(k) if k < 0 => { + if -k as u64 > self.cursor { + None + } else { + Some(self.cursor - (-k as u64)) + } + } + SeekFrom::Current(k) => Some(cmp::min(self.cursor + k as u64, stream_end)), + SeekFrom::End(k) if k >= 0 => Some(stream_end), + SeekFrom::End(k) => { + if -k as u64 > stream_end { + None + } else { + Some(stream_end - (-k) as u64) + } + } + }; + + match position { + None => Err(ErrorKind::InvalidInput.into()), + Some(pos) => { + let rel = pos as i64 - self.cursor as i64; + self.inner.seek(SeekFrom::Current(rel))?; + if rel >= 0 { + self.cursor += rel as u64; + self.limit -= rel as u64; + } else { + self.cursor -= -rel as u64; + self.limit += -rel as u64; + } + Ok(pos) + } + } + } +} + impl SizeHint for Take { #[inline] fn lower_bound(&self) -> usize { diff --git a/library/std/src/io/tests.rs b/library/std/src/io/tests.rs index eb62634856462..39fff22704134 100644 --- a/library/std/src/io/tests.rs +++ b/library/std/src/io/tests.rs @@ -602,3 +602,35 @@ fn bench_take_read_buf(b: &mut test::Bencher) { [255; 128].take(64).read_buf(&mut rbuf).unwrap(); }); } + +#[test] +fn test_io_take_seek() { + let mut buf = Cursor::new(b"....0123456789abcdef"); + buf.set_position(4); + let mut stream = buf.take(8); + assert_eq!(stream.seek(SeekFrom::End(0)).unwrap(), 8); + assert_eq!(stream.seek(SeekFrom::End(4)).unwrap(), 8); + assert_eq!(stream.seek(SeekFrom::End(-8)).unwrap(), 0); + assert_eq!(stream.seek(SeekFrom::End(-9)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + let mut bytes: [u8; 2] = [0; 2]; + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"01"); + assert_eq!(stream.stream_position().unwrap(), 2); + assert_eq!(stream.seek(SeekFrom::Current(2)).unwrap(), 4); + assert_eq!(stream.seek(SeekFrom::Current(-5)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + assert_eq!(stream.seek(SeekFrom::Start(1)).unwrap(), 1); + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"12"); + assert_eq!(stream.seek(SeekFrom::Current(3)).unwrap(), 6); + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"67"); + assert_eq!(stream.stream_position().unwrap(), 8); + // reached end of file. + assert!(stream.read_exact(&mut bytes).is_err()); + + assert_eq!(stream.seek(SeekFrom::Current(-3)).unwrap(), 5); + let mut res = Vec::new(); + assert!(stream.read_to_end(&mut res).is_ok()); + assert_eq!(&res, b"567"); + assert_eq!(stream.stream_position().unwrap(), 8); +} From bf24ec0b73fc027dba8ce5d946c6ef1ec7b753ce Mon Sep 17 00:00:00 2001 From: Baojun Wang Date: Mon, 6 Jun 2022 17:22:41 -0400 Subject: [PATCH 2/2] Do not assume take(n) actually get n bytes --- library/std/src/io/mod.rs | 13 +++++- library/std/src/io/tests.rs | 82 +++++++++++++++++++++++++------------ 2 files changed, 68 insertions(+), 27 deletions(-) diff --git a/library/std/src/io/mod.rs b/library/std/src/io/mod.rs index 753b182df878c..8bb6c3508bccb 100644 --- a/library/std/src/io/mod.rs +++ b/library/std/src/io/mod.rs @@ -988,7 +988,7 @@ pub trait Read { where Self: Sized, { - Take { inner: self, limit, cursor: 0 } + Take { inner: self, limit, cursor: 0, seek_once: false } } } @@ -2409,6 +2409,7 @@ pub struct Take { inner: T, limit: u64, cursor: u64, + seek_once: bool, } impl Take { @@ -2640,8 +2641,18 @@ impl BufRead for Take { } } +#[stable(feature = "rust1", since = "1.0.0")] impl Seek for Take { fn seek(&mut self, pos: SeekFrom) -> Result { + if !self.seek_once { + let old_pos = self.inner.stream_position()?; + let end = self.inner.seek(SeekFrom::End(0))?; + if end != old_pos { + self.inner.seek(SeekFrom::Start(old_pos))?; + } + self.seek_once = true; + self.limit = cmp::min(self.limit, end - old_pos); + } let stream_end = self.cursor + self.limit; let position = match pos { SeekFrom::Start(k) => Some(cmp::min(k, stream_end)), diff --git a/library/std/src/io/tests.rs b/library/std/src/io/tests.rs index 39fff22704134..51d1ca5089865 100644 --- a/library/std/src/io/tests.rs +++ b/library/std/src/io/tests.rs @@ -607,30 +607,60 @@ fn bench_take_read_buf(b: &mut test::Bencher) { fn test_io_take_seek() { let mut buf = Cursor::new(b"....0123456789abcdef"); buf.set_position(4); - let mut stream = buf.take(8); - assert_eq!(stream.seek(SeekFrom::End(0)).unwrap(), 8); - assert_eq!(stream.seek(SeekFrom::End(4)).unwrap(), 8); - assert_eq!(stream.seek(SeekFrom::End(-8)).unwrap(), 0); - assert_eq!(stream.seek(SeekFrom::End(-9)).unwrap_err().kind(), io::ErrorKind::InvalidInput); - let mut bytes: [u8; 2] = [0; 2]; - assert!(stream.read_exact(&mut bytes).is_ok()); - assert_eq!(bytes, *b"01"); - assert_eq!(stream.stream_position().unwrap(), 2); - assert_eq!(stream.seek(SeekFrom::Current(2)).unwrap(), 4); - assert_eq!(stream.seek(SeekFrom::Current(-5)).unwrap_err().kind(), io::ErrorKind::InvalidInput); - assert_eq!(stream.seek(SeekFrom::Start(1)).unwrap(), 1); - assert!(stream.read_exact(&mut bytes).is_ok()); - assert_eq!(bytes, *b"12"); - assert_eq!(stream.seek(SeekFrom::Current(3)).unwrap(), 6); - assert!(stream.read_exact(&mut bytes).is_ok()); - assert_eq!(bytes, *b"67"); - assert_eq!(stream.stream_position().unwrap(), 8); - // reached end of file. - assert!(stream.read_exact(&mut bytes).is_err()); - - assert_eq!(stream.seek(SeekFrom::Current(-3)).unwrap(), 5); - let mut res = Vec::new(); - assert!(stream.read_to_end(&mut res).is_ok()); - assert_eq!(&res, b"567"); - assert_eq!(stream.stream_position().unwrap(), 8); + { + let mut stream = buf.by_ref().take(8); + assert_eq!(stream.seek(SeekFrom::End(0)).unwrap(), 8); + assert_eq!(stream.seek(SeekFrom::End(4)).unwrap(), 8); + assert_eq!(stream.seek(SeekFrom::End(-8)).unwrap(), 0); + assert_eq!(stream.seek(SeekFrom::End(-9)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + let mut bytes: [u8; 2] = [0; 2]; + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"01"); + assert_eq!(stream.stream_position().unwrap(), 2); + assert_eq!(stream.seek(SeekFrom::Current(2)).unwrap(), 4); + assert_eq!(stream.seek(SeekFrom::Current(-5)).unwrap_err().kind(), io::ErrorKind::InvalidInput); + assert_eq!(stream.seek(SeekFrom::Start(1)).unwrap(), 1); + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"12"); + assert_eq!(stream.seek(SeekFrom::Current(3)).unwrap(), 6); + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"67"); + assert_eq!(stream.stream_position().unwrap(), 8); + // reached end of file. + assert!(stream.read_exact(&mut bytes).is_err()); + assert_eq!(stream.seek(SeekFrom::Current(-3)).unwrap(), 5); + let mut res = Vec::new(); + assert!(stream.read_to_end(&mut res).is_ok()); + assert_eq!(&res, b"567"); + assert_eq!(stream.stream_position().unwrap(), 8); + } + assert_eq!(buf.stream_position().unwrap(), 12); +} + +#[test] +fn test_io_take_seek_insufficient_bytes() { + let mut buf = Cursor::new(b"....0123456789abcdef"); + buf.set_position(16); + { + // only four bytes are available. + let mut stream = buf.by_ref().take(8); + assert_eq!(stream.seek(SeekFrom::Start(10)).unwrap(), 4); + assert_eq!(stream.seek(SeekFrom::End(-4)).unwrap(), 0); + assert_eq!(stream.seek(SeekFrom::End(1)).unwrap(), 4); + assert!(stream.seek(SeekFrom::Current(-5)).is_err()); + assert_eq!(stream.seek(SeekFrom::Current(-4)).unwrap(), 0); + let mut bytes: [u8; 2] = [0; 2]; + assert!(stream.read_exact(&mut bytes).is_ok()); + assert_eq!(bytes, *b"cd"); + assert_eq!(stream.stream_position().unwrap(), 2); + let mut res = Vec::new(); + assert!(stream.read_to_end(&mut res).is_ok()); + assert_eq!(&res, b"ef"); + assert_eq!(stream.stream_position().unwrap(), 4); + assert!(stream.seek(SeekFrom::Current(-1)).is_ok()); + assert_eq!(stream.read_exact(&mut bytes).unwrap_err().kind(), io::ErrorKind::UnexpectedEof); + assert_eq!(stream.read_to_end(&mut res).unwrap(), 0); + assert_eq!(stream.stream_position().unwrap(), 4); + } + assert_eq!(buf.stream_position().unwrap(), 20); }