From 72ad534b23c4f02eef4803eb83c8f1f445ebf4dc Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Tue, 6 Aug 2024 22:07:28 -0400 Subject: [PATCH 1/6] Fixes infinite recursion and off-by-one error --- src/tri.rs | 146 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 94 insertions(+), 52 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index 4eab9e105..2f490a236 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -10,7 +10,16 @@ use core::cmp::{max, min}; use num_traits::Zero; -use crate::{dimension::is_layout_f, Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip}; +use crate::{ + dimension::{is_layout_c, is_layout_f}, + Array, + ArrayBase, + Axis, + Data, + Dimension, + IntoDimension, + Zip, +}; impl ArrayBase where @@ -30,38 +39,44 @@ where /// ``` /// use ndarray::array; /// - /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; - /// let res = arr.triu(0); - /// assert_eq!(res, array![[1, 2, 3], [0, 5, 6], [0, 0, 9]]); + /// let arr = array![ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ]; + /// assert_eq!( + /// arr.triu(0), + /// array![ + /// [1, 2, 3], + /// [0, 5, 6], + /// [0, 0, 9] + /// ] + /// ); /// ``` pub fn triu(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } - match is_layout_f(&self.dim, &self.strides) { - true => { - let n = self.ndim(); - let mut x = self.view(); - x.swap_axes(n - 2, n - 1); - let mut tril = x.tril(-k); - tril.swap_axes(n - 2, n - 1); - - tril - } - false => { - let mut res = Array::zeros(self.raw_dim()); - Zip::indexed(self.rows()) - .and(res.rows_mut()) - .for_each(|i, src, mut dst| { - let row_num = i.into_dimension().last_elem(); - let lower = max(row_num as isize + k, 0); - dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); - }); - - res - } + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) { + let n = self.ndim(); + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.tril(-k); + tril.swap_axes(n - 2, n - 1); + + return tril; } + let mut res = Array::zeros(self.raw_dim()); + Zip::indexed(self.rows()) + .and(res.rows_mut()) + .for_each(|i, src, mut dst| { + let row_num = i.into_dimension().last_elem(); + let lower = max(row_num as isize + k, 0); + dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); + }); + + res } /// Lower triangular of an array. @@ -75,39 +90,45 @@ where /// ``` /// use ndarray::array; /// - /// let arr = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; - /// let res = arr.tril(0); - /// assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); + /// let arr = array![ + /// [1, 2, 3], + /// [4, 5, 6], + /// [7, 8, 9] + /// ]; + /// assert_eq!( + /// arr.tril(0), + /// array![ + /// [1, 0, 0], + /// [4, 5, 0], + /// [7, 8, 9] + /// ] + /// ); /// ``` pub fn tril(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } - match is_layout_f(&self.dim, &self.strides) { - true => { - let n = self.ndim(); - let mut x = self.view(); - x.swap_axes(n - 2, n - 1); - let mut tril = x.triu(-k); - tril.swap_axes(n - 2, n - 1); - - tril - } - false => { - let mut res = Array::zeros(self.raw_dim()); - let ncols = self.len_of(Axis(self.ndim() - 1)) as isize; - Zip::indexed(self.rows()) - .and(res.rows_mut()) - .for_each(|i, src, mut dst| { - let row_num = i.into_dimension().last_elem(); - let upper = min(row_num as isize + k, ncols) + 1; - dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); - }); - - res - } + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) { + let n = self.ndim(); + let mut x = self.view(); + x.swap_axes(n - 2, n - 1); + let mut tril = x.triu(-k); + tril.swap_axes(n - 2, n - 1); + + return tril; } + let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(self.ndim() - 1)) as isize; + Zip::indexed(self.rows()) + .and(res.rows_mut()) + .for_each(|i, src, mut dst| { + let row_num = i.into_dimension().last_elem(); + let upper = min(row_num as isize + k + 1, ncols); + dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); + }); + + res } } @@ -188,6 +209,19 @@ mod tests assert_eq!(res, array![[1, 0, 0], [4, 5, 0], [7, 8, 9]]); } + #[test] + fn test_2d_single() + { + let x = array![[1]]; + + assert_eq!(x.triu(0), array![[1]]); + assert_eq!(x.tril(0), array![[1]]); + assert_eq!(x.triu(1), array![[0]]); + assert_eq!(x.tril(1), array![[1]]); + assert_eq!(x.triu(-1), array![[1]]); + assert_eq!(x.tril(-1), array![[0]]); + } + #[test] fn test_3d() { @@ -285,8 +319,16 @@ mod tests let res = x.triu(0); assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); + let x = array![[1, 2, 3], [4, 5, 6]]; + let res = x.tril(0); + assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]); + let x = array![[1, 2], [3, 4], [5, 6]]; let res = x.triu(0); assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]); + + let x = array![[1, 2], [3, 4], [5, 6]]; + let res = x.tril(0); + assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]); } } From 4dbbca143084cb841a62d31701af02ce7c2481c2 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Wed, 7 Aug 2024 16:54:47 -0400 Subject: [PATCH 2/6] Avoids overflow using saturating arithmetic --- src/tri.rs | 35 +++++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 10 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index 2f490a236..059b6b78e 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use core::cmp::{max, min}; +use core::{cmp::min, isize}; use num_traits::Zero; @@ -58,8 +58,12 @@ where if self.ndim() <= 1 { return self.to_owned(); } - if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) { - let n = self.ndim(); + + // Performance optimization for F-order arrays. + // C-order array check prevents infinite recursion in edge cases like [[1]]. + // k-size check prevents underflow when k == isize::MIN + let n = self.ndim(); + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { let mut x = self.view(); x.swap_axes(n - 2, n - 1); let mut tril = x.tril(-k); @@ -67,12 +71,16 @@ where return tril; } + let mut res = Array::zeros(self.raw_dim()); Zip::indexed(self.rows()) .and(res.rows_mut()) .for_each(|i, src, mut dst| { let row_num = i.into_dimension().last_elem(); - let lower = max(row_num as isize + k, 0); + let lower = match k >= 0 { + true => row_num.saturating_add(k as usize), // Avoid overflow + false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 + }; dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); }); @@ -109,8 +117,12 @@ where if self.ndim() <= 1 { return self.to_owned(); } - if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) { - let n = self.ndim(); + + // Performance optimization for F-order arrays. + // C-order array check prevents infinite recursion in edge cases like [[1]]. + // k-size check prevents underflow when k == isize::MIN + let n = self.ndim(); + if is_layout_f(&self.dim, &self.strides) && !is_layout_c(&self.dim, &self.strides) && k > isize::MIN { let mut x = self.view(); x.swap_axes(n - 2, n - 1); let mut tril = x.triu(-k); @@ -118,13 +130,18 @@ where return tril; } + let mut res = Array::zeros(self.raw_dim()); - let ncols = self.len_of(Axis(self.ndim() - 1)) as isize; + let ncols = self.len_of(Axis(n - 1)); Zip::indexed(self.rows()) .and(res.rows_mut()) .for_each(|i, src, mut dst| { let row_num = i.into_dimension().last_elem(); - let upper = min(row_num as isize + k + 1, ncols); + let mut upper = match k >= 0 { + true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow + false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow + }; + upper = min(upper, ncols); dst.slice_mut(s![..upper]).assign(&src.slice(s![..upper])); }); @@ -319,7 +336,6 @@ mod tests let res = x.triu(0); assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); - let x = array![[1, 2, 3], [4, 5, 6]]; let res = x.tril(0); assert_eq!(res, array![[1, 0, 0], [4, 5, 0]]); @@ -327,7 +343,6 @@ mod tests let res = x.triu(0); assert_eq!(res, array![[1, 2], [0, 4], [0, 0]]); - let x = array![[1, 2], [3, 4], [5, 6]]; let res = x.tril(0); assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]); } From 1b29d175fc77d1551ec0258502c6c6f16a548a38 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Wed, 7 Aug 2024 17:04:34 -0400 Subject: [PATCH 3/6] Removes unused import --- src/tri.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tri.rs b/src/tri.rs index 059b6b78e..abc6d5d95 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -6,7 +6,7 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -use core::{cmp::min, isize}; +use core::cmp::min; use num_traits::Zero; From 44d6bb19e8b472e70523fc45b5aed7ce2305202a Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Wed, 7 Aug 2024 17:11:09 -0400 Subject: [PATCH 4/6] Fixes bug for isize::MAX for triu --- src/tri.rs | 57 ++++++++++++++++++++++++++---------------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index abc6d5d95..5686f2008 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -12,13 +12,7 @@ use num_traits::Zero; use crate::{ dimension::{is_layout_c, is_layout_f}, - Array, - ArrayBase, - Axis, - Data, - Dimension, - IntoDimension, - Zip, + Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip, }; impl ArrayBase @@ -53,8 +47,7 @@ where /// ] /// ); /// ``` - pub fn triu(&self, k: isize) -> Array - { + pub fn triu(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } @@ -73,14 +66,16 @@ where } let mut res = Array::zeros(self.raw_dim()); + let ncols = self.len_of(Axis(n - 1)); Zip::indexed(self.rows()) .and(res.rows_mut()) .for_each(|i, src, mut dst| { let row_num = i.into_dimension().last_elem(); - let lower = match k >= 0 { + let mut lower = match k >= 0 { true => row_num.saturating_add(k as usize), // Avoid overflow false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 }; + lower = min(lower, ncols); dst.slice_mut(s![lower..]).assign(&src.slice(s![lower..])); }); @@ -112,8 +107,7 @@ where /// ] /// ); /// ``` - pub fn tril(&self, k: isize) -> Array - { + pub fn tril(&self, k: isize) -> Array { if self.ndim() <= 1 { return self.to_owned(); } @@ -150,14 +144,14 @@ where } #[cfg(test)] -mod tests -{ +mod tests { + use core::isize; + use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; use alloc::vec; #[test] - fn test_keep_order() - { + fn test_keep_order() { let x = Array2::::ones((3, 3).f()); let res = x.triu(0); assert!(dimension::is_layout_f(&res.dim, &res.strides)); @@ -167,8 +161,7 @@ mod tests } #[test] - fn test_0d() - { + fn test_0d() { let x = Array0::::ones(()); let res = x.triu(0); assert_eq!(res, x); @@ -185,8 +178,7 @@ mod tests } #[test] - fn test_1d() - { + fn test_1d() { let x = array![1, 2, 3]; let res = x.triu(0); assert_eq!(res, x); @@ -203,8 +195,7 @@ mod tests } #[test] - fn test_2d() - { + fn test_2d() { let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; // Upper @@ -227,8 +218,7 @@ mod tests } #[test] - fn test_2d_single() - { + fn test_2d_single() { let x = array![[1]]; assert_eq!(x.triu(0), array![[1]]); @@ -240,8 +230,7 @@ mod tests } #[test] - fn test_3d() - { + fn test_3d() { let x = array![ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], @@ -300,8 +289,7 @@ mod tests } #[test] - fn test_off_axis() - { + fn test_off_axis() { let x = array![ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], @@ -330,8 +318,7 @@ mod tests } #[test] - fn test_odd_shape() - { + fn test_odd_shape() { let x = array![[1, 2, 3], [4, 5, 6]]; let res = x.triu(0); assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); @@ -346,4 +333,14 @@ mod tests let res = x.tril(0); assert_eq!(res, array![[1, 0], [3, 4], [5, 6]]); } + + #[test] + fn test_odd_k() { + let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + let z = Array2::zeros([3, 3]); + assert_eq!(x.triu(isize::MIN), x); + assert_eq!(x.tril(isize::MIN), z); + assert_eq!(x.triu(isize::MAX), z); + assert_eq!(x.tril(isize::MAX), x); + } } From 9e1491f8edbdc12605dd9a6295ba13cf640428c6 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Thu, 8 Aug 2024 00:29:34 -0400 Subject: [PATCH 5/6] Fix formatting --- src/tri.rs | 44 +++++++++++++++++++++++++++++++------------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index 5686f2008..53dc43fda 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -12,7 +12,13 @@ use num_traits::Zero; use crate::{ dimension::{is_layout_c, is_layout_f}, - Array, ArrayBase, Axis, Data, Dimension, IntoDimension, Zip, + Array, + ArrayBase, + Axis, + Data, + Dimension, + IntoDimension, + Zip, }; impl ArrayBase @@ -47,7 +53,8 @@ where /// ] /// ); /// ``` - pub fn triu(&self, k: isize) -> Array { + pub fn triu(&self, k: isize) -> Array + { if self.ndim() <= 1 { return self.to_owned(); } @@ -107,7 +114,8 @@ where /// ] /// ); /// ``` - pub fn tril(&self, k: isize) -> Array { + pub fn tril(&self, k: isize) -> Array + { if self.ndim() <= 1 { return self.to_owned(); } @@ -144,14 +152,16 @@ where } #[cfg(test)] -mod tests { +mod tests +{ use core::isize; use crate::{array, dimension, Array0, Array1, Array2, Array3, ShapeBuilder}; use alloc::vec; #[test] - fn test_keep_order() { + fn test_keep_order() + { let x = Array2::::ones((3, 3).f()); let res = x.triu(0); assert!(dimension::is_layout_f(&res.dim, &res.strides)); @@ -161,7 +171,8 @@ mod tests { } #[test] - fn test_0d() { + fn test_0d() + { let x = Array0::::ones(()); let res = x.triu(0); assert_eq!(res, x); @@ -178,7 +189,8 @@ mod tests { } #[test] - fn test_1d() { + fn test_1d() + { let x = array![1, 2, 3]; let res = x.triu(0); assert_eq!(res, x); @@ -195,7 +207,8 @@ mod tests { } #[test] - fn test_2d() { + fn test_2d() + { let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; // Upper @@ -218,7 +231,8 @@ mod tests { } #[test] - fn test_2d_single() { + fn test_2d_single() + { let x = array![[1]]; assert_eq!(x.triu(0), array![[1]]); @@ -230,7 +244,8 @@ mod tests { } #[test] - fn test_3d() { + fn test_3d() + { let x = array![ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], @@ -289,7 +304,8 @@ mod tests { } #[test] - fn test_off_axis() { + fn test_off_axis() + { let x = array![ [[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]], @@ -318,7 +334,8 @@ mod tests { } #[test] - fn test_odd_shape() { + fn test_odd_shape() + { let x = array![[1, 2, 3], [4, 5, 6]]; let res = x.triu(0); assert_eq!(res, array![[1, 2, 3], [0, 5, 6]]); @@ -335,7 +352,8 @@ mod tests { } #[test] - fn test_odd_k() { + fn test_odd_k() + { let x = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; let z = Array2::zeros([3, 3]); assert_eq!(x.triu(isize::MIN), x); From 5e5f551b04ca8443e9e9dfd248f4fbc353b22cd9 Mon Sep 17 00:00:00 2001 From: Adam Kern Date: Thu, 8 Aug 2024 16:28:37 -0400 Subject: [PATCH 6/6] Uses broadcast indices to remove D::Smaller: Copy trait bound --- src/tri.rs | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/src/tri.rs b/src/tri.rs index 53dc43fda..b7d297fcc 100644 --- a/src/tri.rs +++ b/src/tri.rs @@ -17,7 +17,6 @@ use crate::{ Axis, Data, Dimension, - IntoDimension, Zip, }; @@ -26,7 +25,6 @@ where S: Data, D: Dimension, A: Clone + Zero, - D::Smaller: Copy, { /// Upper triangular of an array. /// @@ -74,10 +72,12 @@ where let mut res = Array::zeros(self.raw_dim()); let ncols = self.len_of(Axis(n - 1)); - Zip::indexed(self.rows()) + let nrows = self.len_of(Axis(n - 2)); + let indices = Array::from_iter(0..nrows); + Zip::from(self.rows()) .and(res.rows_mut()) - .for_each(|i, src, mut dst| { - let row_num = i.into_dimension().last_elem(); + .and_broadcast(&indices) + .for_each(|src, mut dst, row_num| { let mut lower = match k >= 0 { true => row_num.saturating_add(k as usize), // Avoid overflow false => row_num.saturating_sub(k.unsigned_abs()), // Avoid underflow, go to 0 @@ -135,10 +135,13 @@ where let mut res = Array::zeros(self.raw_dim()); let ncols = self.len_of(Axis(n - 1)); - Zip::indexed(self.rows()) + let nrows = self.len_of(Axis(n - 2)); + let indices = Array::from_iter(0..nrows); + Zip::from(self.rows()) .and(res.rows_mut()) - .for_each(|i, src, mut dst| { - let row_num = i.into_dimension().last_elem(); + .and_broadcast(&indices) + .for_each(|src, mut dst, row_num| { + // let row_num = i.into_dimension().last_elem(); let mut upper = match k >= 0 { true => row_num.saturating_add(k as usize).saturating_add(1), // Avoid overflow false => row_num.saturating_sub((k + 1).unsigned_abs()), // Avoid underflow