Skip to content

Commit a5d6eaf

Browse files
committed
blas: Simplify layout logic for gemm
Using cblas we can simplify this further to a more satisfying translation (from ndarray to BLAS), much simpler logic. Avoids creating and handling an extra layer of array views.
1 parent 5c8b9de commit a5d6eaf

File tree

2 files changed

+67
-68
lines changed

2 files changed

+67
-68
lines changed

crates/blas-tests/tests/oper.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ fn gen_mat_mul()
253253
for &(m, k, n) in &sizes {
254254
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
255255
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
256-
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
256+
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build() * 0.5;
257257
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
258258
let mut c = ArrayBuilder::new((m, n)).memory_order(ord3).build();
259259

src/linalg/impl_linalg.rs

Lines changed: 66 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,11 @@ use num_complex::{Complex32 as c32, Complex64 as c64};
2424

2525
#[cfg(feature = "blas")]
2626
use libc::c_int;
27-
#[cfg(feature = "blas")]
28-
use std::mem::swap;
2927

3028
#[cfg(feature = "blas")]
3129
use cblas_sys as blas_sys;
3230
#[cfg(feature = "blas")]
33-
use cblas_sys::{CblasNoTrans, CblasRowMajor, CblasTrans, CBLAS_LAYOUT};
31+
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT};
3432

3533
/// len of vector before we use blas
3634
#[cfg(feature = "blas")]
@@ -377,93 +375,65 @@ use self::mat_mul_general as mat_mul_impl;
377375
#[cfg(feature = "blas")]
378376
fn mat_mul_impl<A>(
379377
alpha: A,
380-
lhs: &ArrayView2<'_, A>,
381-
rhs: &ArrayView2<'_, A>,
378+
a: &ArrayView2<'_, A>,
379+
b: &ArrayView2<'_, A>,
382380
beta: A,
383381
c: &mut ArrayViewMut2<'_, A>,
384382
) where
385383
A: LinalgScalar,
386384
{
387385
// size cutoff for using BLAS
388386
let cut = GEMM_BLAS_CUTOFF;
389-
let ((mut m, k), (k2, mut n)) = (lhs.dim(), rhs.dim());
387+
let ((m, k), (k2, n)) = (a.dim(), b.dim());
390388
debug_assert_eq!(k, k2);
391389
if !(m > cut || n > cut || k > cut)
392390
|| !(same_type::<A, f32>()
393391
|| same_type::<A, f64>()
394392
|| same_type::<A, c32>()
395393
|| same_type::<A, c64>())
396394
{
397-
return mat_mul_general(alpha, lhs, rhs, beta, c);
395+
return mat_mul_general(alpha, a, b, beta, c);
398396
}
399397

400398
#[allow(clippy::never_loop)] // MSRV Rust 1.64 does not have break from block
401399
'blas_block: loop {
402-
let mut a = lhs.view();
403-
let mut b = rhs.view();
404-
let mut c = c.view_mut();
405-
406-
let c_layout = get_blas_compatible_layout(&c);
407-
let c_layout_is_c = matches!(c_layout, Some(MemoryOrder::C));
408-
let c_layout_is_f = matches!(c_layout, Some(MemoryOrder::F));
409-
410400
// Compute A B -> C
411-
// we require for BLAS compatibility that:
412-
// A, B are contiguous (stride=1) in their fastest dimension.
413-
// C is c-contiguous in one dimension (stride=1 in Axis(1))
401+
// We require for BLAS compatibility that:
402+
// A, B, C are contiguous (stride=1) in their fastest dimension,
403+
// but it can be either first or second axis (either rowmajor/"c" or colmajor/"f").
414404
//
415-
// If C is f-contiguous, use transpose equivalency
416-
// to translate to the C-contiguous case:
417-
// A^t B^t = C^t => B A = C
418-
419-
let (a_layout, b_layout) =
420-
match (get_blas_compatible_layout(&a), get_blas_compatible_layout(&b)) {
421-
(Some(a_layout), Some(b_layout)) if c_layout_is_c => {
422-
// normal case
423-
(a_layout, b_layout)
405+
// The "normal case" is CblasRowMajor for cblas.
406+
// Select CblasRowMajor, CblasColMajor to fit C's memory order.
407+
//
408+
// Apply transpose to A, B as needed if they differ from the normal case.
409+
// If C is CblasColMajor then transpose both A, B (again!)
410+
411+
let (a_layout, a_axis, b_layout, b_axis, c_layout) =
412+
match (get_blas_compatible_layout(a),
413+
get_blas_compatible_layout(b),
414+
get_blas_compatible_layout(c))
415+
{
416+
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => {
417+
(a_layout, a_layout.lead_axis(),
418+
b_layout, b_layout.lead_axis(), c_layout)
424419
},
425-
(Some(a_layout), Some(b_layout)) if c_layout_is_f => {
426-
// Transpose equivalency
427-
// A^t B^t = C^t => B A = C
428-
//
429-
// A^t becomes the new B
430-
// B^t becomes the new A
431-
let a_t = a.reversed_axes();
432-
a = b.reversed_axes();
433-
b = a_t;
434-
c = c.reversed_axes();
435-
// Assign (n, k, m) -> (m, k, n) effectively
436-
swap(&mut m, &mut n);
437-
438-
// Continue using the already computed memory layouts
439-
(b_layout.opposite(), a_layout.opposite())
420+
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => {
421+
// CblasColMajor is the "other case"
422+
// Mark a, b as having layouts opposite of what they were detected as, which
423+
// ends up with the correct transpose setting w.r.t col major
424+
(a_layout.opposite(), a_layout.lead_axis(),
425+
b_layout.opposite(), b_layout.lead_axis(), c_layout)
440426
},
441-
_otherwise => {
442-
break 'blas_block;
443-
}
427+
_ => break 'blas_block,
444428
};
445429

446-
let a_trans;
447-
let b_trans;
448-
let lda; // Stride of a
449-
let ldb; // Stride of b
430+
let a_trans = a_layout.to_cblas_transpose();
431+
let lda = blas_stride(&a, a_axis);
450432

451-
if let MemoryOrder::C = a_layout {
452-
lda = blas_stride(&a, 0);
453-
a_trans = CblasNoTrans;
454-
} else {
455-
lda = blas_stride(&a, 1);
456-
a_trans = CblasTrans;
457-
}
433+
let b_trans = b_layout.to_cblas_transpose();
434+
let ldb = blas_stride(&b, b_axis);
458435

459-
if let MemoryOrder::C = b_layout {
460-
ldb = blas_stride(&b, 0);
461-
b_trans = CblasNoTrans;
462-
} else {
463-
ldb = blas_stride(&b, 1);
464-
b_trans = CblasTrans;
465-
}
466-
let ldc = blas_stride(&c, 0);
436+
let ldc = blas_stride(&c, c_layout.lead_axis());
467437

468438
macro_rules! gemm_scalar_cast {
469439
(f32, $var:ident) => {
@@ -487,7 +457,7 @@ fn mat_mul_impl<A>(
487457
// Where Op is notrans/trans/conjtrans
488458
unsafe {
489459
blas_sys::$gemm(
490-
CblasRowMajor,
460+
c_layout.to_cblas_layout(),
491461
a_trans,
492462
b_trans,
493463
m as blas_index, // m, rows of Op(a)
@@ -507,14 +477,15 @@ fn mat_mul_impl<A>(
507477
}
508478
};
509479
}
480+
510481
gemm!(f32, cblas_sgemm);
511482
gemm!(f64, cblas_dgemm);
512-
513483
gemm!(c32, cblas_cgemm);
514484
gemm!(c64, cblas_zgemm);
485+
515486
break 'blas_block;
516487
}
517-
mat_mul_general(alpha, lhs, rhs, beta, c)
488+
mat_mul_general(alpha, a, b, beta, c)
518489
}
519490

520491
/// C ← α A B + β C
@@ -873,13 +844,41 @@ enum MemoryOrder
873844
#[cfg(feature = "blas")]
874845
impl MemoryOrder
875846
{
847+
#[inline]
848+
/// Axis of leading stride (opposite of contiguous axis)
849+
fn lead_axis(self) -> usize
850+
{
851+
match self {
852+
MemoryOrder::C => 0,
853+
MemoryOrder::F => 1,
854+
}
855+
}
856+
857+
/// Get opposite memory order
858+
#[inline]
876859
fn opposite(self) -> Self
877860
{
878861
match self {
879862
MemoryOrder::C => MemoryOrder::F,
880863
MemoryOrder::F => MemoryOrder::C,
881864
}
882865
}
866+
867+
fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE
868+
{
869+
match self {
870+
MemoryOrder::C => CblasNoTrans,
871+
MemoryOrder::F => CblasTrans,
872+
}
873+
}
874+
875+
fn to_cblas_layout(self) -> CBLAS_LAYOUT
876+
{
877+
match self {
878+
MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor,
879+
MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor,
880+
}
881+
}
883882
}
884883

885884
#[cfg(feature = "blas")]

0 commit comments

Comments
 (0)