@@ -24,13 +24,11 @@ use num_complex::{Complex32 as c32, Complex64 as c64};
24
24
25
25
#[ cfg( feature = "blas" ) ]
26
26
use libc:: c_int;
27
- #[ cfg( feature = "blas" ) ]
28
- use std:: mem:: swap;
29
27
30
28
#[ cfg( feature = "blas" ) ]
31
29
use cblas_sys as blas_sys;
32
30
#[ cfg( feature = "blas" ) ]
33
- use cblas_sys:: { CblasNoTrans , CblasRowMajor , CblasTrans , CBLAS_LAYOUT } ;
31
+ use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT } ;
34
32
35
33
/// len of vector before we use blas
36
34
#[ cfg( feature = "blas" ) ]
@@ -377,93 +375,65 @@ use self::mat_mul_general as mat_mul_impl;
377
375
#[ cfg( feature = "blas" ) ]
378
376
fn mat_mul_impl < A > (
379
377
alpha : A ,
380
- lhs : & ArrayView2 < ' _ , A > ,
381
- rhs : & ArrayView2 < ' _ , A > ,
378
+ a : & ArrayView2 < ' _ , A > ,
379
+ b : & ArrayView2 < ' _ , A > ,
382
380
beta : A ,
383
381
c : & mut ArrayViewMut2 < ' _ , A > ,
384
382
) where
385
383
A : LinalgScalar ,
386
384
{
387
385
// size cutoff for using BLAS
388
386
let cut = GEMM_BLAS_CUTOFF ;
389
- let ( ( mut m, k) , ( k2, mut n) ) = ( lhs . dim ( ) , rhs . dim ( ) ) ;
387
+ let ( ( m, k) , ( k2, n) ) = ( a . dim ( ) , b . dim ( ) ) ;
390
388
debug_assert_eq ! ( k, k2) ;
391
389
if !( m > cut || n > cut || k > cut)
392
390
|| !( same_type :: < A , f32 > ( )
393
391
|| same_type :: < A , f64 > ( )
394
392
|| same_type :: < A , c32 > ( )
395
393
|| same_type :: < A , c64 > ( ) )
396
394
{
397
- return mat_mul_general ( alpha, lhs , rhs , beta, c) ;
395
+ return mat_mul_general ( alpha, a , b , beta, c) ;
398
396
}
399
397
400
398
#[ allow( clippy:: never_loop) ] // MSRV Rust 1.64 does not have break from block
401
399
' blas_block: loop {
402
- let mut a = lhs. view ( ) ;
403
- let mut b = rhs. view ( ) ;
404
- let mut c = c. view_mut ( ) ;
405
-
406
- let c_layout = get_blas_compatible_layout ( & c) ;
407
- let c_layout_is_c = matches ! ( c_layout, Some ( MemoryOrder :: C ) ) ;
408
- let c_layout_is_f = matches ! ( c_layout, Some ( MemoryOrder :: F ) ) ;
409
-
410
400
// Compute A B -> C
411
- // we require for BLAS compatibility that:
412
- // A, B are contiguous (stride=1) in their fastest dimension.
413
- // C is c-contiguous in one dimension (stride=1 in Axis(1))
401
+ // We require for BLAS compatibility that:
402
+ // A, B, C are contiguous (stride=1) in their fastest dimension,
403
+ // but it can be either first or second axis (either rowmajor/"c" or colmajor/"f").
414
404
//
415
- // If C is f-contiguous, use transpose equivalency
416
- // to translate to the C-contiguous case:
417
- // A^t B^t = C^t => B A = C
418
-
419
- let ( a_layout, b_layout) =
420
- match ( get_blas_compatible_layout ( & a) , get_blas_compatible_layout ( & b) ) {
421
- ( Some ( a_layout) , Some ( b_layout) ) if c_layout_is_c => {
422
- // normal case
423
- ( a_layout, b_layout)
405
+ // The "normal case" is CblasRowMajor for cblas.
406
+ // Select CblasRowMajor, CblasColMajor to fit C's memory order.
407
+ //
408
+ // Apply transpose to A, B as needed if they differ from the normal case.
409
+ // If C is CblasColMajor then transpose both A, B (again!)
410
+
411
+ let ( a_layout, a_axis, b_layout, b_axis, c_layout) =
412
+ match ( get_blas_compatible_layout ( a) ,
413
+ get_blas_compatible_layout ( b) ,
414
+ get_blas_compatible_layout ( c) )
415
+ {
416
+ ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: C ) ) => {
417
+ ( a_layout, a_layout. lead_axis ( ) ,
418
+ b_layout, b_layout. lead_axis ( ) , c_layout)
424
419
} ,
425
- ( Some ( a_layout) , Some ( b_layout) ) if c_layout_is_f => {
426
- // Transpose equivalency
427
- // A^t B^t = C^t => B A = C
428
- //
429
- // A^t becomes the new B
430
- // B^t becomes the new A
431
- let a_t = a. reversed_axes ( ) ;
432
- a = b. reversed_axes ( ) ;
433
- b = a_t;
434
- c = c. reversed_axes ( ) ;
435
- // Assign (n, k, m) -> (m, k, n) effectively
436
- swap ( & mut m, & mut n) ;
437
-
438
- // Continue using the already computed memory layouts
439
- ( b_layout. opposite ( ) , a_layout. opposite ( ) )
420
+ ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: F ) ) => {
421
+ // CblasColMajor is the "other case"
422
+ // Mark a, b as having layouts opposite of what they were detected as, which
423
+ // ends up with the correct transpose setting w.r.t col major
424
+ ( a_layout. opposite ( ) , a_layout. lead_axis ( ) ,
425
+ b_layout. opposite ( ) , b_layout. lead_axis ( ) , c_layout)
440
426
} ,
441
- _otherwise => {
442
- break ' blas_block;
443
- }
427
+ _ => break ' blas_block,
444
428
} ;
445
429
446
- let a_trans;
447
- let b_trans;
448
- let lda; // Stride of a
449
- let ldb; // Stride of b
430
+ let a_trans = a_layout. to_cblas_transpose ( ) ;
431
+ let lda = blas_stride ( & a, a_axis) ;
450
432
451
- if let MemoryOrder :: C = a_layout {
452
- lda = blas_stride ( & a, 0 ) ;
453
- a_trans = CblasNoTrans ;
454
- } else {
455
- lda = blas_stride ( & a, 1 ) ;
456
- a_trans = CblasTrans ;
457
- }
433
+ let b_trans = b_layout. to_cblas_transpose ( ) ;
434
+ let ldb = blas_stride ( & b, b_axis) ;
458
435
459
- if let MemoryOrder :: C = b_layout {
460
- ldb = blas_stride ( & b, 0 ) ;
461
- b_trans = CblasNoTrans ;
462
- } else {
463
- ldb = blas_stride ( & b, 1 ) ;
464
- b_trans = CblasTrans ;
465
- }
466
- let ldc = blas_stride ( & c, 0 ) ;
436
+ let ldc = blas_stride ( & c, c_layout. lead_axis ( ) ) ;
467
437
468
438
macro_rules! gemm_scalar_cast {
469
439
( f32 , $var: ident) => {
@@ -487,7 +457,7 @@ fn mat_mul_impl<A>(
487
457
// Where Op is notrans/trans/conjtrans
488
458
unsafe {
489
459
blas_sys:: $gemm(
490
- CblasRowMajor ,
460
+ c_layout . to_cblas_layout ( ) ,
491
461
a_trans,
492
462
b_trans,
493
463
m as blas_index, // m, rows of Op(a)
@@ -507,14 +477,15 @@ fn mat_mul_impl<A>(
507
477
}
508
478
} ;
509
479
}
480
+
510
481
gemm ! ( f32 , cblas_sgemm) ;
511
482
gemm ! ( f64 , cblas_dgemm) ;
512
-
513
483
gemm ! ( c32, cblas_cgemm) ;
514
484
gemm ! ( c64, cblas_zgemm) ;
485
+
515
486
break ' blas_block;
516
487
}
517
- mat_mul_general ( alpha, lhs , rhs , beta, c)
488
+ mat_mul_general ( alpha, a , b , beta, c)
518
489
}
519
490
520
491
/// C ← α A B + β C
@@ -873,13 +844,41 @@ enum MemoryOrder
873
844
#[ cfg( feature = "blas" ) ]
874
845
impl MemoryOrder
875
846
{
847
+ #[ inline]
848
+ /// Axis of leading stride (opposite of contiguous axis)
849
+ fn lead_axis ( self ) -> usize
850
+ {
851
+ match self {
852
+ MemoryOrder :: C => 0 ,
853
+ MemoryOrder :: F => 1 ,
854
+ }
855
+ }
856
+
857
+ /// Get opposite memory order
858
+ #[ inline]
876
859
fn opposite ( self ) -> Self
877
860
{
878
861
match self {
879
862
MemoryOrder :: C => MemoryOrder :: F ,
880
863
MemoryOrder :: F => MemoryOrder :: C ,
881
864
}
882
865
}
866
+
867
+ fn to_cblas_transpose ( self ) -> cblas_sys:: CBLAS_TRANSPOSE
868
+ {
869
+ match self {
870
+ MemoryOrder :: C => CblasNoTrans ,
871
+ MemoryOrder :: F => CblasTrans ,
872
+ }
873
+ }
874
+
875
+ fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
876
+ {
877
+ match self {
878
+ MemoryOrder :: C => CBLAS_LAYOUT :: CblasRowMajor ,
879
+ MemoryOrder :: F => CBLAS_LAYOUT :: CblasColMajor ,
880
+ }
881
+ }
883
882
}
884
883
885
884
#[ cfg( feature = "blas" ) ]
0 commit comments