Skip to content

Commit

Permalink
support for transparent images
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 9, 2022
1 parent fcd35de commit bebc280
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 7 deletions.
1 change: 1 addition & 0 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
assert num_layers >= 1, 'number of layers must be greater than or equal to 1'
has_resblocks = num_resnet_blocks > 0

self.channels = channels
self.image_size = image_size
self.num_tokens = num_tokens
self.num_layers = num_layers
Expand Down
3 changes: 3 additions & 0 deletions dalle_pytorch/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def __init__(self):
self.dec = load_model(download(OPENAI_VAE_DECODER_PATH))
make_contiguous(self)

self.channels = 3
self.num_layers = 3
self.image_size = 256
self.num_tokens = 8192
Expand Down Expand Up @@ -175,7 +176,9 @@ def __init__(self, vqgan_model_path=None, vqgan_config_path=None):

# f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]

self.num_layers = int(log(f)/log(2))
self.channels = 3
self.image_size = 256
self.num_tokens = config.model.params.n_embed
self.is_gumbel = isinstance(self.model, GumbelVQ)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.5.2',
version = '1.6.0',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down
11 changes: 6 additions & 5 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,6 @@ def cp_path_to_dir(cp_path, tag):
else:
vae = OpenAIDiscreteVAE()

IMAGE_SIZE = vae.image_size
resume_epoch = loaded_obj.get('epoch', 0)
else:
if exists(VAE_PATH):
Expand Down Expand Up @@ -296,8 +295,6 @@ def cp_path_to_dir(cp_path, tag):
else:
vae = OpenAIDiscreteVAE()

IMAGE_SIZE = vae.image_size

dalle_params = dict(
num_text_tokens=tokenizer.vocab_size,
text_seq_len=TEXT_SEQ_LEN,
Expand All @@ -319,6 +316,10 @@ def cp_path_to_dir(cp_path, tag):
)
resume_epoch = 0

IMAGE_SIZE = vae.image_size
CHANNELS = vae.channels
IMAGE_MODE = 'RGBA' if CHANNELS == 4 else 'RGB'

# configure OpenAI VAE for float16s

if isinstance(vae, OpenAIDiscreteVAE) and args.fp16:
Expand All @@ -345,8 +346,8 @@ def group_weight(model):
is_shuffle = not distributed_utils.using_backend(distributed_utils.HorovodBackend)

imagepreproc = T.Compose([
T.Lambda(lambda img: img.convert('RGB')
if img.mode != 'RGB' else img),
T.Lambda(lambda img: img.convert(IMAGE_MODE)
if img.mode != IMAGE_MODE else img),
T.RandomResizedCrop(IMAGE_SIZE,
scale=(args.resize_ratio, 1.),
ratio=(1., 1.)),
Expand Down
9 changes: 8 additions & 1 deletion train_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@

model_group.add_argument('--kl_loss_weight', type = float, default = 0., help = 'KL loss weight')

model_group.add_argument('--transparent', dest = 'transparent', action = 'store_true')

args = parser.parse_args()

# constants
Expand All @@ -88,6 +90,10 @@
HIDDEN_DIM = args.hidden_dim
KL_LOSS_WEIGHT = args.kl_loss_weight

TRANSPARENT = args.transparent
CHANNELS = 4 if TRANSPARENT else 3
IMAGE_MODE = 'RGBA' if TRANSPARENT else 'RGB'

STARTING_TEMP = args.starting_temp
TEMP_MIN = args.temp_min
ANNEAL_RATE = args.anneal_rate
Expand All @@ -107,7 +113,7 @@
ds = ImageFolder(
IMAGE_PATH,
T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Lambda(lambda img: img.convert(IMAGE_MODE) if img.mode != IMAGE_MODE else img),
T.Resize(IMAGE_SIZE),
T.CenterCrop(IMAGE_SIZE),
T.ToTensor()
Expand All @@ -127,6 +133,7 @@
image_size = IMAGE_SIZE,
num_layers = NUM_LAYERS,
num_tokens = NUM_TOKENS,
channels = CHANNELS,
codebook_dim = EMB_DIM,
hidden_dim = HIDDEN_DIM,
num_resnet_blocks = NUM_RESNET_BLOCKS
Expand Down

0 comments on commit bebc280

Please sign in to comment.