Skip to content

Commit

Permalink
enh: keep target dtype in output (#865)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Sep 12, 2024
1 parent 1df8480 commit 9cf7424
Show file tree
Hide file tree
Showing 6 changed files with 937 additions and 852 deletions.
22 changes: 11 additions & 11 deletions nbs/src/core/core.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@
" def _output_fcst(self, models, attr, h, X, level=tuple()):\n",
" #returns empty output according to method\n",
" cuts, has_level_models = self._get_cols(models=models, attr=attr, h=h, X=X, level=level)\n",
" out = np.full((self.n_groups * h, cuts[-1]), fill_value=np.nan, dtype=np.float32)\n",
" out = np.full((self.n_groups * h, cuts[-1]), fill_value=np.nan, dtype=self.data.dtype)\n",
" return out, cuts, has_level_models\n",
"\n",
" def predict(self, fm, h, X=None, level=tuple()):\n",
Expand Down Expand Up @@ -292,7 +292,7 @@
" if fitted:\n",
" #for the moment we dont return levels for fitted values in \n",
" #forecast mode\n",
" fitted_vals = np.full((self.data.shape[0], 1 + cuts[-1]), np.nan, dtype=np.float32)\n",
" fitted_vals = np.full((self.data.shape[0], 1 + cuts[-1]), np.nan, dtype=self.data.dtype)\n",
" if self.data.ndim == 1:\n",
" fitted_vals[:, 0] = self.data\n",
" else:\n",
Expand Down Expand Up @@ -367,9 +367,9 @@
" n_models = len(models)\n",
" cuts, has_level_models = self._get_cols(models=models, attr='forecast', h=h, X=None, level=level)\n",
" # first column of out is the actual y\n",
" out = np.full((self.n_groups, n_windows, h, 1 + cuts[-1]), np.nan, dtype=np.float32)\n",
" out = np.full((self.n_groups, n_windows, h, 1 + cuts[-1]), np.nan, dtype=self.data.dtype)\n",
" if fitted:\n",
" fitted_vals = np.full((self.data.shape[0], n_windows, n_models + 1), np.nan, dtype=np.float32)\n",
" fitted_vals = np.full((self.data.shape[0], n_windows, n_models + 1), np.nan, dtype=self.data.dtype)\n",
" fitted_idxs = np.full((self.data.shape[0], n_windows), False, dtype=bool)\n",
" last_fitted_idxs = np.full_like(fitted_idxs, False, dtype=bool)\n",
" matches = ['mean', 'lo', 'hi']\n",
Expand Down Expand Up @@ -554,7 +554,7 @@
" \n",
" def fit(self, y, X):\n",
" self.last_value = y[-1]\n",
" self.fitted_values = np.full(y.size, np.nan, np.float32)\n",
" self.fitted_values = np.full(y.size, np.nan, dtype=y.dtype)\n",
" self.fitted_values[1:] = y[:1]\n",
" return self\n",
" \n",
Expand All @@ -574,7 +574,7 @@
" mean = y[-1] + np.arange(1, h + 1)\n",
" res = {'mean': mean}\n",
" if fitted:\n",
" fitted_values = np.full(y.size, np.nan, np.float32)\n",
" fitted_values = np.full(y.size, np.nan, dtype=y.dtype)\n",
" fitted_values[1:] = y[1:]\n",
" res['fitted'] = fitted_values\n",
" if level is not None:\n",
Expand All @@ -588,7 +588,7 @@
" mean = self.last_value + np.arange(1, h + 1)\n",
" res = {'mean': mean}\n",
" if fitted:\n",
" fitted_values = np.full(y.size, np.nan, np.float32)\n",
" fitted_values = np.full(y.size, np.nan, dtype=mean.dtype)\n",
" fitted_values[1:] = y[1:]\n",
" res['fitted'] = fitted_values\n",
" if level is not None:\n",
Expand Down Expand Up @@ -1736,7 +1736,7 @@
"\n",
" def _forecast_parallel(self, h, fitted, X, level, target_col):\n",
" n_series = self.ga.n_groups\n",
" forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=np.float32))\n",
" forecast_res = defaultdict(lambda: np.empty(n_series * h, dtype=self.ga.data.dtype))\n",
" fitted_res = defaultdict(\n",
" lambda: np.empty(self.ga.data.shape[0], dtype=self.ga.data.dtype)\n",
" )\n",
Expand Down Expand Up @@ -3029,7 +3029,7 @@
" test_eq(pd.to_datetime(series['ds']), fitted['ds'])\n",
" else:\n",
" test_eq(series['ds'], fitted['ds'])\n",
" test_eq(series['y'].astype(np.float32), fitted['y'])\n",
" test_eq(series['y'], fitted['y'])\n",
"test_fcst_fitted(series)\n",
"test_fcst_fitted(series, str_ds=True)"
]
Expand All @@ -3053,7 +3053,7 @@
" fitted_res = fitted_fcst.forecast(df=series, h=14, fitted=True)\n",
" fitted = fitted_fcst.forecast_fitted_values()\n",
" test_eq(series['ds'], fitted['ds'])\n",
" test_eq(series['y'].astype(np.float32), fitted['y'])\n",
" test_eq(series['y'], fitted['y'])\n",
" # test NullModel actualy fails\n",
" fitted_fcst = StatsForecast(\n",
" models=[NullModel()],\n",
Expand Down Expand Up @@ -3870,7 +3870,7 @@
" 'unique_id': [0] * 10 + [1] * 10,\n",
" 'ds': np.hstack([np.arange(10), np.arange(10)]),\n",
" 'y': np.random.rand(20),\n",
" 'x': np.arange(20, dtype=np.float32),\n",
" 'x': np.arange(20, dtype=np.float64),\n",
" }\n",
")\n",
"train_mask = df['ds'] < 6\n",
Expand Down
Loading

0 comments on commit 9cf7424

Please sign in to comment.