Skip to content

Commit

Permalink
Merge pull request #18 from nasaharvest/ndvi-in-construct-single-pres…
Browse files Browse the repository at this point in the history
…to-input

Update construct_single_presto_input to unmask NDVI if its calculated
  • Loading branch information
gabrieltseng committed Oct 13, 2023
2 parents d498ba6 + 46e314d commit 613265e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 3 deletions.
12 changes: 10 additions & 2 deletions presto/dataops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .pipelines.s1_s2_era5_srtm import (
BANDS,
ERA5_BANDS,
NORMED_BANDS,
REMOVED_BANDS,
S1_BANDS,
S1_S2_ERA5_SRTM,
Expand Down Expand Up @@ -60,8 +61,15 @@ def construct_single_presto_input(
if dynamic_world is None:
dynamic_world = torch.ones(num_timesteps) * (DynamicWorld2020_2021.class_amount)

keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"]
mask = mask[:, keep_indices]

if normalize:
keep_indices = [idx for idx, val in enumerate(BANDS) if val != "B9"]
mask = mask[:, keep_indices]
# normalize includes x = x[:, keep_indices]
x = S1_S2_ERA5_SRTM.normalize(x)
if s2_bands is not None:
if ("B8" in s2_bands) and ("B4" in s2_bands):
mask[:, NORMED_BANDS.index("NDVI")] = 0
else:
x = x[:, keep_indices]
return x, mask, dynamic_world
26 changes: 25 additions & 1 deletion tests/test_dataops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,39 @@

from presto import construct_single_presto_input
from presto.dataops.pipelines.dynamicworld import DynamicWorld2020_2021
from presto.dataops.pipelines.s1_s2_era5_srtm import NORMED_BANDS


class TestDatopsUtils(TestCase):
def test_construct_single_presto_input(self):
input_bands = ["B2", "B3", "B4", "B8"]
x, mask, dw = construct_single_presto_input(
s2=torch.ones(2, 3), s2_bands=["B2", "B3", "B4"], normalize=False
s2=torch.ones(2, 4), s2_bands=input_bands, normalize=False
)
self.assertTrue(torch.equal(dw, torch.ones_like(dw) * DynamicWorld2020_2021.class_amount))
self.assertEqual(len(dw), x.shape[0])
self.assertEqual(x.shape, mask.shape)
self.assertTrue((x[mask == 1] == 0).all())
self.assertTrue((x[mask == 0] != 0).all())
for idx, band in enumerate(NORMED_BANDS):
if band in input_bands:
self.assertTrue((mask[:, idx] == 0).all())
else:
self.assertTrue((mask[:, idx] == 1).all())

def test_construct_single_presto_input_ndvi(self):
input_bands = ["B2", "B3", "B4", "B8"]
x, mask, dw = construct_single_presto_input(
s2=torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]]).float(),
s2_bands=input_bands,
normalize=True,
)
self.assertTrue(torch.equal(dw, torch.ones_like(dw) * DynamicWorld2020_2021.class_amount))
self.assertEqual(len(dw), x.shape[0])
self.assertEqual(x.shape, mask.shape)
# we can't test for equality to 0 since we normalize;
# that's tested above
self.assertTrue((x[mask == 0] != 0).all())
for idx, band in enumerate(NORMED_BANDS):
if band in input_bands + ["NDVI"]:
self.assertTrue((mask[:, idx] == 0).all())

0 comments on commit 613265e

Please sign in to comment.