@@ -18,6 +18,7 @@ use defmac::defmac;
18
18
use itertools:: iproduct;
19
19
use num_complex:: Complex32 ;
20
20
use num_complex:: Complex64 ;
21
+ use num_traits:: Num ;
21
22
22
23
#[ test]
23
24
fn mat_vec_product_1d ( )
@@ -49,46 +50,29 @@ fn mat_vec_product_1d_inverted_axis()
49
50
assert_eq ! ( a. t( ) . dot( & b) , ans) ;
50
51
}
51
52
52
- fn range_mat ( m : Ix , n : Ix ) -> Array2 < f32 >
53
+ fn range_mat < A : Num + Copy > ( m : Ix , n : Ix ) -> Array2 < A >
53
54
{
54
- Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
55
- . into_shape_with_order ( ( m, n) )
56
- . unwrap ( )
57
- }
58
-
59
- fn range_mat64 ( m : Ix , n : Ix ) -> Array2 < f64 >
60
- {
61
- Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
62
- . into_shape_with_order ( ( m, n) )
63
- . unwrap ( )
55
+ ArrayBuilder :: new ( ( m, n) ) . build ( )
64
56
}
65
57
66
58
fn range_mat_complex ( m : Ix , n : Ix ) -> Array2 < Complex32 >
67
59
{
68
- Array :: linspace ( 0. , ( m * n) as f32 - 1. , m * n)
69
- . into_shape_with_order ( ( m, n) )
70
- . unwrap ( )
71
- . map ( |& f| Complex32 :: new ( f, 0. ) )
60
+ ArrayBuilder :: new ( ( m, n) ) . build ( )
72
61
}
73
62
74
63
fn range_mat_complex64 ( m : Ix , n : Ix ) -> Array2 < Complex64 >
75
64
{
76
- Array :: linspace ( 0. , ( m * n) as f64 - 1. , m * n)
77
- . into_shape_with_order ( ( m, n) )
78
- . unwrap ( )
79
- . map ( |& f| Complex64 :: new ( f, 0. ) )
65
+ ArrayBuilder :: new ( ( m, n) ) . build ( )
80
66
}
81
67
82
68
fn range1_mat64 ( m : Ix ) -> Array1 < f64 >
83
69
{
84
- Array :: linspace ( 0. , m as f64 - 1. , m )
70
+ ArrayBuilder :: new ( m ) . build ( )
85
71
}
86
72
87
73
fn range_i32 ( m : Ix , n : Ix ) -> Array2 < i32 >
88
74
{
89
- Array :: from_iter ( 0 ..( m * n) as i32 )
90
- . into_shape_with_order ( ( m, n) )
91
- . unwrap ( )
75
+ ArrayBuilder :: new ( ( m, n) ) . build ( )
92
76
}
93
77
94
78
// simple, slow, correct (hopefully) mat mul
@@ -163,8 +147,8 @@ where
163
147
fn mat_mul_order ( )
164
148
{
165
149
let ( m, n, k) = ( 50 , 50 , 50 ) ;
166
- let a = range_mat ( m, n) ;
167
- let b = range_mat ( n, k) ;
150
+ let a = range_mat :: < f32 > ( m, n) ;
151
+ let b = range_mat :: < f32 > ( n, k) ;
168
152
let mut af = Array :: zeros ( a. dim ( ) . f ( ) ) ;
169
153
let mut bf = Array :: zeros ( b. dim ( ) . f ( ) ) ;
170
154
af. assign ( & a) ;
@@ -183,7 +167,7 @@ fn mat_mul_order()
183
167
fn mat_mul_broadcast ( )
184
168
{
185
169
let ( m, n, k) = ( 16 , 16 , 16 ) ;
186
- let a = range_mat ( m, n) ;
170
+ let a = range_mat :: < f32 > ( m, n) ;
187
171
let x1 = 1. ;
188
172
let x = Array :: from ( vec ! [ x1] ) ;
189
173
let b0 = x. broadcast ( ( n, k) ) . unwrap ( ) ;
@@ -203,8 +187,8 @@ fn mat_mul_broadcast()
203
187
fn mat_mul_rev ( )
204
188
{
205
189
let ( m, n, k) = ( 16 , 16 , 16 ) ;
206
- let a = range_mat ( m, n) ;
207
- let b = range_mat ( n, k) ;
190
+ let a = range_mat :: < f32 > ( m, n) ;
191
+ let b = range_mat :: < f32 > ( n, k) ;
208
192
let mut rev = Array :: zeros ( b. dim ( ) ) ;
209
193
let mut rev = rev. slice_mut ( s ! [ ..; -1 , ..] ) ;
210
194
rev. assign ( & b) ;
@@ -233,8 +217,8 @@ fn mat_mut_zero_len()
233
217
}
234
218
}
235
219
} ) ;
236
- mat_mul_zero_len ! ( range_mat) ;
237
- mat_mul_zero_len ! ( range_mat64 ) ;
220
+ mat_mul_zero_len ! ( range_mat:: < f32 > ) ;
221
+ mat_mul_zero_len ! ( range_mat :: < f64 > ) ;
238
222
mat_mul_zero_len ! ( range_i32) ;
239
223
}
240
224
@@ -307,11 +291,11 @@ fn gen_mat_mul()
307
291
#[ test]
308
292
fn gemm_64_1_f ( )
309
293
{
310
- let a = range_mat64 ( 64 , 64 ) . reversed_axes ( ) ;
294
+ let a = range_mat :: < f64 > ( 64 , 64 ) . reversed_axes ( ) ;
311
295
let ( m, n) = a. dim ( ) ;
312
296
// m x n times n x 1 == m x 1
313
- let x = range_mat64 ( n, 1 ) ;
314
- let mut y = range_mat64 ( m, 1 ) ;
297
+ let x = range_mat :: < f64 > ( n, 1 ) ;
298
+ let mut y = range_mat :: < f64 > ( m, 1 ) ;
315
299
let answer = reference_mat_mul ( & a, & x) + & y;
316
300
general_mat_mul ( 1.0 , & a, & x, 1.0 , & mut y) ;
317
301
assert_relative_eq ! ( y, answer, epsilon = 1e-12 , max_relative = 1e-7 ) ;
@@ -393,11 +377,8 @@ fn gen_mat_vec_mul()
393
377
for & s1 in & [ 1 , 2 , -1 , -2 ] {
394
378
for & s2 in & [ 1 , 2 , -1 , -2 ] {
395
379
for & ( m, k) in & sizes {
396
- for & rev in & [ false , true ] {
397
- let mut a = range_mat64 ( m, k) ;
398
- if rev {
399
- a = a. reversed_axes ( ) ;
400
- }
380
+ for order in [ Order :: C , Order :: F ] {
381
+ let a = ArrayBuilder :: new ( ( m, k) ) . memory_order ( order) . build ( ) ;
401
382
let ( m, k) = a. dim ( ) ;
402
383
let b = range1_mat64 ( k) ;
403
384
let mut c = range1_mat64 ( m) ;
@@ -438,11 +419,8 @@ fn vec_mat_mul()
438
419
for & s1 in & [ 1 , 2 , -1 , -2 ] {
439
420
for & s2 in & [ 1 , 2 , -1 , -2 ] {
440
421
for & ( m, n) in & sizes {
441
- for & rev in & [ false , true ] {
442
- let mut b = range_mat64 ( m, n) ;
443
- if rev {
444
- b = b. reversed_axes ( ) ;
445
- }
422
+ for order in [ Order :: C , Order :: F ] {
423
+ let b = ArrayBuilder :: new ( ( m, n) ) . memory_order ( order) . build ( ) ;
446
424
let ( m, n) = b. dim ( ) ;
447
425
let a = range1_mat64 ( m) ;
448
426
let mut c = range1_mat64 ( n) ;
0 commit comments