Skip to content

This repository provides a PyTorch Lighting implementation for VICReg, as described in the paper VICReg: Variance-Invariance-Covariance Regularization For Self-Supervised Learning. This repo is inspired on the original repository of Meta AI.

License

Notifications You must be signed in to change notification settings

felipevillaarenas/vicreg

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

75 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

VICReg: Variance-Invariance-Covariance Regularization For Self-Supervised Learning

Python PyTorch Lightning Lightning



📌  Introduction

This repository provides a PyTorch Lighting implementation for VICReg, as described in the paper VICReg: Variance-Invariance-Covariance Regularization For Self-Supervised Learning. This repo is inspired on the original repository of Meta AI.

This module was written with the style used in Lightning Bolts for other SOTA Self-Supervised models.

Why PyTorch Lightning?

PyTorch Lightning is a lightweight PyTorch wrapper for high-performance AI research. It makes your code neatly organized and provides lots of useful features, like ability to run model on CPU, GPU, multi-GPU cluster and TPU.

Why Lightning Bolts?

Lightning Bolts is a community-built deep learning research and production toolbox, featuring a collection of well established and SOTA models and components, pre-trained weights, callbacks, loss functions, data sets, and data modules.​

How to used this module?

Here are some examples!

Python

model = VICReg(
        arch="resnet18",
        maxpool1=False,
        first_conv=False,
        mlp_expander='2048-2048-2048',
        invariance_coeff=25.0,
        variance_coeff=25.0,
        covariance_coeff=1.0,
        optimizer="lars",
        learning_rate=0.3,
        warmup_steps=10
        )

dm = CIFAR10DataModule(batch_size=128, num_workers=0)

dm.train_transforms = VICRegTrainDataTransform(
        input_height=32,
        gaussian_blur=False,
        jitter_strength=1.0
        )

dm.val_transforms = VICRegEvalDataTransform(
        input_height=32,
        gaussian_blur=False,
        jitter_strength=1.0
        )

trainer = pl.Trainer()
trainer.fit(model, datamodule=dm)

Command line interface [cifar10]

python vicreg_module.py
                --accelerator gpu
                --devices 1
                --dataset cifar10
                --data_dir /path/to/cifar/
                --batch_size 128
                --arch resnet18
                --maxpool1 False
                --first_conv False,
                --mlp_expander 2048-2048-2048
                --invariance_coeff 25.0
                --variance_coeff 25.0
                --covariance_coeff 1.0
                --optimizer adam
                --learning_rate 0.3
                --warmup_steps 10

Command line interface [imagenet]

python vicreg_module.py
                --accelerator gpu
                --devices 1
                --dataset imagenet
                --data_dir /path/to/imagenet/
                --batch_size 512
                --arch resnet50
                --maxpool1 True
                --first_conv True,
                --mlp_expander 8192-8192-8192
                --invariance_coeff 25.0
                --variance_coeff 25.0
                --covariance_coeff 1.0
                --optimizer lars
                --learning_rate 0.6
                --warmup_steps 10

Results

I have pre-trained the model for CIFAR10(here WandB eval metrics for CIFAR10)

Check the Colab version

If you love notebooks and free GPUs, the Colab version of this repository can be found here


About

This repository provides a PyTorch Lighting implementation for VICReg, as described in the paper VICReg: Variance-Invariance-Covariance Regularization For Self-Supervised Learning. This repo is inspired on the original repository of Meta AI.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages