-
Notifications
You must be signed in to change notification settings - Fork 334
Use axpy for scaled_add #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -292,9 +292,72 @@ impl<A, S, D> ArrayBase<S, D> | |
S2: Data<Elem=A>, | ||
A: LinalgScalar, | ||
E: Dimension, | ||
{ | ||
self.scaled_add_impl(alpha, rhs); | ||
} | ||
|
||
fn scaled_add_generic<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>) | ||
where S: DataMut, | ||
S2: Data<Elem=A>, | ||
A: LinalgScalar, | ||
E: Dimension, | ||
{ | ||
self.zip_mut_with(rhs, move |y, &x| *y = *y + (alpha * x)); | ||
} | ||
|
||
#[cfg(not(feature = "blas"))] | ||
fn scaled_add_impl<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>) | ||
where S: DataMut, | ||
S2: Data<Elem=A>, | ||
A: LinalgScalar, | ||
E: Dimension, | ||
{ | ||
self.scaled_add_generic(alpha, rhs); | ||
} | ||
|
||
#[cfg(feature = "blas")] | ||
fn scaled_add_impl<S2, E>(&mut self, alpha: A, rhs: &ArrayBase<S2, E>) | ||
where S: DataMut, | ||
S2: Data<Elem=A>, | ||
A: LinalgScalar, | ||
E: Dimension, | ||
{ | ||
debug_assert_eq!(self.len(), rhs.len()); | ||
assert!(self.len() == rhs.len()); | ||
{ | ||
macro_rules! axpy { | ||
($ty:ty, $func:ident) => {{ | ||
if blas_compat::<$ty, _, _>(self) && blas_compat::<$ty, _, _>(rhs) { | ||
let order = Dimension::_fastest_varying_stride_order(&self.strides); | ||
let incx = self.strides()[order[0]]; | ||
|
||
let order = Dimension::_fastest_varying_stride_order(&rhs.strides); | ||
let incy = self.strides()[order[0]]; | ||
|
||
unsafe { | ||
let (lhs_ptr, n, incx) = blas_1d_params(self.ptr, | ||
self.len(), | ||
incx); | ||
let (rhs_ptr, _, incy) = blas_1d_params(rhs.ptr, | ||
rhs.len(), | ||
incy); | ||
blas_sys::c::$func( | ||
n, | ||
cast_as(&alpha), | ||
rhs_ptr as *const $ty, | ||
incy, | ||
lhs_ptr as *mut $ty, | ||
incx); | ||
return; | ||
} | ||
} | ||
}} | ||
} | ||
axpy!{f32, cblas_saxpy}; | ||
axpy!{f64, cblas_daxpy}; | ||
} | ||
self.scaled_add_generic(alpha, rhs); | ||
} | ||
} | ||
|
||
// mat_mul_impl uses ArrayView arguments to send all array kinds into | ||
|
@@ -531,6 +594,32 @@ fn blas_compat_1d<A, S>(a: &ArrayBase<S, Ix1>) -> bool | |
true | ||
} | ||
|
||
#[cfg(feature="blas")] | ||
fn blas_compat<A, S, D>(a: &ArrayBase<S, D>) -> bool | ||
where S: Data, | ||
A: 'static, | ||
S::Elem: 'static, | ||
D: Dimension, | ||
{ | ||
if !same_type::<A, S::Elem>() { | ||
return false; | ||
} | ||
|
||
match D::equispaced_stride(&a.raw_dim(), &a.strides) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible that this stride is 0 here? I don't think blas supports that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depends, can the stride ever be 0 in any dimension and for any kind of vector? A n-dimensional vector with one dimension set to 0 doesn't really make sense, so we should make sure that such a state can never be reached. A special case is a 0-dimensional array. It looks like its stride is set to 0. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes it's very common, that's how broadcasting works (same in numpy) >>> import numpy as np
>>> np.broadcast_to(np.arange(9.), (9, 9)).strides
(0, 8) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, so the way the |
||
Some(stride) => { | ||
if a.len() as isize * stride > blas_index::max_value() as isize || | ||
stride < blas_index::min_value() as isize { | ||
return false; | ||
} | ||
}, | ||
None => { | ||
return false; | ||
} | ||
} | ||
|
||
true | ||
} | ||
|
||
#[cfg(feature="blas")] | ||
fn blas_row_major_2d<A, S>(a: &ArrayBase<S, Ix2>) -> bool | ||
where S: Data, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lol, I missed a lot. This is not what we want to do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't? :( I basically took this from
dot_impl
:There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rhs is broadcast to self's dim if they mismatch. Yeah, dot doesn't broadcast.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, the mess is on my side. I've added more rigorous tests so that it has something to aspire to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, the assertion there doesn't make any sense. With broadcasting, one stride would be
0
so thatequispaced_stride
returnsNone
(as per the argument above), which would skipaxpy
and fall through to the default implementation. The assertion would simply prevent the broadcasting with the genericscaled_add
to trigger.