From 8028ecf6a983914c9697959979b49c718fc3f70f Mon Sep 17 00:00:00 2001 From: Jean-Guillaume Durand Date: Mon, 12 Aug 2024 12:43:58 -0700 Subject: [PATCH] Added packaging files --- satclip/.gitignore => .gitignore | 4 +++- pyproject.toml | 39 ++++++++++++++++++++++++++++++++ satclip/load.py | 4 ++-- satclip/load_lightweight.py | 10 ++++---- satclip/location_encoder.py | 10 ++++---- satclip/main.py | 7 +++--- satclip/model.py | 30 ++++++++++++------------ 7 files changed, 72 insertions(+), 32 deletions(-) rename satclip/.gitignore => .gitignore (55%) create mode 100644 pyproject.toml diff --git a/satclip/.gitignore b/.gitignore similarity index 55% rename from satclip/.gitignore rename to .gitignore index 52804a2..e3a142f 100644 --- a/satclip/.gitignore +++ b/.gitignore @@ -2,4 +2,6 @@ satclip_logs/ # Python -__pycache__/ \ No newline at end of file +__pycache__/ +*.egg-info/ +build/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..5f7bbef --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,39 @@ +[build-system] +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" + +[project] +name = "satclip" +description = "A global, general-purpose geographic location encoder" +version = "0.0.1" +authors = [ + {name="Konstantin Klemmer"}, + {name="Esther Rolf"}, + {name="Caleb Robinson"}, + {name="Lester Mackey"}, + {name="Marc RuƟwurm"}, +] +readme = "README.md" +license = {file = "LICENSE"} +requires-python = ">=3.9" +classifiers = [ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent", + "License :: OSI Approved :: MIT License", +] + +dependencies = [ + "albumentations", + "lightning == 2.2.2", + "pandas", + "rasterio >= 1.3.10", + "torchgeo >= 0.5", # Forces Python 3.9+ +] + +[project.urls] +Homepage = "https://github.com/microsoft/satclip" +Repository = "https://github.com/microsoft/satclip.git" +Issues = "https://github.com/microsoft/satclip/issues" + +[tool.setuptools.packages.find] +include = ["satclip", "satclip.*"] diff --git a/satclip/load.py b/satclip/load.py index ac975de..5c8448f 100644 --- a/satclip/load.py +++ b/satclip/load.py @@ -1,4 +1,4 @@ -from main import * +from .main import * def get_satclip(ckpt_path, device, return_all=False): ckpt = torch.load(ckpt_path,map_location=device) @@ -15,4 +15,4 @@ def get_satclip(ckpt_path, device, return_all=False): if return_all: return geo_model else: - return geo_model.location \ No newline at end of file + return geo_model.location diff --git a/satclip/load_lightweight.py b/satclip/load_lightweight.py index 6dfc828..2ec47d3 100644 --- a/satclip/load_lightweight.py +++ b/satclip/load_lightweight.py @@ -1,5 +1,6 @@ import torch -from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder + +from .location_encoder import get_neural_network, get_positional_encoding, LocationEncoder def get_satclip_loc_encoder(ckpt_path, device): @@ -14,7 +15,7 @@ def get_satclip_loc_encoder(ckpt_path, device): hp['max_radius'], hp['frequency_num'] ) - + nnet = get_neural_network( hp['pe_type'], posenc.embedding_dim, @@ -25,12 +26,11 @@ def get_satclip_loc_encoder(ckpt_path, device): # only load nnet params from state dict state_dict = ckpt['state_dict'] - state_dict = {k[k.index('nnet'):]:state_dict[k] + state_dict = {k[k.index('nnet'):]:state_dict[k] for k in state_dict.keys() if 'nnet' in k} - + loc_encoder = LocationEncoder(posenc, nnet).double() loc_encoder.load_state_dict(state_dict) loc_encoder.eval() return loc_encoder - \ No newline at end of file diff --git a/satclip/location_encoder.py b/satclip/location_encoder.py index 7bf1617..1911ac9 100644 --- a/satclip/location_encoder.py +++ b/satclip/location_encoder.py @@ -1,11 +1,11 @@ -from torch import nn, optim import math + import torch import torch.nn.functional as F from einops import rearrange -import numpy as np -from datetime import datetime -import positional_encoding as PE +from torch import nn + +from . import positional_encoding as PE """ FCNet @@ -110,7 +110,7 @@ def forward(self, x, mods = None): x *= rearrange(mod, 'd -> () d') return self.last_layer(x) - + class Sine(nn.Module): def __init__(self, w0 = 1.): super().__init__() diff --git a/satclip/main.py b/satclip/main.py index 03d1daf..a5134ab 100644 --- a/satclip/main.py +++ b/satclip/main.py @@ -3,10 +3,11 @@ import lightning.pytorch import torch -from datamodules.s2geo_dataset import S2GeoDataModule from lightning.pytorch.cli import LightningCLI -from loss import SatCLIPLoss -from model import SatCLIP + +from .datamodules.s2geo_dataset import S2GeoDataModule +from .loss import SatCLIPLoss +from .model import SatCLIP torch.set_float32_matmul_precision('high') diff --git a/satclip/model.py b/satclip/model.py index b75a129..26b400f 100644 --- a/satclip/model.py +++ b/satclip/model.py @@ -1,17 +1,15 @@ from collections import OrderedDict -from typing import Tuple, Union, Optional +from typing import Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn -import math import timm -import torchgeo.models from torchgeo.models import ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights -from location_encoder import get_positional_encoding, get_neural_network, LocationEncoder -from datamodules.s2geo_dataset import S2Geo + +from .location_encoder import get_positional_encoding, get_neural_network, LocationEncoder class Bottleneck(nn.Module): expansion = 4 @@ -257,12 +255,12 @@ def __init__(self, # location le_type: str, pe_type: str, - frequency_num: int, - max_radius: int, + frequency_num: int, + max_radius: int, min_radius: int, harmonics_calculation: str, - legendre_polys: int=10, - sh_embedding_dims: int=16, + legendre_polys: int=10, + sh_embedding_dims: int=16, ffn: bool=True, num_hidden_layers: int=2, capacity: int=256, @@ -270,7 +268,7 @@ def __init__(self, **kwargs ): super().__init__() - + if isinstance(vision_layers, (tuple, list)): print('using modified resnet') vision_heads = vision_width * 32 // 64 @@ -282,7 +280,7 @@ def __init__(self, width=vision_width, in_channels=in_channels ) - + elif vision_layers == 'moco_resnet18': print('using pretrained moco resnet18') weights = ResNet18_Weights.SENTINEL2_ALL_MOCO @@ -300,7 +298,7 @@ def __init__(self, self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False) self.visual.requires_grad_(False) self.visual.fc.requires_grad_(True) - + elif vision_layers == 'moco_vit16': print('using pretrained moco vit16') weights = ViTSmall16_Weights.SENTINEL2_ALL_MOCO @@ -322,13 +320,13 @@ def __init__(self, output_dim=embed_dim, in_channels=in_channels ) - + self.posenc = get_positional_encoding(name=le_type, harmonics_calculation=harmonics_calculation, legendre_polys=legendre_polys, min_radius=min_radius, max_radius=max_radius, frequency_num=frequency_num).double() self.nnet = get_neural_network(name=pe_type, input_dim=self.posenc.embedding_dim, num_classes=embed_dim, dim_hidden=capacity, num_layers=num_hidden_layers).double() - self.location = LocationEncoder(self.posenc, + self.location = LocationEncoder(self.posenc, self.nnet ).double() - + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) self.initialize_parameters() @@ -362,7 +360,7 @@ def encode_location(self, coords): def forward(self, image, coords): - image_features = self.encode_image(image) + image_features = self.encode_image(image) location_features = self.encode_location(coords).float() # normalized features image_features = image_features / image_features.norm(dim=1, keepdim=True)