From b39145379087e265245048e859d3e4abd885e1bd Mon Sep 17 00:00:00 2001 From: Jamie Cunliffe Date: Fri, 22 Oct 2021 17:21:41 +0100 Subject: [PATCH] Do not emit undefined lshr/ashr for Neon shifts In LLVM the lshr and ashr instruction is undefined when the amount is equal to the element size, but Neon shifts support such shifts. With signed values, when shifting by the element size is requested subtract one as right shifting a signed value by its size is equivalent to a shift of size - 1. When shifting an unsigned value by the element size, return a vector of all 0's as that's what the shift would result in. --- crates/core_arch/src/aarch64/neon/mod.rs | 14 ++++-- .../src/arm_shared/neon/generated.rs | 48 ++++++++++++------- crates/stdarch-gen/neon.spec | 3 +- crates/stdarch-gen/src/main.rs | 22 +++++++++ 4 files changed, 66 insertions(+), 21 deletions(-) diff --git a/crates/core_arch/src/aarch64/neon/mod.rs b/crates/core_arch/src/aarch64/neon/mod.rs index 82d76bd203..d22a9b1b56 100644 --- a/crates/core_arch/src/aarch64/neon/mod.rs +++ b/crates/core_arch/src/aarch64/neon/mod.rs @@ -2772,7 +2772,8 @@ pub unsafe fn vshld_n_u64(a: u64) -> u64 { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrd_n_s64(a: i64) -> i64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a >> N + let n: i32 = if N == 64 { 63 } else { N }; + a >> n } /// Unsigned shift right @@ -2782,7 +2783,12 @@ pub unsafe fn vshrd_n_s64(a: i64) -> i64 { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrd_n_u64(a: u64) -> u64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a >> N + let n: i32 = if N == 64 { + return 0; + } else { + N + }; + a >> n } /// Signed shift right and accumulate @@ -2792,7 +2798,7 @@ pub unsafe fn vshrd_n_u64(a: u64) -> u64 { #[rustc_legacy_const_generics(2)] pub unsafe fn vsrad_n_s64(a: i64, b: i64) -> i64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a + (b >> N) + a + vshrd_n_s64::(b) } /// Unsigned shift right and accumulate @@ -2802,7 +2808,7 @@ pub unsafe fn vsrad_n_s64(a: i64, b: i64) -> i64 { #[rustc_legacy_const_generics(2)] pub unsafe fn vsrad_n_u64(a: u64, b: u64) -> u64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a + (b >> N) + a + vshrd_n_u64::(b) } /// Shift Left and Insert (immediate) diff --git a/crates/core_arch/src/arm_shared/neon/generated.rs b/crates/core_arch/src/arm_shared/neon/generated.rs index b5457bcfb1..2e98d473bf 100644 --- a/crates/core_arch/src/arm_shared/neon/generated.rs +++ b/crates/core_arch/src/arm_shared/neon/generated.rs @@ -21987,7 +21987,8 @@ pub unsafe fn vshll_n_u32(a: uint32x2_t) -> uint64x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s8(a: int8x8_t) -> int8x8_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdup_n_s8(N.try_into().unwrap())) + let n: i32 = if N == 8 { 7 } else { N }; + simd_shr(a, vdup_n_s8(n.try_into().unwrap())) } /// Shift right @@ -21999,7 +22000,8 @@ pub unsafe fn vshr_n_s8(a: int8x8_t) -> int8x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s8(a: int8x16_t) -> int8x16_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdupq_n_s8(N.try_into().unwrap())) + let n: i32 = if N == 8 { 7 } else { N }; + simd_shr(a, vdupq_n_s8(n.try_into().unwrap())) } /// Shift right @@ -22011,7 +22013,8 @@ pub unsafe fn vshrq_n_s8(a: int8x16_t) -> int8x16_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s16(a: int16x4_t) -> int16x4_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdup_n_s16(N.try_into().unwrap())) + let n: i32 = if N == 16 { 15 } else { N }; + simd_shr(a, vdup_n_s16(n.try_into().unwrap())) } /// Shift right @@ -22023,7 +22026,8 @@ pub unsafe fn vshr_n_s16(a: int16x4_t) -> int16x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s16(a: int16x8_t) -> int16x8_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdupq_n_s16(N.try_into().unwrap())) + let n: i32 = if N == 16 { 15 } else { N }; + simd_shr(a, vdupq_n_s16(n.try_into().unwrap())) } /// Shift right @@ -22035,7 +22039,8 @@ pub unsafe fn vshrq_n_s16(a: int16x8_t) -> int16x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s32(a: int32x2_t) -> int32x2_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdup_n_s32(N.try_into().unwrap())) + let n: i32 = if N == 32 { 31 } else { N }; + simd_shr(a, vdup_n_s32(n.try_into().unwrap())) } /// Shift right @@ -22047,7 +22052,8 @@ pub unsafe fn vshr_n_s32(a: int32x2_t) -> int32x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s32(a: int32x4_t) -> int32x4_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdupq_n_s32(N.try_into().unwrap())) + let n: i32 = if N == 32 { 31 } else { N }; + simd_shr(a, vdupq_n_s32(n.try_into().unwrap())) } /// Shift right @@ -22059,7 +22065,8 @@ pub unsafe fn vshrq_n_s32(a: int32x4_t) -> int32x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s64(a: int64x1_t) -> int64x1_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdup_n_s64(N.try_into().unwrap())) + let n: i32 = if N == 64 { 63 } else { N }; + simd_shr(a, vdup_n_s64(n.try_into().unwrap())) } /// Shift right @@ -22071,7 +22078,8 @@ pub unsafe fn vshr_n_s64(a: int64x1_t) -> int64x1_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s64(a: int64x2_t) -> int64x2_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdupq_n_s64(N.try_into().unwrap())) + let n: i32 = if N == 64 { 63 } else { N }; + simd_shr(a, vdupq_n_s64(n.try_into().unwrap())) } /// Shift right @@ -22083,7 +22091,8 @@ pub unsafe fn vshrq_n_s64(a: int64x2_t) -> int64x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u8(a: uint8x8_t) -> uint8x8_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdup_n_u8(N.try_into().unwrap())) + let n: i32 = if N == 8 { return vdup_n_u8(0); } else { N }; + simd_shr(a, vdup_n_u8(n.try_into().unwrap())) } /// Shift right @@ -22095,7 +22104,8 @@ pub unsafe fn vshr_n_u8(a: uint8x8_t) -> uint8x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u8(a: uint8x16_t) -> uint8x16_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdupq_n_u8(N.try_into().unwrap())) + let n: i32 = if N == 8 { return vdupq_n_u8(0); } else { N }; + simd_shr(a, vdupq_n_u8(n.try_into().unwrap())) } /// Shift right @@ -22107,7 +22117,8 @@ pub unsafe fn vshrq_n_u8(a: uint8x16_t) -> uint8x16_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u16(a: uint16x4_t) -> uint16x4_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdup_n_u16(N.try_into().unwrap())) + let n: i32 = if N == 16 { return vdup_n_u16(0); } else { N }; + simd_shr(a, vdup_n_u16(n.try_into().unwrap())) } /// Shift right @@ -22119,7 +22130,8 @@ pub unsafe fn vshr_n_u16(a: uint16x4_t) -> uint16x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u16(a: uint16x8_t) -> uint16x8_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdupq_n_u16(N.try_into().unwrap())) + let n: i32 = if N == 16 { return vdupq_n_u16(0); } else { N }; + simd_shr(a, vdupq_n_u16(n.try_into().unwrap())) } /// Shift right @@ -22131,7 +22143,8 @@ pub unsafe fn vshrq_n_u16(a: uint16x8_t) -> uint16x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u32(a: uint32x2_t) -> uint32x2_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdup_n_u32(N.try_into().unwrap())) + let n: i32 = if N == 32 { return vdup_n_u32(0); } else { N }; + simd_shr(a, vdup_n_u32(n.try_into().unwrap())) } /// Shift right @@ -22143,7 +22156,8 @@ pub unsafe fn vshr_n_u32(a: uint32x2_t) -> uint32x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u32(a: uint32x4_t) -> uint32x4_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdupq_n_u32(N.try_into().unwrap())) + let n: i32 = if N == 32 { return vdupq_n_u32(0); } else { N }; + simd_shr(a, vdupq_n_u32(n.try_into().unwrap())) } /// Shift right @@ -22155,7 +22169,8 @@ pub unsafe fn vshrq_n_u32(a: uint32x4_t) -> uint32x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u64(a: uint64x1_t) -> uint64x1_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdup_n_u64(N.try_into().unwrap())) + let n: i32 = if N == 64 { return vdup_n_u64(0); } else { N }; + simd_shr(a, vdup_n_u64(n.try_into().unwrap())) } /// Shift right @@ -22167,7 +22182,8 @@ pub unsafe fn vshr_n_u64(a: uint64x1_t) -> uint64x1_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u64(a: uint64x2_t) -> uint64x2_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdupq_n_u64(N.try_into().unwrap())) + let n: i32 = if N == 64 { return vdupq_n_u64(0); } else { N }; + simd_shr(a, vdupq_n_u64(n.try_into().unwrap())) } /// Shift right narrow diff --git a/crates/stdarch-gen/neon.spec b/crates/stdarch-gen/neon.spec index 8fe91b7090..2e1e1b36ec 100644 --- a/crates/stdarch-gen/neon.spec +++ b/crates/stdarch-gen/neon.spec @@ -6785,7 +6785,8 @@ name = vshr n-suffix constn = N multi_fn = static_assert-N-1-bits -multi_fn = simd_shr, a, {vdup-nself-noext, N.try_into().unwrap()} +multi_fn = fix_right_shift_imm-N-bits +multi_fn = simd_shr, a, {vdup-nself-noext, n.try_into().unwrap()} a = 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64 n = 2 validate 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 diff --git a/crates/stdarch-gen/src/main.rs b/crates/stdarch-gen/src/main.rs index 99165d8706..961ab9d27b 100644 --- a/crates/stdarch-gen/src/main.rs +++ b/crates/stdarch-gen/src/main.rs @@ -2664,6 +2664,28 @@ fn get_call( ); } } + if fn_name.starts_with("fix_right_shift_imm") { + let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect(); + let lim = if fn_format[2] == "bits" { + type_bits(in_t[1]).to_string() + } else { + fn_format[2].clone() + }; + let fixed = if in_t[1].starts_with('u') { + format!("return vdup{nself}(0);", nself = type_to_n_suffix(in_t[1])) + } else { + (lim.parse::().unwrap() - 1).to_string() + }; + + return format!( + r#"let {name}: i32 = if {const_name} == {upper} {{ {fixed} }} else {{ N }};"#, + name = fn_format[1].to_lowercase(), + const_name = fn_format[1], + upper = lim, + fixed = fixed, + ); + } + if fn_name.starts_with("matchn") { let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect(); let len = match &*fn_format[1] {