diff --git a/crates/core_arch/src/aarch64/neon/mod.rs b/crates/core_arch/src/aarch64/neon/mod.rs index 82d76bd203..d22a9b1b56 100644 --- a/crates/core_arch/src/aarch64/neon/mod.rs +++ b/crates/core_arch/src/aarch64/neon/mod.rs @@ -2772,7 +2772,8 @@ pub unsafe fn vshld_n_u64(a: u64) -> u64 { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrd_n_s64(a: i64) -> i64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a >> N + let n: i32 = if N == 64 { 63 } else { N }; + a >> n } /// Unsigned shift right @@ -2782,7 +2783,12 @@ pub unsafe fn vshrd_n_s64(a: i64) -> i64 { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrd_n_u64(a: u64) -> u64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a >> N + let n: i32 = if N == 64 { + return 0; + } else { + N + }; + a >> n } /// Signed shift right and accumulate @@ -2792,7 +2798,7 @@ pub unsafe fn vshrd_n_u64(a: u64) -> u64 { #[rustc_legacy_const_generics(2)] pub unsafe fn vsrad_n_s64(a: i64, b: i64) -> i64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a + (b >> N) + a + vshrd_n_s64::(b) } /// Unsigned shift right and accumulate @@ -2802,7 +2808,7 @@ pub unsafe fn vsrad_n_s64(a: i64, b: i64) -> i64 { #[rustc_legacy_const_generics(2)] pub unsafe fn vsrad_n_u64(a: u64, b: u64) -> u64 { static_assert!(N : i32 where N >= 1 && N <= 64); - a + (b >> N) + a + vshrd_n_u64::(b) } /// Shift Left and Insert (immediate) diff --git a/crates/core_arch/src/arm_shared/neon/generated.rs b/crates/core_arch/src/arm_shared/neon/generated.rs index b5457bcfb1..2e98d473bf 100644 --- a/crates/core_arch/src/arm_shared/neon/generated.rs +++ b/crates/core_arch/src/arm_shared/neon/generated.rs @@ -21987,7 +21987,8 @@ pub unsafe fn vshll_n_u32(a: uint32x2_t) -> uint64x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s8(a: int8x8_t) -> int8x8_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdup_n_s8(N.try_into().unwrap())) + let n: i32 = if N == 8 { 7 } else { N }; + simd_shr(a, vdup_n_s8(n.try_into().unwrap())) } /// Shift right @@ -21999,7 +22000,8 @@ pub unsafe fn vshr_n_s8(a: int8x8_t) -> int8x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s8(a: int8x16_t) -> int8x16_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdupq_n_s8(N.try_into().unwrap())) + let n: i32 = if N == 8 { 7 } else { N }; + simd_shr(a, vdupq_n_s8(n.try_into().unwrap())) } /// Shift right @@ -22011,7 +22013,8 @@ pub unsafe fn vshrq_n_s8(a: int8x16_t) -> int8x16_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s16(a: int16x4_t) -> int16x4_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdup_n_s16(N.try_into().unwrap())) + let n: i32 = if N == 16 { 15 } else { N }; + simd_shr(a, vdup_n_s16(n.try_into().unwrap())) } /// Shift right @@ -22023,7 +22026,8 @@ pub unsafe fn vshr_n_s16(a: int16x4_t) -> int16x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s16(a: int16x8_t) -> int16x8_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdupq_n_s16(N.try_into().unwrap())) + let n: i32 = if N == 16 { 15 } else { N }; + simd_shr(a, vdupq_n_s16(n.try_into().unwrap())) } /// Shift right @@ -22035,7 +22039,8 @@ pub unsafe fn vshrq_n_s16(a: int16x8_t) -> int16x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s32(a: int32x2_t) -> int32x2_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdup_n_s32(N.try_into().unwrap())) + let n: i32 = if N == 32 { 31 } else { N }; + simd_shr(a, vdup_n_s32(n.try_into().unwrap())) } /// Shift right @@ -22047,7 +22052,8 @@ pub unsafe fn vshr_n_s32(a: int32x2_t) -> int32x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s32(a: int32x4_t) -> int32x4_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdupq_n_s32(N.try_into().unwrap())) + let n: i32 = if N == 32 { 31 } else { N }; + simd_shr(a, vdupq_n_s32(n.try_into().unwrap())) } /// Shift right @@ -22059,7 +22065,8 @@ pub unsafe fn vshrq_n_s32(a: int32x4_t) -> int32x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_s64(a: int64x1_t) -> int64x1_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdup_n_s64(N.try_into().unwrap())) + let n: i32 = if N == 64 { 63 } else { N }; + simd_shr(a, vdup_n_s64(n.try_into().unwrap())) } /// Shift right @@ -22071,7 +22078,8 @@ pub unsafe fn vshr_n_s64(a: int64x1_t) -> int64x1_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_s64(a: int64x2_t) -> int64x2_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdupq_n_s64(N.try_into().unwrap())) + let n: i32 = if N == 64 { 63 } else { N }; + simd_shr(a, vdupq_n_s64(n.try_into().unwrap())) } /// Shift right @@ -22083,7 +22091,8 @@ pub unsafe fn vshrq_n_s64(a: int64x2_t) -> int64x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u8(a: uint8x8_t) -> uint8x8_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdup_n_u8(N.try_into().unwrap())) + let n: i32 = if N == 8 { return vdup_n_u8(0); } else { N }; + simd_shr(a, vdup_n_u8(n.try_into().unwrap())) } /// Shift right @@ -22095,7 +22104,8 @@ pub unsafe fn vshr_n_u8(a: uint8x8_t) -> uint8x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u8(a: uint8x16_t) -> uint8x16_t { static_assert!(N : i32 where N >= 1 && N <= 8); - simd_shr(a, vdupq_n_u8(N.try_into().unwrap())) + let n: i32 = if N == 8 { return vdupq_n_u8(0); } else { N }; + simd_shr(a, vdupq_n_u8(n.try_into().unwrap())) } /// Shift right @@ -22107,7 +22117,8 @@ pub unsafe fn vshrq_n_u8(a: uint8x16_t) -> uint8x16_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u16(a: uint16x4_t) -> uint16x4_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdup_n_u16(N.try_into().unwrap())) + let n: i32 = if N == 16 { return vdup_n_u16(0); } else { N }; + simd_shr(a, vdup_n_u16(n.try_into().unwrap())) } /// Shift right @@ -22119,7 +22130,8 @@ pub unsafe fn vshr_n_u16(a: uint16x4_t) -> uint16x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u16(a: uint16x8_t) -> uint16x8_t { static_assert!(N : i32 where N >= 1 && N <= 16); - simd_shr(a, vdupq_n_u16(N.try_into().unwrap())) + let n: i32 = if N == 16 { return vdupq_n_u16(0); } else { N }; + simd_shr(a, vdupq_n_u16(n.try_into().unwrap())) } /// Shift right @@ -22131,7 +22143,8 @@ pub unsafe fn vshrq_n_u16(a: uint16x8_t) -> uint16x8_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u32(a: uint32x2_t) -> uint32x2_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdup_n_u32(N.try_into().unwrap())) + let n: i32 = if N == 32 { return vdup_n_u32(0); } else { N }; + simd_shr(a, vdup_n_u32(n.try_into().unwrap())) } /// Shift right @@ -22143,7 +22156,8 @@ pub unsafe fn vshr_n_u32(a: uint32x2_t) -> uint32x2_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u32(a: uint32x4_t) -> uint32x4_t { static_assert!(N : i32 where N >= 1 && N <= 32); - simd_shr(a, vdupq_n_u32(N.try_into().unwrap())) + let n: i32 = if N == 32 { return vdupq_n_u32(0); } else { N }; + simd_shr(a, vdupq_n_u32(n.try_into().unwrap())) } /// Shift right @@ -22155,7 +22169,8 @@ pub unsafe fn vshrq_n_u32(a: uint32x4_t) -> uint32x4_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshr_n_u64(a: uint64x1_t) -> uint64x1_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdup_n_u64(N.try_into().unwrap())) + let n: i32 = if N == 64 { return vdup_n_u64(0); } else { N }; + simd_shr(a, vdup_n_u64(n.try_into().unwrap())) } /// Shift right @@ -22167,7 +22182,8 @@ pub unsafe fn vshr_n_u64(a: uint64x1_t) -> uint64x1_t { #[rustc_legacy_const_generics(1)] pub unsafe fn vshrq_n_u64(a: uint64x2_t) -> uint64x2_t { static_assert!(N : i32 where N >= 1 && N <= 64); - simd_shr(a, vdupq_n_u64(N.try_into().unwrap())) + let n: i32 = if N == 64 { return vdupq_n_u64(0); } else { N }; + simd_shr(a, vdupq_n_u64(n.try_into().unwrap())) } /// Shift right narrow diff --git a/crates/stdarch-gen/neon.spec b/crates/stdarch-gen/neon.spec index 8fe91b7090..2e1e1b36ec 100644 --- a/crates/stdarch-gen/neon.spec +++ b/crates/stdarch-gen/neon.spec @@ -6785,7 +6785,8 @@ name = vshr n-suffix constn = N multi_fn = static_assert-N-1-bits -multi_fn = simd_shr, a, {vdup-nself-noext, N.try_into().unwrap()} +multi_fn = fix_right_shift_imm-N-bits +multi_fn = simd_shr, a, {vdup-nself-noext, n.try_into().unwrap()} a = 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64 n = 2 validate 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 diff --git a/crates/stdarch-gen/src/main.rs b/crates/stdarch-gen/src/main.rs index 99165d8706..961ab9d27b 100644 --- a/crates/stdarch-gen/src/main.rs +++ b/crates/stdarch-gen/src/main.rs @@ -2664,6 +2664,28 @@ fn get_call( ); } } + if fn_name.starts_with("fix_right_shift_imm") { + let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect(); + let lim = if fn_format[2] == "bits" { + type_bits(in_t[1]).to_string() + } else { + fn_format[2].clone() + }; + let fixed = if in_t[1].starts_with('u') { + format!("return vdup{nself}(0);", nself = type_to_n_suffix(in_t[1])) + } else { + (lim.parse::().unwrap() - 1).to_string() + }; + + return format!( + r#"let {name}: i32 = if {const_name} == {upper} {{ {fixed} }} else {{ N }};"#, + name = fn_format[1].to_lowercase(), + const_name = fn_format[1], + upper = lim, + fixed = fixed, + ); + } + if fn_name.starts_with("matchn") { let fn_format: Vec<_> = fn_name.split('-').map(|v| v.to_string()).collect(); let len = match &*fn_format[1] {