Skip to content

Commit

Permalink
move to film-like conditioning for time
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 25, 2022
1 parent 9aecf56 commit 275998c
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 11 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.4.0',
version = '0.4.2',
license='MIT',
description = 'Video Diffusion - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
29 changes: 19 additions & 10 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,33 +177,42 @@ def forward(self, x, **kwargs):
class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
super().__init__()
self.block = nn.Sequential(
nn.Conv3d(dim, dim_out, (1, 3, 3), padding = (0, 1, 1)),
nn.GroupNorm(groups, dim_out),
nn.SiLU()
)
def forward(self, x):
return self.block(x)
self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding = (0, 1, 1))
self.norm = nn.GroupNorm(groups, dim_out)
self.act = nn.SiLU()

def forward(self, x, scale_shift = None):
x = self.proj(x)
x = self.norm(x)

if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift

return self.act(x)

class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out)
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb = None):
h = self.block1(x)

scale_shift = None
if exists(self.mlp):
assert exists(time_emb), 'time emb must be passed in'
time_emb = self.mlp(time_emb)
h = rearrange(time_emb, 'b c -> b c 1 1 1') + h
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
scale_shift = time_emb.chunk(2, dim = 1)

h = self.block1(x, scale_shift = scale_shift)

h = self.block2(h)
return h + self.res_conv(x)
Expand Down

0 comments on commit 275998c

Please sign in to comment.