Skip to content

Commit

Permalink
Merge pull request #25 from IBM/fix/swin_instantiation
Browse files Browse the repository at this point in the history
Fix/swin instantiation
  • Loading branch information
CarlosGomes98 authored Jun 28, 2024
2 parents da82a3a + 313b9f9 commit 4cbf229
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 51 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ dependencies = [
"lightly>=1.4.25",
"h5py>=3.10.0",
"geobench>=1.0.0",
"mlflow>=2.12.1"
"mlflow>=2.12.1",
"lightning<=2.2.5"
]

[project.optional-dependencies]
Expand Down
7 changes: 2 additions & 5 deletions src/terratorch/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
# import so they get registered
from terratorch.models.backbones.prithvi_vit import TemporalViTEncoder

__all__ = ["TemporalViTEncoder"]
__all__ = ["TemporalViTEncoder"]
__all__ = ["TemporalViTEncoder"]
import terratorch.models.backbones.prithvi_vit
import terratorch.models.backbones.prithvi_swin
40 changes: 0 additions & 40 deletions src/terratorch/models/backbones/prithvi_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,6 @@ def _cfg(file: Path = "", **kwargs) -> dict:
**kwargs,
}

default_cfgs = generate_default_cfgs(
{
"prithvi_swin_90_us": {
"hf_hub_id": "ibm-nasa-geospatial/Prithvi-100M",
"hf_hub_filename": "Prithvi_100M.pt"
}
}
)

def convert_weights_swin2mmseg(ckpt):
# from https://github.com/open-mmlab/mmsegmentation/blob/main/tools/model_converters/swin2mmseg.py
new_ckpt = OrderedDict()
Expand Down Expand Up @@ -215,37 +206,6 @@ def prepare_features_for_image_model(x):
return model


@register_model
def prithvi_swin_90_us(
pretrained: bool = False, # noqa: FBT002, FBT001
pretrained_bands: list[HLSBands] | None = None,
bands: list[int] | None = None,
**kwargs,
) -> MMSegSwinTransformer:
"""Prithvi Swin 90M"""
if pretrained_bands is None:
pretrained_bands = PRETRAINED_BANDS
if bands is None:
bands = pretrained_bands
logging.info(
f"Model bands not passed. Assuming bands are ordered in the same way as {PRETRAINED_BANDS}.\
Pretrained patch_embed layer may be misaligned with current bands"
)

model_args = {
"patch_size": 4,
"window_size": 7,
"embed_dim": 128,
"depths": (2, 2, 18, 2),
"in_chans": 6,
"num_heads": (4, 8, 16, 32),
}
transformer = _create_swin_mmseg_transformer(
"prithvi_swin_90_us", pretrained_bands, bands, pretrained=pretrained, **dict(model_args, **kwargs)
)
return transformer


@register_model
def prithvi_swin_B(
pretrained: bool = False, # noqa: FBT002, FBT001
Expand Down
4 changes: 2 additions & 2 deletions tests/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ def input_386():
return torch.ones((1, NUM_CHANNELS, 386, 386))


@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"]) #["prithvi_swin_90_us", "prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False)
input_tensor = request.getfixturevalue(test_input)
backbone(input_tensor)


@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("model_name", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("test_input", ["input_224", "input_512"])
def test_can_create_backbones_from_timm_features_only(model_name, test_input, request):
backbone = timm.create_model(model_name, pretrained=False, features_only=True)
Expand Down
6 changes: 3 additions & 3 deletions tests/test_prithvi_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def model_input() -> torch.Tensor:
return torch.ones((1, NUM_CHANNELS, 224, 224))


@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("backbone",["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("loss", ["ce", "jaccard", "focal", "dice"])
def test_create_segmentation_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
Expand All @@ -38,7 +38,7 @@ def test_create_segmentation_task(backbone, decoder, loss, model_factory: Prithv
)


@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("loss", ["mae", "rmse", "huber"])
def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
Expand All @@ -55,7 +55,7 @@ def test_create_regression_task(backbone, decoder, loss, model_factory: PrithviM
)


@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300"])
@pytest.mark.parametrize("backbone", ["prithvi_vit_100", "prithvi_vit_300", "prithvi_swin_B"])
@pytest.mark.parametrize("decoder", ["FCNDecoder", "UperNetDecoder", "IdentityDecoder"])
@pytest.mark.parametrize("loss", ["ce", "bce", "jaccard", "focal"])
def test_create_classification_task(backbone, decoder, loss, model_factory: PrithviModelFactory):
Expand Down

0 comments on commit 4cbf229

Please sign in to comment.