Skip to content

Commit

Permalink
cleaner way to organize code around RiverHasWings application of clas…
Browse files Browse the repository at this point in the history
…sifier free guidance to transformers
  • Loading branch information
lucidrains committed Apr 13, 2022
1 parent 144ce50 commit fcd35de
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 13 deletions.
25 changes: 13 additions & 12 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,7 @@ def generate_images(

text, image = out[:, :text_seq_len], out[:, text_seq_len:]

if cond_scale != 1 and use_cache:
# copy the cache state to infer from the same place twice
prev_cache = cache.copy()

logits = self(text, image, cache = cache)

if cond_scale != 1:
# discovery by Katherine Crowson
# https://twitter.com/RiversHaveWings/status/1478093658716966912
null_cond_logits = self(text, image, null_cond_prob = 1., cache = prev_cache)
logits = null_cond_logits + (logits - null_cond_logits) * cond_scale

logits = self.forward_with_cond_scale(text, image, cond_scale = cond_scale, cache = cache)
logits = logits[:, -1, :]

filtered_logits = top_k(logits, thres = filter_thres)
Expand All @@ -556,6 +545,18 @@ def generate_images(

return images

def forward_with_cond_scale(self, *args, cond_scale = 1, cache = None, **kwargs):
if cond_scale == 1:
return self(*args, **kwargs)

prev_cache = cache.copy() if exists(cache) else None
logits = self(*args, cache = cache, **kwargs)

# discovery by Katherine Crowson
# https://twitter.com/RiversHaveWings/status/1478093658716966912
null_cond_logits = self(*args, null_cond_prob = 1., cache = prev_cache, **kwargs)
return null_cond_logits + (logits - null_cond_logits) * cond_scale

def forward(
self,
text,
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.5.1',
version = '1.5.2',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit fcd35de

Please sign in to comment.