Skip to content

Commit e65bd0d

Browse files
committed
tests: Refactor to use ArrayBuilder more places
1 parent 56cac34 commit e65bd0d

File tree

5 files changed

+66
-99
lines changed

5 files changed

+66
-99
lines changed

Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ defmac = "0.2"
4848
quickcheck = { workspace = true }
4949
approx = { workspace = true, default-features = true }
5050
itertools = { workspace = true }
51+
ndarray-gen = { workspace = true }
5152

5253
[features]
5354
default = ["std"]
@@ -93,7 +94,7 @@ default-members = [
9394
]
9495

9596
[workspace.dependencies]
96-
ndarray = { version = "0.16", path = "." }
97+
ndarray = { version = "0.16", path = ".", default-features = false }
9798
ndarray-rand = { path = "ndarray-rand" }
9899
ndarray-gen = { path = "crates/ndarray-gen" }
99100

crates/blas-tests/tests/oper.rs

Lines changed: 21 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use defmac::defmac;
1818
use itertools::iproduct;
1919
use num_complex::Complex32;
2020
use num_complex::Complex64;
21+
use num_traits::Num;
2122

2223
#[test]
2324
fn mat_vec_product_1d()
@@ -49,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis()
4950
assert_eq!(a.t().dot(&b), ans);
5051
}
5152

52-
fn range_mat(m: Ix, n: Ix) -> Array2<f32>
53+
fn range_mat<A: Num + Copy>(m: Ix, n: Ix) -> Array2<A>
5354
{
54-
Array::linspace(0., (m * n) as f32 - 1., m * n)
55-
.into_shape_with_order((m, n))
56-
.unwrap()
57-
}
58-
59-
fn range_mat64(m: Ix, n: Ix) -> Array2<f64>
60-
{
61-
Array::linspace(0., (m * n) as f64 - 1., m * n)
62-
.into_shape_with_order((m, n))
63-
.unwrap()
55+
ArrayBuilder::new((m, n)).build()
6456
}
6557

6658
fn range_mat_complex(m: Ix, n: Ix) -> Array2<Complex32>
6759
{
68-
Array::linspace(0., (m * n) as f32 - 1., m * n)
69-
.into_shape_with_order((m, n))
70-
.unwrap()
71-
.map(|&f| Complex32::new(f, 0.))
60+
ArrayBuilder::new((m, n)).build()
7261
}
7362

7463
fn range_mat_complex64(m: Ix, n: Ix) -> Array2<Complex64>
7564
{
76-
Array::linspace(0., (m * n) as f64 - 1., m * n)
77-
.into_shape_with_order((m, n))
78-
.unwrap()
79-
.map(|&f| Complex64::new(f, 0.))
65+
ArrayBuilder::new((m, n)).build()
8066
}
8167

8268
fn range1_mat64(m: Ix) -> Array1<f64>
8369
{
84-
Array::linspace(0., m as f64 - 1., m)
70+
ArrayBuilder::new(m).build()
8571
}
8672

8773
fn range_i32(m: Ix, n: Ix) -> Array2<i32>
8874
{
89-
Array::from_iter(0..(m * n) as i32)
90-
.into_shape_with_order((m, n))
91-
.unwrap()
75+
ArrayBuilder::new((m, n)).build()
9276
}
9377

9478
// simple, slow, correct (hopefully) mat mul
@@ -163,8 +147,8 @@ where
163147
fn mat_mul_order()
164148
{
165149
let (m, n, k) = (50, 50, 50);
166-
let a = range_mat(m, n);
167-
let b = range_mat(n, k);
150+
let a = range_mat::<f32>(m, n);
151+
let b = range_mat::<f32>(n, k);
168152
let mut af = Array::zeros(a.dim().f());
169153
let mut bf = Array::zeros(b.dim().f());
170154
af.assign(&a);
@@ -183,7 +167,7 @@ fn mat_mul_order()
183167
fn mat_mul_broadcast()
184168
{
185169
let (m, n, k) = (16, 16, 16);
186-
let a = range_mat(m, n);
170+
let a = range_mat::<f32>(m, n);
187171
let x1 = 1.;
188172
let x = Array::from(vec![x1]);
189173
let b0 = x.broadcast((n, k)).unwrap();
@@ -203,8 +187,8 @@ fn mat_mul_broadcast()
203187
fn mat_mul_rev()
204188
{
205189
let (m, n, k) = (16, 16, 16);
206-
let a = range_mat(m, n);
207-
let b = range_mat(n, k);
190+
let a = range_mat::<f32>(m, n);
191+
let b = range_mat::<f32>(n, k);
208192
let mut rev = Array::zeros(b.dim());
209193
let mut rev = rev.slice_mut(s![..;-1, ..]);
210194
rev.assign(&b);
@@ -233,8 +217,8 @@ fn mat_mut_zero_len()
233217
}
234218
}
235219
});
236-
mat_mul_zero_len!(range_mat);
237-
mat_mul_zero_len!(range_mat64);
220+
mat_mul_zero_len!(range_mat::<f32>);
221+
mat_mul_zero_len!(range_mat::<f64>);
238222
mat_mul_zero_len!(range_i32);
239223
}
240224

@@ -307,11 +291,11 @@ fn gen_mat_mul()
307291
#[test]
308292
fn gemm_64_1_f()
309293
{
310-
let a = range_mat64(64, 64).reversed_axes();
294+
let a = range_mat::<f64>(64, 64).reversed_axes();
311295
let (m, n) = a.dim();
312296
// m x n times n x 1 == m x 1
313-
let x = range_mat64(n, 1);
314-
let mut y = range_mat64(m, 1);
297+
let x = range_mat::<f64>(n, 1);
298+
let mut y = range_mat::<f64>(m, 1);
315299
let answer = reference_mat_mul(&a, &x) + &y;
316300
general_mat_mul(1.0, &a, &x, 1.0, &mut y);
317301
assert_relative_eq!(y, answer, epsilon = 1e-12, max_relative = 1e-7);
@@ -393,11 +377,8 @@ fn gen_mat_vec_mul()
393377
for &s1 in &[1, 2, -1, -2] {
394378
for &s2 in &[1, 2, -1, -2] {
395379
for &(m, k) in &sizes {
396-
for &rev in &[false, true] {
397-
let mut a = range_mat64(m, k);
398-
if rev {
399-
a = a.reversed_axes();
400-
}
380+
for order in [Order::C, Order::F] {
381+
let a = ArrayBuilder::new((m, k)).memory_order(order).build();
401382
let (m, k) = a.dim();
402383
let b = range1_mat64(k);
403384
let mut c = range1_mat64(m);
@@ -438,11 +419,8 @@ fn vec_mat_mul()
438419
for &s1 in &[1, 2, -1, -2] {
439420
for &s2 in &[1, 2, -1, -2] {
440421
for &(m, n) in &sizes {
441-
for &rev in &[false, true] {
442-
let mut b = range_mat64(m, n);
443-
if rev {
444-
b = b.reversed_axes();
445-
}
422+
for order in [Order::C, Order::F] {
423+
let b = ArrayBuilder::new((m, n)).memory_order(order).build();
446424
let (m, n) = b.dim();
447425
let a = range1_mat64(m);
448426
let mut c = range1_mat64(n);

crates/ndarray-gen/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ edition = "2018"
55
publish = false
66

77
[dependencies]
8-
ndarray = { workspace = true }
8+
ndarray = { workspace = true, default-features = false }
99
num-traits = { workspace = true }

crates/ndarray-gen/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#![no_std]
12
// Copyright 2024 bluss and ndarray developers.
23
//
34
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or

0 commit comments

Comments
 (0)