Skip to content

Commit

Permalink
use He initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Aug 15, 2024
1 parent d6ece24 commit 8d72bf5
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ class LocationParams(torch.nn.Module):
def __init__(self, n_channels, size) -> None:
super().__init__()

self.params = torch.nn.Parameter(torch.randn(n_channels, size, size))
# He initialization of weights
tensor = torch.randn(n_channels, size, size)
torch.nn.init.kaiming_normal_(tensor, mode="fan_out")
self.params = torch.nn.Parameter(tensor)


def forward(self, cond):
batch_size = cond.shape[0]
Expand Down

0 comments on commit 8d72bf5

Please sign in to comment.