diff --git a/presto/dataops/utils.py b/presto/dataops/utils.py index ce5991f..3c3aa6e 100644 --- a/presto/dataops/utils.py +++ b/presto/dataops/utils.py @@ -6,6 +6,7 @@ from .pipelines.s1_s2_era5_srtm import ( BANDS, ERA5_BANDS, + NORMED_BANDS, REMOVED_BANDS, S1_BANDS, S1_S2_ERA5_SRTM, @@ -60,8 +61,15 @@ def construct_single_presto_input( if dynamic_world is None: dynamic_world = torch.ones(num_timesteps) * (DynamicWorld2020_2021.class_amount) + keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"] + mask = mask[:, keep_indices] + if normalize: - keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"] - mask = mask[:, keep_indices] + # normalize includes x = x[:, keep_indices] x = S1_S2_ERA5_SRTM.normalize(x) + if s2_bands is not None: + if ("B8" in s2_bands) and ("B4" in s2_bands): + mask[:, NORMED_BANDS.index("NDVI")] = 0 + else: + x = x[:, keep_indices] return x, mask, dynamic_world diff --git a/tests/test_dataops_utils.py b/tests/test_dataops_utils.py index 3a06ae9..2975fdc 100644 --- a/tests/test_dataops_utils.py +++ b/tests/test_dataops_utils.py @@ -4,15 +4,39 @@ from presto import construct_single_presto_input from presto.dataops.pipelines.dynamicworld import DynamicWorld2020_2021 +from presto.dataops.pipelines.s1_s2_era5_srtm import NORMED_BANDS class TestDatopsUtils(TestCase): def test_construct_single_presto_input(self): + input_bands = ["B2", "B3", "B4", "B8"] x, mask, dw = construct_single_presto_input( - s2=torch.ones(2, 3), s2_bands=["B2", "B3", "B4"], normalize=False + s2=torch.ones(2, 4), s2_bands=input_bands, normalize=False ) self.assertTrue(torch.equal(dw, torch.ones_like(dw) * DynamicWorld2020_2021.class_amount)) self.assertEqual(len(dw), x.shape[0]) self.assertEqual(x.shape, mask.shape) self.assertTrue((x[mask == 1] == 0).all()) self.assertTrue((x[mask == 0] != 0).all()) + for idx, band in enumerate(NORMED_BANDS): + if band in input_bands: + self.assertTrue((mask[:, idx] == 0).all()) + else: + self.assertTrue((mask[:, idx] == 1).all()) + + def test_construct_single_presto_input_ndvi(self): + input_bands = ["B2", "B3", "B4", "B8"] + x, mask, dw = construct_single_presto_input( + s2=torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]).float(), + s2_bands=input_bands, + normalize=True, + ) + self.assertTrue(torch.equal(dw, torch.ones_like(dw) * DynamicWorld2020_2021.class_amount)) + self.assertEqual(len(dw), x.shape[0]) + self.assertEqual(x.shape, mask.shape) + # we can't test for equality to 0 since we normalize; + # that's tested above + self.assertTrue((x[mask == 0] != 0).all()) + for idx, band in enumerate(NORMED_BANDS): + if band in input_bands + ["NDVI"]: + self.assertTrue((mask[:, idx] == 0).all())