Skip to content

Commit

Permalink
remove all biases, given results from PaLM as well as hearing about @…
Browse files Browse the repository at this point in the history
…borisdayma experiments
  • Loading branch information
lucidrains committed Apr 22, 2022
1 parent 62f7f20 commit 0032a49
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 11 deletions.
20 changes: 10 additions & 10 deletions nuwa_pytorch/nuwa_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,10 +276,10 @@ def __init__(
self.chunk_size = chunk_size

self.net = nn.Sequential(
nn.Linear(dim, inner_dim * 2),
nn.Linear(dim, inner_dim * 2, bias = False),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim)
nn.Linear(inner_dim, dim, bias = False)
)

def forward(self, x):
Expand Down Expand Up @@ -315,7 +315,7 @@ def __init__(
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

def forward(
self,
Expand Down Expand Up @@ -638,7 +638,7 @@ def __init__(
self.talking_heads = nn.Conv3d(heads, heads, 1, bias = False)
self.dropout = nn.Dropout(dropout)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# handle variables for unfold

Expand Down Expand Up @@ -787,7 +787,7 @@ def __init__(
self.dropout = nn.Dropout(dropout)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

# handle variables for 2d unfold

Expand Down Expand Up @@ -938,7 +938,7 @@ def __init__(

self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

self.null_k = nn.Parameter(torch.randn(heads, dim_head))
self.null_v = nn.Parameter(torch.randn(heads, dim_head))
Expand Down Expand Up @@ -1821,7 +1821,7 @@ def __init__(
sparse_3dna_rel_pos_bias = sparse_3dna_rel_pos_bias
)

self.to_logits = nn.Linear(dim, num_image_tokens)
self.to_logits = nn.Linear(dim, num_image_tokens, bias = False)

def embed_text(self, text, mask = None):
batch, seq_len, device = *text.shape, text.device
Expand Down Expand Up @@ -2090,8 +2090,8 @@ def __init__(
sparse_2dna_rel_pos_bias = sparse_2dna_rel_pos_bias
)

self.to_video_logits = nn.Linear(dim, num_image_tokens)
self.to_audio_logits = nn.Linear(dim, num_audio_tokens)
self.to_video_logits = nn.Linear(dim, num_image_tokens, bias = False)
self.to_audio_logits = nn.Linear(dim, num_audio_tokens, bias = False)

def embed_text(self, text, mask = None):
batch, seq_len, device = *text.shape, text.device
Expand Down Expand Up @@ -2413,7 +2413,7 @@ def __init__(
sparse_3dna_attn = True
)

self.to_logits = nn.Linear(dim, num_image_tokens)
self.to_logits = nn.Linear(dim, num_image_tokens, bias = False)

def embed_sketch(self, sketch, mask = None):
batch, frames, channels, image_size, _, device = *sketch.shape, sketch.device
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'nuwa-pytorch',
packages = find_packages(exclude=[]),
include_package_data = True,
version = '0.7.2',
version = '0.7.3',
license='MIT',
description = 'NÜWA - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 0032a49

Please sign in to comment.