Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add custom dataset and instruction for training on a custom dataset #54

Merged
merged 1 commit into from
Jun 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,16 @@ included in the repository, run
streamlit run scripts/sample_conditional.py -- -r logs/2020-11-20T21-45-44_ade20k_transformer/ --ignore_base_data data="{target: main.DataModuleFromConfig, params: {batch_size: 1, validation: {target: taming.data.ade20k.Examples}}}"
```

## Training on custom data

Training on your own dataset can be beneficial to get better tokens and hence better images for your domain.
Those are the steps to follow to make this work:
1. install the repo with `conda env create -f environment.yaml`, `conda activate taming` and `pip install -e .`
1. put your .jpg files in a folder `your_folder`
2. create 2 text files a xx_train.txt and xx_test.txt that point to the files in your training and test set respectively (for example `find `pwd`/your_folder -name "*.jpg" > train.txt`)
3. adapt configs/custom_vqgan.yaml to point to these 2 files
4. run `python main.py --base configs/custom_vqgan.yaml -t True --gpus 0,1`

## Data Preparation

### ImageNet
Expand Down
43 changes: 43 additions & 0 deletions configs/custom_vqgan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
model:
base_learning_rate: 4.5e-6
target: taming.models.vqgan.VQModel
params:
embed_dim: 256
n_embed: 1024
ddconfig:
double_z: False
z_channels: 256
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [ 1,1,2,2,4] # num_down = len(ch_mult)-1
num_res_blocks: 2
attn_resolutions: [16]
dropout: 0.0

lossconfig:
target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
params:
disc_conditional: False
disc_in_channels: 3
disc_start: 10000
disc_weight: 0.8
codebook_weight: 1.0

data:
target: main.DataModuleFromConfig
params:
batch_size: 5
num_workers: 8
train:
target: taming.data.custom.CustomTrain
params:
training_images_list_file: some/training.txt
size: 256
validation:
target: taming.data.custom.CustomTest
params:
test_images_list_file: some/test.txt
size: 256

38 changes: 38 additions & 0 deletions taming/data/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset

from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex


class CustomBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None

def __len__(self):
return len(self.data)

def __getitem__(self, i):
example = self.data[i]
return example



class CustomTrain(CustomBase):
def __init__(self, size, training_images_list_file):
super().__init__()
with open(training_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)


class CustomTest(CustomBase):
def __init__(self, size, test_images_list_file):
super().__init__()
with open(test_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)