diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..18220bd --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +nohup.out +data/* +output/* +*.pyc diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..d159169 --- /dev/null +++ b/LICENSE @@ -0,0 +1,339 @@ + GNU GENERAL PUBLIC LICENSE + Version 2, June 1991 + + Copyright (C) 1989, 1991 Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The licenses for most software are designed to take away your +freedom to share and change it. By contrast, the GNU General Public +License is intended to guarantee your freedom to share and change free +software--to make sure the software is free for all its users. This +General Public License applies to most of the Free Software +Foundation's software and to any other program whose authors commit to +using it. (Some other Free Software Foundation software is covered by +the GNU Lesser General Public License instead.) You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +this service if you wish), that you receive source code or can get it +if you want it, that you can change the software or use pieces of it +in new free programs; and that you know you can do these things. + + To protect your rights, we need to make restrictions that forbid +anyone to deny you these rights or to ask you to surrender the rights. +These restrictions translate to certain responsibilities for you if you +distribute copies of the software, or if you modify it. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must give the recipients all the rights that +you have. You must make sure that they, too, receive or can get the +source code. And you must show them these terms so they know their +rights. + + We protect your rights with two steps: (1) copyright the software, and +(2) offer you this license which gives you legal permission to copy, +distribute and/or modify the software. + + Also, for each author's protection and ours, we want to make certain +that everyone understands that there is no warranty for this free +software. If the software is modified by someone else and passed on, we +want its recipients to know that what they have is not the original, so +that any problems introduced by others will not reflect on the original +authors' reputations. + + Finally, any free program is threatened constantly by software +patents. We wish to avoid the danger that redistributors of a free +program will individually obtain patent licenses, in effect making the +program proprietary. To prevent this, we have made it clear that any +patent must be licensed for everyone's free use or not licensed at all. + + The precise terms and conditions for copying, distribution and +modification follow. + + GNU GENERAL PUBLIC LICENSE + TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + + 0. This License applies to any program or other work which contains +a notice placed by the copyright holder saying it may be distributed +under the terms of this General Public License. The "Program", below, +refers to any such program or work, and a "work based on the Program" +means either the Program or any derivative work under copyright law: +that is to say, a work containing the Program or a portion of it, +either verbatim or with modifications and/or translated into another +language. (Hereinafter, translation is included without limitation in +the term "modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not +covered by this License; they are outside its scope. The act of +running the Program is not restricted, and the output from the Program +is covered only if its contents constitute a work based on the +Program (independent of having been made by running the Program). +Whether that is true depends on what the Program does. + + 1. You may copy and distribute verbatim copies of the Program's +source code as you receive it, in any medium, provided that you +conspicuously and appropriately publish on each copy an appropriate +copyright notice and disclaimer of warranty; keep intact all the +notices that refer to this License and to the absence of any warranty; +and give any other recipients of the Program a copy of this License +along with the Program. + +You may charge a fee for the physical act of transferring a copy, and +you may at your option offer warranty protection in exchange for a fee. + + 2. You may modify your copy or copies of the Program or any portion +of it, thus forming a work based on the Program, and copy and +distribute such modifications or work under the terms of Section 1 +above, provided that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices + stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in + whole or in part contains or is derived from the Program or any + part thereof, to be licensed as a whole at no charge to all third + parties under the terms of this License. + + c) If the modified program normally reads commands interactively + when run, you must cause it, when started running for such + interactive use in the most ordinary way, to print or display an + announcement including an appropriate copyright notice and a + notice that there is no warranty (or else, saying that you provide + a warranty) and that users may redistribute the program under + these conditions, and telling the user how to view a copy of this + License. (Exception: if the Program itself is interactive but + does not normally print such an announcement, your work based on + the Program is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If +identifiable sections of that work are not derived from the Program, +and can be reasonably considered independent and separate works in +themselves, then this License, and its terms, do not apply to those +sections when you distribute them as separate works. But when you +distribute the same sections as part of a whole which is a work based +on the Program, the distribution of the whole must be on the terms of +this License, whose permissions for other licensees extend to the +entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest +your rights to work written entirely by you; rather, the intent is to +exercise the right to control the distribution of derivative or +collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program +with the Program (or with a work based on the Program) on a volume of +a storage or distribution medium does not bring the other work under +the scope of this License. + + 3. You may copy and distribute the Program (or a work based on it, +under Section 2) in object code or executable form under the terms of +Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable + source code, which must be distributed under the terms of Sections + 1 and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three + years, to give any third party, for a charge no more than your + cost of physically performing source distribution, a complete + machine-readable copy of the corresponding source code, to be + distributed under the terms of Sections 1 and 2 above on a medium + customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer + to distribute corresponding source code. (This alternative is + allowed only for noncommercial distribution and only if you + received the program in object code or executable form with such + an offer, in accord with Subsection b above.) + +The source code for a work means the preferred form of the work for +making modifications to it. For an executable work, complete source +code means all the source code for all modules it contains, plus any +associated interface definition files, plus the scripts used to +control compilation and installation of the executable. However, as a +special exception, the source code distributed need not include +anything that is normally distributed (in either source or binary +form) with the major components (compiler, kernel, and so on) of the +operating system on which the executable runs, unless that component +itself accompanies the executable. + +If distribution of executable or object code is made by offering +access to copy from a designated place, then offering equivalent +access to copy the source code from the same place counts as +distribution of the source code, even though third parties are not +compelled to copy the source along with the object code. + + 4. You may not copy, modify, sublicense, or distribute the Program +except as expressly provided under this License. Any attempt +otherwise to copy, modify, sublicense or distribute the Program is +void, and will automatically terminate your rights under this License. +However, parties who have received copies, or rights, from you under +this License will not have their licenses terminated so long as such +parties remain in full compliance. + + 5. You are not required to accept this License, since you have not +signed it. However, nothing else grants you permission to modify or +distribute the Program or its derivative works. These actions are +prohibited by law if you do not accept this License. Therefore, by +modifying or distributing the Program (or any work based on the +Program), you indicate your acceptance of this License to do so, and +all its terms and conditions for copying, distributing or modifying +the Program or works based on it. + + 6. Each time you redistribute the Program (or any work based on the +Program), the recipient automatically receives a license from the +original licensor to copy, distribute or modify the Program subject to +these terms and conditions. You may not impose any further +restrictions on the recipients' exercise of the rights granted herein. +You are not responsible for enforcing compliance by third parties to +this License. + + 7. If, as a consequence of a court judgment or allegation of patent +infringement or for any other reason (not limited to patent issues), +conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot +distribute so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you +may not distribute the Program at all. For example, if a patent +license would not permit royalty-free redistribution of the Program by +all those who receive copies directly or indirectly through you, then +the only way you could satisfy both it and this License would be to +refrain entirely from distribution of the Program. + +If any portion of this section is held invalid or unenforceable under +any particular circumstance, the balance of the section is intended to +apply and the section as a whole is intended to apply in other +circumstances. + +It is not the purpose of this section to induce you to infringe any +patents or other property right claims or to contest validity of any +such claims; this section has the sole purpose of protecting the +integrity of the free software distribution system, which is +implemented by public license practices. Many people have made +generous contributions to the wide range of software distributed +through that system in reliance on consistent application of that +system; it is up to the author/donor to decide if he or she is willing +to distribute software through any other system and a licensee cannot +impose that choice. + +This section is intended to make thoroughly clear what is believed to +be a consequence of the rest of this License. + + 8. If the distribution and/or use of the Program is restricted in +certain countries either by patents or by copyrighted interfaces, the +original copyright holder who places the Program under this License +may add an explicit geographical distribution limitation excluding +those countries, so that distribution is permitted only in or among +countries not thus excluded. In such case, this License incorporates +the limitation as if written in the body of this License. + + 9. The Free Software Foundation may publish revised and/or new versions +of the General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + +Each version is given a distinguishing version number. If the Program +specifies a version number of this License which applies to it and "any +later version", you have the option of following the terms and conditions +either of that version or of any later version published by the Free +Software Foundation. If the Program does not specify a version number of +this License, you may choose any version ever published by the Free Software +Foundation. + + 10. If you wish to incorporate parts of the Program into other free +programs whose distribution conditions are different, write to the author +to ask for permission. For software which is copyrighted by the Free +Software Foundation, write to the Free Software Foundation; we sometimes +make exceptions for this. Our decision will be guided by the two goals +of preserving the free status of all derivatives of our free software and +of promoting the sharing and reuse of software generally. + + NO WARRANTY + + 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY +FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN +OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES +PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED +OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS +TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE +PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, +REPAIR OR CORRECTION. + + 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR +REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, +INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING +OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED +TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY +YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER +PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE +POSSIBILITY OF SUCH DAMAGES. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +convey the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software; you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation; either version 2 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along + with this program; if not, write to the Free Software Foundation, Inc., + 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this +when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, the commands you use may +be called something other than `show w' and `show c'; they could even be +mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your +school, if any, to sign a "copyright disclaimer" for the program, if +necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the program + `Gnomovision' (which makes passes at compilers) written by James Hacker. + + , 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program into +proprietary programs. If your program is a subroutine library, you may +consider it more useful to permit linking proprietary applications with the +library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..8ad3a77 --- /dev/null +++ b/README.md @@ -0,0 +1,315 @@ +## [Swin2-MoSE: A New Single Image Super-Resolution Model for Remote Sensing](https://arxiv.org/abs/2404.18924) + +Official PyTorch implementation of **Swin2-MoSE**. + +In this paper, we propose **Swin2-MoSE** model, an enhanced version of Swin2SR for +Single-Image Super-Resolution for Remote Sensing. + +Swin2-MoSE Aarchitecture + +Swin2-MoSE MoE-SM + +Swin2-MoSE Positional Encoding + +Authors: Leonardo Rossi, Vittorio Bernuzzi, Tomaso Fontanini, + Massimo Bertozzi, Andrea Prati. + +[IMP Lab](http://implab.ce.unipr.it/) - +Dipartimento di Ingegneria e Architettura + +University of Parma, Italy + + +## Abstract + +Due to the limitations of current optical and sensor technologies and the high cost of updating them, the spectral and spatial resolution of satellites may not always meet desired requirements. +For these reasons, Remote-Sensing Single-Image Super-Resolution (RS-SISR) techniques have gained significant interest. + +In this paper, we propose Swin2-MoSE model, an enhanced version of Swin2SR. + +Our model introduces MoE-SM, an enhanced Mixture-of-Experts (MoE) to replace the Feed-Forward inside all Transformer block. +MoE-SM is designed with Smart-Merger, and new layer for merging the output of individual experts, and with a new way to split the work between experts, defining a new per-example strategy instead of the commonly used per-token one. + +Furthermore, we analyze how positional encodings interact with each other, demonstrating that per-channel bias and per-head bias can positively cooperate. + +Finally, we propose to use a combination of Normalized-Cross-Correlation (NCC) and Structural Similarity Index Measure (SSIM) losses, to avoid typical MSE loss limitations. + +Experimental results demonstrate that Swin2-MoSE outperforms SOTA by up to 0.377 ~ 0.958 dB (PSNR) on task of 2x, 3x and 4x resolution-upscaling (Sen2Venus and OLI2MSI datasets). +We show the efficacy of Swin2-MoSE, applying it to a semantic segmentation task (SeasoNet dataset). + + +## Usage + +### Installation +```bash +$ git clone https://github.com/IMPLabUniPr/swin2-mose/tree/official_code +$ cd swin2-mose +$ conda env create -n swin2_mose_env --file environment.yml +$ conda activate swin2_mose_env +``` + +### Prepare Sen2Venus dataset + +1) After you downloaded the files from + [Sen2Venus](https://zenodo.org/records/6514159) official website, unzip them + inside the `./datasets/sen2venus_original` directory. + +2) Run the script [split.py](https://github.com/IMPLabUniPr/swin2-mose/tree/official_code/scripts/sen2venus/split.py) to split the dataset in training (~80%) and + test (~20%): + +```bash +python scripts/sen2venus/split.py --input ./datasets/sen2venus_original --output ./data/sen2venus +``` + +After the successfull execution of the script, you will find `train.csv` and +`test.csv` files inside the `./data/sen2venus`. + +Note: if you want to skip this run and use our `train.csv` and `test.csv` +files directly, you can download them from +[Release v1.0](https://github.com/IMPLabUniPr/swin2-mose/releases/tag/v1.0) +page. + +3) Run the script [rebuild.py](https://github.com/IMPLabUniPr/swin2-mose/tree/official_code/scripts/sen2venus/rebuild.py) to rebuild the dataset in a compatible + format: + +```bash +python scripts/sen2venus/rebuild.py --data ./datasets/sen2venus_original --output ./data/sen2venus +``` + +If everything went well, you will have the following files structure: + +``` +data/sen2venus +├── test +│   ├── 000000_ALSACE_2018-02-14.pt +│   ├── 000001_ALSACE_2018-02-14.pt +| ... +├── test.csv +├── train +│   ├── 000000_ALSACE_2018-02-14.pt +│   ├── 000001_ALSACE_2018-02-14.pt +| ... +└── train.csv +``` + +Note about Sen2venus: we found a small error in file name convention! + +On paper, authors wrote for `4x` files, the following: + +``` +{id}_05m_b5b6b7b8a.pt - 5m patches (256×256 pix.) for S2 B5, B6, B7 and B8A (from VENµS) +{id}_20m_b5b6b7b8a.pt - 20m patches (64×64 pix.) for S2 B5, B6, B7 and B8A (from Sentinel-2) +``` + +But, we found the following name conventions: + +``` +ALSACE_C_32ULU_2018-02-14_05m_b4b5b6b8a.pt +ALSACE_C_32ULU_2018-02-14_20m_b4b5b6b8a.pt +``` + +### Prepare OLI2MSI dataset + +Download from the +[OLI2MSI](https://github.com/wjwjww/OLI2MSI) official website and unzip it +inside the `./data/oli2msi` directory. + +If everything went well, you will have the following files structure: + +``` +data/oli2msi +├── test_hr +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0071.TIF +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0108.TIF +| ... +├── test_lr +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0071.TIF +│   ├── L8_126038_20190923_S2B_20190923_T49RCQ_N0108.TIF +| ... +├── train_hr +│   ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0008.TIF +│   ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0015.TIF +| ... +└── train_lr + ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0008.TIF + ├── L8_126038_20190923_S2B_20190923_T49RBQ_N0015.TIF + ... +``` + +### Prepare SeasoNet dataset + +Download from the [SeasoNet](https://zenodo.org/records/5850307) official +website and unzip it inside the `./data/SeasoNet/data` directory. + +If everything went well, you will have the following files structure: + +``` +data/SeasoNet +└── data +    ├── fall +    │   ├── grid1 +    │   └── grid2 +    ├── meta.csv +    ├── snow +    │   ├── grid1 +    │   └── grid2 +    ├── splits +    │   ├── test.csv +    │   ├── train.csv +    │   └── val.csv +    ├── spring +    │   ├── grid1 +    │   └── grid2 +    ├── summer +    │   ├── grid1 +    │   └── grid2 +    └── winter +    ├── grid1 +    └── grid2 +``` + +Note about SeasoNet: SeasoNet could be easily downloaded by +[TorchGeo](https://torchgeo.readthedocs.io/en/stable/api/datasets.html#seasonet) +class, specifying the root directory `./data/SeasoNet/data/`. + +### Download pretrained + +Open +[Release v1.0](https://github.com/IMPLabUniPr/swin2-mose/releases/tag/v1.0) +page and download .pt (pretrained) and .pkl (results) file. + +Unzip them inside the output directory, obtaining the following directories +structure: + +``` +output2/sen2venus_exp4_2x_v5/ +├── checkpoints +│   └── model-70.pt +└── eval + └── results-70.pt +``` + +### Train + +```bash +python src/main.py --phase train --config $CONFIG_FILE --output $OUT_DIR --epochs ${EPOCH} --epoch -1 +python src/main_ssegm.py --phase train --config $CONFIG_FILE --output $OUT_DIR --epochs ${EPOCH} --epoch -1 +``` + +### Validate + +```bash +python src/main.py --phase test --config $CONFIG_FILE --output $OUT_DIR --batch_size 32 --epoch ${EPOCH} +python src/main.py --phase test --config $CONFIG_FILE --batch_size 32 --eval_method bicubic +``` + +### Show results + +``` +python src/main.py --phase vis --config $CONFIG_FILE --output $OUT_DIR --num_images 3 --epoch ${EPOCH} +python src/main.py --phase vis --config $CONFIG_FILE --output output/sen2venus_4x_bicubic --num_images 3 --eval_method bicubic +python src/main.py --phase vis --config $CONFIG_FILE --output output/sen2venus_4x_bicubic --num_images 3 --eval_method bicubic --dpi 1200 +python src/main_ssegm.py --phase vis --config $CONFIG_FILE --output $OUT_DIR --num_images 2 --epoch ${EPOCH} +python src/main_ssegm.py --phase vis --config $CONFIG_FILE --output $OUT_DIR --num_images 2 --epoch ${EPOCH} --hide_sr +``` + +### Compute mean/std + +``` +python src/main.py --phase mean_std --config $CONFIG_FILE +python src/main_ssegm.py --phase mean_std --config $CONFIG_FILE +``` + +### Measure execution average time + +``` +python src/main.py --phase avg_time --config $CONFIG_FILE --repeat_times 1000 --warm_times 20 --batch_size 8 +``` + +## Results + +### Table 1 + +Ablation study on loss usage. + +| # | Losses ||| Performace ||| Conf | +| # | NCC | SSIM | MSE | NCC | SSIM | PSNR | | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | x | | | 0.9550 | 0.5804 | 16.4503 | [conf](cfgs/sen2venus_exp1_v1.yml) | +| 2 | | x | | 0.9565 | 0.9847 | 45.5427 | [conf](cfgs/sen2venus_exp1_v2.yml) | +| 3 | | | x | 0.9546 | 0.9828 | 45.4759 | [conf](cfgs/sen2venus_exp1_v3.yml) | +| 4 | x | x | | 0.9572 | 0.9841 | 45.6986 | [conf](cfgs/sen2venus_exp1_v4.yml) | +| 5 | x | | x | 0.9549 | 0.9828 | 45.5163 | [conf](cfgs/sen2venus_exp1_v5.yml) | +| 6 | x | x | x | 0.9555 | 0.9833 | 45.5542 | [conf](cfgs/sen2venus_exp1_v6.yml) | + +### Table 2 + +Ablation study on positional encoding. + +| # | Positional Encoding ||| Performace || Conf | +| # | RPE | log CPB | LePE | SSIM | PSNR | | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | x | | | 0.9841 | 45.5855 | [conf](cfgs/sen2venus_exp2_v1.yml) | +| 2 | | x | | 0.9841 | 45.6986 | [conf](cfgs/sen2venus_exp2_v2.yml) | +| 3 | | | x | 0.9843 | 45.7278 | [conf](cfgs/sen2venus_exp2_v3.yml) | +| 4 | | x | x | 0.9845 | 45.8046 | [conf](cfgs/sen2venus_exp2_v4.yml) | +| 5 | x | | x | 0.9847 | 45.8539 | [conf](cfgs/sen2venus_exp2_v5.yml) | +| 6 | x | x | | 0.9843 | 45.6945 | [conf](cfgs/sen2venus_exp2_v6.yml) | +| 7 | x | x | x | 0.9846 | 45.8185 | [conf](cfgs/sen2venus_exp2_v7.yml) | + +### Table 3 + +Ablation study on positional encoding. + +` | # | | | MLP #Params || Performace || Latency | Conf |` +| # | Arch | SM | APC | SPC | SSIM | PSNR | (s) | | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | MLP | | 32’670 | 32’670 | 0.9847 | 45.8539 | 0.194 | [conf](cfgs/sen2venus_exp3_v1.yml) | +| 2 | MoE 8/2 | | 32’760 | 132’480 | 0.9845 | 45.8647 | 0.202 | [conf](cfgs/sen2venus_exp3_v2.yml) | +| 3 | MoE 8/2 | x | 32’779 | 132’499 | 0.9849 | 45.9272 | 0.212 | [conf](cfgs/sen2venus_exp3_v3.yml) | + +### Table 4 + +Quantitative comparison with SOTA models on Sen2Veµs and OLI2MS datasets. + +| # | Model | Sen2Venus 2x || OLI2MSI 3x || Sen2Venus 4x || Conf ||| +| # | | SSIM | PSNR | SSIM | PSNR | SSIM | PSNR | 2x | 3x | 4x | +|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:| +| 1 | Bicubic | 0.9883 | 45.5588 | 0.9768 | 42.1835 | 0.9674 | 42.0499 | | | | +| 2 | [SwinIR](https://openaccess.thecvf.com/content/ICCV2021W/AIM/html/Liang_SwinIR_Image_Restoration_Using_Swin_Transformer_ICCVW_2021_paper.html) | 0.9938 | 48.7064 | 0.9860 | 43.7482 | 0.9825 | 45.3460 | [conf](cfgs/sen2venus_exp4_2x_v2.yml) | [conf](cfgs/oli2msi_exp4_3x_v2.yml) | [conf](cfgs/sen2venus_exp4_4x_v2.yml) | +| 3 | [Swinfir](https://arxiv.org/abs/2208.11247) | 0.9940 | 48.8532 | 0.9863 | 44.4829 | 0.9830 | 45.5500 | [conf](cfgs/sen2venus_exp4_2x_v3.yml) | [conf](cfgs/oli2msi_exp4_3x_v3.yml) | [conf](cfgs/sen2venus_exp4_4x_v3.yml) | +| 4 | [Swin2SR](https://link.springer.com/chapter/10.1007/978-3-031-25063-7_42) | 0.9942 | 49.0467 | 0.9881 | 44.9614 | 0.9828 | 45.4759 | [conf](cfgs/sen2venus_exp4_2x_v4.yml) | [conf](cfgs/oli2msi_exp4_3x_v4.yml) | [conf](cfgs/sen2venus_exp4_4x_v4.yml) | +| 5 | Swin2-MoSE (ours) | 0.9948 | 49.4784 | 0.9912 | 45.9194 | 0.9849 | 45.9272 | [conf](cfgs/sen2venus_exp4_2x_v5.yml) | [conf](cfgs/oli2msi_exp4_3x_v5.yml) | [conf](cfgs/sen2venus_exp4_4x_v5.yml) | + + +### Figure 11 + +Results for the Semantic Segmentation task on SeasoNet dataset. + +| # | Model | Conf | +|:---:|:---:|:---:| +| 1 | FarSeg | [conf](cfgs/seasonet_exp5_v1.yml) | +| 2 | FarSeg++ | [conf](cfgs/seasonet_exp5_v2.yml) | +| 3 | FarSeg+S2MFE | [conf](cfgs/seasonet_exp5_v3.yml) | + +## License + +See [GPL v2](./LICENSE) License. + +## Acknowledgement + +Project ECS\_00000033\_ECOSISTER funded under the National Recovery and Resilience Plan (NRRP), Mission 4 +Component 2 Investment 1.5 - funded by the European Union – NextGenerationEU. +This research benefits from the HPC (High Performance Computing) facility of the University of Parma, Italy. + +## Citation +If you find our work useful in your research, please cite: + +``` +@article{rossi2024swin2, + title={Swin2-MoSE: A New Single Image Super-Resolution Model for Remote Sensing}, + author={Rossi, Leonardo and Bernuzzi, Vittorio and Fontanini, Tomaso and Bertozzi, Massimo and Prati, Andrea}, + journal={arXiv preprint arXiv:2404.18924}, + year={2024} +} +``` diff --git a/cfgs/oli2msi_exp4_3x_v2.yml b/cfgs/oli2msi_exp4_3x_v2.yml new file mode 100644 index 0000000..442dc4c --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v2.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_4x_v2.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/oli2msi_exp4_3x_v3.yml b/cfgs/oli2msi_exp4_3x_v3.yml new file mode 100644 index 0000000..21acc6b --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v3.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_4x_v3.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/oli2msi_exp4_3x_v4.yml b/cfgs/oli2msi_exp4_3x_v4.yml new file mode 100644 index 0000000..7ed76f2 --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v4.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_4x_v4.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/oli2msi_exp4_3x_v5.yml b/cfgs/oli2msi_exp4_3x_v5.yml new file mode 100644 index 0000000..0469430 --- /dev/null +++ b/cfgs/oli2msi_exp4_3x_v5.yml @@ -0,0 +1,18 @@ +__base__: sen2venus_exp4_v5.yml +dataset: + root_path: data/oli2msi + collate_fn: mods.v6.collate_fn + denorm: mods.v6.uncollate_fn + printable: mods.v6.printable + load_dataset: datasets.oli2msi.load_dataset + hr_name: null + lr_name: null +super_res: { + model: { + upscale: 3, + in_chans: 3, + } +} +metrics: { + upscale_factor: 3, +} diff --git a/cfgs/seasonet_exp5_v1.yml b/cfgs/seasonet_exp5_v1.yml new file mode 100644 index 0000000..dc2da0e --- /dev/null +++ b/cfgs/seasonet_exp5_v1.yml @@ -0,0 +1,48 @@ +batch_size: 128 +losses: { + with_ce_criterion: true, + weights: { + ce: 1.0, + } +} +dataset: + root_path: data/SeasoNet/data + cls: datasets.seasonet.SeasoNetDataset + kwargs: { + seasons: ['Spring'], + bands: ['10m_RGB', '10m_IR',], + grids:[1], + } + collate_fn: mods.v5.collate_fn + stats: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + collate_kwargs: {} + denorm: mods.v5.uncollate_fn + printable: mods.v5.printable +optim: { + learning_rate: 0.0001, + model_betas: [0.9, 0.999], + model_eps: 0.00000001, + model_weight_decay: 0 +} +semantic_segm: { + pad_before: [4, 4, 4, 4], + in_channels: 4, + type: FarSeg, + model: { + backbone: resnet50, + } +} +train: semantic_segm.training.train +mean_std: mods.v5.get_mean_std +metrics: { + eval_every: 1 +} +visualize: { + model: semantic_segm.model.build_model, + checkpoint: chk_loader.load_state_dict_model_only, +} diff --git a/cfgs/seasonet_exp5_v2.yml b/cfgs/seasonet_exp5_v2.yml new file mode 100644 index 0000000..a1258f5 --- /dev/null +++ b/cfgs/seasonet_exp5_v2.yml @@ -0,0 +1,10 @@ +__base__: seasonet_exp5_v1.yml +semantic_segm: { + conv_up: { + in_ch: 4, + middle_ch: 90, + out_ch: 64, + kernel_size: 1, + padding: 0, + } +} diff --git a/cfgs/seasonet_exp5_v3.yml b/cfgs/seasonet_exp5_v3.yml new file mode 100644 index 0000000..607d9b9 --- /dev/null +++ b/cfgs/seasonet_exp5_v3.yml @@ -0,0 +1,10 @@ +__base__: seasonet_exp5_v1.yml +semantic_segm: { + in_channels: 90, + pad_before: [0, 0, 0, 0], + pad_after: [4, 4, 4, 4], + upscaler: { + config: cfgs/sen2venus_exp4_2x_v5.yml, + chk: output/sen2venus_exp4_2x_v5/checkpoints/model-70.pt, + } +} diff --git a/cfgs/sen2venus_4x.yml b/cfgs/sen2venus_4x.yml new file mode 100644 index 0000000..710837c --- /dev/null +++ b/cfgs/sen2venus_4x.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_base.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b5b6b7b8a: { + mean: [1182.8475341796875, 2155.208251953125, 2507.487060546875, 2800.94140625], + std: [594.6590576171875, 643.8070068359375, 777.8865356445312, 829.2948608398438], + min: [-8687.0, -3340.0, -2245.0, -5048.0], + max: [20197.0, 16498.0, 16674.0, 21622.0] + } + tensor_20m_b5b6b7b8a: { + mean: [1180.6920166015625, 2149.5302734375, 2500.98779296875, 2794.01220703125], + std: [592.2827758789062, 639.0105590820312, 769.7623291015625, 819.83349609375], + min: [-446.0, -295.0, -340.0, -551.0], + max: [13375.0, 15898.0, 15551.0, 15079.0] + } + hr_name: tensor_05m_b5b6b7b8a + lr_name: tensor_20m_b5b6b7b8a + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 4, + } +} +visualize: { + input_shape: [4, 64, 64], +} +metrics: { + upscale_factor: 4, +} diff --git a/cfgs/sen2venus_base.yml b/cfgs/sen2venus_base.yml new file mode 100644 index 0000000..7c643bc --- /dev/null +++ b/cfgs/sen2venus_base.yml @@ -0,0 +1,38 @@ +batch_size: 8 +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +optim: { + learning_rate: 0.0001, + model_betas: [0.9, 0.999], + model_eps: 0.00000001, + model_weight_decay: 0 +} +super_res: { + version: 'v2', + model: { + depths: [6, 6, 6, 6], + embed_dim: 90, + img_range: 1., + img_size: 64, + in_chans: 4, + mlp_ratio: 2, + num_heads: [6, 6, 6, 6], + resi_connection: 1conv, + upsampler: pixelshuffledirect, + window_size: 16, + } +} +train: super_res.training.train +mean_std: mods.v3.get_mean_std +visualize: { + checkpoint: chk_loader.load_state_dict_model_only, + model: super_res.model.build_model, +} +metrics: { + only_test_y_channel: false, +} diff --git a/cfgs/sen2venus_exp1_v1.yml b/cfgs/sen2venus_exp1_v1.yml new file mode 100644 index 0000000..e1bbca3 --- /dev/null +++ b/cfgs/sen2venus_exp1_v1.yml @@ -0,0 +1,9 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: false, + with_ssim_criterion: false, + with_cc_criterion: true, + weights: { + cc: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v2.yml b/cfgs/sen2venus_exp1_v2.yml new file mode 100644 index 0000000..3bc23e4 --- /dev/null +++ b/cfgs/sen2venus_exp1_v2.yml @@ -0,0 +1,9 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: false, + with_ssim_criterion: true, + with_cc_criterion: false, + weights: { + ssim: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v3.yml b/cfgs/sen2venus_exp1_v3.yml new file mode 100644 index 0000000..16407e2 --- /dev/null +++ b/cfgs/sen2venus_exp1_v3.yml @@ -0,0 +1,9 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: true, + with_ssim_criterion: false, + with_cc_criterion: false, + weights: { + pixel: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v4.yml b/cfgs/sen2venus_exp1_v4.yml new file mode 100644 index 0000000..659e8eb --- /dev/null +++ b/cfgs/sen2venus_exp1_v4.yml @@ -0,0 +1,10 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: false, + with_ssim_criterion: true, + with_cc_criterion: true, + weights: { + ssim: 1.0, + cc: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v5.yml b/cfgs/sen2venus_exp1_v5.yml new file mode 100644 index 0000000..2cdae3a --- /dev/null +++ b/cfgs/sen2venus_exp1_v5.yml @@ -0,0 +1,10 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: true, + with_ssim_criterion: false, + with_cc_criterion: true, + weights: { + pixel: 1.0, + cc: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp1_v6.yml b/cfgs/sen2venus_exp1_v6.yml new file mode 100644 index 0000000..c5226f2 --- /dev/null +++ b/cfgs/sen2venus_exp1_v6.yml @@ -0,0 +1,11 @@ +__base__: sen2venus_4x.yml +losses: { + with_pixel_criterion: true, + with_ssim_criterion: true, + with_cc_criterion: true, + weights: { + cc: 1.0, + ssim: 1.0, + pixel: 1.0, + }, +} diff --git a/cfgs/sen2venus_exp2_v1.yml b/cfgs/sen2venus_exp2_v1.yml new file mode 100644 index 0000000..b05a4cc --- /dev/null +++ b/cfgs/sen2venus_exp2_v1.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: false, + use_cpb_bias: false, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp2_v2.yml b/cfgs/sen2venus_exp2_v2.yml new file mode 100644 index 0000000..6bbe5c5 --- /dev/null +++ b/cfgs/sen2venus_exp2_v2.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp1_v4.yml diff --git a/cfgs/sen2venus_exp2_v3.yml b/cfgs/sen2venus_exp2_v3.yml new file mode 100644 index 0000000..850787d --- /dev/null +++ b/cfgs/sen2venus_exp2_v3.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: false, + use_rpe_bias: false, + } +} diff --git a/cfgs/sen2venus_exp2_v4.yml b/cfgs/sen2venus_exp2_v4.yml new file mode 100644 index 0000000..4769276 --- /dev/null +++ b/cfgs/sen2venus_exp2_v4.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: true, + use_rpe_bias: false, + } +} diff --git a/cfgs/sen2venus_exp2_v5.yml b/cfgs/sen2venus_exp2_v5.yml new file mode 100644 index 0000000..67eea1e --- /dev/null +++ b/cfgs/sen2venus_exp2_v5.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: false, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp2_v6.yml b/cfgs/sen2venus_exp2_v6.yml new file mode 100644 index 0000000..28a4bf1 --- /dev/null +++ b/cfgs/sen2venus_exp2_v6.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: false, + use_cpb_bias: true, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp2_v7.yml b/cfgs/sen2venus_exp2_v7.yml new file mode 100644 index 0000000..645a502 --- /dev/null +++ b/cfgs/sen2venus_exp2_v7.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp2_v2.yml +super_res: { + model: { + use_lepe: true, + use_cpb_bias: true, + use_rpe_bias: true, + } +} diff --git a/cfgs/sen2venus_exp3_v1.yml b/cfgs/sen2venus_exp3_v1.yml new file mode 100644 index 0000000..0d2d960 --- /dev/null +++ b/cfgs/sen2venus_exp3_v1.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp2_v5.yml diff --git a/cfgs/sen2venus_exp3_v1_wrong.yml b/cfgs/sen2venus_exp3_v1_wrong.yml new file mode 100644 index 0000000..319d71d --- /dev/null +++ b/cfgs/sen2venus_exp3_v1_wrong.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp2_v4.yml diff --git a/cfgs/sen2venus_exp3_v2.yml b/cfgs/sen2venus_exp3_v2.yml new file mode 100644 index 0000000..136c927 --- /dev/null +++ b/cfgs/sen2venus_exp3_v2.yml @@ -0,0 +1,8 @@ +__base__: sen2venus_exp3_v3.yml +super_res: { + model: { + MoE_config: { + with_smart_merger: null + } + } +} diff --git a/cfgs/sen2venus_exp3_v3.yml b/cfgs/sen2venus_exp3_v3.yml new file mode 100644 index 0000000..04d6404 --- /dev/null +++ b/cfgs/sen2venus_exp3_v3.yml @@ -0,0 +1,17 @@ +__base__: sen2venus_exp2_v5.yml +losses: { + weights: { + moe: 0.2 + } +} +super_res: { + model: { + mlp_ratio: 1, + MoE_config: { + k: 2, + num_experts: 8, + with_noise: false, + with_smart_merger: v1, + } + } +} diff --git a/cfgs/sen2venus_exp4_2x_v2.yml b/cfgs/sen2venus_exp4_2x_v2.yml new file mode 100644 index 0000000..c2e1334 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v2.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_4x_v2.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_2x_v3.yml b/cfgs/sen2venus_exp4_2x_v3.yml new file mode 100644 index 0000000..19d3ef9 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v3.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_4x_v3.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_2x_v4.yml b/cfgs/sen2venus_exp4_2x_v4.yml new file mode 100644 index 0000000..5f4ae23 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v4.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_4x_v4.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_2x_v5.yml b/cfgs/sen2venus_exp4_2x_v5.yml new file mode 100644 index 0000000..a60b242 --- /dev/null +++ b/cfgs/sen2venus_exp4_2x_v5.yml @@ -0,0 +1,33 @@ +__base__: sen2venus_exp4_v5.yml +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable +super_res: { + model: { + upscale: 2, + } +} +metrics: { + upscale_factor: 2, +} +visualize: { + input_shape: [4, 128, 128], +} diff --git a/cfgs/sen2venus_exp4_4x_v2.yml b/cfgs/sen2venus_exp4_4x_v2.yml new file mode 100644 index 0000000..f98d8bc --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v2.yml @@ -0,0 +1,4 @@ +__base__: sen2venus_exp4_4x_v4.yml +super_res: { + version: 'v1', +} diff --git a/cfgs/sen2venus_exp4_4x_v3.yml b/cfgs/sen2venus_exp4_4x_v3.yml new file mode 100644 index 0000000..1a1ac91 --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v3.yml @@ -0,0 +1,7 @@ +__base__: sen2venus_exp4_4x_v2.yml +super_res: { + version: 'swinfir', + model: { + resi_connection: SFB, + } +} diff --git a/cfgs/sen2venus_exp4_4x_v4.yml b/cfgs/sen2venus_exp4_4x_v4.yml new file mode 100644 index 0000000..fda8c51 --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v4.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp1_v3.yml diff --git a/cfgs/sen2venus_exp4_4x_v5.yml b/cfgs/sen2venus_exp4_4x_v5.yml new file mode 100644 index 0000000..3369f8f --- /dev/null +++ b/cfgs/sen2venus_exp4_4x_v5.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp3_v3.yml diff --git a/cfgs/sen2venus_exp4_v5.yml b/cfgs/sen2venus_exp4_v5.yml new file mode 100644 index 0000000..3369f8f --- /dev/null +++ b/cfgs/sen2venus_exp4_v5.yml @@ -0,0 +1 @@ +__base__: sen2venus_exp3_v3.yml diff --git a/cfgs/sen2venus_super_res_x2_all.yml b/cfgs/sen2venus_super_res_x2_all.yml new file mode 100644 index 0000000..21e7040 --- /dev/null +++ b/cfgs/sen2venus_super_res_x2_all.yml @@ -0,0 +1,60 @@ +batch_size: 16 +losses: { + with_pixel_criterion: true, + weights: { + pixel: 1.0, + } +} +dataset: + root_path: data/sen2venus + stats: + use_minmax: true + tensor_05m_b2b3b4b8: { + mean: [444.21923828125, 715.9031372070312, 813.4345703125, 2604.867919921875], + std: [279.85552978515625, 385.3569641113281, 648.458984375, 796.9918212890625], + min: [-1025.0, -3112.0, -5122.0, -3851.0], + max: [14748.0, 14960.0, 16472.0, 16109.0] + } + tensor_10m_b2b3b4b8: { + mean: [443.78643798828125, 715.4202270507812, 813.0512084960938, 2602.813232421875], + std: [283.89276123046875, 389.26361083984375, 651.094970703125, 811.5682373046875], + min: [-848.0, -902.0, -946.0, -323.0], + max: [19684.0, 17982.0, 17064.0, 15958.0] + } + hr_name: tensor_05m_b2b3b4b8 + lr_name: tensor_10m_b2b3b4b8 + collate_fn: mods.v3.collate_fn + denorm: mods.v3.uncollate_fn + printable: mods.v3.printable + places: [] +optim: { + learning_rate: 0.0001, + model_betas: [0.9, 0.999], + model_eps: 0.00000001, + model_weight_decay: 0 +} +super_res: { + model: { + upscale: 2, + in_chans: 4, + img_size: 64, + window_size: 16, + img_range: 1., + depths: [6, 6, 6, 6], + embed_dim: 90, + num_heads: [6, 6, 6, 6], + mlp_ratio: 2, + upsampler: pixelshuffledirect, + resi_connection: 1conv + } +} +train: super_res.training.train +mean_std: mods.v3.get_mean_std +visualize: { + model: super_res.model.build_model, + checkpoint: chk_loader.load_state_dict_model_only, +} +metrics: { + only_test_y_channel: false, + upscale_factor: 2, +} diff --git a/conda-environment.yaml b/conda-environment.yaml new file mode 100644 index 0000000..a2dc43a --- /dev/null +++ b/conda-environment.yaml @@ -0,0 +1,172 @@ +name: l1bsr3 +channels: + - pytorch + - nvidia + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - boltons=23.0.0=py310h06a4308_0 + - brotlipy=0.7.0=py310h7f8727e_1002 + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.19.1=h5eee18b_0 + - ca-certificates=2023.08.22=h06a4308_0 + - certifi=2023.7.22=py310h06a4308_0 + - cffi=1.15.1=py310h5eee18b_3 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - conda=23.9.0=py310h06a4308_0 + - conda-libmamba-solver=23.9.1=py310h06a4308_0 + - conda-package-handling=2.2.0=py310h06a4308_0 + - conda-package-streaming=0.9.0=py310h06a4308_0 + - cryptography=41.0.3=py310hdda0065_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.2.140=0 + - cuda-runtime=12.1.0=0 + - ffmpeg=4.3=hf484d3e_0 + - filelock=3.9.0=py310h06a4308_0 + - fmt=9.1.0=hdb19cb5_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gmpy2=2.1.2=py310heeb90bb_0 + - gnutls=3.6.15=he1e5248_0 + - icu=73.1=h6a678d5_0 + - idna=3.4=py310h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46305 + - jinja2=3.1.2=py310h06a4308_0 + - jpeg=9e=h5eee18b_1 + - jsonpatch=1.32=pyhd3eb1b0_0 + - jsonpointer=2.1=pyhd3eb1b0_0 + - krb5=1.20.1=h143b758_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libarchive=3.6.2=h6ac8c49_2 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.7.2.10=0 + - libcurand=10.3.3.141=0 + - libcurl=8.2.1=h251f7ec_0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.17=h5eee18b_1 + - libedit=3.1.20221030=h5eee18b_0 + - libev=4.33=h7f8727e_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libmamba=1.5.1=haf1ee3a_0 + - libmambapy=1.5.1=py310h2dafd23_0 + - libnghttp2=1.52.0=h2d74bed_1 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libpng=1.6.39=h5eee18b_0 + - libsolv=0.7.24=he621ea3_0 + - libssh2=1.10.0=hdbd6064_2 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxml2=2.10.4=hf1b16e4_1 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_0 + - markupsafe=2.1.1=py310h7f8727e_0 + - mkl=2023.1.0=h213fc3f_46343 + - mkl-service=2.4.0=py310h5eee18b_1 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - mpmath=1.3.0=py310h06a4308_0 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - networkx=3.1=py310h06a4308_0 + - numpy=1.26.0=py310h5f9d8c6_0 + - numpy-base=1.26.0=py310hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.11=h7f8727e_2 + - packaging=23.1=py310h06a4308_0 + - pcre2=10.42=hebb0a14_0 + - pillow=10.0.1=py310ha6cbd5a_0 + - pip=23.2.1=py310h06a4308_0 + - pluggy=1.0.0=py310h06a4308_1 + - pybind11-abi=4=hd3eb1b0_1 + - pycosat=0.6.6=py310h5eee18b_0 + - pycparser=2.21=pyhd3eb1b0_0 + - pyopenssl=23.2.0=py310h06a4308_0 + - pysocks=1.7.1=py310h06a4308_0 + - python=3.10.13=h955ad1f_0 + - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - pyyaml=6.0=py310h5eee18b_1 + - readline=8.2=h5eee18b_0 + - reproc=14.2.4=h295c915_1 + - reproc-cpp=14.2.4=h295c915_1 + - requests=2.31.0=py310h06a4308_0 + - ruamel.yaml=0.17.21=py310h5eee18b_0 + - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 + - setuptools=68.0.0=py310h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - sympy=1.11.1=py310h06a4308_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchtriton=2.1.0=py310 + - torchvision=0.16.0=py310_cu121 + - tqdm=4.65.0=py310h2f386ee_0 + - truststore=0.8.0=py310h06a4308_0 + - typing_extensions=4.7.1=py310h06a4308_0 + - tzdata=2023c=h04d1e81_0 + - urllib3=1.26.16=py310h06a4308_0 + - wheel=0.41.2=py310h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - yaml-cpp=0.7.0=h295c915_1 + - zlib=1.2.13=h5eee18b_0 + - zstandard=0.19.0=py310h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==2.0.0 + - cachetools==5.3.1 + - cycler==0.12.1 + - easydict==1.10 + - einops==0.7.0 + - fonttools==4.43.1 + - fsspec==2023.9.2 + - google-auth==2.23.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.59.0 + - kiwisolver==1.4.5 + - markdown==3.4.4 + - matplotlib==3.5.1 + - oauthlib==3.2.2 + - opencv-python==4.8.1.78 + - protobuf==4.24.4 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pyparsing==3.1.1 + - python-dateutil==2.8.2 + - requests-oauthlib==1.3.1 + - rsa==4.9 + - scipy==1.11.3 + - six==1.16.0 + - tensorboard==2.14.1 + - tensorboard-data-server==0.7.1 + - tifffile==2023.9.26 + - timm==0.6.7 + - werkzeug==3.0.0 +prefix: /hpc/home/leonardo.rossi/.conda/envs/l1bsr3 diff --git a/data/.gitignore b/data/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..bbabfe5 --- /dev/null +++ b/environment.yml @@ -0,0 +1,258 @@ +name: swin2_mose_env +channels: + - pytorch + - nvidia + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - bzip2=1.0.8=h7b6447c_0 + - c-ares=1.19.1=h5eee18b_0 + - ca-certificates=2023.08.22=h06a4308_0 + - charset-normalizer=2.0.4=pyhd3eb1b0_0 + - cuda-cudart=12.1.105=0 + - cuda-cupti=12.1.105=0 + - cuda-libraries=12.1.0=0 + - cuda-nvrtc=12.1.105=0 + - cuda-nvtx=12.1.105=0 + - cuda-opencl=12.2.140=0 + - cuda-runtime=12.1.0=0 + - ffmpeg=4.3=hf484d3e_0 + - fmt=9.1.0=hdb19cb5_0 + - freetype=2.12.1=h4a9f257_0 + - giflib=5.2.1=h5eee18b_3 + - gmp=6.2.1=h295c915_3 + - gnutls=3.6.15=he1e5248_0 + - icu=73.1=h6a678d5_0 + - intel-openmp=2023.1.0=hdb19cb5_46305 + - jpeg=9e=h5eee18b_1 + - jsonpatch=1.32=pyhd3eb1b0_0 + - jsonpointer=2.1=pyhd3eb1b0_0 + - krb5=1.20.1=h143b758_1 + - lame=3.100=h7b6447c_0 + - lcms2=2.12=h3be6417_0 + - ld_impl_linux-64=2.38=h1181459_1 + - lerc=3.0=h295c915_0 + - libarchive=3.6.2=h6ac8c49_2 + - libcublas=12.1.0.26=0 + - libcufft=11.0.2.4=0 + - libcufile=1.7.2.10=0 + - libcurand=10.3.3.141=0 + - libcurl=8.2.1=h251f7ec_0 + - libcusolver=11.4.4.55=0 + - libcusparse=12.0.2.55=0 + - libdeflate=1.17=h5eee18b_1 + - libedit=3.1.20221030=h5eee18b_0 + - libev=4.33=h7f8727e_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libiconv=1.16=h7f8727e_2 + - libidn2=2.3.4=h5eee18b_0 + - libjpeg-turbo=2.0.0=h9bf148f_0 + - libmamba=1.5.1=haf1ee3a_0 + - libnghttp2=1.52.0=h2d74bed_1 + - libnpp=12.0.2.50=0 + - libnvjitlink=12.1.105=0 + - libnvjpeg=12.1.1.14=0 + - libpng=1.6.39=h5eee18b_0 + - libsolv=0.7.24=he621ea3_0 + - libssh2=1.10.0=hdbd6064_2 + - libstdcxx-ng=11.2.0=h1234567_1 + - libtasn1=4.19.0=h5eee18b_0 + - libtiff=4.5.1=h6a678d5_0 + - libunistring=0.9.10=h27cfd23_0 + - libuuid=1.41.5=h5eee18b_0 + - libwebp=1.3.2=h11a3e52_0 + - libwebp-base=1.3.2=h5eee18b_0 + - libxml2=2.10.4=hf1b16e4_1 + - llvm-openmp=14.0.6=h9e868ea_0 + - lz4-c=1.9.4=h6a678d5_0 + - mkl=2023.1.0=h213fc3f_46343 + - mkl_fft=1.3.8=py310h5eee18b_0 + - mkl_random=1.2.4=py310hdb19cb5_0 + - mpc=1.1.0=h10f8cd9_1 + - mpfr=4.0.2=hb69a4c5_1 + - ncurses=6.4=h6a678d5_0 + - nettle=3.7.3=hbbd107a_1 + - numpy-base=1.26.0=py310hb5e798b_0 + - openh264=2.1.1=h4ff587b_0 + - openjpeg=2.4.0=h3ad879b_0 + - openssl=3.0.11=h7f8727e_2 + - pcre2=10.42=hebb0a14_0 + - pybind11-abi=4=hd3eb1b0_1 + - pycparser=2.21=pyhd3eb1b0_0 + - python=3.10.13=h955ad1f_0 + - pytorch=2.1.0=py3.10_cuda12.1_cudnn8.9.2_0 + - pytorch-cuda=12.1=ha16c6d3_5 + - pytorch-mutex=1.0=cuda + - readline=8.2=h5eee18b_0 + - reproc=14.2.4=h295c915_1 + - reproc-cpp=14.2.4=h295c915_1 + - ruamel.yaml=0.17.21=py310h5eee18b_0 + - ruamel.yaml.clib=0.2.6=py310h5eee18b_1 + - sqlite=3.41.2=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - torchtriton=2.1.0=py310 + - typing_extensions=4.7.1=py310h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - yaml=0.2.5=h7b6447c_0 + - yaml-cpp=0.7.0=h295c915_1 + - zlib=1.2.13=h5eee18b_0 + - zstd=1.5.5=hc292b87_0 + - pip: + - absl-py==2.0.0 + - aenum==3.1.15 + - affine==2.4.0 + - aiohttp==3.9.0 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.9.3 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.1.0 + - bitsandbytes==0.41.2.post2 + - boltons==23.0.0 + - brotlipy==0.7.0 + - cachetools==5.3.1 + - certifi==2023.7.22 + - cffi==1.15.1 + - click==8.1.7 + - click-plugins==1.1.1 + - cligj==0.7.2 + - cmake==3.25.0 + - conda==23.9.0 + - conda-libmamba-solver==23.9.1 + - conda-package-handling==2.2.0 + - conda-package-streaming==0.9.0 + - cryptography==41.0.3 + - cycler==0.12.1 + - decorator==5.1.1 + - dill==0.3.8 + - docstring-parser==0.15 + - easydict==1.10 + - efficientnet-pytorch==0.7.1 + - einops==0.7.0 + - exceptiongroup==1.1.3 + - executing==2.0.1 + - filelock==3.9.0 + - fiona==1.9.5 + - fonttools==4.43.1 + - frozenlist==1.4.0 + - fsspec==2023.9.2 + - gmpy2==2.1.2 + - google-auth==2.23.2 + - google-auth-oauthlib==1.0.0 + - grpcio==1.59.0 + - huggingface-hub==0.19.4 + - hydra-core==1.3.2 + - idna==3.4 + - importlib-resources==6.1.1 + - ipdb==0.13.13 + - ipython==8.17.2 + - jedi==0.19.1 + - jinja2==3.1.2 + - jsonargparse==4.27.0 + - jsonschema==4.20.0 + - jsonschema-specifications==2023.11.1 + - kiwisolver==1.4.5 + - kornia==0.7.0 + - libmambapy==1.5.1 + - lightly==1.4.21 + - lightly-utils==0.0.2 + - lightning==2.1.2 + - lightning-utilities==0.10.0 + - lit==15.0.7 + - markdown==3.4.4 + - markdown-it-py==3.0.0 + - markupsafe==2.1.1 + - matplotlib==3.5.1 + - matplotlib-inline==0.1.6 + - mdurl==0.1.2 + - mkl-fft==1.3.8 + - mkl-random==1.2.4 + - mkl-service==2.4.0 + - mpmath==1.3.0 + - msgpack==1.0.7 + - multidict==6.0.4 + - munch==4.0.0 + - networkx==3.1 + - ninja==1.11.1.1 + - numpy==1.26.0 + - oauthlib==3.2.2 + - omegaconf==2.3.0 + - opencv-python==4.8.1.78 + - packaging==23.1 + - pandas==2.1.3 + - parso==0.8.3 + - pexpect==4.8.0 + - pillow==10.0.1 + - pip==23.2.1 + - piq==0.8.0 + - pluggy==1.0.0 + - pretrainedmodels==0.7.4 + - prompt-toolkit==3.0.41 + - protobuf==4.24.4 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pyarrow==14.0.1 + - pyasn1==0.5.0 + - pyasn1-modules==0.3.0 + - pycosat==0.6.6 + - pydantic==1.10.13 + - pygments==2.17.1 + - pyopenssl==23.2.0 + - pyparsing==3.1.1 + - pyproj==3.6.1 + - pysocks==1.7.1 + - python-dateutil==2.8.2 + - pytorch-lightning==2.1.2 + - pytz==2023.3.post1 + - pyyaml==6.0 + - rasterio==1.3.9 + - ray==2.8.0 + - referencing==0.31.0 + - requests==2.31.0 + - requests-oauthlib==1.3.1 + - rich==13.7.0 + - rpds-py==0.13.1 + - rsa==4.9 + - rtree==1.1.0 + - ruamel-yaml==0.17.21 + - ruamel-yaml-clib==0.2.6 + - safetensors==0.4.0 + - scipy==1.11.3 + - segmentation-models-pytorch==0.3.3 + - setuptools==68.0.0 + - shapely==2.0.2 + - six==1.16.0 + - snuggs==1.4.7 + - stack-data==0.6.3 + - sympy==1.11.1 + - tensorboard==2.14.1 + - tensorboard-data-server==0.7.1 + - tensorboardx==2.6.2.2 + - tifffile==2023.9.26 + - timm==0.9.2 + - tomli==2.0.1 + - torch==2.0.0+cu118 + - torchgeo==0.5.1 + - torchmetrics==1.2.0 + - torchvision==0.15.0+cu118 + - tqdm==4.65.0 + - traitlets==5.13.0 + - triton==2.0.0 + - truststore==0.8.0 + - typeshed-client==2.4.0 + - typing-extensions==4.7.1 + - tzdata==2023.3 + - urllib3==1.26.16 + - wcwidth==0.2.10 + - werkzeug==3.0.0 + - wheel==0.41.2 + - yarl==1.9.2 + - zstandard==0.19.0 +prefix: /home/hachreak/miniconda2/envs/swin2_mose_env + diff --git a/images/fig1.png b/images/fig1.png new file mode 100644 index 0000000..6b27b49 Binary files /dev/null and b/images/fig1.png differ diff --git a/images/fig2.png b/images/fig2.png new file mode 100644 index 0000000..4097c31 Binary files /dev/null and b/images/fig2.png differ diff --git a/images/fig3.png b/images/fig3.png new file mode 100644 index 0000000..cbd473c Binary files /dev/null and b/images/fig3.png differ diff --git a/output/.gitignore b/output/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/scripts/get_all_metrics.sh b/scripts/get_all_metrics.sh new file mode 100755 index 0000000..2725d47 --- /dev/null +++ b/scripts/get_all_metrics.sh @@ -0,0 +1,53 @@ +#!/bin/bash + +VRS=$1 +EPOCH_INIT=${2:-0} +DATASET=${3:-sen2venus} + +HPC=${HPC:-hpc} + +EPS=`ls -1 output/${HPC}/${DATASET}_v${VRS}/eval| awk -F"-" '{print $2}' | awk -F. '{print $1}'|sort -n` +# echo $EPS + +echo "eval ${DATASET} v$VRS" +for EPOCH in $EPS; do + echo "epoch $EPOCH" + if [ $EPOCH -lt $EPOCH_INIT ]; then + echo "skip epoch $EPOCH" + continue + fi + RESULTS=`./scripts/get_metrics.sh $VRS $EPOCH $DATASET` + # echo $RESULTS + + R_EPOCHS="$R_EPOCHS\n$EPOCH" + + R_PSNR="$R_PSNR\n`echo $RESULTS | awk '{print $1}'`" + R_SSIM="$R_SSIM\n`echo $RESULTS | awk '{print $2}'`" + R_CC="$R_CC\n`echo $RESULTS | awk '{print $3}'`" + R_RMSE="$R_RMSE\n`echo $RESULTS | awk '{print $4}'`" + R_SAM="$R_SAM\n`echo $RESULTS | awk '{print $5}'`" + R_ERGAS="$R_ERGAS\n`echo $RESULTS | awk '{print $6}'`" +done + +echo "########" +echo "" +echo "epochs" +echo -e $R_EPOCHS +echo "" +echo "psnr" +echo -e $R_PSNR +echo "" +echo "ssim" +echo -e $R_SSIM +echo "" +echo "cc" +echo -e $R_CC +echo "" +echo "rmse" +echo -e $R_RMSE +echo "" +echo "sam" +echo -e $R_SAM +echo "" +echo "ergas" +echo -e $R_ERGAS diff --git a/scripts/get_metrics.sh b/scripts/get_metrics.sh new file mode 100755 index 0000000..1c758d7 --- /dev/null +++ b/scripts/get_metrics.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +VRS=$1 +EPOCH=$2 +DATASET=${3:-sen2venus} + +HPC=${HPC:-hpc} + +IFS=$'\n ' +###### + +CONFIG_FILE=cfgs/${DATASET}_v${VRS}.yml +OUT_DIR=output/${HPC}/${DATASET}_v${VRS}/ + +RESULT=`python src/main.py --phase test --config $CONFIG_FILE --output $OUT_DIR --epoch ${EPOCH} | grep PSNR -A1 |tail -n1` + +echo $RESULT | awk '{print $1}' +echo $RESULT | awk '{print $2}' +echo $RESULT | awk '{print $3}' +echo $RESULT | awk '{print $4}' +echo $RESULT | awk '{print $5}' +echo $RESULT | awk '{print $6}' diff --git a/scripts/sen2venus/rebuild.py b/scripts/sen2venus/rebuild.py new file mode 100644 index 0000000..95bc027 --- /dev/null +++ b/scripts/sen2venus/rebuild.py @@ -0,0 +1,78 @@ +import pandas as pd +import os +import argparse +import torch + +from functools import lru_cache + + +def parse_configs(): + parser = argparse.ArgumentParser(description='Sen2Ven rebuilder') + + parser.add_argument('--data', type=str, default='./data/sen2venus', + help='Directory where original data are stored.') + parser.add_argument('--output', type=str, default='./data/sen2venus/split', + help='Directory where save the output files' + ' and load generated csv files.') + + args = parser.parse_args() + return args + + +@lru_cache(maxsize=5) +def cached_torch_load(filename): + print('load pt file {}'.format(filename)) + return torch.load(filename) + + +def load_rows(row): + cols = [name for name in row.keys() if name.startswith('tensor_')] + return { + col: cached_torch_load(os.path.join( + args.data, row['place'], row[col] + ))[row['index']].clone() + for col in cols + } + + +def get_filename(index, row): + return '{:06d}_{}_{}.pt'.format(index, row['place'], row['date']) + + +def get_dataset_type(filename): + _, name = os.path.split(filename) + name, _ = os.path.splitext(name) + return name + + +def rebuild(input_filename, args): + dtype = get_dataset_type(input_filename) + out_dir = os.path.join(args.output, dtype) + os.makedirs(out_dir, exist_ok=True) + + print('load {}'.format(input_filename)) + df = pd.read_csv(input_filename) + df = df.sort_values(['place', 'date']) + for index in range(len(df)): + row = df.iloc[index] + tensors = load_rows(row) + # build filename + filename = get_filename(index, row) + # save tensor + fname = os.path.join(out_dir, filename) + print('save {}'.format(fname)) + torch.save(tensors, fname) + + +if __name__ == "__main__": + # parse input arguments + args = parse_configs() + for k, v in vars(args).items(): + print('{}: {}'.format(k, v)) + + filename = os.path.join(args.output, 'test.csv') + print('rebuild test dataset..') + rebuild(filename, args) + filename = os.path.join(args.output, 'train.csv') + print('rebuild train dataset..') + rebuild(filename, args) diff --git a/scripts/sen2venus/split.py b/scripts/sen2venus/split.py new file mode 100644 index 0000000..32042dc --- /dev/null +++ b/scripts/sen2venus/split.py @@ -0,0 +1,106 @@ +import os +import pandas as pd +import numpy as np +import argparse + +from functools import partial +from sklearn.model_selection import train_test_split + + +def get_list_files(dir_name): + for dirpath, dirs, files in os.walk(dir_name): + for filename in files: + if filename.endswith('index.csv'): + fname = os.path.join(dirpath, filename) + yield fname + + +def get_info_by_id(dataframes, id_): + return dataframes.iloc[np.where(dataframes['start'].values <= id_)[0][-1]] + + +def read_csv_file(): + read_csv = partial(pd.read_csv, sep='\s+') + + def f(fnames): + for fname in fnames: + df = read_csv(fname) + yield fname, df + return f + + +def split_dataset(dir_name, test_size=.2, seed=42): + fnames = list(get_list_files(dir_name)) + dirnames = [] + dataframes = [] + pairs_counter = 0 + for fname, df in read_csv_file()(fnames): + # df = read_csv(fname) + dirname = os.path.split(os.path.split(fname)[0])[1] + # collect csv rows + dataframes.append(df) + # count how many examples on this csv file + count_examples = df['nb_patches'].sum() + # collect directory where the patch is + dirnames.extend([dirname] * count_examples) + # count total (x, y) pairs + pairs_counter += count_examples + dataframes = pd.concat(dataframes) + y = np.array(dirnames) + # index list to shuffle and split in training / test dataset + ids = np.array(list(range(pairs_counter))) + # nb_patches cumulative column (useful to get file from the index) + start_v = [0] + start_v[1:] = dataframes['nb_patches'].cumsum().iloc[:-1] + dataframes['start'] = start_v + # split them + X_train, X_test, y_train, y_test = train_test_split( + ids, y, test_size=test_size, random_state=seed, stratify=y) + return dataframes, X_train, X_test, y_train, y_test + + +def collect_frames_by_ids(dataframes, ids, y): + rows = [get_info_by_id(dataframes, id_).to_dict() for id_ in ids] + df = pd.DataFrame(rows) + df['index'] = ids - df['start'] + df['place'] = y + return df + + +def parse_configs(): + parser = argparse.ArgumentParser(description='Sen2Ven splitter') + + parser.add_argument('--seed', type=int, default=123, metavar='N', + help='random seed (default: 123)') + parser.add_argument('--test_size', type=float, default=.2, metavar='N', + help='testi size (default 0.2)') + parser.add_argument('--input', type=str, default='./data/sen2venus', + help='Directory where original data are stored.') + parser.add_argument('--output', type=str, default='./data/sen2venus/split', + help='Directory where save the output files.') + + args = parser.parse_args() + return args + + +if __name__ == "__main__": + # parse input arguments + args = parse_configs() + for k, v in vars(args).items(): + print('{}: {}'.format(k, v)) + + dataframes, X_train, X_test, y_train, y_test = split_dataset( + dir_name=args.input, test_size=args.test_size, seed=args.seed) + + df_train = collect_frames_by_ids(dataframes, X_train, y_train) + df_test = collect_frames_by_ids(dataframes, X_test, y_test) + + os.makedirs(args.output, exist_ok=True) + # save train csv file + train_file = os.path.join(args.output, 'train.csv') + print('save train file: {}'.format(train_file)) + df_train.to_csv(train_file) + # save test csv file + test_file = os.path.join(args.output, 'test.csv') + print('save train file: {}'.format(test_file)) + df_test.to_csv(test_file) diff --git a/src/chk_loader.py b/src/chk_loader.py new file mode 100644 index 0000000..6d0bfd0 --- /dev/null +++ b/src/chk_loader.py @@ -0,0 +1,60 @@ +import os +import torch +import numpy as np + + +def get_last_epoch(filenames): + epochs = [int(name.split('-')[1].split('.')[0]) for name in filenames] + return filenames[np.array(epochs).argsort()[-1]] + + +def load_checkpoint(cfg): + dir_chk = os.path.join(cfg.output, 'checkpoints') + if cfg.epoch != -1: + path = os.path.join(dir_chk, 'model-{:02d}.pt'.format(cfg.epoch)) + else: + try: + fnames = os.listdir(dir_chk) + path = get_last_epoch(fnames) + path = os.path.join(dir_chk, path) + except IndexError: + raise FileNotFoundError() + # load checkpoint + print('load file {}'.format(path)) + if not os.path.exists(path): + raise FileNotFoundError() + return torch.load(path, map_location=cfg.device) + + +def load_state_dict_model(model, optimizer, checkpoint): + print('load model state') + model.load_state_dict(checkpoint['model_state_dict']) + print('load optimizer state') + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + + return checkpoint['epoch'] + 1, checkpoint['index'] + + +def load_state_dict_model_only(model, checkpoint): + print('load model state') + model.load_state_dict(checkpoint['model_state_dict']) + + return checkpoint['epoch'] + 1, checkpoint['index'] + + +def save_state_dict_model(model, optimizer, epoch, index, cfg): + # save checkpoint + n_epoch = epoch + 1 + if (n_epoch) % cfg.snapshot_interval == 0: + dir_chk = os.path.join(cfg.output, 'checkpoints') + os.makedirs(dir_chk, exist_ok=True) + path = os.path.join(dir_chk, 'model-{:02d}.pt'.format(n_epoch)) + + checkpoint = { + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'index': index, + } + + torch.save(checkpoint, path) diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..29d10cb --- /dev/null +++ b/src/config.py @@ -0,0 +1,135 @@ +import yaml +import os + +from easydict import EasyDict +from functools import reduce + + +class DotDict(EasyDict): + def __getattr__(self, k): + try: + v = self[k] + except KeyError: + return super().__getattr__(k) + if isinstance(v, dict): + return DotDict(v) + return v + + def __setitem__(self, k, v): + if isinstance(k, str) and '.' in k: + k = k.split('.') + if isinstance(k, (list, tuple)): + last = reduce(lambda d, kk: d[kk], k[:-1], self) + last[k[-1]] = v + return + return super().__setitem__(k, v) + + def __getitem__(self, k): + if isinstance(k, str) and '.' in k: + k = k.split('.') + if isinstance(k, (list, tuple)): + return reduce(lambda d, kk: d[kk], k, self) + return super().__getitem__(k) + + # def get(self, k, default=None): + # if isinstance(k, str) and '.' in k: + # try: + # return self[k] + # except KeyError: + # return default + # return super().get(k, default) + + def update(self, mydict): + for k, v in mydict.items(): + self[k] = v + + +def parse_config(args): + print('load config file ', args.config) + cfg = EasyDict(load_config(args.config)) + + var_args = vars(args) + + # check batch_size to overwrite only if defined + if var_args['batch_size'] is None: + del var_args['batch_size'] + + cfg.update(var_args) + + # backup the config file + os.makedirs(cfg.output, exist_ok=True) + with open(os.path.join(cfg.output, 'cfg.yml'), 'w') as bkfile: + yaml.dump(cfg, bkfile, default_flow_style=False) + + # load from node if exists + load_node_dataset(cfg) + # load from nvme if exists + load_nvme_dataset(cfg) + + return cfg + + +def load_config(path, default_path=None): + ''' Loads config file. + + Args: + path (str): path to config file + default_path (bool): whether to use default path + ''' + # Load configuration from file itself + with open(path, 'r') as f: + cfg_special = yaml.safe_load(f) + + # Check if we should inherit from a config + inherit_from = cfg_special.get('__base__') + + # If yes, load this config first as default + # If no, use the default_path + if inherit_from is not None: + dirname = os.path.split(path)[0] + cfg = load_config(os.path.join(dirname, inherit_from), default_path) + elif default_path is not None: + with open(default_path, 'r') as f: + cfg = yaml.load(f) + else: + cfg = dict() + + # Include main configuration + update_recursive(cfg, cfg_special) + + return cfg + + +def update_recursive(dict1, dict2): + ''' Update two config dictionaries recursively. + + Args: + dict1 (dict): first dictionary to be updated + dict2 (dict): second dictionary which entries should be used + + ''' + for k, v in dict2.items(): + # Add item if not yet in dict1 + if k not in dict1: + dict1[k] = None + # Update + if isinstance(dict1[k], dict): + update_recursive(dict1[k], v) + else: + dict1[k] = v + + +def load_nvme_dataset(cfg): + root_path_nvme = cfg.dataset.root_path.rstrip('/') + '_nvme' + if os.path.exists(root_path_nvme): + # update path to read from nvme disk + cfg.dataset.root_path = root_path_nvme + print('load dataset from {}'.format(cfg.dataset.root_path)) + + +def load_node_dataset(cfg): + root_path_node = cfg.dataset.root_path.rstrip('/') + '_node' + if os.path.exists(root_path_node): + # update path to read from nvme disk + cfg.dataset.root_path = root_path_node + print('load dataset from {}'.format(cfg.dataset.root_path)) diff --git a/src/datasets/oli2msi.py b/src/datasets/oli2msi.py new file mode 100644 index 0000000..0ff3b53 --- /dev/null +++ b/src/datasets/oli2msi.py @@ -0,0 +1,118 @@ +# +# Source code: https://github.com/wjwjww/OLI2MSI +# + +import os +import rasterio +import torch + +from torch.utils.data import Dataset, DataLoader + +from utils import load_fun + + +def load_file(filename): + with rasterio.open(filename) as src: + file_ = src.read() + return file_ + + +def _load_files_dir(fdir): + print('load files from {}'.format(fdir)) + file_list = [] + for dirpath, dirs, files in os.walk(fdir): + for filename in files: + if filename.endswith('.TIF'): + file_list.append(os.path.join(dirpath, filename)) + return sorted(file_list) + + +class OLI2MSI(Dataset): + def __init__(self, cfg, is_training=True): + print('dataset OLI2MSI') + self.root_path = cfg.dataset.root_path + self.relname = 'train' + if not is_training: + self.relname = 'test' + self.fdir_lr = os.path.join(self.root_path, self.relname + '_lr') + self.fdir_hr = os.path.join(self.root_path, self.relname + '_hr') + self._load_files() + + def _load_files(self): + self.files_lr = _load_files_dir(self.fdir_lr) + self.files_hr = _load_files_dir(self.fdir_hr) + + def __len__(self): + return len(self.files_lr) + + def __getitem__(self, index): + file_lr = load_file(self.files_lr[index]) + file_hr = load_file(self.files_hr[index]) + return file_lr, file_hr + + +def load_dataset(cfg, only_test=False, concat_datasets=False): + collate_fn = cfg.dataset.get('collate_fn') + if collate_fn is not None: + collate_fn = load_fun(collate_fn) + + persistent_workers = False + if cfg.num_workers > 0: + persistent_workers = True + + train_dset = None + train_dloader = None + concat_dloader = None + + if concat_datasets: + train_dset = OLI2MSI(cfg, is_training=True) + val_dset = OLI2MSI(cfg, is_training=False) + dset = torch.utils.data.ConcatDataset([train_dset, val_dset]) + + shuffle = True + + concat_dloader = DataLoader( + dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg) + ) + + if not only_test: + train_dset = OLI2MSI(cfg, is_training=True) + + train_dloader = DataLoader( + train_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + val_dset = OLI2MSI(cfg, is_training=False) + + shuffle = False + + # TODO distribute also val_dset + val_dloader = DataLoader( + val_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg) + ) + + return train_dloader, val_dloader, concat_dloader diff --git a/src/datasets/seasonet.py b/src/datasets/seasonet.py new file mode 100644 index 0000000..9a56163 --- /dev/null +++ b/src/datasets/seasonet.py @@ -0,0 +1,131 @@ +import torch + +from torch.utils.data import Dataset, DataLoader +from torchgeo.datasets import SeasoNet + +from utils import load_fun + + +class SeasoNetDataset(Dataset): + + classes = [ + 'Continuous urban fabric', + 'Discontinuous urban fabric', + 'Industrial or commercial units', + 'Road and rail networks and associated land', + 'Port areas', + 'Airports', + 'Mineral extraction sites', + 'Dump sites', + 'Construction sites', + 'Green urban areas', + 'Sport and leisure facilities', + 'Non-irrigated arable land', + 'Vineyards', + 'Fruit trees and berry plantations', + 'Pastures', + 'Broad-leaved forest', + 'Coniferous forest', + 'Mixed forest', + 'Natural grasslands', + 'Moors and heathland', + 'Transitional woodland/shrub', + 'Beaches, dunes, sands', + 'Bare rock', + 'Sparsely vegetated areas', + 'Inland marshes', + 'Peat bogs', + 'Salt marshes', + 'Intertidal flats', + 'Water courses', + 'Water bodies', + 'Coastal lagoons', + 'Estuaries', + 'Sea and ocean', + ] + + def __init__(self, cfg, is_training=True): + print('dataset SEASONET') + self.root_path = cfg.dataset.root_path + self.relname = 'train' + if not is_training: + self.relname = 'test' + + self.dset = SeasoNet( + root=self.root_path, + split=self.relname, + **cfg.dataset.kwargs + ) + + cfg.classes = self.classes + self.kwargs = cfg.dataset.kwargs + + def __len__(self): + return len(self.dset) + + def __getitem__(self, index): + item = self.dset[index] + + if self.kwargs.get('bands') == ['20m']: + # get only B5, B6, B7, B8A because sen2venus + item['image'] = item['image'][:4] + + return item + + +def load_dataset(cfg, only_test=False, concat_datasets=False): + dataset_cls = load_fun(cfg.dataset.get('cls', 'SeasoNetDataset')) + + collate_fn = cfg.dataset.get('collate_fn') + if collate_fn is not None: + collate_fn = load_fun(collate_fn) + + persistent_workers = False + if cfg.num_workers > 0: + persistent_workers = True + + train_dset = None + train_dloader = None + concat_dloader = None + + if concat_datasets: + train_dset = dataset_cls(cfg, is_training=True) + val_dset = dataset_cls(cfg, is_training=False) + dset = torch.utils.data.ConcatDataset([train_dset, val_dset]) + concat_dloader = DataLoader( + dset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + if not only_test: + train_dset = dataset_cls(cfg, is_training=True) + train_dloader = DataLoader( + train_dset, + batch_size=cfg.batch_size, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + val_dset = dataset_cls(cfg, is_training=False) + val_dloader = DataLoader( + val_dset, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg), + ) + + return train_dloader, val_dloader, concat_dloader diff --git a/src/datasets/sen2venus.py b/src/datasets/sen2venus.py new file mode 100644 index 0000000..fb74206 --- /dev/null +++ b/src/datasets/sen2venus.py @@ -0,0 +1,112 @@ +import os +import torch + +from functools import lru_cache +from torch.utils.data import Dataset, DataLoader + +from utils import load_fun + + +@lru_cache(maxsize=10) +def cached_torch_load(filename): + return torch.load(filename) + + +class Sen2VenusDataset(Dataset): + def __init__(self, cfg, is_training=True): + self.root_path = cfg.dataset.root_path + self.relname = 'train' + if not is_training: + self.relname = 'test' + self.fdir = os.path.join(self.root_path, self.relname) + self._load_files() + self._filter_files(cfg) + + def _filter_files(self, cfg): + places = cfg.dataset.get('places') + if places is not None and not places == []: + self.files = list(filter( + lambda name: name.lower().split('_')[1] in places, + self.files)) + + def _load_files(self): + print('load {} files from {}'.format(self.relname, self.fdir)) + self.files = [] + for dirpath, dirs, files in os.walk(self.fdir): + for filename in files: + if filename.endswith('.pt'): + self.files.append(os.path.join(dirpath, filename)) + + def __len__(self): + return len(self.files) + + def __getitem__(self, index): + return cached_torch_load(self.files[index]) + + +def load_dataset(cfg, only_test=False, concat_datasets=False): + dataset_cls = globals()[cfg.dataset.get('cls', 'Sen2VenusDataset')] + hr_name = cfg.dataset.hr_name + lr_name = cfg.dataset.lr_name + + collate_fn = cfg.dataset.get('collate_fn') + if collate_fn is not None: + collate_fn = load_fun(collate_fn) + + persistent_workers = False + if cfg.num_workers > 0: + persistent_workers = True + + train_dset = None + train_dloader = None + concat_dloader = None + + if concat_datasets: + train_dset = dataset_cls(cfg, is_training=True) + val_dset = dataset_cls(cfg, is_training=False) + dset = torch.utils.data.ConcatDataset([train_dset, val_dset]) + + shuffle = True + + concat_dloader = DataLoader( + dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=shuffle, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg, hr_name=hr_name, lr_name=lr_name) + ) + + if not only_test: + train_dset = dataset_cls(cfg, is_training=True) + + train_dloader = DataLoader( + train_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=True, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg, hr_name=hr_name, lr_name=lr_name) + ) + + val_dset = dataset_cls(cfg, is_training=False) + + val_dloader = DataLoader( + val_dset, + batch_size=cfg.batch_size, + sampler=None, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=True, + drop_last=False, + persistent_workers=persistent_workers, + collate_fn=collate_fn(cfg, hr_name=hr_name, lr_name=lr_name) + ) + + return train_dloader, val_dloader, concat_dloader diff --git a/src/debug.py b/src/debug.py new file mode 100644 index 0000000..2f59313 --- /dev/null +++ b/src/debug.py @@ -0,0 +1,60 @@ +import torch + +from time import time + +from utils import load_fun + + +def log_losses(losses, phase, writer, index): + # Write the data during training to the training log file + for k, v in losses.items(): + writer.add_scalar("{}/{}".format(phase, k), v.item(), index) + + +def _get_abs_weights_grads(model): + return torch.cat([ + p.grad.detach().view(-1) for p in model.parameters() + if p.requires_grad + ]).abs() + + +def _get_abs_weights(model): + return torch.cat([ + p.detach().view(-1) for p in model.parameters() + if p.requires_grad + ]).abs() + + +def _perf_measure(model, data, count, warm_count=0): + for _ in range(warm_count): + model(data) + start = time() + for _ in range(count): + model(data) + end = time() + return (end - start) / count + + +def measure_avg_time(cfg): + vis = cfg.get('visualize', {}) + model = load_fun(vis.get('model'))(cfg) + x = torch.rand([cfg.batch_size, ] + vis.input_shape).to(cfg.device) + return _perf_measure(model, x, cfg.repeat_times, cfg.warm_times) + + +def log_hr_stats(lr, sr, hr, writer, index, cfg): + if cfg.get('debug', {}).get('sr_hr', False): + def log_delta(delta, name): + writer.add_scalar('stats/mean_{}'.format(name), delta.mean(), + index) + writer.add_scalar('stats/max_{}'.format(name), delta.max(), index) + + q_0_99 = torch.quantile(delta, 0.99, interpolation='nearest') + q_0_999 = torch.quantile(delta, 0.999, interpolation='nearest') + writer.add_scalar('stats/q099_{}'.format(name), q_0_99, index) + writer.add_scalar('stats/q0999_{}'.format(name), q_0_999, index) + + sf = cfg.metrics.upscale_factor + upscale = torch.nn.Upsample(scale_factor=sf, mode='bicubic') + log_delta((sr - upscale(lr)).abs(), 'sr_lr') + log_delta((sr - hr).abs(), 'sr_hr') diff --git a/src/imgproc.py b/src/imgproc.py new file mode 100644 index 0000000..3ce0672 --- /dev/null +++ b/src/imgproc.py @@ -0,0 +1,59 @@ +# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import math +import os +import random +from typing import Any + +import cv2 +import numpy as np +import torch +from numpy import ndarray +from torch import Tensor +import tifffile + + +def tensor_to_image(tensor: Tensor, range_norm: bool, half: bool): + """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type + + Args: + tensor (Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1] + range_norm (bool): Scale [-1, 1] data to between [0, 1] + half (bool): Whether to convert torch.float32 similarly to torch.half type. + + Returns: + image (np.ndarray): Data types supported by PIL or OpenCV + + Examples: + >>> example_image = cv2.imread("lr_image.bmp") + >>> example_tensor = image_to_tensor(example_image, range_norm=False, half=False) + + """ + if range_norm: + tensor = tensor.add(1.0).div(2.0) + if half: + tensor = tensor.half() + + # image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8") + # i dati non vengono più convertiti in uint8 ma restano in float32 + # se l'immagine è RGB/RGBA fa la permutazione, altrimenti no + # facendo la permutazione ad un'immagine a 12 canali, il risultato SR sarebbe di dimensione 12x636 e 628 canali + if 1 < tensor.size()[1] <= 4: + # image = tensor.squeeze(0).permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("float32") + image = tensor.squeeze(0).permute(1, 2, 0).clamp(0, 1).cpu().numpy().astype("float32") + else: + # image = tensor.squeeze(0).mul(255).clamp(0, 255).cpu().numpy().astype("float32") + image = tensor.squeeze(0).clamp(0, 1).cpu().numpy().astype("float32") + + return image diff --git a/src/losses/__init__.py b/src/losses/__init__.py new file mode 100644 index 0000000..82e6363 --- /dev/null +++ b/src/losses/__init__.py @@ -0,0 +1,17 @@ +from torch import nn + +from . import metrics_loss + + +def build_losses(cfg): + losses = {} + if cfg.losses.get('with_ce_criterion', False): + losses['ce_criterion'] = nn.CrossEntropyLoss() + if cfg.losses.get('with_pixel_criterion', False): + losses['pixel_criterion'] = nn.MSELoss() + if cfg.losses.get('with_cc_criterion', False): + losses['cc_criterion'] = metrics_loss.cc_loss + if cfg.losses.get('with_ssim_criterion', False): + losses['ssim_criterion'] = metrics_loss.ssim_loss(cfg) + print(losses) + return losses diff --git a/src/losses/metrics_loss.py b/src/losses/metrics_loss.py new file mode 100644 index 0000000..e8465e5 --- /dev/null +++ b/src/losses/metrics_loss.py @@ -0,0 +1,46 @@ +import piq + +from metrics import _cc_single_torch +from utils import load_fun + + +def norm_0_to_1(fun): + def wrapper(cfg): + dset = cfg.dataset + use_minmax = cfg.dataset.get('stats').get('use_minmax', False) + denorm = load_fun(dset.get('denorm'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name) + evaluable = load_fun(dset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + + rfun = fun(cfg) + + def f(sr, hr): + hr, sr, _ = evaluable(*denorm(hr, sr, None)) + sr = sr.clamp(min=0, max=1) + hr = hr.clamp(min=0, max=1) + return rfun(sr, hr) + + return f + return wrapper + + +@norm_0_to_1 +def ssim_loss(cfg): + criterion = piq.SSIMLoss() + + def f(sr, hr): + return criterion(sr, hr) + + return f + + +def cc_loss(sr, hr): + cc_value = _cc_single_torch(sr, hr) + return 1 - ((cc_value + 1) * .5) diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..d07e81b --- /dev/null +++ b/src/main.py @@ -0,0 +1,127 @@ +import argparse +import torch + +from config import parse_config +from utils import load_fun, set_deterministic +from visualize import main as vis_main +from validation import main as val_main, print_metrics as val_print_metrics, \ + load_metrics +from debug import measure_avg_time + + +def parse_configs(): + parser = argparse.ArgumentParser(description='SuperRes model') + # For training and testing + parser.add_argument('--config', + default="cfgs/sen2venus_v26_8.yml", + help='Configuration file.') + parser.add_argument('--phase', + default='train', + choices=['train', 'test', 'mean_std', 'vis', + 'plot_data', 'avg_time'], + help='Training or testing or play phase.') + parser.add_argument('--seed', + type=int, + default=123, + metavar='N', + help='random seed (default: 123)') + parser.add_argument('--batch_size', + type=int, + default=None, + metavar='B', + help='Batch size. If defined, overwrite cfg file.') + help_num_workers = 'The number of workers to load dataset. Default: 0' + parser.add_argument('--num_workers', + type=int, + default=0, + metavar='N', + help=help_num_workers) + parser.add_argument('--output', + default='./output/sen2venus_v26_8', + help='Directory where save the output.') + help_epoch = 'The epoch to restart from (training) or to eval (testing).' + parser.add_argument('--epoch', + type=int, + default=-1, + help=help_epoch) + parser.add_argument('--epochs', + type=int, + default=200, + metavar='N', + help='number of epoches (default: 50)') + help_snapshot = 'The epoch interval of model snapshot (default: 10)' + parser.add_argument('--snapshot_interval', + type=int, + default=1, + metavar='N', + help=help_snapshot) + parser.add_argument("--num_images", + type=int, + default=10, + help="Number of images to plot") + parser.add_argument('--eval_method', + default=None, type=str, + help='Non-DL method to use on evaluation.') + parser.add_argument('--repeat_times', + type=int, + default=1000, + help='Measure times repeating model call') + help_warm = 'Warm model calling it before starting the measure' + parser.add_argument('--warm_times', + type=int, + default=10, + help=help_warm) + parser.add_argument('--dpi', + type=int, + default=2400, + help="dpi in png output file.") + + args = parser.parse_args() + return parse_config(args) + + +def main(cfg): + cfg.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + load_dataset_fun = load_fun(cfg.dataset.get( + 'load_dataset', 'datasets.sen2venus.load_dataset')) + + if cfg.phase == 'avg_time': + print(measure_avg_time(cfg)) + elif cfg.phase == 'train': + train_fun = load_fun(cfg.get('train', 'srgan.training.train')) + train_dloader, val_dloader, _ = load_dataset_fun(cfg) + train_fun = load_fun(cfg.get('train', 'srgan.training.train')) + train_fun(train_dloader, val_dloader, cfg) + elif cfg.phase == 'mean_std': + if 'stats' in cfg.dataset.keys(): + cfg.dataset.pop('stats') + _, _, concat_dloader = load_dataset_fun(cfg, concat_datasets=True) + fun = load_fun(cfg.get(cfg.phase)) + fun(concat_dloader, cfg) + elif cfg.phase == 'plot_data': + _, _, concat_dloader = load_dataset_fun(cfg, concat_datasets=True) + fun = load_fun(cfg.get(cfg.phase)) + fun(concat_dloader, cfg) + elif cfg.phase == 'vis': + cfg['batch_size'] = 1 + vis_main(cfg) + elif cfg.phase == 'test': + try: + if cfg.eval_method is not None: + raise FileNotFoundError() + metrics = load_metrics(cfg) + except FileNotFoundError: + _, val_dloader, _ = load_dataset_fun(cfg, only_test=True) + metrics = val_main( + val_dloader, cfg, save_metrics=cfg.eval_method is None) + val_print_metrics(metrics) + + +if __name__ == "__main__": + # parse input arguments + cfg = parse_configs() + # fix random seed + set_deterministic(cfg.seed) + # run main + main(cfg) diff --git a/src/main_ssegm.py b/src/main_ssegm.py new file mode 100644 index 0000000..5070ce2 --- /dev/null +++ b/src/main_ssegm.py @@ -0,0 +1,101 @@ +import argparse +import torch + +from config import parse_config +from datasets.seasonet import load_dataset +from utils import load_fun, set_deterministic +from validation import print_metrics as val_print_metrics +from semantic_segm.validation import load_metrics, main as val_main +from semantic_segm.visualize import main as vis_main + + +def parse_configs(): + parser = argparse.ArgumentParser(description='Semantic Segmentation model') + # For training and testing + parser.add_argument('--config', + default="cfgs/seasonet_v1.yml", + help='Configuration file.') + parser.add_argument('--phase', + default='test', + choices=['train', 'test', 'mean_std', 'vis'], + help='Training or testing or play phase.') + parser.add_argument('--seed', + type=int, + default=123, + metavar='N', + help='random seed (default: 123)') + parser.add_argument('--batch_size', + type=int, + default=None, + metavar='B', + help='Batch size. If defined, overwrite cfg file.') + help_num_workers = 'The number of workers to load dataset. Default: 0' + parser.add_argument('--num_workers', + type=int, + default=0, + metavar='N', + help=help_num_workers) + parser.add_argument('--output', + default='./output/demo', + help='Directory where save the output.') + help_epoch = 'The epoch to restart from (training) or to eval (testing).' + parser.add_argument('--epoch', + type=int, + default=-1, + help=help_epoch) + parser.add_argument('--epochs', + type=int, + default=200, + metavar='N', + help='number of epoches (default: 50)') + help_snapshot = 'The epoch interval of model snapshot (default: 10)' + parser.add_argument('--snapshot_interval', + type=int, + default=1, + metavar='N', + help=help_snapshot) + parser.add_argument("--num_images", + type=int, + default=10, + help="Number of images to plot") + parser.add_argument('--hide_sr', + action='store_true', + default=False) + parser.add_argument('--dpi', + type=int, + default=2400, + help="dpi in png output file.") + + args = parser.parse_args() + return parse_config(args) + + +if __name__ == "__main__": + # parse input arguments + cfg = parse_configs() + # set random seed + set_deterministic(cfg.seed) + + cfg.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + if cfg.phase == 'train': + train_dloader, val_dloader, _ = load_dataset(cfg) + train_fun = load_fun(cfg.train) + train_fun(train_dloader, val_dloader, cfg) + elif cfg.phase == 'mean_std': + if 'stats' in cfg.dataset.keys(): + cfg.dataset.pop('stats') + _, _, concat_dloader = load_dataset(cfg, concat_datasets=True) + fun = load_fun(cfg.get(cfg.phase)) + fun(concat_dloader, cfg) + elif cfg.phase == 'vis': + vis_main(cfg) + elif cfg.phase == 'test': + try: + metrics = load_metrics(cfg) + except FileNotFoundError: + # validate if not already done + _, val_dloader, _ = load_dataset(cfg) + metrics = val_main(val_dloader, cfg) + # print metrics + val_print_metrics(metrics) diff --git a/src/metrics.py b/src/metrics.py new file mode 100644 index 0000000..c88f74c --- /dev/null +++ b/src/metrics.py @@ -0,0 +1,317 @@ +# Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import collections.abc +import math +import typing +import warnings +from itertools import repeat +from typing import Any + +import cv2 +import numpy as np +import torch +import piq +from numpy import ndarray +from scipy.io import loadmat +from scipy.ndimage.filters import convolve +from scipy.special import gamma +from torch import nn +from torch.nn import functional as F + + +class piq_metric(object): + def __init__(self, cfg): + pass + + def to(self, device): + return self + + def __call__(self, x, y): + x = x.clone() + x[x < 0] = 0. + x[x > 1] = 1. + return self._metric(x, y) + + def _metric(self, x, y): + pass + + +class piq_psnr(piq_metric): + def _metric(self, x, y): + return piq.psnr(x, y) + + +class piq_ssim(piq_metric): + def _metric(self, x, y): + return piq.ssim(x, y) + + +class piq_rmse(piq_metric): + def __init__(self, cfg): + self.mse = nn.MSELoss() + + def _metric(self, x, y): + return torch.sqrt(self.mse(x, y)) + + +def _ergas_single_torch(raw_tensor: torch.Tensor, + dst_tensor: torch.Tensor): + """ + Compute the ERGAS (Erreur Relative Globale Adimensionnelle De Synthèse) metric for a pair of input tensors. + + ERGAS measures the relative global error of synthesis for remote sensing or image processing tasks. + It evaluates the quality of an output image concerning a reference image, taking into account spectral bands. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared (typically the reconstructed image). + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + ERGAS (torch.Tensor): The ERGAS metric score. + + """ + # Compute the number of spectral bands + N_spectral = raw_tensor.shape[1] + + # Reshape images for processing + raw_tensor_reshaped = raw_tensor.view(N_spectral, -1) + dst_tensor_reshaped = dst_tensor.view(N_spectral, -1) + N_pixels = raw_tensor_reshaped.shape[1] + + # Assuming HR size is 256x256 and LR size is 128x128 + hr_size = torch.tensor(256).cuda() + lr_size = torch.tensor(128).cuda() + + # Calculate the beta value + beta = (hr_size / lr_size).cuda() + + # Calculate RMSE of each band + rmse = torch.sqrt(torch.nansum((dst_tensor_reshaped - raw_tensor_reshaped) ** 2, dim=1) / N_pixels) + mu_ref = torch.mean(dst_tensor_reshaped, dim=1) + + # Calculate ERGAS + ERGAS = 100 * (1 / beta ** 2) * torch.sqrt(torch.nansum(torch.div(rmse, mu_ref) ** 2) / N_spectral) + + return ERGAS + + +class ERGAS(nn.Module): + """ + PyTorch implementation of ERGAS (Erreur Relative Globale Adimensionnelle De Synthèse) metric. + + ERGAS measures the relative global error of synthesis for remote sensing or image processing tasks. + It evaluates the quality of an output image concerning a reference image, taking into account spectral bands. + + Args: + None + + Attributes: + None + + Methods: + forward(raw_tensor, dst_tensor): + Compute ERGAS metric between two input tensors representing images. + + Example: + ergas_calculator = ERGAS() + raw_image = torch.tensor(...) # Replace with your raw image data + dst_image = torch.tensor(...) # Replace with your reference image data + ergas_score = ergas_calculator(raw_image, dst_image) + print(f"ERGAS Score: {ergas_score.item()}") + + """ + + def __init__(self): + super().__init__() + + def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): + """ + Compute ERGAS metric between two input tensors representing images. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared (typically the reconstructed image). + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + ergas_metrics (torch.Tensor): The ERGAS metric score. + + Note: + ERGAS measures the relative global error of synthesis for remote sensing or image processing tasks. + It evaluates the quality of an output image concerning a reference image, taking into account spectral bands. + + """ + ergas_metrics = _ergas_single_torch(raw_tensor, dst_tensor) + return ergas_metrics + + +def _cc_single_torch(raw_tensor: torch.Tensor, + dst_tensor: torch.Tensor): + + """ + Compute the Cross-Correlation (CC) metric between two input tensors representing images. + + CC measures the similarity between two images by calculating the cross-correlation coefficient between spectral bands. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + CC (torch.Tensor): The Cross-Correlation (CC) metric score. + + """ + N_spectral = raw_tensor.shape[1] + + # Reshaping fused and reference data + raw_tensor_reshaped = raw_tensor.view(N_spectral, -1) + dst_tensor_reshaped = dst_tensor.view(N_spectral, -1) + + # Calculating mean value + mean_raw = torch.mean(raw_tensor_reshaped, 1).unsqueeze(1) + mean_dst = torch.mean(dst_tensor_reshaped, 1).unsqueeze(1) + + CC = torch.sum((raw_tensor_reshaped - mean_raw) * (dst_tensor_reshaped - mean_dst), 1) / torch.sqrt( + torch.sum((raw_tensor_reshaped - mean_raw) ** 2, 1) * torch.sum((dst_tensor_reshaped - mean_dst) ** 2, 1)) + + CC = torch.mean(CC) + + return CC + + +class CC(nn.Module): + """ + PyTorch implementation of the Cross-Correlation (CC) metric for image similarity. + + CC measures the similarity between two images by calculating the cross-correlation coefficient between spectral bands. + + Args: + None + + Attributes: + None + + Methods: + forward(raw_tensor, dst_tensor): + Compute the Cross-Correlation (CC) metric between two input tensors representing images. + + Example: + cc_calculator = CC() + raw_image = torch.tensor(...) # Replace with your raw image data + dst_image = torch.tensor(...) # Replace with your reference image data + cc_score = cc_calculator(raw_image, dst_image) + print(f"Cross-Correlation Score: {cc_score.item()}") + + """ + + def __init__(self): + super().__init__() + + def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): + """ + Compute the Cross-Correlation (CC) metric between two input tensors representing images. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + cc_metrics (torch.Tensor): The Cross-Correlation (CC) metric score. + + Note: + CC measures the similarity between two images by calculating the cross-correlation coefficient between spectral bands. + + """ + cc_metrics = _cc_single_torch(raw_tensor, dst_tensor) + return cc_metrics + + +def _sam_single_torch(raw_tensor: torch.Tensor, + dst_tensor: torch.Tensor): + """ + Compute the Spectral Angle Mapper (SAM) metric between two input tensors representing images. + + SAM measures the spectral similarity between two images by calculating the spectral angle between corresponding pixels. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + SAM (torch.Tensor): The Spectral Angle Mapper (SAM) metric score. + + """ + # Compute the number of spectral bands + N_spectral = raw_tensor.shape[1] + + # Reshape fused and reference data + raw_tensor_reshaped = raw_tensor.view(N_spectral, -1) + dst_tensor_reshaped = dst_tensor.view(N_spectral, -1) + N_pixels = raw_tensor_reshaped.shape[1] + + # Calculate the inner product + inner_prod = torch.nansum(raw_tensor_reshaped * dst_tensor_reshaped, 0) + raw_norm = torch.nansum(raw_tensor_reshaped ** 2, dim=0).sqrt() + dst_norm = torch.nansum(dst_tensor_reshaped ** 2, dim=0).sqrt() + + # Calculate SAM + SAM = torch.rad2deg(torch.nansum(torch.acos(inner_prod / (raw_norm * dst_norm))) / N_pixels) + + return SAM + + +class SAM(nn.Module): + """ + PyTorch implementation of the Spectral Angle Mapper (SAM) metric for spectral similarity. + + SAM measures the spectral similarity between two images by calculating the spectral angle between corresponding pixels. + + Args: + None + + Attributes: + None + + Methods: + forward(raw_tensor, dst_tensor): + Compute the Spectral Angle Mapper (SAM) metric between two input tensors representing images. + + Example: + sam_calculator = SAM() + raw_image = torch.tensor(...) # Replace with your raw image data + dst_image = torch.tensor(...) # Replace with your reference image data + sam_score = sam_calculator(raw_image, dst_image) + print(f"Spectral Angle Mapper Score: {sam_score.item()}") + + """ + + def __init__(self): + super().__init__() + + def forward(self, raw_tensor: torch.Tensor, dst_tensor: torch.Tensor): + """ + Compute the Spectral Angle Mapper (SAM) metric between two input tensors representing images. + + Args: + raw_tensor (torch.Tensor): The image tensor to be compared. + dst_tensor (torch.Tensor): The reference image tensor. + + Returns: + sam_metrics (torch.Tensor): The Spectral Angle Mapper (SAM) metric score. + + Note: + SAM measures the spectral similarity between two images by calculating the spectral angle between corresponding pixels. + + """ + sam_metrics = _sam_single_torch(raw_tensor, dst_tensor) + return sam_metrics diff --git a/src/mods/v3.py b/src/mods/v3.py new file mode 100644 index 0000000..d821c15 --- /dev/null +++ b/src/mods/v3.py @@ -0,0 +1,162 @@ +import torch + +from torchvision import transforms +from torch.utils.data import default_collate +from tqdm import tqdm + + +def collate_fn(cfg, hr_name, lr_name): + print('collate_fn hr field: ', hr_name) + print('collate_fn lr field: ', lr_name) + + def do_nothing(x): + return x + norm_lr = do_nothing + norm_hr = do_nothing + + if 'stats' in cfg.dataset and cfg.phase != 'plot_data': + hr_mean = cfg.dataset.stats.get(hr_name).mean + hr_std = cfg.dataset.stats.get(hr_name).std + lr_mean = cfg.dataset.stats.get(lr_name).mean + lr_std = cfg.dataset.stats.get(lr_name).std + norm_hr = transforms.Normalize(mean=hr_mean, std=hr_std) + norm_lr = transforms.Normalize(mean=lr_mean, std=lr_std) + + def f(batch): + batch = default_collate(batch) + + lr = norm_lr(batch[lr_name].float()) + hr = norm_hr(batch[hr_name].float()) + + return {'lr': lr, 'hr': hr} + return f + + +def uncollate_fn(cfg, hr_name, lr_name): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view((t2.shape[:2] + (1, 1))) + return t1 + + def denorm(tensor, mean, std): + # get stats + mean = torch.tensor(mean).to(cfg.device) + std = torch.tensor(std).to(cfg.device) + # denorm + return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) + + def f(hr=None, sr=None, lr=None): + stats_hr = cfg.dataset.stats.get(hr_name) + stats_lr = cfg.dataset.stats.get(lr_name) + + if hr is not None: + hr = denorm(hr, stats_hr.mean, stats_hr.std) + if sr is not None: + sr = denorm(sr, stats_hr.mean, stats_hr.std) + if lr is not None: + lr = denorm(lr, stats_lr.mean, stats_lr.std) + + return hr, sr, lr + + return f + + +def printable(cfg, hr_name, lr_name, filter_outliers=False, use_minmax=False): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view(t1.shape + (1, 1)) + return t1 + + def _printable(tensor, stats): + if use_minmax: + max_ = to_shape(torch.tensor( + stats.max), tensor).to(cfg.device) + min_ = to_shape(torch.tensor( + stats.min), tensor).to(cfg.device) + else: + # get stats + mean = torch.tensor(stats.mean).to(cfg.device) + std = torch.tensor(stats.std).to(cfg.device) + # compute min and max + max_ = to_shape(mean + std, tensor) + min_ = to_shape(mean - std, tensor) + + # fitler outliers if needed to visualize them + if filter_outliers: + tensor = torch.min(torch.max(tensor, min_), max_) + # printable + return (tensor - min_) / (max_ - min_) + + def f(hr=None, sr=None, lr=None): + stats_hr = cfg.dataset.stats.get(hr_name) + stats_lr = cfg.dataset.stats.get(lr_name) + + if hr is not None: + hr = _printable(hr, stats_hr) + if sr is not None: + sr = _printable(sr, stats_hr) + if lr is not None: + lr = _printable(lr, stats_lr) + + return hr, sr, lr + + return f + + +class MeanStd(object): + def __init__(self, device): + self.channels_sum = 0 + self.channels_sqrd_sum = 0 + self.num_batches = 0 + self.max = torch.full((4,), -torch.inf, device=device) + self.min = torch.full((4,), torch.inf, device=device) + + def __call__(self, data): + # max + d_perm = data.permute(1, 0, 2, 3) + d_max = d_perm.reshape(data.shape[1], -1).max(dim=1)[0] + self.max = torch.where(d_max > self.max, d_max, self.max) + # min + d_min = d_perm.reshape(data.shape[1], -1).min(dim=1)[0] + self.min = torch.where(d_min <= self.min, d_min, self.min) + # mean, std + self.channels_sum += torch.mean(data, dim=[0, 2, 3]) + self.channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3]) + self.num_batches += 1 + + def get_mean_std(self): + mean = self.channels_sum / self.num_batches + std = (self.channels_sqrd_sum / self.num_batches - mean**2) ** 0.5 + return mean.tolist(), std.tolist() + + def get_min_max(self): + return self.min.tolist(), self.max.tolist() + + +def get_mean_std(train_dloader, cfg): + hr_ms = MeanStd(device=cfg.device) + lr_ms = MeanStd(device=cfg.device) + + for index, batch in tqdm( + enumerate(train_dloader), total=len(train_dloader)): + hr = batch["hr"].to(device=cfg.device, non_blocking=True).float() + lr = batch["lr"].to(device=cfg.device, non_blocking=True).float() + + hr_ms(hr) + lr_ms(lr) + + hr_mean, hr_std = hr_ms.get_mean_std() + lr_mean, lr_std = lr_ms.get_mean_std() + hr_min, hr_max = hr_ms.get_min_max() + lr_min, lr_max = lr_ms.get_min_max() + + print('HR (mean, std, min, max)') + print('mean: {},'.format(hr_mean)) + print('std: {},'.format(hr_std)) + print('min: {},'.format(hr_min)) + print('max: {}'.format(hr_max)) + print('LR (mean, std, min, max)') + print('mean: {},'.format(lr_mean)) + print('std: {},'.format(lr_std)) + print('min: {},'.format(lr_min)) + print('max: {}'.format(lr_max)) diff --git a/src/mods/v5.py b/src/mods/v5.py new file mode 100644 index 0000000..fd797bf --- /dev/null +++ b/src/mods/v5.py @@ -0,0 +1,94 @@ +import torch + +from torchvision import transforms +from torch.utils.data import default_collate +from tqdm import tqdm + +from .v3 import MeanStd + + +def collate_fn(cfg): + print('collate_fn') + + def do_nothing(x): + return x + norm_fun = do_nothing + + if 'stats' in cfg.dataset and cfg.phase != 'plot_data': + mean = cfg.dataset.stats.mean + std = cfg.dataset.stats.std + norm_fun = transforms.Normalize(mean=mean, std=std) + + def f(batch): + batch = default_collate(batch) + batch['image'] = norm_fun(batch['image'].float()) + return batch + return f + + +def uncollate_fn(cfg): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view((t2.shape[:2] + (1, 1))) + return t1 + + def denorm(tensor, mean, std): + # get stats + mean = torch.tensor(mean).to(cfg.device) + std = torch.tensor(std).to(cfg.device) + # denorm + return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) + + def f(sr, lr): + stats = cfg.dataset.stats + if sr is not None: + sr = denorm(sr.float(), stats.mean, stats.std) + lr = denorm(lr.float(), stats.mean, stats.std) + return sr, lr + + return f + + +def printable(cfg): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view(t1.shape + (1, 1)) + return t1 + + def _printable(tensor, stats): + max_ = to_shape(torch.tensor( + stats.max), tensor).to(cfg.device) + min_ = to_shape(torch.tensor( + stats.min), tensor).to(cfg.device) + + # printable + return (tensor - min_) / (max_ - min_) + + def f(sr, lr): + stats = cfg.dataset.stats + # batch = default_collate(batch) + if sr is not None: + sr = _printable(sr, stats) + lr = _printable(lr, stats) + return sr, lr + + return f + + +def get_mean_std(train_dloader, cfg): + image_ms = MeanStd(device=cfg.device) + + for index, batch in tqdm( + enumerate(train_dloader), total=len(train_dloader)): + image = batch["image"].to(device=cfg.device, non_blocking=True).float() + + image_ms(image) + + image_mean, image_std = image_ms.get_mean_std() + image_min, image_max = image_ms.get_min_max() + + print('Image (mean, std, min, max)') + print('mean: {},'.format(image_mean)) + print('std: {},'.format(image_std)) + print('min: {},'.format(image_min)) + print('max: {}'.format(image_max)) diff --git a/src/mods/v6.py b/src/mods/v6.py new file mode 100644 index 0000000..b7c29f4 --- /dev/null +++ b/src/mods/v6.py @@ -0,0 +1,102 @@ +import torch + +from torchvision import transforms +from torch.utils.data import default_collate +from tqdm import tqdm + + +def collate_fn(cfg): + + def f(batch): + batch = default_collate(batch) + + lr = batch[0].float() + hr = batch[1].float() + + return {'lr': lr, 'hr': hr} + return f + + +def uncollate_fn(cfg, hr_name, lr_name): + def to_shape(t1, t2): + t1 = t1[None].repeat(t2.shape[0], 1) + t1 = t1.view((t2.shape[:2] + (1, 1))) + return t1 + + def denorm(tensor, mean, std): + # get stats + mean = torch.tensor(mean).to(cfg.device) + std = torch.tensor(std).to(cfg.device) + # denorm + return (tensor * to_shape(std, tensor)) + to_shape(mean, tensor) + + def f(hr=None, sr=None, lr=None): + return hr, sr, lr + + return f + + +def printable(cfg, hr_name, lr_name, filter_outliers=False, use_minmax=False): + def f(hr=None, sr=None, lr=None): + return hr, sr, lr + + return f + + +class MeanStd(object): + def __init__(self, device): + self.channels_sum = 0 + self.channels_sqrd_sum = 0 + self.num_batches = 0 + self.max = torch.full((4,), -torch.inf, device=device) + self.min = torch.full((4,), torch.inf, device=device) + + def __call__(self, data): + # max + d_perm = data.permute(1, 0, 2, 3) + d_max = d_perm.reshape(data.shape[1], -1).max(dim=1)[0] + self.max = torch.where(d_max > self.max, d_max, self.max) + # min + d_min = d_perm.reshape(data.shape[1], -1).min(dim=1)[0] + self.min = torch.where(d_min <= self.min, d_min, self.min) + # mean, std + self.channels_sum += torch.mean(data, dim=[0, 2, 3]) + self.channels_sqrd_sum += torch.mean(data**2, dim=[0, 2, 3]) + self.num_batches += 1 + + def get_mean_std(self): + mean = self.channels_sum / self.num_batches + std = (self.channels_sqrd_sum / self.num_batches - mean**2) ** 0.5 + return mean.tolist(), std.tolist() + + def get_min_max(self): + return self.min.tolist(), self.max.tolist() + + +def get_mean_std(train_dloader, cfg): + hr_ms = MeanStd(device=cfg.device) + lr_ms = MeanStd(device=cfg.device) + + for index, batch in tqdm( + enumerate(train_dloader), total=len(train_dloader)): + hr = batch["hr"].to(device=cfg.device, non_blocking=True).float() + lr = batch["lr"].to(device=cfg.device, non_blocking=True).float() + + hr_ms(hr) + lr_ms(lr) + + hr_mean, hr_std = hr_ms.get_mean_std() + lr_mean, lr_std = lr_ms.get_mean_std() + hr_min, hr_max = hr_ms.get_min_max() + lr_min, lr_max = lr_ms.get_min_max() + + print('HR (mean, std, min, max)') + print('mean: {},'.format(hr_mean)) + print('std: {},'.format(hr_std)) + print('min: {},'.format(hr_min)) + print('max: {}'.format(hr_max)) + print('LR (mean, std, min, max)') + print('mean: {},'.format(lr_mean)) + print('std: {},'.format(lr_std)) + print('min: {},'.format(lr_min)) + print('max: {}'.format(lr_max)) diff --git a/src/optim.py b/src/optim.py new file mode 100644 index 0000000..aa70419 --- /dev/null +++ b/src/optim.py @@ -0,0 +1,13 @@ +from torch import optim + + +def build_optimizer(model, cfg): + o_cfg = cfg + + optimizer = optim.Adam(model.parameters(), + o_cfg.optim.learning_rate, + o_cfg.optim.model_betas, + o_cfg.optim.model_eps, + o_cfg.optim.model_weight_decay) + + return optimizer diff --git a/src/semantic_segm/model.py b/src/semantic_segm/model.py new file mode 100644 index 0000000..a04b268 --- /dev/null +++ b/src/semantic_segm/model.py @@ -0,0 +1,100 @@ +import torch + +from torchgeo.models import FarSeg +from easydict import EasyDict +from torch import nn +from torch.nn import functional as F + +from chk_loader import load_state_dict_model_only +from config import load_config +from super_res.model import build_model as super_res_build_model +from utils import set_required_grad + + +class UpScaler(nn.Module): + def __init__(self, cfg): + super().__init__() + u_cfg = cfg.semantic_segm.upscaler + + # load dl upscaler + if u_cfg.get('chk') is not None: + print('load super_res model config file ', u_cfg.config) + model_cfg = EasyDict(load_config(u_cfg.config)) + model_cfg.device = cfg.device + # load super_res + print("loading super_res model") + self.super_res = super_res_build_model(model_cfg) + # load checkpoint + print("loading super_res chk file {}".format(u_cfg.chk)) + swin_checkpoint = torch.load(u_cfg.chk) + load_state_dict_model_only(self.super_res, swin_checkpoint) + # put eval mode + print('set super_res eval mode') + set_required_grad(self.super_res, False) + + def forward(self, x): + with torch.no_grad(): + x = self.super_res.forward_backbone(x) + return x + + +class SemanticSegm(nn.Module): + def __init__(self, cfg): + super().__init__() + ssegm = cfg.semantic_segm + + # padding before and after + self.pad_before = ssegm.pad_before + self.pad_after = self.pad_before + if 'pad_after' in ssegm: + self.pad_after = ssegm.pad_after + + # segmentation model + if ssegm.type == 'FarSeg': + self.model = FarSeg(classes=len(cfg.classes), **ssegm.model) + out_ch = self.model.backbone.conv1.out_channels + self.model.backbone.conv1 = nn.Conv2d( + ssegm.in_channels, out_ch, kernel_size=7, stride=2, padding=3 + ) + + # super res model + if 'upscaler' in ssegm: + self.upscaler = UpScaler(cfg) + # or conv [4, H, W] -> [90, H, W] for model++ + if 'conv_up' in ssegm: + print('model++ with conv_up') + print(ssegm.conv_up) + kernel_size = ssegm.conv_up.get('kernel_size', 1) + padding = ssegm.conv_up.get('padding', 0) + self.conv_up = nn.Conv2d( + ssegm.conv_up.in_ch, ssegm.conv_up.middle_ch, + kernel_size=kernel_size, padding=padding) + self.model.backbone.conv1 = nn.Conv2d( + ssegm.conv_up.middle_ch, ssegm.conv_up.out_ch, + kernel_size=7, stride=2, padding=3) + + def forward(self, x): + # upscale input if super res model is defined + if hasattr(self, 'upscaler'): + x = self.upscaler(x) + if not torch.is_tensor(x): + # remove loss_moe value + x, _ = x + + pad_x = F.pad(x, self.pad_before, "constant", 0) + + # or use conv to increase channels (for baseline++) + if hasattr(self, 'conv_up'): + pad_x = self.conv_up(pad_x) + + # semantic segmantion model + pad_x = self.model(pad_x) + x = pad_x[..., + self.pad_after[0]:-self.pad_after[1], + self.pad_after[2]:-self.pad_after[3]] + + return x + + +def build_model(cfg): + return SemanticSegm(cfg).to(cfg.device) diff --git a/src/semantic_segm/training.py b/src/semantic_segm/training.py new file mode 100644 index 0000000..83ed929 --- /dev/null +++ b/src/semantic_segm/training.py @@ -0,0 +1,77 @@ +import datetime + +import debug + +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +from .validation import validate, save_metrics +from chk_loader import load_checkpoint, load_state_dict_model, \ + save_state_dict_model +from optim import build_optimizer +from .model import build_model +from losses import build_losses + + +def train(train_dloader, val_dloader, cfg): + # Tensorboard + writer = SummaryWriter(cfg.output + '/tensorboard/train_{}'.format( + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))) + # eval every x + eval_every = cfg.metrics.get('eval_every', 1) + + model = build_model(cfg) + losses = build_losses(cfg) + optimizer = build_optimizer(model, cfg) + + begin_epoch = 0 + index = 0 + try: + checkpoint = load_checkpoint(cfg) + begin_epoch, index = load_state_dict_model( + model, optimizer, checkpoint) + except FileNotFoundError: + print('no checkpoint found') + + for e in range(begin_epoch, cfg.epochs): + index = train_epoch( + model, train_dloader, losses, optimizer, e, writer, index, cfg) + + if (e+1) % eval_every == 0: + result = validate( + model, val_dloader, e, writer, 'test', cfg) + # save result of eval + cfg.epoch = e+1 + save_metrics(result, cfg) + + save_state_dict_model(model, optimizer, e, index, cfg) + + +def train_epoch(model, train_dloader, losses, optimizer, epoch, writer, + index, cfg): + weights = cfg.losses.weights + for index, batch in tqdm( + enumerate(train_dloader, index), total=len(train_dloader), + desc='Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + + batch['image'] = batch['image'].to(cfg.device) + batch['mask'] = batch['mask'].to(cfg.device) + + out = model(batch['image']) + + loss_tracker = {} + + if 'ce_criterion' in losses: + loss_tracker['ce_loss'] = losses['ce_criterion']( + out, batch['mask']) * weights.ce + + # train + loss_tracker['train_loss'] = sum(loss_tracker.values()) + optimizer.zero_grad() + loss_tracker['train_loss'].backward() + optimizer.step() + + # log_stats(writer, index, cfg) + debug.log_losses(loss_tracker, 'train', writer, index) + + return index diff --git a/src/semantic_segm/validation.py b/src/semantic_segm/validation.py new file mode 100644 index 0000000..a2e1191 --- /dev/null +++ b/src/semantic_segm/validation.py @@ -0,0 +1,99 @@ +import torch + +from tqdm import tqdm +from collections import OrderedDict + +from utils import F1AverageMeter, load_fun +from chk_loader import load_checkpoint +from validation import get_result_filename + + +def validate(g_model, val_dloader, epoch, writer, mode, cfg): + g_model.eval() + avg_metrics = build_avg_metrics(cfg) + + with torch.no_grad(): + for j, batch in tqdm( + enumerate(val_dloader), total=len(val_dloader), + desc='Val Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + + batch['image'] = batch['image'].to(cfg.device) + batch['mask'] = batch['mask'].to(cfg.device) + + out = g_model(batch['image']) + preds = out.max(-1)[1] + + for k, fun in avg_metrics.items(): + avg_metrics[k].update((preds, batch['mask'])) + + if writer is not None: + for k, v in avg_metrics.items(): + try: + writer.add_scalar( + "{}/{}".format(mode, k), v.avg.item(), epoch+1) + except RuntimeError: + # skip if metric is a list (like f1 per class) + pass + + g_model.train() + return avg_metrics + + +def build_avg_metrics(cfg): + return OrderedDict([ + ('f1_macro', F1AverageMeter( + name="f1_macro", fmt=":4.4f", average='macro', cfg=cfg)), + ('f1_micro', F1AverageMeter( + name="f1_micro", fmt=":4.4f", average='micro', cfg=cfg)), + ('f1_class', F1AverageMeter( + name="f1_class", fmt=":4.4f", average=None, cfg=cfg)), + ]) + + +def main(val_dloader, cfg): + model = load_eval_method(cfg) + result = validate( + model, val_dloader, cfg.epoch - 1, None, 'test', cfg) + save_metrics(result, cfg) + return result + + +def load_eval_method(cfg): + vis = cfg.visualize + model = load_fun(vis.get('model'))(cfg) + # Load model state dict + try: + checkpoint = load_checkpoint(cfg) + _, _ = load_fun(vis.get('checkpoint'))(model, checkpoint) + except Exception as e: + print(e) + exit(-1) + + return model + + +def load_metrics(cfg): + filename = get_result_filename(cfg) + print('load results {}'.format(filename)) + result = torch.load(filename) + # check if epoch corresponds + assert result['epoch'] == cfg.epoch + # loadl classes + cfg.classes = result['classes'] + # build AVG objects + avg_metrics = build_avg_metrics(cfg) + for k, v in result['metrics'].items(): + avg_metrics[k].avg = v + return avg_metrics + + +def save_metrics(metrics, cfg): + filename = get_result_filename(cfg) + print('save results {}'.format(filename)) + torch.save({ + 'classes': cfg.classes, + 'epoch': cfg.epoch, + 'metrics': OrderedDict([ + (k, v.avg) for k, v in metrics.items() + ]) + }, filename) diff --git a/src/semantic_segm/visualize.py b/src/semantic_segm/visualize.py new file mode 100644 index 0000000..448c3b1 --- /dev/null +++ b/src/semantic_segm/visualize.py @@ -0,0 +1,122 @@ +import torch +import os +import matplotlib.pyplot as plt +import imgproc + +from tqdm import tqdm + +from chk_loader import load_checkpoint +from datasets.seasonet import load_dataset +from utils import load_fun + + +def plot_images(img, out_path, basename, fname, batch_number, dpi): + # dir: out_dir / basename + out_path = os.path.join(out_path, basename) + if not os.path.exists(out_path): + os.makedirs(out_path) + + n_cols = 1 + n_bands = img.shape[1] + fig, axs = plt.subplots( + n_bands, n_cols, sharex=True, sharey=True, figsize=(n_bands, 1)) + if img is not None: + img = img.squeeze(0) + for i in range(img.shape[0]): + axs[i].imshow( + imgproc.tensor_to_image(img[i].detach(), False, False)) + axs[i].axis('off') + + out_fname = os.path.join(out_path, '{}_{}.png'.format(fname, batch_number)) + plt.savefig(out_fname, dpi=dpi) + plt.close() + + +def save_fig(output, filename, dpi): + plt.imshow(output.cpu().numpy().squeeze(0)) + plt.savefig(filename, dpi=dpi, pad_inches=0.01) + print('image saved: {}'.format(filename)) + + +def plot_classes(image, out_path, basename, index, cfg): + # dir: out_dir / basename + out_path = os.path.join(out_path, basename) + if not os.path.exists(out_path): + os.makedirs(out_path) + + for idx, name in enumerate(cfg.classes): + output = (image == idx).int() * 255. + unq = output.unique() + # import ipdb; ipdb.set_trace() + if len(unq) > 1 or unq.item() != 0: + print('class ', idx) + # import ipdb; ipdb.set_trace() + fname = 'image_{:02d}_{:02d}.jpg'.format(index, idx) + full_name = os.path.join(out_path, fname) + save_fig(output, full_name, cfg.dpi) + + +def main(cfg): + cfg['batch_size'] = 1 + vis = cfg.get('visualize', {}) + # Load dataset + _, val_dloader, _ = load_dataset(cfg, only_test=True) + # Initialize model + model = load_fun(vis.get('model'))(cfg) + denorm = load_fun(cfg.dataset.get('denorm'))(cfg) + # evaluable = load_fun(cfg.dataset.get('printable'))(cfg) + printable = load_fun(cfg.dataset.get('printable'))(cfg) + + # Load model state dict + try: + checkpoint = load_checkpoint(cfg) + _, _ = load_fun(vis.get('checkpoint'))(model, checkpoint) + except Exception: + print('no model checkpoint found') + exit(0) + + # Create a folder of super-resolution experiment results + out_path = os.path.join(cfg.output, 'images_{}'.format(cfg.epoch)) + if not os.path.exists(out_path): + os.makedirs(out_path) + + model.to(cfg.device) + model.eval() + + iterations = min(cfg.num_images, len(val_dloader)) + + for index, batch in tqdm( + enumerate(val_dloader), total=iterations, + desc='%d Images' % (iterations)): + + if index >= iterations: + break + + batch['image'] = batch['image'].to(cfg.device) + batch['mask'] = batch['mask'].to(cfg.device) + + out = model(batch['image']) + preds = out.max(-1)[1] + + plot_classes(preds, out_path, 'preds', index, cfg) + plot_classes(batch['mask'], out_path, 'gt', index, cfg) + + sr = None + if not cfg.hide_sr: + if hasattr(model, 'upscaler'): + sr = model.upscaler(batch['image']) + + if not torch.is_tensor(sr): + sr, _ = sr + + # denormalize to original values + sr, lr = denorm(sr, batch['image']) + + # normalize [0, 1] removing outliers to have a printable version + sr, lr = printable(sr, lr) + + # plot images + plot_images(lr, out_path, 'images', 'lr', index, cfg.dpi) + + if not cfg.hide_sr and sr is not None: + plot_images(sr, out_path, 'images', 'sr', index, cfg.dpi) diff --git a/src/super_res/model.py b/src/super_res/model.py new file mode 100644 index 0000000..fb1fb98 --- /dev/null +++ b/src/super_res/model.py @@ -0,0 +1,15 @@ + +def build_model(cfg): + version = cfg.super_res.get('version', 'v1') + print('load super_res {}'.format(version)) + + if version == 'v1': + from .network_swinir import SwinIR as SRModel + elif version == 'v2': + from .network_swin2sr import Swin2SR as SRModel + elif version == 'swinfir': + from .swinfir_arch import SwinFIR as SRModel + + model = SRModel(**cfg.super_res.model).to(cfg.device) + + return model diff --git a/src/super_res/moe.py b/src/super_res/moe.py new file mode 100644 index 0000000..5815c21 --- /dev/null +++ b/src/super_res/moe.py @@ -0,0 +1,324 @@ +# +# Source code: https://github.com/davidmrau/mixture-of-experts +# + +# Sparsely-Gated Mixture-of-Experts Layers. +# See "Outrageously Large Neural Networks" +# https://arxiv.org/abs/1701.06538 +# +# Author: David Rau +# +# The code is based on the TensorFlow implementation: +# https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/expert_utils.py + + +import torch +import torch.nn as nn +from torch.distributions.normal import Normal +from copy import deepcopy +import numpy as np + +from .utils import Mlp as MLP + + +class SparseDispatcher(object): + """Helper for implementing a mixture of experts. + The purpose of this class is to create input minibatches for the + experts and to combine the results of the experts to form a unified + output tensor. + There are two functions: + dispatch - take an input Tensor and create input Tensors for each expert. + combine - take output Tensors from each expert and form a combined output + Tensor. Outputs from different experts for the same batch element are + summed together, weighted by the provided "gates". + The class is initialized with a "gates" Tensor, which specifies which + batch elements go to which experts, and the weights to use when combining + the outputs. Batch element b is sent to expert e iff gates[b, e] != 0. + The inputs and outputs are all two-dimensional [batch, depth]. + Caller is responsible for collapsing additional dimensions prior to + calling this class and reshaping the output to the original shape. + See common_layers.reshape_like(). + Example use: + gates: a float32 `Tensor` with shape `[batch_size, num_experts]` + inputs: a float32 `Tensor` with shape `[batch_size, input_size]` + experts: a list of length `num_experts` containing sub-networks. + dispatcher = SparseDispatcher(num_experts, gates) + expert_inputs = dispatcher.dispatch(inputs) + expert_outputs = [experts[i](expert_inputs[i]) for i in range(num_experts)] + outputs = dispatcher.combine(expert_outputs) + The preceding code sets the output for a particular example b to: + output[b] = Sum_i(gates[b, i] * experts[i](inputs[b])) + This class takes advantage of sparsity in the gate matrix by including in the + `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. + """ + + def __init__(self, num_experts, gates): + """Create a SparseDispatcher.""" + + self._gates = gates + self._num_experts = num_experts + # sort experts + sorted_experts, index_sorted_experts = torch.nonzero(gates).sort(0) + # drop indices + _, self._expert_index = sorted_experts.split(1, dim=1) + # get according batch index for each expert + self._batch_index = torch.nonzero(gates)[index_sorted_experts[:, 1], 0] + # calculate num samples that each expert gets + self._part_sizes = (gates > 0).sum(0).tolist() + # expand gates to match with self._batch_index + gates_exp = gates[self._batch_index.flatten()] + self._nonzero_gates = torch.gather(gates_exp, 1, self._expert_index) + + def dispatch(self, inp): + """Create one input Tensor for each expert. + The `Tensor` for a expert `i` contains the slices of `inp` corresponding + to the batch elements `b` where `gates[b, i] > 0`. + Args: + inp: a `Tensor` of shape "[batch_size, ]` + Returns: + a list of `num_experts` `Tensor`s with shapes + `[expert_batch_size_i, ]`. + """ + + # assigns samples to experts whose gate is nonzero + + # expand according to batch index so we can just split by _part_sizes + inp_exp = inp[self._batch_index].squeeze(1) + return torch.split(inp_exp, self._part_sizes, dim=0) + + def combine(self, expert_out, multiply_by_gates=True, cnn_combine=None): + """Sum together the expert output, weighted by the gates. + The slice corresponding to a particular batch element `b` is computed + as the sum over all experts `i` of the expert output, weighted by the + corresponding gate values. If `multiply_by_gates` is set to False, the + gate values are ignored. + Args: + expert_out: a list of `num_experts` `Tensor`s, each with shape + `[expert_batch_size_i, ]`. + multiply_by_gates: a boolean + Returns: + a `Tensor` with shape `[batch_size, ]`. + """ + # apply exp to expert outputs, so we are not longer in log space + stitched = torch.cat(expert_out, 0) + + if multiply_by_gates: + stitched = stitched.mul(self._nonzero_gates.unsqueeze(1)) + zeros = torch.zeros((self._gates.size(0),) + expert_out[-1].shape[1:], + requires_grad=True, device=stitched.device) + # combine samples that have been processed by the same k experts + + if cnn_combine is not None: + return self.smartly_combine(stitched, cnn_combine) + + combined = zeros.index_add(0, self._batch_index, stitched.float()) + return combined + + def smartly_combine(self, stitched, cnn_combine): + idxes = [] + for i in self._batch_index.unique(): + idx = (self._batch_index == i).nonzero().squeeze(1) + idxes.append(idx) + idxes = torch.stack(idxes) + return cnn_combine(stitched[idxes]).squeeze(1) + + def expert_to_gates(self): + """Gate values corresponding to the examples in the per-expert `Tensor`s. + Returns: + a list of `num_experts` one-dimensional `Tensor`s with type `tf.float32` + and shapes `[expert_batch_size_i]` + """ + # split nonzero gates for each expert + return torch.split(self._nonzero_gates, self._part_sizes, dim=0) + + +def build_experts(experts_cfg, default_cfg, num_experts): + experts_cfg = deepcopy(experts_cfg) + if experts_cfg is None: + # old build way + return nn.ModuleList([ + MLP(*default_cfg) + for i in range(num_experts)]) + # new build way: mix mlp with leff + experts = [] + for e_cfg in experts_cfg: + type_ = e_cfg.pop('type') + if type_ == 'mlp': + experts.append(MLP(*default_cfg)) + return nn.ModuleList(experts) + + +class MoE(nn.Module): + """Call a Sparsely gated mixture of experts layer with 1-layer + Feed-Forward networks as experts. + + Args: + input_size: integer - size of the input + output_size: integer - size of the input + num_experts: an integer - number of experts + hidden_size: an integer - hidden size of the experts + noisy_gating: a boolean + k: an integer - how many experts to use for each batch element + """ + + def __init__(self, input_size, output_size, num_experts, hidden_size, + experts=None, noisy_gating=True, k=4, + x_gating=None, with_noise=True, with_smart_merger=None): + super(MoE, self).__init__() + self.noisy_gating = noisy_gating + self.num_experts = num_experts + self.output_size = output_size + self.input_size = input_size + self.hidden_size = hidden_size + self.k = k + self.with_noise = with_noise + # instantiate experts + self.experts = build_experts( + experts, + (self.input_size, self.hidden_size, self.output_size), + num_experts) + self.w_gate = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) + self.w_noise = nn.Parameter(torch.zeros(input_size, num_experts), requires_grad=True) + + self.x_gating = x_gating + if self.x_gating == 'conv1d': + self.x_gate = nn.Conv1d(4096, 1, kernel_size=3, padding=1) + + self.softplus = nn.Softplus() + self.softmax = nn.Softmax(1) + self.register_buffer("mean", torch.tensor([0.0])) + self.register_buffer("std", torch.tensor([1.0])) + assert(self.k <= self.num_experts) + + self.cnn_combine = None + if with_smart_merger == 'v1': + print('with SMART MERGER') + self.cnn_combine = nn.Conv2d(self.k, 1, kernel_size=3, padding=1) + + def cv_squared(self, x): + """The squared coefficient of variation of a sample. + Useful as a loss to encourage a positive distribution to be more uniform. + Epsilons added for numerical stability. + Returns 0 for an empty Tensor. + Args: + x: a `Tensor`. + Returns: + a `Scalar`. + """ + eps = 1e-10 + # if only num_experts = 1 + + if x.shape[0] == 1: + return torch.tensor([0], device=x.device, dtype=x.dtype) + return x.float().var() / (x.float().mean()**2 + eps) + + def _gates_to_load(self, gates): + """Compute the true load per expert, given the gates. + The load is the number of examples for which the corresponding gate is >0. + Args: + gates: a `Tensor` of shape [batch_size, n] + Returns: + a float32 `Tensor` of shape [n] + """ + return (gates > 0).sum(0) + + def _prob_in_top_k(self, clean_values, noisy_values, noise_stddev, noisy_top_values): + """Helper function to NoisyTopKGating. + Computes the probability that value is in top k, given different random noise. + This gives us a way of backpropagating from a loss that balances the number + of times each expert is in the top k experts per example. + In the case of no noise, pass in None for noise_stddev, and the result will + not be differentiable. + Args: + clean_values: a `Tensor` of shape [batch, n]. + noisy_values: a `Tensor` of shape [batch, n]. Equal to clean values plus + normally distributed noise with standard deviation noise_stddev. + noise_stddev: a `Tensor` of shape [batch, n], or None + noisy_top_values: a `Tensor` of shape [batch, m]. + "values" Output of tf.top_k(noisy_top_values, m). m >= k+1 + Returns: + a `Tensor` of shape [batch, n]. + """ + batch = clean_values.size(0) + m = noisy_top_values.size(1) + top_values_flat = noisy_top_values.flatten() + + threshold_positions_if_in = torch.arange(batch, device=clean_values.device) * m + self.k + threshold_if_in = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_in), 1) + is_in = torch.gt(noisy_values, threshold_if_in) + threshold_positions_if_out = threshold_positions_if_in - 1 + threshold_if_out = torch.unsqueeze(torch.gather(top_values_flat, 0, threshold_positions_if_out), 1) + # is each value currently in the top k. + normal = Normal(self.mean, self.std) + prob_if_in = normal.cdf((clean_values - threshold_if_in)/noise_stddev) + prob_if_out = normal.cdf((clean_values - threshold_if_out)/noise_stddev) + prob = torch.where(is_in, prob_if_in, prob_if_out) + return prob + + def noisy_top_k_gating(self, x, train, noise_epsilon=1e-2): + """Noisy top-k gating. + See paper: https://arxiv.org/abs/1701.06538. + Args: + x: input Tensor with shape [batch_size, input_size] + train: a boolean - we only add noise at training time. + noise_epsilon: a float + Returns: + gates: a Tensor with shape [batch_size, num_experts] + load: a Tensor with shape [num_experts] + """ + clean_logits = x @ self.w_gate + if self.noisy_gating and train: + raw_noise_stddev = x @ self.w_noise + noise_stddev = ((self.softplus(raw_noise_stddev) + noise_epsilon)) + noisy_logits = clean_logits + (torch.randn_like(clean_logits) * noise_stddev) + logits = noisy_logits + else: + logits = clean_logits + + # calculate topk + 1 that will be needed for the noisy gates + top_logits, top_indices = logits.topk(min(self.k + 1, self.num_experts), dim=1) + top_k_logits = top_logits[:, :self.k] + top_k_indices = top_indices[:, :self.k] + top_k_gates = self.softmax(top_k_logits) + + zeros = torch.zeros_like(logits, requires_grad=True) + gates = zeros.scatter(1, top_k_indices, top_k_gates) + + if self.noisy_gating and self.k < self.num_experts and train: + load = (self._prob_in_top_k(clean_logits, noisy_logits, noise_stddev, top_logits)).sum(0) + else: + load = self._gates_to_load(gates) + return gates, load + + def forward(self, x, loss_coef=1e-2): + """Args: + x: tensor shape [batch_size, input_size] + train: a boolean scalar. + loss_coef: a scalar - multiplier on load-balancing losses + + Returns: + y: a tensor with shape [batch_size, output_size]. + extra_training_loss: a scalar. This should be added into the overall + training loss of the model. The backpropagation of this loss + encourages all experts to be approximately equally used across a batch. + """ + if self.x_gating is not None: + xg = self.x_gate(x).squeeze(1) + else: + xg = x.mean(1) + + gates, load = self.noisy_top_k_gating( + xg, self.training and self.with_noise) + # calculate importance loss + importance = gates.sum(0) + # + loss = self.cv_squared(importance) + self.cv_squared(load) + loss *= loss_coef + + dispatcher = SparseDispatcher(self.num_experts, gates) + expert_inputs = dispatcher.dispatch(x) + gates = dispatcher.expert_to_gates() + expert_outputs = [self.experts[i](expert_inputs[i]) + for i in range(self.num_experts)] + y = dispatcher.combine(expert_outputs, cnn_combine=self.cnn_combine) + return y, loss diff --git a/src/super_res/network_swin2sr.py b/src/super_res/network_swin2sr.py new file mode 100644 index 0000000..c41e533 --- /dev/null +++ b/src/super_res/network_swin2sr.py @@ -0,0 +1,1173 @@ +# +# Source code: https://github.com/mv-lab/swin2sr +# +# ----------------------------------------------------------------------------------- +# Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345 +# Written by Conde and Choi et al. +# ----------------------------------------------------------------------------------- + +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .utils import window_reverse, Mlp, window_partition +from .moe import MoE + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0], + use_lepe=False, + use_cpb_bias=True, + use_rpe_bias=False): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + self.use_cpb_bias = use_cpb_bias + + if self.use_cpb_bias: + print('positional encoder: CPB') + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential(nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False)) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack( + torch.meshgrid([relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / np.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.use_rpe_bias = use_rpe_bias + if self.use_rpe_bias: + print('positional encoder: RPE') + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + rpe_relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("rpe_relative_position_index", rpe_relative_position_index) + + trunc_normal_(self.relative_position_bias_table, std=.02) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + self.use_lepe = use_lepe + if self.use_lepe: + print('positional encoder: LEPE') + self.get_v = nn.Conv2d( + dim, dim, kernel_size=3, stride=1, padding=1, groups=dim) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + if self.use_lepe: + lepe = self.lepe_pos(v) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01)).to(self.logit_scale.device)).exp() + attn = attn * logit_scale + + if self.use_cpb_bias: + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if self.use_rpe_bias: + relative_position_bias = self.relative_position_bias_table[self.rpe_relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v) + + if self.use_lepe: + x = x + lepe + + x = x.transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def lepe_pos(self, v): + B, NH, HW, NW = v.shape + C = NH * NW + H = W = int(math.sqrt(HW)) + v = v.transpose(-2, -1).contiguous().view(B, C, H, W) + lepe = self.get_v(v) + lepe = lepe.reshape(-1, self.num_heads, NW, HW) + lepe = lepe.permute(0, 1, 3, 2).contiguous() + return lepe + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, ' \ + f'pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0, + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size), + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + use_rpe_bias=use_rpe_bias) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + if MoE_config is None: + print('-->>> MLP') + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + else: + print('-->>> MOE') + print(MoE_config) + self.mlp = MoE( + input_size=dim, output_size=dim, hidden_size=mlp_hidden_dim, + **MoE_config) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + + shortcut = x + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + + loss_moe = None + res = self.mlp(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = x + self.drop_path(self.norm2(res)) + + return x, loss_moe + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + pretrained_window_size=0, + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + loss_moe_all = 0 + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + if self.downsample is not None: + x = self.downsample(x) + return x, loss_moe_all + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + B, C, H, W = x.shape + # FIXME look at relaxing size constraints + # assert H == self.img_size[0] and W == self.img_size[1], + # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer(dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias + ) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + loss_moe = None + res = self.residual_group(x, x_size) + + if not torch.is_tensor(res): + res, loss_moe = res + + res = self.patch_embed(self.conv(self.patch_unembed(res, x_size))) + return res + x, loss_moe + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + +class Upsample_hf(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample_hf, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + + +class Swin2SR(nn.Module): + r""" Swin2SR + A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + use_rpe_bias=False, + **kwargs): + super(Swin2SR, self).__init__() + print('==== SWIN 2SR') + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias, + ) + self.layers.append(layer) + + if self.upsampler == 'pixelshuffle_hf': + self.layers_hf = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + use_rpe_bias=use_rpe_bias + ) + self.layers_hf.append(layer) + + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffle_aux': + self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_after_aux = nn.Sequential( + nn.Conv2d(3, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffle_hf': + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.upsample_hf = Upsample_hf(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.conv_first_hf = nn.Sequential(nn.Conv2d(num_feat, embed_dim, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + self.conv_before_upsample_hf = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + loss_moe_all = 0 + for layer in self.layers: + x = layer(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x, loss_moe_all + + def forward_features_hf(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + loss_moe_all = 0 + for layer in self.layers_hf: + x = layer(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x, loss_moe_all + + def forward_backbone(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + else: + raise Exception('not implemented yet') + + x = x / self.img_range + self.mean + return x + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + loss_moe = 0 + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffle_aux': + bicubic = F.interpolate(x, size=(H * self.upscale, W * self.upscale), mode='bicubic', align_corners=False) + bicubic = self.conv_bicubic(bicubic) + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.conv_before_upsample(x) + aux = self.conv_aux(x) # b, 3, LR_H, LR_W + x = self.conv_after_aux(aux) + x = self.upsample(x)[:, :, :H * self.upscale, :W * self.upscale] + bicubic[:, :, :H * self.upscale, :W * self.upscale] + x = self.conv_last(x) + aux = aux / self.img_range + self.mean + elif self.upsampler == 'pixelshuffle_hf': + # for classical SR with HF + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x_before = self.conv_before_upsample(x) + x_out = self.conv_last(self.upsample(x_before)) + + x_hf = self.conv_first_hf(x_before) + + res_hf = self.forward_features_hf(x_hf) + if not torch.is_tensor(res_hf): + res_hf, loss_moe_hf = res_hf + loss_moe += loss_moe_hf + + x_hf = self.conv_after_body_hf(res_hf) + x_hf + x_hf = self.conv_before_upsample_hf(x_hf) + x_hf = self.conv_last_hf(self.upsample_hf(x_hf)) + x = x_out + x_hf + x_hf = x_hf / self.img_range + self.mean + + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + + x = self.conv_after_body(res) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + + res = self.forward_features(x_first) + if not torch.is_tensor(res): + res, loss_moe = res + + res = self.conv_after_body(res) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + if self.upsampler == "pixelshuffle_aux": + return x[:, :, :H*self.upscale, :W*self.upscale], aux, loss_moe + + elif self.upsampler == "pixelshuffle_hf": + x_out = x_out / self.img_range + self.mean + return x_out[:, :, :H*self.upscale, :W*self.upscale], x[:, :, :H*self.upscale, :W*self.upscale], x_hf[:, :, :H*self.upscale, :W*self.upscale], loss_moe + + else: + return x[:, :, :H*self.upscale, :W*self.upscale], loss_moe + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops + + +if __name__ == '__main__': + upscale = 4 + window_size = 8 + height = (1024 // upscale // window_size + 1) * window_size + width = (720 // upscale // window_size + 1) * window_size + model = Swin2SR(upscale=2, img_size=(height, width), + window_size=window_size, img_range=1., depths=[6, 6, 6, 6], + embed_dim=60, num_heads=[6, 6, 6, 6], mlp_ratio=2, upsampler='pixelshuffledirect') + print(model) + print(height, width, model.flops() / 1e9) + + x = torch.randn((1, 3, height, width)) + x = model(x) + print(x.shape) diff --git a/src/super_res/network_swinir.py b/src/super_res/network_swinir.py new file mode 100644 index 0000000..39edfab --- /dev/null +++ b/src/super_res/network_swinir.py @@ -0,0 +1,853 @@ +# +# Original code from: https://github.com/JingyunLiang/SwinIR +# +# ----------------------------------------------------------------------------------- +# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257 +# Originally Written by Ze Liu, Modified by Jingyun Liang. +# ----------------------------------------------------------------------------------- + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from .utils import window_reverse, Mlp, window_partition + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, kernel_size=3): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + use_lepe=False, use_cpb_bias=True, + MoE_config=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + loss_moe_all = 0 + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + if self.downsample is not None: + x = self.downsample(x) + + return x, loss_moe_all + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., norm_layer=nn.LayerNorm, downsample=None, use_checkpoint=False, + img_size=224, patch_size=4, resi_connection='1conv', + use_lepe=False, use_cpb_bias=True, + MoE_config=None): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential(nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + loss_moe = None + res = self.residual_group(x, x_size) + if not torch.is_tensor(res): + res, loss_moe = res + return ( + self.patch_embed( + self.conv(self.patch_unembed(res, x_size))) + x + ), loss_moe + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops + + +class SwinIR(nn.Module): + r""" SwinIR + A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, img_size=64, patch_size=1, in_chans=3, + embed_dim=96, depths=[6, 6, 6, 6], num_heads=[6, 6, 6, 6], + window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, upscale=2, img_range=1., upsampler='', resi_connection='1conv', + use_transpose=False, + use_lepe=False, + use_cpb_bias=True, + MoE_config=None, + **kwargs): + super(SwinIR, self).__init__() + print('==== SWIN IR') + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.4488, 0.4371, 0.4040) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + self.window_size = window_size + + ##################################################################################################### + ################################### 1, shallow feature extraction ################################### + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + ##################################################################################################### + ################################### 2, deep feature extraction ###################################### + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB(dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection, + use_lepe=use_lepe, + use_cpb_bias=use_cpb_bias, + MoE_config=MoE_config, + ) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + if resi_connection == '1conv': + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv_after_body = nn.Sequential(nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1)) + + ##################################################################################################### + ################################ 3, high quality image reconstruction ################################ + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + self.conv_before_upsample = nn.Sequential(nn.Conv2d(embed_dim, num_feat, 3, 1, 1), + nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + if self.upscale == 4: + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def check_image_size(self, x): + _, _, h, w = x.size() + mod_pad_h = (self.window_size - h % self.window_size) % self.window_size + mod_pad_w = (self.window_size - w % self.window_size) % self.window_size + x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect') + return x + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + loss_moe_all = 0 + for layer in self.layers: + x = layer(x, x_size) + if not torch.is_tensor(x): + x, loss_moe = x + loss_moe_all += loss_moe or 0 + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x, loss_moe_all + + def forward(self, x): + H, W = x.shape[2:] + x = self.check_image_size(x) + + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + loss_moe = None + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + x = self.conv_after_body(res) + x + x = self.upsample(x) + elif self.upsampler == 'sunet': + # for lightweight SR + x = self.conv_first(x) + res = self.forward_features(x) + if not torch.is_tensor(res): + res, loss_moe = res + x = self.conv_after_body(res) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + if self.upscale == 4: + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x[:, :, :H*self.upscale, :W*self.upscale], loss_moe + + def flops(self): + flops = 0 + H, W = self.patches_resolution + flops += H * W * 3 * self.embed_dim * 9 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += H * W * 3 * self.embed_dim * self.embed_dim + flops += self.upsample.flops() + return flops diff --git a/src/super_res/swinfir_arch.py b/src/super_res/swinfir_arch.py new file mode 100644 index 0000000..904b5a0 --- /dev/null +++ b/src/super_res/swinfir_arch.py @@ -0,0 +1,537 @@ +# +# Source code: https://github.com/Zdafeng/SwinFIR +# Paper: https://arxiv.org/pdf/2208.11247v3.pdf +# + +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint + +from timm.models.layers import to_2tuple, trunc_normal_ +from .swinfir_utils import WindowAttention, DropPath, Mlp, SFB, \ + PatchEmbed, PatchUnEmbed, Upsample, UpsampleOneStep, window_partition, \ + window_reverse + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + if self.shift_size > 0: + attn_mask = self.calculate_mask(self.input_resolution) + else: + attn_mask = None + + self.register_buffer('attn_mask', attn_mask) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + h, w = x_size + img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1 + h_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + if self.input_resolution == x_size: + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c + else: + attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' + f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}') + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + +class RSTB(nn.Module): + """Residual Swin Transformer Block (RSTB). + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + img_size: Input image size. + patch_size: Patch size. + resi_connection: The convolutional block before residual connection. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection='1conv'): + super(RSTB, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == 'SFB': + self.conv = SFB(dim) + elif resi_connection == 'HSFB': + self.conv = SFB(dim, 2) + elif resi_connection == 'identity': + self.conv = nn.Identity() + + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x + + +# @ARCH_REGISTRY.register() +class SwinFIR(nn.Module): + r""" SwinFIR + A PyTorch impl of : `SwinFIR: Revisiting the SwinIR with Fast Fourier Convolution and + Improved Training for Image Super-Resolution`, based on Swin Transformer and Fast Fourier Convolution. + + Args: + img_size (int | tuple(int)): Input image size. Default 64 + patch_size (int | tuple(int)): Patch size. Default: 1 + in_chans (int): Number of input image channels. Default: 3 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction + img_range: Image range. 1. or 255. + upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None + resi_connection: The convolutional block before residual connection. '1conv'/'3conv' + """ + + def __init__(self, + img_size=64, + patch_size=1, + in_chans=3, + embed_dim=96, + depths=(6, 6, 6, 6), + num_heads=(6, 6, 6, 6), + window_size=7, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + upscale=2, + img_range=1., + upsampler='', + resi_connection='SFB', + **kwargs): + super(SwinFIR, self).__init__() + print('==== SWIN FIR') + num_in_ch = in_chans + num_out_ch = in_chans + num_feat = 64 + self.img_range = img_range + if in_chans == 3: + rgb_mean = (0.3014, 0.3152, 0.3094) + self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) + else: + self.mean = torch.zeros(1, 1, 1, 1) + self.upscale = upscale + self.upsampler = upsampler + + # ------------------------- 1, shallow feature extraction ------------------------- # + self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1) + + # ------------------------- 2, deep feature extraction ------------------------- # + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build Residual Swin Transformer blocks (RSTB) + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = RSTB( + dim=embed_dim, + input_resolution=(patches_resolution[0], patches_resolution[1]), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=img_size, + patch_size=patch_size, + resi_connection=resi_connection) + self.layers.append(layer) + self.norm = norm_layer(self.num_features) + + # build the last conv layer in deep feature extraction + self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + # ------------------------- 3, high quality image reconstruction ------------------------- # + if self.upsampler == 'pixelshuffle': + # for classical SR + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.upsample = Upsample(upscale, num_feat) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR (to save parameters) + self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch, + (patches_resolution[0], patches_resolution[1])) + elif self.upsampler == 'nearest+conv': + # for real-world SR (less artifacts) + assert self.upscale == 4, 'only support x4 now.' + self.conv_before_upsample = nn.Sequential( + nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)) + self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) + self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) + self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + else: + # for image denoising and JPEG compression artifact reduction + self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x, x_size) + + x = self.norm(x) # b seq_len c + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + input = x + self.mean = self.mean.type_as(x) + x = (x - self.mean) * self.img_range + + if self.upsampler == 'pixelshuffle': + # for classical SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.conv_last(self.upsample(x)) + elif self.upsampler == 'pixelshuffledirect': + # for lightweight SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.upsample(x) + elif self.upsampler == 'nearest+conv': + # for real-world SR + x = self.conv_first(x) + x = self.conv_after_body(self.forward_features(x)) + x + x = self.conv_before_upsample(x) + x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest'))) + x = self.conv_last(self.lrelu(self.conv_hr(x))) + else: + # for image denoising and JPEG compression artifact reduction + x_first = self.conv_first(x) + res = self.conv_after_body(self.forward_features(x_first)) + x_first + x = x + self.conv_last(res) + + x = x / self.img_range + self.mean + + return x diff --git a/src/super_res/swinfir_utils.py b/src/super_res/swinfir_utils.py new file mode 100644 index 0000000..ce9db86 --- /dev/null +++ b/src/super_res/swinfir_utils.py @@ -0,0 +1,725 @@ +# +# Source code: https://github.com/Zdafeng/SwinFIR +# Paper: https://arxiv.org/pdf/2208.11247v3.pdf +# + +import math +import torch +import torch.nn as nn + +from einops import rearrange +from timm.models.layers import to_2tuple, trunc_normal_ + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class ChannelAttention(nn.Module): + """Channel attention used in RCAN. + Args: + num_feat (int): Channel number of intermediate features. + squeeze_factor (int): Channel squeeze factor. Default: 16. + """ + + def __init__(self, num_feat, squeeze_factor=16): + super(ChannelAttention, self).__init__() + self.attention = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), + nn.ReLU(inplace=True), + nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), + nn.Sigmoid()) + + def forward(self, x): + y = self.attention(x) + return x * y + + +def window_partition(x, window_size): + """ + Args: + x: (b, h, w, c) + window_size (int): window size + + Returns: + windows: (num_windows*b, window_size, window_size, c) + """ + b, h, w, c = x.shape + x = x.view(b, h // window_size, window_size, w // window_size, window_size, c) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c) + return windows + + +def window_reverse(windows, window_size, h, w): + """ + Args: + windows: (num_windows*b, window_size, window_size, c) + window_size (int): Window size + h (int): Height of image + w (int): Width of image + + Returns: + x: (b, h, w, c) + """ + b = int(windows.shape[0] / (h * w / window_size / window_size)) + x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + +class CAB(nn.Module): + + def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30): + super(CAB, self).__init__() + + self.cab = nn.Sequential( + nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), + nn.GELU(), + nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), + ChannelAttention(num_feat, squeeze_factor) + ) + + def forward(self, x): + return self.cab(x) + + +class WindowAttentionHATFIR(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, rpi, mask=None): + """ + Args: + x: input features with shape of (num_windows*b, n, c) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + b_, n, c = x.shape + qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nw = mask.shape[0] + attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, n, n) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(b_, n, c) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class HAB(nn.Module): + r""" Hybrid Attention Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + compress_ratio=3, + squeeze_factor=30, + conv_scale=0.01, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + self.attn = WindowAttentionHATFIR( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.conv_scale = conv_scale + self.conv_block = CAB(num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x, x_size, rpi_sa, attn_mask): + h, w = x_size + b, _, c = x.shape + # assert seq_len == h * w, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + # Conv_X + conv_x = self.conv_block(x.permute(0, 3, 1, 2)) + conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = attn_mask + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c + x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size + attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c) + shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c + + # reverse cyclic shift + if self.shift_size > 0: + attn_x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + attn_x = shifted_x + attn_x = attn_x.view(b, h * w, c) + + # FFN + x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: b, h*w, c + """ + h, w = self.input_resolution + b, seq_len, c = x.shape + assert seq_len == h * w, 'input feature has wrong size' + assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.' + + x = x.view(b, h, w, c) + + x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c + x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c + x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c + x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c + x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c + x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class OCAB(nn.Module): + # overlapping cross-attention block + + def __init__(self, dim, + input_resolution, + window_size, + overlap_ratio, + num_heads, + qkv_bias=True, + qk_scale=None, + mlp_ratio=2, + norm_layer=nn.LayerNorm + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + self.overlap_win_size = int(window_size * overlap_ratio) + window_size + + self.norm1 = norm_layer(dim) + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.unfold = nn.Unfold(kernel_size=(self.overlap_win_size, self.overlap_win_size), stride=window_size, + padding=(self.overlap_win_size - window_size) // 2) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((window_size + self.overlap_win_size - 1) * (window_size + self.overlap_win_size - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + self.proj = nn.Linear(dim, dim) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU) + + def forward(self, x, x_size, rpi): + h, w = x_size + b, _, c = x.shape + + shortcut = x + x = self.norm1(x) + x = x.view(b, h, w, c) + + qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w + q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c + kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w + + # partition windows + q_windows = window_partition(q, self.window_size) # nw*b, window_size, window_size, c + q_windows = q_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c + + kv_windows = self.unfold(kv) # b, c*w*w, nw + kv_windows = rearrange(kv_windows, 'b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch', nc=2, ch=c, + owh=self.overlap_win_size, oww=self.overlap_win_size).contiguous() # 2, nw*b, ow*ow, c + k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c + + b_, nq, _ = q_windows.shape + _, n, _ = k_windows.shape + d = self.dim // self.num_heads + q = q_windows.reshape(b_, nq, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, nq, d + k = k_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d + v = v_windows.reshape(b_, n, self.num_heads, d).permute(0, 2, 1, 3) # nw*b, nH, n, d + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + + relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view( + self.window_size * self.window_size, self.overlap_win_size * self.overlap_win_size, + -1) # ws*ws, wse*wse, nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, ws*ws, wse*wse + attn = attn + relative_position_bias.unsqueeze(0) + + attn = self.softmax(attn) + attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, self.dim) + x = window_reverse(attn_windows, self.window_size, h, w) # b h w c + x = x.view(b, h * w, self.dim) + + x = self.proj(x) + shortcut + + x = x + self.mlp(self.norm2(x)) + return x + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) # b Ph*Pw c + if self.norm is not None: + x = self.norm(x) + return x + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c + return x + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + +class FourierUnit(nn.Module): + def __init__(self, embed_dim, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.conv_layer = torch.nn.Conv2d(embed_dim * 2, embed_dim * 2, 1, 1, 0) + self.relu = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(ffted) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute(0, 1, 3, 4, + 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + return output + + +class SpectralTransform(nn.Module): + def __init__(self, embed_dim, last_conv=False): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.last_conv = last_conv + + self.conv1 = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // 2, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True) + ) + self.fu = FourierUnit(embed_dim // 2) + + self.conv2 = torch.nn.Conv2d(embed_dim // 2, embed_dim, 1, 1, 0) + + if self.last_conv: + self.last_conv = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1) + + def forward(self, x): + x = self.conv1(x) + output = self.fu(x) + output = self.conv2(x + output) + if self.last_conv: + output = self.last_conv(output) + return output + + +## Residual Block (RB) +class ResB(nn.Module): + def __init__(self, embed_dim, red=1): + super(ResB, self).__init__() + self.body = nn.Sequential( + nn.Conv2d(embed_dim, embed_dim // red, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(embed_dim // red, embed_dim, 3, 1, 1), + ) + + def __call__(self, x): + out = self.body(x) + return out + x + + +class SFB(nn.Module): + def __init__(self, embed_dim, red=1): + super(SFB, self).__init__() + self.S = ResB(embed_dim, red) + self.F = SpectralTransform(embed_dim) + self.fusion = nn.Conv2d(embed_dim * 2, embed_dim, 1, 1, 0) + + def __call__(self, x): + s = self.S(x) + f = self.F(x) + out = torch.cat([s, f], dim=1) + out = self.fusion(out) + return out diff --git a/src/super_res/training.py b/src/super_res/training.py new file mode 100644 index 0000000..96ba272 --- /dev/null +++ b/src/super_res/training.py @@ -0,0 +1,109 @@ +import datetime +import torch + +import debug + +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter + +from validation import validate, do_save_metrics as save_metrics +from chk_loader import load_checkpoint, load_state_dict_model, \ + save_state_dict_model +from validation import build_eval_metrics +from losses import build_losses +from optim import build_optimizer +from .model import build_model + + +def train(train_dloader, val_dloader, cfg): + + # Tensorboard + writer = SummaryWriter(cfg.output + '/tensorboard/train_{}'.format( + datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))) + # eval every x + eval_every = cfg.metrics.get('eval_every', 1) + + model = build_model(cfg) + losses = build_losses(cfg) + optimizer = build_optimizer(model, cfg) + + begin_epoch = 0 + index = 0 + try: + checkpoint = load_checkpoint(cfg) + begin_epoch, index = load_state_dict_model( + model, optimizer, checkpoint) + except FileNotFoundError: + print('no checkpoint found') + + print('build eval metrics') + metrics = build_eval_metrics(cfg) + + for e in range(begin_epoch, cfg.epochs): + index = train_epoch( + model, + train_dloader, + losses, + optimizer, + e, + writer, + index, + cfg) + + if (e+1) % eval_every == 0: + result = validate( + model, val_dloader, metrics, e, writer, 'test', cfg) + # save result of eval + cfg.epoch = e+1 + save_metrics(result, cfg) + + save_state_dict_model(model, optimizer, e, index, cfg) + + +def train_epoch(model, train_dloader, losses, optimizer, epoch, writer, + index, cfg): + weights = cfg.losses.weights + for index, batch in tqdm( + enumerate(train_dloader, index), total=len(train_dloader), + desc='Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + + # Transfer in-memory data to CUDA devices to speed up training + hr = batch["hr"].to(device=cfg.device, non_blocking=True) + lr = batch["lr"].to(device=cfg.device, non_blocking=True) + + sr = model(lr) + + loss_tracker = {} + + loss_moe = None + if not torch.is_tensor(sr): + sr, loss_moe = sr + if torch.is_tensor(loss_moe): + loss_tracker['loss_moe'] = loss_moe * weights.moe + + sr = sr.contiguous() + + if 'pixel_criterion' in losses: + loss_tracker['pixel_loss'] = weights.pixel * \ + losses['pixel_criterion'](sr, hr) + + # cc loss + if 'cc_criterion' in losses: + loss_tracker['cc_loss'] = weights.cc * \ + losses['cc_criterion'](sr, hr) + + # ssim loss + if 'ssim_criterion' in losses: + loss_tracker['ssim_loss'] = weights.ssim * \ + losses['ssim_criterion'](sr, hr) + + # train + loss_tracker['train_loss'] = sum(loss_tracker.values()) + optimizer.zero_grad() + loss_tracker['train_loss'].backward() + optimizer.step() + + debug.log_hr_stats(lr, sr, hr, writer, index, cfg) + debug.log_losses(loss_tracker, 'train', writer, index) + + return index diff --git a/src/super_res/utils.py b/src/super_res/utils.py new file mode 100644 index 0000000..e7bb96d --- /dev/null +++ b/src/super_res/utils.py @@ -0,0 +1,56 @@ +from torch import nn + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, + W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view( + -1, window_size, window_size, C) + return windows diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..2ef46c2 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,117 @@ +import torch +import importlib +import numpy as np +import random + +from enum import Enum +from torchmetrics.classification import MulticlassF1Score + + +def set_deterministic(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) # Numpy module. + random.seed(seed) # Python random module. + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = True + torch.backends.cudnn.deterministic = True + + +def w_count(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def load_fun(fullname): + path, name = fullname.rsplit('.', 1) + return getattr(importlib.import_module(path), name) + + +class Summary(Enum): + NONE = 0 + AVERAGE = 1 + SUM = 2 + COUNT = 3 + + +class AverageMeter(object): + def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): + self.name = name + self.fmt = fmt + self.summary_type = summary_type + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + self.avg_item = self.avg.tolist() + fmtstr = "{avg_item" + self.fmt + "}" + try: + return fmtstr.format(**self.__dict__) + except TypeError: + # print a list of elements + fmtstr = "{" + self.fmt + "}" + return ' '.join([ + fmtstr for _ in range(len(self.avg_item)) + ]).format(*self.avg_item) + + def summary(self): + if self.summary_type is Summary.NONE: + fmtstr = "" + elif self.summary_type is Summary.AVERAGE: + fmtstr = "{name} {avg:.2f}" + elif self.summary_type is Summary.SUM: + fmtstr = "{name} {sum:.2f}" + elif self.summary_type is Summary.COUNT: + fmtstr = "{name} {count:.2f}" + else: + raise ValueError( + "Invalid summary type {}".format(self.summary_type)) + + return fmtstr.format(**self.__dict__) + + +class F1AverageMeter(AverageMeter): + def __init__(self, cfg, average, **kwargs): + self.cfg = cfg + self._cfg_average = average + super().__init__(**kwargs) + + def reset(self): + super().reset() + self._to_update = True + self.fun = MulticlassF1Score( + len(self.cfg.classes), average=self._cfg_average + ).to(self.cfg.device) + + @property + def avg(self): + if self._to_update: + self._avg = self.fun.compute() + self._to_update = False + return self._avg + + @avg.setter + def avg(self, value): + self._avg = value + self._to_update = False + + def update(self, val, n=1): + pred, gt = val + self.fun.update(pred, gt) + self._to_update = True + + +def set_required_grad(model, value): + for parameters in model.parameters(): + parameters.requires_grad = value diff --git a/src/validation.py b/src/validation.py new file mode 100644 index 0000000..778ddf8 --- /dev/null +++ b/src/validation.py @@ -0,0 +1,187 @@ +import os +import torch + +from torch import nn +from tqdm import tqdm +from torch.nn import Upsample +from collections import OrderedDict + +from utils import AverageMeter, load_fun +from metrics import CC, SAM, ERGAS, piq_psnr, piq_ssim, \ + piq_rmse +from chk_loader import load_checkpoint + + +def validate(g_model, val_dloader, metrics, epoch, writer, mode, cfg): + # Put the adversarial network model in validation mode + g_model.eval() + + avg_metrics = build_avg_metrics() + + use_minmax = cfg.dataset.get('stats', {}).get('use_minmax', False) + dset = cfg.dataset + denorm = load_fun(dset.get('denorm'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name) + evaluable = load_fun(dset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + + with torch.no_grad(): + for j, batch in tqdm( + enumerate(val_dloader), total=len(val_dloader), + desc='Val Epoch: %d / %d' % (epoch + 1, cfg.epochs)): + hr = batch["hr"].to(device=cfg.device, non_blocking=True) + lr = batch["lr"].to(device=cfg.device, non_blocking=True) + + # Use the generator model to generate a fake sample + sr = g_model(lr) + + if not torch.is_tensor(sr): + sr, _ = sr + + sr = sr.contiguous() + + # denormalize to original values + hr, sr, lr = denorm(hr, sr, lr) + + # normalize [0, 1] using also the outliers to evaluate + hr, sr, lr = evaluable(hr, sr, lr) + + # Statistical loss value for terminal data output + for k, fun in metrics.items(): + for i in range(len(sr)): + avg_metrics[k].update(fun(sr[i][None], hr[i][None])) + + if writer is not None: + for k, v in avg_metrics.items(): + writer.add_scalar("{}/{}".format(mode, k), v.avg.item(), epoch+1) + + if cfg.get('eval_return_to_train', True): + g_model.train() + return avg_metrics + + +def build_eval_metrics(cfg): + # Create an IQA evaluation model + metrics = { + 'psnr_model': piq_psnr(cfg), + 'ssim_model': piq_ssim(cfg), + 'cc_model': CC(), + 'rmse_model': piq_rmse(cfg), + 'sam_model': SAM(), + 'ergas_model': ERGAS(), + } + + for k in metrics.keys(): + metrics[k] = metrics[k].to(cfg.device) + + return metrics + + +def build_avg_metrics(): + return OrderedDict([ + ('psnr_model', AverageMeter("PIQ_PSNR", ":4.4f")), + ('ssim_model', AverageMeter("PIQ_SSIM", ":4.4f")), + ('cc_model', AverageMeter("CC", ":4.4f")), + ('rmse_model', AverageMeter("PIQ_RMSE", ":4.4f")), + ('sam_model', AverageMeter("SAM", ":4.4f")), + ('ergas_model', AverageMeter("ERGAS", ":4.4f")), + ]) + + +def main(val_dloader, cfg, save_metrics=True): + model = load_eval_method(cfg) + print('build eval metrics') + metrics = build_eval_metrics(cfg) + result = validate( + model, val_dloader, metrics, cfg.epoch, None, 'test', cfg) + if save_metrics: + do_save_metrics(result, cfg) + return result + + +def get_result_filename(cfg): + output_dir = os.path.join(cfg.output, 'eval') + os.makedirs(output_dir, exist_ok=True) + return os.path.join(output_dir, 'results-{:02d}.pt'.format(cfg.epoch)) + + +def do_save_metrics(metrics, cfg): + filename = get_result_filename(cfg) + print('save results {}'.format(filename)) + torch.save({ + 'epoch': cfg.epoch, + 'metrics': OrderedDict([ + (k, v.avg) for k, v in metrics.items() + ]) + }, filename) + + +def load_metrics(cfg): + filename = get_result_filename(cfg) + print('load results {}'.format(filename)) + result = torch.load(filename) + # check if epoch corresponds + assert result['epoch'] == cfg.epoch + # build AVG objects + avg_metrics = build_avg_metrics() + for k, v in result['metrics'].items(): + avg_metrics[k].avg = v + return avg_metrics + + +def print_metrics(metrics): + names = [] + values = [] + for i, v in enumerate(metrics.values()): + try: + names.append(v.name) + values.append(v) + except AttributeError: + # skip for retrocompatibility + pass + print(*names) + print(*values) + + +def load_eval_method(cfg): + if cfg.eval_method is None: + vis = cfg.visualize + model = load_fun(vis.get('model'))(cfg) + # Load model state dict + try: + checkpoint = load_checkpoint(cfg) + _, _ = load_fun(vis.get('checkpoint'))(model, checkpoint) + except Exception as e: + print(e) + exit(0) + + return model + + print('load non-dl upsampler: {}'.format(cfg.eval_method)) + return NonDLEvalMethod(cfg) + + +class NonDLEvalMethod(object): + def __init__(self, cfg): + self.upscale_factor = cfg.metrics.upscale_factor + self.upsampler = Upsample( + scale_factor=self.upscale_factor, + mode=cfg.eval_method) + + def __call__(self, x): + return self.upsampler(x) + + def eval(self): + pass + + def train(self): + pass + + def to(self, device): + return self diff --git a/src/visualize.py b/src/visualize.py new file mode 100644 index 0000000..f97ff35 --- /dev/null +++ b/src/visualize.py @@ -0,0 +1,101 @@ +import os +import torch +import matplotlib.pyplot as plt +import imgproc + +from tqdm import tqdm + +from validation import build_eval_metrics, load_eval_method +from utils import load_fun + + +def plot_images(img, out_path, basename, fname, dpi): + # dir: out_dir / basename + out_path = os.path.join(out_path, basename) + if not os.path.exists(out_path): + os.makedirs(out_path) + + img = img.squeeze(0) + for i in range(img.shape[0]): + plt.imshow(imgproc.tensor_to_image(img[i].detach(), False, False)) + plt.axis('off') + out_fname = os.path.join(out_path, '{}_{}.png'.format(fname, i)) + plt.savefig(out_fname, dpi=dpi) + plt.close() + + +def main(cfg): + # Load model state dict + model = load_eval_method(cfg) + + # manipulation function for data + use_minmax = cfg.dataset.get('stats').get('use_minmax', False) + denorm = load_fun(cfg.dataset.get('denorm'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name) + evaluable = load_fun(cfg.dataset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + printable = load_fun(cfg.dataset.get('printable'))( + cfg, + hr_name=cfg.dataset.hr_name, + lr_name=cfg.dataset.lr_name, + filter_outliers=False, + use_minmax=use_minmax) + + # Create a folder of super-resolution experiment results + out_path = os.path.join(cfg.output, 'images_{}'.format(cfg.epoch)) + if not os.path.exists(out_path): + os.makedirs(out_path) + + # move to device + model.to(cfg.device) + model.eval() + # Load dataset + load_dataset_fun = load_fun(cfg.dataset.get( + 'load_dataset', 'datasets.sen2venus.load_dataset')) + _, val_dloader, _ = load_dataset_fun(cfg, only_test=True) + # Define metrics for evaluate each image + metrics = build_eval_metrics(cfg) + + indices = dict.fromkeys(metrics.keys(), None) + iterations = min(cfg.num_images, len(val_dloader)) + + with torch.no_grad(): + for index, batch in tqdm( + enumerate(val_dloader), total=iterations, + desc='%d Images' % (iterations)): + + if index >= iterations: + break + + hr = batch["hr"].to(device=cfg.device, non_blocking=True) + lr = batch["lr"].to(device=cfg.device, non_blocking=True) + + sr = model(lr) + + if not torch.is_tensor(sr): + sr, _ = sr + + # denormalize to original values + hr_dn, sr_dn, lr_dn = denorm(hr, sr, lr) + + # normalize [0, 1] using also the outliers to evaluate + hr_eval, sr_eval, lr_eval = evaluable(hr_dn, sr_dn, lr_dn) + # compute metrics + for k, fun in metrics.items(): + for i in range(len(sr)): + res = fun(sr_eval[i][None], hr_eval[i][None]).detach() + indices[k] = res if not res.shape else res.squeeze(0) + + # normalize [0, 1] removing outliers to have a printable version + hr, sr, lr = printable(hr_dn, sr_dn, lr_dn) + # plot images + plot_images(lr, out_path, 'lr', index, cfg.dpi) + plot_images(hr, out_path, 'hr', index, cfg.dpi) + plot_images(sr, out_path, 'sr', index, cfg.dpi) + plot_images((hr - sr).abs(), out_path, 'delta', index, cfg.dpi)