Skip to content

Commit

Permalink
fix a bug with bert encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 15, 2022
1 parent f692a12 commit e4d9e9d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'video-diffusion-pytorch',
packages = find_packages(exclude=[]),
version = '0.5.2',
version = '0.5.3',
license='MIT',
description = 'Video Diffusion - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
7 changes: 6 additions & 1 deletion video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,12 @@ def __init__(
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

# text conditioning parameters

self.text_use_bert_cls = text_use_bert_cls

# dynamic thresholding when sampling

self.use_dynamic_thres = use_dynamic_thres
self.dynamic_thres_percentile = dynamic_thres_percentile

Expand Down Expand Up @@ -715,7 +720,7 @@ def p_losses(self, x_start, t, cond = None, noise = None, **kwargs):
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)

if is_list_str(cond):
cond = bert_embed(tokenize(cond), return_cls_repr = text_use_bert_cls)
cond = bert_embed(tokenize(cond), return_cls_repr = self.text_use_bert_cls)
cond = cond.to(device)

x_recon = self.denoise_fn(x_noisy, t, cond = cond, **kwargs)
Expand Down

0 comments on commit e4d9e9d

Please sign in to comment.