Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 3, 2024
1 parent f68f31e commit 2b0b75f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name='video-diffusion-pytorch',
packages=find_packages(exclude=[]),
version='0.6.3',
version='0.7.0',
license='MIT',
description='Video Diffusion - Pytorch',
long_description_content_type='text/markdown',
Expand Down
24 changes: 16 additions & 8 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,15 @@ def forward(self, x):
mean = torch.mean(x, dim = 1, keepdim = True)
return (x - mean) / (var + self.eps).sqrt() * self.gamma

class RMSNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.scale = dim ** 0.5
self.gamma = nn.Parameter(torch.ones(dim, 1, 1, 1))

def forward(self, x):
return F.normalize(x, dim = 1) * self.scale * self.gamma

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
Expand All @@ -174,10 +183,10 @@ def forward(self, x, **kwargs):


class Block(nn.Module):
def __init__(self, dim, dim_out, groups = 8):
def __init__(self, dim, dim_out):
super().__init__()
self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding = (0, 1, 1))
self.norm = nn.GroupNorm(groups, dim_out)
self.norm = RMSNorm(dim_out)
self.act = nn.SiLU()

def forward(self, x, scale_shift = None):
Expand All @@ -191,15 +200,15 @@ def forward(self, x, scale_shift = None):
return self.act(x)

class ResnetBlock(nn.Module):
def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8):
def __init__(self, dim, dim_out, *, time_emb_dim = None):
super().__init__()
self.mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(time_emb_dim, dim_out * 2)
) if exists(time_emb_dim) else None

self.block1 = Block(dim, dim_out, groups = groups)
self.block2 = Block(dim_out, dim_out, groups = groups)
self.block1 = Block(dim, dim_out)
self.block2 = Block(dim_out, dim_out)
self.res_conv = nn.Conv3d(dim, dim_out, 1) if dim != dim_out else nn.Identity()

def forward(self, x, time_emb = None):
Expand Down Expand Up @@ -355,8 +364,7 @@ def __init__(
init_dim = None,
init_kernel_size = 7,
use_sparse_linear_attn = True,
block_type = 'resnet',
resnet_groups = 8
block_type = 'resnet'
):
super().__init__()
self.channels = channels
Expand Down Expand Up @@ -412,7 +420,7 @@ def __init__(

# block type

block_klass = partial(ResnetBlock, groups = resnet_groups)
block_klass = ResnetBlock
block_klass_cond = partial(block_klass, time_emb_dim = cond_dim)

# modules for all layers
Expand Down

0 comments on commit 2b0b75f

Please sign in to comment.