Skip to content

Commit 8429de8

Browse files
committed
MAINT: Add more tests for scaled_add
1 parent 245fcde commit 8429de8

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

tests/oper.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,42 @@ fn scaled_add() {
459459

460460
}
461461

462+
#[test]
463+
fn scaled_add_2() {
464+
let beta = -2.3;
465+
let sizes = vec![(4, 4, 1, 4),
466+
(8, 8, 1, 8),
467+
(17, 15, 17, 15),
468+
(4, 17, 4, 17),
469+
(17, 3, 1, 3),
470+
(19, 18, 19, 18),
471+
(16, 17, 16, 17),
472+
(15, 16, 15, 16),
473+
(67, 63, 1, 63),
474+
];
475+
// test different strides
476+
for &s1 in &[1, 2, -1, -2] {
477+
for &s2 in &[1, 2, -1, -2] {
478+
for &(m, k, n, q) in &sizes {
479+
let mut a = range_mat64(m, k);
480+
let mut answer = a.clone();
481+
let c = range_mat64(n, q);
482+
483+
{
484+
let mut av = a.slice_mut(s![..;s1, ..;s2]);
485+
let c = c.slice(s![..;s1, ..;s2]);
486+
487+
let mut answerv = answer.slice_mut(s![..;s1, ..;s2]);
488+
answerv += &(beta * &c);
489+
av.scaled_add(beta, &c);
490+
}
491+
assert_close(a.view(), answer.view());
492+
}
493+
}
494+
}
495+
}
496+
497+
462498
#[test]
463499
fn gen_mat_mul() {
464500
let alpha = -2.3;

0 commit comments

Comments
 (0)