Skip to content

uoe-agents/CMID

Repository files navigation

Conditional Mutual Information for Disentanglement

This is the implementation of Conditional Mutual Information for Disentanglement (CMID) from the paper Conditional Mutual Information for Disentangled Representations in Reinforcement Learning.

This code is based on the DrQ PyTorch implementation by Yarats et al. and the DMControl Generalisation Benchmark by Hansen et al. which also contains the official SVEA implementation. As per the original code bases, we use kornia for data augmentation.

The CMID auxiliary task applied to SVEA as the base RL algorithm is largely contained in the algorithms/svea_cmid.py file. The dmc2gym folder contains the dmc2gym code amended slighty to create the colour correlations.

Requirements

We assume you have access to MuJoCo and a GPU that can run CUDA 11.7. Then, the simplest way to install all required dependencies is to create a conda environment by running:

conda env create -f conda_env.yml

You can activate your environment with:

conda activate cmid

Instructions

You can run the code uing the configuration specified in arguments.py with:

python train.py

The configs folder contains bash scripts for all the algorithms used in the paper on the cartpole task as an example. You can run a specific configuration using the bash script, for example:

sh configs/cartpole_colour_correlation_svea_cmid.sh

This will produce the runs folder, where all the outputs are going to be stored including train/eval logs.

The console output is also available in the form:

| train | E: 5 | S: 5000 | R: 11.4359 | D: 66.8 s | BR: 0.0581 | ALOSS: -1.0640 | CLOSS: 0.0996 | TLOSS: -23.1683 | TVAL: 0.0945 | AENT: 3.8132 | CMIDD: 0.7837 | CMIDA: 0.6953

a training entry decodes as

train - training episode
E - total number of episodes
S - total number of environment steps
R - episode return
D - duration in seconds
BR - average reward of a sampled batch
ALOSS - average loss of the actor
CLOSS - average loss of the critic
TLOSS - average loss of the temperature parameter
TVAL - the value of temperature
AENT - the actor's entropy
CMIDD - average of the CMID discriminator loss
CMIDA - average of the CMID adversarial loss

while an evaluation entry

| eval  | E: 20 | S: 20000 | R: 10.9356

contains

E - evaluation was performed after E episodes
S - evaluation was performed after S environment steps
R - average episode return computed over `num_eval_episodes` (usually 10)

Results and Plots From Paper

The data for the experiment results in the paper can be found here. These files contain the evaluation returns for all algorithms and seeds used to create Figures 4 and 5.

Citation

@inproceedings{dunion2023cmid,
   title={Conditional Mutual Information for Disentangled Representations in Reinforcement Learning},
   author={Mhairi Dunion and Trevor McInroe and Kevin Sebastian Luck and Josiah Hanna and Stefano V. Albrecht},
   booktitle={Conference on Neural Information Processing Systems},
   year={2023}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published