-
Notifications
You must be signed in to change notification settings - Fork 0
/
C_default_lsun_configs.py
72 lines (63 loc) · 2.03 KB
/
C_default_lsun_configs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import ml_collections
import torch
def get_default_configs():
config = ml_collections.ConfigDict()
# training
config.training = training = ml_collections.ConfigDict()
config.training.batch_size = 1 #1
training.n_iters = 300001
training.snapshot_freq = 10000#10000 #50000
training.log_freq = 100
training.eval_freq = 100
## store additional checkpoints for preemption in cloud computing environments
training.snapshot_freq_for_preemption = 20000
## produce samples at each snapshot.
training.snapshot_sampling = True
training.likelihood_weighting = False
training.continuous = True
training.reduce_mean = False
# sampling
config.sampling = sampling = ml_collections.ConfigDict()
sampling.n_steps_each = 1
sampling.noise_removal = True
sampling.probability_flow = False
sampling.snr = 0.075
# evaluation
config.eval = evaluate = ml_collections.ConfigDict()
evaluate.begin_ckpt = 1
evaluate.end_ckpt = 100
evaluate.batch_size = 1 # 8
evaluate.enable_sampling = True
evaluate.num_samples = 1000#1000 #50000
evaluate.enable_loss = True
evaluate.enable_bpd = False
evaluate.bpd_dataset = 'test'
# data
config.data = data = ml_collections.ConfigDict()
data.dataset = 'LSUN'
data.image_size = 64 # 768
data.random_flip = True #True
data.uniform_dequantization = False
data.centered = False ### inverse scale / get_data_scaler
data.num_channels = 9048
# model
config.model = model = ml_collections.ConfigDict()
model.sigma_max = 378#50 #378
model.sigma_min = 0.01
model.num_scales = 2000#1000 #2000
model.beta_min = 0.1
model.beta_max = 20.
model.dropout = 0.
model.embedding_type = 'fourier'
# optimization
config.optim = optim = ml_collections.ConfigDict()
optim.weight_decay = 0
optim.optimizer = 'Adam'
optim.lr = 2e-4#2e-4
optim.beta1 = 0.5#5-0.5#4-0.9#5e4-0.9
optim.eps = 1e-8#1e-8
optim.warmup = 5000
optim.grad_clip = 1.
config.seed = 42
config.device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
return config