Skip to content

Commit

Permalink
Merge branch 'custom_dataset' of https://github.com/rom1504/taming-tr…
Browse files Browse the repository at this point in the history
…ansformers into rom1504-custom_dataset
  • Loading branch information
pesser committed Jun 24, 2021
2 parents 049bd91 + c3eeaff commit f21f6da
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ included in the repository, run
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
```

## Training on custom data

Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
Those are the steps to follow to make this work:
1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
1. put your .jpg files in a folder `your_folder`
2. create 2 text files a xx_train.txt and xx_test.txt that point to the files in your training and test set respectively (for example `find `pwd`/your_folder -name "*.jpg" > train.txt`)
3. adapt configs/custom_vqgan.yaml to point to these 2 files
4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1`

## Data Preparation

### ImageNet
Expand Down
43 changes: 43 additions & 0 deletions configs/custom_vqgan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
double_z: False
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16]
dropout: 0.0

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 10000
disc_weight: 0.8
codebook_weight: 1.0

data:
target: main.DataModuleFromConfig
params:
batch_size: 5
num_workers: 8
train:
target: taming.data.custom.CustomTrain
params:
training_images_list_file: some/training.txt
size: 256
validation:
target: taming.data.custom.CustomTest
params:
test_images_list_file: some/test.txt
size: 256

38 changes: 38 additions & 0 deletions taming/data/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset

from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex


class CustomBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None

def __len__(self):
return len(self.data)

def __getitem__(self, i):
example = self.data[i]
return example



class CustomTrain(CustomBase):
def __init__(self, size, training_images_list_file):
super().__init__()
with open(training_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)


class CustomTest(CustomBase):
def __init__(self, size, test_images_list_file):
super().__init__()
with open(test_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)


0 comments on commit f21f6da

Please sign in to comment.