## [Swin2-MoSE: A New Single Image Super-Resolution Model for Remote Sensing](https://arxiv.org/abs/2404.18924)

Official PyTorch implementation of **Swin2-MoSE**.

In this paper, we propose **Swin2-MoSE** model, an enhanced version of Swin2SR for
Single-Image Super-Resolution for Remote Sensing.

Swin2-MoSE Aarchitecture

Swin2-MoSE MoE-SM

Swin2-MoSE Positional Encoding

Authors: Leonardo Rossi, Vittorio Bernuzzi, Tomaso Fontanini,
Massimo Bertozzi, Andrea Prati.

[IMP Lab](http://implab.ce.unipr.it/) -
Dipartimento di Ingegneria e Architettura

University of Parma, Italy


## Abstract

Due to the limitations of current optical and sensor technologies and the high cost of updating them, the spectral and spatial resolution of satellites may not always meet desired requirements. +For these reasons, Remote-Sensing Single-Image Super-Resolution (RS-SISR) techniques have gained significant interest. + +In this paper, we propose Swin2-MoSE model, an enhanced version of Swin2SR. + +Our model introduces MoE-SM, an enhanced Mixture-of-Experts (MoE) to replace the Feed-Forward inside all Transformer block. +MoE-SM is designed with Smart-Merger, and new layer for merging the output of individual experts, and with a new way to split the work between experts, defining a new per-example strategy instead of the commonly used per-token one. + +Furthermore, we analyze how positional encodings interact with each other, demonstrating that per-channel bias and per-head bias can positively cooperate. + +Finally, we propose to use a combination of Normalized-Cross-Correlation (NCC) and Structural Similarity Index Measure (SSIM) losses, to avoid typical MSE loss limitations. + +Experimental results demonstrate that Swin2-MoSE outperforms SOTA by up to 0.377 ~ 0.958 dB (PSNR) on task of 2x, 3x and 4x resolution-upscaling (Sen2Venus and OLI2MSI datasets). +We show the efficacy of Swin2-MoSE, applying it to a semantic segmentation task (SeasoNet dataset). + + +## Usage + +### Installation +```bash +$ git clone https://github.com/IMPLabUniPr/swin2-mose/tree/official_code +$ cd swin2-mose +$ conda env create -n swin2_mose_env --file environment.yml +$ conda activate swin2_mose_env +``` + +### Prepare Sen2Venus dataset + +1) After you downloaded the files from + [Sen2Venus](https://zenodo.org/records/6514159) official website, unzip them + inside the `./datasets/sen2venus_original` directory. + +2) Run the script [split.py](https://github.com/IMPLabUniPr/swin2-mose/tree/official_code/scripts/sen2venus/split.py) to split the dataset in training (~80%) and + test (~20%): + +```bash +python scripts/sen2venus/split.py --input ./datasets/sen2venus_original --output ./data/sen2venus +``` + +After the successfull execution of the script, you will find `train.csv` and +`test.csv` files inside the `./data/sen2venus`. + +Note: if you want to skip this run and use our `train.csv` and `test.csv` +files directly, you can download them from +[Release v1.0](https://github.com/IMPLabUniPr/swin2-mose/releases/tag/v1.0) +page. + +3) Run the script [rebuild.py](https://github.com/IMPLabUniPr/swin2-mose/tree/official_code/scripts/sen2venus/rebuild.py) to rebuild the dataset in a compatible + format: + +```bash +python scripts/sen2venus/rebuild.py --data ./datasets/sen2venus_original --output ./data/sen2venus +``` + +If everything went well, you will have the following files structure: + +``` +data/sen2venus +├── test +│   ├── 000000_ALSACE_2018-02-14.pt +│   ├── 000001_ALSACE_2018-02-14.pt +| ... +├── test.csv +├── train +│   ├── 000000_ALSACE_2018-02-14.pt +│   ├── 000001_ALSACE_2018-02-14.pt +| ... +└── train.csv +``` + +Note about Sen2venus: we found a small error in file name convention! + +On paper, authors wrote for `4x` files, the following: + +``` +{id}_05m_b5b6b7b8a.pt - 5m patches (256×256 pix.) for S2 B5, B6, B7 and B8A (from VENµS) +{id}_20m_b5b6b7b8a.pt - 20m patches (64×64 pix.) for S2 B5, B6, B7 and B8A (from Sentinel-2) +``` + +But, we found the following name conventions: + +``` +ALSACE_C_32ULU_2018-02-14_05m_b4b5b6b8a.pt +ALSACE_C_32ULU_2018-02-14_20m_b4b5b6b8a.pt +``` + +### Prepare OLI2MSI dataset + +Download from the +[OLI2MSI](https://github.com/wjwjww/OLI2MSI) official website and unzip it +inside the `./data/oli2msi` directory. + +If everything went well, you will have the following files structure: + +``` +data/oli2msi +├── test_hr +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0071.TIF +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0108.TIF +| ... +├── test_lr +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0071.TIF +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0108.TIF +| ... +├── train_hr +│   ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0008.TIF +│   ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0015.TIF +| ... +└── train_lr + ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0008.TIF + ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0015.TIF + ... +``` + +### Prepare SeasoNet dataset + +Download from the [SeasoNet](https://zenodo.org/records/5850307) official +website and unzip it inside the `./data/SeasoNet/data` directory. + +If everything went well, you will have the following files structure: + +``` +data/SeasoNet +└── data +    ├── fall +    │   ├── grid1 +    │   └── grid2 +    ├── meta.csv +    ├── snow +    │   ├── grid1 +    │   └── grid2 +    ├── splits +    │   ├── test.csv +    │   ├── train.csv +    │   └── val.csv +    ├── spring +    │   ├── grid1 +    │   └── grid2 +    ├── summer +    │   ├── grid1 +    │   └── grid2 +    └── winter +    ├── grid1 +    └── grid2 +``` + +Note about SeasoNet: SeasoNet could be easily downloaded by +[TorchGeo](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#seasonet) +class, specifying the root directory `./data/SeasoNet/data/`. + +### Download pretrained + +Open +[Release v1.0](https://github.com/IMPLabUniPr/swin2-mose/releases/tag/v1.0) +page and download .pt (pretrained) and .pkl (results) file. + +Unzip them inside the output directory, obtaining the following directories +structure: + +``` +output2/sen2venus_exp4_2x_v5/ +├── checkpoints +│   └── model-70.pt +└── eval + └── results-70.pt +``` + +### Train + +```bash +python src/main.py --phase train --config $CONFIG_FILE --output $OUT_DIR --epochs ${EPOCH} --epoch -1 +python src/main_ssegm.py --phase train --config $CONFIG_FILE --output $OUT_DIR --epochs ${EPOCH} --epoch -1 +``` + +### Validate + +```bash +python src/main.py --phase test --config $CONFIG_FILE --output $OUT_DIR --batch_size 32 --epoch ${EPOCH} +python src/main.py --phase test --config $CONFIG_FILE --batch_size 32 --eval_method bicubic +``` + +### Show results + +``` +python src/main.py --phase vis --config $CONFIG_FILE --output $OUT_DIR --num_images 3 --epoch ${EPOCH} +python src/main.py --phase vis --config $CONFIG_FILE --output output/sen2venus_4x_bicubic --num_images 3 --eval_method bicubic +python src/main.py --phase vis --config $CONFIG_FILE --output output/sen2venus_4x_bicubic --num_images 3 --eval_method bicubic --dpi 1200 +python src/main_ssegm.py --phase vis --config $CONFIG_FILE --output $OUT_DIR --num_images 2 --epoch ${EPOCH} +python src/main_ssegm.py --phase vis --config $CONFIG_FILE --output $OUT_DIR --num_images 2 --epoch ${EPOCH} --hide_sr +``` + +### Compute mean/std + +``` +python src/main.py --phase mean_std --config $CONFIG_FILE +python src/main_ssegm.py --phase mean_std --config $CONFIG_FILE +``` + +### Measure execution average time + +``` +python src/main.py --phase avg_time --config $CONFIG_FILE --repeat_times 1000 --warm_times 20 --batch_size 8 +``` + +## Results + +### Table 1 + +Ablation study on loss usage. + +| # | Losses ||| Performace ||| Conf | +| # | NCC | SSIM | MSE | NCC | SSIM | PSNR | | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | x | | | 0.9550 | 0.5804 | 16.4503 | [conf](cfgs/sen2venus_exp1_v1.yml) | +| 2 | | x | | 0.9565 | 0.9847 | 45.5427 | [conf](cfgs/sen2venus_exp1_v2.yml) | +| 3 | | | x | 0.9546 | 0.9828 | 45.4759 | [conf](cfgs/sen2venus_exp1_v3.yml) | +| 4 | x | x | | 0.9572 | 0.9841 | 45.6986 | [conf](cfgs/sen2venus_exp1_v4.yml) | +| 5 | x | | x | 0.9549 | 0.9828 | 45.5163 | [conf](cfgs/sen2venus_exp1_v5.yml) | +| 6 | x | x | x | 0.9555 | 0.9833 | 45.5542 | [conf](cfgs/sen2venus_exp1_v6.yml) | + +### Table 2 + +Ablation study on positional encoding. + +| # | Positional Encoding ||| Performace || Conf | +| # | RPE | log CPB | LePE | SSIM | PSNR | | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | x | | | 0.9841 | 45.5855 | [conf](cfgs/sen2venus_exp2_v1.yml) | +| 2 | | x | | 0.9841 | 45.6986 | [conf](cfgs/sen2venus_exp2_v2.yml) | +| 3 | | | x | 0.9843 | 45.7278 | [conf](cfgs/sen2venus_exp2_v3.yml) | +| 4 | | x | x | 0.9845 | 45.8046 | [conf](cfgs/sen2venus_exp2_v4.yml) | +| 5 | x | | x | 0.9847 | 45.8539 | [conf](cfgs/sen2venus_exp2_v5.yml) | +| 6 | x | x | | 0.9843 | 45.6945 | [conf](cfgs/sen2venus_exp2_v6.yml) | +| 7 | x | x | x | 0.9846 | 45.8185 | [conf](cfgs/sen2venus_exp2_v7.yml) | + +### Table 3 + +Ablation study on positional encoding. + +` | # | | | MLP #Params || Performace || Latency | Conf |` +| # | Arch | SM | APC | SPC | SSIM | PSNR | (s) | | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | MLP | | 32’670 | 32’670 | 0.9847 | 45.8539 | 0.194 | [conf](cfgs/sen2venus_exp3_v1.yml) | +| 2 | MoE 8/2 | | 32’760 | 132’480 | 0.9845 | 45.8647 | 0.202 | [conf](cfgs/sen2venus_exp3_v2.yml) | +| 3 | MoE 8/2 | x | 32’779 | 132’499 | 0.9849 | 45.9272 | 0.212 | [conf](cfgs/sen2venus_exp3_v3.yml) | + +### Table 4 + +Quantitative comparison with SOTA models on Sen2Veµs and OLI2MS datasets. + +| # | Model | Sen2Venus 2x || OLI2MSI 3x || Sen2Venus 4x || Conf ||| +| # | | SSIM | PSNR | SSIM | PSNR | SSIM | PSNR | 2x | 3x | 4x | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | Bicubic | 0.9883 | 45.5588 | 0.9768 | 42.1835 | 0.9674 | 42.0499 | | | | +| 2 | [SwinIR](https://openaccess.thecvf.com/content/ICCV2021W/AIM/html/Liang_SwinIR_Image_Restoration_Using_Swin_Transformer_ICCVW_2021_paper.html) | 0.9938 | 48.7064 | 0.9860 | 43.7482 | 0.9825 | 45.3460 | [conf](cfgs/sen2venus_exp4_2x_v2.yml) | [conf](cfgs/oli2msi_exp4_3x_v2.yml) | [conf](cfgs/sen2venus_exp4_4x_v2.yml) | +| 3 | [Swinfir](https://arxiv.org/abs/2208.11247) | 0.9940 | 48.8532 | 0.9863 | 44.4829 | 0.9830 | 45.5500 | [conf](cfgs/sen2venus_exp4_2x_v3.yml) | [conf](cfgs/oli2msi_exp4_3x_v3.yml) | [conf](cfgs/sen2venus_exp4_4x_v3.yml) | +| 4 | [Swin2SR](https://link.springer.com/chapter/10.1007/978-3-031-25063-7_42) | 0.9942 | 49.0467 | 0.9881 | 44.9614 | 0.9828 | 45.4759 | [conf](cfgs/sen2venus_exp4_2x_v4.yml) | [conf](cfgs/oli2msi_exp4_3x_v4.yml) | [conf](cfgs/sen2venus_exp4_4x_v4.yml) | +| 5 | Swin2-MoSE (ours) | 0.9948 | 49.4784 | 0.9912 | 45.9194 | 0.9849 | 45.9272 | [conf](cfgs/sen2venus_exp4_2x_v5.yml) | [conf](cfgs/oli2msi_exp4_3x_v5.yml) | [conf](cfgs/sen2venus_exp4_4x_v5.yml) | + + +### Figure 11 + +Results for the Semantic Segmentation task on SeasoNet dataset. + +| # | Model | Conf | +|:---:|:---:|:---:| +| 1 | FarSeg | [conf](cfgs/seasonet_exp5_v1.yml) | +| 2 | FarSeg++ | [conf](cfgs/seasonet_exp5_v2.yml) | +| 3 | FarSeg+S2MFE | [conf](cfgs/seasonet_exp5_v3.yml) | + +## License + +See [GPL v2](./LICENSE) License. + +## Acknowledgement + +Project ECS\_00000033\_ECOSISTER funded under the National Recovery and Resilience Plan (NRRP), Mission 4 +Component 2 Investment 1.5 - funded by the European Union – NextGenerationEU. +This research benefits from the HPC (High Performance Computing) facility of the University of Parma, Italy. + +## Citation +If you find our work useful in your research, please cite: + +``` +@article{rossi2024swin2, + title={Swin2-MoSE: A New Single Image Super-Resolution Model for Remote Sensing}, + author={Rossi, Leonardo and Bernuzzi, Vittorio and Fontanini, Tomaso and Bertozzi, Massimo and Prati, Andrea}, + journal={arXiv preprint arXiv:2404.18924}, + year={2024} +} +``` diff --git a/cfgs/oli2msi_exp4_3x_v2.yml b/cfgs/oli2msi_exp4_3x_v2.yml new file mode 100644 index 0000000..442dc4c --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v2.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_4x_v2.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/oli2msi_exp4_3x_v3.yml b/cfgs/oli2msi_exp4_3x_v3.yml new file mode 100644 index 0000000..21acc6b --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v3.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_4x_v3.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/oli2msi_exp4_3x_v4.yml b/cfgs/oli2msi_exp4_3x_v4.yml new file mode 100644 index 0000000..7ed76f2 --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v4.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_4x_v4.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/oli2msi_exp4_3x_v5.yml b/cfgs/oli2msi_exp4_3x_v5.yml new file mode 100644 index 0000000..0469430 --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v5.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_v5.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/seasonet_exp5_v1.yml b/cfgs/seasonet_exp5_v1.yml new file mode 100644 index 0000000..dc2da0e --- /dev/null +++ b/cfgs/seasonet_exp5_v1.yml @@ -0,0 +1,48 @@ +batch_size: 128 +losses: { + with_ce_criterion: true, + weights: { + ce: 1.0, + } +} +dataset: + root_path: data/SeasoNet/data + cls: datasets.seasonet.SeasoNetDataset + kwargs: { + seasons: ['Spring'], + bands: ['10m_RGB', '10m_IR',], + grids:[1], + } + collate_fn: mods.v5.collate_fn + stats: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + collate_kwargs: {} + denorm: mods.v5.uncollate_fn + printable: mods.v5.printable +optim: { + learning_rate: 0.0001, + model_betas: [0.9, 0.999], + model_eps: 0.00000001, + model_weight_decay: 0 +} +semantic_segm: { + pad_before: [4, 4, 4, 4], + in_channels: 4, + type: FarSeg, + model: { + backbone: resnet50, + } +} +train: semantic_segm.training.train +mean_std: mods.v5.get_mean_std +metrics: { + eval_every: 1 +} +visualize: { + model: semantic_segm.model.build_model, + checkpoint: chk_loader.load_state_dict_model_only, +} diff --git a/cfgs/seasonet_exp5_v2.yml b/cfgs/seasonet_exp5_v2.yml new file mode 100644 index 0000000..a1258f5 --- /dev/null +++ b/cfgs/seasonet_exp5_v2.yml @@ -0,0 +1,10 @@ +__base__: seasonet_exp5_v1.yml +semantic_segm: { + conv_up: { + in_ch: 4, + middle_ch: 90, + out_ch: 64, + kernel_size: 1, + padding: 0, + } +} diff --git a/cfgs/seasonet_exp5_v3.yml b/cfgs/seasonet_exp5_v3.yml new file mode 100644 index 0000000..607d9b9 --- /dev/null +++ b/cfgs/seasonet_exp5_v3.yml @@ -0,0 +1,10 @@ +__base__: seasonet_exp5_v1.yml +semantic_segm: { + in_channels: 90, + pad_before: [0, 0, 0, 0], + pad_after: [4, 4, 4, 4], + upscaler: { + config: cfgs/sen2venus_exp4_2x_v5.yml, + chk: output/sen2venus_exp4_2x_v5/checkpoints/model-70.pt, + } +} diff --git a/cfgs/sen2venus_4x.yml b/cfgs/sen2venus_4x.yml new file mode 100644 index 0000000..710837c --- /dev/null +++ b/cfgs/sen2venus_4x.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_base.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b5b6b7b8a: { + mean: [1182.8475341796875, 2155.208251953125, 2507.487060546875, 2800.94140625], + std: [594.6590576171875, 643.8070068359375, 777.8865356445312, 829.2948608398438], + min: [-8687.0, -3340.0, -2245.0, -5048.0], + max: [20197.0, 16498.0, 16674.0, 21622.0] + } + tensor_20m_b5b6b7b8a: { + mean: [1180.6920166015625, 2149.5302734375, 2500.98779296875, 2794.01220703125], + std: [592.2827758789062, 639.0105590820312, 769.7623291015625, 819.83349609375], + min: [-446.0, -295.0, -340.0, -551.0], + max: [13375.0, 15898.0, 15551.0, 15079.0] + } + hr_name: tensor_05m_b5b6b7b8a + lr_name: tensor_20m_b5b6b7b8a + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 4, + } +} +visualize: { + input_shape: [4, 64, 64], +} +metrics: { + upscale_factor: 4, +} diff --git a/cfgs/sen2venus_base.yml b/cfgs/sen2venus_base.yml new file mode 100644 index 0000000..7c643bc --- /dev/null +++ b/cfgs/sen2venus_base.yml @@ -0,0 +1,38 @@ +batch_size: 8 +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +optim: { + learning_rate: 0.0001, + model_betas: [0.9, 0.999], + model_eps: 0.00000001, + model_weight_decay: 0 +} +super_res: { + version: 'v2', + model: { + depths: [6, 6, 6, 6], + embed_dim: 90, + img_range: 1., + img_size: 64, + in_chans: 4, + mlp_ratio: 2, + num_heads: [6, 6, 6, 6], + resi_connection: 1conv, + upsampler: pixelshuffledirect, + window_size: 16, + } +} +train: super_res.training.train +mean_std: mods.v3.get_mean_std +visualize: { + checkpoint: chk_loader.load_state_dict_model_only, + model: super_res.model.build_model, +} +metrics: { + only_test_y_channel: false, +} diff --git a/cfgs/sen2venus_exp1_v1.yml b/cfgs/sen2venus_exp1_v1.yml new file mode 100644 index 0000000..e1bbca3 --- /dev/null +++ b/cfgs/sen2venus_exp1_v1.yml @@ -0,0 +1,9 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: false, + with_ssim_criterion: false, + with_cc_criterion: true, + weights: { + cc: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v2.yml b/cfgs/sen2venus_exp1_v2.yml new file mode 100644 index 0000000..3bc23e4 --- /dev/null +++ b/cfgs/sen2venus_exp1_v2.yml @@ -0,0 +1,9 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: false, + with_ssim_criterion: true, + with_cc_criterion: false, + weights: { + ssim: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v3.yml b/cfgs/sen2venus_exp1_v3.yml new file mode 100644 index 0000000..16407e2 --- /dev/null +++ b/cfgs/sen2venus_exp1_v3.yml @@ -0,0 +1,9 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: true, + with_ssim_criterion: false, + with_cc_criterion: false, + weights: { + pixel: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v4.yml b/cfgs/sen2venus_exp1_v4.yml new file mode 100644 index 0000000..659e8eb --- /dev/null +++ b/cfgs/sen2venus_exp1_v4.yml @@ -0,0 +1,10 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: false, + with_ssim_criterion: true, + with_cc_criterion: true, + weights: { + ssim: 1.0, + cc: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v5.yml b/cfgs/sen2venus_exp1_v5.yml new file mode 100644 index 0000000..2cdae3a --- /dev/null +++ b/cfgs/sen2venus_exp1_v5.yml @@ -0,0 +1,10 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: true, + with_ssim_criterion: false, + with_cc_criterion: true, + weights: { + pixel: 1.0, + cc: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v6.yml b/cfgs/sen2venus_exp1_v6.yml new file mode 100644 index 0000000..c5226f2 --- /dev/null +++ b/cfgs/sen2venus_exp1_v6.yml @@ -0,0 +1,11 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: true, + with_ssim_criterion: true, + with_cc_criterion: true, + weights: { + cc: 1.0, + ssim: 1.0, + pixel: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp2_v1.yml b/cfgs/sen2venus_exp2_v1.yml new file mode 100644 index 0000000..b05a4cc --- /dev/null +++ b/cfgs/sen2venus_exp2_v1.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: false, + use_cpb_bias: false, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp2_v2.yml b/cfgs/sen2venus_exp2_v2.yml new file mode 100644 index 0000000..6bbe5c5 --- /dev/null +++ b/cfgs/sen2venus_exp2_v2.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp1_v4.yml diff --git a/cfgs/sen2venus_exp2_v3.yml b/cfgs/sen2venus_exp2_v3.yml new file mode 100644 index 0000000..850787d --- /dev/null +++ b/cfgs/sen2venus_exp2_v3.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: false, + use_rpe_bias: false, + } +} diff --git a/cfgs/sen2venus_exp2_v4.yml b/cfgs/sen2venus_exp2_v4.yml new file mode 100644 index 0000000..4769276 --- /dev/null +++ b/cfgs/sen2venus_exp2_v4.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: true, + use_rpe_bias: false, + } +} diff --git a/cfgs/sen2venus_exp2_v5.yml b/cfgs/sen2venus_exp2_v5.yml new file mode 100644 index 0000000..67eea1e --- /dev/null +++ b/cfgs/sen2venus_exp2_v5.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: false, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp2_v6.yml b/cfgs/sen2venus_exp2_v6.yml new file mode 100644 index 0000000..28a4bf1 --- /dev/null +++ b/cfgs/sen2venus_exp2_v6.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: false, + use_cpb_bias: true, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp2_v7.yml b/cfgs/sen2venus_exp2_v7.yml new file mode 100644 index 0000000..645a502 --- /dev/null +++ b/cfgs/sen2venus_exp2_v7.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: true, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp3_v1.yml b/cfgs/sen2venus_exp3_v1.yml new file mode 100644 index 0000000..0d2d960 --- /dev/null +++ b/cfgs/sen2venus_exp3_v1.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp2_v5.yml diff --git a/cfgs/sen2venus_exp3_v1_wrong.yml b/cfgs/sen2venus_exp3_v1_wrong.yml new file mode 100644 index 0000000..319d71d --- /dev/null +++ b/cfgs/sen2venus_exp3_v1_wrong.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp2_v4.yml diff --git a/cfgs/sen2venus_exp3_v2.yml b/cfgs/sen2venus_exp3_v2.yml new file mode 100644 index 0000000..136c927 --- /dev/null +++ b/cfgs/sen2venus_exp3_v2.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp3_v3.yml +super_res: { + model: { + MoE_config: { + with_smart_merger: null + } + } +} diff --git a/cfgs/sen2venus_exp3_v3.yml b/cfgs/sen2venus_exp3_v3.yml new file mode 100644 index 0000000..04d6404 --- /dev/null +++ b/cfgs/sen2venus_exp3_v3.yml @@ -0,0 +1,17 @@ +__base__: sen2venus_exp2_v5.yml +losses: { + weights: { + moe: 0.2 + } +} +super_res: { + model: { + mlp_ratio: 1, + MoE_config: { + k: 2, + num_experts: 8, + with_noise: false, + with_smart_merger: v1, + } + } +} diff --git a/cfgs/sen2venus_exp4_2x_v2.yml b/cfgs/sen2venus_exp4_2x_v2.yml new file mode 100644 index 0000000..c2e1334 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v2.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_4x_v2.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_2x_v3.yml b/cfgs/sen2venus_exp4_2x_v3.yml new file mode 100644 index 0000000..19d3ef9 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v3.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_4x_v3.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_2x_v4.yml b/cfgs/sen2venus_exp4_2x_v4.yml new file mode 100644 index 0000000..5f4ae23 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v4.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_4x_v4.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_2x_v5.yml b/cfgs/sen2venus_exp4_2x_v5.yml new file mode 100644 index 0000000..a60b242 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v5.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_v5.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_4x_v2.yml b/cfgs/sen2venus_exp4_4x_v2.yml new file mode 100644 index 0000000..f98d8bc --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v2.yml @@ -0,0 +1,4 @@ +__base__: sen2venus_exp4_4x_v4.yml +super_res: { + version: 'v1', +} diff --git a/cfgs/sen2venus_exp4_4x_v3.yml b/cfgs/sen2venus_exp4_4x_v3.yml new file mode 100644 index 0000000..1a1ac91 --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v3.yml @@ -0,0 +1,7 @@ +__base__: sen2venus_exp4_4x_v2.yml +super_res: { + version: 'swinfir', + model: { + resi_connection: SFB, + } +} diff --git a/cfgs/sen2venus_exp4_4x_v4.yml b/cfgs/sen2venus_exp4_4x_v4.yml new file mode 100644 index 0000000..fda8c51 --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v4.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp1_v3.yml diff --git a/cfgs/sen2venus_exp4_4x_v5.yml b/cfgs/sen2venus_exp4_4x_v5.yml new file mode 100644 index 0000000..3369f8f --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v5.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp3_v3.yml diff --git a/cfgs/sen2venus_exp4_v5.yml b/cfgs/sen2venus_exp4_v5.yml new file mode 100644 index 0000000..3369f8f --- /dev/null +++ b/cfgs/sen2venus_exp4_v5.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp3_v3.yml diff --git a/cfgs/sen2venus_super_res_x2_all.yml b/cfgs/sen2venus_super_res_x2_all.yml new file mode 100644 index 0000000..21e7040 --- /dev/null +++ b/cfgs/sen2venus_super_res_x2_all.yml @@ -0,0 +1,60 @@ +batch_size: 16 +losses: { + with_pixel_criterion: true, + weights: { + pixel: 1.0, + } +} +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable + places: [] +optim: { + learning_rate: 0.0001, + model_betas: [0.9, 0.999], + model_eps: 0.00000001, + model_weight_decay: 0 +} +super_res: { + model: { + upscale: 2, + in_chans: 4, + img_size: 64, + window_size: 16, + img_range: 1., + depths: [6, 6, 6, 6], + embed_dim: 90, + num_heads: [6, 6, 6, 6], + mlp_ratio: 2, + upsampler: pixelshuffledirect, + resi_connection: 1conv + } +} +train: super_res.training.train +mean_std: mods.v3.get_mean_std +visualize: { + model: super_res.model.build_model, + checkpoint: chk_loader.load_state_dict_model_only, +} +metrics: { + only_test_y_channel: false, + upscale_factor: 2, +} diff --git a/conda-environment.yaml b/conda-environment.yaml new file mode 100644 index 0000000..a2dc43a --- /dev/null +++ b/conda-environment.yaml @@ -0,0 +1,172 @@ +name: l1bsr3 +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - boltons=23.0.0=py310h06a4308_0 + - brotlipy=0.7.0=py310h7f8727e_1002 + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.19.1=h5eee18b_0 + - ca-certificates=2023.08.22=h06a4308_0 + - certifi=2023.7.22=py310h06a4308_0 + - cffi=1.15.1=py310h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - conda=23.9.0=py310h06a4308_0 + - conda-libmamba-solver=23.9.1=py310h06a4308_0 + - conda-package-handling=2.2.0=py310h06a4308_0 + - conda-package-streaming=0.9.0=py310h06a4308_0 + - cryptography=41.0.3=py310hdda0065_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.2.140=0 + - cuda-runtime=12.1.0=0 + - ffmpeg=4.3=hf484d3e_0 + - filelock=3.9.0=py310h06a4308_0 + - fmt=9.1.0=hdb19cb5_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py310heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - icu=73.1=h6a678d5_0 + - idna=3.4=py310h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46305 + - jinja2=3.1.2=py310h06a4308_0 + - jpeg=9e=h5eee18b_1 + - jsonpatch=1.32=pyhd3eb1b0_0 + - jsonpointer=2.1=pyhd3eb1b0_0 + - krb5=1.20.1=h143b758_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libarchive=3.6.2=h6ac8c49_2 + - libcublas= + - libcufft= + - libcufile= + - libcurand= + - libcurl=8.2.1=h251f7ec_0 + - libcusolver= + - libcusparse= + - libdeflate=1.17=h5eee18b_1 + - libedit=3.1.20221030=h5eee18b_0 + - libev=4.33=h7f8727e_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libmamba=1.5.1=haf1ee3a_0 + - libmambapy=1.5.1=py310h2dafd23_0 + - libnghttp2=1.52.0=h2d74bed_1 + - libnpp= + - libnvjitlink=12.1.105=0 + - libnvjpeg= + - libpng=1.6.39=h5eee18b_0 + - libsolv=0.7.24=he621ea3_0 + - libssh2=1.10.0=hdbd6064_2 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxml2=2.10.4=hf1b16e4_1 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_0 + - markupsafe=2.1.1=py310h7f8727e_0 + - mkl=2023.1.0=h213fc3f_46343 + - mkl-service=2.4.0=py310h5eee18b_1 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py310h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.1=py310h06a4308_0 + - numpy=1.26.0=py310h5f9d8c6_0 + - numpy-base=1.26.0=py310hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.11=h7f8727e_2 + - packaging=23.1=py310h06a4308_0 + - pcre2=10.42=hebb0a14_0 + - pillow=10.0.1=py310ha6cbd5a_0 + - pip=23.2.1=py310h06a4308_0 + - pluggy=1.0.0=py310h06a4308_1 + - pybind11-abi=4=hd3eb1b0_1 + - pycosat=0.6.6=py310h5eee18b_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.2.0=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.13=h955ad1f_0 + - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0=py310h5eee18b_1 + - readline=8.2=h5eee18b_0 + - reproc=14.2.4=h295c915_1 + - reproc-cpp=14.2.4=h295c915_1 + - requests=2.31.0=py310h06a4308_0 + - ruamel.yaml=0.17.21=py310h5eee18b_0 + - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 + - setuptools=68.0.0=py310h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - sympy=1.11.1=py310h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchtriton=2.1.0=py310 + - torchvision=0.16.0=py310_cu121 + - tqdm=4.65.0=py310h2f386ee_0 + - truststore=0.8.0=py310h06a4308_0 + - typing_extensions=4.7.1=py310h06a4308_0 + - tzdata=2023c=h04d1e81_0 + - urllib3=1.26.16=py310h06a4308_0 + - wheel=0.41.2=py310h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - yaml-cpp=0.7.0=h295c915_1 + - zlib=1.2.13=h5eee18b_0 + - zstandard=0.19.0=py310h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==2.0.0 + - cachetools==5.3.1 + - cycler==0.12.1 + - easydict==1.10 + - einops==0.7.0 + - fonttools==4.43.1 + - fsspec==2023.9.2 + - google-auth==2.23.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.59.0 + - kiwisolver==1.4.5 + - markdown==3.4.4 + - matplotlib==3.5.1 + - oauthlib==3.2.2 + - opencv-python== + - protobuf==4.24.4 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scipy==1.11.3 + - six==1.16.0 + - tensorboard==2.14.1 + - tensorboard-data-server==0.7.1 + - tifffile==2023.9.26 + - timm==0.6.7 + - werkzeug==3.0.0 +prefix: /hpc/home/leonardo.rossi/.conda/envs/l1bsr3 diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..bbabfe5 --- /dev/null +++ b/environment.yml @@ -0,0 +1,258 @@ +name: swin2_mose_env +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.19.1=h5eee18b_0 + - ca-certificates=2023.08.22=h06a4308_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.2.140=0 + - cuda-runtime=12.1.0=0 + - ffmpeg=4.3=hf484d3e_0 + - fmt=9.1.0=hdb19cb5_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - icu=73.1=h6a678d5_0 + - intel-openmp=2023.1.0=hdb19cb5_46305 + - jpeg=9e=h5eee18b_1 + - jsonpatch=1.32=pyhd3eb1b0_0 + - jsonpointer=2.1=pyhd3eb1b0_0 + - krb5=1.20.1=h143b758_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libarchive=3.6.2=h6ac8c49_2 + - libcublas= + - libcufft= + - libcufile= + - libcurand= + - libcurl=8.2.1=h251f7ec_0 + - libcusolver= + - libcusparse= + - libdeflate=1.17=h5eee18b_1 + - libedit=3.1.20221030=h5eee18b_0 + - libev=4.33=h7f8727e_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libmamba=1.5.1=haf1ee3a_0 + - libnghttp2=1.52.0=h2d74bed_1 + - libnpp= + - libnvjitlink=12.1.105=0 + - libnvjpeg= + - libpng=1.6.39=h5eee18b_0 + - libsolv=0.7.24=he621ea3_0 + - libssh2=1.10.0=hdbd6064_2 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxml2=2.10.4=hf1b16e4_1 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2023.1.0=h213fc3f_46343 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - numpy-base=1.26.0=py310hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.11=h7f8727e_2 + - pcre2=10.42=hebb0a14_0 + - pybind11-abi=4=hd3eb1b0_1 + - pycparser=2.21=pyhd3eb1b0_0 + - python=3.10.13=h955ad1f_0 + - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - reproc=14.2.4=h295c915_1 + - reproc-cpp=14.2.4=h295c915_1 + - ruamel.yaml=0.17.21=py310h5eee18b_0 + - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 + - sqlite=3.41.2=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchtriton=2.1.0=py310 + - typing_extensions=4.7.1=py310h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - yaml-cpp=0.7.0=h295c915_1 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==2.0.0 + - aenum==3.1.15 + - affine==2.4.0 + - aiohttp==3.9.0 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.9.3 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.1.0 + - bitsandbytes==0.41.2.post2 + - boltons==23.0.0 + - brotlipy==0.7.0 + - cachetools==5.3.1 + - certifi==2023.7.22 + - cffi==1.15.1 + - click==8.1.7 + - click-plugins==1.1.1 + - cligj==0.7.2 + - cmake==3.25.0 + - conda==23.9.0 + - conda-libmamba-solver==23.9.1 + - conda-package-handling==2.2.0 + - conda-package-streaming==0.9.0 + - cryptography==41.0.3 + - cycler==0.12.1 + - decorator==5.1.1 + - dill==0.3.8 + - docstring-parser==0.15 + - easydict==1.10 + - efficientnet-pytorch==0.7.1 + - einops==0.7.0 + - exceptiongroup==1.1.3 + - executing==2.0.1 + - filelock==3.9.0 + - fiona==1.9.5 + - fonttools==4.43.1 + - frozenlist==1.4.0 + - fsspec==2023.9.2 + - gmpy2==2.1.2 + - google-auth==2.23.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.59.0 + - huggingface-hub==0.19.4 + - hydra-core==1.3.2 + - idna==3.4 + - importlib-resources==6.1.1 + - ipdb==0.13.13 + - ipython==8.17.2 + - jedi==0.19.1 + - jinja2==3.1.2 + - jsonargparse==4.27.0 + - jsonschema==4.20.0 + - jsonschema-specifications==2023.11.1 + - kiwisolver==1.4.5 + - kornia==0.7.0 + - libmambapy==1.5.1 + - lightly==1.4.21 + - lightly-utils==0.0.2 + - lightning==2.1.2 + - lightning-utilities==0.10.0 + - lit==15.0.7 + - markdown==3.4.4 + - markdown-it-py==3.0.0 + - markupsafe==2.1.1 + - matplotlib==3.5.1 + - matplotlib-inline==0.1.6 + - mdurl==0.1.2 + - mkl-fft==1.3.8 + - mkl-random==1.2.4 + - mkl-service==2.4.0 + - mpmath==1.3.0 + - msgpack==1.0.7 + - multidict==6.0.4 + - munch==4.0.0 + - networkx==3.1 + - ninja== + - numpy==1.26.0 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - opencv-python== + - packaging==23.1 + - pandas==2.1.3 + - parso==0.8.3 + - pexpect==4.8.0 + - pillow==10.0.1 + - pip==23.2.1 + - piq==0.8.0 + - pluggy==1.0.0 + - pretrainedmodels==0.7.4 + - prompt-toolkit==3.0.41 + - protobuf==4.24.4 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyarrow==14.0.1 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pycosat==0.6.6 + - pydantic==1.10.13 + - pygments==2.17.1 + - pyopenssl==23.2.0 + - pyparsing==3.1.1 + - pyproj==3.6.1 + - pysocks==1.7.1 + - python-dateutil==2.8.2 + - pytorch-lightning==2.1.2 + - pytz==2023.3.post1 + - pyyaml==6.0 + - rasterio==1.3.9 + - ray==2.8.0 + - referencing==0.31.0 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rich==13.7.0 + - rpds-py==0.13.1 + - rsa==4.9 + - rtree==1.1.0 + - ruamel-yaml==0.17.21 + - ruamel-yaml-clib==0.2.6 + - safetensors==0.4.0 + - scipy==1.11.3 + - segmentation-models-pytorch==0.3.3 + - setuptools==68.0.0 + - shapely==2.0.2 + - six==1.16.0 + - snuggs==1.4.7 + - stack-data==0.6.3 + - sympy==1.11.1 + - tensorboard==2.14.1 + - tensorboard-data-server==0.7.1 + - tensorboardx== + - tifffile==2023.9.26 + - timm==0.9.2 + - tomli==2.0.1 + - torch==2.0.0+cu118 + - torchgeo==0.5.1 + - torchmetrics==1.2.0 + - torchvision==0.15.0+cu118 + - tqdm==4.65.0 + - traitlets==5.13.0 + - triton==2.0.0 + - truststore==0.8.0 + - typeshed-client==2.4.0 + - typing-extensions==4.7.1 + - tzdata==2023.3 + - urllib3==1.26.16 + - wcwidth==0.2.10 + - werkzeug==3.0.0 + - wheel==0.41.2 + - yarl==1.9.2 + - zstandard==0.19.0 +prefix: /home/hachreak/miniconda2/envs/swin2_mose_env + diff --git a/images/fig1.png b/images/fig1.png new file mode 100644 index 0000000..6b27b49 Binary files /dev/null and b/images/fig1.png differ diff --git a/images/fig2.png b/images/fig2.png new file mode 100644 index 0000000..4097c31 Binary files /dev/null and b/images/fig2.png differ diff --git a/images/fig3.png b/images/fig3.png new file mode 100644 index 0000000..cbd473c Binary files /dev/null and b/images/fig3.png differ diff --git a/output/.gitignore b/output/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/scripts/get_all_metrics.sh b/scripts/get_all_metrics.sh new file mode 100755 index 0000000..2725d47 --- /dev/null +++ b/scripts/get_all_metrics.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +VRS=$1 +EPOCH_INIT=${2:-0} +DATASET=${3:-sen2venus} + +HPC=${HPC:-hpc} + +EPS=`ls -1 output/${HPC}/${DATASET}_v${VRS}/eval| awk -F"-" '{print $2}' | awk -F. '{print $1}'|sort -n` +# echo $EPS + +echo "eval ${DATASET} v$VRS" +for EPOCH in $EPS; do + echo "epoch $EPOCH" + if [ $EPOCH -lt $EPOCH_INIT ]; then + echo "skip epoch $EPOCH" + continue + fi + RESULTS=`./scripts/get_metrics.sh $VRS $EPOCH $DATASET` + # echo $RESULTS + + R_EPOCHS="$R_EPOCHS\n$EPOCH" + + R_PSNR="$R_PSNR\n`echo $RESULTS | awk '{print $1}'`" + R_SSIM="$R_SSIM\n`echo $RESULTS | awk '{print $2}'`" + R_CC="$R_CC\n`echo $RESULTS | awk '{print $3}'`" + R_RMSE="$R_RMSE\n`echo $RESULTS | awk '{print $4}'`" + R_SAM="$R_SAM\n`echo $RESULTS | awk '{print $5}'`" + R_ERGAS="$R_ERGAS\n`echo $RESULTS | awk '{print $6}'`" +done + +echo "########" +echo "" +echo "epochs" +echo -e $R_EPOCHS +echo "" +echo "psnr" +echo -e $R_PSNR +echo "" +echo "ssim" +echo -e $R_SSIM +echo "" +echo "cc" +echo -e $R_CC +echo "" +echo "rmse" +echo -e $R_RMSE +echo "" +echo "sam" +echo -e $R_SAM +echo "" +echo "ergas" +echo -e $R_ERGAS diff --git a/scripts/get_metrics.sh b/scripts/get_metrics.sh new file mode 100755 index 0000000..1c758d7 --- /dev/null +++ b/scripts/get_metrics.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +VRS=$1 +EPOCH=$2 +DATASET=${3:-sen2venus} + +HPC=${HPC:-hpc} + +IFS=$'\n ' +###### + +CONFIG_FILE=cfgs/${DATASET}_v${VRS}.yml +OUT_DIR=output/${HPC}/${DATASET}_v${VRS}/ + +RESULT=`python src/main.py --phase test --config $CONFIG_FILE --output $OUT_DIR --epoch ${EPOCH} | grep PSNR -A1 |tail -n1` + +echo $RESULT | awk '{print $1}' +echo $RESULT | awk '{print $2}' +echo $RESULT | awk '{print $3}' +echo $RESULT | awk '{print $4}' +echo $RESULT | awk '{print $5}' +echo $RESULT | awk '{print $6}' diff --git a/scripts/sen2venus/rebuild.py b/scripts/sen2venus/rebuild.py new file mode 100644 index 0000000..95bc027 --- /dev/null +++ b/scripts/sen2venus/rebuild.py @@ -0,0 +1,78 @@ +import pandas as pd +import os +import argparse +import torch + +from functools import lru_cache + + +def parse_configs(): + parser = argparse.ArgumentParser(description='Sen2Ven rebuilder') + + parser.add_argument('--data', type=str, default='./data/sen2venus', + help='Directory where original data are stored.') + parser.add_argument('--output', type=str, default='./data/sen2venus/split', + help='Directory where save the output files' + ' and load generated csv files.') + + args = parser.parse_args() + return args + + +@lru_cache(maxsize=5) +def cached_torch_load(filename): + print('load pt file {}'.format(filename)) + return torch.load(filename) + + +def load_rows(row): + cols = [name for name in row.keys() if name.startswith('tensor_')] + return { + col: cached_torch_load(os.path.join( + args.data, row['place'], row[col] + ))[row['index']].clone() + for col in cols + } + + +def get_filename(index, row): + return '{:06d}_{}_{}.pt'.format(index, row['place'], row['date']) + + +def get_dataset_type(filename): + _, name = os.path.split(filename) + name, _ = os.path.splitext(name) + return name + + +def rebuild(input_filename, args): + dtype = get_dataset_type(input_filename) + out_dir = os.path.join(args.output, dtype) + os.makedirs(out_dir, exist_ok=True) + + print('load {}'.format(input_filename)) + df = pd.read_csv(input_filename) + df = df.sort_values(['place', 'date']) + for index in range(len(df)): + row = df.iloc[index] + tensors = load_rows(row) + # build filename + filename = get_filename(index, row) + # save tensor + fname = os.path.join(out_dir, filename) + print('save {}'.format(fname)) + torch.save(tensors, fname) + + +if __name__ == "__main__": + # parse input arguments + args = parse_configs() + for k, v in vars(args).items(): + print('{}: {}'.format(k, v)) + + filename = os.path.join(args.output, 'test.csv') + print('rebuild test dataset..') + rebuild(filename, args) + filename = os.path.join(args.output, 'train.csv') + print('rebuild train dataset..') + rebuild(filename, args) diff --git a/scripts/sen2venus/split.py b/scripts/sen2venus/split.py new file mode 100644 index 0000000..32042dc --- /dev/null +++ b/scripts/sen2venus/split.py @@ -0,0 +1,106 @@ +import os +import pandas as pd +import numpy as np +import argparse + +from functools import partial +from sklearn.model_selection import train_test_split + + +def get_list_files(dir_name): + for dirpath, dirs, files in os.walk(dir_name): + for filename in files: + if filename.endswith('index.csv'): + fname = os.path.join(dirpath, filename) + yield fname + + +def get_info_by_id(dataframes, id_): + return dataframes.iloc[np.where(dataframes['start'].values <= id_)[0][-1]] + + +def read_csv_file(): + read_csv = partial(pd.read_csv, sep='\s+') + + def f(fnames): + for fname in fnames: + df = read_csv(fname) + yield fname, df + return f + + +def split_dataset(dir_name, test_size=.2, seed=42): + fnames = list(get_list_files(dir_name)) + dirnames = [] + dataframes = [] + pairs_counter = 0 + for fname, df in read_csv_file()(fnames): + # df = read_csv(fname) + dirname = os.path.split(os.path.split(fname)[0])[1] + # collect csv rows + dataframes.append(df) + # count how many examples on this csv file + count_examples = df['nb_patches'].sum() + # collect directory where the patch is + dirnames.extend([dirname] * count_examples) + # count total (x, y) pairs + pairs_counter += count_examples + dataframes = pd.concat(dataframes) + y = np.array(dirnames) + # index list to shuffle and split in training / test dataset + ids = np.array(list(range(pairs_counter))) + # nb_patches cumulative column (useful to get file from the index) + start_v = [0] + start_v[1:] = dataframes['nb_patches'].cumsum().iloc[:-1] + dataframes['start'] = start_v + # split them + X_train, X_test, y_train, y_test = train_test_split( + ids, y, test_size=test_size, random_state=seed, stratify=y) + return dataframes, X_train, X_test, y_train, y_test + + +def collect_frames_by_ids(dataframes, ids, y): + rows = [get_info_by_id(dataframes, id_).to_dict() for id_ in ids] + df = pd.DataFrame(rows) + df['index'] = ids - df['start'] + df['place'] = y + return df + + +def parse_configs(): + parser = argparse.ArgumentParser(description='Sen2Ven splitter') + + parser.add_argument('--seed', type=int, default=123, metavar='N', + help='random seed (default: 123)') + parser.add_argument('--test_size', type=float, default=.2, metavar='N', + help='testi size (default 0.2)') + parser.add_argument('--input', type=str, default='./data/sen2venus', + help='Directory where original data are stored.') + parser.add_argument('--output', type=str, default='./data/sen2venus/split', + help='Directory where save the output files.') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + # parse input arguments + args = parse_configs() + for k, v in vars(args).items(): + print('{}: {}'.format(k, v)) + + dataframes, X_train, X_test, y_train, y_test = split_dataset( + dir_name=args.input, test_size=args.test_size, seed=args.seed) + + df_train = collect_frames_by_ids(dataframes, X_train, y_train) + df_test = collect_frames_by_ids(dataframes, X_test, y_test) + + os.makedirs(args.output, exist_ok=True) + # save train csv file + train_file = os.path.join(args.output, 'train.csv') + print('save train file: {}'.format(train_file)) + df_train.to_csv(train_file) + # save test csv file + test_file = os.path.join(args.output, 'test.csv') + print('save train file: {}'.format(test_file)) + df_test.to_csv(test_file) diff --git a/src/chk_loader.py b/src/chk_loader.py new file mode 100644 index 0000000..6d0bfd0 --- /dev/null +++ b/src/chk_loader.py @@ -0,0 +1,60 @@ +import os +import torch +import numpy as np + + +def get_last_epoch(filenames): + epochs = [int(name.split('-')[1].split('.')[0]) for name in filenames] + return filenames[np.array(epochs).argsort()[-1]] + + +def load_checkpoint(cfg): + dir_chk = os.path.join(cfg.output, 'checkpoints') + if cfg.epoch != -1: + path = os.path.join(dir_chk, 'model-{:02d}.pt'.format(cfg.epoch)) + else: + try: + fnames = os.listdir(dir_chk) + path = get_last_epoch(fnames) + path = os.path.join(dir_chk, path) + except IndexError: + raise FileNotFoundError() + # load checkpoint + print('load file {}'.format(path)) + if not os.path.exists(path): + raise FileNotFoundError() + return torch.load(path, map_location=cfg.device) + + +def load_state_dict_model(model, optimizer, checkpoint): + print('load model state') + model.load_state_dict(checkpoint['model_state_dict']) + print('load optimizer state') + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + return checkpoint['epoch'] + 1, checkpoint['index'] + + +def load_state_dict_model_only(model, checkpoint): + print('load model state') + model.load_state_dict(checkpoint['model_state_dict']) + + return checkpoint['epoch'] + 1, checkpoint['index'] + + +def save_state_dict_model(model, optimizer, epoch, index, cfg): + # save checkpoint + n_epoch = epoch + 1 + if (n_epoch) % cfg.snapshot_interval == 0: + dir_chk = os.path.join(cfg.output, 'checkpoints') + os.makedirs(dir_chk, exist_ok=True) + path = os.path.join(dir_chk, 'model-{:02d}.pt'.format(n_epoch)) + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'index': index, + } + + torch.save(checkpoint, path) diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..29d10cb --- /dev/null +++ b/src/config.py @@ -0,0 +1,135 @@ +import yaml +import os + +from easydict import EasyDict +from functools import reduce + + +class DotDict(EasyDict): + def __getattr__(self, k): + try: + v = self[k] + except KeyError: + return super().__getattr__(k) + if isinstance(v, dict): + return DotDict(v) + return v + + def __setitem__(self, k, v): + if isinstance(k, str) and '.' in k: + k = k.split('.') + if isinstance(k, (list, tuple)): + last = reduce(lambda d, kk: d[kk], k[:-1], self) + last[k[-1]] = v + return + return super().__setitem__(k, v) + + def __getitem__(self, k): + if isinstance(k, str) and '.' in k: + k = k.split('.') + if isinstance(k, (list, tuple)): + return reduce(lambda d, kk: d[kk], k, self) + return super().__getitem__(k) + + # def get(self, k, default=None): + # if isinstance(k, str) and '.' in k: + # try: + # return self[k] + # except KeyError: + # return default + # return super().get(k, default) + + def update(self, mydict): + for k, v in mydict.items(): + self[k] = v + + +def parse_config(args): + print('load config file ', args.config) + cfg = EasyDict(load_config(args.config)) + + var_args = vars(args) + + # check batch_size to overwrite only if defined + if var_args['batch_size'] is None: + del var_args['batch_size'] + + cfg.update(var_args) + + # backup the config file + os.makedirs(cfg.output, exist_ok=True) + with open(os.path.join(cfg.output, 'cfg.yml'), 'w') as bkfile: + yaml.dump(cfg, bkfile, default_flow_style=False) + + # load from node if exists + load_node_dataset(cfg) + # load from nvme if exists + load_nvme_dataset(cfg) + + return cfg + + +def load_config(path, default_path=None): + ''' Loads config file. + + Args: + path (str): path to config file + default_path (bool): whether to use default path + ''' + # Load configuration from file itself + with open(path, 'r') as f: + cfg_special = yaml.safe_load(f) + + # Check if we should inherit from a config + inherit_from = cfg_special.get('__base__') + + # If yes, load this config first as default + # If no, use the default_path + if inherit_from is not None: + dirname = os.path.split(path)[0] + cfg = load_config(os.path.join(dirname, inherit_from), default_path) + elif default_path is not None: + with open(default_path, 'r') as f: + cfg = yaml.load(f) + else: + cfg = dict() + + # Include main configuration + update_recursive(cfg, cfg_special) + + return cfg + + +def update_recursive(dict1, dict2): + ''' Update two config dictionaries recursively. + + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be used + + ''' + for k, v in dict2.items(): + # Add item if not yet in dict1 + if k not in dict1: + dict1[k] = None + # Update + if isinstance(dict1[k], dict): + update_recursive(dict1[k], v) + else: + dict1[k] = v + + +def load_nvme_dataset(cfg): + root_path_nvme = cfg.dataset.root_path.rstrip('/') + '_nvme' + if os.path.exists(root_path_nvme): + # update path to read from nvme disk + cfg.dataset.root_path = root_path_nvme + print('load dataset from {}'.format(cfg.dataset.root_path)) + + +def load_node_dataset(cfg): + root_path_node = cfg.dataset.root_path.rstrip('/') + '_node' + if os.path.exists(root_path_node): + # update path to read from nvme disk + cfg.dataset.root_path = root_path_node + print('load dataset from {}'.format(cfg.dataset.root_path)) diff --git a/src/datasets/oli2msi.py b/src/datasets/oli2msi.py new file mode 100644 index 0000000..0ff3b53 --- /dev/null +++ b/src/datasets/oli2msi.py @@ -0,0 +1,118 @@ +# +# Source code: https://github.com/wjwjww/OLI2MSI +# + +import os +import rasterio +import torch + +from torch.utils.data import Dataset, DataLoader + +from utils import load_fun + + +def load_file(filename): + with rasterio.open(filename) as src: + file_ = src.read() + return file_ + + +def _load_files_dir(fdir): + print('load files from {}'.format(fdir)) + file_list = [] + for dirpath, dirs, files in os.walk(fdir): + for filename in files: + if filename.endswith('.TIF'): + file_list.append(os.path.join(dirpath, filename)) + return sorted(file_list) + + +class OLI2MSI(Dataset): + def __init__(self, cfg, is_training=True): + print('dataset OLI2MSI') + self.root_path = cfg.dataset.root_path + self.relname = 'train' + if not is_training: + self.relname = 'test' + self.fdir_lr = os.path.join(self.root_path, self.relname + '_lr') + self.fdir_hr = os.path.join(self.root_path, self.relname + '_hr') + self._load_files() + + def _load_files(self): + self.files_lr = _load_files_dir(self.fdir_lr) + self.files_hr = _load_files_dir(self.fdir_hr) + + def __len__(self): + return len(self.files_lr) + + def __getitem__(self, index): + file_lr = load_file(self.files_lr[index]) + file_hr = load_file(self.files_hr[index]) + return file_lr, file_hr + + +def load_dataset(cfg, only_test=False, concat_datasets=False): + collate_fn = cfg.dataset.get('collate_fn') + if collate_fn is not None: + collate_fn = load_fun(collate_fn) + + persistent_workers = False + if cfg.num_workers > 0: + persistent_workers = True + + train_dset = None + train_dloader = None + concat_dloader = None + + if concat_datasets: + train_dset = OLI2MSI(cfg, is_training=True) + val_dset = OLI2MSI(cfg, is_training=False) + dset = torch.utils.data.ConcatDataset([train_dset, val_dset]) + + shuffle = True + + concat_dloader = DataLoader( + dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg) + ) + + if not only_test: + train_dset = OLI2MSI(cfg, is_training=True) + + train_dloader = DataLoader( + train_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + val_dset = OLI2MSI(cfg, is_training=False) + + shuffle = False + + # TODO distribute also val_dset + val_dloader = DataLoader( + val_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg) + ) + + return train_dloader, val_dloader, concat_dloader diff --git a/src/datasets/seasonet.py b/src/datasets/seasonet.py new file mode 100644 index 0000000..9a56163 --- /dev/null +++ b/src/datasets/seasonet.py @@ -0,0 +1,131 @@ +import torch + +from torch.utils.data import Dataset, DataLoader +from torchgeo.datasets import SeasoNet + +from utils import load_fun + + +class SeasoNetDataset(Dataset): + + classes = [ + 'Continuous urban fabric', + 'Discontinuous urban fabric', + 'Industrial or commercial units', + 'Road and rail networks and associated land', + 'Port areas', + 'Airports', + 'Mineral extraction sites', + 'Dump sites', + 'Construction sites', + 'Green urban areas', + 'Sport and leisure facilities', + 'Non-irrigated arable land', + 'Vineyards', + 'Fruit trees and berry plantations', + 'Pastures', + 'Broad-leaved forest', + 'Coniferous forest', + 'Mixed forest', + 'Natural grasslands', + 'Moors and heathland', + 'Transitional woodland/shrub', + 'Beaches, dunes, sands', + 'Bare rock', + 'Sparsely vegetated areas', + 'Inland marshes', + 'Peat bogs', + 'Salt marshes', + 'Intertidal flats', + 'Water courses', + 'Water bodies', + 'Coastal lagoons', + 'Estuaries', + 'Sea and ocean', + ] + + def __init__(self, cfg, is_training=True): + print('dataset SEASONET') + self.root_path = cfg.dataset.root_path + self.relname = 'train' + if not is_training: + self.relname = 'test' + + self.dset = SeasoNet( + root=self.root_path, + split=self.relname, + **cfg.dataset.kwargs + ) + + cfg.classes = self.classes + self.kwargs = cfg.dataset.kwargs + + def __len__(self): + return len(self.dset) + + def __getitem__(self, index): + item = self.dset[index] + + if self.kwargs.get('bands') == ['20m']: + # get only B5, B6, B7, B8A because sen2venus + item['image'] = item['image'][:4] + + return item + + +def load_dataset(cfg, only_test=False, concat_datasets=False): + dataset_cls = load_fun(cfg.dataset.get('cls', 'SeasoNetDataset')) + + collate_fn = cfg.dataset.get('collate_fn') + if collate_fn is not None: + collate_fn = load_fun(collate_fn) + + persistent_workers = False + if cfg.num_workers > 0: + persistent_workers = True + + train_dset = None + train_dloader = None + concat_dloader = None + + if concat_datasets: + train_dset = dataset_cls(cfg, is_training=True) + val_dset = dataset_cls(cfg, is_training=False) + dset = torch.utils.data.ConcatDataset([train_dset, val_dset]) + concat_dloader = DataLoader( + dset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + if not only_test: + train_dset = dataset_cls(cfg, is_training=True) + train_dloader = DataLoader( + train_dset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + val_dset = dataset_cls(cfg, is_training=False) + val_dloader = DataLoader( + val_dset, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + return train_dloader, val_dloader, concat_dloader diff --git a/src/datasets/sen2venus.py b/src/datasets/sen2venus.py new file mode 100644 index 0000000..fb74206 --- /dev/null +++ b/src/datasets/sen2venus.py @@ -0,0 +1,112 @@ +import os +import torch + +from functools import lru_cache +from torch.utils.data import Dataset, DataLoader + +from utils import load_fun + + +@lru_cache(maxsize=10) +def cached_torch_load(filename): + return torch.load(filename) + + +class Sen2VenusDataset(Dataset): + def __init__(self, cfg, is_training=True): + self.root_path = cfg.dataset.root_path + self.relname = 'train' + if not is_training: + self.relname = 'test' + self.fdir = os.path.join(self.root_path, self.relname) + self._load_files() + self._filter_files(cfg) + + def _filter_files(self, cfg): + places = cfg.dataset.get('places') + if places is not None and not places == []: + self.files = list(filter( + lambda name: name.lower().split('_')[1] in places, + self.files)) + + def _load_files(self): + print('load {} files from {}'.format(self.relname, self.fdir)) + self.files = [] + for dirpath, dirs, files in os.walk(self.fdir): + for filename in files: + if filename.endswith('.pt'): + self.files.append(os.path.join(dirpath, filename)) + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + return cached_torch_load(self.files[index]) + + +def load_dataset(cfg, only_test=False, concat_datasets=False): + dataset_cls = globals()[cfg.dataset.get('cls', 'Sen2VenusDataset')] + hr_name = cfg.dataset.hr_name + lr_name = cfg.dataset.lr_name + + collate_fn = cfg.dataset.get('collate_fn') + if collate_fn is not None: + collate_fn = load_fun(collate_fn) + + persistent_workers = False + if cfg.num_workers > 0: + persistent_workers = True + + train_dset = None + train_dloader = None + concat_dloader = None + + if concat_datasets: + train_dset = dataset_cls(cfg, is_training=True) + val_dset = dataset_cls(cfg, is_training=False) + dset = torch.utils.data.ConcatDataset([train_dset, val_dset]) + + shuffle = True + + concat_dloader = DataLoader( + dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg, hr_name=hr_name, lr_name=lr_name) + ) + + if not only_test: + train_dset = dataset_cls(cfg, is_training=True) + + train_dloader = DataLoader( + train_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg, hr_name=hr_name, lr_name=lr_name) + ) + + val_dset = dataset_cls(cfg, is_training=False) + + val_dloader = DataLoader( + val_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg, hr_name=hr_name, lr_name=lr_name) + ) + + return train_dloader, val_dloader, concat_dloader diff --git a/src/debug.py b/src/debug.py new file mode 100644 index 0000000..2f59313 --- /dev/null +++ b/src/debug.py @@ -0,0 +1,60 @@ +import torch + +from time import time + +from utils import load_fun + + +def log_losses(losses, phase, writer, index): + # Write the data during training to the training log file + for k, v in losses.items(): + writer.add_scalar("{}/{}".format(phase, k), v.item(), index) + + +def _get_abs_weights_grads(model): + return torch.cat([ + p.grad.detach().view(-1) for p in model.parameters() + if p.requires_grad + ]).abs() + + +def _get_abs_weights(model): + return torch.cat([ + p.detach().view(-1) for p in model.parameters() + if p.requires_grad + ]).abs() + + +def _perf_measure(model, data, count, warm_count=0): + for _ in range(warm_count): + model(data) + start = time() + for _ in range(count): + model(data) + end = time() + return (end - start) / count + + +def measure_avg_time(cfg): + vis = cfg.get('visualize', {}) + model = load_fun(vis.get('model'))(cfg) + x = torch.rand([cfg.batch_size, ] + vis.input_shape).to(cfg.device) + return _perf_measure(model, x, cfg.repeat_times, cfg.warm_times) + + +def log_hr_stats(lr, sr, hr, writer, index, cfg): + if cfg.get('debug', {}).get('sr_hr', False): + def log_delta(delta, name): + writer.add_scalar('stats/mean_{}'.format(name), delta.mean(), + index) + writer.add_scalar('stats/max_{}'.format(name), delta.max(), index) + + q_0_99 = torch.quantile(delta, 0.99, interpolation='nearest') + q_0_999 = torch.quantile(delta, 0.999, interpolation='nearest') + writer.add_scalar('stats/q099_{}'.format(name), q_0_99, index) + writer.add_scalar('stats/q0999_{}'.format(name), q_0_999, index) + + sf = cfg.metrics.upscale_factor + upscale = torch.nn.Upsample(scale_factor=sf, mode='bicubic') + log_delta((sr - upscale(lr)).abs(), 'sr_lr') + log_delta((sr - hr).abs(), 'sr_hr') diff --git a/src/imgproc.py b/src/imgproc.py new file mode 100644 index 0000000..3ce0672 --- /dev/null +++ b/src/imgproc.py @@ -0,0 +1,59 @@ +# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import math +import os +import random +from typing import Any + +import cv2 +import numpy as np +import torch +from numpy import ndarray +from torch import Tensor +import tifffile + + +def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool): + """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type + + Args: + tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1] + range_norm (bool): Scale [-1, 1] data to between [0, 1] + half (bool): Whether to convert torch.float32 similarly to torch.half type. + + Returns: + image (np.ndarray): Data types supported by PIL or OpenCV + + Examples: + >>> example_image = cv2.imread("lr_image.bmp") + >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False) + + """ + if range_norm: + tensor = tensor.add(1.0).div(2.0) + if half: + tensor = tensor.half() + + # image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8") + # i dati non vengono più convertiti in uint8 ma restano in float32 + # se l'immagine è RGB/RGBA fa la permutazione, altrimenti no + # facendo la permutazione ad un'immagine a 12 canali, il risultato SR sarebbe di dimensione 12x636 e 628 canali + if 1 < tensor.size()[1] <= 4: + # image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("float32") + image = tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy().astype("float32") + else: + # image = tensor.squeeze(0).mul(255).clamp(0, 255).cpu().numpy().astype("float32") + image = tensor.squeeze(0).clamp(0, 1).cpu().numpy().astype("float32") + + return image diff --git a/src/losses/__init__.py b/src/losses/__init__.py new file mode 100644 index 0000000..82e6363 --- /dev/null +++ b/src/losses/__init__.py @@ -0,0 +1,17 @@ +from torch import nn + +from . import metrics_loss + + +def build_losses(cfg): + losses = {} + if cfg.losses.get('with_ce_criterion', False): + losses['ce_criterion'] = nn.CrossEntropyLoss() + if cfg.losses.get('with_pixel_criterion', False): + losses['pixel_criterion'] = nn.MSELoss() + if cfg.losses.get('with_cc_criterion', False): + losses['cc_criterion'] = metrics_loss.cc_loss + if cfg.losses.get('with_ssim_criterion', False): + losses['ssim_criterion'] = metrics_loss.ssim_loss(cfg) + print(losses) + return losses diff --git a/src/losses/metrics_loss.py b/src/losses/metrics_loss.py new file mode 100644 index 0000000..e8465e5 --- /dev/null +++ b/src/losses/metrics_loss.py @@ -0,0 +1,46 @@ +import piq + +from metrics import _cc_single_torch +from utils import load_fun + + +def norm_0_to_1(fun): + def wrapper(cfg): + dset = cfg.dataset + use_minmax = cfg.dataset.get('stats').get('use_minmax', False) + denorm = load_fun(dset.get('denorm'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name) + evaluable = load_fun(dset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + + rfun = fun(cfg) + + def f(sr, hr): + hr, sr, _ = evaluable(*denorm(hr, sr, None)) + sr = sr.clamp(min=0, max=1) + hr = hr.clamp(min=0, max=1) + return rfun(sr, hr) + + return f + return wrapper + + +@norm_0_to_1 +def ssim_loss(cfg): + criterion = piq.SSIMLoss() + + def f(sr, hr): + return criterion(sr, hr) + + return f + + +def cc_loss(sr, hr): + cc_value = _cc_single_torch(sr, hr) + return 1 - ((cc_value + 1) * .5) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..d07e81b --- /dev/null +++ b/src/main.py @@ -0,0 +1,127 @@ +import argparse +import torch + +from config import parse_config +from utils import load_fun, set_deterministic +from visualize import main as vis_main +from validation import main as val_main, print_metrics as val_print_metrics, \ + load_metrics +from debug import measure_avg_time + + +def parse_configs(): + parser = argparse.ArgumentParser(description='SuperRes model') + # For training and testing + parser.add_argument('--config', + default="cfgs/sen2venus_v26_8.yml", + help='Configuration file.') + parser.add_argument('--phase', + default='train', + choices=['train', 'test', 'mean_std', 'vis', + 'plot_data', 'avg_time'], + help='Training or testing or play phase.') + parser.add_argument('--seed', + type=int, + default=123, + metavar='N', + help='random seed (default: 123)') + parser.add_argument('--batch_size', + type=int, + default=None, + metavar='B', + help='Batch size. If defined, overwrite cfg file.') + help_num_workers = 'The number of workers to load dataset. Default: 0' + parser.add_argument('--num_workers', + type=int, + default=0, + metavar='N', + help=help_num_workers) + parser.add_argument('--output', + default='./output/sen2venus_v26_8', + help='Directory where save the output.') + help_epoch = 'The epoch to restart from (training) or to eval (testing).' + parser.add_argument('--epoch', + type=int, + default=-1, + help=help_epoch) + parser.add_argument('--epochs', + type=int, + default=200, + metavar='N', + help='number of epoches (default: 50)') + help_snapshot = 'The epoch interval of model snapshot (default: 10)' + parser.add_argument('--snapshot_interval', + type=int, + default=1, + metavar='N', + help=help_snapshot) + parser.add_argument("--num_images", + type=int, + default=10, + help="Number of images to plot") + parser.add_argument('--eval_method', + default=None, type=str, + help='Non-DL method to use on evaluation.') + parser.add_argument('--repeat_times', + type=int, + default=1000, + help='Measure times repeating model call') + help_warm = 'Warm model calling it before starting the measure' + parser.add_argument('--warm_times', + type=int, + default=10, + help=help_warm) + parser.add_argument('--dpi', + type=int, + default=2400, + help="dpi in png output file.") + + args = parser.parse_args() + return parse_config(args) + + +def main(cfg): + cfg.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + load_dataset_fun = load_fun(cfg.dataset.get( + 'load_dataset', 'datasets.sen2venus.load_dataset')) + + if cfg.phase == 'avg_time': + print(measure_avg_time(cfg)) + elif cfg.phase == 'train': + train_fun = load_fun(cfg.get('train', 'srgan.training.train')) + train_dloader, val_dloader, _ = load_dataset_fun(cfg) + train_fun = load_fun(cfg.get('train', 'srgan.training.train')) + train_fun(train_dloader, val_dloader, cfg) + elif cfg.phase == 'mean_std': + if 'stats' in cfg.dataset.keys(): + cfg.dataset.pop('stats') + _, _, concat_dloader = load_dataset_fun(cfg, concat_datasets=True) + fun = load_fun(cfg.get(cfg.phase)) + fun(concat_dloader, cfg) + elif cfg.phase == 'plot_data': + _, _, concat_dloader = load_dataset_fun(cfg, concat_datasets=True) + fun = load_fun(cfg.get(cfg.phase)) + fun(concat_dloader, cfg) + elif cfg.phase == 'vis': + cfg['batch_size'] = 1 + vis_main(cfg) + elif cfg.phase == 'test': + try: + if cfg.eval_method is not None: + raise FileNotFoundError() + metrics = load_metrics(cfg) + except FileNotFoundError: + _, val_dloader, _ = load_dataset_fun(cfg, only_test=True) + metrics = val_main( + val_dloader, cfg, save_metrics=cfg.eval_method is None) + val_print_metrics(metrics) + + +if __name__ == "__main__": + # parse input arguments + cfg = parse_configs() + # fix random seed + set_deterministic(cfg.seed) + # run main + main(cfg) diff --git a/src/main_ssegm.py b/src/main_ssegm.py new file mode 100644 index 0000000..5070ce2 --- /dev/null +++ b/src/main_ssegm.py @@ -0,0 +1,101 @@ +import argparse +import torch + +from config import parse_config +from datasets.seasonet import load_dataset +from utils import load_fun, set_deterministic +from validation import print_metrics as val_print_metrics +from semantic_segm.validation import load_metrics, main as val_main +from semantic_segm.visualize import main as vis_main + + +def parse_configs(): + parser = argparse.ArgumentParser(description='Semantic Segmentation model') + # For training and testing + parser.add_argument('--config', + default="cfgs/seasonet_v1.yml", + help='Configuration file.') + parser.add_argument('--phase', + default='test', + choices=['train', 'test', 'mean_std', 'vis'], + help='Training or testing or play phase.') + parser.add_argument('--seed', + type=int, + default=123, + metavar='N', + help='random seed (default: 123)') + parser.add_argument('--batch_size', + type=int, + default=None, + metavar='B', + help='Batch size. If defined, overwrite cfg file.') + help_num_workers = 'The number of workers to load dataset. Default: 0' + parser.add_argument('--num_workers', + type=int, + default=0, + metavar='N', + help=help_num_workers) + parser.add_argument('--output', + default='./output/demo', + help='Directory where save the output.') + help_epoch = 'The epoch to restart from (training) or to eval (testing).' + parser.add_argument('--epoch', + type=int, + default=-1, + help=help_epoch) + parser.add_argument('--epochs', + type=int, + default=200, + metavar='N', + help='number of epoches (default: 50)') + help_snapshot = 'The epoch interval of model snapshot (default: 10)' + parser.add_argument('--snapshot_interval', + type=int, + default=1, + metavar='N', + help=help_snapshot) + parser.add_argument("--num_images", + type=int, + default=10, + help="Number of images to plot") + parser.add_argument('--hide_sr', + action='store_true', + default=False) + parser.add_argument('--dpi', + type=int, + default=2400, + help="dpi in png output file.") + + args = parser.parse_args() + return parse_config(args) + + +if __name__ == "__main__": + # parse input arguments + cfg = parse_configs() + # set random seed + set_deterministic(cfg.seed) + + cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if cfg.phase == 'train': + train_dloader, val_dloader, _ = load_dataset(cfg) + train_fun = load_fun(cfg.train) + train_fun(train_dloader, val_dloader, cfg) + elif cfg.phase == 'mean_std': + if 'stats' in cfg.dataset.keys(): + cfg.dataset.pop('stats') + _, _, concat_dloader = load_dataset(cfg, concat_datasets=True) + fun = load_fun(cfg.get(cfg.phase)) + fun(concat_dloader, cfg) + elif cfg.phase == 'vis': + vis_main(cfg) + elif cfg.phase == 'test': + try: + metrics = load_metrics(cfg) + except FileNotFoundError: + # validate if not already done + _, val_dloader, _ = load_dataset(cfg) + metrics = val_main(val_dloader, cfg) + # print metrics + val_print_metrics(metrics) diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000..c88f74c --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,317 @@ +# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections.abc +import math +import typing +import warnings +from itertools import repeat +from typing import Any + +import cv2 +import numpy as np +import torch +import piq +from numpy import ndarray +from scipy.io import loadmat +from scipy.ndimage.filters import convolve +from scipy.special import gamma +from torch import nn +from torch.nn import functional as F + + +class piq_metric(object): + def __init__(self, cfg): + pass + + def to(self, device): + return self + + def __call__(self, x, y): + x = x.clone() + x[x < 0] = 0. + x[x > 1] = 1. + return self._metric(x, y) + + def _metric(self, x, y): + pass + + +class piq_psnr(piq_metric): + def _metric(self, x, y): + return piq.psnr(x, y) + + +class piq_ssim(piq_metric): + def _metric(self, x, y): + return piq.ssim(x, y) + + +class piq_rmse(piq_metric): + def __init__(self, cfg): + self.mse = nn.MSELoss() + + def _metric(self, x, y): + return torch.sqrt(self.mse(x, y)) + + +def _ergas_single_torch(raw_tensor: torch.Tensor, + dst_tensor: torch.Tensor): + """ + Compute the ERGAS (Erreur Relative Globale Adimensionnelle De Synthèse) metric for a pair of input tensors. + + ERGAS measures the relative global error of synthesis for remote sensing or image processing tasks. + It evaluates the quality of an output image concerning a reference image, taking into account spectral bands. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared (typically the reconstructed image). + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + ERGAS (torch.Tensor): The ERGAS metric score. + + """ + # Compute the number of spectral bands + N_spectral = raw_tensor.shape[1] + + # Reshape images for processing + raw_tensor_reshaped = raw_tensor.view(N_spectral, -1) + dst_tensor_reshaped = dst_tensor.view(N_spectral, -1) + N_pixels = raw_tensor_reshaped.shape[1] + + # Assuming HR size is 256x256 and LR size is 128x128 + hr_size = torch.tensor(256).cuda() + lr_size = torch.tensor(128).cuda() + + # Calculate the beta value + beta = (hr_size / lr_size).cuda() + + # Calculate RMSE of each band + rmse = torch.sqrt(torch.nansum((dst_tensor_reshaped - raw_tensor_reshaped) ** 2, dim=1) / N_pixels) + mu_ref = torch.mean(dst_tensor_reshaped, dim=1) + + # Calculate ERGAS + ERGAS = 100 * (1 / beta ** 2) * torch.sqrt(torch.nansum(torch.div(rmse, mu_ref) ** 2) / N_spectral) + + return ERGAS + + +class ERGAS(nn.Module): + """ + PyTorch implementation of ERGAS (Erreur Relative Globale Adimensionnelle De Synthèse) metric. + + ERGAS measures the relative global error of synthesis for remote sensing or image processing tasks. + It evaluates the quality of an output image concerning a reference image, taking into account spectral bands. + + Args: + None + + Attributes: + None + + Methods: + forward(raw_tensor, dst_tensor): + Compute ERGAS metric between two input tensors representing images. + + Example: + ergas_calculator = ERGAS() + raw_image = torch.tensor(...) # Replace with your raw image data + dst_image = torch.tensor(...) # Replace with your reference image data + ergas_score = ergas_calculator(raw_image, dst_image) + print(f"ERGAS Score: {ergas_score.item()}") + + """ + + def __init__(self): + super().__init__() + + def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): + """ + Compute ERGAS metric between two input tensors representing images. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared (typically the reconstructed image). + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + ergas_metrics (torch.Tensor): The ERGAS metric score. + + Note: + ERGAS measures the relative global error of synthesis for remote sensing or image processing tasks. + It evaluates the quality of an output image concerning a reference image, taking into account spectral bands. + + """ + ergas_metrics = _ergas_single_torch(raw_tensor, dst_tensor) + return ergas_metrics + + +def _cc_single_torch(raw_tensor: torch.Tensor, + dst_tensor: torch.Tensor): + + """ + Compute the Cross-Correlation (CC) metric between two input tensors representing images. + + CC measures the similarity between two images by calculating the cross-correlation coefficient between spectral bands. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + CC (torch.Tensor): The Cross-Correlation (CC) metric score. + + """ + N_spectral = raw_tensor.shape[1] + + # Reshaping fused and reference data + raw_tensor_reshaped = raw_tensor.view(N_spectral, -1) + dst_tensor_reshaped = dst_tensor.view(N_spectral, -1) + + # Calculating mean value + mean_raw = torch.mean(raw_tensor_reshaped, 1).unsqueeze(1) + mean_dst = torch.mean(dst_tensor_reshaped, 1).unsqueeze(1) + + CC = torch.sum((raw_tensor_reshaped - mean_raw) * (dst_tensor_reshaped - mean_dst), 1) / torch.sqrt( + torch.sum((raw_tensor_reshaped - mean_raw) ** 2, 1) * torch.sum((dst_tensor_reshaped - mean_dst) ** 2, 1)) + + CC = torch.mean(CC) + + return CC + + +class CC(nn.Module): + """ + PyTorch implementation of the Cross-Correlation (CC) metric for image similarity. + + CC measures the similarity between two images by calculating the cross-correlation coefficient between spectral bands. + + Args: + None + + Attributes: + None + + Methods: + forward(raw_tensor, dst_tensor): + Compute the Cross-Correlation (CC) metric between two input tensors representing images. + + Example: + cc_calculator = CC() + raw_image = torch.tensor(...) # Replace with your raw image data + dst_image = torch.tensor(...) # Replace with your reference image data + cc_score = cc_calculator(raw_image, dst_image) + print(f"Cross-Correlation Score: {cc_score.item()}") + + """ + + def __init__(self): + super().__init__() + + def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): + """ + Compute the Cross-Correlation (CC) metric between two input tensors representing images. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + cc_metrics (torch.Tensor): The Cross-Correlation (CC) metric score. + + Note: + CC measures the similarity between two images by calculating the cross-correlation coefficient between spectral bands. + + """ + cc_metrics = _cc_single_torch(raw_tensor, dst_tensor) + return cc_metrics + + +def _sam_single_torch(raw_tensor: torch.Tensor, + dst_tensor: torch.Tensor): + """ + Compute the Spectral Angle Mapper (SAM) metric between two input tensors representing images. + + SAM measures the spectral similarity between two images by calculating the spectral angle between corresponding pixels. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + SAM (torch.Tensor): The Spectral Angle Mapper (SAM) metric score. + + """ + # Compute the number of spectral bands + N_spectral = raw_tensor.shape[1] + + # Reshape fused and reference data + raw_tensor_reshaped = raw_tensor.view(N_spectral, -1) + dst_tensor_reshaped = dst_tensor.view(N_spectral, -1) + N_pixels = raw_tensor_reshaped.shape[1] + + # Calculate the inner product + inner_prod = torch.nansum(raw_tensor_reshaped * dst_tensor_reshaped, 0) + raw_norm = torch.nansum(raw_tensor_reshaped ** 2, dim=0).sqrt() + dst_norm = torch.nansum(dst_tensor_reshaped ** 2, dim=0).sqrt() + + # Calculate SAM + SAM = torch.rad2deg(torch.nansum(torch.acos(inner_prod / (raw_norm * dst_norm))) / N_pixels) + + return SAM + + +class SAM(nn.Module): + """ + PyTorch implementation of the Spectral Angle Mapper (SAM) metric for spectral similarity. + + SAM measures the spectral similarity between two images by calculating the spectral angle between corresponding pixels. + + Args: + None + + Attributes: + None + + Methods: + forward(raw_tensor, dst_tensor): + Compute the Spectral Angle Mapper (SAM) metric between two input tensors representing images. + + Example: + sam_calculator = SAM() + raw_image = torch.tensor(...) # Replace with your raw image data + dst_image = torch.tensor(...) # Replace with your reference image data + sam_score = sam_calculator(raw_image, dst_image) + print(f"Spectral Angle Mapper Score: {sam_score.item()}") + + """ + + def __init__(self): + super().__init__() + + def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): + """ + Compute the Spectral Angle Mapper (SAM) metric between two input tensors representing images. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + sam_metrics (torch.Tensor): The Spectral Angle Mapper (SAM) metric score. + + Note: + SAM measures the spectral similarity between two images by calculating the spectral angle between corresponding pixels. + + """ + sam_metrics = _sam_single_torch(raw_tensor, dst_tensor) + return sam_metrics diff --git a/src/mods/v3.py b/src/mods/v3.py new file mode 100644 index 0000000..d821c15 --- /dev/null +++ b/src/mods/v3.py @@ -0,0 +1,162 @@ +import torch + +from torchvision import transforms +from torch.utils.data import default_collate +from tqdm import tqdm + + +def collate_fn(cfg, hr_name, lr_name): + print('collate_fn hr field: ', hr_name) + print('collate_fn lr field: ', lr_name) + + def do_nothing(x): + return x + norm_lr = do_nothing + norm_hr = do_nothing + + if 'stats' in cfg.dataset and cfg.phase != 'plot_data': + hr_mean = cfg.dataset.stats.get(hr_name).mean + hr_std = cfg.dataset.stats.get(hr_name).std + lr_mean = cfg.dataset.stats.get(lr_name).mean + lr_std = cfg.dataset.stats.get(lr_name).std + norm_hr = transforms.Normalize(mean=hr_mean, std=hr_std) + norm_lr = transforms.Normalize(mean=lr_mean, std=lr_std) + + def f(batch): + batch = default_collate(batch) + + lr = norm_lr(batch[lr_name].float()) + hr = norm_hr(batch[hr_name].float()) + + return {'lr': lr, 'hr': hr} + return f + + +def uncollate_fn(cfg, hr_name, lr_name): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view((t2.shape[:2] + (1, 1))) + return t1 + + def denorm(tensor, mean, std): + # get stats + mean = torch.tensor(mean).to(cfg.device) + std = torch.tensor(std).to(cfg.device) + # denorm + return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) + + def f(hr=None, sr=None, lr=None): + stats_hr = cfg.dataset.stats.get(hr_name) + stats_lr = cfg.dataset.stats.get(lr_name) + + if hr is not None: + hr = denorm(hr, stats_hr.mean, stats_hr.std) + if sr is not None: + sr = denorm(sr, stats_hr.mean, stats_hr.std) + if lr is not None: + lr = denorm(lr, stats_lr.mean, stats_lr.std) + + return hr, sr, lr + + return f + + +def printable(cfg, hr_name, lr_name, filter_outliers=False, use_minmax=False): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view(t1.shape + (1, 1)) + return t1 + + def _printable(tensor, stats): + if use_minmax: + max_ = to_shape(torch.tensor( + stats.max), tensor).to(cfg.device) + min_ = to_shape(torch.tensor( + stats.min), tensor).to(cfg.device) + else: + # get stats + mean = torch.tensor(stats.mean).to(cfg.device) + std = torch.tensor(stats.std).to(cfg.device) + # compute min and max + max_ = to_shape(mean + std, tensor) + min_ = to_shape(mean - std, tensor) + + # fitler outliers if needed to visualize them + if filter_outliers: + tensor = torch.min(torch.max(tensor, min_), max_) + # printable + return (tensor - min_) / (max_ - min_) + + def f(hr=None, sr=None, lr=None): + stats_hr = cfg.dataset.stats.get(hr_name) + stats_lr = cfg.dataset.stats.get(lr_name) + + if hr is not None: + hr = _printable(hr, stats_hr) + if sr is not None: + sr = _printable(sr, stats_hr) + if lr is not None: + lr = _printable(lr, stats_lr) + + return hr, sr, lr + + return f + + +class MeanStd(object): + def __init__(self, device): + self.channels_sum = 0 + self.channels_sqrd_sum = 0 + self.num_batches = 0 + self.max = torch.full((4,), -torch.inf, device=device) + self.min = torch.full((4,), torch.inf, device=device) + + def __call__(self, data): + # max + d_perm = data.permute(1, 0, 2, 3) + d_max = d_perm.reshape(data.shape[1], -1).max(dim=1)[0] + self.max = torch.where(d_max > self.max, d_max, self.max) + # min + d_min = d_perm.reshape(data.shape[1], -1).min(dim=1)[0] + self.min = torch.where(d_min <= self.min, d_min, self.min) + # mean, std + self.channels_sum += torch.mean(data, dim=[0, 2, 3]) + self.channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3]) + self.num_batches += 1 + + def get_mean_std(self): + mean = self.channels_sum / self.num_batches + std = (self.channels_sqrd_sum / self.num_batches - mean**2) ** 0.5 + return mean.tolist(), std.tolist() + + def get_min_max(self): + return self.min.tolist(), self.max.tolist() + + +def get_mean_std(train_dloader, cfg): + hr_ms = MeanStd(device=cfg.device) + lr_ms = MeanStd(device=cfg.device) + + for index, batch in tqdm( + enumerate(train_dloader), total=len(train_dloader)): + hr = batch["hr"].to(device=cfg.device, non_blocking=True).float() + lr = batch["lr"].to(device=cfg.device, non_blocking=True).float() + + hr_ms(hr) + lr_ms(lr) + + hr_mean, hr_std = hr_ms.get_mean_std() + lr_mean, lr_std = lr_ms.get_mean_std() + hr_min, hr_max = hr_ms.get_min_max() + lr_min, lr_max = lr_ms.get_min_max() + + print('HR (mean, std, min, max)') + print('mean: {},'.format(hr_mean)) + print('std: {},'.format(hr_std)) + print('min: {},'.format(hr_min)) + print('max: {}'.format(hr_max)) + print('LR (mean, std, min, max)') + print('mean: {},'.format(lr_mean)) + print('std: {},'.format(lr_std)) + print('min: {},'.format(lr_min)) + print('max: {}'.format(lr_max)) diff --git a/src/mods/v5.py b/src/mods/v5.py new file mode 100644 index 0000000..fd797bf --- /dev/null +++ b/src/mods/v5.py @@ -0,0 +1,94 @@ +import torch + +from torchvision import transforms +from torch.utils.data import default_collate +from tqdm import tqdm + +from .v3 import MeanStd + + +def collate_fn(cfg): + print('collate_fn') + + def do_nothing(x): + return x + norm_fun = do_nothing + + if 'stats' in cfg.dataset and cfg.phase != 'plot_data': + mean = cfg.dataset.stats.mean + std = cfg.dataset.stats.std + norm_fun = transforms.Normalize(mean=mean, std=std) + + def f(batch): + batch = default_collate(batch) + batch['image'] = norm_fun(batch['image'].float()) + return batch + return f + + +def uncollate_fn(cfg): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view((t2.shape[:2] + (1, 1))) + return t1 + + def denorm(tensor, mean, std): + # get stats + mean = torch.tensor(mean).to(cfg.device) + std = torch.tensor(std).to(cfg.device) + # denorm + return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) + + def f(sr, lr): + stats = cfg.dataset.stats + if sr is not None: + sr = denorm(sr.float(), stats.mean, stats.std) + lr = denorm(lr.float(), stats.mean, stats.std) + return sr, lr + + return f + + +def printable(cfg): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view(t1.shape + (1, 1)) + return t1 + + def _printable(tensor, stats): + max_ = to_shape(torch.tensor( + stats.max), tensor).to(cfg.device) + min_ = to_shape(torch.tensor( + stats.min), tensor).to(cfg.device) + + # printable + return (tensor - min_) / (max_ - min_) + + def f(sr, lr): + stats = cfg.dataset.stats + # batch = default_collate(batch) + if sr is not None: + sr = _printable(sr, stats) + lr = _printable(lr, stats) + return sr, lr + + return f + + +def get_mean_std(train_dloader, cfg): + image_ms = MeanStd(device=cfg.device) + + for index, batch in tqdm( + enumerate(train_dloader), total=len(train_dloader)): + image = batch["image"].to(device=cfg.device, non_blocking=True).float() + + image_ms(image) + + image_mean, image_std = image_ms.get_mean_std() + image_min, image_max = image_ms.get_min_max() + + print('Image (mean, std, min, max)') + print('mean: {},'.format(image_mean)) + print('std: {},'.format(image_std)) + print('min: {},'.format(image_min)) + print('max: {}'.format(image_max)) diff --git a/src/mods/v6.py b/src/mods/v6.py new file mode 100644 index 0000000..b7c29f4 --- /dev/null +++ b/src/mods/v6.py @@ -0,0 +1,102 @@ +import torch + +from torchvision import transforms +from torch.utils.data import default_collate +from tqdm import tqdm + + +def collate_fn(cfg): + + def f(batch): + batch = default_collate(batch) + + lr = batch[0].float() + hr = batch[1].float() + + return {'lr': lr, 'hr': hr} + return f + + +def uncollate_fn(cfg, hr_name, lr_name): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view((t2.shape[:2] + (1, 1))) + return t1 + + def denorm(tensor, mean, std): + # get stats + mean = torch.tensor(mean).to(cfg.device) + std = torch.tensor(std).to(cfg.device) + # denorm + return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) + + def f(hr=None, sr=None, lr=None): + return hr, sr, lr + + return f + + +def printable(cfg, hr_name, lr_name, filter_outliers=False, use_minmax=False): + def f(hr=None, sr=None, lr=None): + return hr, sr, lr + + return f + + +class MeanStd(object): + def __init__(self, device): + self.channels_sum = 0 + self.channels_sqrd_sum = 0 + self.num_batches = 0 + self.max = torch.full((4,), -torch.inf, device=device) + self.min = torch.full((4,), torch.inf, device=device) + + def __call__(self, data): + # max + d_perm = data.permute(1, 0, 2, 3) + d_max = d_perm.reshape(data.shape[1], -1).max(dim=1)[0] + self.max = torch.where(d_max > self.max, d_max, self.max) + # min + d_min = d_perm.reshape(data.shape[1], -1).min(dim=1)[0] + self.min = torch.where(d_min <= self.min, d_min, self.min) + # mean, std + self.channels_sum += torch.mean(data, dim=[0, 2, 3]) + self.channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3]) + self.num_batches += 1 + + def get_mean_std(self): + mean = self.channels_sum / self.num_batches + std = (self.channels_sqrd_sum / self.num_batches - mean**2) ** 0.5 + return mean.tolist(), std.tolist() + + def get_min_max(self): + return self.min.tolist(), self.max.tolist() + + +def get_mean_std(train_dloader, cfg): + hr_ms = MeanStd(device=cfg.device) + lr_ms = MeanStd(device=cfg.device) + + for index, batch in tqdm( + enumerate(train_dloader), total=len(train_dloader)): + hr = batch["hr"].to(device=cfg.device, non_blocking=True).float() + lr = batch["lr"].to(device=cfg.device, non_blocking=True).float() + + hr_ms(hr) + lr_ms(lr) + + hr_mean, hr_std = hr_ms.get_mean_std() + lr_mean, lr_std = lr_ms.get_mean_std() + hr_min, hr_max = hr_ms.get_min_max() + lr_min, lr_max = lr_ms.get_min_max() + + print('HR (mean, std, min, max)') + print('mean: {},'.format(hr_mean)) + print('std: {},'.format(hr_std)) + print('min: {},'.format(hr_min)) + print('max: {}'.format(hr_max)) + print('LR (mean, std, min, max)') + print('mean: {},'.format(lr_mean)) + print('std: {},'.format(lr_std)) + print('min: {},'.format(lr_min)) + print('max: {}'.format(lr_max)) diff --git a/src/optim.py b/src/optim.py new file mode 100644 index 0000000..aa70419 --- /dev/null +++ b/src/optim.py @@ -0,0 +1,13 @@ +from torch import optim + + +def build_optimizer(model, cfg): + o_cfg = cfg + + optimizer = optim.Adam(model.parameters(), + o_cfg.optim.learning_rate, + o_cfg.optim.model_betas, + o_cfg.optim.model_eps, + o_cfg.optim.model_weight_decay) + + return optimizer diff --git a/src/semantic_segm/model.py b/src/semantic_segm/model.py new file mode 100644 index 0000000..a04b268 --- /dev/null +++ b/src/semantic_segm/model.py @@ -0,0 +1,100 @@ +import torch + +from torchgeo.models import FarSeg +from easydict import EasyDict +from torch import nn +from torch.nn import functional as F + +from chk_loader import load_state_dict_model_only +from config import load_config +from super_res.model import build_model as super_res_build_model +from utils import set_required_grad + + +class UpScaler(nn.Module): + def __init__(self, cfg): + super().__init__() + u_cfg = cfg.semantic_segm.upscaler + + # load dl upscaler + if u_cfg.get('chk') is not None: + print('load super_res model config file ', u_cfg.config) + model_cfg = EasyDict(load_config(u_cfg.config)) + model_cfg.device = cfg.device + # load super_res + print("loading super_res model") + self.super_res = super_res_build_model(model_cfg) + # load checkpoint + print("loading super_res chk file {}".format(u_cfg.chk)) + swin_checkpoint = torch.load(u_cfg.chk) + load_state_dict_model_only(self.super_res, swin_checkpoint) + # put eval mode + print('set super_res eval mode') + set_required_grad(self.super_res, False) + + def forward(self, x): + with torch.no_grad(): + x = self.super_res.forward_backbone(x) + return x + + +class SemanticSegm(nn.Module): + def __init__(self, cfg): + super().__init__() + ssegm = cfg.semantic_segm + + # padding before and after + self.pad_before = ssegm.pad_before + self.pad_after = self.pad_before + if 'pad_after' in ssegm: + self.pad_after = ssegm.pad_after + + # segmentation model + if ssegm.type == 'FarSeg': + self.model = FarSeg(classes=len(cfg.classes), **ssegm.model) + out_ch = self.model.backbone.conv1.out_channels + self.model.backbone.conv1 = nn.Conv2d( + ssegm.in_channels, out_ch, kernel_size=7, stride=2, padding=3 + ) + + # super res model + if 'upscaler' in ssegm: + self.upscaler = UpScaler(cfg) + # or conv [4, H, W] -> [90, H, W] for model++ + if 'conv_up' in ssegm: + print('model++ with conv_up') + print(ssegm.conv_up) + kernel_size = ssegm.conv_up.get('kernel_size', 1) + padding = ssegm.conv_up.get('padding', 0) + self.conv_up = nn.Conv2d( + ssegm.conv_up.in_ch, ssegm.conv_up.middle_ch, + kernel_size=kernel_size, padding=padding) + self.model.backbone.conv1 = nn.Conv2d( + ssegm.conv_up.middle_ch, ssegm.conv_up.out_ch, + kernel_size=7, stride=2, padding=3) + + def forward(self, x): + # upscale input if super res model is defined + if hasattr(self, 'upscaler'): + x = self.upscaler(x) + if not torch.is_tensor(x): + # remove loss_moe value + x, _ = x + + pad_x = F.pad(x, self.pad_before, "constant", 0) + + # or use conv to increase channels (for baseline++) + if hasattr(self, 'conv_up'): + pad_x = self.conv_up(pad_x) + + # semantic segmantion model + pad_x = self.model(pad_x) + x = pad_x[..., + self.pad_after[0]:-self.pad_after[1], + self.pad_after[2]:-self.pad_after[3]] + + return x + + +def build_model(cfg): + return SemanticSegm(cfg).to(cfg.device) diff --git a/src/semantic_segm/training.py b/src/semantic_segm/training.py new file mode 100644 index 0000000..83ed929 --- /dev/null +++ b/src/semantic_segm/training.py @@ -0,0 +1,77 @@ +import datetime + +import debug + +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +from .validation import validate, save_metrics +from chk_loader import load_checkpoint, load_state_dict_model, \ + save_state_dict_model +from optim import build_optimizer +from .model import build_model +from losses import build_losses + + +def train(train_dloader, val_dloader, cfg): + # Tensorboard + writer = SummaryWriter(cfg.output + '/tensorboard/train_{}'.format( + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))) + # eval every x + eval_every = cfg.metrics.get('eval_every', 1) + + model = build_model(cfg) + losses = build_losses(cfg) + optimizer = build_optimizer(model, cfg) + + begin_epoch = 0 + index = 0 + try: + checkpoint = load_checkpoint(cfg) + begin_epoch, index = load_state_dict_model( + model, optimizer, checkpoint) + except FileNotFoundError: + print('no checkpoint found') + + for e in range(begin_epoch, cfg.epochs): + index = train_epoch( + model, train_dloader, losses, optimizer, e, writer, index, cfg) + + if (e+1) % eval_every == 0: + result = validate( + model, val_dloader, e, writer, 'test', cfg) + # save result of eval + cfg.epoch = e+1 + save_metrics(result, cfg) + + save_state_dict_model(model, optimizer, e, index, cfg) + + +def train_epoch(model, train_dloader, losses, optimizer, epoch, writer, + index, cfg): + weights = cfg.losses.weights + for index, batch in tqdm( + enumerate(train_dloader, index), total=len(train_dloader), + desc='Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + + batch['image'] = batch['image'].to(cfg.device) + batch['mask'] = batch['mask'].to(cfg.device) + + out = model(batch['image']) + + loss_tracker = {} + + if 'ce_criterion' in losses: + loss_tracker['ce_loss'] = losses['ce_criterion']( + out, batch['mask']) * weights.ce + + # train + loss_tracker['train_loss'] = sum(loss_tracker.values()) + optimizer.zero_grad() + loss_tracker['train_loss'].backward() + optimizer.step() + + # log_stats(writer, index, cfg) + debug.log_losses(loss_tracker, 'train', writer, index) + + return index diff --git a/src/semantic_segm/validation.py b/src/semantic_segm/validation.py new file mode 100644 index 0000000..a2e1191 --- /dev/null +++ b/src/semantic_segm/validation.py @@ -0,0 +1,99 @@ +import torch + +from tqdm import tqdm +from collections import OrderedDict + +from utils import F1AverageMeter, load_fun +from chk_loader import load_checkpoint +from validation import get_result_filename + + +def validate(g_model, val_dloader, epoch, writer, mode, cfg): + g_model.eval() + avg_metrics = build_avg_metrics(cfg) + + with torch.no_grad(): + for j, batch in tqdm( + enumerate(val_dloader), total=len(val_dloader), + desc='Val Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + + batch['image'] = batch['image'].to(cfg.device) + batch['mask'] = batch['mask'].to(cfg.device) + + out = g_model(batch['image']) + preds = out.max(-1)[1] + + for k, fun in avg_metrics.items(): + avg_metrics[k].update((preds, batch['mask'])) + + if writer is not None: + for k, v in avg_metrics.items(): + try: + writer.add_scalar( + "{}/{}".format(mode, k), v.avg.item(), epoch+1) + except RuntimeError: + # skip if metric is a list (like f1 per class) + pass + + g_model.train() + return avg_metrics + + +def build_avg_metrics(cfg): + return OrderedDict([ + ('f1_macro', F1AverageMeter( + name="f1_macro", fmt=":4.4f", average='macro', cfg=cfg)), + ('f1_micro', F1AverageMeter( + name="f1_micro", fmt=":4.4f", average='micro', cfg=cfg)), + ('f1_class', F1AverageMeter( + name="f1_class", fmt=":4.4f", average=None, cfg=cfg)), + ]) + + +def main(val_dloader, cfg): + model = load_eval_method(cfg) + result = validate( + model, val_dloader, cfg.epoch - 1, None, 'test', cfg) + save_metrics(result, cfg) + return result + + +def load_eval_method(cfg): + vis = cfg.visualize + model = load_fun(vis.get('model'))(cfg) + # Load model state dict + try: + checkpoint = load_checkpoint(cfg) + _, _ = load_fun(vis.get('checkpoint'))(model, checkpoint) + except Exception as e: + print(e) + exit(-1) + + return model + + +def load_metrics(cfg): + filename = get_result_filename(cfg) + print('load results {}'.format(filename)) + result = torch.load(filename) + # check if epoch corresponds + assert result['epoch'] == cfg.epoch + # loadl classes + cfg.classes = result['classes'] + # build AVG objects + avg_metrics = build_avg_metrics(cfg) + for k, v in result['metrics'].items(): + avg_metrics[k].avg = v + return avg_metrics + + +def save_metrics(metrics, cfg): + filename = get_result_filename(cfg) + print('save results {}'.format(filename)) + torch.save({ + 'classes': cfg.classes, + 'epoch': cfg.epoch, + 'metrics': OrderedDict([ + (k, v.avg) for k, v in metrics.items() + ]) + }, filename) diff --git a/src/semantic_segm/visualize.py b/src/semantic_segm/visualize.py new file mode 100644 index 0000000..448c3b1 --- /dev/null +++ b/src/semantic_segm/visualize.py @@ -0,0 +1,122 @@ +import torch +import os +import matplotlib.pyplot as plt +import imgproc + +from tqdm import tqdm + +from chk_loader import load_checkpoint +from datasets.seasonet import load_dataset +from utils import load_fun + + +def plot_images(img, out_path, basename, fname, batch_number, dpi): + # dir: out_dir / basename + out_path = os.path.join(out_path, basename) + if not os.path.exists(out_path): + os.makedirs(out_path) + + n_cols = 1 + n_bands = img.shape[1] + fig, axs = plt.subplots( + n_bands, n_cols, sharex=True, sharey=True, figsize=(n_bands, 1)) + if img is not None: + img = img.squeeze(0) + for i in range(img.shape[0]): + axs[i].imshow( + imgproc.tensor_to_image(img[i].detach(), False, False)) + axs[i].axis('off') + + out_fname = os.path.join(out_path, '{}_{}.png'.format(fname, batch_number)) + plt.savefig(out_fname, dpi=dpi) + plt.close() + + +def save_fig(output, filename, dpi): + plt.imshow(output.cpu().numpy().squeeze(0)) + plt.savefig(filename, dpi=dpi, pad_inches=0.01) + print('image saved: {}'.format(filename)) + + +def plot_classes(image, out_path, basename, index, cfg): + # dir: out_dir / basename + out_path = os.path.join(out_path, basename) + if not os.path.exists(out_path): + os.makedirs(out_path) + + for idx, name in enumerate(cfg.classes): + output = (image == idx).int() * 255. + unq = output.unique() + # import ipdb; ipdb.set_trace() + if len(unq) > 1 or unq.item() != 0: + print('class ', idx) + # import ipdb; ipdb.set_trace() + fname = 'image_{:02d}_{:02d}.jpg'.format(index, idx) + full_name = os.path.join(out_path, fname) + save_fig(output, full_name, cfg.dpi) + + +def main(cfg): + cfg['batch_size'] = 1 + vis = cfg.get('visualize', {}) + # Load dataset + _, val_dloader, _ = load_dataset(cfg, only_test=True) + # Initialize model + model = load_fun(vis.get('model'))(cfg) + denorm = load_fun(cfg.dataset.get('denorm'))(cfg) + # evaluable = load_fun(cfg.dataset.get('printable'))(cfg) + printable = load_fun(cfg.dataset.get('printable'))(cfg) + + # Load model state dict + try: + checkpoint = load_checkpoint(cfg) + _, _ = load_fun(vis.get('checkpoint'))(model, checkpoint) + except Exception: + print('no model checkpoint found') + exit(0) + + # Create a folder of super-resolution experiment results + out_path = os.path.join(cfg.output, 'images_{}'.format(cfg.epoch)) + if not os.path.exists(out_path): + os.makedirs(out_path) + + model.to(cfg.device) + model.eval() + + iterations = min(cfg.num_images, len(val_dloader)) + + for index, batch in tqdm( + enumerate(val_dloader), total=iterations, + desc='%d Images' % (iterations)): + + if index >= iterations: + break + + batch['image'] = batch['image'].to(cfg.device) + batch['mask'] = batch['mask'].to(cfg.device) + + out = model(batch['image']) + preds = out.max(-1)[1] + + plot_classes(preds, out_path, 'preds', index, cfg) + plot_classes(batch['mask'], out_path, 'gt', index, cfg) + + sr = None + if not cfg.hide_sr: + if hasattr(model, 'upscaler'): + sr = model.upscaler(batch['image']) + + if not torch.is_tensor(sr): + sr, _ = sr + + # denormalize to original values + sr, lr = denorm(sr, batch['image']) + + # normalize [0, 1] removing outliers to have a printable version + sr, lr = printable(sr, lr) + + # plot images + plot_images(lr, out_path, 'images', 'lr', index, cfg.dpi) + + if not cfg.hide_sr and sr is not None: + plot_images(sr, out_path, 'images', 'sr', index, cfg.dpi) diff --git a/src/super_res/model.py b/src/super_res/model.py new file mode 100644 index 0000000..fb1fb98 --- /dev/null +++ b/src/super_res/model.py @@ -0,0 +1,15 @@ + +def build_model(cfg): + version = cfg.super_res.get('version', 'v1') + print('load super_res {}'.format(version)) + + if version == 'v1': + from .network_swinir import SwinIR as SRModel + elif version == 'v2': + from .network_swin2sr import Swin2SR as SRModel + elif version == 'swinfir': + from .swinfir_arch import SwinFIR as SRModel + + model = SRModel(**cfg.super_res.model).to(cfg.device) + + return model diff --git a/src/super_res/moe.py b/src/super_res/moe.py new file mode 100644 index 0000000..5815c21 --- /dev/null +++ b/src/super_res/moe.py @@ -0,0 +1,324 @@ +# +# Source code: https://github.com/davidmrau/mixture-of-experts +# + +# Sparsely-Gated Mixture-of-Experts Layers. +# See "Outrageously Large Neural Networks" +# https://arxiv.org/abs/1701.06538 +# +# Author: David Rau +# +# The code is based on the TensorFlow implementation: +# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py + + +import torch +import torch.nn as nn +from torch.distributions.normal import Normal +from copy import deepcopy +import numpy as np + +from .utils import Mlp as MLP + + +class SparseDispatcher(object): + """Helper for implementing a mixture of experts. + The purpose of this class is to create input minibatches for the + experts and to combine the results of the experts to form a unified + output tensor. + There are two functions: + dispatch - take an input Tensor and create input Tensors for each expert. + combine - take output Tensors from each expert and form a combined output + Tensor. Outputs from different experts for the same batch element are + summed together, weighted by the provided "gates". + The class is initialized with a "gates" Tensor, which specifies which + batch elements go to which experts, and the weights to use when combining + the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. + The inputs and outputs are all two-dimensional [batch, depth]. + Caller is responsible for collapsing additional dimensions prior to + calling this class and reshaping the output to the original shape. + See common_layers.reshape_like(). + Example use: + gates: a float32 `Tensor` with shape `[batch_size, num_experts]` + inputs: a float32 `Tensor` with shape `[batch_size, input_size]` + experts: a list of length `num_experts` containing sub-networks. + dispatcher = SparseDispatcher(num_experts, gates) + expert_inputs = dispatcher.dispatch(inputs) + expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] + outputs = dispatcher.combine(expert_outputs) + The preceding code sets the output for a particular example b to: + output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) + This class takes advantage of sparsity in the gate matrix by including in the + `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. + """ + + def __init__(self, num_experts, gates): + """Create a SparseDispatcher.""" + + self._gates = gates + self._num_experts = num_experts + # sort experts + sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) + # drop indices + _, self._expert_index = sorted_experts.split(1, dim=1) + # get according batch index for each expert + self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] + # calculate num samples that each expert gets + self._part_sizes = (gates > 0).sum(0).tolist() + # expand gates to match with self._batch_index + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + """Create one input Tensor for each expert. + The `Tensor` for a expert `i` contains the slices of `inp` corresponding + to the batch elements `b` where `gates[b, i] > 0`. + Args: + inp: a `Tensor` of shape "[batch_size, ]` + Returns: + a list of `num_experts` `Tensor`s with shapes + `[expert_batch_size_i, ]`. + """ + + # assigns samples to experts whose gate is nonzero + + # expand according to batch index so we can just split by _part_sizes + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + def combine(self, expert_out, multiply_by_gates=True, cnn_combine=None): + """Sum together the expert output, weighted by the gates. + The slice corresponding to a particular batch element `b` is computed + as the sum over all experts `i` of the expert output, weighted by the + corresponding gate values. If `multiply_by_gates` is set to False, the + gate values are ignored. + Args: + expert_out: a list of `num_experts` `Tensor`s, each with shape + `[expert_batch_size_i, ]`. + multiply_by_gates: a boolean + Returns: + a `Tensor` with shape `[batch_size, ]`. + """ + # apply exp to expert outputs, so we are not longer in log space + stitched = torch.cat(expert_out, 0) + + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates.unsqueeze(1)) + zeros = torch.zeros((self._gates.size(0),) + expert_out[-1].shape[1:], + requires_grad=True, device=stitched.device) + # combine samples that have been processed by the same k experts + + if cnn_combine is not None: + return self.smartly_combine(stitched, cnn_combine) + + combined = zeros.index_add(0, self._batch_index, stitched.float()) + return combined + + def smartly_combine(self, stitched, cnn_combine): + idxes = [] + for i in self._batch_index.unique(): + idx = (self._batch_index == i).nonzero().squeeze(1) + idxes.append(idx) + idxes = torch.stack(idxes) + return cnn_combine(stitched[idxes]).squeeze(1) + + def expert_to_gates(self): + """Gate values corresponding to the examples in the per-expert `Tensor`s. + Returns: + a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` + and shapes `[expert_batch_size_i]` + """ + # split nonzero gates for each expert + return torch.split(self._nonzero_gates, self._part_sizes, dim=0) + + +def build_experts(experts_cfg, default_cfg, num_experts): + experts_cfg = deepcopy(experts_cfg) + if experts_cfg is None: + # old build way + return nn.ModuleList([ + MLP(*default_cfg) + for i in range(num_experts)]) + # new build way: mix mlp with leff + experts = [] + for e_cfg in experts_cfg: + type_ = e_cfg.pop('type') + if type_ == 'mlp': + experts.append(MLP(*default_cfg)) + return nn.ModuleList(experts) + + +class MoE(nn.Module): + """Call a Sparsely gated mixture of experts layer with 1-layer + Feed-Forward networks as experts. + + Args: + input_size: integer - size of the input + output_size: integer - size of the input + num_experts: an integer - number of experts + hidden_size: an integer - hidden size of the experts + noisy_gating: a boolean + k: an integer - how many experts to use for each batch element + """ + + def __init__(self, input_size, output_size, num_experts, hidden_size, + experts=None, noisy_gating=True, k=4, + x_gating=None, with_noise=True, with_smart_merger=None): + super(MoE, self).__init__() + self.noisy_gating = noisy_gating + self.num_experts = num_experts + self.output_size = output_size + self.input_size = input_size + self.hidden_size = hidden_size + self.k = k + self.with_noise = with_noise + # instantiate experts + self.experts = build_experts( + experts, + (self.input_size, self.hidden_size, self.output_size), + num_experts) + self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) + self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) + + self.x_gating = x_gating + if self.x_gating == 'conv1d': + self.x_gate = nn.Conv1d(4096, 1, kernel_size=3, padding=1) + + self.softplus = nn.Softplus() + self.softmax = nn.Softmax(1) + self.register_buffer("mean", torch.tensor([0.0])) + self.register_buffer("std", torch.tensor([1.0])) + assert(self.k <= self.num_experts) + + self.cnn_combine = None + if with_smart_merger == 'v1': + print('with SMART MERGER') + self.cnn_combine = nn.Conv2d(self.k, 1, kernel_size=3, padding=1) + + def cv_squared(self, x): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`. + """ + eps = 1e-10 + # if only num_experts = 1 + + if x.shape[0] == 1: + return torch.tensor([0], device=x.device, dtype=x.dtype) + return x.float().var() / (x.float().mean()**2 + eps) + + def _gates_to_load(self, gates): + """Compute the true load per expert, given the gates. + The load is the number of examples for which the corresponding gate is >0. + Args: + gates: a `Tensor` of shape [batch_size, n] + Returns: + a float32 `Tensor` of shape [n] + """ + return (gates > 0).sum(0) + + def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): + """Helper function to NoisyTopKGating. + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 + Returns: + a `Tensor` of shape [batch, n]. + """ + batch = clean_values.size(0) + m = noisy_top_values.size(1) + top_values_flat = noisy_top_values.flatten() + + threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k + threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + is_in = torch.gt(noisy_values, threshold_if_in) + threshold_positions_if_out = threshold_positions_if_in - 1 + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + # is each value currently in the top k. + normal = Normal(self.mean, self.std) + prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) + prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob = torch.where(is_in, prob_if_in, prob_if_out) + return prob + + def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): + """Noisy top-k gating. + See paper: https://arxiv.org/abs/1701.06538. + Args: + x: input Tensor with shape [batch_size, input_size] + train: a boolean - we only add noise at training time. + noise_epsilon: a float + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] + """ + clean_logits = x @ self.w_gate + if self.noisy_gating and train: + raw_noise_stddev = x @ self.w_noise + noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + + # calculate topk + 1 that will be needed for the noisy gates + top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) + top_k_logits = top_logits[:, :self.k] + top_k_indices = top_indices[:, :self.k] + top_k_gates = self.softmax(top_k_logits) + + zeros = torch.zeros_like(logits, requires_grad=True) + gates = zeros.scatter(1, top_k_indices, top_k_gates) + + if self.noisy_gating and self.k < self.num_experts and train: + load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + else: + load = self._gates_to_load(gates) + return gates, load + + def forward(self, x, loss_coef=1e-2): + """Args: + x: tensor shape [batch_size, input_size] + train: a boolean scalar. + loss_coef: a scalar - multiplier on load-balancing losses + + Returns: + y: a tensor with shape [batch_size, output_size]. + extra_training_loss: a scalar. This should be added into the overall + training loss of the model. The backpropagation of this loss + encourages all experts to be approximately equally used across a batch. + """ + if self.x_gating is not None: + xg = self.x_gate(x).squeeze(1) + else: + xg = x.mean(1) + + gates, load = self.noisy_top_k_gating( + xg, self.training and self.with_noise) + # calculate importance loss + importance = gates.sum(0) + # + loss = self.cv_squared(importance) + self.cv_squared(load) + loss *= loss_coef + + dispatcher = SparseDispatcher(self.num_experts, gates) + expert_inputs = dispatcher.dispatch(x) + gates = dispatcher.expert_to_gates() + expert_outputs = [self.experts[i](expert_inputs[i]) + for i in range(self.num_experts)] + y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine) + return y, loss diff --git a/src/super_res/network_swin2sr.py b/src/super_res/network_swin2sr.py new file mode 100644 index 0000000..c41e533 --- /dev/null +++ b/src/super_res/network_swin2sr.py @@ -0,0 +1,1173 @@ +# +# Source code: https://github.com/mv-lab/swin2sr +# +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345 +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .utils import window_reverse, Mlp, window_partition +from .moe import MoE + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0], + use_lepe=False, + use_cpb_bias=True, + use_rpe_bias=False): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + self.use_cpb_bias = use_cpb_bias + + if self.use_cpb_bias: + print('positional encoder: CPB') + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.use_rpe_bias = use_rpe_bias + if self.use_rpe_bias: + print('positional encoder: RPE') + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + rpe_relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("rpe_relative_position_index", rpe_relative_position_index) + + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + self.use_lepe = use_lepe + if self.use_lepe: + print('positional encoder: LEPE') + self.get_v = nn.Conv2d( + dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + if self.use_lepe: + lepe = self.lepe_pos(v) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + if self.use_cpb_bias: + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if self.use_rpe_bias: + relative_position_bias = self.relative_position_bias_table[self.rpe_relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v) + + if self.use_lepe: + x = x + lepe + + x = x.transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def lepe_pos(self, v): + B, NH, HW, NW = v.shape + C = NH * NW + H = W = int(math.sqrt(HW)) + v = v.transpose(-2, -1).contiguous().view(B, C, H, W) + lepe = self.get_v(v) + lepe = lepe.reshape(-1, self.num_heads, NW, HW) + lepe = lepe.permute(0, 1, 3, 2).contiguous() + return lepe + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0, + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + use_rpe_bias=use_rpe_bias) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if MoE_config is None: + print('-->>> MLP') + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + else: + print('-->>> MOE') + print(MoE_config) + self.mlp = MoE( + input_size=dim, output_size=dim, hidden_size=mlp_hidden_dim, + **MoE_config) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + + loss_moe = None + res = self.mlp(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = x + self.drop_path(self.norm2(res)) + + return x, loss_moe + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0, + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + loss_moe_all = 0 + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + if self.downsample is not None: + x = self.downsample(x) + return x, loss_moe_all + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + loss_moe = None + res = self.residual_group(x, x_size) + + if not torch.is_tensor(res): + res, loss_moe = res + + res = self.patch_embed(self.conv(self.patch_unembed(res, x_size))) + return res + x, loss_moe + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False, + **kwargs): + super(Swin2SR, self).__init__() + print('==== SWIN 2SR') + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias, + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + loss_moe_all = 0 + for layer in self.layers: + x = layer(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x, loss_moe_all + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + loss_moe_all = 0 + for layer in self.layers_hf: + x = layer(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x, loss_moe_all + + def forward_backbone(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + else: + raise Exception('not implemented yet') + + x = x / self.img_range + self.mean + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + loss_moe = 0 + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + + res_hf = self.forward_features_hf(x_hf) + if not torch.is_tensor(res_hf): + res_hf, loss_moe_hf = res_hf + loss_moe += loss_moe_hf + + x_hf = self.conv_after_body_hf(res_hf) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + + res = self.forward_features(x_first) + if not torch.is_tensor(res): + res, loss_moe = res + + res = self.conv_after_body(res) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux, loss_moe + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale], loss_moe + + else: + return x[:, :, :H*self.upscale, :W*self.upscale], loss_moe + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/src/super_res/network_swinir.py b/src/super_res/network_swinir.py new file mode 100644 index 0000000..39edfab --- /dev/null +++ b/src/super_res/network_swinir.py @@ -0,0 +1,853 @@ +# +# Original code from: https://github.com/JingyunLiang/SwinIR +# +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .utils import window_reverse, Mlp, window_partition + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + use_lepe=False, use_cpb_bias=True, + MoE_config=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + loss_moe_all = 0 + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + if self.downsample is not None: + x = self.downsample(x) + + return x, loss_moe_all + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', + use_lepe=False, use_cpb_bias=True, + MoE_config=None): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + loss_moe = None + res = self.residual_group(x, x_size) + if not torch.is_tensor(res): + res, loss_moe = res + return ( + self.patch_embed( + self.conv(self.patch_unembed(res, x_size))) + x + ), loss_moe + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + use_transpose=False, + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + **kwargs): + super(SwinIR, self).__init__() + print('==== SWIN IR') + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + loss_moe_all = 0 + for layer in self.layers: + x = layer(x, x_size) + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x, loss_moe_all + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + loss_moe = None + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + x = self.conv_after_body(res) + x + x = self.upsample(x) + elif self.upsampler == 'sunet': + # for lightweight SR + x = self.conv_first(x) + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + x = self.conv_after_body(res) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale], loss_moe + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops diff --git a/src/super_res/swinfir_arch.py b/src/super_res/swinfir_arch.py new file mode 100644 index 0000000..904b5a0 --- /dev/null +++ b/src/super_res/swinfir_arch.py @@ -0,0 +1,537 @@ +# +# Source code: https://github.com/Zdafeng/SwinFIR +# Paper: https://arxiv.org/pdf/2208.11247v3.pdf +# + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import to_2tuple, trunc_normal_ +from .swinfir_utils import WindowAttention, DropPath, Mlp, SFB, \ + PatchEmbed, PatchUnEmbed, Upsample, UpsampleOneStep, window_partition, \ + window_reverse + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + h, w = x_size + img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1 + h_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' + f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}') + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == 'SFB': + self.conv = SFB(dim) + elif resi_connection == 'HSFB': + self.conv = SFB(dim, 2) + elif resi_connection == 'identity': + self.conv = nn.Identity() + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + +# @ARCH_REGISTRY.register() +class SwinFIR(nn.Module): + r""" SwinFIR + A PyTorch impl of : `SwinFIR: Revisiting the SwinIR with Fast Fourier Convolution and + Improved Training for Image Super-Resolution`, based on Swin Transformer and Fast Fourier Convolution. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1., + upsampler='', + resi_connection='SFB', + **kwargs): + super(SwinFIR, self).__init__() + print('==== SWIN FIR') + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.3014, 0.3152, 0.3094) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # b seq_len c + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + input = x + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x diff --git a/src/super_res/swinfir_utils.py b/src/super_res/swinfir_utils.py new file mode 100644 index 0000000..ce9db86 --- /dev/null +++ b/src/super_res/swinfir_utils.py @@ -0,0 +1,725 @@ +# +# Source code: https://github.com/Zdafeng/SwinFIR +# Paper: https://arxiv.org/pdf/2208.11247v3.pdf +# + +import math +import torch +import torch.nn as nn + +from einops import rearrange +from timm.models.layers import to_2tuple, trunc_normal_ + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ChannelAttention(nn.Module): + """Channel attention used in RCAN. + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + """ + + def __init__(self, num_feat, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), + nn.Sigmoid()) + + def forward(self, x): + y = self.attention(x) + return x * y + + +def window_partition(x, window_size): + """ + Args: + x: (b, h, w, c) + window_size (int): window size + + Returns: + windows: (num_windows*b, window_size, window_size, c) + """ + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows + + +def window_reverse(windows, window_size, h, w): + """ + Args: + windows: (num_windows*b, window_size, window_size, c) + window_size (int): Window size + h (int): Height of image + w (int): Width of image + + Returns: + x: (b, h, w, c) + """ + b = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + +class CAB(nn.Module): + + def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30): + super(CAB, self).__init__() + + self.cab = nn.Sequential( + nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), + nn.GELU(), + nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), + ChannelAttention(num_feat, squeeze_factor) + ) + + def forward(self, x): + return self.cab(x) + + +class WindowAttentionHATFIR(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, rpi, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class HAB(nn.Module): + r""" Hybrid Attention Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttentionHATFIR( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.conv_scale = conv_scale + self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size, rpi_sa, attn_mask): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # Conv_X + conv_x = self.conv_block(x.permute(0, 3, 1, 2)) + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = attn_mask + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attn_x = shifted_x + attn_x = attn_x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: b, h*w, c + """ + h, w = self.input_resolution + b, seq_len, c = x.shape + assert seq_len == h * w, 'input feature has wrong size' + assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.' + + x = x.view(b, h, w, c) + + x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c + x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c + x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c + x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c + x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c + x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class OCAB(nn.Module): + # overlapping cross-attention block + + def __init__(self, dim, + input_resolution, + window_size, + overlap_ratio, + num_heads, + qkv_bias=True, + qk_scale=None, + mlp_ratio=2, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.overlap_win_size = int(window_size * overlap_ratio) + window_size + + self.norm1 = norm_layer(dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, + padding=(self.overlap_win_size - window_size) // 2) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + self.proj = nn.Linear(dim, dim) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU) + + def forward(self, x, x_size, rpi): + h, w = x_size + b, _, c = x.shape + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w + q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c + kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w + + # partition windows + q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c + q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + kv_windows = self.unfold(kv) # b, c*w*w, nw + kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, + owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c + k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c + + b_, nq, _ = q_windows.shape + _, n, _ = k_windows.shape + d = self.dim // self.num_heads + q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d + k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d + v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view( + self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, + -1) # ws*ws, wse*wse, nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse + attn = attn + relative_position_bias.unsqueeze(0) + + attn = self.softmax(attn) + attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim) + x = window_reverse(attn_windows, self.window_size, h, w) # b h w c + x = x.view(b, h * w, self.dim) + + x = self.proj(x) + shortcut + + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # b Ph*Pw c + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + +class FourierUnit(nn.Module): + def __init__(self, embed_dim, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.conv_layer = torch.nn.Conv2d(embed_dim * 2, embed_dim * 2, 1, 1, 0) + self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(ffted) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4, + 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + return output + + +class SpectralTransform(nn.Module): + def __init__(self, embed_dim, last_conv=False): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.last_conv = last_conv + + self.conv1 = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True) + ) + self.fu = FourierUnit(embed_dim // 2) + + self.conv2 = torch.nn.Conv2d(embed_dim // 2, embed_dim, 1, 1, 0) + + if self.last_conv: + self.last_conv = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + output = self.fu(x) + output = self.conv2(x + output) + if self.last_conv: + output = self.last_conv(output) + return output + + +## Residual Block (RB) +class ResB(nn.Module): + def __init__(self, embed_dim, red=1): + super(ResB, self).__init__() + self.body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // red, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embed_dim // red, embed_dim, 3, 1, 1), + ) + + def __call__(self, x): + out = self.body(x) + return out + x + + +class SFB(nn.Module): + def __init__(self, embed_dim, red=1): + super(SFB, self).__init__() + self.S = ResB(embed_dim, red) + self.F = SpectralTransform(embed_dim) + self.fusion = nn.Conv2d(embed_dim * 2, embed_dim, 1, 1, 0) + + def __call__(self, x): + s = self.S(x) + f = self.F(x) + out = torch.cat([s, f], dim=1) + out = self.fusion(out) + return out diff --git a/src/super_res/training.py b/src/super_res/training.py new file mode 100644 index 0000000..96ba272 --- /dev/null +++ b/src/super_res/training.py @@ -0,0 +1,109 @@ +import datetime +import torch + +import debug + +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +from validation import validate, do_save_metrics as save_metrics +from chk_loader import load_checkpoint, load_state_dict_model, \ + save_state_dict_model +from validation import build_eval_metrics +from losses import build_losses +from optim import build_optimizer +from .model import build_model + + +def train(train_dloader, val_dloader, cfg): + + # Tensorboard + writer = SummaryWriter(cfg.output + '/tensorboard/train_{}'.format( + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))) + # eval every x + eval_every = cfg.metrics.get('eval_every', 1) + + model = build_model(cfg) + losses = build_losses(cfg) + optimizer = build_optimizer(model, cfg) + + begin_epoch = 0 + index = 0 + try: + checkpoint = load_checkpoint(cfg) + begin_epoch, index = load_state_dict_model( + model, optimizer, checkpoint) + except FileNotFoundError: + print('no checkpoint found') + + print('build eval metrics') + metrics = build_eval_metrics(cfg) + + for e in range(begin_epoch, cfg.epochs): + index = train_epoch( + model, + train_dloader, + losses, + optimizer, + e, + writer, + index, + cfg) + + if (e+1) % eval_every == 0: + result = validate( + model, val_dloader, metrics, e, writer, 'test', cfg) + # save result of eval + cfg.epoch = e+1 + save_metrics(result, cfg) + + save_state_dict_model(model, optimizer, e, index, cfg) + + +def train_epoch(model, train_dloader, losses, optimizer, epoch, writer, + index, cfg): + weights = cfg.losses.weights + for index, batch in tqdm( + enumerate(train_dloader, index), total=len(train_dloader), + desc='Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + + # Transfer in-memory data to CUDA devices to speed up training + hr = batch["hr"].to(device=cfg.device, non_blocking=True) + lr = batch["lr"].to(device=cfg.device, non_blocking=True) + + sr = model(lr) + + loss_tracker = {} + + loss_moe = None + if not torch.is_tensor(sr): + sr, loss_moe = sr + if torch.is_tensor(loss_moe): + loss_tracker['loss_moe'] = loss_moe * weights.moe + + sr = sr.contiguous() + + if 'pixel_criterion' in losses: + loss_tracker['pixel_loss'] = weights.pixel * \ + losses['pixel_criterion'](sr, hr) + + # cc loss + if 'cc_criterion' in losses: + loss_tracker['cc_loss'] = weights.cc * \ + losses['cc_criterion'](sr, hr) + + # ssim loss + if 'ssim_criterion' in losses: + loss_tracker['ssim_loss'] = weights.ssim * \ + losses['ssim_criterion'](sr, hr) + + # train + loss_tracker['train_loss'] = sum(loss_tracker.values()) + optimizer.zero_grad() + loss_tracker['train_loss'].backward() + optimizer.step() + + debug.log_hr_stats(lr, sr, hr, writer, index, cfg) + debug.log_losses(loss_tracker, 'train', writer, index) + + return index diff --git a/src/super_res/utils.py b/src/super_res/utils.py new file mode 100644 index 0000000..e7bb96d --- /dev/null +++ b/src/super_res/utils.py @@ -0,0 +1,56 @@ +from torch import nn + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, + W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view( + -1, window_size, window_size, C) + return windows diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..2ef46c2 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,117 @@ +import torch +import importlib +import numpy as np +import random + +from enum import Enum +from torchmetrics.classification import MulticlassF1Score + + +def set_deterministic(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) # Numpy module. + random.seed(seed) # Python random module. + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + +def w_count(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def load_fun(fullname): + path, name = fullname.rsplit('.', 1) + return getattr(importlib.import_module(path), name) + + +class Summary(Enum): + NONE = 0 + AVERAGE = 1 + SUM = 2 + COUNT = 3 + + +class AverageMeter(object): + def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): + self.name = name + self.fmt = fmt + self.summary_type = summary_type + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + self.avg_item = self.avg.tolist() + fmtstr = "{avg_item" + self.fmt + "}" + try: + return fmtstr.format(**self.__dict__) + except TypeError: + # print a list of elements + fmtstr = "{" + self.fmt + "}" + return ' '.join([ + fmtstr for _ in range(len(self.avg_item)) + ]).format(*self.avg_item) + + def summary(self): + if self.summary_type is Summary.NONE: + fmtstr = "" + elif self.summary_type is Summary.AVERAGE: + fmtstr = "{name} {avg:.2f}" + elif self.summary_type is Summary.SUM: + fmtstr = "{name} {sum:.2f}" + elif self.summary_type is Summary.COUNT: + fmtstr = "{name} {count:.2f}" + else: + raise ValueError( + "Invalid summary type {}".format(self.summary_type)) + + return fmtstr.format(**self.__dict__) + + +class F1AverageMeter(AverageMeter): + def __init__(self, cfg, average, **kwargs): + self.cfg = cfg + self._cfg_average = average + super().__init__(**kwargs) + + def reset(self): + super().reset() + self._to_update = True + self.fun = MulticlassF1Score( + len(self.cfg.classes), average=self._cfg_average + ).to(self.cfg.device) + + @property + def avg(self): + if self._to_update: + self._avg = self.fun.compute() + self._to_update = False + return self._avg + + @avg.setter + def avg(self, value): + self._avg = value + self._to_update = False + + def update(self, val, n=1): + pred, gt = val + self.fun.update(pred, gt) + self._to_update = True + + +def set_required_grad(model, value): + for parameters in model.parameters(): + parameters.requires_grad = value diff --git a/src/validation.py b/src/validation.py new file mode 100644 index 0000000..778ddf8 --- /dev/null +++ b/src/validation.py @@ -0,0 +1,187 @@ +import os +import torch + +from torch import nn +from tqdm import tqdm +from torch.nn import Upsample +from collections import OrderedDict + +from utils import AverageMeter, load_fun +from metrics import CC, SAM, ERGAS, piq_psnr, piq_ssim, \ + piq_rmse +from chk_loader import load_checkpoint + + +def validate(g_model, val_dloader, metrics, epoch, writer, mode, cfg): + # Put the adversarial network model in validation mode + g_model.eval() + + avg_metrics = build_avg_metrics() + + use_minmax = cfg.dataset.get('stats', {}).get('use_minmax', False) + dset = cfg.dataset + denorm = load_fun(dset.get('denorm'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name) + evaluable = load_fun(dset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + + with torch.no_grad(): + for j, batch in tqdm( + enumerate(val_dloader), total=len(val_dloader), + desc='Val Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + hr = batch["hr"].to(device=cfg.device, non_blocking=True) + lr = batch["lr"].to(device=cfg.device, non_blocking=True) + + # Use the generator model to generate a fake sample + sr = g_model(lr) + + if not torch.is_tensor(sr): + sr, _ = sr + + sr = sr.contiguous() + + # denormalize to original values + hr, sr, lr = denorm(hr, sr, lr) + + # normalize [0, 1] using also the outliers to evaluate + hr, sr, lr = evaluable(hr, sr, lr) + + # Statistical loss value for terminal data output + for k, fun in metrics.items(): + for i in range(len(sr)): + avg_metrics[k].update(fun(sr[i][None], hr[i][None])) + + if writer is not None: + for k, v in avg_metrics.items(): + writer.add_scalar("{}/{}".format(mode, k), v.avg.item(), epoch+1) + + if cfg.get('eval_return_to_train', True): + g_model.train() + return avg_metrics + + +def build_eval_metrics(cfg): + # Create an IQA evaluation model + metrics = { + 'psnr_model': piq_psnr(cfg), + 'ssim_model': piq_ssim(cfg), + 'cc_model': CC(), + 'rmse_model': piq_rmse(cfg), + 'sam_model': SAM(), + 'ergas_model': ERGAS(), + } + + for k in metrics.keys(): + metrics[k] = metrics[k].to(cfg.device) + + return metrics + + +def build_avg_metrics(): + return OrderedDict([ + ('psnr_model', AverageMeter("PIQ_PSNR", ":4.4f")), + ('ssim_model', AverageMeter("PIQ_SSIM", ":4.4f")), + ('cc_model', AverageMeter("CC", ":4.4f")), + ('rmse_model', AverageMeter("PIQ_RMSE", ":4.4f")), + ('sam_model', AverageMeter("SAM", ":4.4f")), + ('ergas_model', AverageMeter("ERGAS", ":4.4f")), + ]) + + +def main(val_dloader, cfg, save_metrics=True): + model = load_eval_method(cfg) + print('build eval metrics') + metrics = build_eval_metrics(cfg) + result = validate( + model, val_dloader, metrics, cfg.epoch, None, 'test', cfg) + if save_metrics: + do_save_metrics(result, cfg) + return result + + +def get_result_filename(cfg): + output_dir = os.path.join(cfg.output, 'eval') + os.makedirs(output_dir, exist_ok=True) + return os.path.join(output_dir, 'results-{:02d}.pt'.format(cfg.epoch)) + + +def do_save_metrics(metrics, cfg): + filename = get_result_filename(cfg) + print('save results {}'.format(filename)) + torch.save({ + 'epoch': cfg.epoch, + 'metrics': OrderedDict([ + (k, v.avg) for k, v in metrics.items() + ]) + }, filename) + + +def load_metrics(cfg): + filename = get_result_filename(cfg) + print('load results {}'.format(filename)) + result = torch.load(filename) + # check if epoch corresponds + assert result['epoch'] == cfg.epoch + # build AVG objects + avg_metrics = build_avg_metrics() + for k, v in result['metrics'].items(): + avg_metrics[k].avg = v + return avg_metrics + + +def print_metrics(metrics): + names = [] + values = [] + for i, v in enumerate(metrics.values()): + try: + names.append(v.name) + values.append(v) + except AttributeError: + # skip for retrocompatibility + pass + print(*names) + print(*values) + + +def load_eval_method(cfg): + if cfg.eval_method is None: + vis = cfg.visualize + model = load_fun(vis.get('model'))(cfg) + # Load model state dict + try: + checkpoint = load_checkpoint(cfg) + _, _ = load_fun(vis.get('checkpoint'))(model, checkpoint) + except Exception as e: + print(e) + exit(0) + + return model + + print('load non-dl upsampler: {}'.format(cfg.eval_method)) + return NonDLEvalMethod(cfg) + + +class NonDLEvalMethod(object): + def __init__(self, cfg): + self.upscale_factor = cfg.metrics.upscale_factor + self.upsampler = Upsample( + scale_factor=self.upscale_factor, + mode=cfg.eval_method) + + def __call__(self, x): + return self.upsampler(x) + + def eval(self): + pass + + def train(self): + pass + + def to(self, device): + return self diff --git a/src/visualize.py b/src/visualize.py new file mode 100644 index 0000000..f97ff35 --- /dev/null +++ b/src/visualize.py @@ -0,0 +1,101 @@ +import os +import torch +import matplotlib.pyplot as plt +import imgproc + +from tqdm import tqdm + +from validation import build_eval_metrics, load_eval_method +from utils import load_fun + + +def plot_images(img, out_path, basename, fname, dpi): + # dir: out_dir / basename + out_path = os.path.join(out_path, basename) + if not os.path.exists(out_path): + os.makedirs(out_path) + + img = img.squeeze(0) + for i in range(img.shape[0]): + plt.imshow(imgproc.tensor_to_image(img[i].detach(), False, False)) + plt.axis('off') + out_fname = os.path.join(out_path, '{}_{}.png'.format(fname, i)) + plt.savefig(out_fname, dpi=dpi) + plt.close() + + +def main(cfg): + # Load model state dict + model = load_eval_method(cfg) + + # manipulation function for data + use_minmax = cfg.dataset.get('stats').get('use_minmax', False) + denorm = load_fun(cfg.dataset.get('denorm'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name) + evaluable = load_fun(cfg.dataset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + printable = load_fun(cfg.dataset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + + # Create a folder of super-resolution experiment results + out_path = os.path.join(cfg.output, 'images_{}'.format(cfg.epoch)) + if not os.path.exists(out_path): + os.makedirs(out_path) + + # move to device + model.to(cfg.device) + model.eval() + # Load dataset + load_dataset_fun = load_fun(cfg.dataset.get( + 'load_dataset', 'datasets.sen2venus.load_dataset')) + _, val_dloader, _ = load_dataset_fun(cfg, only_test=True) + # Define metrics for evaluate each image + metrics = build_eval_metrics(cfg) + + indices = dict.fromkeys(metrics.keys(), None) + iterations = min(cfg.num_images, len(val_dloader)) + + with torch.no_grad(): + for index, batch in tqdm( + enumerate(val_dloader), total=iterations, + desc='%d Images' % (iterations)): + + if index >= iterations: + break + + hr = batch["hr"].to(device=cfg.device, non_blocking=True) + lr = batch["lr"].to(device=cfg.device, non_blocking=True) + + sr = model(lr) + + if not torch.is_tensor(sr): + sr, _ = sr + + # denormalize to original values + hr_dn, sr_dn, lr_dn = denorm(hr, sr, lr) + + # normalize [0, 1] using also the outliers to evaluate + hr_eval, sr_eval, lr_eval = evaluable(hr_dn, sr_dn, lr_dn) + # compute metrics + for k, fun in metrics.items(): + for i in range(len(sr)): + res = fun(sr_eval[i][None], hr_eval[i][None]).detach() + indices[k] = res if not res.shape else res.squeeze(0) + + # normalize [0, 1] removing outliers to have a printable version + hr, sr, lr = printable(hr_dn, sr_dn, lr_dn) + # plot images + plot_images(lr, out_path, 'lr', index, cfg.dpi) + plot_images(hr, out_path, 'hr', index, cfg.dpi) + plot_images(sr, out_path, 'sr', index, cfg.dpi) + plot_images((hr - sr).abs(), out_path, 'delta', index, cfg.dpi)