@@ -861,6 +861,7 @@ where
861
861
862
862
#[ cfg( feature = "blas" ) ]
863
863
#[ derive( Copy , Clone ) ]
864
+ #[ cfg_attr( test, derive( PartialEq , Eq , Debug ) ) ]
864
865
enum MemoryOrder
865
866
{
866
867
C ,
@@ -885,24 +886,34 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
885
886
let ( m, n) = dim. into_pattern ( ) ;
886
887
let s0 = stride[ 0 ] as isize ;
887
888
let s1 = stride[ 1 ] as isize ;
888
- let ( inner_stride, outer_dim) = match order {
889
- MemoryOrder :: C => ( s1, n) ,
890
- MemoryOrder :: F => ( s0, m) ,
889
+ let ( inner_stride, outer_stride , inner_dim , outer_dim) = match order {
890
+ MemoryOrder :: C => ( s1, s0 , m , n) ,
891
+ MemoryOrder :: F => ( s0, s1 , n , m) ,
891
892
} ;
893
+
892
894
if !( inner_stride == 1 || outer_dim == 1 ) {
893
895
return false ;
894
896
}
897
+
895
898
if s0 < 1 || s1 < 1 {
896
899
return false ;
897
900
}
901
+
898
902
if ( s0 > blas_index:: MAX as isize || s0 < blas_index:: MIN as isize )
899
903
|| ( s1 > blas_index:: MAX as isize || s1 < blas_index:: MIN as isize )
900
904
{
901
905
return false ;
902
906
}
907
+
908
+ // leading stride must >= the dimension (no broadcasting/aliasing)
909
+ if inner_dim > 1 && ( outer_stride as usize ) < outer_dim {
910
+ return false ;
911
+ }
912
+
903
913
if m > blas_index:: MAX as usize || n > blas_index:: MAX as usize {
904
914
return false ;
905
915
}
916
+
906
917
true
907
918
}
908
919
@@ -1066,8 +1077,26 @@ mod blas_tests
1066
1077
}
1067
1078
1068
1079
#[ test]
1069
- fn test ( )
1080
+ fn blas_too_short_stride ( )
1070
1081
{
1071
- //WIP test that stride is larger than other dimension
1082
+ // leading stride must be longer than the other dimension
1083
+ // Example, in a 5 x 5 matrix, the leading stride must be >= 5 for BLAS.
1084
+
1085
+ const N : usize = 5 ;
1086
+ const MAXSTRIDE : usize = N + 2 ;
1087
+ let mut data = [ 0 ; MAXSTRIDE * N ] ;
1088
+ let mut iter = 0 ..data. len ( ) ;
1089
+ data. fill_with ( || iter. next ( ) . unwrap ( ) ) ;
1090
+
1091
+ for stride in 1 ..=MAXSTRIDE {
1092
+ let m = ArrayView :: from_shape ( ( N , N ) . strides ( ( stride, 1 ) ) , & data) . unwrap ( ) ;
1093
+ eprintln ! ( "{:?}" , m) ;
1094
+
1095
+ if stride < N {
1096
+ assert_eq ! ( get_blas_compatible_layout( & m) , None ) ;
1097
+ } else {
1098
+ assert_eq ! ( get_blas_compatible_layout( & m) , Some ( MemoryOrder :: C ) ) ;
1099
+ }
1100
+ }
1072
1101
}
1073
1102
}
0 commit comments