From 260c25044f67373775b9233edb2e9927271b8b05 Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Mon, 3 Apr 2023 12:49:17 -0700 Subject: [PATCH 1/2] Avoid some extra bounds checks in `read_{u8,u16}` --- compiler/rustc_serialize/src/lib.rs | 1 + compiler/rustc_serialize/src/opaque.rs | 81 +++++++++++++++--------- compiler/rustc_serialize/tests/opaque.rs | 8 +-- 3 files changed, 56 insertions(+), 34 deletions(-) diff --git a/compiler/rustc_serialize/src/lib.rs b/compiler/rustc_serialize/src/lib.rs index 1f8d2336c4e58..eb288200cf82d 100644 --- a/compiler/rustc_serialize/src/lib.rs +++ b/compiler/rustc_serialize/src/lib.rs @@ -11,6 +11,7 @@ Core encoding and decoding interfaces. )] #![feature(never_type)] #![feature(associated_type_bounds)] +#![feature(iter_advance_by)] #![feature(min_specialization)] #![feature(core_intrinsics)] #![feature(maybe_uninit_slice)] diff --git a/compiler/rustc_serialize/src/opaque.rs b/compiler/rustc_serialize/src/opaque.rs index 0e0ebc79eb2e3..d00f935868f48 100644 --- a/compiler/rustc_serialize/src/opaque.rs +++ b/compiler/rustc_serialize/src/opaque.rs @@ -535,34 +535,55 @@ impl Encoder for FileEncoder { // ----------------------------------------------------------------------------- pub struct MemDecoder<'a> { + // Previously this type stored `position: usize`, but because it's staying + // safe code, that meant that reading `n` bytes meant a bounds check both + // for `position + n` *and* `position`, since there's nothing saying that + // the additions didn't wrap. Storing an iterator like this instead means + // there's no offsetting needed to get to the data, and the iterator instead + // of a slice means only increasing the start pointer on reads, rather than + // also needing to decrease the count in a slice. + // This field is first because it's touched more than `data`. + reader: std::slice::Iter<'a, u8>, pub data: &'a [u8], - position: usize, } impl<'a> MemDecoder<'a> { #[inline] pub fn new(data: &'a [u8], position: usize) -> MemDecoder<'a> { - MemDecoder { data, position } + let reader = data[position..].iter(); + MemDecoder { data, reader } } #[inline] pub fn position(&self) -> usize { - self.position + self.data.len() - self.reader.len() } #[inline] pub fn set_position(&mut self, pos: usize) { - self.position = pos + self.reader = self.data[pos..].iter(); } #[inline] pub fn advance(&mut self, bytes: usize) { - self.position += bytes; + self.reader.advance_by(bytes).unwrap(); + } + + #[cold] + fn panic_insufficient_data(&self) -> ! { + let pos = self.position(); + let len = self.data.len(); + panic!("Insufficient remaining data at position {pos} (length {len})"); } } macro_rules! read_leb128 { - ($dec:expr, $fun:ident) => {{ leb128::$fun($dec.data, &mut $dec.position) }}; + ($dec:expr, $fun:ident) => {{ + let mut position = 0_usize; + let val = leb128::$fun($dec.reader.as_slice(), &mut position); + let _ = $dec.reader.advance_by(position); + val + }}; } impl<'a> Decoder for MemDecoder<'a> { @@ -583,17 +604,14 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_u16(&mut self) -> u16 { - let bytes = [self.data[self.position], self.data[self.position + 1]]; - let value = u16::from_le_bytes(bytes); - self.position += 2; - value + let bytes = self.read_raw_bytes(2); + u16::from_le_bytes(bytes.try_into().unwrap()) } #[inline] fn read_u8(&mut self) -> u8 { - let value = self.data[self.position]; - self.position += 1; - value + let bytes = self.read_raw_bytes(1); + bytes[0] } #[inline] @@ -618,17 +636,12 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_i16(&mut self) -> i16 { - let bytes = [self.data[self.position], self.data[self.position + 1]]; - let value = i16::from_le_bytes(bytes); - self.position += 2; - value + self.read_u16() as i16 } #[inline] fn read_i8(&mut self) -> i8 { - let value = self.data[self.position]; - self.position += 1; - value as i8 + self.read_u8() as i8 } #[inline] @@ -663,20 +676,28 @@ impl<'a> Decoder for MemDecoder<'a> { #[inline] fn read_str(&mut self) -> &'a str { let len = self.read_usize(); - let sentinel = self.data[self.position + len]; - assert!(sentinel == STR_SENTINEL); - let s = unsafe { - std::str::from_utf8_unchecked(&self.data[self.position..self.position + len]) - }; - self.position += len + 1; - s + + // This cannot reuse `read_raw_bytes` as that runs into lifetime issues + // where the slice gets tied to `'b` instead of just to `'a`. + if self.reader.len() <= len { + self.panic_insufficient_data(); + } + let slice = self.reader.as_slice(); + assert!(slice[len] == STR_SENTINEL); + self.reader.advance_by(len + 1).unwrap(); + unsafe { + std::str::from_utf8_unchecked(&slice[..len]) + } } #[inline] fn read_raw_bytes(&mut self, bytes: usize) -> &'a [u8] { - let start = self.position; - self.position += bytes; - &self.data[start..self.position] + if self.reader.len() < bytes { + self.panic_insufficient_data(); + } + let slice = self.reader.as_slice(); + self.reader.advance_by(bytes).unwrap(); + &slice[..bytes] } } diff --git a/compiler/rustc_serialize/tests/opaque.rs b/compiler/rustc_serialize/tests/opaque.rs index 3a695d0714ee1..032853ac640cf 100644 --- a/compiler/rustc_serialize/tests/opaque.rs +++ b/compiler/rustc_serialize/tests/opaque.rs @@ -55,7 +55,7 @@ fn test_unit() { #[test] fn test_u8() { let mut vec = vec![]; - for i in u8::MIN..u8::MAX { + for i in u8::MIN..=u8::MAX { vec.push(i); } check_round_trip(vec); @@ -63,7 +63,7 @@ fn test_u8() { #[test] fn test_u16() { - for i in u16::MIN..u16::MAX { + for i in u16::MIN..=u16::MAX { check_round_trip(vec![1, 2, 3, i, i, i]); } } @@ -86,7 +86,7 @@ fn test_usize() { #[test] fn test_i8() { let mut vec = vec![]; - for i in i8::MIN..i8::MAX { + for i in i8::MIN..=i8::MAX { vec.push(i); } check_round_trip(vec); @@ -94,7 +94,7 @@ fn test_i8() { #[test] fn test_i16() { - for i in i16::MIN..i16::MAX { + for i in i16::MIN..=i16::MAX { check_round_trip(vec![-1, 2, -3, i, i, i, 2]); } } From 23a77dcd5b614ba87eddd9616daad3172f5b7872 Mon Sep 17 00:00:00 2001 From: Scott McMurray Date: Mon, 3 Apr 2023 13:36:00 -0700 Subject: [PATCH 2/2] Also move read-LEB to using iterators --- compiler/rustc_serialize/src/leb128.rs | 23 ++++++++++++----------- compiler/rustc_serialize/src/opaque.rs | 15 +++++++-------- compiler/rustc_serialize/tests/leb128.rs | 12 ++++++------ 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/compiler/rustc_serialize/src/leb128.rs b/compiler/rustc_serialize/src/leb128.rs index 7dad9aa01fafd..407cce969c405 100644 --- a/compiler/rustc_serialize/src/leb128.rs +++ b/compiler/rustc_serialize/src/leb128.rs @@ -49,25 +49,25 @@ impl_write_unsigned_leb128!(write_usize_leb128, usize); macro_rules! impl_read_unsigned_leb128 { ($fn_name:ident, $int_ty:ty) => { + // This returns `Option` to avoid needing to emit the panic paths here. + // Letting the caller do it instead helps keep our code size small. #[inline] - pub fn $fn_name(slice: &[u8], position: &mut usize) -> $int_ty { + pub fn $fn_name(slice: &mut std::slice::Iter<'_, u8>) -> Option<$int_ty> { // The first iteration of this loop is unpeeled. This is a // performance win because this code is hot and integer values less // than 128 are very common, typically occurring 50-80% or more of // the time, even for u64 and u128. - let byte = slice[*position]; - *position += 1; + let byte = *(slice.next()?); if (byte & 0x80) == 0 { - return byte as $int_ty; + return Some(byte as $int_ty); } let mut result = (byte & 0x7F) as $int_ty; let mut shift = 7; loop { - let byte = slice[*position]; - *position += 1; + let byte = *(slice.next()?); if (byte & 0x80) == 0 { result |= (byte as $int_ty) << shift; - return result; + return Some(result); } else { result |= ((byte & 0x7F) as $int_ty) << shift; } @@ -126,15 +126,16 @@ impl_write_signed_leb128!(write_isize_leb128, isize); macro_rules! impl_read_signed_leb128 { ($fn_name:ident, $int_ty:ty) => { + // This returns `Option` to avoid needing to emit the panic paths here. + // Letting the caller do it instead helps keep our code size small. #[inline] - pub fn $fn_name(slice: &[u8], position: &mut usize) -> $int_ty { + pub fn $fn_name(slice: &mut std::slice::Iter<'_, u8>) -> Option<$int_ty> { let mut result = 0; let mut shift = 0; let mut byte; loop { - byte = slice[*position]; - *position += 1; + byte = *(slice.next()?); result |= <$int_ty>::from(byte & 0x7F) << shift; shift += 7; @@ -148,7 +149,7 @@ macro_rules! impl_read_signed_leb128 { result |= (!0 << shift); } - result + Some(result) } }; } diff --git a/compiler/rustc_serialize/src/opaque.rs b/compiler/rustc_serialize/src/opaque.rs index d00f935868f48..fa174066225d4 100644 --- a/compiler/rustc_serialize/src/opaque.rs +++ b/compiler/rustc_serialize/src/opaque.rs @@ -538,7 +538,7 @@ pub struct MemDecoder<'a> { // Previously this type stored `position: usize`, but because it's staying // safe code, that meant that reading `n` bytes meant a bounds check both // for `position + n` *and* `position`, since there's nothing saying that - // the additions didn't wrap. Storing an iterator like this instead means + // the additions didn't wrap. Storing an iterator like this instead means // there's no offsetting needed to get to the data, and the iterator instead // of a slice means only increasing the start pointer on reads, rather than // also needing to decrease the count in a slice. @@ -579,10 +579,11 @@ impl<'a> MemDecoder<'a> { macro_rules! read_leb128 { ($dec:expr, $fun:ident) => {{ - let mut position = 0_usize; - let val = leb128::$fun($dec.reader.as_slice(), &mut position); - let _ = $dec.reader.advance_by(position); - val + if let Some(val) = leb128::$fun(&mut $dec.reader) { + val + } else { + $dec.panic_insufficient_data() + } }}; } @@ -685,9 +686,7 @@ impl<'a> Decoder for MemDecoder<'a> { let slice = self.reader.as_slice(); assert!(slice[len] == STR_SENTINEL); self.reader.advance_by(len + 1).unwrap(); - unsafe { - std::str::from_utf8_unchecked(&slice[..len]) - } + unsafe { std::str::from_utf8_unchecked(&slice[..len]) } } #[inline] diff --git a/compiler/rustc_serialize/tests/leb128.rs b/compiler/rustc_serialize/tests/leb128.rs index 314c07db981da..3f847951e90d5 100644 --- a/compiler/rustc_serialize/tests/leb128.rs +++ b/compiler/rustc_serialize/tests/leb128.rs @@ -28,12 +28,12 @@ macro_rules! impl_test_unsigned_leb128 { stream.extend($write_fn_name(&mut buf, x)); } - let mut position = 0; + let mut reader = stream.iter(); for &expected in &values { - let actual = $read_fn_name(&stream, &mut position); + let actual = $read_fn_name(&mut reader).unwrap(); assert_eq!(expected, actual); } - assert_eq!(stream.len(), position); + assert_eq!(reader.len(), 0); } }; } @@ -74,12 +74,12 @@ macro_rules! impl_test_signed_leb128 { stream.extend($write_fn_name(&mut buf, x)); } - let mut position = 0; + let mut reader = stream.iter(); for &expected in &values { - let actual = $read_fn_name(&stream, &mut position); + let actual = $read_fn_name(&mut reader).unwrap(); assert_eq!(expected, actual); } - assert_eq!(stream.len(), position); + assert_eq!(reader.len(), 0); } }; }