Skip to content
This repository has been archived by the owner on Nov 29, 2023. It is now read-only.

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 9, 2023
1 parent 5abbee1 commit 015796f
Show file tree
Hide file tree
Showing 12 changed files with 3 additions and 20 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<!-- ALL-CONTRIBUTORS-LIST:END -->
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
1 change: 0 additions & 1 deletion satflow/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,6 @@ def per_worker_init(self, worker_id: int):
pass

def __getitem__(self, idx):

x = {
SATELLITE_DATA: torch.randn(
self.batch_size, self.seq_length, self.width, self.height, self.number_sat_channels
Expand Down
3 changes: 2 additions & 1 deletion satflow/data/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def eumetsat_name_to_datetime(filename: str):

def retrieve_pixel_value(geo_coord, data_source):
"""Return floating-point value that corresponds to given point.
Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal"""
Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal
"""
x, y = geo_coord[0], geo_coord[1]
forward_transform = affine.Affine.from_gdal(*data_source.GetGeoTransform())
reverse_transform = ~forward_transform
Expand Down
2 changes: 0 additions & 2 deletions satflow/models/conv_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "s
)

def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4):

outputs = []

# encoder
Expand Down Expand Up @@ -203,7 +202,6 @@ def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3,
return outputs

def forward(self, x, forecast_steps=0, hidden_state=None):

"""
Parameters
----------
Expand Down
1 change: 0 additions & 1 deletion satflow/models/gan/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def __init__(

mult = 2**n_downsampling
for i in range(n_blocks): # add ResNet blocks

model += [
ResnetBlock(
ngf * mult,
Expand Down
5 changes: 0 additions & 5 deletions satflow/models/layers/Attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ def __init__(
)

def forward(self, x):

return self.model(x)


Expand Down Expand Up @@ -56,7 +55,6 @@ def init_conv(self, conv, glu=True):
conv.bias.data.zero_()

def forward(self, x):

batch_size, C, T, W, H = x.size()

assert T % 2 == 0 and W % 2 == 0 and H % 2 == 0, "T, W, H is not even"
Expand Down Expand Up @@ -111,7 +109,6 @@ def forward(self, x):

class SelfAttention(nn.Module):
def __init__(self, in_dim, activation=F.relu, pooling_factor=2): # TODO for better compability

super(SelfAttention, self).__init__()
self.activation = activation

Expand All @@ -134,7 +131,6 @@ def init_conv(self, conv, glu=True):
conv.bias.data.zero_()

def forward(self, x):

if len(x.size()) == 4:
batch_size, C, W, H = x.size()
T = 1
Expand Down Expand Up @@ -224,7 +220,6 @@ def forward(self, x):


if __name__ == "__main__":

self_attn = SelfAttention(16) # no less than 8
print(self_attn)

Expand Down
1 change: 0 additions & 1 deletion satflow/models/layers/Discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,6 @@ def forward(self, x, class_id):


if __name__ == "__main__":

batch_size = 6
n_frames = 8
n_class = 4
Expand Down
2 changes: 0 additions & 2 deletions satflow/models/layers/GResBlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
self.CBNorm2 = ConditionalNorm(out_channel, n_class)

def forward(self, x, condition=None):

# The time dimension is combined with the batch dimension here, so each frame proceeds
# through the blocks independently
BT, C, W, H = x.size()
Expand Down Expand Up @@ -100,7 +99,6 @@ def forward(self, x, condition=None):


if __name__ == "__main__":

n_class = 96
batch_size = 4
n_frames = 20
Expand Down
3 changes: 0 additions & 3 deletions satflow/models/layers/Generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hier
self.colorize = SpectralNorm(nn.Conv2d(2 * ch, 3, kernel_size=(3, 3), padding=1))

def forward(self, x, class_id):

if self.hierar_flag is True:
noise_emb = torch.split(x, self.in_dim, dim=1)
else:
Expand All @@ -87,7 +86,6 @@ def forward(self, x, class_id):

for k, conv in enumerate(self.conv):
if isinstance(conv, ConvGRU):

if k > 0:
_, C, W, H = y.size()
y = y.view(-1, self.n_frames, C, W, H).contiguous()
Expand Down Expand Up @@ -132,7 +130,6 @@ def forward(self, x, class_id):


if __name__ == "__main__":

batch_size = 5
in_dim = 120
n_class = 4
Expand Down
1 change: 0 additions & 1 deletion satflow/models/layers/Normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def forward(self, x, class_id):


if __name__ == "__main__":

cn = ConditionalNorm(3, 2)
x = torch.rand([4, 3, 64, 64])
class_id = torch.rand([4, 2])
Expand Down
1 change: 0 additions & 1 deletion satflow/models/layers/RUnetLayers.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def __init__(self, ch_out, t=2, conv_type: str = "standard"):

def forward(self, x):
for i in range(self.t):

if i == 0:
x1 = self.conv(x)

Expand Down
1 change: 0 additions & 1 deletion satflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

@hydra.main(config_path="configs/", config_name="config.yaml")
def main(config: DictConfig):

# Imports should be nested inside @hydra.main to optimize tab completion
# Read more here: https://github.com/facebookresearch/hydra/issues/934
from satflow.core import utils
Expand Down

0 comments on commit 015796f

Please sign in to comment.