diff --git a/sgkit/stats/preprocessing.py b/sgkit/stats/preprocessing.py index 195cada68..b7b993f31 100644 --- a/sgkit/stats/preprocessing.py +++ b/sgkit/stats/preprocessing.py @@ -62,8 +62,8 @@ def fit( Alternate allele counts with missing values encoded as either nan or negative numbers. """ - X = da.ma.masked_array(X, mask=da.isnan(X) | (X < 0)) - self.mean_ = da.ma.filled(da.mean(X, axis=0), fill_value=np.nan) + X = _replace_missing_with_nan(X) + self.mean_ = da.nanmean(X, axis=0) p = self.mean_ / self.ploidy self.scale_ = da.sqrt(p * (1 - p)) self.n_features_in_ = X.shape[1] @@ -90,10 +90,10 @@ def transform( Alternate allele counts with missing values encoded as either nan or negative numbers. """ - X = da.ma.masked_array(X, mask=da.isnan(X) | (X < 0)) + X = _replace_missing_with_nan(X) X -= self.mean_ X /= self.scale_ - return da.ma.filled(X, fill_value=np.nan) + return X def inverse_transform(self, X: ArrayLike, copy: Optional[bool] = None) -> ArrayLike: """Invert transform @@ -109,6 +109,14 @@ def inverse_transform(self, X: ArrayLike, copy: Optional[bool] = None) -> ArrayL return X +def _replace_missing_with_nan(X): + if np.issubdtype(X.dtype, np.floating): + nanarray = da.asarray(np.nan, dtype=X.dtype) + else: + nanarray = da.asarray(np.nan) + return da.where(X < 0, nanarray, X) + + def filter_partial_calls( ds: Dataset, *,