Skip to content

Commit

Permalink
fix a bug where the highest resolution unet residual was not being co…
Browse files Browse the repository at this point in the history
…ncatted to the final resnet block
  • Loading branch information
lucidrains committed Jun 16, 2022
1 parent f55f1b0 commit 4590f60
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.2',
version = '0.5.1',
license='MIT',
description = 'Video Diffusion - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
6 changes: 3 additions & 3 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __init__(

# initial conv

init_dim = default(init_dim, dim // 3 * 2)
init_dim = default(init_dim, dim)
assert is_odd(init_kernel_size)

init_padding = init_kernel_size // 2
Expand Down Expand Up @@ -438,7 +438,7 @@ def __init__(

self.mid_block2 = block_klass_cond(mid_dim, mid_dim)

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
for ind, (dim_in, dim_out) in enumerate(reversed(in_out)):
is_last = ind >= (num_resolutions - 1)

self.ups.append(nn.ModuleList([
Expand Down Expand Up @@ -513,7 +513,7 @@ def forward(
x = self.mid_block2(x, t)

for block1, block2, spatial_attn, temporal_attn, upsample in self.ups:
x = torch.cat((x, h.pop()), dim=1)
x = torch.cat((x, h.pop()), dim = 1)
x = block1(x, t)
x = block2(x, t)
x = spatial_attn(x)
Expand Down

0 comments on commit 4590f60

Please sign in to comment.