Skip to content

Commit ebb2649

Browse files
committed
blas: Fix to skip array with too short stride
If we have a matrix of dimension say 5 x 5, BLAS requires the leading stride to be >= 5. Smaller cases are possible for read-only array views in ndarray(broadcasting and custom strides). In this case we mark the array as not BLAS compatible
1 parent 771bc16 commit ebb2649

File tree

1 file changed

+34
-5
lines changed

1 file changed

+34
-5
lines changed

src/linalg/impl_linalg.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ where
861861

862862
#[cfg(feature = "blas")]
863863
#[derive(Copy, Clone)]
864+
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
864865
enum MemoryOrder
865866
{
866867
C,
@@ -885,24 +886,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
885886
let (m, n) = dim.into_pattern();
886887
let s0 = stride[0] as isize;
887888
let s1 = stride[1] as isize;
888-
let (inner_stride, outer_dim) = match order {
889-
MemoryOrder::C => (s1, n),
890-
MemoryOrder::F => (s0, m),
889+
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
890+
MemoryOrder::C => (s1, s0, m, n),
891+
MemoryOrder::F => (s0, s1, n, m),
891892
};
893+
892894
if !(inner_stride == 1 || outer_dim == 1) {
893895
return false;
894896
}
897+
895898
if s0 < 1 || s1 < 1 {
896899
return false;
897900
}
901+
898902
if (s0 > blas_index::MAX as isize || s0 < blas_index::MIN as isize)
899903
|| (s1 > blas_index::MAX as isize || s1 < blas_index::MIN as isize)
900904
{
901905
return false;
902906
}
907+
908+
// leading stride must >= the dimension (no broadcasting/aliasing)
909+
if inner_dim > 1 && (outer_stride as usize) < outer_dim {
910+
return false;
911+
}
912+
903913
if m > blas_index::MAX as usize || n > blas_index::MAX as usize {
904914
return false;
905915
}
916+
906917
true
907918
}
908919

@@ -1066,8 +1077,26 @@ mod blas_tests
10661077
}
10671078

10681079
#[test]
1069-
fn test()
1080+
fn blas_too_short_stride()
10701081
{
1071-
//WIP test that stride is larger than other dimension
1082+
// leading stride must be longer than the other dimension
1083+
// Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1084+
1085+
const N: usize = 5;
1086+
const MAXSTRIDE: usize = N + 2;
1087+
let mut data = [0; MAXSTRIDE * N];
1088+
let mut iter = 0..data.len();
1089+
data.fill_with(|| iter.next().unwrap());
1090+
1091+
for stride in 1..=MAXSTRIDE {
1092+
let m = ArrayView::from_shape((N, N).strides((stride, 1)), &data).unwrap();
1093+
eprintln!("{:?}", m);
1094+
1095+
if stride < N {
1096+
assert_eq!(get_blas_compatible_layout(&m), None);
1097+
} else {
1098+
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
1099+
}
1100+
}
10721101
}
10731102
}

0 commit comments

Comments
 (0)