Skip to content

StanfordMIMI/CompRx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CompRx: A Benchmark for Diagnostically Lossless Compression of Medical Images

Overview

🫁 Overview

Medical images are often acquired at high resolutions (>1 megapixel) in order to capture fine-grained details necessary for diagnosis. CompRx is a benchmark with clinically-relevant evaluation tasks that measure the preservation of fine-grained diagnostic features in compressed medical images.

⚡️ Installation

pip install -e .
pip install -r requirements.txt
pre-commit install
pre-commit

Update .env with the filepath of your project folder on disk. Then, run accelerate config to generate a config file, and copy it from ~/cache/huggingface/accelerate/default_config.yaml to the project directory. Finally, create symlinks from the data/ folder to the datasets you want to train on.

⚙️ Med-VAE

The train_vae.py script can be used to train Med-VAE.

# Train VAE (edit vae.yaml in order to control loss and autoencoder parameters)
accelerate launch comprx/train_vae.py experiment=vae

The infer_vae.py script can be used to cache latents.

# After training VAE, cache learned latents
accelerate launch --num_processes=1 comprx/infer_vae.py experiment=vae_inference csv_stem=malignancy img_size=64 paths.inference_output_dir=/admin/home-sluijs/comprx/data/tmp/ num_latent_channels=1 model.ddconfig.ch_mult=[1,2,4,4] resume_from_ckpt=/fsx/aimi/vae-checkpoints/15000/8x1/step_15000.pt dataset_ids=[1]

🩺 CompRx Tasks

Fine-Grained Classification

To train a classifier, use the comprx/train_cls.py script.

# Malignancy Detection
accelerate launch comprx/train_cls.py \
    experiment=cls_bicubic_malignancy
    dataset_id=2 \
    data_subdir=256 \
    model.backbone=resnet50 \
    model.freeze=True \  # defaults to True
    ckpt_path=/path/checkpoints/last.pt/pytorch_model.bin

# BI-RADS Prediction
accelerate launch comprx/train_cls.py \
    experiment=cls_bicubic_birads \
    data_subdir=256 \
    odel.backbone=resnet50 \
    ckpt_path=/path/checkpoints/last.pt/pytorch_model.bin

# Calcification Detection
accelerate launch comprx/train_cls.py \
    experiment=cls_bicubic_calcification \
    dataset_id=2 \
    data_subdir=256 \
    model.backbone=resnet50 \
    ckpt_path=/path/checkpoints/last.pt/pytorch_model.bin

# Bone Age Prediction
accelerate launch comprx/train_cls.py \
    experiment=cls_bicubic_boneage \
    dataset_id=9 \
    data_subdir=256 \
    model.backbone=resnet50 \
    ckpt_path=/path/checkpoints/last.pt/pytorch_model.bin

# Pediatric Wrist Fracture Detection
accelerate launch comprx/train_cls.py \
    experiment=cls_bicubic_fracture \
    dataset_id=10 \
    data_subdir=256 \
    model.backbone=resnet50 \
    ckpt_path=/path/checkpoints/last.pt/pytorch_model.bin


# Use ImageNet pretrained weights
accelerate launch comprx/train_cls.py experiment=... +model.pretrained=True

Perceptual Quality

To evaluate perceptual quality, use the comprx/compute_vae_rec_metrics.py script.

python3 comprx/compute_vae_rec_metrics.py \
    experiment=vae_metrics \
    resume_from_ckpt=bicubic-4x \
    img_size=768 \
    seed=0 

📚 Datasets

Each dataset uses the same Dataset named GenericDataset. This dataset is reliant on a single CSV file in which the user is able to specify columns with information about either splits, images, labels or text. The corresponding _transform methods can be used to transform these column values to representations that can be used in the ML-pipeline. Here are two examples for CANDID-PTX and RSNA Mammo, respectively.

from comprx.dataloaders import GenericDataset

ds = GenericDataset(
    split_path="/path/candid/splits/train.csv",
    data_dir="/path/candid/data/1024",
    dataset_id=0,
    img_column="image_uuid",
    img_suffix=".img.npy",
    img_transform=load_tensor,
    lbl_columns=["ptx", "fracture", "chest_tube"],
    lbl_transform=load_labels,
    txt_column="report",
    txt_transform=lambda x: x,
)

The ZarrLoader used in the example below allows for efficient random cropping without having to load the complete array buffer into memory.

from functools import partial

ds = GenericDataset(
    split_path="/path/rsna/splits/malignancy.csv",
    split_column="split",
    split_name="train",
    data_dir="/path/rsna/mammogram/mg-1/data",
    dataset_id=0,
    img_column="image_uuid",
    img_transform=ZarrLoader(size=512),
    lbl_columns=["BIRADS"],
    lbl_transform=partial(load_labels, dtype=torch.long, fill_null=0),
)

🖥️ Acknowledgments

This repository is powered by Hydra and HuggingFace Accelerate. Our implementation of Med-VAE is inspired by prior work on diffusion models from CompVis and Stability AI.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages