From e7c9381800f909bbbc2706081a6cee5b2044326d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 19:33:39 +0100 Subject: [PATCH 01/14] Quick and dirty linear regression --- examples/linear_regression.rs | 69 +++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 examples/linear_regression.rs diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs new file mode 100644 index 00000000..42055cac --- /dev/null +++ b/examples/linear_regression.rs @@ -0,0 +1,69 @@ +use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Data}; +use ndarray_linalg::{Solve}; + + +/// The simple linear regression model is +/// y = bX + e where e ~ N(0, sigma^2 * I) +/// In probabilistic terms this corresponds to +/// y - bX ~ N(0, sigma^2 * I) +/// y | X, b ~ N(bX, sigma^2 * I) +/// The loss for the model is simply the squared error between the model +/// predictions and the true values: +/// Loss = ||y - bX||^2 +/// The MLE for the model parameters b can be computed in closed form via the +/// normal equation: +/// b = (X^T X)^{-1} X^T y +/// where (X^T X)^{-1} X^T is known as the pseudoinverse / Moore-Penrose +/// inverse. +struct LinearRegression { + beta: Option>, + fit_intercept: bool, +} + +impl LinearRegression { + fn new(fit_intercept: bool) -> LinearRegression { + LinearRegression { + beta: None, + fit_intercept + } + } + + fn fit(&mut self, mut X: Array2, y: Array1) { + let (n_samples, n_features) = X.dim(); + + // Check that our inputs have compatible shapes + assert_eq!(y.dim(), n_samples); + + // If we are fitting the intercept, we need an additional column + if self.fit_intercept { + let dummy_column: Array = Array::ones((n_samples, 1)); + X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); + }; + + let rhs = X.t().dot(&y); + let linear_operator = X.t().dot(&X); + self.beta = Some(linear_operator.solve_into(rhs).unwrap()); + } + + fn predict(&self, mut X: &mut ArrayBase) -> Array1 + where + A: Data, + { + let (n_samples, n_features) = X.dim(); + + // If we are fitting the intercept, we need an additional column + let X = if self.fit_intercept { + let dummy_column: Array = Array::ones((n_samples, 1)); + stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap() + } else { + X.to_owned() + }; + + match &self.beta { + None => panic!("The linear regression estimator has to be fitted first!"), + Some(beta) => { + X.dot(beta) + } + } + } +} From 7f1a2e1e7ffc50af57b2327d083950ed5fb07758 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 19:57:58 +0100 Subject: [PATCH 02/14] Basic version works --- Cargo.toml | 1 + examples/linear_regression.rs | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9257b5dd..f007119b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -49,3 +49,4 @@ optional = true [dev-dependencies] paste = "0.1" +ndarray-stats = {git = "https://github.com/rust-ndarray/ndarray-stats", branch = "master"} diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 42055cac..43d20a11 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -1,5 +1,6 @@ use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Data}; -use ndarray_linalg::{Solve}; +use ndarray_linalg::{Solve, random}; +use ndarray_stats::DeviationExt; /// The simple linear regression model is @@ -45,7 +46,7 @@ impl LinearRegression { self.beta = Some(linear_operator.solve_into(rhs).unwrap()); } - fn predict(&self, mut X: &mut ArrayBase) -> Array1 + fn predict(&self, mut X: &ArrayBase) -> Array1 where A: Data, { @@ -67,3 +68,25 @@ impl LinearRegression { } } } + +fn get_data(n_train_samples: usize, n_test_samples: usize, n_features: usize) -> ( + Array2, Array2, Array1, Array1 +) { + let X_train: Array2 = random((n_train_samples, n_features)); + let y_train: Array1 = random(n_train_samples); + let X_test: Array2 = random((n_test_samples, n_features)); + let y_test: Array1 = random(n_test_samples); + (X_train, X_test, y_train, y_test) +} + +pub fn main() { + let n_train_samples = 5000; + let n_test_samples = 1000; + let n_features = 15; + let (X_train, X_test, y_train, y_test) = get_data(n_train_samples, n_test_samples, n_features); + let mut linear_regressor = LinearRegression::new(true); + linear_regressor.fit(X_train, y_train); + let test_predictions = linear_regressor.predict(&X_test); + let mean_squared_error = test_predictions.sq_l2_dist(&y_test).unwrap(); + println!("The fitted regressor has a root mean squared error of {:}", mean_squared_error); +} From c4c78e73e93e716f3bbb024de1bc417723247563 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 19:59:49 +0100 Subject: [PATCH 03/14] Fix warnings --- examples/linear_regression.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 43d20a11..836ac223 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -1,3 +1,4 @@ +#![allow(non_snake_case)] use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Data}; use ndarray_linalg::{Solve, random}; use ndarray_stats::DeviationExt; @@ -30,7 +31,7 @@ impl LinearRegression { } fn fit(&mut self, mut X: Array2, y: Array1) { - let (n_samples, n_features) = X.dim(); + let (n_samples, _) = X.dim(); // Check that our inputs have compatible shapes assert_eq!(y.dim(), n_samples); @@ -46,11 +47,11 @@ impl LinearRegression { self.beta = Some(linear_operator.solve_into(rhs).unwrap()); } - fn predict(&self, mut X: &ArrayBase) -> Array1 + fn predict(&self, X: &ArrayBase) -> Array1 where A: Data, { - let (n_samples, n_features) = X.dim(); + let (n_samples, _) = X.dim(); // If we are fitting the intercept, we need an additional column let X = if self.fit_intercept { From ea7b475142896d29f129d0e827a48ad6e9c90e2a Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 20:52:02 +0100 Subject: [PATCH 04/14] Proper generation of data and target --- Cargo.toml | 2 ++ examples/linear_regression.rs | 48 ++++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f007119b..b3b5c9d7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,3 +50,5 @@ optional = true [dev-dependencies] paste = "0.1" ndarray-stats = {git = "https://github.com/rust-ndarray/ndarray-stats", branch = "master"} +ndarray-rand = "0.9" +rand = "0.6" diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 836ac223..108bac06 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -2,7 +2,8 @@ use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Data}; use ndarray_linalg::{Solve, random}; use ndarray_stats::DeviationExt; - +use ndarray_rand::RandomExt; +use rand::distributions::StandardNormal; /// The simple linear regression model is /// y = bX + e where e ~ N(0, sigma^2 * I) @@ -18,7 +19,7 @@ use ndarray_stats::DeviationExt; /// where (X^T X)^{-1} X^T is known as the pseudoinverse / Moore-Penrose /// inverse. struct LinearRegression { - beta: Option>, + pub beta: Option>, fit_intercept: bool, } @@ -30,7 +31,7 @@ impl LinearRegression { } } - fn fit(&mut self, mut X: Array2, y: Array1) { + fn fit(&mut self, mut X: Array2, y: Array1) { let (n_samples, _) = X.dim(); // Check that our inputs have compatible shapes @@ -38,7 +39,7 @@ impl LinearRegression { // If we are fitting the intercept, we need an additional column if self.fit_intercept { - let dummy_column: Array = Array::ones((n_samples, 1)); + let dummy_column: Array = Array::ones((n_samples, 1)); X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); }; @@ -47,15 +48,15 @@ impl LinearRegression { self.beta = Some(linear_operator.solve_into(rhs).unwrap()); } - fn predict(&self, X: &ArrayBase) -> Array1 + fn predict(&self, X: &ArrayBase) -> Array1 where - A: Data, + A: Data, { let (n_samples, _) = X.dim(); // If we are fitting the intercept, we need an additional column let X = if self.fit_intercept { - let dummy_column: Array = Array::ones((n_samples, 1)); + let dummy_column: Array = Array::ones((n_samples, 1)); stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap() } else { X.to_owned() @@ -70,24 +71,31 @@ impl LinearRegression { } } -fn get_data(n_train_samples: usize, n_test_samples: usize, n_features: usize) -> ( - Array2, Array2, Array1, Array1 +fn get_data(n_samples: usize, n_features: usize) -> ( + Array2, Array1 ) { - let X_train: Array2 = random((n_train_samples, n_features)); - let y_train: Array1 = random(n_train_samples); - let X_test: Array2 = random((n_test_samples, n_features)); - let y_test: Array1 = random(n_test_samples); - (X_train, X_test, y_train, y_test) + let shape = (n_samples, n_features); + let noise: Array1 = Array::random(n_samples, StandardNormal); + + let beta: Array1 = random(n_features) * 100.; + println!("Beta used to generate target variable: {:.3}", beta); + + let X: Array2 = random(shape); + let y: Array1 = X.dot(&beta) + noise; + (X, y) } pub fn main() { let n_train_samples = 5000; let n_test_samples = 1000; - let n_features = 15; - let (X_train, X_test, y_train, y_test) = get_data(n_train_samples, n_test_samples, n_features); - let mut linear_regressor = LinearRegression::new(true); - linear_regressor.fit(X_train, y_train); + let n_features = 3; + let (X, y) = get_data(n_train_samples + n_test_samples, n_features); + let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); + let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); + let mut linear_regressor = LinearRegression::new(false); + linear_regressor.fit(X_train.to_owned(), y_train.to_owned()); let test_predictions = linear_regressor.predict(&X_test); - let mean_squared_error = test_predictions.sq_l2_dist(&y_test).unwrap(); - println!("The fitted regressor has a root mean squared error of {:}", mean_squared_error); + let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); + println!("Beta estimated from the training data: {:.3}", linear_regressor.beta.unwrap()); + println!("The fitted regressor has a root mean squared error of {:.3}", mean_squared_error); } From 00b9fd40366ad2f729ffa3a6706b2eaa70ba180b Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 20:55:34 +0100 Subject: [PATCH 05/14] Tune range --- examples/linear_regression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 108bac06..18d04947 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -77,7 +77,7 @@ fn get_data(n_samples: usize, n_features: usize) -> ( let shape = (n_samples, n_features); let noise: Array1 = Array::random(n_samples, StandardNormal); - let beta: Array1 = random(n_features) * 100.; + let beta: Array1 = random(n_features) * 10.; println!("Beta used to generate target variable: {:.3}", beta); let X: Array2 = random(shape); From 71ceb20f32d70691226950257c75e916d63289a0 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 21:07:09 +0100 Subject: [PATCH 06/14] Polishing --- examples/linear_regression.rs | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 18d04947..0af38429 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -1,5 +1,5 @@ #![allow(non_snake_case)] -use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Data}; +use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Ix1, Data}; use ndarray_linalg::{Solve, random}; use ndarray_stats::DeviationExt; use ndarray_rand::RandomExt; @@ -13,11 +13,12 @@ use rand::distributions::StandardNormal; /// The loss for the model is simply the squared error between the model /// predictions and the true values: /// Loss = ||y - bX||^2 -/// The MLE for the model parameters b can be computed in closed form via the -/// normal equation: +/// The maximum likelihood estimation for the model parameters `beta` can be computed +/// in closed form via the normal equation: /// b = (X^T X)^{-1} X^T y -/// where (X^T X)^{-1} X^T is known as the pseudoinverse / Moore-Penrose -/// inverse. +/// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse. +/// +/// Adapted from: https://github.com/xinscrs/numpy-ml struct LinearRegression { pub beta: Option>, fit_intercept: bool, @@ -31,21 +32,34 @@ impl LinearRegression { } } - fn fit(&mut self, mut X: Array2, y: Array1) { + fn fit(&mut self, X: ArrayBase, y: ArrayBase) + where + A: Data, + B: Data, + { let (n_samples, _) = X.dim(); // Check that our inputs have compatible shapes assert_eq!(y.dim(), n_samples); // If we are fitting the intercept, we need an additional column - if self.fit_intercept { + self.beta = if self.fit_intercept { let dummy_column: Array = Array::ones((n_samples, 1)); - X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); + let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); + Some(LinearRegression::solve_normal_equation(X, y)) + } else { + Some(LinearRegression::solve_normal_equation(X, y)) }; + } + fn solve_normal_equation(X: ArrayBase, y: ArrayBase) -> Array1 + where + A: Data, + B: Data, + { let rhs = X.t().dot(&y); let linear_operator = X.t().dot(&X); - self.beta = Some(linear_operator.solve_into(rhs).unwrap()); + linear_operator.solve_into(rhs).unwrap() } fn predict(&self, X: &ArrayBase) -> Array1 @@ -93,7 +107,7 @@ pub fn main() { let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); let mut linear_regressor = LinearRegression::new(false); - linear_regressor.fit(X_train.to_owned(), y_train.to_owned()); + linear_regressor.fit(X_train, y_train); let test_predictions = linear_regressor.predict(&X_test); let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); println!("Beta estimated from the training data: {:.3}", linear_regressor.beta.unwrap()); From ddbd5e13ace0f7b590176da23de45146abe933d3 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 23:06:21 +0100 Subject: [PATCH 07/14] Polish predict function --- examples/linear_regression.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 0af38429..d23af1f7 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -69,13 +69,19 @@ impl LinearRegression { let (n_samples, _) = X.dim(); // If we are fitting the intercept, we need an additional column - let X = if self.fit_intercept { + if self.fit_intercept { let dummy_column: Array = Array::ones((n_samples, 1)); - stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap() + let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); + self._predict(X) } else { - X.to_owned() - }; + self._predict(X) + } + } + fn _predict(&self, X: &ArrayBase) -> Array1 + where + A: Data, + { match &self.beta { None => panic!("The linear regression estimator has to be fitted first!"), Some(beta) => { From 208f451d6b8351209f8ded67e4965274a2989601 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 23:14:03 +0100 Subject: [PATCH 08/14] Missing reference --- examples/linear_regression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index d23af1f7..8bc455d4 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -72,7 +72,7 @@ impl LinearRegression { if self.fit_intercept { let dummy_column: Array = Array::ones((n_samples, 1)); let X = stack(Axis(1), &[dummy_column.view(), X.view()]).unwrap(); - self._predict(X) + self._predict(&X) } else { self._predict(X) } From 53f8124511c1a884e048d76bcb692e58b72458f7 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 23:14:37 +0100 Subject: [PATCH 09/14] Spacing --- examples/linear_regression.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 8bc455d4..8aaf955f 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -109,11 +109,14 @@ pub fn main() { let n_train_samples = 5000; let n_test_samples = 1000; let n_features = 3; + let (X, y) = get_data(n_train_samples + n_test_samples, n_features); let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); + let mut linear_regressor = LinearRegression::new(false); linear_regressor.fit(X_train, y_train); + let test_predictions = linear_regressor.predict(&X_test); let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); println!("Beta estimated from the training data: {:.3}", linear_regressor.beta.unwrap()); From d4be0bbebb4b6ba46a07cb6c8ff458929baafae8 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Wed, 17 Jul 2019 23:15:07 +0100 Subject: [PATCH 10/14] Run cargo fmt --- examples/linear_regression.rs | 42 ++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index 8aaf955f..b888d1d4 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -1,8 +1,8 @@ #![allow(non_snake_case)] -use ndarray::{Array1, ArrayBase, Array2, stack, Axis, Array, Ix2, Ix1, Data}; -use ndarray_linalg::{Solve, random}; -use ndarray_stats::DeviationExt; +use ndarray::{stack, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray_linalg::{random, Solve}; use ndarray_rand::RandomExt; +use ndarray_stats::DeviationExt; use rand::distributions::StandardNormal; /// The simple linear regression model is @@ -28,14 +28,14 @@ impl LinearRegression { fn new(fit_intercept: bool) -> LinearRegression { LinearRegression { beta: None, - fit_intercept + fit_intercept, } } fn fit(&mut self, X: ArrayBase, y: ArrayBase) where - A: Data, - B: Data, + A: Data, + B: Data, { let (n_samples, _) = X.dim(); @@ -53,9 +53,9 @@ impl LinearRegression { } fn solve_normal_equation(X: ArrayBase, y: ArrayBase) -> Array1 - where - A: Data, - B: Data, + where + A: Data, + B: Data, { let rhs = X.t().dot(&y); let linear_operator = X.t().dot(&X); @@ -64,7 +64,7 @@ impl LinearRegression { fn predict(&self, X: &ArrayBase) -> Array1 where - A: Data, + A: Data, { let (n_samples, _) = X.dim(); @@ -79,21 +79,17 @@ impl LinearRegression { } fn _predict(&self, X: &ArrayBase) -> Array1 - where - A: Data, + where + A: Data, { match &self.beta { None => panic!("The linear regression estimator has to be fitted first!"), - Some(beta) => { - X.dot(beta) - } + Some(beta) => X.dot(beta), } } } -fn get_data(n_samples: usize, n_features: usize) -> ( - Array2, Array1 -) { +fn get_data(n_samples: usize, n_features: usize) -> (Array2, Array1) { let shape = (n_samples, n_features); let noise: Array1 = Array::random(n_samples, StandardNormal); @@ -119,6 +115,12 @@ pub fn main() { let test_predictions = linear_regressor.predict(&X_test); let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); - println!("Beta estimated from the training data: {:.3}", linear_regressor.beta.unwrap()); - println!("The fitted regressor has a root mean squared error of {:.3}", mean_squared_error); + println!( + "Beta estimated from the training data: {:.3}", + linear_regressor.beta.unwrap() + ); + println!( + "The fitted regressor has a root mean squared error of {:.3}", + mean_squared_error + ); } From 673c41d53cccf7fa507b7851ba8c192765114245 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 20 Jul 2019 19:14:10 +0100 Subject: [PATCH 11/14] Fix typo in println --- examples/linear_regression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/linear_regression.rs b/examples/linear_regression.rs index b888d1d4..77f2f9b8 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression.rs @@ -120,7 +120,7 @@ pub fn main() { linear_regressor.beta.unwrap() ); println!( - "The fitted regressor has a root mean squared error of {:.3}", + "The fitted regressor has a mean squared error of {:.3}", mean_squared_error ); } From de20e5ff1f89104969525d9b8fa907d9a94a91b9 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 20 Jul 2019 19:25:20 +0100 Subject: [PATCH 12/14] Split in two files, with proper visibility for methods on LinearRegression --- .../linear_regression.rs | 75 +++++-------------- examples/linear_regression/main.rs | 49 ++++++++++++ 2 files changed, 67 insertions(+), 57 deletions(-) rename examples/{ => linear_regression}/linear_regression.rs (59%) create mode 100644 examples/linear_regression/main.rs diff --git a/examples/linear_regression.rs b/examples/linear_regression/linear_regression.rs similarity index 59% rename from examples/linear_regression.rs rename to examples/linear_regression/linear_regression.rs index 77f2f9b8..1dda4bfe 100644 --- a/examples/linear_regression.rs +++ b/examples/linear_regression/linear_regression.rs @@ -1,9 +1,6 @@ #![allow(non_snake_case)] -use ndarray::{stack, Array, Array1, Array2, ArrayBase, Axis, Data, Ix1, Ix2}; -use ndarray_linalg::{random, Solve}; -use ndarray_rand::RandomExt; -use ndarray_stats::DeviationExt; -use rand::distributions::StandardNormal; +use ndarray::{stack, Array, Array1, ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray_linalg::Solve; /// The simple linear regression model is /// y = bX + e where e ~ N(0, sigma^2 * I) @@ -19,20 +16,20 @@ use rand::distributions::StandardNormal; /// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse. /// /// Adapted from: https://github.com/xinscrs/numpy-ml -struct LinearRegression { +pub struct LinearRegression { pub beta: Option>, fit_intercept: bool, } impl LinearRegression { - fn new(fit_intercept: bool) -> LinearRegression { + pub fn new(fit_intercept: bool) -> LinearRegression { LinearRegression { beta: None, fit_intercept, } } - fn fit(&mut self, X: ArrayBase, y: ArrayBase) + pub fn fit(&mut self, X: ArrayBase, y: ArrayBase) where A: Data, B: Data, @@ -52,19 +49,9 @@ impl LinearRegression { }; } - fn solve_normal_equation(X: ArrayBase, y: ArrayBase) -> Array1 - where - A: Data, - B: Data, - { - let rhs = X.t().dot(&y); - let linear_operator = X.t().dot(&X); - linear_operator.solve_into(rhs).unwrap() - } - - fn predict(&self, X: &ArrayBase) -> Array1 - where - A: Data, + pub fn predict(&self, X: &ArrayBase) -> Array1 + where + A: Data, { let (n_samples, _) = X.dim(); @@ -78,6 +65,16 @@ impl LinearRegression { } } + fn solve_normal_equation(X: ArrayBase, y: ArrayBase) -> Array1 + where + A: Data, + B: Data, + { + let rhs = X.t().dot(&y); + let linear_operator = X.t().dot(&X); + linear_operator.solve_into(rhs).unwrap() + } + fn _predict(&self, X: &ArrayBase) -> Array1 where A: Data, @@ -88,39 +85,3 @@ impl LinearRegression { } } } - -fn get_data(n_samples: usize, n_features: usize) -> (Array2, Array1) { - let shape = (n_samples, n_features); - let noise: Array1 = Array::random(n_samples, StandardNormal); - - let beta: Array1 = random(n_features) * 10.; - println!("Beta used to generate target variable: {:.3}", beta); - - let X: Array2 = random(shape); - let y: Array1 = X.dot(&beta) + noise; - (X, y) -} - -pub fn main() { - let n_train_samples = 5000; - let n_test_samples = 1000; - let n_features = 3; - - let (X, y) = get_data(n_train_samples + n_test_samples, n_features); - let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); - let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); - - let mut linear_regressor = LinearRegression::new(false); - linear_regressor.fit(X_train, y_train); - - let test_predictions = linear_regressor.predict(&X_test); - let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); - println!( - "Beta estimated from the training data: {:.3}", - linear_regressor.beta.unwrap() - ); - println!( - "The fitted regressor has a mean squared error of {:.3}", - mean_squared_error - ); -} diff --git a/examples/linear_regression/main.rs b/examples/linear_regression/main.rs new file mode 100644 index 00000000..4e76edb5 --- /dev/null +++ b/examples/linear_regression/main.rs @@ -0,0 +1,49 @@ +#![allow(non_snake_case)] +use ndarray::{Array1, Array2, Array, Axis}; +use ndarray_linalg::random; +use ndarray_stats::DeviationExt; +use ndarray_rand::RandomExt; +use rand::distributions::StandardNormal; + +// Import LinearRegression from other file ("module") in this example +mod linear_regression; +use linear_regression::LinearRegression; + +/// It returns a tuple: input data and the associated target variable. +/// +/// The target variable is a linear function of the input, perturbed by gaussian noise. +fn get_data(n_samples: usize, n_features: usize) -> (Array2, Array1) { + let shape = (n_samples, n_features); + let noise: Array1 = Array::random(n_samples, StandardNormal); + + let beta: Array1 = random(n_features) * 10.; + println!("Beta used to generate target variable: {:.3}", beta); + + let X: Array2 = random(shape); + let y: Array1 = X.dot(&beta) + noise; + (X, y) +} + +pub fn main() { + let n_train_samples = 5000; + let n_test_samples = 1000; + let n_features = 3; + + let (X, y) = get_data(n_train_samples + n_test_samples, n_features); + let (X_train, X_test) = X.view().split_at(Axis(0), n_train_samples); + let (y_train, y_test) = y.view().split_at(Axis(0), n_train_samples); + + let mut linear_regressor = LinearRegression::new(false); + linear_regressor.fit(X_train, y_train); + + let test_predictions = linear_regressor.predict(&X_test); + let mean_squared_error = test_predictions.mean_sq_err(&y_test.to_owned()).unwrap(); + println!( + "Beta estimated from the training data: {:.3}", + linear_regressor.beta.unwrap() + ); + println!( + "The fitted regressor has a mean squared error of {:.3}", + mean_squared_error + ); +} From cdd3c5d975ff0a1f7afe523c6a7ced13d485748d Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 20 Jul 2019 19:29:50 +0100 Subject: [PATCH 13/14] Add docs. --- examples/linear_regression/linear_regression.rs | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/examples/linear_regression/linear_regression.rs b/examples/linear_regression/linear_regression.rs index 1dda4bfe..8f3cbd8d 100644 --- a/examples/linear_regression/linear_regression.rs +++ b/examples/linear_regression/linear_regression.rs @@ -29,6 +29,13 @@ impl LinearRegression { } } + /// Given: + /// - an input matrix `X`, with shape `(n_samples, n_features)`; + /// - a target variable `y`, with shape `(n_samples,)`; + /// `fit` tunes the `beta` parameter of the linear regression model + /// to match the training data distribution. + /// + /// `self` is modified in place, nothing is returned. pub fn fit(&mut self, X: ArrayBase, y: ArrayBase) where A: Data, @@ -49,6 +56,11 @@ impl LinearRegression { }; } + /// Given an input matrix `X`, with shape `(n_samples, n_features)`, + /// `predict` returns the target variable according to linear model + /// learned from the training data distribution. + /// + /// **Panics** if `self` has not be `fit`ted before calling `predict. pub fn predict(&self, X: &ArrayBase) -> Array1 where A: Data, From 85e11f8be9ede442cc576c08bf0ff726a1730597 Mon Sep 17 00:00:00 2001 From: LukeMathWalker Date: Sat, 20 Jul 2019 19:30:38 +0100 Subject: [PATCH 14/14] Quote the original repo, not forks --- examples/linear_regression/linear_regression.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/linear_regression/linear_regression.rs b/examples/linear_regression/linear_regression.rs index 8f3cbd8d..f11956be 100644 --- a/examples/linear_regression/linear_regression.rs +++ b/examples/linear_regression/linear_regression.rs @@ -15,7 +15,7 @@ use ndarray_linalg::Solve; /// b = (X^T X)^{-1} X^T y /// where (X^T X)^{-1} X^T is known as the pseudoinverse or Moore-Penrose inverse. /// -/// Adapted from: https://github.com/xinscrs/numpy-ml +/// Adapted from: https://github.com/ddbourgin/numpy-ml pub struct LinearRegression { pub beta: Option>, fit_intercept: bool,