diff --git a/benches/bench1.rs b/benches/bench1.rs index 7d948a6a7..218dbefca 100644 --- a/benches/bench1.rs +++ b/benches/bench1.rs @@ -457,11 +457,26 @@ fn iadd_2d_strided(bench: &mut test::Bencher) }); } +const SCALE_ADD_SZ: usize = 64; + #[bench] fn scaled_add_2d_f32_regular(bench: &mut test::Bencher) { - let mut av = Array::::zeros((64, 64)); - let bv = Array::::zeros((64, 64)); + let mut av = Array::::zeros((SCALE_ADD_SZ, SCALE_ADD_SZ)); + let bv = Array::::zeros((SCALE_ADD_SZ, SCALE_ADD_SZ)); + let scalar = 3.1415926535; + bench.iter(|| { + av.scaled_add(scalar, &bv); + }); +} + +#[bench] +fn scaled_add_2d_f32_stride(bench: &mut test::Bencher) +{ + let mut av = Array::::zeros((SCALE_ADD_SZ, 2 * SCALE_ADD_SZ)); + let bv = Array::::zeros((SCALE_ADD_SZ, 2 * SCALE_ADD_SZ)); + let mut av = av.slice_mut(s![.., ..;2]); + let bv = bv.slice(s![.., ..;2]); let scalar = 3.1415926535; bench.iter(|| { av.scaled_add(scalar, &bv); diff --git a/src/dimension/dimension_trait.rs b/src/dimension/dimension_trait.rs index 88b806658..dabeb1cc4 100644 --- a/src/dimension/dimension_trait.rs +++ b/src/dimension/dimension_trait.rs @@ -269,20 +269,34 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default + return true; } if dim.ndim() == 1 { return false; } + + match Self::equispaced_stride(dim, strides) { + Some(1) => true, + _ => false, + } + } + + /// Return the equispaced stride between all the array elements. + /// + /// Returns `Some(n)` if the strides in all dimensions are equispaced. Returns `None` if not. + #[doc(hidden)] + fn equispaced_stride(dim: &Self, strides: &Self) -> Option { let order = strides._fastest_varying_stride_order(); - let strides = strides.slice(); + let base_stride = strides[order[0]]; // FIXME: Negative strides let dim_slice = dim.slice(); - let mut cstride = 1; + let mut next_stride = base_stride; + let strides = strides.slice(); for &i in order.slice() { // a dimension of length 1 can have unequal strides - if dim_slice[i] != 1 && strides[i] != cstride { - return false; + if dim_slice[i] != 1 && strides[i] != next_stride { + return None; } - cstride *= dim_slice[i]; + next_stride *= dim_slice[i]; } - true + + Some(base_stride as isize) } /// Return the axis ordering corresponding to the fastest variation diff --git a/src/linalg/impl_linalg.rs b/src/linalg/impl_linalg.rs index 1662f5c37..be1c32046 100644 --- a/src/linalg/impl_linalg.rs +++ b/src/linalg/impl_linalg.rs @@ -292,9 +292,72 @@ impl ArrayBase S2: Data, A: LinalgScalar, E: Dimension, + { + self.scaled_add_impl(alpha, rhs); + } + + fn scaled_add_generic(&mut self, alpha: A, rhs: &ArrayBase) + where S: DataMut, + S2: Data, + A: LinalgScalar, + E: Dimension, { self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x)); } + + #[cfg(not(feature = "blas"))] + fn scaled_add_impl(&mut self, alpha: A, rhs: &ArrayBase) + where S: DataMut, + S2: Data, + A: LinalgScalar, + E: Dimension, + { + self.scaled_add_generic(alpha, rhs); + } + + #[cfg(feature = "blas")] + fn scaled_add_impl(&mut self, alpha: A, rhs: &ArrayBase) + where S: DataMut, + S2: Data, + A: LinalgScalar, + E: Dimension, + { + debug_assert_eq!(self.len(), rhs.len()); + assert!(self.len() == rhs.len()); + { + macro_rules! axpy { + ($ty:ty, $func:ident) => {{ + if blas_compat::<$ty, _, _>(self) && blas_compat::<$ty, _, _>(rhs) { + let order = Dimension::_fastest_varying_stride_order(&self.strides); + let incx = self.strides()[order[0]]; + + let order = Dimension::_fastest_varying_stride_order(&rhs.strides); + let incy = self.strides()[order[0]]; + + unsafe { + let (lhs_ptr, n, incx) = blas_1d_params(self.ptr, + self.len(), + incx); + let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr, + rhs.len(), + incy); + blas_sys::c::$func( + n, + cast_as(&alpha), + rhs_ptr as *const $ty, + incy, + lhs_ptr as *mut $ty, + incx); + return; + } + } + }} + } + axpy!{f32, cblas_saxpy}; + axpy!{f64, cblas_daxpy}; + } + self.scaled_add_generic(alpha, rhs); + } } // mat_mul_impl uses ArrayView arguments to send all array kinds into @@ -531,6 +594,32 @@ fn blas_compat_1d(a: &ArrayBase) -> bool true } +#[cfg(feature="blas")] +fn blas_compat(a: &ArrayBase) -> bool + where S: Data, + A: 'static, + S::Elem: 'static, + D: Dimension, +{ + if !same_type::() { + return false; + } + + match D::equispaced_stride(&a.raw_dim(), &a.strides) { + Some(stride) => { + if a.len() as isize * stride > blas_index::max_value() as isize || + stride < blas_index::min_value() as isize { + return false; + } + }, + None => { + return false; + } + } + + true +} + #[cfg(feature="blas")] fn blas_row_major_2d(a: &ArrayBase) -> bool where S: Data, diff --git a/tests/dimension.rs b/tests/dimension.rs index 37d366700..b68fa32de 100644 --- a/tests/dimension.rs +++ b/tests/dimension.rs @@ -42,6 +42,18 @@ fn dyn_dimension() assert_eq!(z.shape(), &dim[..]); } +#[test] +fn equidistance_strides() { + let strides = Dim([4,2,1]); + assert_eq!(Dimension::equispaced_stride(&Dim([2,2,2]), &strides), Some(1)); + + let strides = Dim([8,4,2]); + assert_eq!(Dimension::equispaced_stride(&Dim([2,2,2]), &strides), Some(2)); + + let strides = Dim([16,4,1]); + assert_eq!(Dimension::equispaced_stride(&Dim([2,2,2]), &strides), None); +} + #[test] fn fastest_varying_order() { let strides = Dim([2, 8, 4, 1]);