Skip to content

Commit a53313c

Browse files
authored
Merge pull request #280 from bluss/axpy-fixup
BUG: Fix blas + axpy by only using it for certain cases
2 parents e1decd8 + 2b4d215 commit a53313c

File tree

5 files changed

+55
-138
lines changed

5 files changed

+55
-138
lines changed

benches/bench1.rs

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -457,26 +457,11 @@ fn iadd_2d_strided(bench: &mut test::Bencher)
457457
});
458458
}
459459

460-
const SCALE_ADD_SZ: usize = 64;
461-
462460
#[bench]
463461
fn scaled_add_2d_f32_regular(bench: &mut test::Bencher)
464462
{
465-
let mut av = Array::<f32, _>::zeros((SCALE_ADD_SZ, SCALE_ADD_SZ));
466-
let bv = Array::<f32, _>::zeros((SCALE_ADD_SZ, SCALE_ADD_SZ));
467-
let scalar = 3.1415926535;
468-
bench.iter(|| {
469-
av.scaled_add(scalar, &bv);
470-
});
471-
}
472-
473-
#[bench]
474-
fn scaled_add_2d_f32_stride(bench: &mut test::Bencher)
475-
{
476-
let mut av = Array::<f32, _>::zeros((SCALE_ADD_SZ, 2 * SCALE_ADD_SZ));
477-
let bv = Array::<f32, _>::zeros((SCALE_ADD_SZ, 2 * SCALE_ADD_SZ));
478-
let mut av = av.slice_mut(s![.., ..;2]);
479-
let bv = bv.slice(s![.., ..;2]);
463+
let mut av = Array::<f32, _>::zeros((64, 64));
464+
let bv = Array::<f32, _>::zeros((64, 64));
480465
let scalar = 3.1415926535;
481466
bench.iter(|| {
482467
av.scaled_add(scalar, &bv);

src/dimension/dimension_trait.rs

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -269,34 +269,20 @@ pub unsafe trait Dimension : Clone + Eq + Debug + Send + Sync + Default +
269269
return true;
270270
}
271271
if dim.ndim() == 1 { return false; }
272-
273-
match Self::equispaced_stride(dim, strides) {
274-
Some(1) => true,
275-
_ => false,
276-
}
277-
}
278-
279-
/// Return the equispaced stride between all the array elements.
280-
///
281-
/// Returns `Some(n)` if the strides in all dimensions are equispaced. Returns `None` if not.
282-
#[doc(hidden)]
283-
fn equispaced_stride(dim: &Self, strides: &Self) -> Option<isize> {
284272
let order = strides._fastest_varying_stride_order();
285-
let base_stride = strides[order[0]];
273+
let strides = strides.slice();
286274

287275
// FIXME: Negative strides
288276
let dim_slice = dim.slice();
289-
let mut next_stride = base_stride;
290-
let strides = strides.slice();
277+
let mut cstride = 1;
291278
for &i in order.slice() {
292279
// a dimension of length 1 can have unequal strides
293-
if dim_slice[i] != 1 && strides[i] != next_stride {
294-
return None;
280+
if dim_slice[i] != 1 && strides[i] != cstride {
281+
return false;
295282
}
296-
next_stride *= dim_slice[i];
283+
cstride *= dim_slice[i];
297284
}
298-
299-
Some(base_stride as isize)
285+
true
300286
}
301287

302288
/// Return the axis ordering corresponding to the fastest variation

src/linalg/impl_linalg.rs

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -292,72 +292,9 @@ impl<A, S, D> ArrayBase<S, D>
292292
S2: Data<Elem=A>,
293293
A: LinalgScalar,
294294
E: Dimension,
295-
{
296-
self.scaled_add_impl(alpha, rhs);
297-
}
298-
299-
fn scaled_add_generic<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
300-
where S: DataMut,
301-
S2: Data<Elem=A>,
302-
A: LinalgScalar,
303-
E: Dimension,
304295
{
305296
self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x));
306297
}
307-
308-
#[cfg(not(feature = "blas"))]
309-
fn scaled_add_impl<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
310-
where S: DataMut,
311-
S2: Data<Elem=A>,
312-
A: LinalgScalar,
313-
E: Dimension,
314-
{
315-
self.scaled_add_generic(alpha, rhs);
316-
}
317-
318-
#[cfg(feature = "blas")]
319-
fn scaled_add_impl<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>)
320-
where S: DataMut,
321-
S2: Data<Elem=A>,
322-
A: LinalgScalar,
323-
E: Dimension,
324-
{
325-
debug_assert_eq!(self.len(), rhs.len());
326-
assert!(self.len() == rhs.len());
327-
{
328-
macro_rules! axpy {
329-
($ty:ty, $func:ident) => {{
330-
if blas_compat::<$ty, _, _>(self) && blas_compat::<$ty, _, _>(rhs) {
331-
let order = Dimension::_fastest_varying_stride_order(&self.strides);
332-
let incx = self.strides()[order[0]];
333-
334-
let order = Dimension::_fastest_varying_stride_order(&rhs.strides);
335-
let incy = self.strides()[order[0]];
336-
337-
unsafe {
338-
let (lhs_ptr, n, incx) = blas_1d_params(self.ptr,
339-
self.len(),
340-
incx);
341-
let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr,
342-
rhs.len(),
343-
incy);
344-
blas_sys::c::$func(
345-
n,
346-
cast_as(&alpha),
347-
rhs_ptr as *const $ty,
348-
incy,
349-
lhs_ptr as *mut $ty,
350-
incx);
351-
return;
352-
}
353-
}
354-
}}
355-
}
356-
axpy!{f32, cblas_saxpy};
357-
axpy!{f64, cblas_daxpy};
358-
}
359-
self.scaled_add_generic(alpha, rhs);
360-
}
361298
}
362299

363300
// mat_mul_impl uses ArrayView arguments to send all array kinds into
@@ -594,32 +531,6 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool
594531
true
595532
}
596533

597-
#[cfg(feature="blas")]
598-
fn blas_compat<A, S, D>(a: &ArrayBase<S, D>) -> bool
599-
where S: Data,
600-
A: 'static,
601-
S::Elem: 'static,
602-
D: Dimension,
603-
{
604-
if !same_type::<A, S::Elem>() {
605-
return false;
606-
}
607-
608-
match D::equispaced_stride(&a.raw_dim(), &a.strides) {
609-
Some(stride) => {
610-
if a.len() as isize * stride > blas_index::max_value() as isize ||
611-
stride < blas_index::min_value() as isize {
612-
return false;
613-
}
614-
},
615-
None => {
616-
return false;
617-
}
618-
}
619-
620-
true
621-
}
622-
623534
#[cfg(feature="blas")]
624535
fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool
625536
where S: Data,

tests/dimension.rs

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,18 +42,6 @@ fn dyn_dimension()
4242
assert_eq!(z.shape(), &dim[..]);
4343
}
4444

45-
#[test]
46-
fn equidistance_strides() {
47-
let strides = Dim([4,2,1]);
48-
assert_eq!(Dimension::equispaced_stride(&Dim([2,2,2]), &strides), Some(1));
49-
50-
let strides = Dim([8,4,2]);
51-
assert_eq!(Dimension::equispaced_stride(&Dim([2,2,2]), &strides), Some(2));
52-
53-
let strides = Dim([16,4,1]);
54-
assert_eq!(Dimension::equispaced_stride(&Dim([2,2,2]), &strides), None);
55-
}
56-
5745
#[test]
5846
fn fastest_varying_order() {
5947
let strides = Dim([2, 8, 4, 1]);

tests/oper.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use ndarray::prelude::*;
55
use ndarray::{rcarr1, rcarr2};
66
use ndarray::{LinalgScalar, Data};
77
use ndarray::linalg::general_mat_mul;
8+
use ndarray::Si;
89

910
use std::fmt;
1011
use num_traits::Float;
@@ -494,6 +495,52 @@ fn scaled_add_2() {
494495
}
495496
}
496497

498+
#[test]
499+
fn scaled_add_3() {
500+
let beta = -2.3;
501+
let sizes = vec![(4, 4, 1, 4),
502+
(8, 8, 1, 8),
503+
(17, 15, 17, 15),
504+
(4, 17, 4, 17),
505+
(17, 3, 1, 3),
506+
(19, 18, 19, 18),
507+
(16, 17, 16, 17),
508+
(15, 16, 15, 16),
509+
(67, 63, 1, 63),
510+
];
511+
// test different strides
512+
for &s1 in &[1, 2, -1, -2] {
513+
for &s2 in &[1, 2, -1, -2] {
514+
for &(m, k, n, q) in &sizes {
515+
let mut a = range_mat64(m, k);
516+
let mut answer = a.clone();
517+
let cdim = if n == 1 {
518+
vec![q]
519+
} else {
520+
vec![n, q]
521+
};
522+
let cslice = if n == 1 {
523+
vec![Si(0, None, s2)]
524+
} else {
525+
vec![Si(0, None, s1), Si(0, None, s2)]
526+
};
527+
528+
let c = range_mat64(n, q).into_shape(cdim).unwrap();
529+
530+
{
531+
let mut av = a.slice_mut(s![..;s1, ..;s2]);
532+
let c = c.slice(&cslice);
533+
534+
let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
535+
answerv += &(beta * &c);
536+
av.scaled_add(beta, &c);
537+
}
538+
assert_close(a.view(), answer.view());
539+
}
540+
}
541+
}
542+
}
543+
497544

498545
#[test]
499546
fn gen_mat_mul() {

0 commit comments

Comments
 (0)