Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[question] [feature request] Validation for VAE #37

Open
meric-sakarya opened this issue Jun 27, 2021 · 1 comment
Open

[question] [feature request] Validation for VAE #37

meric-sakarya opened this issue Jun 27, 2021 · 1 comment

Comments

@meric-sakarya
Copy link

Hello, I am trying to train your VAE for a project of my own and I noticed there is no validation part in the training. Is there an easy way to add validation to that training? Could you help me with how I might do that using your DataLoader? Thanks in advance.

@kncrane
Copy link

kncrane commented Jul 16, 2021

If you import train_test_split from sklearn.model_selection you can do something like this

# CHANGED JPG TO PNG
images = [im for im in os.listdir(args.folder) if im.endswith('.png')]
images = np.array(images)
n_samples = len(images)

if args.n_samples > 0:
    n_samples = min(n_samples, args.n_samples)

# indices for all time steps where the episode continues
indices = np.arange(n_samples, dtype='int64')
np.random.shuffle(indices)

# NEW SECTION THAT SPLITS INDICES INTO A TRAIN AND VAL SET FIRST BEFORE BATCHING
indices_df = pd.DataFrame(indices, columns = ['indices'])
train_series, val_series = train_test_split(indices_df['indices'], train_size = 0.8)
train = train_series.to_numpy()
val = val_series.to_numpy()

print("{} images in total".format(n_samples))
print("{} images in training set".format(len(train)))
print("{} images in validation set".format(len(val)))

# split indices into minibatches. minibatchlist is a list of lists; each
# list is the id of the observation preserved through the training
train_minibatchlist = [np.array(sorted(train[start_idx:start_idx + args.batch_size]))
                 for start_idx in range(0, len(train) - args.batch_size + 1, args.batch_size)]

val_minibatchlist = [np.array(sorted(val[start_idx:start_idx + args.batch_size]))
                 for start_idx in range(0, len(val) - args.batch_size + 1, args.batch_size)]

train_data_loader = DataLoader(train_minibatchlist, images, n_workers=2, folder=args.folder)
val_data_loader = DataLoader(val_minibatchlist, images, n_workers=2, folder=args.folder)

This is within train.py, no need to edit DataLoader class

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants