Skip to content

Commit 771bc16

Browse files
committed
blas: Update layout logic for gemm
We compute A B -> C with matrices A, B, C With the blas (cblas) interface it supports matrices that adhere to certain criteria. They should be contiguous on one dimension (stride=1). We glance a little at how numpy does this to try to catch all cases. In short, we accept A, B contiguous on either axis (row or column major). We use the case where C is (weakly) row major, but if it is column major we transpose A, B, C => A^t, B^t, C^t so that we are back to the C row major case. (Weakly = contiguous with stride=1 on that inner dimension, but stride for the other dimension can be larger; to differentiate from strictly whole array contiguous.) Minor change to the gemv function, no functional change, only updating due to the refactoring of blas layout functions. Fixes #1278
1 parent 8d99e56 commit 771bc16

File tree

4 files changed

+278
-137
lines changed

4 files changed

+278
-137
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ rawpointer = { version = "0.2" }
4747
defmac = "0.2"
4848
quickcheck = { workspace = true }
4949
approx = { workspace = true, default-features = true }
50-
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
50+
itertools = { workspace = true }
5151

5252
[features]
5353
default = ["std"]
@@ -73,6 +73,7 @@ matrixmultiply-threading = ["matrixmultiply/threading"]
7373

7474
portable-atomic-critical-section = ["portable-atomic/critical-section"]
7575

76+
7677
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
7778
portable-atomic = { version = "1.6.0" }
7879
portable-atomic-util = { version = "0.2.0", features = [ "alloc" ] }
@@ -103,6 +104,7 @@ approx = { version = "0.5", default-features = false }
103104
quickcheck = { version = "1.0", default-features = false }
104105
rand = { version = "0.8.0", features = ["small_rng"] }
105106
rand_distr = { version = "0.4.0" }
107+
itertools = { version = "0.13.0", default-features = false, features = ["use_std"] }
106108

107109
[profile.bench]
108110
debug = true

crates/blas-tests/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ test = false
1010

1111
[dependencies]
1212
ndarray = { workspace = true, features = ["approx", "blas"] }
13+
ndarray-gen = { workspace = true }
1314

1415
blas-src = { version = "0.10", optional = true }
1516
openblas-src = { version = "0.10", optional = true }
@@ -21,6 +22,7 @@ defmac = "0.2"
2122
approx = { workspace = true }
2223
num-traits = { workspace = true }
2324
num-complex = { workspace = true }
25+
itertools = { workspace = true }
2426

2527
[features]
2628
# Just for making an example and to help testing, , multiple different possible

crates/blas-tests/tests/oper.rs

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ use ndarray::prelude::*;
99

1010
use ndarray::linalg::general_mat_mul;
1111
use ndarray::linalg::general_mat_vec_mul;
12+
use ndarray::Order;
1213
use ndarray::{Data, Ix, LinalgScalar};
14+
use ndarray_gen::array_builder::ArrayBuilder;
1315

1416
use approx::assert_relative_eq;
1517
use defmac::defmac;
18+
use itertools::iproduct;
19+
use ndarray_gen::array_builder::ElementGenerator;
1620
use num_complex::Complex32;
1721
use num_complex::Complex64;
1822

@@ -243,32 +247,59 @@ fn gen_mat_mul()
243247
let sizes = vec![
244248
(4, 4, 4),
245249
(8, 8, 8),
246-
(17, 15, 16),
250+
(10, 10, 10),
251+
(8, 8, 1),
252+
(1, 10, 10),
253+
(10, 1, 10),
254+
(10, 10, 1),
255+
(1, 10, 1),
256+
(10, 1, 1),
257+
(1, 1, 10),
247258
(4, 17, 3),
248259
(17, 3, 22),
249260
(19, 18, 2),
250261
(16, 17, 15),
251262
(15, 16, 17),
252263
(67, 63, 62),
253264
];
254-
// test different strides
255-
for &s1 in &[1, 2, -1, -2] {
256-
for &s2 in &[1, 2, -1, -2] {
257-
for &(m, k, n) in &sizes {
258-
let a = range_mat64(m, k);
259-
let b = range_mat64(k, n);
260-
let mut c = range_mat64(m, n);
265+
let strides = &[1, 2, -1, -2];
266+
let cf_order = [Order::C, Order::F];
267+
268+
// test different strides and memory orders
269+
for (&s1, &s2) in iproduct!(strides, strides) {
270+
for &(m, k, n) in &sizes {
271+
for (ord1, ord2, ord3) in iproduct!(cf_order, cf_order, cf_order) {
272+
println!("Case s1={}, s2={}, orders={:?}, {:?}, {:?}", s1, s2, ord1, ord2, ord3);
273+
let a = ArrayBuilder::new((m, k)).memory_order(ord1).build();
274+
let b = ArrayBuilder::new((k, n)).memory_order(ord2).build();
275+
let mut c = ArrayBuilder::new((m, n))
276+
.memory_order(ord3)
277+
.generator(ElementGenerator::Zero)
278+
.build();
279+
261280
let mut answer = c.clone();
262281

263282
{
264-
let a = a.slice(s![..;s1, ..;s2]);
265-
let b = b.slice(s![..;s2, ..;s2]);
266-
let mut cv = c.slice_mut(s![..;s1, ..;s2]);
283+
let av;
284+
let bv;
285+
let mut cv;
286+
287+
if s1 != 1 || s2 != 1 {
288+
av = a.slice(s![..;s1, ..;s2]);
289+
bv = b.slice(s![..;s2, ..;s2]);
290+
cv = c.slice_mut(s![..;s1, ..;s2]);
291+
} else {
292+
// different stride cases for slicing versus not sliced (for axes of
293+
// len=1); so test not sliced here.
294+
av = a.view();
295+
bv = b.view();
296+
cv = c.view_mut();
297+
}
267298

268-
let answer_part = alpha * reference_mat_mul(&a, &b) + beta * &cv;
299+
let answer_part = alpha * reference_mat_mul(&av, &bv) + beta * &cv;
269300
answer.slice_mut(s![..;s1, ..;s2]).assign(&answer_part);
270301

271-
general_mat_mul(alpha, &a, &b, beta, &mut cv);
302+
general_mat_mul(alpha, &av, &bv, beta, &mut cv);
272303
}
273304
assert_relative_eq!(c, answer, epsilon = 1e-12, max_relative = 1e-7);
274305
}

0 commit comments

Comments
 (0)