From 3a39bf56dfb80b0a4a09391d15f3854b9ea2a42e Mon Sep 17 00:00:00 2001 From: Mathias Hauser Date: Tue, 19 Dec 2023 20:31:51 +0100 Subject: [PATCH] lin reg: only warn on superfluous predictors --- mesmer/stats/_linear_regression.py | 3 ++- tests/unit/test_linear_regression.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/mesmer/stats/_linear_regression.py b/mesmer/stats/_linear_regression.py index 4f84d4a5..4396e0bf 100644 --- a/mesmer/stats/_linear_regression.py +++ b/mesmer/stats/_linear_regression.py @@ -1,3 +1,4 @@ +import warnings from typing import Mapping, Optional import numpy as np @@ -88,7 +89,7 @@ def predict( if available_predictors - required_predictors: superfluous = sorted(available_predictors - required_predictors) superfluous = "', '".join(superfluous) - raise ValueError(f"Superfluous predictors: '{superfluous}'") + warnings.warn(f"Superfluous predictors: '{superfluous}'") if "intercept" in exclude: prediction = xr.zeros_like(params.intercept) diff --git a/tests/unit/test_linear_regression.py b/tests/unit/test_linear_regression.py index 8ab0d30e..c54584dd 100644 --- a/tests/unit/test_linear_regression.py +++ b/tests/unit/test_linear_regression.py @@ -110,17 +110,25 @@ def test_lr_predict_missing_superfluous(): ) lr.params = params + da = xr.DataArray([0, 1, 2], dims="time") + with pytest.raises(ValueError, match="Missing predictors: 'tas', 'tas2'"): lr.predict({}) with pytest.raises(ValueError, match="Missing predictors: 'tas'"): lr.predict({"tas2": None}) - with pytest.raises(ValueError, match="Superfluous predictors: 'something else'"): - lr.predict({"tas": None, "tas2": None, "something else": None}) + with pytest.warns(UserWarning, match="Superfluous predictors: 'something else'"): + result = lr.predict({"tas": da, "tas2": da, "something else": None}) + + expected = xr.DataArray([[5, 9, 13]], dims=("x", "time")) + xr.testing.assert_equal(result, expected) - with pytest.raises(ValueError, match="Superfluous predictors: 'bar', 'foo'"): - lr.predict({"tas": None, "tas2": None, "foo": None, "bar": None}) + with pytest.warns(UserWarning, match="Superfluous predictors: 'bar', 'foo'"): + result = lr.predict({"tas": da, "tas2": da, "foo": None, "bar": None}) + + expected = xr.DataArray([[5, 9, 13]], dims=("x", "time")) + xr.testing.assert_equal(result, expected) @pytest.mark.parametrize("as_2D", [True, False])