Skip to content

Commit

Permalink
one more residual
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 16, 2022
1 parent 4590f60 commit 98d64b4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.1',
version = '0.5.2',
license='MIT',
description = 'Video Diffusion - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
5 changes: 4 additions & 1 deletion video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __init__(

out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim),
block_klass(dim * 2, dim),
nn.Conv3d(dim, out_dim, 1)
)

Expand Down Expand Up @@ -485,6 +485,8 @@ def forward(
time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device = x.device)

x = self.init_conv(x)
r = x.clone()

x = self.init_temporal_attn(x, pos_bias = time_rel_pos_bias)

t = self.time_mlp(time) if exists(self.time_mlp) else None
Expand Down Expand Up @@ -520,6 +522,7 @@ def forward(
x = temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
x = upsample(x)

x = torch.cat((x, r), dim = 1)
return self.final_conv(x)

# gaussian diffusion trainer class
Expand Down

0 comments on commit 98d64b4

Please sign in to comment.