Skip to content

Commit

Permalink
fix version problem if researcher did not install from pip
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 30, 2022
1 parent d03df6a commit 343f8b0
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion dalle_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from dalle_pytorch.vae import OpenAIDiscreteVAE, VQGanVAE

from pkg_resources import get_distribution
__version__ = get_distribution('dalle_pytorch').version
from dalle_pytorch.version import __version__
1 change: 1 addition & 0 deletions dalle_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__version__ = '1.6.3'
1 change: 1 addition & 0 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

# dalle related classes and utils

from dalle_pytorch import __version__
from dalle_pytorch import DiscreteVAE, OpenAIDiscreteVAE, VQGanVAE, DALLE
from dalle_pytorch.tokenizer import tokenizer, HugTokenizer, YttmTokenizer, ChineseTokenizer

Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from setuptools import setup, find_packages
exec(open('dalle_pytorch/version.py').read())

setup(
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.6.2',
version = __version__,
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
author_email = 'lucidrains@gmail.com',
long_description_content_type = 'text/markdown',
url = 'https://github.com/lucidrains/dalle-pytorch',
keywords = [
'artificial intelligence',
Expand Down
7 changes: 2 additions & 5 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

from dalle_pytorch import __version__
from dalle_pytorch import OpenAIDiscreteVAE, VQGanVAE, DiscreteVAE, DALLE
from dalle_pytorch import distributed_utils
from dalle_pytorch.loader import TextImageDataset
Expand Down Expand Up @@ -147,10 +148,6 @@ def exists(val):
def get_trainable_params(model):
return [params for params in model.parameters() if params.requires_grad]

def get_pkg_version():
from pkg_resources import get_distribution
return get_distribution('dalle_pytorch').version

def cp_path_to_dir(cp_path, tag):
"""Convert a checkpoint path to a directory with `tag` inserted.
If `cp_path` is already a directory, return it unchanged.
Expand Down Expand Up @@ -540,7 +537,7 @@ def save_model(path, epoch=0):
'hparams': dalle_params,
'vae_params': vae_params,
'epoch': epoch,
'version': get_pkg_version(),
'version': __version__,
'vae_class_name': vae.__class__.__name__
}

Expand Down

0 comments on commit 343f8b0

Please sign in to comment.