Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Python package #19

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion satclip/.gitignore → .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@
satclip_logs/

# Python
__pycache__/
__pycache__/
*.egg-info/
build/
39 changes: 39 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
[build-system]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is the main change in this PR

requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"

[project]
name = "satclip"
description = "A global, general-purpose geographic location encoder"
version = "0.0.1"
authors = [
{name="Konstantin Klemmer"},
{name="Esther Rolf"},
{name="Caleb Robinson"},
{name="Lester Mackey"},
{name="Marc Rußwurm"},
]
readme = "README.md"
license = {file = "LICENSE"}
requires-python = ">=3.9"
classifiers = [
"Programming Language :: Python :: 3",
"Operating System :: OS Independent",
"License :: OSI Approved :: MIT License",
]

dependencies = [
"albumentations",
"lightning == 2.2.2",
"pandas",
"rasterio >= 1.3.10",
"torchgeo >= 0.5", # Forces Python 3.9+
]

[project.urls]
Homepage = "https://github.com/microsoft/satclip"
Repository = "https://github.com/microsoft/satclip.git"
Issues = "https://github.com/microsoft/satclip/issues"

[tool.setuptools.packages.find]
include = ["satclip", "satclip.*"]
4 changes: 2 additions & 2 deletions satclip/load.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from main import *
from .main import *

def get_satclip(ckpt_path, device, return_all=False):
ckpt = torch.load(ckpt_path,map_location=device)
Expand All @@ -15,4 +15,4 @@ def get_satclip(ckpt_path, device, return_all=False):
if return_all:
return geo_model
else:
return geo_model.location
return geo_model.location
10 changes: 5 additions & 5 deletions satclip/load_lightweight.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from location_encoder import get_neural_network, get_positional_encoding, LocationEncoder

from .location_encoder import get_neural_network, get_positional_encoding, LocationEncoder


def get_satclip_loc_encoder(ckpt_path, device):
Expand All @@ -14,7 +15,7 @@ def get_satclip_loc_encoder(ckpt_path, device):
hp['max_radius'],
hp['frequency_num']
)

nnet = get_neural_network(
hp['pe_type'],
posenc.embedding_dim,
Expand All @@ -25,12 +26,11 @@ def get_satclip_loc_encoder(ckpt_path, device):

# only load nnet params from state dict
state_dict = ckpt['state_dict']
state_dict = {k[k.index('nnet'):]:state_dict[k]
state_dict = {k[k.index('nnet'):]:state_dict[k]
for k in state_dict.keys() if 'nnet' in k}

loc_encoder = LocationEncoder(posenc, nnet).double()
loc_encoder.load_state_dict(state_dict)
loc_encoder.eval()

return loc_encoder

10 changes: 5 additions & 5 deletions satclip/location_encoder.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from torch import nn, optim
import math

import torch
import torch.nn.functional as F
from einops import rearrange
import numpy as np
from datetime import datetime
import positional_encoding as PE
from torch import nn

from . import positional_encoding as PE

"""
FCNet
Expand Down Expand Up @@ -110,7 +110,7 @@ def forward(self, x, mods = None):
x *= rearrange(mod, 'd -> () d')

return self.last_layer(x)

class Sine(nn.Module):
def __init__(self, w0 = 1.):
super().__init__()
Expand Down
7 changes: 4 additions & 3 deletions satclip/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

import lightning.pytorch
import torch
from datamodules.s2geo_dataset import S2GeoDataModule
from lightning.pytorch.cli import LightningCLI
from loss import SatCLIPLoss
from model import SatCLIP

from .datamodules.s2geo_dataset import S2GeoDataModule
from .loss import SatCLIPLoss
from .model import SatCLIP

torch.set_float32_matmul_precision('high')

Expand Down
30 changes: 14 additions & 16 deletions satclip/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from collections import OrderedDict
from typing import Tuple, Union, Optional
from typing import Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
import math

import timm
import torchgeo.models
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, ViTSmall16_Weights
from location_encoder import get_positional_encoding, get_neural_network, LocationEncoder
from datamodules.s2geo_dataset import S2Geo

from .location_encoder import get_positional_encoding, get_neural_network, LocationEncoder

class Bottleneck(nn.Module):
expansion = 4
Expand Down Expand Up @@ -257,20 +255,20 @@ def __init__(self,
# location
le_type: str,
pe_type: str,
frequency_num: int,
max_radius: int,
frequency_num: int,
max_radius: int,
min_radius: int,
harmonics_calculation: str,
legendre_polys: int=10,
sh_embedding_dims: int=16,
legendre_polys: int=10,
sh_embedding_dims: int=16,
ffn: bool=True,
num_hidden_layers: int=2,
capacity: int=256,
*args,
**kwargs
):
super().__init__()

if isinstance(vision_layers, (tuple, list)):
print('using modified resnet')
vision_heads = vision_width * 32 // 64
Expand All @@ -282,7 +280,7 @@ def __init__(self,
width=vision_width,
in_channels=in_channels
)

elif vision_layers == 'moco_resnet18':
print('using pretrained moco resnet18')
weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
Expand All @@ -300,7 +298,7 @@ def __init__(self,
self.visual.load_state_dict(weights.get_state_dict(progress=True), strict=False)
self.visual.requires_grad_(False)
self.visual.fc.requires_grad_(True)

elif vision_layers == 'moco_vit16':
print('using pretrained moco vit16')
weights = ViTSmall16_Weights.SENTINEL2_ALL_MOCO
Expand All @@ -322,13 +320,13 @@ def __init__(self,
output_dim=embed_dim,
in_channels=in_channels
)

self.posenc = get_positional_encoding(name=le_type, harmonics_calculation=harmonics_calculation, legendre_polys=legendre_polys, min_radius=min_radius, max_radius=max_radius, frequency_num=frequency_num).double()
self.nnet = get_neural_network(name=pe_type, input_dim=self.posenc.embedding_dim, num_classes=embed_dim, dim_hidden=capacity, num_layers=num_hidden_layers).double()
self.location = LocationEncoder(self.posenc,
self.location = LocationEncoder(self.posenc,
self.nnet
).double()

self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

self.initialize_parameters()
Expand Down Expand Up @@ -362,7 +360,7 @@ def encode_location(self, coords):

def forward(self, image, coords):

image_features = self.encode_image(image)
image_features = self.encode_image(image)
location_features = self.encode_location(coords).float()
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
Expand Down