Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model for multi-device usage and easier disabling of masking #95

Merged
merged 3 commits into from
Dec 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 42 additions & 33 deletions src/model_clay.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def __init__( # noqa: PLR0913
band_groups,
dropout,
emb_dropout,
device,
):
super().__init__()
assert (
Expand All @@ -66,13 +65,9 @@ def __init__( # noqa: PLR0913
self.dim = dim
self.bands = bands
self.band_groups = band_groups
self.device = device
self.num_spatial_patches = (image_size // patch_size) ** 2
self.num_group_patches = len(band_groups)
self.num_patches = self.num_spatial_patches * self.num_group_patches
self.num_masked_patches = int(
self.mask_ratio * self.num_patches
) # Number of patches to be masked out

# Split the embedding dimensions between spatial & band patches equally
pos_dim = band_dim = dim // 2
Expand All @@ -87,16 +82,24 @@ def __init__( # noqa: PLR0913
)

# Fix the position & band embedding to sine & cosine functions
self.pos_encoding = posemb_sincos_2d(
h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim
) # [L D/2]
self.band_encoding = posemb_sincos_1d(
length=self.num_group_patches, dim=band_dim
) # [G D/2]
self.register_buffer(
name="pos_encoding",
tensor=posemb_sincos_2d(
h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim
), # [L D/2]
persistent=False,
)
self.register_buffer(
name="band_encoding",
tensor=posemb_sincos_1d(
length=self.num_group_patches, dim=band_dim
), # [G D/2]
persistent=False,
)

# Freeze the weights of position & band encoding
self.pos_encoding = self.pos_encoding.to(self.device).requires_grad_(False)
self.band_encoding = self.band_encoding.to(self.device).requires_grad_(False)
self.pos_encoding = self.pos_encoding.requires_grad_(False)
self.band_encoding = self.band_encoding.requires_grad_(False)

self.dropout = nn.Dropout(emb_dropout)

Expand Down Expand Up @@ -234,15 +237,18 @@ def mask_out(self, patches):
random_indices = torch.argsort(noise, dim=-1) # [B GL]
reverse_indices = torch.argsort(random_indices, dim=-1) # [B GL]

num_masked_patches = int(
self.mask_ratio * self.num_patches
) # Number of patches to be masked out
masked_indices, unmasked_indices = (
random_indices[:, : self.num_masked_patches], # [B mask_ratio * GL]
random_indices[:, self.num_masked_patches :], # [B (1 - mask_ratio) * GL]
random_indices[:, :num_masked_patches], # [B mask_ratio * GL]
random_indices[:, num_masked_patches:], # [B (1 - mask_ratio) * GL]
)

# create a mask of shape B G L, where 1 indicates a masked patch
# and 0 indicates an unmasked patch
masked_matrix = torch.zeros((B, GL), device=patches.device) # [B GL] = 0
masked_matrix[:, : self.num_masked_patches] = 1 # [B mask_ratio * GL] = 1
masked_matrix[:, :num_masked_patches] = 1 # [B mask_ratio * GL] = 1
masked_matrix = torch.gather(
masked_matrix, dim=1, index=reverse_indices
) # [B GL] -> [B GL] - reorder the patches
Expand Down Expand Up @@ -331,7 +337,6 @@ def __init__( # noqa: PLR0913
bands,
band_groups,
dropout,
device,
):
super().__init__()
self.mask_ratio = mask_ratio
Expand All @@ -340,11 +345,9 @@ def __init__( # noqa: PLR0913
self.encoder_dim = encoder_dim
self.dim = dim
self.band_groups = band_groups
self.device = device
self.num_spatial_patches = (image_size // patch_size) ** 2
self.num_group_patches = len(band_groups)
self.num_patches = self.num_spatial_patches * self.num_group_patches
self.num_masked_patches = int(self.mask_ratio * self.num_patches)

self.enc_to_dec = (
nn.Linear(encoder_dim, dim) if encoder_dim != dim else nn.Identity()
Expand All @@ -363,16 +366,24 @@ def __init__( # noqa: PLR0913
pos_dim = band_dim = dim // 2

# Fix the position & band embedding to sine & cosine functions
self.pos_encoding = posemb_sincos_2d(
h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim
) # [L D/2]
self.band_encoding = posemb_sincos_1d(
length=self.num_group_patches, dim=band_dim
) # [G D/2]
self.register_buffer(
name="pos_encoding",
tensor=posemb_sincos_2d(
h=image_size // patch_size, w=image_size // patch_size, dim=pos_dim
), # [L D/2]
persistent=False,
)
self.register_buffer(
name="band_encoding",
tensor=posemb_sincos_1d(
length=self.num_group_patches, dim=band_dim
), # [G D/2]
persistent=False,
)

# Freeze the weights of position & band encoding
self.pos_encoding = self.pos_encoding.to(self.device).requires_grad_(False)
self.band_encoding = self.band_encoding.to(self.device).requires_grad_(False)
self.pos_encoding = self.pos_encoding.requires_grad_(False)
self.band_encoding = self.band_encoding.requires_grad_(False)

self.embed_to_pixels = nn.ModuleDict(
{
Expand Down Expand Up @@ -438,8 +449,9 @@ def reconstruct_and_add_encoding(

# Reconstruct the masked patches from the random mask patch &
# add position & band encoding to them
num_masked_patches = int(self.mask_ratio * self.num_patches)
masked_patches = repeat(
self.mask_patch, "D -> B GL D", B=B, GL=self.num_masked_patches
self.mask_patch, "D -> B GL D", B=B, GL=num_masked_patches
) # [B GL:mask_ratio D]
masked_patches = (
masked_patches + masked_pos_band_encoding
Expand Down Expand Up @@ -566,8 +578,6 @@ def __init__( # noqa: PLR0913
self.patch_size = patch_size
self.bands = bands
self.band_groups = band_groups
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device", device)

self.encoder = Encoder(
mask_ratio=mask_ratio,
Expand All @@ -582,7 +592,6 @@ def __init__( # noqa: PLR0913
band_groups=band_groups,
dropout=dropout,
emb_dropout=emb_dropout,
device=device,
)

self.decoder = Decoder(
Expand All @@ -598,7 +607,6 @@ def __init__( # noqa: PLR0913
bands=bands,
band_groups=band_groups,
dropout=decoder_dropout,
device=device,
)

def per_pixel_loss(self, cube, pixels, masked_matrix):
Expand Down Expand Up @@ -786,7 +794,7 @@ def configure_optimizers(self):
betas=(self.hparams.b1, self.hparams.b2),
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer, T_0=200, T_mult=2, eta_min=self.hparams.lr * 10, last_epoch=-1
optimizer, T_0=1000, T_mult=2, eta_min=self.hparams.lr * 10, last_epoch=-1
)

return {
Expand All @@ -807,6 +815,7 @@ def shared_step(self, batch: dict[str, torch.Tensor], batch_idx: int, phase: str
on_epoch=True,
prog_bar=True,
logger=True,
sync_dist=True,
)
return loss

Expand Down