Skip to content

This repo contains a simple and clear PyTorch implementation of the main building blocks of "Unsupervised Data Augmentation for Consistency Training" by Qizhe Xie, Zihang Dai, Eduard Hovy, Minh-Thang Luong, Quoc V. Le

Notifications You must be signed in to change notification settings

UmarSpa/Unsupervised-Data-Augmentation-for-Consistency-Training

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Unsupervised Data Augmentation for Consistency Training

This repo contains a simple and clear PyTorch implementation of the main building blocks of "Unsupervised Data Augmentation for Consistency Training" by Qizhe Xie, Zihang Dai, Eduard Hovy, Minh-Thang Luong, Quoc V. Le

Parameters

--mod:          default='semisup':          Supervised (sup) or semi-supervised training (semisup)
--sup_num:      default=4000:               Number of samples in supervised training set (out of 50K)
--val_num:      default=1000:               Number of samples in validation set (out of 50K)
--rand_seed:    default=89:                 Random seed for dataset shuffle
--sup_aug:      default=['crop', 'hflip']:  Data augmentation for supervised and unsupervised samples (crop, hflip, cutout, randaug)
--unsup_aug:    default=['randaug']:        Data augmentation (Noise) for unsupervised noisy samples (crop, hflip, cutout, randaug)
--bsz_sup:      default=64:                 Batch size for supervised training
--bsz_unsup:    default=448:                Batch size for unsupervised training
--softmax_temp: default=0.4:                Softmax temperature for target distribution (unsup)
--conf_thresh:  default=0.8:                Confidence threshold for target distribution (unsup)
--unsup_loss_w: default=1.0:                Unsupervised loss weight
--max_iter:     default=500000:             Total training iterations
--vis_idx:      default=10:                 Output visualization index
--eval_idx:     default=1000:               Validation index
--out_dir:      default='./output/':        Output directory

Examples runs

For semi supervised training:

python main.py --mod 'semisup' --sup_num 4000 --sup_aug 'crop' 'hflip' --unsup_aug 'randaug' --bsz_sup 64 --bsz_sup 448

For supervised training:

python main.py --mod 'sup' --sup_num 49000 --sup_aug 'randaug' --bsz_sup 64

Notes

Some of the code for this implementation is borrowed from online sources, as detailed below:

About

This repo contains a simple and clear PyTorch implementation of the main building blocks of "Unsupervised Data Augmentation for Consistency Training" by Qizhe Xie, Zihang Dai, Eduard Hovy, Minh-Thang Luong, Quoc V. Le

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages