Skip to content

Commit

Permalink
fix a bug where the highest resolution unet residual was not being co…
Browse files Browse the repository at this point in the history
…ncatted to the final resnet block
  • Loading branch information
lucidrains committed Jun 16, 2022
1 parent f55f1b0 commit 485186f
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.2',
version = '0.5.0',
license='MIT',
description = 'Video Diffusion - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
5 changes: 3 additions & 2 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ def __init__(

out_dim = default(out_dim, channels)
self.final_conv = nn.Sequential(
block_klass(dim, dim),
block_klass(dim * 2, dim),
nn.Conv3d(dim, out_dim, 1)
)

Expand Down Expand Up @@ -513,13 +513,14 @@ def forward(
x = self.mid_block2(x, t)

for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = block2(x, t)
x = spatial_attn(x)
x = temporal_attn(x, pos_bias = time_rel_pos_bias, focus_present_mask = focus_present_mask)
x = upsample(x)

x = torch.cat((x, h.pop()), dim = 1)
return self.final_conv(x)

# gaussian diffusion trainer class
Expand Down

0 comments on commit 485186f

Please sign in to comment.