diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a9d15ac7..8e183866 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,17 +11,16 @@ on: jobs: Tests: runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10"] steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v2 + - uses: actions/checkout@v3 + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 with: - python-version: 3.8.12 - - uses: syphar/restore-virtualenv@v1 - id: cnetenv - with: - requirement_files: setup.cfg - - uses: syphar/restore-pip-download-cache@v1 - if: steps.cnetenv.outputs.cache-hit != 'true' + python-version: ${{ matrix.python-version }} + cache: pip # caching pip dependencies based on changes to pyproject.toml - name: Install GDAL binaries run: | # Temporary? dpkg fix: https://askubuntu.com/questions/1276111/error-upgrading-grub-efi-amd64-signed-special-device-old-ssd-does-not-exist @@ -39,19 +38,19 @@ jobs: - name: Install Python packages run: | # Install Python GDAL - pip install -U pip setuptools wheel - pip install -U Cython "numpy<=1.21.0" + pip install -U pip + pip install -U setuptools wheel + pip install -U numpy==1.24.4 + pip install setuptools==57.5.0 GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') pip install GDAL==$GDAL_VERSION --no-binary=gdal - name: Install PyTorch run: | TORCH_CPU="https://download.pytorch.org/whl/cpu" - TORCH_VERSION="1.13.0" - pip install --upgrade --no-cache-dir setuptools>=0.59.5 - pip install torch==${TORCH_VERSION} torchvision torchaudio --extra-index-url $TORCH_CPU - PYG_TORCH_CPU="https://data.pyg.org/whl/torch-${TORCH_VERSION}+cpu.html" - pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f $PYG_TORCH_CPU - if: steps.cnetenv.outputs.cache-hit != 'true' + TORCH_VERSION="2.2.2" + pip install -U --no-cache-dir setuptools>=65.5.1 + pip install torch==${TORCH_VERSION} torchvision==0.17.2 torchaudio==${TORCH_VERSION} --extra-index-url $TORCH_CPU + pip install natten==0.17.1+torch220cpu -f https://shi-labs.com/natten/wheels - name: Install cultionet run: | pip install . @@ -60,52 +59,3 @@ jobs: pip install pytest cd tests/ python -m pytest - -# Version: -# needs: Tests -# if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') -# runs-on: ubuntu-latest -# concurrency: release -# steps: -# - uses: actions/checkout@v2 -# with: -# fetch-depth: 0 -# token: ${{ secrets.CULTIONET_TOKEN }} -# - uses: actions/setup-python@v2 -# with: -# python-version: 3.8.12 -# - name: Python Semantic Release -# run: | -# python -m pip install python-semantic-release -# # Add credentials -# git config user.name "github-actions" -# git config user.email "github-actions@github.com" -# # Bump cultionet version -# semantic-release publish -# env: -# GH_TOKEN: ${{ secrets.CULTIONET_TOKEN }} - -# # https://github.com/fnkr/github-action-ghr -# Release: -# needs: Version -# if: github.event_name == 'push' && github.ref == 'refs/heads/main' && !contains(github.event.head_commit.message, 'chore(release):') -# name: Upload release -# runs-on: ubuntu-latest -# steps: -# - uses: actions/checkout@v2 -# - uses: actions/setup-python@v2 -# with: -# python-version: 3.8.12 -# - uses: syphar/restore-virtualenv@v1 -# id: cnetenv -# with: -# requirement_files: setup.cfg -# - uses: syphar/restore-pip-download-cache@v1 -# if: steps.cnetenv.outputs.cache-hit != 'true' -# - name: Checkout -# uses: fnkr/github-action-ghr@v1 -# if: startsWith(github.ref, 'refs/tags/') -# env: -# GHR_PATH: . -# GHR_COMPRESS: gz -# GITHUB_TOKEN: ${{ secrets.CULTIONET_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3d6302c9..270d3072 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,12 +3,14 @@ repos: rev: 22.3.0 hooks: - id: black + exclude: src/cultionet/utils/stats.py language_version: python3 args: [--skip-string-normalization] - repo: https://github.com/myint/docformatter rev: v1.4 hooks: - id: docformatter + exclude: src/cultionet/utils/stats.py args: [ --in-place, --wrap-summaries, @@ -20,3 +22,11 @@ repos: rev: 6.0.0 hooks: - id: flake8 + exclude: src/cultionet/utils/stats.py + - repo: https://github.com/pycqa/isort + rev: 5.11.5 + hooks: + - id: isort + exclude: src/cultionet/utils/stats.py + name: isort (python) + args: [--settings-path=pyproject.toml] diff --git a/README.md b/README.md index 25bf792e..34038f11 100644 --- a/README.md +++ b/README.md @@ -1,66 +1,65 @@ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) -[![python](https://img.shields.io/badge/Python-3.8%20%7C%203.9-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org) -[![](https://img.shields.io/github/v/release/jgrss/cultionet?display_name=release)](https://github.com/jgrss/cultionet/releases) +[![python](https://img.shields.io/badge/Python-3.9%20%7C%203.10-3776AB.svg?style=flat&logo=python&logoColor=white)](https://www.python.org) + +![](https://img.shields.io/badge/Version-v2.0.0b-8A2BE2) [![](https://github.com/jgrss/cultionet/actions/workflows/ci.yml/badge.svg)](https://github.com/jgrss/cultionet/actions?query=workflow%3ACI) -**cultionet** is a library for semantic segmentation of cultivated land using a neural network. There are various model configurations that can -be used in **cultionet**, but the base architecture is [UNet 3+](https://arxiv.org/abs/2004.08790) with [multi-stream decoders](https://arxiv.org/abs/1902.04099). The library is built on **[PyTorch Lightning](https://www.pytorchlightning.ai/)** and the segmentation objectives (class targets and losses) were designed following [previous work in the remote sensing community](https://www.sciencedirect.com/science/article/abs/pii/S0034425720301115). +## Cultionet -Below are highlights of **cultionet**: +Cultionet is a library for semantic segmentation of cultivated land with a neural network. The base architecture is a UNet variant, inspired by [UNet 3+](https://arxiv.org/abs/2004.08790) and [Psi-Net](https://arxiv.org/abs/1902.04099), with convolution blocks following [ResUNet-a](https://arxiv.org/abs/1904.00592). The library is built on [PyTorch Lightning](https://www.pytorchlightning.ai/) and the segmentation objectives (class targets and losses) were designed following [previous work in the remote sensing community](https://www.sciencedirect.com/science/article/abs/pii/S0034425720301115). -1. satellite image time series instead of individual dates for training and inference -2. [UNet 3+](https://arxiv.org/abs/2004.08790) [Psi](https://arxiv.org/abs/1902.04099) residual convolution (`ResUNet3Psi`) architecture -3. [Spatial-channel attention](https://www.mdpi.com/2072-4292/14/9/2253) -4. [Tanimoto loss](https://www.mdpi.com/2072-4292/13/18/3707) -5. Deep supervision and temporal features with [RNN STAR](https://www.sciencedirect.com/science/article/pii/S0034425721003230) -6. Deep, multi-output supervision +Key features of Cultionet: -## The cultionet input data +* uses satellite image time series instead of individual dates for training and inference +* uses a [Transformer](https://arxiv.org/abs/1706.03762) time series embeddings +* uses a UNet architecture with skip connections and deep supervision similar to [UNet 3+](https://arxiv.org/abs/2004.08790) +* uses multi-stream outputs inspired by [Psi-Net](https://arxiv.org/abs/1902.04099) +* uses residual [ResUNet-a](https://arxiv.org/abs/1904.00592) blocks with [Dilated Neighborhood Attention](https://arxiv.org/abs/2209.15001) +* uses the [Tanimoto loss](https://www.mdpi.com/2072-4292/13/18/3707) -The model inputs are satellite time series (e.g., bands or spectral indices). Data are stored in a [PyTorch Geometric Data object](https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.Data). For example, **cultionet** datasets will have data -that look something like the following. +## Install Cultionet -```python -from torch_geometric.data import Data +If PyTorch is installed -Data( - x=[10000, 65], y=[10000], bdist=[10000], - height=100, width=100, ntime=13, nbands=5, - zero_padding=0, start_year=2020, end_year=2021, - left=, bottom=, - right=, top=, - res=10.0, train_id='{site id}_2021_1_none', num_nodes=10000 -) +```commandline +pip install git@github.com:jgrss/cultionet.git ``` -where +See the [installation section](#installation) for more detailed instructions. -``` -x = input features = torch.Tensor of (samples x bands*time) -y = labels = torch.Tensor of (samples,) -bdist = distance transform = torch.Tensor of (samples,) -height = image height/rows = int -width = image width/columns = int -ntime = image time dimensions/sequence length = int -nbands = image band dimensions/channels = int -left = image left coordinate bounds = float -bottom = image bottom coordinate bounds = float -right = image right coordinate bounds = float -top = image top coordinate bounds = float -res = image spatial resolution = float -train_id = image id = str +--- + +## Data format + +The model inputs are satellite time series (e.g., bands or spectral indices). Data are stored in a PyTorch [Data](https://github.com/jgrss/cultionet/blob/99fb16797f2d84b812c47dd9d03aea92b6b7aefa/src/cultionet/data/data.py#L51) object. For example, Cultionet datasets will have data that look something like the following. + +```python +Data( + x=[1, 3, 12, 100, 100], y=[1, 100, 100], bdist=[1, 100, 100], + start_year=torch.tensor([2020]), end_year=torch.tensor([2021]), + left=torch.tensor([]), bottom=torch.tensor([]), + right=torch.tensor([]), top=torch.tensor([]), + res=torch.tensor([10.0]), batch_id=['{site id}_2021_1_none'], +) ``` -As an example, for a time series of red, green, blue, and NIR with 25 time steps (bi-weekly + 1 additional end point), -the data would be shaped like: +where ``` -x = [[r_w1, ..., r_w25, g_w1, ..., g_wN, b_w1, ..., b_wN, n_w1, ..., n_wN]] +x = input features = torch.Tensor of (batch x channels/bands x time x height x width) +y = labels = torch.Tensor of (batch x height x width) +bdist = distance transform = torch.Tensor of (batch x height x width) +left = image left coordinate bounds = torch.Tensor +bottom = image bottom coordinate bounds = torch.Tensor +right = image right coordinate bounds = torch.Tensor +top = image top coordinate bounds = torch.Tensor +res = image spatial resolution = torch.Tensor +batch_id = image id = list ``` -## Create train dataset +## Datasets -### Create the training data +### Create the vector training dataset Training data pairs should consist of two files per grid/year. One file is a polygon vector file (stored as a GeoPandas-compatible format like GeoPackage) of the training grid for a region. The other file is a polygon vector file (stored in the same format) @@ -75,76 +74,63 @@ of the training labels for a grid. **What is a training label?** > Training labels are __polygons__ of delineated cropland (i.e., crop fields). The training labels will be clipped to the -> training grid (described above). Thus, it is important to exhaustively digitize all crop fields within a grid. +> training grid (described above). Thus, it is important to digitize all crop fields within a grid unless data are to be used +> for partial labels. **Configuration file** -> The configuration file (`cultionet/scripts/config.yml`) is used to create training datasets. This file is only meant -> to be a template. For each project, copy this template and modify it accordingly. - -* image_vis - * A list of image indices to use for training. -* regions - * The start and end range of the training regions to use in the dataset. -* years - * A list of years to use in the training dataset. Image years correspond to the _end_ period of the time series. - Thus, 2021 would align with a time 2020-2021 series. +> The configuration file is used to create training datasets. Copy the [config template](scripts/config.yml) and modify it accordingly. **Training data requirements** > The polygon vector file should have a field with values for crop fields set equal to 1. Other crop classes are allowed and > can be recoded during the data creation step. However, the current version of cultionet expects the final data to be binary -> (i.e., 0=non-cropland; 1=cropland). For grids with all null data (i.e., non-crop), simply create an empty grid file. +> (i.e., 0=non-cropland; 1=cropland). For grids with all null data (i.e., non-crop), simply create a grid file with no intersecting +> crop polygons. **Training name requirements** -> The polygon/grid pairs should be named with the format **{region}_{poly}_{year}.gpkg**. The region name can be any string -> or integer. However, integers should have six character length (e.g., the region might correspond to grid 1 and be -> named '000001_poly_2020.gpkg'. +> There are no requirements. Simply specify the paths in the configuration file. -Example directory structure and format for training data. For a single AOI, there is a grid file and a polygon file. The -number of grid/polygon pairs is unlimited. +Example directory structure and format for training data. For each region, there is a grid file and a polygon file. The +number of grid/polygon pairs within the region is unlimited. ```yaml -project_dir: - user_train: - '{region}_grid_{time_series_end_year}.gpkg' - '{region}_poly_{time_series_end_year}.gpkg' +region_id_file: + - /user_data/training/grid_REGION_A_YEAR.gpkg + - /user_data/training/grid_REGION_B_YEAR.gpkg + - ... + +polygon_file: + - /user_data/training/crop_polygons_REGION_A_YEAR.gpkg + - /user_data/training/crop_polygons_REGION_B_YEAR.gpkg + - ... ``` -Using the format above, a train directory might look like: +The grid file should contain polygons of the AOIs. The AOIs represent the area that imagery will be clipped and masked to (only 1 km x 1 km has been tested). Required +columns include 'geo_id' and 'year', which are a unique identifier and the sampling year, respectively. -```yaml -project_dir: - user_train: - 'site1_grid_2021.gpkg' - 'site1_poly_2021.gpkg' - 'site1_grid_2022.gpkg' - 'site1_poly_2022.gpkg' - 'site2_grid_2020.gpkg' - 'site2_poly_2020.gpkg' - ... +```python +grid_df = gpd.read_file("/user_data/training/grid_REGION_A_YEAR.gpkg") +grid_df.head(2) + + geo_id year geometry +0 REGION_A_e3a4f2346f50984d87190249a5def1d0 2021 POLYGON ((... +1 REGION_A_18485a3271482f2f8a10bb16ae59be74 2021 POLYGON ((... ``` -or +The polygon file should contain polygons of field boundaries, with a column for the crop class. Any number of other columns can be included. Note that polygons do not need to be clipped to the grids. -```yaml -project_dir: - user_train: - '000001_grid_2021.gpkg' - '000001_poly_2021.gpkg' - '000001_grid_2022.gpkg' - '000001_poly_2022.gpkg' - '000002_grid_2020.gpkg' - '000002_poly_2020.gpkg' - ... +```python +import geopandas as gpd +poly_df = gpd.read_file("/user_data/training/crop_polygons_REGION_A_YEAR.gpkg") +poly_df.head(2) + crop_class geometry +0 1 POLYGON ((... +1 1 POLYGON ((... ``` -> **Note:** a site can have multiple grid/polygon pairs if collected across different timeframes - ### Create the image time series -This must be done outside of **cultionet**. Essentially, a directory with band or VI time series must be generated before -using **cultionet**. - -> **Note:** it is expected that the time series have length greater than 1 +This must be done outside of Cultionet. Essentially, a directory with band or VI time series must be generated before +using Cultionet. - The raster files should be stored as GeoTiffs with names that follow a date format (e.g., `yyyyddd.tif` or `yyymmdd.tif`). - The date format can be specified at the CLI. @@ -152,43 +138,42 @@ using **cultionet**. - Just note that a higher frequency will result in larger memory footprints for the GPU, plus slower training and inference. - While there is no requirement for the time series frequency, time series _must_ have different start and end years. - For example, a northern hemisphere time series might consist of (1 Jan 2020 to 1 Jan 2021) whereas a southern hemisphere time series might range from (1 July 2020 to 1 July 2021). In either case, note that something like (1 Jan 2020 to 1 Dec 2020) will not work. -- The years in the directories must align with the training data files. More specifically, the training data year (year in the polygon/grid pairs) should correspond to the time series end year. - - For example, a file named `000001_poly_2020.gpkg` should be trained on 2019-2020 imagery, while `000001_poly_2022.gpkg` would match a 2021-2022 time series. +- Time series should align with the training data files. More specifically, the training data year (year in the grid vector file) should correspond to the time series start year. + - For example, a training grid 'year' column equal to 2022 should be trained on a 2022-2023 image time series. - The image time series footprints (bounding box) can be of any size, but should encompass the training data bounds. During data creation (next step below), only the relevant bounds of the image are extracted and matched with the training data using the training grid bounds. -**Example time series directory with bi-weekly cadence for three VIs (i.e., evi2, gcvi, kndvi)** +Example time series directory with bi-weekly cadence for three VIs (i.e., evi2, gcvi, kndvi) ```yaml project_dir: time_series_vars: - region: + grid_id_a: evi2: - 2020001.tif - 2020014.tif - ... - 2021001.tif - 2021014.tif - ... 2022001.tif + 2022014.tif + ... + 2023001.tif gcvi: kndvi: + grid_id_b: + ``` -### Create the time series training data +### Create the time series training dataset After training data and image time series have been created, the training data PyTorch files (.pt) can be generated using the commands below. -> **Note:** Modify a copy of `cultionet/scripts/config.yml` as needed. +> **Note:** Modify a copy of the [config template](scripts/config.yml) as needed and save in the project directory. The command below assumes image time series are saved under `/project_dir/time_series_vars`. The training polygon and grid paths are taken from the config.yml file. + +This command would generate .pt files with image time series of 100 x 100 height/width and a spatial resolution of 10 meters. ```commandline -# Navigate to the cultionet script directory. -cd cultionet/scripts/ -# Activate the virtual environment. See installation section below for environment details. +# Activate your virtual environment. See installation section below for environment details. pyenv venv venv.cultionet # Create the training dataset. -(venv.cultionet) cultionet create --project-path /project_dir --grid-size 100 100 --config-file config.yml +(venv.cultionet) cultionet create --project-path /project_dir --grid-size 100 100 --destination train -r 10.0 --max-crop-class 1 --crop-column crop_class --image-date-format %Y%m%d --num-workers 8 --config-file config.yml ``` The output .pt data files will be stored in `/project_dir/data/train/processed`. Each .pt data file will consist of @@ -196,11 +181,10 @@ all the information needed to train the segmentation model. ## Training a model -To train the model, you will need to create the train dataset object and pass it to the **cultionet** fit method. A script -is provided to help ease this process. To train a model on a dataset, use (as an example): +To train a model on a dataset, use (as an example): ```commandline -(venv.cultionet) cultionet train --project-path /project_dir --val-frac 0.2 --random-seed 500 --batch-size 4 --epochs 30 --filters 32 --device gpu --patience 5 --learning-rate 0.001 --reset-model +(venv.cultionet) cultionet train --val-frac 0.2 --augment-prob 0.5 --epochs 100 --hidden-channels 32 --processes 8 --load-batch-workers 8 --batch-size 4 --accumulate-grad-batches 4 --dropout 0.2 --deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 --weight-decay 1e-4 --attention-weights natten ``` For more CLI options, see: @@ -209,7 +193,7 @@ For more CLI options, see: (venv.cultionet) cultionet train -h ``` -After a model has been fit, the last checkpoint file can be found at `/project_dir/ckpt/last.ckpt`. +After a model has been fit, the best/last checkpoint file can be found at `/project_dir/ckpt/last.ckpt`. ## Predicting on an image with a trained model @@ -228,164 +212,47 @@ After a model has been fit, the last checkpoint file can be found at `/project_d ## Installation -### (Option 1) Build Docker images - -If using a GPU with CUDA 11.3, see the **cultionet** [Dockerfile](https://github.com/jgrss/cultionet/blob/main/Dockerfile) -and [dockerfiles/README.md](https://github.com/jgrss/cultionet/blob/main/dockerfiles/README.md) to build a Docker image. - -If installing from scratch locally, see the instructions below. - -### (Option 2) Install with Conda Mamba on a CPU - -#### 1) Create a Conda `environment.yml` file with: - -```yaml -name: venv.cnet -channels: -- defaults -dependencies: -- python=3.8.12 -- libgcc -- libspatialindex -- libgdal=3.4.1 -- gdal=3.4.1 -- numpy>=1.22.0 -- pip -``` - -#### 2) Install Python packages +#### Install Cultionet (assumes a working CUDA installation) +1. Create a new virtual environment (example using [pyenv](https://github.com/pyenv/pyenv)) ```commandline -conda install -c conda-forge mamba -conda config --add channels conda-forge -mamba env create --file environment.yml -conda activate venv.cnet -(venv.cnet) mamba install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch -(venv.cnet) mamba install pyg -c pyg -(venv.cnet) pip install -U pip setuptools wheel -(venv.cnet) pip install cultionet@git+https://github.com/jgrss/cultionet.git +pyenv virtualenv 3.10.14 venv.cultionet +pyenv activate venv.cultionet ``` -### (Option 3) Install with pip on a CPU - -This section assumes you have all the necessary Linux builds, such as GDAL. If not, see the next installation section. - -#### Install Python packages - -```commandline -pyenv virtualenv 3.8.12 venv.cnet -pyenv activate venv.cnet -(venv.cnet) pip install -U pip setuptools wheel numpy cython -(venv.cnet) pip install gdal==$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') --no-binary=gdal -(venv.cnet) pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu -(venv.cnet) TORCH_VERSION=$(python -c "import torch;print(torch.__version__)") -(venv.cnet) pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-${TORCH_VERSION}.html -(venv.cnet) pip install cultionet@git+https://github.com/jgrss/cultionet.git -``` - -### (Option 4) Install CUDA and built GPU packages - -1. Install NVIDIA driver (skip if using the CPU) - -```commandline -sudo add-apt-repository ppa:graphics-drivers/ppa -sudo apt-get update -sudo apt install ubuntu-drivers-common -ubuntu-drivers devices -sudo apt install nvidia-driver-465 -``` - -`reboot machine` - -2. Install CUDA toolkit (skip if using the CPU) -> See https://developer.nvidia.com/cuda-11.3.0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_local - -`reboot machine` - -3. Install Pyenv -> See https://github.com/pyenv/pyenv/wiki#suggested-build-environment - -```commandline -curl https://pyenv.run | bash - -sudo apt-get update; sudo apt-get install make build-essential libssl-dev zlib1g-dev \ -libbz2-dev libreadline-dev libsqlite3-dev wget curl llvm \ -libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev -``` - -4. Add to the ~/.bashrc: +2. Update install numpy and Python GDAL (assumes GDAL binaries are already installed) ```commandline -echo 'export PYENV_ROOT="$HOME/.pyenv" -export PATH="$PYENV_ROOT/bin:$PATH" -if which pyenv > /dev/null; then eval "$(pyenv init --path)"; fi -if which pyenv > /dev/null; then eval "$(pyenv init -)"; fi -if which pyenv > /dev/null; then eval "$(pyenv virtualenv-init -)"; fi' >> ~/.bashrc -source ~/.bashrc -``` -Then run ~/.bashrc -``` commandline -source ~/.bashrc -``` - -5. Install new version of Python -```commandline -pyenv install 3.8.12 -``` - -6. Create a new virtual environment -```commandline -pyenv virtualenv 3.8.12 venv.cultionet -``` - -7. Install libraries -```commandline -pyenv activate venv.seg +(venv.cultionet) pip install -U pip +(venv.cultionet) pip install -U setuptools wheel +pip install -U numpy==1.24.4 +(venv.cultionet) pip install setuptools==57.5.0 +(venv.cultionet) GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') +(venv.cultionet) pip install GDAL==$GDAL_VERSION --no-binary=gdal ``` -8. Update install libraries +3. Install PyTorch 2.2.1 for CUDA 11.4 and 11.8 ```commandline -(venv.cultionet) pip install -U pip setuptools wheel "cython>=0.29.*" "numpy<=1.21.0" -# required to build GDAL Python bindings for 3.2.1 -(venv.cultionet) pip install --upgrade --no-cache-dir "setuptools<=58.*" +(venv.cultionet) pip install -U --no-cache-dir setuptools>=65.5.1 +(venv.cultionet) pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.2.2 --index-url https://download.pytorch.org/whl/cu118 ``` -9. Install PyTorch -> See https://pytorch.org/get-started/locally/ -```commandline -(venv.cultionet) pip install torch==1.10.1+cu113 torchvision==0.11.2+cu113 torchaudio==0.10.1+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html -``` +The command below should print `True` if PyTorch can access a GPU. ```commandline python -c "import torch;print(torch.cuda.is_available())" ``` -10. Install PyTorch geometric dependencies +4. Install `natten` for CUDA 11.8 if using [neighborhood attention](https://github.com/SHI-Labs/NATTEN). ```commandline -(venv.cultionet) pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric torch-geometric-temporal -f https://data.pyg.org/whl/torch-1.10.1+cu113.html +(venv.cultionet) pip install natten==0.17.1+torch220cu118 -f https://shi-labs.com/natten/wheels ``` -11. Install GDAL -```commandline -sudo add-apt-repository ppa:ubuntugis/ppa -sudo apt install build-essential -sudo apt update -sudo apt install libspatialindex-dev libgdal-dev gdal-bin +5. Install cultionet -export CPLUS_INCLUDE_PATH=/usr/include/gdal -export C_INCLUDE_PATH=/usr/include/gdal -``` - -12. Install GDAL Python bindings ```commandline -(venv.cultionet) pip install GDAL==3.2.1 +(venv.cultionet) pip install git@github.com:jgrss/cultionet.git ``` -### Package - -Install **cultionet** +### Installing CUDA on Ubuntu -```commandline -git clone git@github.com:jgrss/cultionet.git -cd cultionet -(venv.cultionet) pip install . -``` +See [CUDA installation](docs/cuda_installation.md) diff --git a/dockerfiles/Dockerfile_cuda117_torch2.0 b/dockerfiles/Dockerfile_cuda117_torch2.0 new file mode 100644 index 00000000..75add9dd --- /dev/null +++ b/dockerfiles/Dockerfile_cuda117_torch2.0 @@ -0,0 +1,58 @@ +FROM nvidia/cuda:11.6.0-base-ubuntu20.04 + +# Install GDAL +RUN apt update -y && \ + apt upgrade -y && \ + apt install software-properties-common -y && \ + add-apt-repository ppa:ubuntugis/ubuntugis-unstable -y && \ + apt update -y && \ + apt install \ + build-essential \ + python3.8 \ + python3-pip \ + libgeos++-dev \ + libgeos-3.8.0 \ + libgeos-c1v5 \ + libgeos-dev \ + libgeos-doc \ + libspatialindex-dev \ + g++ \ + libgdal-dev \ + gdal-bin \ + libproj-dev \ + libspatialindex-dev \ + geotiff-bin \ + libgl1 \ + git -y + +ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" +ENV C_INCLUDE_PATH="/usr/include/gdal" +ENV LD_LIBRARY_PATH="/usr/local/lib" +ENV PATH="/root/.local/bin:$PATH" + +RUN pip install -U pip setuptools wheel +RUN pip install -U --no-cache-dir "setuptools>=59.5.0" +RUN pip install -U "Cython>=0.29.0,<3.0.0" numpy>=1.22.0 +RUN pip install intel-openmp + +# Install PyTorch Geometric and its dependencies +RUN pip install \ + torch>=2.0.0 \ + torchvision \ + torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 + +RUN TORCH_VERSION=`(python -c "import torch;print(torch.__version__)")` && + pip install \ + torch-scatter \ + torch-sparse \ + torch-cluster \ + torch-spline-conv \ + torch-geometric --extra-index-url https://data.pyg.org/whl/torch-${TORCH_VERSION}.html + +RUN GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') && \ + pip install GDAL==$GDAL_VERSION --no-binary=gdal + +# Install cultionet +RUN pip install --user cultionet@git+https://github.com/jgrss/cultionet.git@jgrss/transfer_ltae + +CMD ["cultionet"] diff --git a/dockerfiles/Dockerfile_cuda118_torch2.0 b/dockerfiles/Dockerfile_cuda118_torch2.0 new file mode 100644 index 00000000..c9eabe36 --- /dev/null +++ b/dockerfiles/Dockerfile_cuda118_torch2.0 @@ -0,0 +1,58 @@ +FROM nvidia/cuda:11.6.0-base-ubuntu20.04 + +# Install GDAL +RUN apt update -y && \ + apt upgrade -y && \ + apt install software-properties-common -y && \ + add-apt-repository ppa:ubuntugis/ubuntugis-unstable -y && \ + apt update -y && \ + apt install \ + build-essential \ + python3.8 \ + python3-pip \ + libgeos++-dev \ + libgeos-3.8.0 \ + libgeos-c1v5 \ + libgeos-dev \ + libgeos-doc \ + libspatialindex-dev \ + g++ \ + libgdal-dev \ + gdal-bin \ + libproj-dev \ + libspatialindex-dev \ + geotiff-bin \ + libgl1 \ + git -y + +ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" +ENV C_INCLUDE_PATH="/usr/include/gdal" +ENV LD_LIBRARY_PATH="/usr/local/lib" +ENV PATH="/root/.local/bin:$PATH" + +RUN pip install -U pip setuptools wheel +RUN pip install -U --no-cache-dir "setuptools>=59.5.0" +RUN pip install -U "Cython>=0.29.0,<3.0.0" numpy>=1.22.0 +RUN pip install intel-openmp + +# Install PyTorch Geometric and its dependencies +RUN pip install \ + torch>=2.0.0 \ + torchvision \ + torchaudio --extra-index-url https://download.pytorch.org/whl/cu118 + +RUN TORCH_VERSION=`(python -c "import torch;print(torch.__version__)")` && + pip install \ + torch-scatter \ + torch-sparse \ + torch-cluster \ + torch-spline-conv \ + torch-geometric --extra-index-url https://data.pyg.org/whl/torch-${TORCH_VERSION}.html + +RUN GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') && \ + pip install GDAL==$GDAL_VERSION --no-binary=gdal + +# Install cultionet +RUN pip install --user cultionet@git+https://github.com/jgrss/cultionet.git@jgrss/transfer_ltae + +CMD ["cultionet"] diff --git a/dockerfiles/Dockerfile_cuda121_torch2.1 b/dockerfiles/Dockerfile_cuda121_torch2.1 new file mode 100644 index 00000000..7822b494 --- /dev/null +++ b/dockerfiles/Dockerfile_cuda121_torch2.1 @@ -0,0 +1,50 @@ +FROM nvidia/cuda:12.1.0-base-ubuntu20.04 + +# Install GDAL +RUN apt update -y && \ + apt upgrade -y && \ + apt install software-properties-common -y && \ + add-apt-repository ppa:ubuntugis/ubuntugis-unstable -y && \ + apt update -y && \ + apt install \ + build-essential \ + python3.8 \ + python3-pip \ + libgeos++-dev \ + libgeos-3.8.0 \ + libgeos-c1v5 \ + libgeos-dev \ + libgeos-doc \ + libspatialindex-dev \ + g++ \ + libgdal-dev \ + gdal-bin \ + libproj-dev \ + libspatialindex-dev \ + geotiff-bin \ + libgl1 \ + git -y + +ENV CPLUS_INCLUDE_PATH="/usr/include/gdal" +ENV C_INCLUDE_PATH="/usr/include/gdal" +ENV LD_LIBRARY_PATH="/usr/local/lib" +ENV PATH="/root/.local/bin:$PATH" + +RUN pip install -U pip setuptools wheel +RUN pip install -U --no-cache-dir "setuptools>=59.5.0" +RUN pip install -U "Cython>=0.29.0,<3.0.0" numpy>=1.22.0 +RUN pip install intel-openmp + +# Install PyTorch Geometric and its dependencies +RUN pip install \ + torch==2.1.0 \ + torchvision \ + torchaudio --extra-index-url https://download.pytorch.org/whl/cu121 + +RUN GDAL_VERSION=$(gdal-config --version | awk -F'[.]' '{print $1"."$2"."$3}') && \ + pip install GDAL==$GDAL_VERSION --no-binary=gdal + +# Install cultionet +RUN pip install --user cultionet@git+https://github.com/jgrss/cultionet.git@jgrss/transfer_ltae + +CMD ["cultionet"] diff --git a/docs/cuda_installation.md b/docs/cuda_installation.md new file mode 100644 index 00000000..a831429b --- /dev/null +++ b/docs/cuda_installation.md @@ -0,0 +1,27 @@ +## (Option 1) Build Docker images + +If using a GPU with CUDA 11.3, see the cultionet [Dockerfile](https://github.com/jgrss/cultionet/blob/main/Dockerfile) +and [dockerfiles/README.md](https://github.com/jgrss/cultionet/blob/main/dockerfiles/README.md) to build a Docker image. + +If installing from scratch locally, see the instructions below. + +## (Option 2) Install locally with GPU + +### Install CUDA driver, if necessary + +1. Install NVIDIA driver + +```commandline +sudo add-apt-repository ppa:graphics-drivers/ppa +sudo apt-get update +sudo apt install ubuntu-drivers-common +ubuntu-drivers devices +sudo apt install nvidia-driver-465 +``` + +`reboot machine` + +2. Install CUDA toolkit +> See https://developer.nvidia.com/cuda-11.3.0-download-archive?target_os=Linux&target_arch=x86_64&Distribution=Ubuntu&target_version=20.04&target_type=deb_local + +`reboot machine` \ No newline at end of file diff --git a/docs/requirements.txt b/docs/requirements.txt index ceb2bdbe..01840c6b 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -25,7 +25,7 @@ tensorboard>=2.2.0 PyYAML>=5.1 geowombat@git+https://github.com/jgrss/geowombat.git tsaug@git+https://github.com/jgrss/tsaug.git -setuptools==59.5.0 +setuptools>=70 numpydoc sphinx sphinx-automodapi diff --git a/pyproject.toml b/pyproject.toml index a9626824..639eb319 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,8 @@ [build-system] requires = [ - 'setuptools>=65.5.1', + 'setuptools>=70', 'wheel', - 'numpy>=1.22.0', - 'Cython>=0.29.0,<3.0.0', + 'numpy<2,>=1.22', ] [tool.black] @@ -24,3 +23,11 @@ exclude = ''' | dist )/ ''' + +[tool.isort] +line_length = 79 +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +ensure_newline_before_comments = true diff --git a/scripts/config.yml b/scripts/config.yml index 5011b617..4ccc8e29 100644 --- a/scripts/config.yml +++ b/scripts/config.yml @@ -1,21 +1,51 @@ +# The bands or band indices to use +# +# Time series should be store din `time_series_vars` under the project directory. +# E.g., +# time_series_vars//evi2/20210101.tif +# ... +# time_series_vars//evi2/20220101.tif +# +# See the README.md for more details. image_vis: + - avi - evi2 - gcvi - kndvi + - sipi1 -# The regions to process (start, end) -regions: - - 161 - - 171 +# The training region/grid file path +# +# This file should contain each training site as a polygon feature. +# There should be a 'geo_id' column that contains the unique site id. This site id should +# match the corresponding time series variables (e.g., time_series_vars//evi2/20210101.tif). +# +# geo_id year geometry +# 0 site_id_1 2019 POLYGON ((... +# ... ... ... ... +# N site_id_n 2021 POLYGON ((... +region_id_file: + - /home/grids-train.gpkg -# End years (i.e., 2020 = 2019 planting/harvest year) -# 2019 = 2018 CDL -# 2020 = 2019 CDL -# 2021 = 2020 CDL -# 2022 = 2021 CDL -years: - - 2020 - - 2021 - - 2022 +# The training field/boundaries file path +# +# This file should field polygons. The fields do not need to be clipped to the grids. +# Note that grids with no intersecting polygons will be used in training as treated as +# entirely non-cultivated (i.e., all zeros). There should be a column that defines the +# crop class. For a binary model (crop|not crop), this column can be filled with 1s. +# +# geo_id year crop geometry +# 0 poly_id_1 2020 1 POLYGON ((... +# ... ... ... ... ... +# N poly_id_n 2019 1 POLYGON ((... +polygon_file: + - /home/fields-train.gpkg -lc_path: !!null +# Each year in `region_id_file` should correspond to the year of harvest +# For US harvest year 2019, an end date of 12-31 would mean 2019-01-01 to 2020-01-01 +# For Argentina harvest year 2019, an end date of 07-01 would mean 2018-07-01 to 2019-07-01 +start_mmdd: '01-01' +end_mmdd: '12-31' + +# The length of the time series +num_months: 12 diff --git a/scripts/move_and_reshape_data.py b/scripts/move_and_reshape_data.py new file mode 100644 index 00000000..5da2d306 --- /dev/null +++ b/scripts/move_and_reshape_data.py @@ -0,0 +1,90 @@ +import argparse +from pathlib import Path + +import joblib +import torch +from einops import rearrange +from tqdm import tqdm + +from cultionet.data import Data + + +def reshape_batch(filename: Path) -> Data: + # Load the old file + batch = joblib.load(filename) + + batch_x = rearrange( + batch.x, + '(h w) (c t) -> 1 c t h w', + c=batch.nbands, + t=batch.ntime, + h=batch.height, + w=batch.width, + ) + batch_y = rearrange( + batch.y, '(h w) -> 1 h w', h=batch.height, w=batch.width + ) + batch_bdist = rearrange( + batch.bdist, '(h w) -> 1 h w', h=batch.height, w=batch.width + ) + + return Data( + x=batch_x, + y=batch_y, + bdist=batch_bdist, + start_year=torch.tensor([batch.start_year]).long(), + end_year=torch.tensor([batch.end_year]).long(), + left=torch.tensor([batch.left]).float(), + bottom=torch.tensor([batch.bottom]).float(), + right=torch.tensor([batch.right]).float(), + top=torch.tensor([batch.top]).float(), + res=torch.tensor([batch.res]).float(), + batch_id=[batch.train_id], + ) + + +def read_and_move( + input_data_path: str, + output_data_path: str, +): + input_data_path = Path(input_data_path) + output_data_path = Path(output_data_path) + output_data_path.mkdir(parents=True, exist_ok=True) + + # Get raw data only + data_list = list(input_data_path.glob("*_none.pt")) + + for fn in tqdm(data_list, desc='Moving files'): + new_batch = reshape_batch(fn) + new_batch.to_file(output_data_path / fn.name) + + +def main(): + parser = argparse.ArgumentParser( + description="Move and reshape data batches", + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--input-data-path", + dest="input_data_path", + help="The input path of data to reshape (default: %(default)s)", + default=None, + ) + parser.add_argument( + "--output-data-path", + dest="output_data_path", + help="The output path of reshaped data (default: %(default)s)", + default=None, + ) + + args = parser.parse_args() + + read_and_move( + input_data_path=args.input_data_path, + output_data_path=args.output_data_path, + ) + + +if __name__ == '__main__': + main() diff --git a/setup.cfg b/setup.cfg index c6f12693..e02e65b8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -12,10 +12,11 @@ version = attr: cultionet.__version__ classifiers = Intended Audience :: Science/Research Topic :: Scientific :: Agriculture + Topic :: Scientific :: Cropland Topic :: Scientific :: Neural Network + Topic :: Scientific :: Time series Topic :: Scientific :: Segmentation - Programming Language :: Cython - Programming Language :: Python :: 3.8 :: 3.9 + Programming Language :: Python :: 3.9 :: 3.10 [options] package_dir= @@ -23,42 +24,47 @@ package_dir= packages=find: include_package_data = True setup_requires = - setuptools>=65.5.1 - Cython>=0.29.0,<3.0.0 - numpy>=1.22.0 + setuptools>=70 + wheel + numpy<2,>=1.22 python_requires = - >=3.8.0,<4.0.0 + >=3.9,<3.11 install_requires = - attrs>=21.0 - frozendict>=2.2.0 - frozenlist>=1.3.0 - numpy>=1.22.0 - scipy>=1.5.0 - pandas>=1.0.0,<=1.3.5 - geopandas>=0.10.0 - rasterio - shapely>=1.8.0 - scikit-image>=0.19.0 - xarray>=2022.6.0 - opencv-python>=4.5.5.0 + attrs>=21 + dask>=2024.8.0 + distributed>=2024.8.0 + xarray>=2024.7.0 + frozendict>=2.2 + frozenlist>=1.3 + numpy<2,>=1.22 + scipy>=1.5 + pandas>=1 + geopandas>=0.10 + rasterio<2,>=1.3 + shapely>=1.8 + fiona>=1.9 + scikit-image>=0.19 + opencv-python>=4.5.5 decorator==4.4.2 rtree>=0.9.7 - graphviz>=0.19.0 - tqdm>=4.62.0 + graphviz>=0.19 + tqdm>=4.66 pyDeprecate==0.3.1 future>=0.17.1 - tensorboard>=2.2.0 + tensorboard>=2.2 PyYAML>=5.1 - pytorch_lightning>=1.7.6,<=1.9.5 - torchmetrics>=0.10.0,<0.11.0 - ray>=2.0.0,<=2.1.0 - gudhi>=3.7.1 - pyarrow>=11.0.0 - geowombat@git+https://github.com/jgrss/geowombat.git@v2.1.9 + lightning>=2.2 + torchmetrics>=1.3 + einops>=0.7 + ray>=2.34 + pyarrow>=11 + typing-extensions + lz4 + rich-argparse + pyogrio>=0.7 + geowombat@git+https://github.com/jgrss/geowombat.git tsaug@git+https://github.com/jgrss/tsaug.git - geosample@git+https://github.com/jgrss/geosample.git@v1.0.1 - setuptools>=65.5.1 - Cython>=0.29.0,<3.0.0 + pygrts@git+https://github.com/jgrss/pygrts.git@v1.4.1 [options.extras_require] docs = numpydoc @@ -67,6 +73,7 @@ docs = numpydoc test = black flake8 docformatter + pytest [options.entry_points] console_scripts = diff --git a/setup.py b/setup.py index 3141a208..181d5ff3 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,4 @@ from distutils.core import setup -from distutils.extension import Extension -from Cython.Build import cythonize - -try: - from Cython.Distutils import build_ext -except: - from distutils.command import build_ext - -try: - import numpy as np -except: - raise ImportError('NumPy must be installed.') - - -def get_extensions(): - return [Extension('*', sources=['src/cultionet/networks/_build_network.pyx'])] - - -def setup_package(): - metadata = dict( - ext_modules=cythonize(get_extensions()), - include_dirs=[np.get_include()] - ) - - setup(**metadata) - if __name__ == '__main__': - setup_package() + setup() diff --git a/src/cultionet/__init__.py b/src/cultionet/__init__.py index 9db456b7..f0dd0c66 100644 --- a/src/cultionet/__init__.py +++ b/src/cultionet/__init__.py @@ -1,5 +1,10 @@ __path__: str = __import__("pkgutil").extend_path(__path__, __name__) -__version__ = "1.7.3" -from .model import fit, load_model, predict, predict_lightning +__version__ = "2.0.0b" +from .model import fit, fit_transfer, load_model, predict_lightning -__all__ = ["fit", "fit_maskrcnn", "load_model", "predict", "predict_lightning"] +__all__ = [ + "fit", + "fit_transfer", + "load_model", + "predict_lightning", +] diff --git a/src/cultionet/augment/augmenter_utils.py b/src/cultionet/augment/augmenter_utils.py index 8642ad93..582494a4 100644 --- a/src/cultionet/augment/augmenter_utils.py +++ b/src/cultionet/augment/augmenter_utils.py @@ -1,17 +1,16 @@ import typing as T import numpy as np -from scipy.ndimage.measurements import label as nd_label -from tsaug import AddNoise, Drift, TimeWarp import torch +import torch.nn.functional as F +from einops import rearrange +from tsaug import AddNoise, Drift, TimeWarp -from ..data.utils import LabeledData +from ..data import Data -def feature_stack_to_tsaug( - x: np.ndarray, ntime: int, nbands: int, nrows: int, ncols: int -) -> np.ndarray: - """Reshapes from (T*C x H x W) -> (H*W x T X C) +def feature_stack_to_tsaug(x: torch.Tensor) -> torch.Tensor: + """Reshapes from (1 x C x T x H x W) -> (H*W x T X C) where, T = time @@ -19,24 +18,18 @@ def feature_stack_to_tsaug( H = height W = width - Args: - x: The array to reshape. The input shape is (T*C x H x W). - ntime: The number of array time periods (T). - nbands: The number of array bands/channels (C). - nrows: The number of array rows (H). - ncols: The number of array columns (W). + Parameters + ========== + x + The array to reshape. The input shape is (1 x C x T x H x W). """ - return ( - x.transpose(1, 2, 0) - .reshape(nrows * ncols, ntime * nbands) - .reshape(nrows * ncols, ntime, nbands) - ) + return rearrange(x, '1 c t h w -> (h w) t c') def tsaug_to_feature_stack( - x: np.ndarray, nfeas: int, nrows: int, ncols: int -) -> np.ndarray: - """Reshapes from (H*W x T X C) -> (T*C x H x W) + x: torch.Tensor, height: int, width: int +) -> torch.Tensor: + """Reshapes from (H*W x T X C) -> (1 x C x T x H x W) where, T = time @@ -44,179 +37,324 @@ def tsaug_to_feature_stack( H = height W = width - Args: - x: The array to reshape. The input shape is (H*W x T X C). - nfeas: The number of array features (time x channels). - nrows: The number of array rows (height). - ncols: The number of array columns (width). + Parameters + ========== + x + The array to reshape. The input shape is (H*W x T X C). + height + The number of array rows (height). + width + The number of array columns (width). """ - return x.reshape(nrows * ncols, nfeas).T.reshape(nfeas, nrows, ncols) - - -def get_prop_data( - ldata: LabeledData, p: T.Any, x: np.ndarray -) -> T.Tuple[tuple, np.ndarray, np.ndarray]: - # Get the segment bounding box - min_row, min_col, max_row, max_col = p.bbox - bounds_slice = (slice(min_row, max_row), slice(min_col, max_col)) - # Get the segment features within the bounds - xseg = x[(slice(0, None),) + bounds_slice].copy() - # Get the segments within the bounds - seg = ldata.segments[bounds_slice].copy() - # Get the segment mask - mask = np.uint8(seg == p.label)[np.newaxis] - - return bounds_slice, xseg, mask - - -def reinsert_prop( - x: np.ndarray, - bounds_slice: tuple, - mask: np.ndarray, - x_update: np.ndarray, - x_original: np.ndarray, -) -> np.ndarray: - x[(slice(0, None),) + bounds_slice] = np.where( - mask == 1, x_update, x_original + return rearrange( + x, + '(h w) t c -> 1 c t h w', + h=height, + w=width, + ) + + +class SegmentParcel: + def __init__( + self, + coords_slices: tuple, + dims_slice: tuple, + xseg: torch.Tensor, + ): + self.coords_slices = coords_slices + self.dims_slice = dims_slice + self.xseg = xseg + + @classmethod + def from_prop(cls, ldata: Data, p: T.Any) -> "SegmentParcel": + # Get the segment bounding box + min_row, min_col, max_row, max_col = p.bbox + coords_slices = (slice(0, None),) * 3 + dims_slice = ( + slice(min_row, max_row), + slice(min_col, max_col), + ) + + # Get the segment features within the bounds + xseg = ldata.x[coords_slices + dims_slice] + + return cls( + coords_slices=coords_slices, + dims_slice=dims_slice, + xseg=xseg, + ) + + +def insert_parcel( + parcel_data: Data, + augmented: torch.Tensor, + segment_parcel: SegmentParcel, + prop: object, +) -> Data: + parcel_data.x[ + segment_parcel.coords_slices + segment_parcel.dims_slice + ] = torch.where( + rearrange( + torch.from_numpy(parcel_data.segments)[segment_parcel.dims_slice], + 'h w -> 1 1 1 h w', + ) + == prop.label, + augmented, + parcel_data.x[ + segment_parcel.coords_slices + segment_parcel.dims_slice + ], ) - return x + return parcel_data def augment_time( - ldata: LabeledData, + ldata: Data, p: T.Any, - x: np.ndarray, - ntime: int, - nbands: int, add_noise: bool, warper: T.Union[AddNoise, Drift, TimeWarp], aug: str, ) -> np.ndarray: """Applies temporal augmentation to a dataset.""" - bounds_slice, xseg, mask = get_prop_data(ldata=ldata, p=p, x=x) + segment_parcel = SegmentParcel.from_prop(ldata=ldata, p=p) - # xseg shape = (ntime*nbands x nrows x ncols) - xseg_original = xseg.copy() - nfeas, nrows, ncols = xseg.shape - assert nfeas == int( - ntime * nbands - ), "The array feature dimensions do not match the expected shape." + ( + num_batch, + num_channels, + num_time, + height, + width, + ) = segment_parcel.xseg.shape - # (H*W x T X C) - xseg = feature_stack_to_tsaug(xseg, ntime, nbands, nrows, ncols) + # (1 x C x T x H x W) -> (H*W x T X C) + xseg = feature_stack_to_tsaug(segment_parcel.xseg) if aug == "tspeaks": - new_indices = np.sort( - np.random.choice( - range(0, ntime * 2 - 8), replace=False, size=ntime - ) + half_a = F.interpolate( + rearrange(xseg, 'b t c -> b c t'), + size=num_time // 2, + mode='linear', + ) + half_b = F.interpolate( + rearrange(xseg, 'b t c -> b c t'), + size=num_time - half_a.shape[-1], + mode='linear', + ) + xseg = rearrange( + torch.cat((half_a, half_b), dim=-1), + 'b c t -> b t c', ) - xseg = np.concatenate((xseg, xseg), axis=1)[:, 4:-4][:, new_indices] + # Warp the segment - xseg = warper.augment(xseg) + xseg = warper.augment(xseg.numpy()) + if add_noise: noise_warper = AddNoise(scale=np.random.uniform(low=0.01, high=0.05)) xseg = noise_warper.augment(xseg) - # Reshape back from (H*W x T x C) -> (T*C x H x W) - xseg = tsaug_to_feature_stack(xseg, nfeas, nrows, ncols).clip(0, 1) - - # Insert back into full array - x = reinsert_prop( - x=x, - bounds_slice=bounds_slice, - mask=mask, - x_update=xseg, - x_original=xseg_original, - ) - return x + # Reshape back from (H*W x T x C) -> (1 x C x T x H x W) + xseg = tsaug_to_feature_stack( + torch.from_numpy(xseg), height=height, width=width + ).clip(0, 1) + + return insert_parcel( + parcel_data=ldata, + augmented=xseg, + segment_parcel=segment_parcel, + prop=p, + ) def roll_time( - ldata: LabeledData, p: T.Any, x: np.ndarray, ntime: int -) -> np.ndarray: - bounds_slice, xseg, mask = get_prop_data(ldata=ldata, p=p, x=x) - xseg_original = xseg.copy() + ldata: Data, + p: object, + rng: T.Optional[np.random.Generator] = None, + random_seed: T.Optional[int] = None, +) -> Data: + if rng is None: + rng = np.random.default_rng(random_seed) + + segment_parcel = SegmentParcel.from_prop(ldata=ldata, p=p) # Get a temporal shift for the object - shift = np.random.choice( - range(-int(x.shape[0] * 0.25), int(x.shape[0] * 0.25) + 1), size=1 - )[0] - # Shift time in each band separately - for b in range(0, xseg.shape[0], ntime): - # Get the slice for the current band, n time steps - xseg[b : b + ntime] = np.roll(xseg[b : b + ntime], shift=shift, axis=0) - - # Insert back into full array - x = reinsert_prop( - x=x, - bounds_slice=bounds_slice, - mask=mask, - x_update=xseg, - x_original=xseg_original, + shift = rng.choice( + range(-int(ldata.num_time * 0.25), int(ldata.num_time * 0.25) + 1) ) - return x + # Shift time + # (1 x C x T x H x W) + xseg = torch.roll(segment_parcel.xseg, shift, dims=2) + return insert_parcel( + parcel_data=ldata, + augmented=xseg, + segment_parcel=segment_parcel, + prop=p, + ) -def create_parcel_masks( - labels_array: np.ndarray, max_crop_class: int -) -> T.Union[None, dict]: - """Creates masks for each instance. - Reference: - https://torchtutorialstaging.z5.web.core.windows.net/intermediate/torchvision_tutorial.html +def interpolant(t): + return t * t * t * (t * (t * 6 - 15) + 10) + + +def scale_min_max( + x: torch.Tensor, + in_range: tuple, + out_range: tuple, +) -> torch.Tensor: + min_in, max_in = in_range + min_out, max_out = out_range + + return (((max_out - min_out) * (x - min_in)) / (max_in - min_in)) + min_out + + +def generate_perlin_noise_3d( + shape: T.Tuple[int, int, int], + res: T.Tuple[int, int, int], + tileable: T.Tuple[bool, bool, bool] = ( + False, + False, + False, + ), + out_range: T.Optional[T.Tuple[float, float]] = None, + interpolant: T.Callable = interpolant, + rng: T.Optional[np.random.Generator] = None, + random_seed: T.Optional[int] = None, +) -> torch.Tensor: + """Generates a 3D tensor of perlin noise. + + Parameters + ========== + shape + The shape of the generated array (tuple of three ints). This must + be a multiple of res. + res + The number of periods of noise to generate along each + axis (tuple of three ints). Note shape must be a multiple of res. + tileable + If the noise should be tileable along each axis (tuple of three bools). + Defaults to (False, False, False). + interpolant + The interpolation function, defaults to t*t*t*(t*(t*6 - 15) + 10). + + Returns + ======= + A tensor with the generated noise. + + Raises: + ValueError: If shape is not a multiple of res. + + Source: + https://github.com/pvigier/perlin-numpy/tree/master + + MIT License + Copyright (c) 2019 Pierre Vigier """ - # Remove edges - mask = np.where( - (labels_array > 0) & (labels_array <= max_crop_class), 1, 0 + if out_range is None: + out_range = (-0.1, 0.1) + + if rng is None: + rng = np.random.default_rng(random_seed) + + delta = (res[0] / shape[0], res[1] / shape[1], res[2] / shape[2]) + d = (shape[0] // res[0], shape[1] // res[1], shape[2] // res[2]) + grid = np.mgrid[ + : res[0] : delta[0], : res[1] : delta[1], : res[2] : delta[2] + ] + grid = np.mgrid[ + : res[0] : delta[0], : res[1] : delta[1], : res[2] : delta[2] + ] + grid = grid.transpose(1, 2, 3, 0) % 1 + + grid = torch.from_numpy(grid) + + # Gradients + torch.manual_seed(rng.integers(low=0, high=2147483647)) + theta = 2 * np.pi * torch.rand(res[0] + 1, res[1] + 1, res[2] + 1) + torch.manual_seed(rng.integers(low=0, high=2147483647)) + phi = 2 * np.pi * torch.rand(res[0] + 1, res[1] + 1, res[2] + 1) + gradients = torch.stack( + ( + torch.sin(phi) * torch.cos(theta), + torch.sin(phi) * torch.sin(theta), + torch.cos(phi), + ), + axis=3, + ) + + if tileable[0]: + gradients[-1] = gradients[0] + if tileable[1]: + gradients[:, -1] = gradients[:, 0] + if tileable[2]: + gradients[..., -1] = gradients[..., 0] + + gradients = ( + gradients.repeat_interleave(d[0], 0) + .repeat_interleave(d[1], 1) + .repeat_interleave(d[2], 2) + ) + g000 = gradients[: -d[0], : -d[1], : -d[2]] + g100 = gradients[d[0] :, : -d[1], : -d[2]] + g010 = gradients[: -d[0], d[1] :, : -d[2]] + g110 = gradients[d[0] :, d[1] :, : -d[2]] + g001 = gradients[: -d[0], : -d[1], d[2] :] + g101 = gradients[d[0] :, : -d[1], d[2] :] + g011 = gradients[: -d[0], d[1] :, d[2] :] + g111 = gradients[d[0] :, d[1] :, d[2] :] + + # Ramps + n000 = torch.sum( + torch.stack((grid[..., 0], grid[..., 1], grid[..., 2]), dim=3) * g000, + dim=3, + ) + n100 = torch.sum( + torch.stack((grid[..., 0] - 1, grid[..., 1], grid[..., 2]), dim=3) + * g100, + dim=3, ) - mask = nd_label(mask)[0] - obj_ids = np.unique(mask) - # first id is the background, so remove it - obj_ids = obj_ids[1:] - # split the color-encoded mask into a set - # of binary masks - masks = mask == obj_ids[:, None, None] - - # get bounding box coordinates for each mask - num_objs = len(obj_ids) - boxes = [] - small_box_idx = [] - for i in range(num_objs): - pos = np.where(masks[i]) - xmin = np.min(pos[1]) - xmax = np.max(pos[1]) - ymin = np.min(pos[0]) - ymax = np.max(pos[0]) - # Fields too small - if (xmax - xmin == 0) or (ymax - ymin == 0): - small_box_idx.append(i) - continue - boxes.append([xmin, ymin, xmax, ymax]) - - if small_box_idx: - good_idx = np.array( - [ - idx - for idx in range(0, masks.shape[0]) - if idx not in small_box_idx - ] + n010 = torch.sum( + torch.stack((grid[..., 0], grid[..., 1] - 1, grid[..., 2]), dim=3) + * g010, + dim=3, + ) + n110 = torch.sum( + torch.stack((grid[..., 0] - 1, grid[..., 1] - 1, grid[..., 2]), dim=3) + * g110, + dim=3, + ) + n001 = torch.sum( + torch.stack((grid[..., 0], grid[..., 1], grid[..., 2] - 1), dim=3) + * g001, + dim=3, + ) + n101 = torch.sum( + torch.stack((grid[..., 0] - 1, grid[..., 1], grid[..., 2] - 1), dim=3) + * g101, + dim=3, + ) + n011 = torch.sum( + torch.stack((grid[..., 0], grid[..., 1] - 1, grid[..., 2] - 1), dim=3) + * g011, + dim=3, + ) + n111 = torch.sum( + torch.stack( + (grid[..., 0] - 1, grid[..., 1] - 1, grid[..., 2] - 1), dim=3 ) - masks = masks[good_idx] - # convert everything into arrays - boxes = torch.as_tensor(boxes, dtype=torch.float32) - if boxes.size(0) == 0: - return None - # there is only one class - labels = torch.ones((masks.shape[0],), dtype=torch.int64) - masks = torch.as_tensor(masks, dtype=torch.uint8) - - assert ( - boxes.size(0) == labels.size(0) == masks.size(0) - ), "The tensor sizes do not match." - - target = {"boxes": boxes, "labels": labels, "masks": masks} - - return target + * g111, + dim=3, + ) + + # Interpolation + t = interpolant(grid) + n00 = n000 * (1 - t[..., 0]) + t[..., 0] * n100 + n10 = n010 * (1 - t[..., 0]) + t[..., 0] * n110 + n01 = n001 * (1 - t[..., 0]) + t[..., 0] * n101 + n11 = n011 * (1 - t[..., 0]) + t[..., 0] * n111 + n0 = (1 - t[..., 1]) * n00 + t[..., 1] * n10 + n1 = (1 - t[..., 1]) * n01 + t[..., 1] * n11 + + x = (1 - t[..., 2]) * n0 + t[..., 2] * n1 + + return scale_min_max(x, in_range=(-0.5, 0.5), out_range=out_range) diff --git a/src/cultionet/augment/augmenters.py b/src/cultionet/augment/augmenters.py index 7c0b2c0f..97f69f5d 100644 --- a/src/cultionet/augment/augmenters.py +++ b/src/cultionet/augment/augmenters.py @@ -1,163 +1,65 @@ -from abc import abstractmethod import typing as T -import enum -from dataclasses import dataclass, replace +from abc import abstractmethod +from dataclasses import replace from pathlib import Path -from tsaug import AddNoise, Drift, TimeWarp -import numpy as np -import cv2 -from skimage import util as sk_util -from torch_geometric.data import Data +import einops import joblib +import numpy as np +import torch +from frozendict import frozendict +from torchvision.transforms import InterpolationMode, v2 +from torchvision.transforms.v2 import functional as VF +from tsaug import AddNoise, Drift, TimeWarp -from .augmenter_utils import augment_time, create_parcel_masks, roll_time -from ..data.utils import create_data_object, LabeledData -from ..networks import SingleSensorNetwork -from ..utils.reshape import nd_to_columns - - -@dataclass -class DataCopies: - x: np.ndarray - y: T.Union[np.ndarray, None] - bdist: T.Union[np.ndarray, None] - - -@dataclass -class AugmenterArgs: - ntime: int - nbands: int - max_crop_class: int - k: int - instance_seg: bool - zero_padding: int - kwargs: dict +from ..data import Data +from .augmenter_utils import augment_time, generate_perlin_noise_3d, roll_time -class AugmenterModule(object): +class AugmenterModule: """Prepares, augments, and finalizes data.""" prefix: str = "data_" suffix: str = ".pt" - def __call__(self, ldata: LabeledData, aug_args: AugmenterArgs) -> Data: + def __call__(self, ldata: Data) -> Data: assert hasattr(self, "name_") assert isinstance(self.name_, str) - cdata = self.prepare_data(ldata) - cdata = self.forward(cdata, ldata, aug_args) - data = self.finalize( - x=cdata.x, y=cdata.y, bdist=cdata.bdist, aug_args=aug_args - ) + cdata = self.forward(ldata.copy()) + cdata.x = cdata.x.float().clip(1e-9, 1) + cdata.bdist = cdata.bdist.float().clip(0, 1) + if cdata.y is not None: + cdata.y = cdata.y.long() - return data + return cdata @abstractmethod - def forward( - self, cdata: DataCopies, ldata: LabeledData, aug_args: AugmenterArgs - ) -> DataCopies: + def forward(self, cdata: Data) -> Data: raise NotImplementedError def file_name(self, uid: str) -> str: return f"{self.prefix}{uid}{self.suffix}" - def save(self, out_directory: Path, data: Data, compress: int = 5) -> None: + def save( + self, out_directory: Path, data: Data, compress: T.Union[int, str] = 5 + ) -> None: out_path = out_directory / self.file_name(data.train_id) joblib.dump(data, out_path, compress=compress) - def prepare_data(self, ldata: LabeledData) -> DataCopies: - x = ldata.x.copy() - y = ldata.y - bdist = ldata.bdist - # TODO: for orientation layer - # ori = ldata.ori - # if zero_padding > 0: - # zpad = torch.nn.ZeroPad2d(zero_padding) - # x = zpad(torch.tensor(x)).numpy() - # y = zpad(torch.tensor(y)).numpy() - # bdist = zpad(torch.tensor(bdist)).numpy() - # ori = zpad(torch.tensor(ori)).numpy() - - if y is not None: - y = y.copy() - if bdist is not None: - bdist = bdist.copy() - - return DataCopies(x=x, y=y, bdist=bdist) - - def finalize( - self, - x: np.ndarray, - y: T.Union[np.ndarray, None], - bdist: T.Union[np.ndarray, None], - aug_args: AugmenterArgs, - ) -> Data: - # Create the network - nwk = SingleSensorNetwork( - np.ascontiguousarray(x, dtype="float64"), k=aug_args.k - ) - - ( - edge_indices_a, - edge_indices_b, - edge_attrs_diffs, - edge_attrs_dists, - __, - __, - ) = nwk.create_network() - edge_indices = np.c_[edge_indices_a, edge_indices_b] - edge_attrs = np.c_[edge_attrs_diffs, edge_attrs_dists] - - # Create the node position tensor - dims, height, width = x.shape - # pos_x = np.arange(0, width * kwargs['res'], kwargs['res']) - # pos_y = np.arange(height * kwargs['res'], 0, -kwargs['res']) - # grid_x, grid_y = np.meshgrid(pos_x, pos_y, indexing='xy') - # xy = np.c_[grid_x.flatten(), grid_y.flatten()] - - x = nd_to_columns(x, dims, height, width) - - mask_y = None - if aug_args.instance_seg: - mask_y = create_parcel_masks(y, aug_args.max_crop_class) - - return create_data_object( - x, - edge_indices, - edge_attrs, - ntime=aug_args.ntime, - nbands=aug_args.nbands, - height=height, - width=width, - y=y, - mask_y=mask_y, - bdist=bdist, - # ori=ori_aug, - zero_padding=aug_args.zero_padding, - **aug_args.kwargs, - ) - class AugmentTimeMixin(AugmenterModule): - def forward( - self, cdata: DataCopies, ldata: LabeledData, aug_args: AugmenterArgs - ) -> DataCopies: + def forward(self, cdata: Data) -> Data: # Warp each segment - for p in ldata.props: - x = augment_time( - ldata, + for p in cdata.props: + cdata = augment_time( + cdata, p=p, - x=cdata.x, - ntime=aug_args.ntime, - nbands=aug_args.nbands, add_noise=self.add_noise_, warper=self.warper, aug=self.name_, ) - cdata = replace(cdata, x=x) - # y and bdist are unaltered return cdata @@ -167,10 +69,12 @@ def __init__( name: str, n_speed_change_lim: T.Tuple[int, int] = None, max_speed_ratio_lim: T.Tuple[float, float] = None, + rng: T.Optional[np.random.Generator] = None, ): + self.name_ = name self.n_speed_change_lim = n_speed_change_lim self.max_speed_ratio_lim = max_speed_ratio_lim - self.name_ = name + self.rng = rng self.add_noise_ = True if self.n_speed_change_lim is None: @@ -179,10 +83,13 @@ def __init__( self.max_speed_ratio_lim = (1.1, 1.5) self.warper = TimeWarp( - n_speed_change=np.random.randint( - low=self.n_speed_change_lim[0], high=self.n_speed_change_lim[1] + n_speed_change=int( + self.rng.integers( + low=self.n_speed_change_lim[0], + high=self.n_speed_change_lim[1], + ) ), - max_speed_ratio=np.random.uniform( + max_speed_ratio=self.rng.uniform( low=self.max_speed_ratio_lim[0], high=self.max_speed_ratio_lim[1], ), @@ -191,8 +98,13 @@ def __init__( class AugmentAddTimeNoise(AugmentTimeMixin): - def __init__(self, scale_lim: T.Tuple[int, int] = None): + def __init__( + self, + scale_lim: T.Tuple[int, int] = None, + rng: T.Optional[np.random.Generator] = None, + ): self.scale_lim = scale_lim + self.rng = rng self.name_ = "tsnoise" self.add_noise_ = False @@ -200,7 +112,7 @@ def __init__(self, scale_lim: T.Tuple[int, int] = None): self.scale_lim = (0.01, 0.05) self.warper = AddNoise( - scale=np.random.uniform( + scale=self.rng.uniform( low=self.scale_lim[0], high=self.scale_lim[1] ) ) @@ -211,9 +123,11 @@ def __init__( self, max_drift_lim: T.Tuple[int, int] = None, n_drift_points_lim: T.Tuple[int, int] = None, + rng: T.Optional[np.random.Generator] = None, ): self.max_drift_lim = max_drift_lim self.n_drift_points_lim = n_drift_points_lim + self.rng = rng self.name_ = "tsdrift" self.add_noise_ = True @@ -223,220 +137,245 @@ def __init__( self.n_drift_points_lim = (1, 6) self.warper = Drift( - max_drift=np.random.uniform( - low=self.max_drift_lim[0], high=self.max_drift_lim[1] + max_drift=self.rng.uniform( + low=self.max_drift_lim[0], + high=self.max_drift_lim[1], ), - n_drift_points=np.random.randint( - low=self.n_drift_points_lim[0], high=self.n_drift_points_lim[1] + n_drift_points=int( + self.rng.integers( + low=self.n_drift_points_lim[0], + high=self.n_drift_points_lim[1], + ) ), static_rand=True, ) -class Rotate(AugmenterModule): - def __init__(self, deg: int): - self.name_ = f"rotate-{deg}" +class Roll(AugmenterModule): + def __init__(self, rng: T.Optional[np.random.Generator] = None): + self.rng = rng + self.name_ = "roll" - deg_dict = { - 90: cv2.ROTATE_90_CLOCKWISE, - 180: cv2.ROTATE_180, - 270: cv2.ROTATE_90_COUNTERCLOCKWISE, - } - self.deg_func = deg_dict[deg] + def forward(self, cdata: Data) -> Data: + for p in cdata.props: + cdata = roll_time(cdata, p, rng=self.rng) - def forward( - self, - cdata: DataCopies, - ldata: LabeledData = None, - aug_args: AugmenterArgs = None, - ) -> DataCopies: - # Create the output array for rotated features - x = np.zeros( - ( - cdata.x.shape[0], - *cv2.rotate(np.float32(cdata.x[0]), self.deg_func).shape, - ), - dtype=cdata.x.dtype, - ) - for i in range(0, cdata.x.shape[0]): - x[i] = cv2.rotate(np.float32(cdata.x[i]), self.deg_func) + return cdata - # Rotate labels - label_dtype = "float" if "float" in cdata.y.dtype.name else "int" - if label_dtype == "float": - y = cv2.rotate(np.float32(cdata.y), self.deg_func) - else: - y = cv2.rotate(np.uint8(cdata.y), self.deg_func) - # Rotate the distance transform - bdist = cv2.rotate(np.float32(cdata.bdist), self.deg_func) - # ori_aug = cv2.rotate(np.float32(ori), self.deg_func) - cdata = replace(cdata, x=x, y=y, bdist=bdist) +class PerlinNoise(AugmenterModule): + def __init__(self, rng: T.Optional[np.random.Generator] = None): + self.rng = rng + self.name_ = "perlin" + + def forward(self, cdata: Data) -> Data: + res = self.rng.choice([2, 5, 10]) + noise = generate_perlin_noise_3d( + shape=cdata.x.shape[2:], + res=(1, res, res), + tileable=(False, False, False), + out_range=(-0.03, 0.03), + rng=self.rng, + ) + + noise = einops.rearrange(noise, 't h w -> 1 1 t h w') + cdata.x = cdata.x + noise.to( + dtype=cdata.x.dtype, device=cdata.x.device + ) return cdata -class Roll(AugmenterModule): - def __init__(self): - self.name_ = "roll" +class Rotate(AugmenterModule): + def __init__(self, deg: int, **kwargs): + self.deg = deg + self.name_ = f"rotate-{deg}" + + def forward(self, cdata: Data) -> Data: + x = einops.rearrange(cdata.x, '1 c t h w -> 1 t c h w') + + x_rotation_transform = v2.RandomRotation( + degrees=[self.deg, self.deg], + interpolation=InterpolationMode.BILINEAR, + ) + y_rotation_transform = v2.RandomRotation( + degrees=[self.deg, self.deg], + interpolation=InterpolationMode.NEAREST, + ) + + cdata.x = einops.rearrange( + x_rotation_transform(x), + '1 t c h w -> 1 c t h w', + ) + cdata.bdist = x_rotation_transform(cdata.bdist) + cdata.y = y_rotation_transform(cdata.y) - def forward( - self, - cdata: DataCopies, - ldata: LabeledData = None, - aug_args: AugmenterArgs = None, - ) -> DataCopies: - for p in ldata.props: - x = roll_time(ldata, p, cdata.x, aug_args.ntime) - cdata = replace(cdata, x=x) - - # y and bdist are unaltered return cdata class Flip(AugmenterModule): - def __init__(self, direction: str): + def __init__(self, direction: str, **kwargs): self.direction = direction self.name_ = direction - def forward( - self, - cdata: DataCopies, - ldata: LabeledData = None, - aug_args: AugmenterArgs = None, - ) -> DataCopies: - x = cdata.x.copy() - if self.direction == "flipfb": - # Reverse the channels - for b in range(0, cdata.x.shape[0], aug_args.ntime): - # Get the slice for the current band, n time steps - x[b : b + aug_args.ntime] = x[b : b + aug_args.ntime][::-1] - - # y and bdist are unaltered - cdata = replace(cdata) + def forward(self, cdata: Data) -> Data: + x = einops.rearrange(cdata.x, '1 c t h w -> 1 t c h w') + + if self.direction == 'fliplr': + flip_transform = VF.hflip + elif self.direction == 'flipud': + flip_transform = VF.vflip else: - flip_func = getattr(np, self.direction) - for i in range(0, x.shape[0]): - x[i] = flip_func(x[i]) + raise NameError("The direction is not supported.") - y = flip_func(cdata.y) - bdist = flip_func(cdata.bdist) - # ori_aug = getattr(np, aug)(ori) - cdata = replace(cdata, x=x, y=y, bdist=bdist) + cdata.x = einops.rearrange( + flip_transform(x), + '1 t c h w -> 1 c t h w', + ) + cdata.bdist = flip_transform(cdata.bdist) + cdata.y = flip_transform(cdata.y) return cdata -class SKLearnMixin(AugmenterModule): - def forward( - self, - cdata: DataCopies, - ldata: LabeledData = None, - aug_args: AugmenterArgs = None, - ) -> DataCopies: - x = cdata.x.copy() - for i in range(0, x.shape[0]): - x[i] = sk_util.random_noise( - x[i], mode=self.name_, clip=True, **self.kwargs - ) +class RandomCropResize(AugmenterModule): + def __init__(self, rng: T.Optional[np.random.Generator] = None): + self.rng = rng + self.name_ = "cropresize" - # y and bdist are unaltered - cdata = replace(cdata, x=x) - - return cdata + def forward(self, cdata: Data) -> Data: + div = self.rng.choice([2, 4]) + size = (cdata.y.shape[-2] // div, cdata.y.shape[-1] // div) + random_seed = self.rng.integers(low=0, high=2147483647) -class GaussianNoise(SKLearnMixin): - def __init__(self, **kwargs): - self.kwargs = kwargs - self.name_ = "gaussian" + x = einops.rearrange(cdata.x, 'b c t h w -> b t c h w') + x = self.random_crop( + x, + interpolation=InterpolationMode.BILINEAR, + size=size, + random_seed=random_seed, + ) + cdata.x = einops.rearrange(x, 'b t c h w -> b c t h w') + cdata.bdist = self.random_crop( + cdata.bdist, + interpolation=InterpolationMode.BILINEAR, + size=size, + random_seed=random_seed, + ) + cdata.y = self.random_crop( + cdata.y, + interpolation=InterpolationMode.NEAREST, + size=size, + random_seed=random_seed, + ) + return cdata -class SaltAndPepperNoise(SKLearnMixin): - def __init__(self, **kwargs): - self.kwargs = kwargs - self.name_ = "s&p" + def random_crop( + self, + x: torch.Tensor, + size: tuple, + interpolation: str, + random_seed: int, + ) -> torch.Tensor: + np.random.seed(random_seed) + torch.manual_seed(random_seed) + + transform = v2.RandomCrop( + size=size, + ) + resize = v2.Resize( + size=x.shape[-2:], + interpolation=interpolation, + ) + return resize(transform(x)) -class SpeckleNoise(SKLearnMixin): - """ - Example: - >>> augmenter = SpeckleNoise() - >>> data = augmenter(labeled_data, **kwargs) - """ - def __init__(self, **kwargs): +class GaussianBlur(AugmenterModule): + def __init__(self, rng: T.Optional[np.random.Generator] = None, **kwargs): self.kwargs = kwargs - self.name_ = "speckle" + self.name_ = "gaussian" + def forward(self, cdata: Data) -> Data: + transform = v2.GaussianBlur(kernel_size=3, **self.kwargs) + cdata.x = transform(cdata.x) -class NoAugmentation(AugmenterModule): - def __init__(self): - self.name_ = "none" - - def forward( - self, - cdata: DataCopies, - ldata: LabeledData = None, - aug_args: AugmenterArgs = None, - ) -> DataCopies: return cdata -class AugmenterMapping(enum.Enum): - """Key: Augmenter mappings""" +class SaltAndPepperNoise(AugmenterModule): + def __init__(self, rng: T.Optional[np.random.Generator] = None, **kwargs): + self.rng = rng + self.kwargs = kwargs + self.name_ = "s&p" - tswarp = AugmentTimeWarp(name="tswarp") - tsnoise = AugmentAddTimeNoise() - tsdrift = AugmentTimeDrift() - tspeaks = AugmentTimeWarp("tspeaks") - rot90 = Rotate(deg=90) - rot180 = Rotate(deg=180) - rot270 = Rotate(deg=270) - roll = Roll() - fliplr = Flip(direction="fliplr") - flipud = Flip(direction="flipud") - gaussian = GaussianNoise(mean=0.0, var=0.005) - saltpepper = SaltAndPepperNoise(amount=0.01) - speckle = SpeckleNoise(mean=0.0, var=0.05) - none = NoAugmentation() + def forward(self, cdata: Data) -> Data: + random_seed = self.rng.integers(low=0, high=2147483647) + cdata.x = self.gaussian_noise( + cdata.x, + random_seed=random_seed, + **self.kwargs, + ) + return cdata -class AugmenterBase(object): - def __init__( - self, - augmentations: T.Sequence[str], - ntime: int, - nbands: int, - max_crop_class: int, - k: int = 3, - instance_seg: bool = False, - zero_padding: int = 0, - **kwargs, - ): - self.augmentations = augmentations - self.augmenters_ = [] - self.aug_args = AugmenterArgs( - ntime=ntime, - nbands=nbands, - max_crop_class=max_crop_class, - k=k, - instance_seg=instance_seg, - zero_padding=zero_padding, - kwargs=kwargs, - ) + def gaussian_noise( + self, x: torch.Tensor, random_seed: int, sigma: float = 0.01 + ) -> torch.Tensor: + np.random.seed(random_seed) + torch.manual_seed(random_seed) - self._init_augmenters() + return x + sigma * torch.randn_like(x) - def _init_augmenters(self): - for augmentation in self.augmentations: - self.augmenters_.append(AugmenterMapping[augmentation].value) - def update_aug_args(self, **kwargs): - self.aug_args = replace(self.aug_args, **kwargs) +class NoAugmentation(AugmenterModule): + def __init__(self, **kwargs): + self.name_ = "none" + + def forward(self, cdata: Data) -> Data: + return cdata -class Augmenters(AugmenterBase): +AUGMENTER_METHODS = frozendict( + tswarp=AugmentTimeWarp, + tsnoise=AugmentAddTimeNoise, + tsdrift=AugmentTimeDrift, + tspeaks=AugmentTimeWarp, + rot90=Rotate, + rot180=Rotate, + rot270=Rotate, + roll=Roll, + fliplr=Flip, + flipud=Flip, + gaussian=GaussianBlur, + saltpepper=SaltAndPepperNoise, + cropresize=RandomCropResize, + perlin=PerlinNoise, + none=NoAugmentation, +) + +MODULE_DEFAULTS = dict( + tswarp=dict(name="tswarp"), + tsnoise={}, + tsdrift={}, + tspeaks=dict(name="tspeaks"), + rot90=dict(deg=90), + rot180=dict(deg=180), + rot270=dict(deg=270), + roll={}, + fliplr=dict(direction="fliplr"), + flipud=dict(direction="flipud"), + gaussian=dict(sigma=(0.2, 0.5)), + saltpepper=dict(sigma=0.01), + cropresize={}, + perlin={}, + none={}, +) + + +class Augmenters: """Applies augmentations for a sequence of augmentation methods. Inputs to callables: @@ -455,22 +394,52 @@ class Augmenters(AugmenterBase): each labeled parcel in `y`. aug_args: Additional keyword arguments passed to the - `torch_geometric.data.Data` object. + `Data` object. Example: - >>> aug = Augmenters( - >>> augmentations=['tswarp'], - >>> ntime=13, - >>> nbands=5, - >>> max_crop_class=1 - >>> ) - >>> - >>> for method in aug: - >>> method(ldata, aug_args=aug.aug_args) + >>> augmenters = Augmenters(augmentations=['tswarp']) + >>> ldata = augmenters(ldata) """ - def __init__(self, **kwargs): - super(Augmenters, self).__init__(**kwargs) + def __init__( + self, + augmentations: T.Sequence[str], + rng: T.Optional[np.random.Generator] = None, + random_seed: T.Optional[int] = None, + **kwargs, + ): + self.augmentations = augmentations + self.augmenters_ = [] + self.kwargs = kwargs + + if rng is None: + rng = np.random.default_rng(random_seed) + + self._init_augmenters(rng) + + def _init_augmenters(self, rng: np.random.Generator): + for aug_name in self.augmentations: + self.augmenters_.append( + AUGMENTER_METHODS[aug_name]( + **{ + "rng": rng, + **MODULE_DEFAULTS[aug_name], + **self.kwargs, + } + ) + ) + + def update_aug_args(self, **kwargs): + self.aug_args = replace(self.aug_args, **kwargs) def __iter__(self): yield from self.augmenters_ + + def __call__(self, batch: Data) -> Data: + return self.forward(batch) + + def forward(self, batch: Data) -> Data: + for augmenter in self: + batch = augmenter(batch) + + return batch diff --git a/src/cultionet/callbacks.py b/src/cultionet/callbacks.py index 3145738c..b6e14459 100644 --- a/src/cultionet/callbacks.py +++ b/src/cultionet/callbacks.py @@ -1,16 +1,42 @@ +import hashlib import typing as T -import filelock from pathlib import Path +import filelock import geowombat as gw import rasterio as rio -from rasterio.windows import Window import torch -from pytorch_lightning.callbacks import BasePredictionWriter -from torch_geometric.data import Data +from lightning.pytorch.callbacks import ( + BasePredictionWriter, + LearningRateMonitor, + ModelCheckpoint, + ModelPruning, + RichProgressBar, + StochasticWeightAveraging, +) +from lightning.pytorch.callbacks.progress.rich_progress import ( + RichProgressBarTheme, +) +from rasterio.windows import Window + +from .data.constant import SCALE_FACTOR +from .enums import InferenceNames -from .data.const import SCALE_FACTOR -from .utils.reshape import ModelOutputs +PROGRESS_BAR_CALLBACK = RichProgressBar( + refresh_rate=1, + theme=RichProgressBarTheme( + description="#cacaca", + progress_bar="#ACFCD6", + progress_bar_finished="#ACFCD6", + progress_bar_pulse="#FCADED", + batch_progress="#AA9439", + time="grey54", + processing_speed="grey70", + metrics="#cacaca", + metrics_text_delimiter=" • ", + metrics_format=".3f", + ), +) def tile_size_is_correct( @@ -24,8 +50,6 @@ def __init__( self, reference_image: Path, out_path: Path, - num_classes: int, - ref_res: float, resampling, compression: str, write_interval: str = "batch", @@ -36,125 +60,90 @@ def __init__( self.out_path = out_path self.out_path.parent.mkdir(parents=True, exist_ok=True) - with gw.config.update(ref_res=ref_res): - with gw.open(reference_image, resampling=resampling) as src: - rechunk = False - new_row_chunks = src.gw.check_chunksize( - src.gw.row_chunks, src.gw.nrows - ) - if new_row_chunks != src.gw.row_chunks: - rechunk = True - new_col_chunks = src.gw.check_chunksize( - src.gw.col_chunks, src.gw.ncols + with gw.open(reference_image, resampling=resampling) as src: + self.crs = src.crs + rechunk = False + new_row_chunks = src.gw.check_chunksize( + src.gw.row_chunks, src.gw.nrows + ) + if new_row_chunks != src.gw.row_chunks: + rechunk = True + new_col_chunks = src.gw.check_chunksize( + src.gw.col_chunks, src.gw.ncols + ) + if new_col_chunks != src.gw.col_chunks: + rechunk = True + if rechunk: + src = src.chunk( + chunks={ + 'band': -1, + 'y': new_row_chunks, + 'x': new_col_chunks, + } ) - if new_col_chunks != src.gw.col_chunks: - rechunk = True - if rechunk: - src = src.chunk( - chunks={ - 'band': -1, - 'y': new_row_chunks, - 'x': new_col_chunks, - } - ) - profile = { - "crs": src.crs, - "transform": src.gw.transform, - "height": src.gw.nrows, - "width": src.gw.ncols, - # distance (+1) + edge (+1) + crop (+1) crop types (+N) - # `num_classes` includes background - "count": 3 + num_classes - 1, - "dtype": "uint16", - "blockxsize": src.gw.col_chunks, - "blockysize": src.gw.row_chunks, - "driver": "GTiff", - "sharing": False, - "compress": compression, - } - profile["tiled"] = tile_size_is_correct( - profile["blockxsize"], profile["blockysize"] + + self.profile = { + "crs": self.crs, + "transform": src.gw.transform, + "height": src.gw.nrows, + "width": src.gw.ncols, + # distance (+1) + edge (+1) + crop (+1) + "count": 3, + "dtype": "uint16", + "blockxsize": src.gw.col_chunks, + "blockysize": src.gw.row_chunks, + "driver": "GTiff", + "sharing": False, + "compress": compression, + } + + self.profile["tiled"] = tile_size_is_correct( + self.profile["blockxsize"], self.profile["blockysize"] ) - with rio.open(out_path, mode="w", **profile): + + with rio.open(self.out_path, mode="w", **self.profile): pass - self.dst = rio.open(out_path, mode="r+") + + self.dst = rio.open(self.out_path, mode="r+") def write_on_epoch_end( self, trainer, pl_module, predictions, batch_indices ): self.dst.close() - def reshape_predictions( + def slice_predictions( self, - batch: Data, + batch_slice: tuple, distance_batch: torch.Tensor, edge_batch: torch.Tensor, crop_batch: torch.Tensor, - crop_type_batch: T.Union[torch.Tensor, None], - batch_index: int, - ) -> T.Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, T.Union[torch.Tensor, None] - ]: - pad_slice2d = ( - slice( - int(batch.row_pad_before[batch_index]), - int(batch.height[batch_index]) - - int(batch.row_pad_after[batch_index]), - ), - slice( - int(batch.col_pad_before[batch_index]), - int(batch.width[batch_index]) - - int(batch.col_pad_after[batch_index]), - ), - ) - pad_slice3d = ( + ) -> T.Dict[str, torch.Tensor]: + + distance_batch = distance_batch[batch_slice] + edge_batch = edge_batch[batch_slice] + crop_batch = crop_batch[batch_slice] + + if crop_batch.shape[0] > 1: + crop_batch = crop_batch[[1]] + + return { + InferenceNames.DISTANCE: distance_batch, + InferenceNames.EDGE: edge_batch, + InferenceNames.CROP: crop_batch, + } + + def get_batch_slice(self, padding: int, window: Window) -> tuple: + return ( slice(0, None), slice( - int(batch.row_pad_before[batch_index]), - int(batch.height[batch_index]) - - int(batch.row_pad_after[batch_index]), + padding, + padding + window.height, ), slice( - int(batch.col_pad_before[batch_index]), - int(batch.width[batch_index]) - - int(batch.col_pad_after[batch_index]), + padding, + padding + window.width, ), ) - rheight = pad_slice2d[0].stop - pad_slice2d[0].start - rwidth = pad_slice2d[1].stop - pad_slice2d[1].start - - def reshaper(x: torch.Tensor, channel_dims: int) -> torch.Tensor: - if channel_dims == 1: - return ( - x.reshape( - int(batch.height[batch_index]), - int(batch.width[batch_index]), - )[pad_slice2d] - .contiguous() - .view(-1)[:, None] - ) - else: - return ( - x.t() - .reshape( - channel_dims, - int(batch.height[batch_index]), - int(batch.width[batch_index]), - )[pad_slice3d] - .permute(1, 2, 0) - .reshape(rheight * rwidth, channel_dims) - ) - - distance_batch = reshaper(distance_batch, channel_dims=1) - edge_batch = reshaper(edge_batch, channel_dims=1) - crop_batch = reshaper(crop_batch, channel_dims=2) - if crop_type_batch is not None: - num_classes = crop_type_batch.size(1) - crop_type_batch = reshaper( - crop_type_batch, channel_dims=num_classes - ) - - return distance_batch, edge_batch, crop_batch, crop_type_batch def write_on_batch_end( self, @@ -166,57 +155,117 @@ def write_on_batch_end( batch_idx, dataloader_idx, ): - distance = prediction["dist"] - edge = prediction["edge"] - crop = prediction["crop"] - crop_type = prediction["crop_type"] - for batch_index in batch.batch.unique(): - mask = batch.batch == batch_index - w = Window( - row_off=int(batch.window_row_off[batch_index]), - col_off=int(batch.window_col_off[batch_index]), - height=int(batch.window_height[batch_index]), - width=int(batch.window_width[batch_index]), + pred_df = prediction.get("pred_df") + if pred_df is not None: + if not pred_df.empty: + pred_df = pred_df.set_crs(crs=self.crs, allow_override=True) + # Create a hash to avoid long file names + batch_hash = hashlib.shake_256( + '-'.join(batch.batch_id).encode() + ) + pred_df.to_file( + self.out_path.parent + / f"{self.out_path.stem}_{batch_hash.hexdigest(16)}.gpkg", + driver="GPKG", + ) + + distance = prediction[InferenceNames.DISTANCE] + edge = prediction[InferenceNames.EDGE] + crop = prediction[InferenceNames.CROP] + + for batch_index in range(batch.x.shape[0]): + window_row_off = int(batch.window_row_off[batch_index]) + window_height = int(batch.window_height[batch_index]) + window_col_off = int(batch.window_col_off[batch_index]) + window_width = int(batch.window_width[batch_index]) + + if window_row_off + window_height > self.profile["height"]: + window_height = self.profile["height"] - window_row_off + if window_col_off + window_width > self.profile["width"]: + window_width = self.profile["width"] - window_col_off + + write_window = Window( + row_off=window_row_off, + col_off=window_col_off, + height=window_height, + width=window_width, ) - w_pad = Window( - row_off=int(batch.window_pad_row_off[batch_index]), - col_off=int(batch.window_pad_col_off[batch_index]), - height=int(batch.window_pad_height[batch_index]), - width=int(batch.window_pad_width[batch_index]), + + batch_slice = self.get_batch_slice( + padding=batch.padding[batch_index], + window=write_window, ) - ( - distance_batch, - edge_batch, - crop_batch, - crop_type_batch, - ) = self.reshape_predictions( - batch=batch, - distance_batch=distance[mask], - edge_batch=edge[mask], - crop_batch=crop[mask], - crop_type_batch=crop_type[mask] - if crop_type is not None - else None, - batch_index=batch_index, + + batch_dict = self.slice_predictions( + batch_slice=batch_slice, + distance_batch=distance[batch_index], + edge_batch=edge[batch_index], + crop_batch=crop[batch_index], ) - if crop_type_batch is None: - crop_type_batch = torch.zeros( - (crop_batch.size(0), 2), dtype=crop_batch.dtype + + stack = ( + torch.cat( + ( + batch_dict[InferenceNames.DISTANCE], + batch_dict[InferenceNames.EDGE], + batch_dict[InferenceNames.CROP], + ), + dim=0, ) - mo = ModelOutputs( - distance=distance_batch, - edge=edge_batch, - crop=crop_batch, - crop_type=crop_type_batch, - instances=None, - apply_softmax=False, + .detach() + .cpu() + .numpy() ) - stack = mo.stack_outputs(w, w_pad) + stack = (stack * SCALE_FACTOR).clip(0, SCALE_FACTOR) with filelock.FileLock("./dst.lock"): self.dst.write( stack, indexes=range(1, self.dst.profile["count"] + 1), - window=w, + window=write_window, ) + + +def setup_callbacks( + ckpt_file: T.Union[str, Path], + stochastic_weight_averaging: bool = False, + stochastic_weight_averaging_lr: float = 0.05, + stochastic_weight_averaging_start: float = 0.8, + model_pruning: bool = False, +) -> T.Tuple[LearningRateMonitor, T.Sequence[T.Any]]: + # Checkpoint + cb_train_loss = ModelCheckpoint(monitor="loss") + # Validation and test loss + cb_val_loss = ModelCheckpoint( + dirpath=ckpt_file.parent, + filename=ckpt_file.stem, + save_last=False, + save_top_k=1, + mode="min", + monitor="val_score", + every_n_train_steps=0, + every_n_epochs=1, + ) + # Early stopping + # early_stop_callback = EarlyStopping( + # monitor="val_score", + # min_delta=early_stopping_min_delta, + # patience=early_stopping_patience, + # mode="min", + # check_on_train_epoch_end=False, + # ) + # Learning rate + lr_monitor = LearningRateMonitor(logging_interval="epoch") + callbacks = [lr_monitor, cb_train_loss, cb_val_loss] + if stochastic_weight_averaging: + callbacks.append( + StochasticWeightAveraging( + swa_lrs=stochastic_weight_averaging_lr, + swa_epoch_start=stochastic_weight_averaging_start, + ) + ) + if 0 < model_pruning <= 1: + callbacks.append(ModelPruning("l1_unstructured", amount=model_pruning)) + + return lr_monitor, callbacks diff --git a/src/cultionet/data/__init__.py b/src/cultionet/data/__init__.py index e69de29b..02a0feb2 100644 --- a/src/cultionet/data/__init__.py +++ b/src/cultionet/data/__init__.py @@ -0,0 +1 @@ +from .data import Data diff --git a/src/cultionet/data/const.py b/src/cultionet/data/constant.py similarity index 100% rename from src/cultionet/data/const.py rename to src/cultionet/data/constant.py diff --git a/src/cultionet/data/create.py b/src/cultionet/data/create.py index 8fe37821..969faceb 100644 --- a/src/cultionet/data/create.py +++ b/src/cultionet/data/create.py @@ -1,197 +1,64 @@ import typing as T from pathlib import Path -from functools import partial -import warnings +import dask.array as da +import einops +import geopandas as gpd import geowombat as gw -from geowombat.core import polygon_to_array -from geowombat.core.windows import get_window_offsets import numpy as np -from scipy.ndimage.measurements import label as nd_label -import cv2 -from rasterio.warp import calculate_default_transform -from rasterio.windows import Window +import pandas as pd +import psutil +import ray +import torch import xarray as xr -import geopandas as gpd +from affine import Affine +from psutil._common import bytes2human +from rasterio.windows import Window, from_bounds +from ray.exceptions import RayTaskError +from ray.util.dask import ray_dask_get +from scipy.ndimage import label as nd_label from skimage.measure import regionprops -from tqdm.auto import tqdm -from torch_geometric.data import Data -import joblib -from joblib import delayed, parallel_backend from threadpoolctl import threadpool_limits -from .utils import LabeledData, get_image_list_dims -from ..augment.augmenters import Augmenters, AugmenterMapping -from ..errors import TopologyClipError +from ..augment.augmenters import AUGMENTER_METHODS from ..utils.logging import set_color_logger -from ..utils.model_preprocessing import TqdmParallel - +from .data import Data, LabeledData +from .store import BatchStore +from .utils import ( + cleanup_edges, + edge_gradient, + fillz, + get_crop_count, + get_image_list_dims, + normalize_boundary_distances, + polygon_to_array, +) logger = set_color_logger(__name__) -def roll( - arr_pad: np.ndarray, - shift: T.Union[int, T.Tuple[int, int]], - axis: T.Union[int, T.Tuple[int, int]], -) -> np.ndarray: - """Rolls array elements along a given axis and slices off padded edges.""" - return np.roll(arr_pad, shift, axis=axis)[1:-1, 1:-1] - - -def close_edge_ends(array: np.ndarray) -> np.ndarray: - """Closes 1 pixel gaps at image edges.""" - # Top - idx = np.where(array[1] == 1) - z = np.zeros(array.shape[1], dtype="uint8") - z[idx] = 1 - array[0] = z - # Bottom - idx = np.where(array[-2] == 1) - z = np.zeros(array.shape[1], dtype="uint8") - z[idx] = 1 - array[-1] = z - # Left - idx = np.where(array[:, 1] == 1) - z = np.zeros(array.shape[0], dtype="uint8") - z[idx] = 1 - array[:, 0] = z - # Right - idx = np.where(array[:, -2] == 1) - z = np.zeros(array.shape[0], dtype="uint8") - z[idx] = 1 - array[:, -1] = z - - return array - - -def get_other_crop_count(array: np.ndarray) -> np.ndarray: - array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") - - rarray = roll(array_pad, 1, axis=0) - crop_count = np.uint8((rarray > 0) & (rarray != array) & (array > 0)) - rarray = roll(array_pad, -1, axis=0) - crop_count += np.uint8((rarray > 0) & (rarray != array) & (array > 0)) - rarray = roll(array_pad, 1, axis=1) - crop_count += np.uint8((rarray > 0) & (rarray != array) & (array > 0)) - rarray = roll(array_pad, -1, axis=1) - crop_count += np.uint8((rarray > 0) & (rarray != array) & (array > 0)) - - return crop_count - - -def fill_edge_gaps(labels: np.ndarray, array: np.ndarray) -> np.ndarray: - """Fills neighboring 1-pixel edge gaps.""" - # array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode='edge') - # hsum = roll(array_pad, 1, axis=0) + roll(array_pad, -1, axis=0) - # vsum = roll(array_pad, 1, axis=1) + roll(array_pad, -1, axis=1) - # array = np.where( - # (hsum == 2) & (vsum == 0), 1, array - # ) - # array = np.where( - # (hsum == 0) & (vsum == 2), 1, array - # ) - other_count = get_other_crop_count(np.where(array == 1, 0, labels)) - array = np.where(other_count > 0, 1, array) - - return array - - -def get_crop_count(array: np.ndarray, edge_class: int) -> np.ndarray: - array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") - - rarray = roll(array_pad, 1, axis=0) - crop_count = np.uint8((rarray > 0) & (rarray != edge_class)) - rarray = roll(array_pad, -1, axis=0) - crop_count += np.uint8((rarray > 0) & (rarray != edge_class)) - rarray = roll(array_pad, 1, axis=1) - crop_count += np.uint8((rarray > 0) & (rarray != edge_class)) - rarray = roll(array_pad, -1, axis=1) - crop_count += np.uint8((rarray > 0) & (rarray != edge_class)) - - return crop_count - - -def get_edge_count(array: np.ndarray, edge_class: int) -> np.ndarray: - array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") - - edge_count = np.uint8(roll(array_pad, 1, axis=0) == edge_class) - edge_count += np.uint8(roll(array_pad, -1, axis=0) == edge_class) - edge_count += np.uint8(roll(array_pad, 1, axis=1) == edge_class) - edge_count += np.uint8(roll(array_pad, -1, axis=1) == edge_class) - - return edge_count - - -def get_non_count(array: np.ndarray) -> np.ndarray: - array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") - - non_count = np.uint8(roll(array_pad, 1, axis=0) == 0) - non_count += np.uint8(roll(array_pad, -1, axis=0) == 0) - non_count += np.uint8(roll(array_pad, 1, axis=1) == 0) - non_count += np.uint8(roll(array_pad, -1, axis=1) == 0) - - return non_count - - -def cleanup_edges( - array: np.ndarray, original: np.ndarray, edge_class: int -) -> np.ndarray: - """Removes crop pixels that border non-crop pixels.""" - array_pad = np.pad(original, pad_width=((1, 1), (1, 1)), mode="edge") - original_zero = np.uint8(roll(array_pad, 1, axis=0) == 0) - original_zero += np.uint8(roll(array_pad, -1, axis=0) == 0) - original_zero += np.uint8(roll(array_pad, 1, axis=1) == 0) - original_zero += np.uint8(roll(array_pad, -1, axis=1) == 0) - - # Fill edges - array = np.where( - (array == 0) - & (get_crop_count(array, edge_class) > 0) - & (get_edge_count(array, edge_class) > 0), - edge_class, - array, - ) - # Remove crops next to non-crop - array = np.where( - (array > 0) - & (array != edge_class) - & (get_non_count(array) > 0) - & (get_edge_count(array, edge_class) > 0), - 0, - array, - ) - # Fill in non-cropland - array = np.where(original_zero == 4, 0, array) - # Remove isolated crop pixels (i.e., crop clumps with 2 or fewer pixels) - array = np.where( - (array > 0) - & (array != edge_class) - & (get_crop_count(array, edge_class) <= 1), - 0, - array, - ) - - return array - - def is_grid_processed( process_path: Path, transforms: T.List[str], - group_id: str, - grid_id: T.Union[str, int], + region: str, + start_date: str, + end_date: str, uid_format: str, ) -> bool: """Checks if a grid is already processed.""" + batches_stored = [] for aug in transforms: - aug_method = AugmenterMapping[aug].value + aug_method = AUGMENTER_METHODS[aug]() train_id = uid_format.format( - GROUP_ID=group_id, ROW_ID=grid_id, AUGMENTER=aug_method.name_ + REGION_ID=region, + START_DATE=start_date, + END_DATE=end_date, + AUGMENTER=aug_method.name_, ) train_path = process_path / aug_method.file_name(train_id) - if train_path.is_file(): + if train_path.exists(): batch_stored = True else: batch_stored = False @@ -201,737 +68,711 @@ def is_grid_processed( return all(batches_stored) -def create_boundary_distances( - labels_array: np.ndarray, train_type: str, cell_res: float -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: - """Creates distances from boundaries.""" - if train_type.lower() == "polygon": - mask = np.uint8(labels_array) - else: - mask = np.uint8(1 - labels_array) - # Get unique segments - segments = nd_label(mask)[0] - # Get the distance from edges - bdist = cv2.distanceTransform(mask, cv2.DIST_L2, 3) - bdist *= cell_res - - grad_x = cv2.Sobel( - np.pad(bdist, 5, mode="edge"), cv2.CV_32F, dx=1, dy=0, ksize=5 - ) - grad_y = cv2.Sobel( - np.pad(bdist, 5, mode="edge"), cv2.CV_32F, dx=0, dy=1, ksize=5 - ) - ori = cv2.phase(grad_x, grad_y, angleInDegrees=False) - ori = ori[5:-5, 5:-5] / np.deg2rad(360) - ori[labels_array == 0] = 0 - - return mask, segments, bdist, ori - - -def normalize_boundary_distances( - labels_array: np.ndarray, - train_type: str, - cell_res: float, - normalize: bool = True, -) -> T.Tuple[np.ndarray, np.ndarray]: - """Normalizes boundary distances.""" - # Create the boundary distances - __, segments, bdist, ori = create_boundary_distances( - labels_array, train_type, cell_res - ) - dist_max = 1e9 - if normalize: - dist_max = 1.0 - # Normalize each segment by the local max distance - props = regionprops(segments, intensity_image=bdist) - for p in props: - if p.label > 0: - bdist = np.where( - segments == p.label, bdist / p.max_intensity, bdist - ) - bdist = np.nan_to_num( - bdist.clip(0, dist_max), nan=1.0, neginf=1.0, posinf=1.0 +def reshape_and_mask_array( + data: xr.DataArray, + num_time: int, + num_bands: int, + gain: float, + offset: int, + apply_gain: bool = True, +) -> xr.DataArray: + """Reshapes an array and masks no-data values.""" + + dtype = 'float32' if apply_gain else 'int16' + + time_series = xr.DataArray( + # Data are stored [(band x time) x height x width] + ( + data.data.reshape( + num_bands, + num_time, + data.gw.nrows, + data.gw.ncols, + ).transpose(1, 0, 2, 3) + ).astype(dtype), + dims=('time', 'band', 'y', 'x'), + coords={ + 'time': range(num_time), + 'band': range(num_bands), + 'y': data.y, + 'x': data.x, + }, + attrs=data.attrs.copy(), ) - ori = np.nan_to_num(ori.clip(0, 1), nan=1.0, neginf=1.0, posinf=1.0) - return bdist, ori + if apply_gain: + with xr.set_options(keep_attrs=True): + # Mask and scale the data + time_series = ( + time_series.gw.mask_nodata() * gain + offset + ).fillna(0) -def edge_gradient(array: np.ndarray) -> np.ndarray: - """Calculates the morphological gradient of crop fields.""" - se = np.array([[1, 1], [1, 1]], dtype="uint8") - array = np.uint8( - cv2.morphologyEx(np.uint8(array), cv2.MORPH_GRADIENT, se) > 0 - ) - - return array + return time_series -def create_image_vars( - image: T.Union[str, Path, list], - max_crop_class: int, - bounds: tuple, - num_workers: int, +def create_predict_dataset( + image_list: T.List[T.List[T.Union[str, Path]]], + region: str, + process_path: Path = None, + date_format: str = "%Y%j", gain: float = 1e-4, offset: float = 0.0, - grid_edges: T.Optional[gpd.GeoDataFrame] = None, - ref_res: T.Optional[T.Union[float, T.Tuple[float, float]]] = 10.0, - resampling: T.Optional[str] = "nearest", - crop_column: T.Optional[str] = "class", - keep_crop_classes: T.Optional[bool] = False, - replace_dict: T.Optional[T.Dict[int, int]] = None, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, int, int]: - """Creates the initial image training data.""" - edge_class = max_crop_class + 1 - - if isinstance(image, list): - image = [str(fn) for fn in image] - - # Open the image variables - with gw.config.update(ref_bounds=bounds, ref_res=ref_res): - with gw.open( - image, - stack_dim="band", - band_names=list(range(1, len(image) + 1)), - resampling=resampling, - ) as src_ts: - # 65535 'no data' values = nan - mask = xr.where(src_ts > 10_000, np.nan, 1) - # X variables - time_series = ( - ( - src_ts.gw.set_nodata( - src_ts.gw.nodataval, - 0, - out_range=(0, 1), - dtype="float64", - scale_factor=gain, - offset=offset, - ) - * mask + ref_res: T.Union[float, T.Tuple[float, float]] = 10.0, + resampling: str = "nearest", + window_size: int = 100, + padding: int = 20, + num_workers: int = 1, + compress_method: T.Union[int, str] = 'zlib', +): + """Creates a prediction dataset for an image.""" + + # Read windows larger than the re-chunk window size + read_chunksize = 256 + while True: + if read_chunksize < window_size + padding: + read_chunksize *= 2 + else: + break + + total_cpus = psutil.cpu_count(logical=True) + threads_per_worker = total_cpus // num_workers + + logger.info(f"Opening images with window chunk sizes of {read_chunksize}.") + logger.info( + f"Re-chunking image arrays to chunk sizes of {window_size} with padding of {padding}." + ) + logger.info( + f"Virtual memory available is {bytes2human(psutil.virtual_memory().available)}." + ) + logger.info( + f"Creating PyTorch dataset with {num_workers} processes and {threads_per_worker} threads." + ) + + with threadpool_limits(limits=threads_per_worker, user_api="blas"): + + with gw.config.update(ref_res=ref_res): + with gw.open( + image_list, + stack_dim="band", + band_names=list(range(1, len(image_list) + 1)), + resampling=resampling, + chunks=read_chunksize, + ) as src_ts: + # Get the time and band count + num_time, num_bands = get_image_list_dims(image_list, src_ts) + + time_series = reshape_and_mask_array( + data=src_ts, + num_time=num_time, + num_bands=num_bands, + gain=gain, + offset=offset, + apply_gain=False, ) - .fillna(0) - .gw.compute(num_workers=num_workers) - ) - # Get the time and band count - ntime, nbands = get_image_list_dims(image, src_ts) - if grid_edges is not None: - if replace_dict is not None: - for crop_class in grid_edges[crop_column].unique(): - if crop_class not in list(replace_dict.keys()): - grid_edges[crop_column] = grid_edges[ - crop_column - ].replace({crop_class: -999}) - replace_dict[-999] = 1 - grid_edges[crop_column] = grid_edges[crop_column].replace( - replace_dict - ) - # Remove any non-crop polygons - grid_edges = grid_edges.query(f"{crop_column} != 0") - if grid_edges.empty: - labels_array = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="uint8" - ) - bdist = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="float64" - ) - ori = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="float64" - ) - edges = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="uint8" - ) - else: - # Get the field polygons - labels_array_copy = ( - polygon_to_array( - grid_edges.assign( - **{ - crop_column: range( - 1, len(grid_edges.index) + 1 - ) - } - ), - col=crop_column, - data=src_ts, - all_touched=False, - ) - .squeeze() - .gw.compute(num_workers=num_workers) - ) - labels_array = ( - polygon_to_array( - grid_edges, - col=crop_column, - data=src_ts, - all_touched=False, - ) - .squeeze() - .gw.compute(num_workers=num_workers) - ) - # Get the field edges - edges = ( - polygon_to_array( - ( - grid_edges.boundary.to_frame(name="geometry") - .reset_index() - .rename(columns={"index": crop_column}) - .assign( - **{ - crop_column: range( - 1, len(grid_edges.index) + 1 - ) - } - ) - ), - col=crop_column, - data=src_ts, - all_touched=False, - ) - .squeeze() - .gw.compute(num_workers=num_workers) - ) - if not edges.flags["WRITEABLE"]: - edges = edges.copy() - edges[edges > 0] = 1 - assert edges.max() <= 1, "Edges were not created." - if edges.max() == 0: - return None, None, None, None, None, None - image_grad = edge_gradient(labels_array_copy) - image_grad_count = get_crop_count(image_grad, edge_class) - edges = np.where(image_grad_count > 0, edges, 0) - # Recode - if not keep_crop_classes: - labels_array = np.where( - labels_array > 0, max_crop_class, 0 - ) - # Set edges - labels_array[edges == 1] = edge_class - # No crop pixel should border non-crop - labels_array = cleanup_edges( - labels_array, labels_array_copy, edge_class - ) - assert ( - labels_array.max() <= edge_class - ), "The labels array have larger than expected values." - # Normalize the boundary distances for each segment - bdist, ori = normalize_boundary_distances( - np.uint8( - (labels_array > 0) & (labels_array != edge_class) + # Chunk the array into the windows + time_series_array = time_series.chunk( + { + "time": -1, + "band": -1, + "y": window_size, + "x": window_size, + } + ).data + + # Check if the array needs to be padded + # First, get the end chunk size of rows and columns + height_end_chunk = time_series_array.chunks[-2][-1] + width_end_chunk = time_series_array.chunks[-1][-1] + + height_padding = 0 + width_padding = 0 + if padding > height_end_chunk: + height_padding = padding - height_end_chunk + if padding > width_end_chunk: + width_padding = padding - width_end_chunk + + if (height_padding > 0) or (width_padding > 0): + # Pad the full array if the end chunk is smaller than the padding + time_series_array = da.pad( + time_series_array, + pad_width=( + (0, 0), + (0, 0), + (0, height_padding), + (0, width_padding), ), - grid_edges.geom_type.values[0], - src_ts.gw.celly, - ) - # import matplotlib.pyplot as plt - # def save_labels(out_fig: Path): - # fig, axes = plt.subplots(2, 2, figsize=(6, 5), sharey=True, sharex=True, dpi=300) - # axes = axes.flatten() - # for ax, im, title in zip( - # axes, - # (labels_array_copy, labels_array, bdist, ori), - # ('Fields', 'Edges', 'Distance', 'Orientation') - # ): - # ax.imshow(im, interpolation='nearest') - # ax.set_title(title) - # ax.axis('off') - - # plt.tight_layout() - # plt.savefig(out_fig, dpi=300) - # import uuid - # fig_dir = Path('figures') - # fig_dir.mkdir(exist_ok=True, parents=True) - # hash_id = uuid.uuid4().hex - # save_labels( - # out_fig=fig_dir / f'{hash_id}.png' - # ) - else: - labels_array = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="uint8" - ) - bdist = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="float64" - ) - ori = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="float64" - ) - edges = np.zeros( - (src_ts.gw.nrows, src_ts.gw.ncols), dtype="uint8" + ).rechunk({0: -1, 1: -1, 2: window_size, 3: window_size}) + + # Add the padding to each chunk + time_series_array = time_series_array.map_overlap( + lambda x: x, + depth={0: 0, 1: 0, 2: padding, 3: padding}, + boundary=0, + trim=False, ) - return time_series, labels_array, bdist, ori, ntime, nbands + if not ray.is_initialized(): + ray.init(num_cpus=num_workers) + + try: + with BatchStore( + data=time_series, + write_path=process_path, + res=ref_res, + resampling=resampling, + region=region, + start_date=pd.to_datetime( + Path(image_list[0]).stem, format=date_format + ).strftime("%Y%m%d"), + end_date=pd.to_datetime( + Path(image_list[-1]).stem, format=date_format + ).strftime("%Y%m%d"), + window_size=window_size, + padding=padding, + compress_method=compress_method, + ) as batch_store: + batch_store.save( + time_series_array, + scheduler=ray_dask_get, + ) + except RayTaskError as e: + logger.warning(e) + ray.shutdown() + + if ray.is_initialized(): + ray.shutdown() + + +class ReferenceArrays: + def __init__( + self, + labels_array: np.ndarray = None, + boundary_distance: np.ndarray = None, + orientation: np.ndarray = None, + edge_array: np.ndarray = None, + ): + self.labels_array = labels_array + self.boundary_distance = boundary_distance + self.orientation = orientation + self.edge_array = edge_array + + @classmethod + def from_polygons( + cls, + df_polygons_grid: gpd.GeoDataFrame, + max_crop_class: int, + edge_class: int, + crop_column: str, + keep_crop_classes: bool, + data_array: xr.DataArray, + nonag_is_unknown: bool = False, + all_touched: bool = True, + ) -> "ReferenceArrays": + # Polygon label array, where each polygon has a + # unique raster value. + labels_array_unique = polygon_to_array( + df=df_polygons_grid.assign( + **{crop_column: range(1, len(df_polygons_grid.index) + 1)} + ), + reference_data=data_array, + column=crop_column, + ) -def save_and_update( - write_path: Path, predict_data: Data, name: str, compress: int = 5 -) -> None: - predict_path = write_path / f"data_{name}.pt" - joblib.dump(predict_data, predict_path, compress=compress) + # Polygon label array, where each polygon has a value + # equal to the GeoDataFrame `crop_column`. + fill_value = 0 + dtype = "uint8" + if nonag_is_unknown: + # Background values are unknown, so they need to be + # filled with -1 + fill_value = -1 + dtype = "int16" + + labels_array = polygon_to_array( + df=df_polygons_grid, + reference_data=data_array, + column=crop_column, + fill_value=fill_value, + dtype=dtype, + ) + # Get the polygon edges as an array + edge_array = polygon_to_array( + df=( + df_polygons_grid.boundary.to_frame(name="geometry") + .reset_index() + .rename(columns={"index": crop_column}) + .assign( + **{crop_column: range(1, len(df_polygons_grid.index) + 1)} + ) + ), + reference_data=data_array, + column=crop_column, + all_touched=all_touched, + ) -def read_slice(darray: xr.DataArray, w_pad: Window) -> xr.DataArray: - slicer = ( - slice(0, None), - slice(w_pad.row_off, w_pad.row_off + w_pad.height), - slice(w_pad.col_off, w_pad.col_off + w_pad.width), - ) + if not edge_array.flags["WRITEABLE"]: + edge_array = edge_array.copy() - return darray[slicer] + edge_array[edge_array > 0] = 1 + assert edge_array.max() <= 1, "Edges were not created." + # Get the edges from the unique polygon array + image_grad = edge_gradient(labels_array_unique) + # Fill in edges that may have been missed by the polygon boundary + image_grad_count = get_crop_count(image_grad, edge_class) + edge_array = np.where(image_grad_count > 0, edge_array, 0) -def get_window_chunk( - windows: T.List[T.Tuple[Window, Window]], chunksize: int -) -> T.List[T.Tuple[Window, Window]]: - for i in range(0, len(windows), chunksize): - yield windows[i : i + chunksize] + if not keep_crop_classes: + # Recode all crop polygons to a single class + labels_array = np.where( + labels_array > 0, max_crop_class, fill_value + ) + # Set edges within the labels array + # E.g., + # 0 = background + # 1 = crop + # 2 = crop edge + labels_array[edge_array == 1] = edge_class + # No crop pixel should border non-crop + labels_array = cleanup_edges( + np.where(labels_array == fill_value, 0, labels_array), + labels_array_unique, + edge_class, + ) + labels_array = np.where(labels_array == 0, fill_value, labels_array) -def create_and_save_window( - write_path: Path, - ntime: int, - nbands: int, - image_height: int, - image_width: int, - res: float, - resampling: str, - region: str, - year: int, - window_size: int, - padding: int, - darray: xr.DataArray, - w: Window, - w_pad: Window, -) -> None: - x = darray.data.compute(num_workers=1) - - size = window_size + padding * 2 - x_height = x.shape[1] - x_width = x.shape[2] - - row_pad_before = 0 - col_pad_before = 0 - col_pad_after = 0 - row_pad_after = 0 - if (x_height != size) or (x_width != size): - # Pre-padding - if w.row_off < padding: - row_pad_before = padding - w.row_off - if w.col_off < padding: - col_pad_before = padding - w.col_off - # Post-padding - if w.row_off + window_size + padding > image_height: - row_pad_after = size - x_height - if w.col_off + window_size + padding > image_width: - col_pad_after = size - x_width - - x = np.pad( - x, - pad_width=( - (0, 0), - (row_pad_before, row_pad_after), - (col_pad_before, col_pad_after), - ), - mode="constant", + assert ( + labels_array.max() <= edge_class + ), "The labels array have larger than expected values." + + # Normalize the boundary distances for each segment + boundary_distance, orientation = normalize_boundary_distances( + np.uint8((labels_array > 0) & (labels_array != edge_class)), + df_polygons_grid.geom_type.values[0], + data_array.gw.celly, ) - if x.shape[1:] != (size, size): - logger.warning("The array does not match the expected size.") - ldata = LabeledData( - x=x, y=None, bdist=None, ori=None, segments=None, props=None - ) + return cls( + labels_array=labels_array, + boundary_distance=boundary_distance, + orientation=orientation, + edge_array=edge_array, + ) - augmenters = Augmenters( - augmentations=["none"], - ntime=ntime, - nbands=nbands, - max_crop_class=0, - k=3, - instance_seg=False, - zero_padding=0, - window_row_off=w.row_off, - window_col_off=w.col_off, - window_height=w.height, - window_width=w.width, - window_pad_row_off=w_pad.row_off, - window_pad_col_off=w_pad.col_off, - window_pad_height=w_pad.height, - window_pad_width=w_pad.width, - row_pad_before=row_pad_before, - row_pad_after=row_pad_after, - col_pad_before=col_pad_before, - col_pad_after=col_pad_after, - res=res, - resampling=resampling, - left=darray.gw.left, - bottom=darray.gw.bottom, - right=darray.gw.right, - top=darray.gw.top, - ) - for aug_method in augmenters: - aug_kwargs = augmenters.aug_args.kwargs - aug_kwargs["train_id"] = f"{region}_{year}_{w.row_off}_{w.col_off}" - augmenters.update_aug_args(kwargs=aug_kwargs) - predict_data = aug_method(ldata, aug_args=augmenters.aug_args) - aug_method.save( - out_directory=write_path, data=predict_data, compress=5 + +class ImageVariables: + def __init__( + self, + time_series: np.ndarray = None, + labels_array: np.ndarray = None, + boundary_distance: np.ndarray = None, + orientation: np.ndarray = None, + edge_array: np.ndarray = None, + num_time: int = None, + num_bands: int = None, + ): + self.time_series = time_series + self.labels_array = labels_array + self.boundary_distance = boundary_distance + self.orientation = orientation + self.edge_array = edge_array + self.num_time = num_time + self.num_bands = num_bands + + @staticmethod + def recode_polygons( + df_polygons_grid: gpd.GeoDataFrame, + crop_column: str, + replace_dict: dict, + ) -> gpd.GeoDataFrame: + """Recodes polygon labels.""" + + df_polygons_grid[crop_column] = df_polygons_grid[crop_column].replace( + to_replace=replace_dict ) + # Remove any non-crop polygons + return df_polygons_grid.query(f"{crop_column} != 0") + + @staticmethod + def get_default_arrays(num_rows: int, num_cols: int) -> tuple: + labels_array = np.zeros((num_rows, num_cols), dtype="uint8") + boundary_distance = np.zeros((num_rows, num_cols), dtype="float64") + orientation = np.zeros_like(boundary_distance) + edge_array = np.zeros_like(labels_array) + + return labels_array, boundary_distance, orientation, edge_array + + @classmethod + def create_image_vars( + cls, + region: str, + image: T.Union[str, Path, list], + reference_grid: gpd.GeoDataFrame, + max_crop_class: int, + grid_size: T.Optional[ + T.Union[T.Tuple[int, int], T.List[int], None] + ] = None, + gain: float = 1e-4, + offset: float = 0.0, + df_polygons_grid: T.Optional[gpd.GeoDataFrame] = None, + ref_res: float = 10.0, + resampling: str = "nearest", + crop_column: str = "class", + keep_crop_classes: bool = False, + replace_dict: T.Optional[T.Dict[int, int]] = None, + nonag_is_unknown: bool = False, + all_touched: bool = True, + ) -> "ImageVariables": + """Creates the initial image training data.""" + + # Get the reference bounds from the grid geometry + ref_bounds = reference_grid.total_bounds.tolist() + + # Pre-check before opening files + if grid_size is not None: + ref_window = from_bounds( + *ref_bounds, + Affine( + ref_res, 0.0, ref_bounds[0], 0.0, -ref_res, ref_bounds[3] + ), + ) -def create_predict_dataset( - image_list: T.List[T.List[T.Union[str, Path]]], - region: str, - year: int, - process_path: Path = None, - gain: float = 1e-4, - offset: float = 0.0, - ref_res: T.Union[float, T.Tuple[float, float]] = 10.0, - resampling: str = "nearest", - window_size: int = 100, - padding: int = 101, - num_workers: int = 1, - chunksize: int = 100, -): - with threadpool_limits(limits=1, user_api="blas"): - with gw.config.update(ref_res=ref_res): + ref_window = Window( + row_off=int(ref_window.row_off), + col_off=int(ref_window.col_off), + height=int(round(ref_window.height)), + width=int(round(ref_window.width)), + ) + + assert (int(ref_window.height) == grid_size[0]) and ( + int(ref_window.width) == grid_size[1] + ), ( + f"The reference grid size is {ref_window.height} rows x {ref_window.width} columns, but the expected " + f"dimensions are {grid_size[0]} rows x {grid_size[1]} columns" + ) + + # Open the image variables + with gw.config.update( + ref_bounds=ref_bounds, + ref_crs=reference_grid.crs, + ref_res=ref_res, + ): with gw.open( - image_list, + image, stack_dim="band", - band_names=list(range(1, len(image_list) + 1)), + band_names=list(range(1, len(image) + 1)), resampling=resampling, - chunks=512, ) as src_ts: - windows = get_window_offsets( - src_ts.gw.nrows, - src_ts.gw.ncols, - window_size, - window_size, - padding=(padding, padding, padding, padding), - ) - time_series = ( - (src_ts.astype("float64") * gain + offset) - .clip(0, 1) - .chunk({"band": -1, "y": window_size, "x": window_size}) - .transpose("band", "y", "x") - .assign_attrs(**src_ts.attrs) - ) + if grid_size is not None: + if not ( + (src_ts.gw.nrows == grid_size[0]) + and (src_ts.gw.ncols == grid_size[1]) + ): + logger.warning( + f"The reference image size is {src_ts.gw.nrows} rows x {src_ts.gw.ncols} columns, but the expected " + f"dimensions are {grid_size[0]} rows x {grid_size[1]} columns" + ) + return cls() + + # Get the time and band count + num_time, num_bands = get_image_list_dims(image, src_ts) + + time_series = reshape_and_mask_array( + data=src_ts, + num_time=num_time, + num_bands=num_bands, + gain=gain, + offset=offset, + ).data.compute(num_workers=1) + + # Fill isolated zeros + time_series = fillz(time_series) + + # NaNs are filled with 0 in reshape_and_mask_array() + zero_mask = time_series.sum(axis=0) == 0 + if zero_mask.all(): + logger.warning( + f"The {region} time series contains all NaNs." + ) + return cls() - ntime, nbands = get_image_list_dims(image_list, src_ts) - - partial_create = partial( - create_and_save_window, - process_path, - ntime, - nbands, - src_ts.gw.nrows, - src_ts.gw.ncols, - ref_res, - resampling, - region, - year, - window_size, - padding, + # Default outputs + ( + labels_array, + boundary_distance, + orientation, + edge_array, + ) = cls.get_default_arrays( + num_rows=src_ts.gw.nrows, num_cols=src_ts.gw.ncols ) - with tqdm( - total=len(windows), - desc="Creating prediction windows", - position=1, - ) as pbar_total: - with parallel_backend(backend="loky", n_jobs=num_workers): - for window_chunk in get_window_chunk( - windows, chunksize - ): - with TqdmParallel( - tqdm_kwargs={ - "total": len(window_chunk), - "desc": "Window chunks", - "position": 2, - "leave": False, - }, - temp_folder="/tmp", - ) as pool: - __ = pool( - delayed(partial_create)( - read_slice(time_series, window_pad), - window, - window_pad, - ) - for window, window_pad in window_chunk - ) - pbar_total.update(len(window_chunk)) - - -def create_dataset( + # Any polygons intersecting the grid? + if df_polygons_grid is not None: + if replace_dict is not None: + # Recode polygons + df_polygons_grid = cls.recode_polygons( + df_polygons_grid=df_polygons_grid, + crop_column=crop_column, + replace_dict=replace_dict, + ) + + if not df_polygons_grid.empty: + reference_arrays: ReferenceArrays = ( + ReferenceArrays.from_polygons( + df_polygons_grid=df_polygons_grid, + max_crop_class=max_crop_class, + edge_class=max_crop_class + 1, + crop_column=crop_column, + keep_crop_classes=keep_crop_classes, + data_array=src_ts, + nonag_is_unknown=nonag_is_unknown, + all_touched=all_touched, + ) + ) + + if reference_arrays.labels_array is not None: + labels_array = reference_arrays.labels_array + boundary_distance = ( + reference_arrays.boundary_distance + ) + orientation = reference_arrays.orientation + edge_array = reference_arrays.edge_array + + return cls( + time_series=time_series, + labels_array=labels_array, + boundary_distance=boundary_distance, + orientation=orientation, + edge_array=edge_array, + num_time=num_time, + num_bands=num_bands, + ) + + +@threadpool_limits.wrap(limits=1, user_api="blas") +def create_train_batch( image_list: T.List[T.List[T.Union[str, Path]]], - df_grids: gpd.GeoDataFrame, - df_edges: gpd.GeoDataFrame, + df_grid: gpd.GeoDataFrame, + df_polygons: gpd.GeoDataFrame, max_crop_class: int, - group_id: str = None, + region: str, process_path: Path = None, - transforms: T.List[str] = None, + date_format: str = "%Y%j", gain: float = 1e-4, offset: float = 0.0, ref_res: float = 10.0, resampling: str = "nearest", - num_workers: int = 1, grid_size: T.Optional[ T.Union[T.Tuple[int, int], T.List[int], None] ] = None, - instance_seg: T.Optional[bool] = False, - zero_padding: T.Optional[int] = 0, crop_column: T.Optional[str] = "class", keep_crop_classes: T.Optional[bool] = False, replace_dict: T.Optional[T.Dict[int, int]] = None, - pbar: T.Optional[object] = None, + nonag_is_unknown: bool = False, + all_touched: bool = True, + compress_method: T.Union[int, str] = 'zlib', ) -> None: - """Creates a dataset for training. - - Args: - image_list: A list of images. - df_grids: The training grids. - df_edges: The training edges. - max_crop_class: The maximum expected crop class value. - group_id: A group identifier, used for logging. - process_path: The main processing path. - transforms: A list of augmentation transforms to apply. - gain: A gain factor to apply to the images. - offset: An offset factor to apply to the images. - ref_res: The reference cell resolution to resample the images to. - resampling: The image resampling method. - num_workers: The number of dask workers. - grid_size: The requested grid size, in (rows, columns) or (height, width). - lc_path: The land cover image path. - n_ts: The number of temporal augmentations. - data_type: The target data type. - instance_seg: Whether to get instance segmentation mask targets. - zero_padding: Zero padding to apply. - crop_column: The crop column name in the polygon vector files. - keep_crop_classes: Whether to keep the crop classes as they are (True) or recode all - non-zero classes to crop (False). - replace_dict: A dictionary of crop class remappings. + """Creates a batch file for training. + + Parameters + ========== + image_list + A list of images. + df_grid + The training grid. + df_polygons + The training polygons. + max_crop_class + The maximum expected crop class value. + group_id + A group identifier, used for logging. + process_path + The main processing path. + gain + A gain factor to apply to the images. + offset + An offset factor to apply to the images. + ref_res + The reference cell resolution to resample the images to. + resampling + The image resampling method. + grid_size + The requested grid size, in (rows, columns) or (height, width). + lc_path + The land cover image path. + n_ts + The number of temporal augmentations. + data_type + The target data type. + instance_seg + Whether to get instance segmentation mask targets. + zero_padding + Zero padding to apply. + crop_column + The crop column name in the polygon vector files. + keep_crop_classes + Whether to keep the crop classes as they are (True) or recode all + non-zero classes to crop (False). + replace_dict + A dictionary of crop class remappings. + nonag_is_unknown + Whether the non-agricultural background is unknown. + all_touched + Rasterio/Shapely rasterization flag. """ - if transforms is None: - transforms = ["none"] - - merged_grids = [] - sindex = df_grids.sindex - - # Get the image CRS - with gw.open(image_list[0]) as src: - image_crs = src.crs - if ref_res is None: - ref_res = (src.gw.celly, src.gw.cellx) - else: - ref_res = (ref_res, ref_res) - - input_height = None - input_width = None - unprocessed = [] - for row in df_grids.itertuples(): - # Check if the grid has already been saved - if hasattr(row, "grid"): - row_grid_id = row.grid - elif hasattr(row, "region"): - row_grid_id = row.region - else: - raise AttributeError( - "The grid id should be given as 'grid' or 'region'." - ) - - uid_format = "{GROUP_ID}_{ROW_ID}_{AUGMENTER}" + start_date = pd.to_datetime( + Path(image_list[0]).stem, format=date_format + ).strftime("%Y%m%d") + end_date = pd.to_datetime( + Path(image_list[-1]).stem, format=date_format + ).strftime("%Y%m%d") + + uid_format = "{REGION_ID}_{START_DATE}_{END_DATE}_none" + group_id = f"{region}_{start_date}_{end_date}_none" + + transforms = ["none"] + + # Check if the grid has already been saved + batch_stored = is_grid_processed( + process_path=process_path, + transforms=transforms, + region=region, + start_date=start_date, + end_date=end_date, + uid_format=uid_format, + ) - batch_stored = is_grid_processed( - process_path=process_path, - transforms=transforms, - group_id=group_id, - grid_id=row_grid_id, - uid_format=uid_format, - ) - if batch_stored: - pbar.set_description(f"{group_id} stored.") - continue - - # Clip the edges to the current grid - try: - grid_edges = gpd.clip(df_edges, row.geometry) - except ValueError: - logger.warning( - TopologyClipError( - "The input GeoDataFrame contains topology errors." - ) - ) - df_edges = gpd.GeoDataFrame( - data=df_edges[crop_column].values, - columns=[crop_column], - geometry=df_edges.buffer(0).geometry, - ) - grid_edges = gpd.clip(df_edges, row.geometry) - - # These are grids with no crop fields. They should still - # be used for training. - if grid_edges.loc[~grid_edges.is_empty].empty: - grid_edges = df_grids.copy() - grid_edges = grid_edges.assign(**{crop_column: 0}) - # Remove empty geometry - grid_edges = grid_edges.loc[~grid_edges.is_empty] - - if not grid_edges.empty: - # Check if the edges overlap multiple grids - int_idx = sorted( - list( - sindex.intersection( - tuple(grid_edges.total_bounds.flatten()) - ) - ) - ) + if batch_stored: + return - if len(int_idx) > 1: - # Check if any of the grids have already been stored - if any( - [ - rowg in merged_grids - for rowg in df_grids.iloc[int_idx].grid.values.tolist() - ] - ): - pbar.set_description(f"No edges in {group_id}") - continue - - grid_edges = gpd.clip( - df_edges, df_grids.iloc[int_idx].geometry - ) - merged_grids.append(row.grid) + # These are grids with no crop fields. They should still + # be used for training. + if df_polygons.loc[~df_polygons.is_empty].empty: + df_polygons = df_grid.copy() + df_polygons = df_polygons.assign(**{crop_column: 0}) - nonzero_mask = grid_edges[crop_column] != 0 + # Remove empty geometries + df_polygons = df_polygons.loc[~df_polygons.is_empty] - # left, bottom, right, top - ref_bounds = ( - df_grids.to_crs(image_crs).iloc[int_idx].total_bounds.tolist() + if not df_polygons.empty: + type_mask = df_polygons.geom_type == "GeometryCollection" + if type_mask.any(): + exploded_collections = df_polygons.loc[type_mask].explode( + column="geometry" ) - if grid_size is not None: - height, width = grid_size - left, bottom, right, top = ref_bounds - + exploded_collections = exploded_collections.loc[ + (exploded_collections.geom_type == "Polygon") + | (exploded_collections.geom_type == "MultiPolygon") + ] + df_polygons = pd.concat( ( - dst_transform, - dst_width, - dst_height, - ) = calculate_default_transform( - src_crs=image_crs, - dst_crs=image_crs, - width=int(abs(round((right - left) / ref_res[1]))), - height=int(abs(round((top - bottom) / ref_res[0]))), - left=left, - bottom=bottom, - right=right, - top=top, - dst_width=width, - dst_height=height, + df_polygons.loc[~type_mask], + exploded_collections.droplevel(1), ) - dst_left = dst_transform[2] - dst_top = dst_transform[5] - dst_right = dst_left + abs(dst_width * dst_transform[0]) - dst_bottom = dst_top - abs(dst_height * dst_transform[4]) - ref_bounds = [dst_left, dst_bottom, dst_right, dst_top] - - # Data for graph network - xvars, labels_array, bdist, ori, ntime, nbands = create_image_vars( - image=image_list, - max_crop_class=max_crop_class, - bounds=ref_bounds, - num_workers=num_workers, - gain=gain, - offset=offset, - grid_edges=grid_edges if nonzero_mask.any() else None, - ref_res=ref_res[0], - resampling=resampling, - crop_column=crop_column, - keep_crop_classes=keep_crop_classes, - replace_dict=replace_dict, - ) - if xvars is None: - pbar.set_description(f"No fields in {group_id}") - continue - if (xvars.shape[1] < 5) or (xvars.shape[2] < 5): - pbar.set_description(f"{group_id} too small") - continue - - # Get the upper left lat/lon - left, bottom, right, top = ( - df_grids.iloc[int_idx] - .to_crs("epsg:4326") - .total_bounds.tolist() ) - if isinstance(group_id, str): - end_year = int(group_id.split("_")[-1]) - start_year = end_year - 1 - else: - start_year, end_year = None, None - - segments = nd_label(labels_array)[0] - props = regionprops(segments) - - ldata = LabeledData( - x=xvars, - y=labels_array, - bdist=bdist, - ori=ori, - segments=segments, - props=props, - ) + df_polygons = df_polygons.reset_index(drop=True) + df_polygons = df_polygons.loc[df_polygons.geom_type != "Point"] + type_mask = df_polygons.geom_type == "MultiPolygon" + if type_mask.any(): + raise TypeError("MultiPolygons should not exist.") + + # Get a mask of valid polygons + nonzero_mask = df_polygons[crop_column] != 0 + + # Data for the model network + image_variables = ImageVariables.create_image_vars( + region=region, + image=image_list, + reference_grid=df_grid, + df_polygons_grid=df_polygons if nonzero_mask.any() else None, + max_crop_class=max_crop_class, + grid_size=grid_size, + gain=gain, + offset=offset, + ref_res=ref_res, + resampling=resampling, + crop_column=crop_column, + keep_crop_classes=keep_crop_classes, + replace_dict=replace_dict, + nonag_is_unknown=nonag_is_unknown, + all_touched=all_touched, + ) - if input_height is None: - input_height = ldata.y.shape[0] - else: - if ldata.y.shape[0] != input_height: - warnings.warn( - f"{group_id}_{row_grid_id} does not have the same height as the rest of the dataset.", - UserWarning, - ) - unprocessed.append(f"{group_id}_{row_grid_id}") - continue - if input_width is None: - input_width = ldata.y.shape[1] - else: - if ldata.y.shape[1] != input_width: - warnings.warn( - f"{group_id}_{row_grid_id} does not have the same width as the rest of the dataset.", - UserWarning, - ) - unprocessed.append(f"{group_id}_{row_grid_id}") - continue - - augmenters = Augmenters( - augmentations=transforms, - ntime=ntime, - nbands=nbands, - max_crop_class=max_crop_class, - k=3, - instance_seg=instance_seg, - zero_padding=zero_padding, - start_year=start_year, - end_year=end_year, - left=left, - bottom=bottom, - right=right, - top=top, - res=ref_res, - ) - for aug_method in augmenters: - aug_kwargs = augmenters.aug_args.kwargs - aug_kwargs["train_id"] = uid_format.format( - GROUP_ID=group_id, - ROW_ID=row_grid_id, - AUGMENTER=aug_method.name_, - ) - augmenters.update_aug_args(kwargs=aug_kwargs) - aug_data = aug_method(ldata, aug_args=augmenters.aug_args) - aug_method.save( - out_directory=process_path, data=aug_data, compress=5 - ) + if image_variables.time_series is None: + return + + if (image_variables.time_series.shape[1] < 5) or ( + image_variables.time_series.shape[2] < 5 + ): + return + + # Get the upper left lat/lon + lat_left, lat_bottom, lat_right, lat_top = df_grid.to_crs( + "epsg:4326" + ).total_bounds.tolist() + + segments = nd_label( + (image_variables.labels_array > 0) + & (image_variables.labels_array < max_crop_class + 1) + )[0] + props = regionprops(segments) + + labeled_data = LabeledData( + x=image_variables.time_series, + y=image_variables.labels_array, + bdist=image_variables.boundary_distance, + ori=image_variables.orientation, + segments=segments, + props=props, + ) - # if unprocessed: - # logger.warning('Could not process the following grids.') - # logger.info(', '.join(unprocessed)) + batch = Data( + x=einops.rearrange( + torch.from_numpy(labeled_data.x / gain).to(dtype=torch.int32), + 't c h w -> 1 c t h w', + ), + y=einops.rearrange( + torch.from_numpy(labeled_data.y).to( + dtype=torch.int16 if nonag_is_unknown else torch.uint8 + ), + 'h w -> 1 h w', + ), + bdist=einops.rearrange( + torch.from_numpy(labeled_data.bdist / gain).to( + dtype=torch.int32 + ), + 'h w -> 1 h w', + ), + start_year=torch.tensor( + [pd.Timestamp(Path(image_list[0]).stem).year], + dtype=torch.int32, + ), + end_year=torch.tensor( + [pd.Timestamp(Path(image_list[-1]).stem).year], + dtype=torch.int32, + ), + left=torch.tensor([lat_left], dtype=torch.float32), + bottom=torch.tensor([lat_bottom], dtype=torch.float32), + right=torch.tensor([lat_right], dtype=torch.float32), + top=torch.tensor([lat_top], dtype=torch.float32), + batch_id=[group_id], + ) - return pbar + # FIXME: this doesn't support augmentations + for aug in transforms: + aug_method = AUGMENTER_METHODS[aug]() + train_id = uid_format.format( + REGION_ID=region, + START_DATE=start_date, + END_DATE=end_date, + AUGMENTER=aug_method.name_, + ) + train_path = process_path / aug_method.file_name(train_id) + batch.to_file(train_path, compress=compress_method) diff --git a/src/cultionet/data/data.py b/src/cultionet/data/data.py new file mode 100644 index 00000000..5f56d897 --- /dev/null +++ b/src/cultionet/data/data.py @@ -0,0 +1,328 @@ +import inspect +from copy import deepcopy +from dataclasses import dataclass +from functools import singledispatch +from pathlib import Path +from typing import List, Optional, Sequence, Tuple, Union + +import geowombat as gw +import joblib +import numpy as np +import torch +import xarray as xr +from pyproj import CRS +from pyproj.aoi import AreaOfInterest +from pyproj.crs import CRSError +from pyproj.database import query_utm_crs_info +from rasterio.coords import BoundingBox +from rasterio.transform import from_bounds +from rasterio.warp import transform_bounds + + +@singledispatch +def sanitize_crs(crs: CRS) -> CRS: + try: + return crs + except CRSError: + return CRS.from_string("epsg:4326") + + +@sanitize_crs.register +def _(crs: str) -> CRS: + return CRS.from_string(crs) + + +@sanitize_crs.register +def _(crs: int) -> CRS: + return CRS.from_epsg(crs) + + +@singledispatch +def sanitize_res(res: tuple) -> Tuple[float, float]: + return tuple(map(float, res)) + + +@sanitize_res.register(int) +@sanitize_res.register(float) +def _(res) -> Tuple[float, float]: + return sanitize_res((res, res)) + + +class Data: + def __init__( + self, + x: torch.Tensor, + y: Optional[torch.Tensor] = None, + **kwargs, + ): + self.x = x + self.y = y + if kwargs is not None: + for k, v in kwargs.items(): + if v is not None: + assert isinstance( + v, (torch.Tensor, np.ndarray, list) + ), "Only tensors, arrays, and lists are supported." + + setattr(self, k, v) + + def _get_attrs(self) -> set: + members = inspect.getmembers( + self, predicate=lambda x: not inspect.ismethod(x) + ) + return set(dict(members).keys()).intersection( + set(self.__dict__.keys()) + ) + + def to_dict( + self, device: Optional[str] = None, dtype: Optional[str] = None + ) -> dict: + kwargs = {} + for key in self._get_attrs(): + value = getattr(self, key) + if isinstance(value, torch.Tensor): + kwargs[key] = value.clone() + if device is not None: + kwargs[key] = kwargs[key].to(device=device, dtype=dtype) + elif isinstance(value, np.ndarray): + kwargs[key] = value.copy() + else: + if value is None: + kwargs[key] = None + else: + try: + kwargs[key] = deepcopy(value) + except RecursionError: + kwargs[key] = value + + return kwargs + + def to( + self, device: Optional[str] = None, dtype: Optional[str] = None + ) -> "Data": + return Data(**self.to_dict(device=device, dtype=dtype)) + + def __add__(self, other: "Data") -> "Data": + out_dict = {} + for key, value in self.to_dict().items(): + if isinstance(value, torch.Tensor): + out_dict[key] = value + getattr(other, key) + + return Data(**out_dict) + + def __iadd__(self, other: "Data") -> "Data": + self = self + other + + return self + + def copy(self) -> "Data": + return Data(**self.to_dict()) + + @property + def num_samples(self) -> int: + return self.x.shape[0] + + @property + def num_channels(self) -> int: + return self.x.shape[1] + + @property + def num_time(self) -> int: + return self.x.shape[2] + + @property + def height(self) -> int: + return self.x.shape[3] + + @property + def width(self) -> int: + return self.x.shape[4] + + def to_file( + self, filename: Union[Path, str], compress: Union[int, str] = 'zlib' + ) -> None: + Path(filename).parent.mkdir(parents=True, exist_ok=True) + joblib.dump( + self.to_dict(), + filename, + compress=compress, + ) + + @classmethod + def from_file(cls, filename: Union[Path, str]) -> "Data": + return Data(**joblib.load(filename)) + + def __str__(self): + data_string = f"Data(x={tuple(self.x.shape)}" + if self.y is not None: + data_string += f", y={tuple(self.y.shape)}" + + for k, v in self.to_dict().items(): + if k not in ( + 'x', + 'y', + ): + if isinstance(v, (np.ndarray, torch.Tensor)): + if len(v.shape) == 1: + data_string += f", {k}={v.numpy().tolist()}" + else: + data_string += f", {k}={tuple(v.shape)}" + elif isinstance(v, list): + if len(v) == 1: + data_string += f", {k}={v}" + else: + data_string += f", {k}={[len(v)]}" + + data_string += ")" + + return data_string + + def __repr__(self): + return str(self) + + def plot( + self, + channel: Union[int, Sequence[int]], + res: Union[float, Sequence[float]], + crs: Optional[Union[int, str]] = None, + ) -> tuple: + + import matplotlib.pyplot as plt + + fig, axes = plt.subplots(1, 3, figsize=(8, 4), sharey=True, dpi=150) + + ds = self.to_dataset(res=res, crs=crs) + + bands = ds["bands"].assign_attrs(**ds.attrs).sel(channel=channel) + bands = bands.where(lambda x: x > 0) + cv = bands.std(dim='time') / bands.mean(dim='time') + + cv.plot.imshow( + add_colorbar=False, + robust=True, + interpolation="nearest", + ax=axes[0], + ) + ( + ds["labels"].where(lambda x: x != -1).assign_attrs(**ds.attrs) + ).plot.imshow(add_colorbar=False, interpolation="nearest", ax=axes[1]) + (ds["distances"].assign_attrs(**ds.attrs)).plot.imshow( + add_colorbar=False, interpolation="nearest", ax=axes[2] + ) + + for ax in axes: + ax.set_xlabel('') + ax.set_ylabel('') + + axes[0].set_title("CV") + axes[1].set_title("Labels") + axes[2].set_title("Distances") + + fig.supxlabel("X") + fig.supylabel("Y") + + return fig, axes + + def utm_bounds(self) -> CRS: + utm_crs_info = query_utm_crs_info( + datum_name="WGS 84", + area_of_interest=AreaOfInterest( + west_lon_degree=self.left[0], + south_lat_degree=self.bottom[0], + east_lon_degree=self.right[0], + north_lat_degree=self.top[0], + ), + )[0] + + return CRS.from_epsg(utm_crs_info.code) + + def transform_bounds(self, crs: CRS) -> BoundingBox: + """Transforms a bounding box to a new CRS.""" + + bounds = transform_bounds( + src_crs=sanitize_crs("epsg:4326"), + dst_crs=sanitize_crs(crs), + left=self.left[0], + bottom=self.bottom[0], + right=self.right[0], + top=self.top[0], + ) + + return BoundingBox(*bounds) + + def from_bounds( + self, + bounds: BoundingBox, + res: Union[float, Sequence[float]], + ) -> tuple: + """Converts a bounding box to a transform adjusted by the + resolution.""" + + res = sanitize_res(res) + + adjusted_bounds = BoundingBox( + left=bounds.left, + bottom=bounds.top - self.height * float(abs(res[1])), + right=bounds.left + self.width * float(abs(res[0])), + top=bounds.top, + ) + + adjusted_transform = from_bounds( + *adjusted_bounds, + width=self.width, + height=self.height, + ) + + return adjusted_bounds, adjusted_transform + + def to_dataset( + self, + res: Union[float, Sequence[float]], + crs: Optional[Union[int, str]] = None, + ) -> xr.Dataset: + """Converts a PyTorch data batch to an Xarray Dataset.""" + + if crs is None: + crs = self.utm_bounds() + + crs = sanitize_crs(crs) + dst_bounds = self.transform_bounds(crs) + dst_bounds, transform = self.from_bounds(dst_bounds, res=res) + + return xr.Dataset( + data_vars=dict( + bands=( + ["channel", "time", "y", "x"], + self.x[0].numpy() * 1e-4, + ), + labels=(["y", "x"], self.y[0].numpy()), + distances=(["y", "x"], self.bdist[0].numpy() * 1e-4), + ), + coords={ + "channel": range(1, self.num_channels + 1), + "time": range(1, self.num_time + 1), + "y": np.linspace( + dst_bounds.top, dst_bounds.bottom, self.height + ), + "x": np.linspace( + dst_bounds.left, dst_bounds.right, self.width + ), + }, + attrs={ + "name": self.batch_id[0], + "crs": crs.to_epsg(), + "res": (float(abs(transform[0])), float(abs(transform[4]))), + "transform": transform, + "_FillValue": -1, + }, + ) + + +@dataclass +class LabeledData: + x: np.ndarray + y: Union[None, np.ndarray] + bdist: Union[None, np.ndarray] + ori: Union[None, np.ndarray] + segments: Union[None, np.ndarray] + props: Union[None, List] diff --git a/src/cultionet/data/datasets.py b/src/cultionet/data/datasets.py index c67f4d74..0f6d74b6 100644 --- a/src/cultionet/data/datasets.py +++ b/src/cultionet/data/datasets.py @@ -1,25 +1,29 @@ import typing as T -from pathlib import Path +from copy import deepcopy from functools import partial +from pathlib import Path -import numpy as np import attr -import torch -from torch_geometric.data import Data, Dataset +import geopandas as gpd +import lightning as L +import numpy as np import psutil -import joblib +import pygrts +import torch from joblib import delayed, parallel_backend -import pandas as pd -import geopandas as gpd +from scipy.ndimage.measurements import label as nd_label from shapely.geometry import box -from pytorch_lightning import seed_everything -from geosample import QuadTree +from skimage.measure import regionprops from tqdm.auto import tqdm +from ..augment.augmenters import Augmenters from ..errors import TensorShapeError from ..utils.logging import set_color_logger -from ..utils.model_preprocessing import TqdmParallel - +from ..utils.model_preprocessing import ParallelProgress +from ..utils.normalize import NormValues +from .constant import SCALE_FACTOR +from .data import Data +from .spatial_dataset import SpatialDataset ATTRVINSTANCE = attr.validators.instance_of ATTRVIN = attr.validators.in_ @@ -28,123 +32,133 @@ logger = set_color_logger(__name__) -def add_dims(d: torch.Tensor) -> torch.Tensor: - return d.unsqueeze(0) - - -def update_data( - batch: Data, - idx: T.Optional[int] = None, - x: T.Optional[torch.Tensor] = None, -) -> Data: - image_id = None - if idx is not None: - if hasattr(batch, "boxes"): - if batch.boxes is not None: - image_id = ( - torch.zeros_like(batch.box_labels, dtype=torch.int64) + idx - ) - - if x is not None: - exclusion = ("x",) - - return Data( - x=x, - image_id=image_id, - **{k: getattr(batch, k) for k in batch.keys if k not in exclusion}, - ) - else: - return Data( - image_id=image_id, **{k: getattr(batch, k) for k in batch.keys} - ) - - -def zscores( - batch: Data, - data_means: torch.Tensor, - data_stds: torch.Tensor, - idx: T.Optional[int] = None, -) -> Data: - """Normalizes data to z-scores. - - Args: - batch (Data): A `torch_geometric` data object. - data_means (Tensor): The data feature-wise means. - data_stds (Tensor): The data feature-wise standard deviations. - - z = (x - μ) / σ - """ - x = (batch.x - add_dims(data_means)) / add_dims(data_stds) - - return update_data(batch=batch, idx=idx, x=x) - - def _check_shape( - d1: int, h1: int, w1: int, d2: int, h2: int, w2: int, index: int, uid: str + expected_time: int, + expected_height: int, + expected_width: int, + in_time: int, + in_height: int, + in_width: int, + index: int, + uid: str, ) -> T.Tuple[bool, int, str]: - if (d1 != d2) or (h1 != h2) or (w1 != w2): + if ( + (expected_time != in_time) + or (expected_height != in_height) + or (expected_width != in_width) + ): return False, index, uid return True, index, uid -@attr.s -class EdgeDataset(Dataset): - """An edge dataset.""" - - root: T.Union[str, Path, bytes] = attr.ib(default=".") - transform: T.Any = attr.ib(default=None) - pre_transform: T.Any = attr.ib(default=None) - data_means: T.Optional[torch.Tensor] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(torch.Tensor)), default=None - ) - data_stds: T.Optional[torch.Tensor] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(torch.Tensor)), default=None - ) - crop_counts: T.Optional[torch.Tensor] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(torch.Tensor)), default=None - ) - edge_counts: T.Optional[torch.Tensor] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(torch.Tensor)), default=None - ) - pattern: T.Optional[str] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(str)), default="data*.pt" - ) - processes: T.Optional[int] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(int)), default=psutil.cpu_count() - ) - threads_per_worker: T.Optional[int] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(int)), default=1 - ) - random_seed: T.Optional[int] = attr.ib( - validator=ATTRVOPTIONAL(ATTRVINSTANCE(int)), default=42 - ) +class EdgeDataset(SpatialDataset): + """An edge dataset. + + Parameters + ========== + root + The root data directory. + log_transform + Whether to log-transform the data before truncating by Sigmoid. For details, see: + + @article{brown_etal_2022, + title={Dynamic World, Near real-time global 10 m land use land cover mapping}, + author={Brown, Christopher F and Brumby, Steven P and Guzder-Williams, Brookie and Birch, Tanya and Hyde, Samantha Brooks and Mazzariello, Joseph and Czerwinski, Wanda and Pasquarella, Valerie J and Haertel, Robert and Ilyushchenko, Simon and others}, + journal={Scientific Data}, + volume={9}, + number={1}, + pages={251}, + year={2022}, + publisher={Nature Publishing Group UK London}, + url={https://www.nature.com/articles/s41597-022-01307-4}, + } + normalize + Whether to normalize the data by mean and standard deviation, centering the data around the mean or median. Note that if + ``log_transform=True``, it is applied before normalization. + norm_values + The normalization data object. + pattern + The data search pattern. + processes + The number of parallel processes. + random_seed + A random seed value. + augment_prob + The probability of applying random augmentation. + """ data_list_ = None grid_id_column = "grid_id" - def __attrs_post_init__(self): - super(EdgeDataset, self).__init__( - str(self.root), - transform=self.transform, - pre_transform=self.pre_transform, - ) - seed_everything(self.random_seed, workers=True) + def __init__( + self, + root: T.Union[str, Path, bytes] = ".", + log_transform: bool = False, + norm_values: T.Optional[NormValues] = None, + pattern: str = "data*.pt", + processes: int = psutil.cpu_count(), + random_seed: int = 42, + augment_prob: float = 0.0, + ): + self.root = root + self.log_transform = log_transform + self.norm_values = norm_values + self.pattern = pattern + self.processes = processes + self.random_seed = random_seed + self.augment_prob = augment_prob + + L.seed_everything(self.random_seed) self.rng = np.random.default_rng(self.random_seed) + self.augmentations_ = [ + 'tswarp', + 'tsnoise', + 'tsdrift', + 'tspeaks', + 'rot90', + 'rot180', + 'rot270', + 'roll', + 'fliplr', + 'flipud', + 'gaussian', + 'saltpepper', + 'cropresize', + 'perlin', + ] + + self.data_list_ = None + self.processed_dir = Path(self.root) / 'processed' + self.get_data_list() + def get_data_list(self): """Gets the list of data files.""" - self.data_list_ = list(Path(self.processed_dir).glob(self.pattern)) + data_list_ = sorted(list(Path(self.processed_dir).glob(self.pattern))) - if not self.data_list_: + if not data_list_: logger.exception( f"No .pt files were found with pattern {self.pattern}." ) + self.data_list_ = np.array(data_list_) + + @property + def data_list(self): + """Get a list of processed files.""" + return self.data_list_ + + def __len__(self): + """Returns the dataset length.""" + return len(self.data_list) + def cleanup(self): for fn in self.data_list_: fn.unlink() - def shuffle_items(self, data: T.Optional[list] = None): + self.data_list_ = [] + + def shuffle(self, data: T.Optional[list] = None): """Applies a random in-place shuffle to the data list.""" if data is not None: self.rng.shuffle(data) @@ -152,52 +166,13 @@ def shuffle_items(self, data: T.Optional[list] = None): self.rng.shuffle(self.data_list_) @property - def num_time_features(self): - """Get the number of time features.""" - data = self[0] - return int(data.ntime) + def num_channels(self) -> int: + return self[0].num_channels @property - def raw_file_names(self): - """Get the raw file names.""" - if not self.data_list_: - self.get_data_list() - - return self.data_list_ - - def to_frame(self) -> gpd.GeoDataFrame: - """Converts the Dataset to a GeoDataFrame.""" - - def get_box_id(data_id: str, *bounds): - return data_id, box(*bounds).centroid - - with parallel_backend(backend="loky", n_jobs=self.processes): - with TqdmParallel( - tqdm_kwargs={ - "total": len(self), - "desc": "Building GeoDataFrame", - } - ) as pool: - results = pool( - delayed(get_box_id)( - data.train_id, - data.left, - data.bottom, - data.right, - data.top, - ) - for data in self - ) - - ids, geometry = list(map(list, zip(*results))) - df = gpd.GeoDataFrame( - data=ids, - columns=[self.grid_id_column], - geometry=geometry, - crs="epsg:4326", - ) - - return df + def num_time(self) -> int: + """Get the number of time features.""" + return self[0].num_time def get_spatial_partitions( self, @@ -205,14 +180,19 @@ def get_spatial_partitions( splits: int = 0, ) -> None: """Gets the spatial partitions.""" - self.create_spatial_index() + self.create_spatial_index( + id_column=self.grid_id_column, n_jobs=self.processes + ) if isinstance(spatial_partitions, (str, Path)): spatial_partitions = gpd.read_file(spatial_partitions) else: - spatial_partitions = self.to_frame() + spatial_partitions = self.to_frame( + id_column=self.grid_id_column, + n_jobs=self.processes, + ) if splits > 0: - qt = QuadTree(spatial_partitions, force_square=False) + qt = pygrts.QuadTree(spatial_partitions, force_square=False) for __ in range(splits): qt.split() spatial_partitions = qt.to_frame() @@ -276,10 +256,9 @@ def split_indices( else: return self[indices] - def spatial_kfoldcv_iter( - self, partition_column: str - ) -> T.Tuple[str, "EdgeDataset", "EdgeDataset"]: + def spatial_kfoldcv_iter(self, partition_column: str): """Yield generator to iterate over spatial partitions.""" + for kfold in self.spatial_partitions.itertuples(): # Bounding box and indices of the kth fold kfold_indices = self.query_partition_by_name( @@ -292,31 +271,9 @@ def spatial_kfoldcv_iter( yield str(getattr(kfold, partition_column)), train_ds, test_ds - def create_spatial_index(self): - """Creates the spatial index.""" - dataset_grid_path = ( - Path(self.processed_dir).parent.parent / "dataset_grids.gpkg" - ) - if dataset_grid_path.is_file(): - self.dataset_df = gpd.read_file(dataset_grid_path) - else: - self.dataset_df = self.to_frame() - self.dataset_df.to_file(dataset_grid_path, driver="GPKG") - - def download(self): - pass - - def process(self): - pass - - @property - def processed_file_names(self): - """Get a list of processed files.""" - return self.data_list_ - def check_dims( self, - expected_dim: int, + expected_time: int, expected_height: int, expected_width: int, delete_mismatches: bool = False, @@ -324,15 +281,16 @@ def check_dims( ): """Checks if all tensors in the dataset match in shape dimensions.""" check_partial = partial( - _check_shape, expected_dim, expected_height, expected_width + _check_shape, + expected_time=expected_time, + expected_height=expected_height, + expected_width=expected_width, ) - with parallel_backend( backend="loky", n_jobs=self.processes, - inner_max_num_threads=self.threads_per_worker, ): - with TqdmParallel( + with ParallelProgress( tqdm_kwargs={ "total": len(self), "desc": "Checking dimensions", @@ -341,14 +299,15 @@ def check_dims( ) as pool: results = pool( delayed(check_partial)( - self[i].x.shape[1], - self[i].height, - self[i].width, - i, - self[i].train_id, + in_time=self[i].num_time, + in_height=self[i].height, + in_width=self[i].width, + index=i, + uid=self[i].batch_id, ) for i in range(0, len(self)) ) + matches, indices, ids = list(map(list, zip(*results))) if not all(matches): indices = np.array(indices) @@ -368,10 +327,6 @@ def check_dims( else: raise TensorShapeError - def len(self): - """Returns the dataset length.""" - return len(self.processed_file_names) - def split_train_val_by_partition( self, spatial_partitions: str, @@ -382,7 +337,7 @@ def split_train_val_by_partition( self.get_spatial_partitions(spatial_partitions=spatial_partitions) train_indices = [] val_indices = [] - self.shuffle_items() + self.shuffle() # self.spatial_partitions is a GeoDataFrame with Point geometry for row in tqdm( self.spatial_partitions.itertuples(), @@ -415,96 +370,131 @@ def split_train_val( val_frac: float, spatial_overlap_allowed: bool = True, spatial_balance: bool = True, + crs: str = "EPSG:8857", ) -> T.Tuple["EdgeDataset", "EdgeDataset"]: """Splits the dataset into train and validation. - Args: - val_frac (float): The validation fraction. + Parameters + ========== + val_frac + The validation fraction. - Returns: - train dataset, validation dataset + Returns + ======= + train dataset, validation dataset """ - id_column = "common_id" - self.shuffle_items() + # We do not need augmentations when loading batches for + # sample splits. + augment_prob = deepcopy(self.augment_prob) + self.augment_prob = 0.0 + if spatial_overlap_allowed: + self.shuffle() n_train = int(len(self) * (1.0 - val_frac)) train_ds = self[:n_train] val_ds = self[n_train:] else: - # Create a GeoDataFrame of every .pt file in - # the dataset. - self.create_spatial_index() - # Create column of each site's common id - # (i.e., without the year and augmentation). - self.dataset_df[id_column] = self.dataset_df.grid_id.str.split( - "_", expand=True - ).loc[:, 0] - unique_ids = self.dataset_df.common_id.unique() - if spatial_balance: - # Separate train and validation by spatial location - - # Get unique site coordinates - # NOTE: We do this becuase augmentations are stacked at - # the same site, thus creating multiple files with the - # same centroid. - df_unique_locations = gpd.GeoDataFrame( - pd.Series(unique_ids) - .to_frame(name=id_column) - .merge(self.dataset_df, on=id_column) - .drop_duplicates(id_column) - .drop(columns=["grid_id"]) - ).to_crs("EPSG:8858") - # Setup a quad-tree using the GRTS method - # (see https://github.com/jgrss/geosample for details) - qt = QuadTree(df_unique_locations, force_square=False) - # Recursively split the quad-tree until each grid has - # only one sample. - qt.split_recursive(max_samples=1) - n_val = int(val_frac * len(df_unique_locations.index)) - # `qt.sample` random samples from the quad-tree in a - # spatially balanced manner. Thus, `df_val_sample` is - # a GeoDataFrame with `n_val` sites spatially balanced. - df_val_sample = qt.sample(n=n_val) - # Since we only took one sample from each coordinate, - # we need to find all of the .pt files that share - # coordinates with the sampled sites. - val_mask = self.dataset_df.common_id.isin( - df_val_sample.common_id - ) - else: - # Randomly sample a percentage for validation - df_val_ids = ( - pd.Series(unique_ids) - .sample(frac=val_frac, random_state=self.random_seed) - .to_frame(name=id_column) - ) - # Get all ids for validation samples - val_mask = self.dataset_df.common_id.isin(df_val_ids.common_id) - # Get train/val indices - val_idx = self.dataset_df.loc[val_mask].index.tolist() - train_idx = self.dataset_df.loc[~val_mask].index.tolist() - # Slice the dataset - train_ds = self[train_idx] - val_ds = self[val_idx] + self.create_spatial_index( + id_column=self.grid_id_column, + n_jobs=self.processes, + ) + + train_ds, val_ds = self.spatial_splits( + val_frac=val_frac, + id_column=self.grid_id_column, + spatial_balance=spatial_balance, + crs=crs, + random_state=self.random_seed, + ) + + train_ds.augment_prob = augment_prob + val_ds.augment_prob = 0.0 return train_ds, val_ds def load_file(self, filename: T.Union[str, Path]) -> Data: - return joblib.load(filename) + return Data.from_file(filename) - def get(self, idx): - """Gets an individual data object from the dataset. + def __getitem__( + self, idx: T.Union[int, np.ndarray] + ) -> T.Union[dict, "EdgeDataset"]: + if isinstance(idx, (int, np.integer)): + return self.get(idx) + else: + return self.index_select(idx) + + def index_select(self, idx: np.ndarray) -> "EdgeDataset": + dataset = deepcopy(self) + dataset.data_list_ = self.data_list_[idx] - Args: - idx (int): The dataset index position. + return dataset - Returns: - A `torch_geometric` data object. + def get(self, idx: int) -> dict: + """Gets an individual data object from the dataset. + + Parameters + ========== + idx + The dataset index position. """ + batch = self.load_file(self.data_list_[idx]) - if isinstance(self.data_means, torch.Tensor): - batch = zscores(batch, self.data_means, self.data_stds, idx=idx) - else: - batch = update_data(batch=batch, idx=idx) + + batch.x = (batch.x / SCALE_FACTOR).clip(1e-9, 1) + + if hasattr(batch, 'bdist'): + batch.bdist = (batch.bdist / SCALE_FACTOR).clip(1e-9, 1) + + if batch.y is not None: + if self.rng.random() > (1 - self.augment_prob): + # Choose one augmentation to apply + aug_name = self.rng.choice(self.augmentations_) + + if aug_name in ( + 'roll', + 'tswarp', + 'tsnoise', + 'tsdrift', + 'tspeaks', + ): + # FIXME: By default, the crop value is 1 (background is 0 and edges are 2). + # But, it would be better to get 1 from an argument. + # Label properties are only used in 5 augmentations + batch.segments = np.uint8( + nd_label(batch.y.squeeze().numpy() == 1)[0] + ) + batch.props = regionprops(batch.segments) + + # Create the augmenter object + aug_modules = Augmenters( + # NOTE: apply a single augmenter + # TODO: could apply a series of augmenters + augmentations=[aug_name], + rng=self.rng, + ) + + # Apply the object + batch = aug_modules(batch) + batch.segments = None + batch.props = None + + if self.log_transform: + # Dynamic World log transform + # NOTE: If inputs are 0-10,000, then (x * 0.005) + batch.x = torch.log(batch.x * 50.0 + 1.0).clamp_min(1e-9) + + if self.norm_values is not None: + # Center values around the mean or median + batch = self.norm_values(batch) + + # Get the centroid + centroid = box( + float(batch.left), + float(batch.bottom), + float(batch.right), + float(batch.top), + ).centroid + batch.lon = torch.tensor([centroid.x]) + batch.lat = torch.tensor([centroid.y]) return batch diff --git a/src/cultionet/data/lookup.py b/src/cultionet/data/lookup.py deleted file mode 100644 index 641ae371..00000000 --- a/src/cultionet/data/lookup.py +++ /dev/null @@ -1,470 +0,0 @@ -NON_AG = frozenset(("unknown", "developed", "trees")) -NON_CROP = frozenset(("hay", "pasture")) - -CDL_COLORS = dict( - background="#ffffff", - all_crops="#E4A520", - maize="#ffd300", - spring_maize="#dca50c", - maize1="#dca50c", - maize2="#ffd300", - dbl_maize="#b29300", - dbl_spring_maize="#dca50c", - dbl_maize1="#b29300", - dbl_maize2="#b29300", - dbl_cotton="#fe2725", - soybeans="#267000", - spring_soybeans="#267000", - soybeans1="#359c00", - soybeans2="#267000", - dbl_soybeans="#1a4f00", - dbl_spring_soybeans="#1a4f00", - dbl_soybeans1="#1a4f00", - dbl_soybeans2="#1a4f00", - cotton="#fe2725", - peanuts="#70a500", - dbl_wheat="#a57000", - dbl_winter_wheat_soy="#707002", - dbl_winter_wheat_maize="#ffd301", - millet="#70004a", - spring_millet="#8f005f", - pecans="#b4705b", - rye="#ac017c", - oats="#a15989", - dbl_cropping="#9a622a", - sorghum="#fe9e0c", - spring_sorghum="#eb9109", - dbl_sorghum="#905906", - winter_wheat="#a57000", - spring_wheat="#d9b56b", - durum_wheat="#896454", - dry_beans="#a40000", - safflower="#d6d700", - dbl_safflower="#d6d700", - rape_seed="#d1ff00", - mustard="#00b04a", - buckwheat="#d69dbd", - sudangrass="#b663a0", - dbl_sudangrass="#8d5060", - onions="#ff6966", - camelina="#02ad4c", - peas="#54ff00", - watermelons="#ff6766", - honeydew_melons="#ff6766", - dbl_soy_oats="#267000", - dbl_maize_soy="#ffd300", - sweet_potatoes="#702601", - hops="#00ae4c", - pumpkins="#ff6766", - dbl_durum_wheat_sorghum="#ff9e0a", - dbl_barley_sorghum="#ff9e0a", - triticale="#d69dbd", - pop_orn_maize="#dca50c", - almonds="#00a682", - pistachios="#00ff8c", - aquaculture="#01ffff", - dbl_winter_wheat_cotton="#a57000", - dbl_soy_cotton="#267000", - sweet_maize="#dca50c", - sunflower="#ffff00", - dbl_sunflower="#ffff00", - flaxseed="#8099fe", - clover="#e8c0ff", - sod_grass="#afffdc", - lentils="#00deaf", - sugarbeets="#a700e4", - walnuts="#ead6af", - dbl_oats_maize="#ffd300", - herbs="#7fd3ff", - blueberries="#000098", - peaches="#ff8daa", - pears="#b29b71", - grapes="#6f4489", - orchard="#6f4489", - cucumbers="#fd6666", - chick_peaks="#00b04a", - misc_fruits_vegs="#ff6766", - carrots="#ff6666", - asparagus="#ff6666", - garlic="#ff6666", - cantaloupes="#ff6666", - prunes="#ff8fab", - olives="#334a33", - oranges="#e37026", - broccoli="#ff6666", - cabbage="#ff6666", - cauliflower="#ff6666", - celery="#ff6666", - peppers="#ff6766", - pomegranates="#b09970", - nectarines="#ff8fab", - greens="#ff6666", - plums="#ff8fab", - strawberries="#ff6666", - rice="#01a8e6", - potatoes="#702601", - alfalfa="#df91c7", - other_crops="#00ae4c", - sugarcane="#648d6c", - dbl_sugarcane="#648d6c", - speltz="#d69dbd", - winter_barley="#e2007d", - barley="#e240a4", - dbl_barley="#e2007d", - dbl_winter_wheat_sorghum="#a57001", - dbl_barley_maize="#ffd300", - dbl_barley_soy="#267000", - canola="#d1ff00", - switchgrass="#00ae4c", - tomatoes="#f3a378", - tobacco="#008539", - pastureland="#e9ffbf", - grassland_pasture="#e9ffbf", - savanna="#739f73", - other_hay="#a5f18c", - dbl_hay="#7fef81", - fallow="#bfbf77", - harvested="#bfbf77", - planted="#9bbf77", - cherries="#ff00ff", - apples="#bb004f", - squash="#ff6766", - apricots="#ff8fab", - vetch="#00b04a", - lettuce="#ff6666", - turnips="#ff6766", - eggplants="#ff6766", - radishes="#ff6766", - gourds="#ff6666", - cranberries="#ff6666", - christmas_trees="#007776", - other_tree_crops="#b29b71", - citrus="#ffff7d", - deciduous_forest="#92cc92", - evergreen_forest="#92ccaf", - mixed_forest="#afcc92", - forest="#4E6507", - deforestation="#ff35e4", - shrubland="#c7d5a0", - cactus="#c7d5a0", - woody_wetlands="#7fb39a", - herbaceous_wetlands="#7fb2b3", - wetlands="#7fb2b3", - barren="#cdbfa4", - plantation="#7833ad", - open_water="#5990B1", - water="#4c70a4", - developed="#707A88", - developed_high="#5f0100", -) - -CDL_LABELS = dict( - background=0, - cropland=1, - maize=1, - cotton=2, - rice=3, - sorghum=4, - soybeans=5, - sunflower=6, - spring_maize=7, - spring_soybeans=8, - spring_sorghum=9, - peanuts=10, - tobacco=11, - sweet_maize=12, - pop_orn_maize=13, - mint=14, - maize1=15, - maize2=16, - soybeans1=17, - soybeans2=18, - spring_millet=19, - barley=20, - winter_barley=21, - durum_wheat=22, - spring_wheat=23, - winter_wheat=24, - other_small_grains=25, - dbl_winter_wheat_soy=26, - rye=27, - oats=28, - millet=29, - speltz=30, - canola=31, - flaxseed=32, - safflower=33, - rape_seed=34, - mustard=35, - alfalfa=36, - other_hay=37, - camelina=38, - buckwheat=39, - sudangrass=40, - sugarbeets=41, - dry_beans=42, - potatoes=43, - other_crops=44, - sugarcane=45, - sweet_potatoes=46, - misc_fruits_vegs=47, - watermelons=48, - onions=49, - cucumbers=50, - chick_peaks=51, - lentils=52, - peas=53, - tomatoes=54, - hops=56, - herbs=57, - clover=58, - sod_grass=59, - switchgrass=60, - fallow=61, - harvested=62, - planted=63, - young=64, - cherries=66, - peaches=67, - apples=68, - grapes=69, - christmas_trees=70, - other_tree_crops=71, - citrus=72, - pecans=74, - almonds=75, - walnuts=76, - pears=77, - orchard=78, - tilled=79, - dbl_maize=80, - dbl_cotton=81, - dbl_sorghum=82, - dbl_soybeans=83, - dbl_sunflower=84, - dbl_tobacco=85, - dbl_millet=86, - dbl_hay=87, - dbl_sudangrass=88, - dbl_dry_beans=89, - dbl_other_crops=90, - dbl_sugarcane=91, - aquaculture=92, - dbl_onions=93, - dbl_rice=94, - dbl_alfalfa=95, - dbl_clover=96, - dbl_wheat=97, - dbl_barley=98, - dbl_oats=99, - dbl_spring_maize=100, - dbl_spring_soybeans=101, - dbl_safflower=104, - dbl_cropping=110, - open_water=111, - developed_open=121, - developed_low=122, - developed_medium=123, - developed_high=124, - barren=131, - plantation=138, - eucalyptus_plantation=139, - pine_plantation=140, - deciduous_forest=141, - evergreen_forest=142, - mixed_forest=143, - forest=144, - deforestation=145, - reforestation=146, - shrubland=152, - cactus=153, - savanna=173, - grassland=174, - pastureland=175, - grassland_pasture=176, - woody_wetlands=190, - herbaceous_wetlands=195, - pistachios=204, - triticale=205, - carrots=206, - asparagus=207, - garlic=208, - cantaloupes=209, - prunes=210, - olives=211, - oranges=212, - honeydew_melons=213, - broccoli=214, - peppers=216, - pomegranates=217, - nectarines=218, - greens=219, - plums=220, - strawberries=221, - squash=222, - apricots=223, - vetch=224, - dbl_winter_wheat_maize=225, - dbl_oats_maize=226, - lettuce=227, - pumpkins=229, - dbl_durum_wheat_sorghum=234, - dbl_barley_sorghum=235, - dbl_winter_wheat_sorghum=236, - dbl_barley_maize=237, - dbl_winter_wheat_cotton=238, - dbl_soy_cotton=239, - dbl_soy_oats=240, - dbl_maize_soy=241, - blueberries=242, - cabbage=243, - cauliflower=244, - celery=245, - radishes=246, - turnips=247, - eggplants=248, - gourds=249, - cranberries=250, - dbl_barley_soy=254, -) - -CDL_CROP_LABELS = dict( - maize=1, - cotton=2, - rice=3, - sorghum=4, - soybeans=5, - sunflower=6, - spring_maize=7, - spring_soybeans=8, - spring_sorghum=9, - peanuts=10, - tobacco=11, - sweet_maize=12, - pop_orn_maize=13, - mint=14, - maize1=15, - maize2=16, - soybeans1=17, - soybeans2=18, - spring_millet=19, - barley=20, - winter_barley=21, - durum_wheat=22, - spring_wheat=23, - winter_wheat=24, - other_small_grains=25, - dbl_winter_wheat_soy=26, - rye=27, - oats=28, - millet=29, - speltz=30, - canola=31, - flaxseed=32, - safflower=33, - rape_seed=34, - mustard=35, - alfalfa=36, - other_hay=37, - camelina=38, - buckwheat=39, - sudangrass=40, - sugarbeets=41, - dry_beans=42, - potatoes=43, - other_crops=44, - sugarcane=45, - sweet_potatoes=46, - misc_fruits_vegs=47, - watermelons=48, - onions=49, - cucumbers=50, - chick_peaks=51, - lentils=52, - peas=53, - tomatoes=54, - hops=56, - herbs=57, - clover=58, - sod_grass=59, - switchgrass=60, - fallow=61, - cherries=66, - peaches=67, - apples=68, - grapes=69, - citrus=72, - pecans=74, - almonds=75, - walnuts=76, - pears=77, - dbl_maize=80, - dbl_cotton=81, - dbl_sorghum=82, - dbl_soybeans=83, - dbl_sunflower=84, - dbl_tobacco=85, - dbl_millet=86, - dbl_hay=87, - dbl_sudangrass=88, - dbl_dry_beans=89, - dbl_other_crops=90, - dbl_sugarcane=91, - aquaculture=92, - dbl_onions=93, - dbl_rice=94, - dbl_alfalfa=95, - dbl_clover=96, - dbl_wheat=97, - dbl_barley=98, - dbl_oats=99, - dbl_spring_maize=100, - dbl_spring_soybeans=101, - dbl_safflower=104, - dbl_cropping=110, - pistachios=204, - triticale=205, - carrots=206, - asparagus=207, - garlic=208, - cantaloupes=209, - prunes=210, - olives=211, - oranges=212, - honeydew_melons=213, - broccoli=214, - peppers=216, - pomegranates=217, - nectarines=218, - greens=219, - plums=220, - strawberries=221, - squash=222, - apricots=223, - vetch=224, - dbl_winter_wheat_maize=225, - dbl_oats_maize=226, - lettuce=227, - pumpkins=229, - dbl_durum_wheat_sorghum=234, - dbl_barley_sorghum=235, - dbl_winter_wheat_sorghum=236, - dbl_barley_maize=237, - dbl_winter_wheat_cotton=238, - dbl_soy_cotton=239, - dbl_soy_oats=240, - dbl_maize_soy=241, - blueberries=242, - cabbage=243, - cauliflower=244, - celery=245, - radishes=246, - turnips=247, - eggplants=248, - gourds=249, - cranberries=250, - dbl_barley_soy=254, -) - -CDL_LABELS_r = {v: k for k, v in CDL_LABELS.items()} -CDL_CROP_LABELS_r = {v: k for k, v in CDL_CROP_LABELS.items()} diff --git a/src/cultionet/data/modules.py b/src/cultionet/data/modules.py index c4fbef00..a91ec922 100644 --- a/src/cultionet/data/modules.py +++ b/src/cultionet/data/modules.py @@ -1,10 +1,11 @@ import typing as T -from torch.utils.data import Sampler -from pytorch_lightning import LightningDataModule -from torch_geometric.loader import DataLoader +import torch +from lightning import LightningDataModule +from torch.utils.data import DataLoader, Sampler from .datasets import EdgeDataset +from .utils import collate_fn class EdgeDataModule(LightningDataModule): @@ -16,10 +17,13 @@ def __init__( val_ds: T.Optional[EdgeDataset] = None, test_ds: T.Optional[EdgeDataset] = None, predict_ds: T.Optional[EdgeDataset] = None, - batch_size: int = 5, + batch_size: int = 4, num_workers: int = 0, shuffle: bool = True, sampler: T.Optional[Sampler] = None, + pin_memory: bool = False, + persistent_workers: bool = True, + generator: T.Optional[torch.Generator] = None, ): super().__init__() @@ -31,6 +35,11 @@ def __init__( self.num_workers = num_workers self.shuffle = shuffle self.sampler = sampler + self.pin_memory = pin_memory + self.persistent_workers = ( + False if num_workers == 0 else persistent_workers + ) + self.generator = generator def train_dataloader(self): """Returns a data loader for train data.""" @@ -40,6 +49,10 @@ def train_dataloader(self): shuffle=None if self.sampler is not None else self.shuffle, num_workers=self.num_workers, sampler=self.sampler, + pin_memory=self.pin_memory, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + generator=self.generator, ) def val_dataloader(self): @@ -47,8 +60,11 @@ def val_dataloader(self): return DataLoader( self.val_ds, batch_size=self.batch_size, - shuffle=self.shuffle, + shuffle=False, num_workers=self.num_workers, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + generator=self.generator, ) def test_dataloader(self): @@ -56,8 +72,11 @@ def test_dataloader(self): return DataLoader( self.test_ds, batch_size=self.batch_size, - shuffle=self.shuffle, + shuffle=False, num_workers=self.num_workers, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + generator=self.generator, ) def predict_dataloader(self): @@ -65,6 +84,9 @@ def predict_dataloader(self): return DataLoader( self.predict_ds, batch_size=self.batch_size, - shuffle=self.shuffle, + shuffle=False, num_workers=self.num_workers, + collate_fn=collate_fn, + persistent_workers=self.persistent_workers, + generator=self.generator, ) diff --git a/src/cultionet/data/spatial_dataset.py b/src/cultionet/data/spatial_dataset.py new file mode 100644 index 00000000..c23c9c1e --- /dev/null +++ b/src/cultionet/data/spatial_dataset.py @@ -0,0 +1,119 @@ +from pathlib import Path +from typing import Optional, Tuple + +import geopandas as gpd +import pygrts +from joblib import delayed, parallel_backend +from shapely.geometry import box +from torch.utils.data import Dataset + +from ..utils.model_preprocessing import ParallelProgress + + +def get_box_id(data_id: str, *bounds) -> tuple: + return data_id, box(*list(map(float, bounds))).centroid + + +class SpatialDataset(Dataset): + dataset_df = None + + @property + def grid_gpkg_path(self) -> Path: + return self.root / "dataset_grids.gpkg" + + def create_spatial_index(self, id_column: str, n_jobs: int): + """Creates the spatial index.""" + + if self.grid_gpkg_path.exists(): + self.dataset_df = gpd.read_file(self.grid_gpkg_path) + else: + self.dataset_df = self.to_frame(id_column=id_column, n_jobs=n_jobs) + self.dataset_df.to_file(self.grid_gpkg_path, driver="GPKG") + + def to_frame(self, id_column: str, n_jobs: int) -> gpd.GeoDataFrame: + """Converts the Dataset to a GeoDataFrame.""" + + with parallel_backend(backend="loky", n_jobs=n_jobs): + with ParallelProgress( + tqdm_kwargs={ + "total": len(self), + "desc": "Building GeoDataFrame", + "ascii": "\u2015\u25E4\u25E5\u25E2\u25E3\u25AA", + "colour": "green", + } + ) as pool: + results = pool( + delayed(get_box_id)( + data.batch_id, + data.left, + data.bottom, + data.right, + data.top, + ) + for data in self + ) + + ids, geometry = list(map(list, zip(*results))) + df = gpd.GeoDataFrame( + data=ids, + columns=[id_column], + geometry=geometry, + crs="epsg:4326", + ) + + return df + + def spatial_splits( + self, + val_frac: float, + id_column: str, + spatial_balance: bool = True, + crs: str = "EPSG:8857", + random_state: Optional[int] = None, + ) -> Tuple[Dataset, Dataset]: + """Takes spatially-balanced splits of the dataset.""" + + if spatial_balance: + # Separate train and validation by spatial location + + # Setup a quad-tree using the GRTS method + # (see https://github.com/jgrss/pygrts for details) + qt = pygrts.QuadTree( + self.dataset_df.to_crs(crs), + force_square=False, + ) + + # Recursively split the quad-tree until each grid has + # only one sample. + qt.split_recursive(max_samples=1) + + n_val = int(val_frac * len(self.dataset_df.index)) + # `qt.sample` random samples from the quad-tree in a + # spatially balanced manner. Thus, `df_val_sample` is + # a GeoDataFrame with `n_val` sites spatially balanced. + df_val_sample = qt.sample(n=n_val, random_state=random_state) + + # Since we only took one sample from each coordinate, + # we need to find all of the .pt files that share + # coordinates with the sampled sites. + val_mask = self.dataset_df[id_column].isin( + df_val_sample[id_column] + ) + else: + # Randomly sample a percentage for validation + df_val_ids = self.dataset_df.sample( + frac=val_frac, random_state=random_state + ).to_frame(name=id_column) + + # Get all ids for validation samples + val_mask = self.dataset_df[id_column].isin(df_val_ids[id_column]) + + # Get train/val indices + val_idx = self.dataset_df.loc[val_mask].index.values + train_idx = self.dataset_df.loc[~val_mask].index.values + + # Slice the dataset + train_ds = self[train_idx] + val_ds = self[val_idx] + + return train_ds, val_ds diff --git a/src/cultionet/data/store.py b/src/cultionet/data/store.py new file mode 100644 index 00000000..5e0cc11f --- /dev/null +++ b/src/cultionet/data/store.py @@ -0,0 +1,158 @@ +from pathlib import Path +from typing import Union + +import dask.array as da +import einops +import numpy as np +import pandas as pd +import torch +import xarray as xr +from dask.delayed import Delayed +from dask.utils import SerializableLock +from rasterio.windows import Window +from retry import retry + +from ..utils.logging import set_color_logger +from .data import Data + +logger = set_color_logger(__name__) + + +class BatchStore: + """``dask.array.store`` for data batches.""" + + lock_ = SerializableLock() + + def __init__( + self, + data: xr.DataArray, + write_path: Path, + res: float, + resampling: str, + region: str, + start_date: str, + end_date: str, + window_size: int, + padding: int, + compress_method: Union[int, str], + ): + self.data = data + self.res = res + self.resampling = resampling + self.region = region + self.start_date = start_date + self.end_date = end_date + self.write_path = write_path + self.window_size = window_size + self.padding = padding + self.compress_method = compress_method + + def __setitem__(self, key: tuple, item: np.ndarray) -> None: + time_range, index_range, y, x = key + + item_window = Window( + col_off=x.start, + row_off=y.start, + width=x.stop - x.start, + height=y.stop - y.start, + ) + pad_window = Window( + col_off=x.start, + row_off=y.start, + width=item.shape[-1], + height=item.shape[-2], + ) + + self.write_batch(item, w=item_window, w_pad=pad_window) + + @retry(IOError, tries=5, delay=1) + def write_batch(self, x: np.ndarray, w: Window, w_pad: Window): + image_height = self.window_size + self.padding * 2 + image_width = self.window_size + self.padding * 2 + + # Get row adjustments + row_after_to_pad = image_height - w_pad.height + + # Get column adjustments + col_after_to_pad = image_width - w_pad.width + + if any([row_after_to_pad > 0, col_after_to_pad > 0]): + x = np.pad( + x, + pad_width=( + (0, 0), + (0, 0), + (0, row_after_to_pad), + (0, col_after_to_pad), + ), + mode="constant", + constant_values=0, + ) + + x = einops.rearrange( + torch.from_numpy(x.astype('int32')).to(dtype=torch.int32), + 't c h w -> 1 c t h w', + ) + + assert x.shape[-2:] == ( + image_height, + image_width, + ), "The padded array does not have the correct height/width dimensions." + + batch_id = f"{self.region}_{self.start_date}_{self.end_date}_{w.row_off}_{w.col_off}" + + # Get the upper left lat/lon + ( + lat_left, + lat_bottom, + lat_right, + lat_top, + ) = self.data.gw.geodataframe.to_crs("epsg:4326").total_bounds.tolist() + + batch = Data( + x=x, + start_year=torch.tensor( + [pd.Timestamp(self.start_date).year], + dtype=torch.int32, + ), + end_year=torch.tensor( + [pd.Timestamp(self.end_date).year], + dtype=torch.int32, + ), + padding=[self.padding], + window_row_off=[w.row_off], + window_col_off=[w.col_off], + window_height=[w.height], + window_width=[w.width], + res=[self.res], + resampling=[self.resampling], + left=torch.tensor([lat_left], dtype=torch.float32), + bottom=torch.tensor([lat_bottom], dtype=torch.float32), + right=torch.tensor([lat_right], dtype=torch.float32), + top=torch.tensor([lat_top], dtype=torch.float32), + batch_id=[batch_id], + ) + + batch.to_file( + self.write_path / f"{batch_id}.pt", + compress=self.compress_method, + ) + + try: + _ = batch.from_file(self.write_path / f"{batch_id}.pt") + except EOFError: + raise IOError + + def __enter__(self) -> "BatchStore": + self.closed = False + + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.closed = True + + def _open(self) -> "BatchStore": + return self + + def save(self, data: da.Array, **kwargs) -> Delayed: + da.store(data, self, lock=self.lock_, compute=True, **kwargs) diff --git a/src/cultionet/data/utils.py b/src/cultionet/data/utils.py index 183d7c60..e06b5bd3 100644 --- a/src/cultionet/data/utils.py +++ b/src/cultionet/data/utils.py @@ -1,28 +1,71 @@ -import shutil import typing as T -from dataclasses import dataclass +from functools import singledispatch from pathlib import Path -import uuid +import cv2 +import geopandas as gpd import numpy as np -import xarray as xr +import pandas as pd import torch -from torch_geometric.data import Data +import xarray as xr +from affine import Affine +from rasterio.features import rasterize as rio_rasterize +from scipy.ndimage import label as nd_label +from scipy.ndimage import uniform_filter +from skimage.measure import regionprops + +from .data import Data + + +@singledispatch +def get_empty(template: torch.Tensor) -> torch.Tensor: + return torch.tensor([]) + + +@get_empty.register +def _(template: np.ndarray) -> np.ndarray: + return np.array([]) + + +@get_empty.register +def _(template: list) -> list: + return [] + + +@get_empty.register +def _(template: None) -> None: + return None + + +@singledispatch +def concat(value: torch.Tensor, other: torch.Tensor) -> torch.Tensor: + return torch.cat((value, other)) + + +@concat.register +def _(value: np.ndarray, other: np.ndarray) -> np.ndarray: + return np.concatenate((value, other)) + + +@concat.register +def _(value: list, other: list) -> list: + return value + other -from .datasets import EdgeDataset -from ..networks import SingleSensorNetwork -from ..utils.reshape import nd_to_columns -from ..utils.normalize import NormValues +def collate_fn(data_list: T.List[Data]) -> Data: + kwargs = {} + # Iterate over data keys + for key in data_list[0].to_dict().keys(): + # Get an empty container + key_value = get_empty(getattr(data_list[0], key)) + if key_value is not None: + # Fill the container + for sample in data_list: + key_value = concat(key_value, getattr(sample, key)) -@dataclass -class LabeledData: - x: np.ndarray - y: T.Union[None, np.ndarray] - bdist: T.Union[None, np.ndarray] - ori: T.Union[None, np.ndarray] - segments: T.Union[None, np.ndarray] - props: T.Union[None, T.List] + kwargs[key] = key_value + + return Data(**kwargs) def get_image_list_dims( @@ -37,170 +80,248 @@ def get_image_list_dims( return ntime, nbands -def create_data_object( - x: np.ndarray, - edge_indices: np.ndarray, - edge_attrs: np.ndarray, - ntime: int, - nbands: int, - height: int, - width: int, - y: T.Optional[np.ndarray] = None, - mask_y: T.Optional[np.ndarray] = None, - bdist: T.Optional[np.ndarray] = None, - ori: T.Optional[np.ndarray] = None, - zero_padding: T.Optional[int] = 0, - other: T.Optional[np.ndarray] = None, - **kwargs, -) -> Data: - """Creates a training data object.""" - # edge_indices = torch.tensor(edge_indices, dtype=torch.long).t().contiguous() - # edge_attrs = torch.tensor(edge_attrs, dtype=torch.float) - edge_indices = None - edge_attrs = None - x = torch.tensor(x, dtype=torch.float) - - boxes = None - box_labels = None - box_masks = None - if mask_y is not None: - boxes = mask_y["boxes"] - box_labels = mask_y["labels"] - box_masks = mask_y["masks"] - - if y is None: - train_data = Data( - x=x, - edge_index=edge_indices, - edge_attrs=edge_attrs, - height=height, - width=width, - ntime=ntime, - nbands=nbands, - boxes=boxes, - box_labels=box_labels, - box_masks=box_masks, - zero_padding=zero_padding, - **kwargs, +def split_multipolygons(df: gpd.GeoDataFrame) -> gpd.GeoDataFrame: + """Splits a MultiPolygon into a Polygon.""" + + # Check for multi-polygons + multi_polygon_mask = df.geom_type == "MultiPolygon" + + if multi_polygon_mask.any(): + new_polygons = [] + for _, multi_polygon_df in df.loc[multi_polygon_mask].iterrows(): + # Split the multi-polygon into a list of polygons + polygon_list = list(multi_polygon_df.geometry.geoms) + # Duplicate the row, replacing the geometry + for split_polygon in polygon_list: + new_polygons.append( + multi_polygon_df.to_frame().T.assign( + geometry=[split_polygon] + ) + ) + + # Stack and replace + df = pd.concat( + ( + df.loc[~multi_polygon_mask], + pd.concat(new_polygons), + ) ) + + return df + + +def roll( + arr_pad: np.ndarray, + shift: T.Union[int, T.Tuple[int, int]], + axis: T.Union[int, T.Tuple[int, int]], +) -> np.ndarray: + """Rolls array elements along a given axis and slices off padded edges.""" + return np.roll(arr_pad, shift, axis=axis)[1:-1, 1:-1] + + +def get_crop_count(array: np.ndarray, edge_class: int) -> np.ndarray: + array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") + + rarray = roll(array_pad, 1, axis=0) + crop_count = np.uint8((rarray > 0) & (rarray != edge_class)) + rarray = roll(array_pad, -1, axis=0) + crop_count += np.uint8((rarray > 0) & (rarray != edge_class)) + rarray = roll(array_pad, 1, axis=1) + crop_count += np.uint8((rarray > 0) & (rarray != edge_class)) + rarray = roll(array_pad, -1, axis=1) + crop_count += np.uint8((rarray > 0) & (rarray != edge_class)) + + return crop_count + + +def get_edge_count(array: np.ndarray, edge_class: int) -> np.ndarray: + array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") + + edge_count = np.uint8(roll(array_pad, 1, axis=0) == edge_class) + edge_count += np.uint8(roll(array_pad, -1, axis=0) == edge_class) + edge_count += np.uint8(roll(array_pad, 1, axis=1) == edge_class) + edge_count += np.uint8(roll(array_pad, -1, axis=1) == edge_class) + + return edge_count + + +def get_non_count(array: np.ndarray) -> np.ndarray: + array_pad = np.pad(array, pad_width=((1, 1), (1, 1)), mode="edge") + + non_count = np.uint8(roll(array_pad, 1, axis=0) == 0) + non_count += np.uint8(roll(array_pad, -1, axis=0) == 0) + non_count += np.uint8(roll(array_pad, 1, axis=1) == 0) + non_count += np.uint8(roll(array_pad, -1, axis=1) == 0) + + return non_count + + +def cleanup_edges( + array: np.ndarray, + original: np.ndarray, + edge_class: int, +) -> np.ndarray: + """Removes crop pixels that border non-crop pixels.""" + array_pad = np.pad(original, pad_width=((1, 1), (1, 1)), mode="edge") + original_zero = np.uint8(roll(array_pad, 1, axis=0) == 0) + original_zero += np.uint8(roll(array_pad, -1, axis=0) == 0) + original_zero += np.uint8(roll(array_pad, 1, axis=1) == 0) + original_zero += np.uint8(roll(array_pad, -1, axis=1) == 0) + + # Fill edges + array = np.where( + (array == 0) + & (get_crop_count(array, edge_class) > 0) + & (get_edge_count(array, edge_class) > 0), + edge_class, + array, + ) + # Remove crops next to non-crop + array = np.where( + (array > 0) + & (array != edge_class) + & (get_non_count(array) > 0) + & (get_edge_count(array, edge_class) > 0), + 0, + array, + ) + # Fill in non-cropland + array = np.where(original_zero == 4, 0, array) + # Remove isolated crop pixels (i.e., crop clumps with 2 or fewer pixels) + array = np.where( + (array > 0) + & (array != edge_class) + & (get_crop_count(array, edge_class) <= 1) + & (get_edge_count(array, edge_class) <= 1), + 0, + array, + ) + + return array + + +def create_boundary_distances( + labels_array: np.ndarray, train_type: str, cell_res: float +) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Creates distances from boundaries.""" + if train_type.lower() == "polygon": + mask = np.uint8(labels_array) else: - y = torch.tensor( - y.flatten(), - dtype=torch.float if "float" in y.dtype.name else torch.long, - ) - bdist_ = torch.tensor(bdist.flatten(), dtype=torch.float) - # ori_ = torch.tensor(ori.flatten(), dtype=torch.float) - - if other is None: - train_data = Data( - x=x, - edge_index=edge_indices, - edge_attrs=edge_attrs, - y=y, - bdist=bdist_, - # ori=ori_, - height=height, - width=width, - ntime=ntime, - nbands=nbands, - boxes=boxes, - box_labels=box_labels, - box_masks=box_masks, - zero_padding=zero_padding, - **kwargs, - ) - else: - other_ = torch.tensor(other.flatten(), dtype=torch.float) - - train_data = Data( - x=x, - edge_index=edge_indices, - edge_attrs=edge_attrs, - y=y, - bdist=bdist_, - # ori=ori_, - other=other_, - height=height, - width=width, - ntime=ntime, - nbands=nbands, - boxes=boxes, - box_labels=box_labels, - box_masks=box_masks, - zero_padding=zero_padding, - **kwargs, - ) + mask = np.uint8(1 - labels_array) - # Ensure the correct node count - train_data.num_nodes = x.shape[0] + # Get unique segments + segments = nd_label(mask)[0] - return train_data + # Get the distance from edges + bdist = cv2.distanceTransform(mask, cv2.DIST_L2, 3) + bdist *= cell_res + grad_x = cv2.Sobel( + np.pad(bdist, 5, mode="edge"), cv2.CV_32F, dx=1, dy=0, ksize=5 + ) + grad_y = cv2.Sobel( + np.pad(bdist, 5, mode="edge"), cv2.CV_32F, dx=0, dy=1, ksize=5 + ) + ori = cv2.phase(grad_x, grad_y, angleInDegrees=False) + ori = ori[5:-5, 5:-5] / np.deg2rad(360) + ori[labels_array == 0] = 0 + + return mask, segments, bdist, ori + + +def normalize_boundary_distances( + labels_array: np.ndarray, + train_type: str, + cell_res: float, + normalize: bool = True, +) -> T.Tuple[np.ndarray, np.ndarray]: + """Normalizes boundary distances.""" -def create_network_data(xvars: np.ndarray, ntime: int, nbands: int) -> Data: - # Create the network - nwk = SingleSensorNetwork( - np.ascontiguousarray(xvars, dtype="float64"), k=3 + # Create the boundary distances + __, segments, bdist, ori = create_boundary_distances( + labels_array, train_type, cell_res ) + dist_max = 1e9 + if normalize: + dist_max = 1.0 + # Normalize each segment by the local max distance + props = regionprops(segments, intensity_image=bdist) + for p in props: + if p.label > 0: + bdist = np.where( + segments == p.label, bdist / p.max_intensity, bdist + ) + bdist = np.nan_to_num( + bdist.clip(0, dist_max), nan=1.0, neginf=1.0, posinf=1.0 + ) + ori = np.nan_to_num(ori.clip(0, 1), nan=1.0, neginf=1.0, posinf=1.0) + + return bdist, ori - ( - edge_indices_a, - edge_indices_b, - edge_attrs_diffs, - edge_attrs_dists, - xpos, - ypos, - ) = nwk.create_network() - edge_indices = np.c_[edge_indices_a, edge_indices_b] - edge_attrs = np.c_[edge_attrs_diffs, edge_attrs_dists] - xy = np.c_[xpos, ypos] - nfeas, nrows, ncols = xvars.shape - xvars = nd_to_columns(xvars, nfeas, nrows, ncols) - - return create_data_object( - xvars, - edge_indices, - edge_attrs, - xy, - ntime=ntime, - nbands=nbands, - height=nrows, - width=ncols, + +def edge_gradient(array: np.ndarray) -> np.ndarray: + """Calculates the morphological gradient of crop fields.""" + se = np.array([[1, 1], [1, 1]], dtype="uint8") + array = np.uint8( + cv2.morphologyEx(np.uint8(array), cv2.MORPH_GRADIENT, se) > 0 ) + return array + -class NetworkDataset(object): - def __init__(self, data: Data, data_path: Path, data_values: NormValues): - self.data_values = data_values - self.data_path = data_path +def polygon_to_array( + df: gpd.GeoDataFrame, + reference_data: xr.DataArray, + column: str, + fill_value: int = 0, + default_value: int = 1, + all_touched: bool = False, + dtype: str = "uint8", +) -> np.ndarray: + """Converts a polygon, or polygons, to an array.""" - self.processed_path = self.data_path / "processed" - self.processed_path.mkdir(parents=True, exist_ok=True) + df = df.copy() - # Create a random filename so that the processed - # directory can be used by other processes - filename = str(uuid.uuid4()).replace("-", "") - pt_name = f"{filename}_.pt" - self.pattern = f"{filename}*.pt" - self.pt_file = self.processed_path / pt_name + if df.crs != reference_data.crs: + # Transform the geometry + df = df.to_crs(reference_data.crs) - self._save(data) + # Get the reference bounds + left, bottom, right, top = reference_data.gw.bounds + # Get intersecting polygons + df = df.cx[left:right, bottom:top] + # Clip the polygons to the reference bounds + df = gpd.clip(df, reference_data.gw.geodataframe) - def _save(self, data: Data) -> None: - torch.save(data, self.pt_file) + # Get the output dimensions + dst_transform = Affine( + reference_data.gw.cellx, 0.0, left, 0.0, -reference_data.gw.celly, top + ) - def clear(self) -> None: - if self.processed_path.is_dir(): - shutil.rmtree(str(self.processed_path)) + # Get the shape geometry and encoding value + shapes = list(zip(df.geometry, df[column])) - def unlink(self) -> None: - self.pt_file.unlink() + # Override dtype + if (dtype == "uint8") and (df[column].max() > 255): + dtype = "int32" - @property - def ds(self) -> EdgeDataset: - return EdgeDataset( - self.data_path, - data_means=self.data_values.mean, - data_stds=self.data_values.std, - pattern=self.pattern, - ) + # Convert the polygon(s) to an array + polygon_array = rio_rasterize( + shapes, + out_shape=(reference_data.gw.nrows, reference_data.gw.ncols), + fill=fill_value, + transform=dst_transform, + all_touched=all_touched, + default_value=default_value, + dtype=dtype, + ) + + return polygon_array + + +def fillz(x: np.ndarray) -> np.ndarray: + """Fills zeros with the focal mean value.""" + + focal_mean = uniform_filter(x, size=(0, 0, 3, 3), mode='reflect') + + return np.where(x == 0, focal_mean, x) diff --git a/src/cultionet/enums/__init__.py b/src/cultionet/enums/__init__.py new file mode 100644 index 00000000..3af0408f --- /dev/null +++ b/src/cultionet/enums/__init__.py @@ -0,0 +1,101 @@ +import enum + + +class StrEnum(str, enum.Enum): + """ + Source: + https://github.com/irgeek/StrEnum/blob/master/strenum/__init__.py + """ + + def __new__(cls, value, *args, **kwargs): + return super().__new__(cls, value, *args, **kwargs) + + def __str__(self) -> str: + return self.value + + +class DataColumns(StrEnum): + GEOID = "geo_id" + YEAR = "year" + + +class AttentionTypes(StrEnum): + NATTEN = "natten" + SPATIAL_CHANNEL = "spatial_channel" + + +class CLISteps(StrEnum): + CREATE = "create" + CREATE_PREDICT = "create-predict" + SKFOLDCV = "skfoldcv" + TRAIN = "train" + TRAIN_TRANSFER = "train-transfer" + PREDICT = "predict" + PREDICT_TRANSFER = "predict-transfer" + VERSION = "version" + + +class Destinations(StrEnum): + CKPT = 'ckpt' + DATA = 'data' + FIGURES = 'figures' + PREDICT = 'predict' + PROCESSED = 'processed' + TRAIN = 'train' + TEST = 'test' + TIME_SERIES_VARS = 'time_series_vars' + USER_TRAIN = 'user_train' + + +class InferenceNames(StrEnum): + CLASSES_L2 = 'classes_l2' + CLASSES_L3 = 'classes_l3' + CROP_TYPE = 'crop_type' + DISTANCE = 'distance' + EDGE = 'edge' + CROP = 'crop' + RECONSTRUCTION = 'reconstruction' + + +class LossTypes(StrEnum): + BOUNDARY = "BoundaryLoss" + CENTERLINE_DICE = "CLDiceLoss" + CLASS_BALANCED_MSE = "ClassBalancedMSELoss" + LOG_COSH = "LogCoshLoss" + FOCAL_TVERSKY = "FocalTverskyLoss" + TANIMOTO_COMPLEMENT = "TanimotoComplementLoss" + TANIMOTO = "TanimotoDistLoss" + TANIMOTO_COMBINED = "TanimotoCombined" + TVERSKY = "TverskyLoss" + + +class ModelNames(StrEnum): + CLASS_INFO = "classes.info" + CKPT_NAME = "last.ckpt" + CKPT_TRANSFER_NAME = "last_transfer.ckpt" + NORM = "last.norm" + + +class ModelTypes(StrEnum): + TOWERUNET = 'TowerUNet' + + +class ResBlockTypes(StrEnum): + RES = 'res' + RESA = 'resa' + + +class LearningRateSchedulers(StrEnum): + COSINE_ANNEALING_LR = 'CosineAnnealingLR' + EXPONENTIAL_LR = 'ExponentialLR' + ONE_CYCLE_LR = 'OneCycleLR' + STEP_LR = 'StepLR' + + +class ValidationNames(StrEnum): + TRUE_CROP = 'true_crop' + TRUE_EDGE = 'true_edge' + TRUE_CROP_AND_EDGE = 'true_crop_and_edge' + TRUE_CROP_OR_EDGE = 'true_crop_or_edge' + TRUE_CROP_TYPE = 'true_crop_type' + MASK = 'mask' # 1|0 mask diff --git a/src/cultionet/layers/__init__.py b/src/cultionet/layers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cultionet/layers/encodings.py b/src/cultionet/layers/encodings.py new file mode 100644 index 00000000..09877a32 --- /dev/null +++ b/src/cultionet/layers/encodings.py @@ -0,0 +1,35 @@ +""" +Source: + https://github.com/VSainteuf/utae-paps/blob/main/src/backbones/positional_encoding.py + + MIT License + + Copyright (c) 2021 VSainteuf +""" + +import numpy as np +import torch + + +def calc_angle(position: int, hid_idx: int, d_hid: int, time_scaler: int): + return position / np.power(time_scaler, 2 * (hid_idx // 2) / d_hid) + + +def get_posi_angle_vec(position, d_hid, time_scaler): + return [ + calc_angle(position, hid_j, d_hid, time_scaler) + for hid_j in range(d_hid) + ] + + +def get_sinusoid_encoding_table( + positions: int, d_hid: int, time_scaler: int = 1_000 +): + positions = list(range(positions)) + sinusoid_table = np.array( + [get_posi_angle_vec(pos_i, d_hid, time_scaler) for pos_i in positions] + ) + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + return torch.tensor(sinusoid_table, dtype=torch.float32) diff --git a/src/cultionet/layers/weights.py b/src/cultionet/layers/weights.py new file mode 100644 index 00000000..a702cd42 --- /dev/null +++ b/src/cultionet/layers/weights.py @@ -0,0 +1,39 @@ +from typing import Callable + +import torch.nn as nn + + +def init_attention_weights(module: Callable) -> None: + if isinstance( + module, + ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.Linear, + ), + ): + nn.init.kaiming_normal_(module.weight.data, a=0, mode="fan_in") + if module.bias is not None: + nn.init.normal_(module.bias.data) + elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0.0) + + +def init_conv_weights(module: Callable) -> None: + if isinstance( + module, + ( + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.Linear, + ), + ): + nn.init.kaiming_normal_(module.weight.data, a=0, mode="fan_in") + if module.bias is not None: + nn.init.normal_(module.bias.data) + elif isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): + nn.init.normal_(module.weight.data, 1.0, 0.02) + nn.init.constant_(module.bias.data, 0.0) diff --git a/src/cultionet/losses/__init__.py b/src/cultionet/losses/__init__.py index ad63f773..c563ca79 100644 --- a/src/cultionet/losses/__init__.py +++ b/src/cultionet/losses/__init__.py @@ -1 +1,11 @@ -from .losses import TanimotoDistLoss +from .losses import ( + BoundaryLoss, + ClassBalancedMSELoss, + CombinedLoss, + FocalTverskyLoss, + LogCoshLoss, + LossPreprocessing, + TanimotoComplementLoss, + TanimotoDistLoss, + TverskyLoss, +) diff --git a/src/cultionet/losses/losses.py b/src/cultionet/losses/losses.py index 1cdf1397..2e327307 100644 --- a/src/cultionet/losses/losses.py +++ b/src/cultionet/losses/losses.py @@ -1,163 +1,106 @@ import typing as T -import warnings -import numpy as np +import einops import torch +import torch.nn as nn import torch.nn.functional as F -from torch_geometric.data import Data -import torchmetrics -from . import topological -from ..models import model_utils +class LossPreprocessing(nn.Module): + def __init__( + self, transform_logits: bool = False, one_hot_targets: bool = True + ): + super().__init__() -def one_hot(targets: torch.Tensor, dims: int) -> torch.Tensor: - return F.one_hot(targets.contiguous().view(-1), dims).float() - - -class LossPreprocessing(torch.nn.Module): - def __init__(self, inputs_are_logits: bool, apply_transform: bool): - super(LossPreprocessing, self).__init__() - - self.inputs_are_logits = inputs_are_logits - self.apply_transform = apply_transform - self.sigmoid = torch.nn.Sigmoid() + self.transform_logits = transform_logits + self.one_hot_targets = one_hot_targets def forward( - self, inputs: torch.Tensor, targets: torch.Tensor = None - ) -> T.Tuple[torch.Tensor, T.Union[torch.Tensor, None]]: + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, + ) -> T.Tuple[torch.Tensor, torch.Tensor]: """Forward pass to transform logits. If logits are single-dimension then they are transformed by Sigmoid. If logits are multi-dimension then they are transformed by Softmax. """ - if self.inputs_are_logits: - if targets is not None: - if (len(targets.unique()) > inputs.size(1)) or ( - targets.unique().max() + 1 > inputs.size(1) - ): - raise ValueError( - "The targets should be ordered values of equal length to the inputs 2nd dimension." - ) - if self.apply_transform: - if inputs.shape[1] == 1: - inputs = self.sigmoid(inputs) - else: - inputs = F.softmax(inputs, dim=1, dtype=inputs.dtype) + + if self.transform_logits: + if inputs.shape[1] == 1: + inputs = F.sigmoid(inputs).to(dtype=inputs.dtype) + else: + inputs = F.softmax(inputs, dim=1, dtype=inputs.dtype) inputs = inputs.clip(0, 1) - if targets is not None: - targets = one_hot(targets, dims=inputs.shape[1]) + + if self.one_hot_targets and (inputs.shape[1] > 1): + + with torch.no_grad(): + targets = einops.rearrange( + F.one_hot(targets, num_classes=inputs.shape[1]), + 'b h w c -> b c h w', + ) else: - inputs = inputs.unsqueeze(1) - targets = targets.unsqueeze(1) + if len(targets.shape) == 3: + targets = einops.rearrange(targets, 'b h w -> b 1 h w') - return inputs, targets + if mask is not None: + if len(mask.shape) == 3: + mask = einops.rearrange(mask, 'b h w -> b 1 h w') -class TopologicalLoss(torch.nn.Module): - """ - Reference: - https://arxiv.org/abs/1906.05404 - https://arxiv.org/pdf/1906.05404.pdf - https://github.com/HuXiaoling/TopoLoss/blob/5cb98177de50a3694f5886137ff7c6f55fd51493/topoloss_pytorch.py - """ + # Apply a mask to zero-out weight + inputs = inputs * mask + targets = targets * mask - def __init__(self): - super(TopologicalLoss, self).__init__() + return inputs, targets - self.gc = model_utils.GraphToConv() - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor, data: Data - ) -> torch.Tensor: - height = ( - int(data.height) if data.batch is None else int(data.height[0]) - ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) +class CombinedLoss(nn.Module): + def __init__(self, losses: T.List[T.Callable]): + super().__init__() - input_dims = inputs.shape[1] - # Probabilities are ether Sigmoid or Softmax - input_index = 0 if input_dims == 1 else 1 + self.losses = losses - inputs = self.gc(inputs, batch_size, height, width) - targets = self.gc(targets.unsqueeze(1), batch_size, height, width) - # Clone tensors before detaching from GPU - inputs_clone = inputs.clone() - targets_clone = targets.clone() + def forward( + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Performs a single forward pass. - topo_cp_weight_map = np.zeros( - inputs_clone[:, input_index].shape, dtype="float32" - ) - topo_cp_ref_map = np.zeros( - inputs_clone[:, input_index].shape, dtype="float32" - ) - topo_mask = np.zeros(inputs_clone[:, input_index].shape, dtype="uint8") + Parameters + ========== + inputs + Predictions from model (probabilities or labels), shaped (B, C, H, W). + targets + Ground truth values, shaped (B, C, H, W). + mask + Values to mask (0) or keep (1), shaped (B, 1, H, W). + + Returns + ======= + Average distance loss (float) + """ - # Detach from GPU for gudhi libary - inputs_clone = ( - inputs_clone[:, input_index].float().cpu().detach().numpy() - ) - targets_clone = targets_clone[:, 0].float().cpu().detach().numpy() + loss = 0.0 + for loss_func in self.losses: + loss = loss + loss_func( + inputs=inputs, + targets=targets, + mask=mask, + ) - pd_lh, bcp_lh, dcp_lh, pairs_lh_pa = topological.critical_points( - inputs_clone - ) - pd_gt, __, __, pairs_lh_gt = topological.critical_points(targets_clone) - - if pairs_lh_pa and pairs_lh_gt: - for batch in range(0, batch_size): - if (pd_lh[batch].size > 0) and (pd_gt[batch].size > 0): - ( - __, - idx_holes_to_fix, - idx_holes_to_remove, - ) = topological.compute_dgm_force( - pd_lh[batch], pd_gt[batch], pers_thresh=0.03 - ) - ( - topo_cp_weight_map[batch], - topo_cp_ref_map[batch], - topo_mask[batch], - ) = topological.set_topology_weights( - likelihood=inputs_clone[batch], - topo_cp_weight_map=topo_cp_weight_map[batch], - topo_cp_ref_map=topo_cp_ref_map[batch], - topo_mask=topo_mask[batch], - bcp_lh=bcp_lh[batch], - dcp_lh=dcp_lh[batch], - idx_holes_to_fix=idx_holes_to_fix, - idx_holes_to_remove=idx_holes_to_remove, - height=inputs.shape[-2], - width=inputs.shape[-1], - ) - - topo_cp_weight_map = torch.tensor( - topo_cp_weight_map, dtype=inputs.dtype, device=inputs.device - ) - topo_cp_ref_map = torch.tensor( - topo_cp_ref_map, dtype=inputs.dtype, device=inputs.device - ) - topo_mask = torch.tensor(topo_mask, dtype=bool, device=inputs.device) - if not topo_mask.any(): - topo_loss = ( - (inputs[:, input_index] * topo_cp_weight_map) - topo_cp_ref_map - ) ** 2 - else: - topo_loss = ( - ( - inputs[:, input_index][topo_mask] - * topo_cp_weight_map[topo_mask] - ) - - topo_cp_ref_map[topo_mask] - ) ** 2 + loss = loss / len(self.losses) - return topo_loss.mean() + return loss -class TanimotoComplementLoss(torch.nn.Module): +class TanimotoComplementLoss(nn.Module): """Tanimoto distance loss. Adapted from publications and source code below: @@ -192,68 +135,118 @@ def __init__( self, smooth: float = 1e-5, depth: int = 5, - targets_are_labels: bool = True, + transform_logits: bool = False, + one_hot_targets: bool = True, ): - super(TanimotoComplementLoss, self).__init__() + super().__init__() self.smooth = smooth self.depth = depth - self.targets_are_labels = targets_are_labels + self.one_hot_targets = one_hot_targets self.preprocessor = LossPreprocessing( - inputs_are_logits=True, apply_transform=True + transform_logits=transform_logits, + one_hot_targets=one_hot_targets, ) + def tanimoto_distance( + self, + y: torch.Tensor, + yhat: torch.Tensor, + dim: T.Optional[T.Tuple[int, ...]] = None, + ) -> torch.Tensor: + if dim is None: + dim = (1, 2, 3) + + scale = 1.0 / self.depth + + tpl = y * yhat + sq_sum = y**2 + yhat**2 + + tpl = tpl.sum(dim=dim) + sq_sum = sq_sum.sum(dim=dim) + + denominator = 0.0 + for d in range(0, self.depth): + a = 2.0**d + b = -(2.0 * a - 1.0) + denominator = denominator + torch.reciprocal( + ((a * sq_sum) + (b * tpl)) + self.smooth + ) + + numerator = tpl + self.smooth + + if dim == (2, 3): + distance = ((numerator * denominator) * scale).sum(dim=1) + else: + distance = (numerator * denominator) * scale + + loss = 1.0 - distance + + return loss + def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, + dim: T.Optional[T.Tuple[int, ...]] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predictions from model (probabilities or labels). - targets: Ground truth values. - - Returns: - Tanimoto distance loss (float) + Parameters + ========== + inputs + Predictions from model (probabilities or labels), shaped (B, C, H, W). + targets + Ground truth values, shaped (B, C, H, W). + mask + Values to mask (0) or keep (1), shaped (B, H, W) or (B, 1, H, W). + + Returns + ======= + Tanimoto distance loss (float) """ - if self.targets_are_labels: - # Discrete targets - if inputs.shape[1] > 1: - # Softmax and One-hot encoding - inputs, targets = self.preprocessor(inputs, targets) - - if len(inputs.shape) == 1: - inputs = inputs.unsqueeze(1) - if len(targets.shape) == 1: - targets = targets.unsqueeze(1) - - length = inputs.shape[1] - - def tanimoto(y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor: - scale = 1.0 / self.depth - tpl = (y * yhat).sum(dim=0) - numerator = tpl + self.smooth - sq_sum = (y**2 + yhat**2).sum(dim=0) - denominator = torch.zeros(length, dtype=inputs.dtype).to( - device=inputs.device - ) - for d in range(0, self.depth): - a = 2**d - b = -(2.0 * a - 1.0) - denominator = denominator + torch.reciprocal( - (a * sq_sum) + (b * tpl) + self.smooth - ) + inputs, targets = self.preprocessor( + inputs=inputs, targets=targets, mask=mask + ) + + loss1 = self.tanimoto_distance(targets, inputs, dim=dim) + loss2 = self.tanimoto_distance(1.0 - targets, 1.0 - inputs, dim=dim) + loss = (loss1 + loss2) * 0.5 - return numerator * denominator * scale + return loss.mean() + + +def tanimoto_dist( + ypred: torch.Tensor, + ytrue: torch.Tensor, + smooth: float, + dim: T.Optional[T.Tuple[int, ...]] = None, +) -> torch.Tensor: + """Tanimoto distance.""" + + if dim is None: + dim = (1, 2, 3) + + ytrue = ytrue.to(dtype=ypred.dtype) + + tpl = ypred * ytrue + sq_sum = ypred**2 + ytrue**2 + + tpl = tpl.sum(dim=dim) + sq_sum = sq_sum.sum(dim=dim) + + numerator = tpl + smooth + denominator = (sq_sum - tpl) + smooth + distance = numerator / denominator - score = tanimoto(targets, inputs) - if inputs.shape[1] == 1: - score = (score + tanimoto(1.0 - targets, 1.0 - inputs)) * 0.5 + loss = 1.0 - distance - return (1.0 - score).mean() + return loss -class TanimotoDistLoss(torch.nn.Module): +class TanimotoDistLoss(nn.Module): """Tanimoto distance loss. References: @@ -294,325 +287,577 @@ class TanimotoDistLoss(torch.nn.Module): def __init__( self, smooth: float = 1e-5, - beta: T.Optional[float] = 0.999, - class_counts: T.Optional[torch.Tensor] = None, - scale_pos_weight: T.Optional[bool] = False, - transform_logits: T.Optional[bool] = False, + transform_logits: bool = False, + one_hot_targets: bool = True, ): - super(TanimotoDistLoss, self).__init__() - - if scale_pos_weight and (class_counts is None): - warnings.warn( - "Cannot balance classes without class weights. Weights will be derived for each batch.", - UserWarning, - ) + super().__init__() self.smooth = smooth - self.beta = beta - self.class_counts = class_counts - self.scale_pos_weight = scale_pos_weight - self.transform_logits = transform_logits + self.preprocessor = LossPreprocessing( - inputs_are_logits=True, apply_transform=True + transform_logits=transform_logits, + one_hot_targets=one_hot_targets, ) def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predictions from model (probabilities, logits or labels). - targets: Ground truth values. - - Returns: - Tanimoto distance loss (float) + Parameters + ========== + inputs + Predictions from model (probabilities or labels), shaped (B, C, H, W). + targets + Ground truth values, shaped (B, C, H, W). + mask + Values to mask (0) or keep (1), shaped (B, 1, H, W). + + Returns + ======= + Tanimoto distance loss (float) """ - if self.transform_logits: - if len(inputs.shape) == 1: - inputs, __ = self.preprocessor(inputs) - else: - if inputs.shape[1] == 1: - inputs, __ = self.preprocessor(inputs) - else: - inputs, targets = self.preprocessor(inputs, targets) - else: - if len(inputs.shape) > 1: - if inputs.shape[1] > 1: - targets = one_hot(targets, dims=inputs.shape[1]) - - if len(inputs.shape) == 1: - inputs = inputs.unsqueeze(1) - if len(targets.shape) == 1: - targets = targets.unsqueeze(1) - - def tanimoto_loss(yhat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - y = y.to(dtype=yhat.dtype) - if self.scale_pos_weight: - if self.class_counts is None: - class_counts = y.sum(dim=0) - else: - class_counts = self.class_counts - effective_num = 1.0 - self.beta**class_counts - weights = (1.0 - self.beta) / effective_num - weights = weights / weights.sum() * class_counts.shape[0] - else: - weights = torch.ones( - inputs.shape[1], dtype=inputs.dtype, device=inputs.device - ) - # Reduce - tpl = (yhat * y).sum(dim=0) - sq_sum = (yhat**2 + y**2).sum(dim=0) - numerator = tpl * weights + self.smooth - denominator = (sq_sum - tpl) * weights + self.smooth - tanimoto = numerator / denominator - loss = 1.0 - tanimoto - return loss + inputs, targets = self.preprocessor( + inputs=inputs, targets=targets, mask=mask + ) - loss = tanimoto_loss(inputs, targets) - if inputs.shape[1] == 1: - compl_loss = tanimoto_loss(1.0 - inputs, 1.0 - targets) - loss = (loss + compl_loss) * 0.5 + loss1 = tanimoto_dist( + inputs, + targets, + smooth=self.smooth, + ) + loss2 = tanimoto_dist( + 1.0 - inputs, + 1.0 - targets, + smooth=self.smooth, + ) + loss = (loss1 + loss2) * 0.5 return loss.mean() -class CrossEntropyLoss(torch.nn.Module): - """Cross entropy loss.""" - - def __init__( - self, - weight: T.Optional[torch.Tensor] = None, - reduction: T.Optional[str] = "mean", - label_smoothing: T.Optional[float] = 0.1, - ): - super(CrossEntropyLoss, self).__init__() - - self.loss_func = torch.nn.CrossEntropyLoss( - weight=weight, reduction=reduction, label_smoothing=label_smoothing - ) +class LogCoshLoss(nn.Module): + def __init__(self): + super().__init__() def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predictions from model. - targets: Ground truth values. - - Returns: - Loss (float) + Parameters + ========== + inputs + Predictions from model (real values), shaped (B, H, W) or (B, 1, H, W). + targets + Targets (real values), shaped (B, H, W) or (B, 1, H, W). + mask + Values to mask (0) or keep (1), shaped (B, H, W) or (B, 1, H, W). + + Returns + ======= + Log Hyperbolic Cosine loss (float) """ - return self.loss_func(inputs, targets) + if len(inputs.shape) == 3: + inputs = einops.rearrange(inputs, 'b h w -> b 1 h w') -class FocalLoss(torch.nn.Module): - """Focal loss. + if len(targets.shape) == 3: + targets = einops.rearrange(targets, 'b h w -> b 1 h w') - Reference: - https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch/notebook - """ + loss = torch.log(torch.cosh(inputs - targets)) - def __init__( - self, - alpha: float = 0.8, - gamma: float = 2.0, - weight: T.Optional[torch.Tensor] = None, - label_smoothing: T.Optional[float] = 0.1, - ): - super(FocalLoss, self).__init__() + if mask is not None: - self.alpha = alpha - self.gamma = gamma + if len(mask.shape) == 3: + mask = einops.rearrange(mask, 'b h w -> b 1 h w') - self.preprocessor = LossPreprocessing( - inputs_are_logits=True, apply_transform=True - ) - self.cross_entropy_loss = torch.nn.CrossEntropyLoss( - weight=weight, reduction="none", label_smoothing=label_smoothing - ) + loss = loss * mask + loss = loss.sum() / mask.sum() + + else: + loss = loss.mean() + + return loss + + +class ClassBalancedMSELoss(nn.Module): + r"""Class-balanced mean squared error loss. + + License: + MIT License + Copyright (c) 2023 Adill Al-Ashgar + + References: + @article{xia_etal_2024, + title={Crop field extraction from high resolution remote sensing images based on semantic edges and spatial structure map}, + author={Xia, Liegang and Liu, Ruiyan and Su, Yishao and Mi, Shulin and Yang, Dezhi and Chen, Jun and Shen, Zhanfeng}, + journal={Geocarto International}, + volume={39}, + number={1}, + pages={2302176}, + year={2024}, + publisher={Taylor \& Francis}, + } + + Source: + https://github.com/Adillwma/ACB_MSE + """ + + def __init__(self): + super().__init__() def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: - inputs, targets = self.preprocessor(inputs, targets) - ce_loss = self.cross_entropy_loss(inputs, targets.half()) - ce_exp = torch.exp(-ce_loss) - focal_loss = self.alpha * (1.0 - ce_exp) ** self.gamma * ce_loss + """ + Parameters + ========== + inputs + Predictions (probabilities), shaped (B, H, W) or (B, 1, H, W). + targets + Ground truth values, shaped (B, H, W) or (B, 1, H, W). + mask + Values to mask (0) or keep (1), shaped (B, H, W) or (B, 1, H, W). + """ - return focal_loss.mean() + if len(inputs.shape) == 4: + inputs = einops.rearrange(inputs, 'b 1 h w -> b h w') + if len(targets.shape) == 4: + targets = einops.rearrange(targets, 'b 1 h w -> b h w') -class QuantileLoss(torch.nn.Module): - """Loss function for quantile regression. + if mask is not None: - Reference: - https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss + if len(mask.shape) == 4: + mask = einops.rearrange(mask, 'b 1 h w -> b h w') - THE MIT License + neg_mask = (targets <= 0.5) & (mask != 0) + pos_mask = (targets > 0.5) & (mask != 0) + target_count = mask.sum() - Copyright 2020 Jan Beitner - """ + else: - def __init__(self, quantiles: T.Tuple[float, float, float]): - super(QuantileLoss, self).__init__() + neg_mask = targets <= 0.5 + pos_mask = ~neg_mask + target_count = targets.nelement() - self.quantiles = quantiles + targets_neg = targets[neg_mask] + targets_pos = targets[pos_mask] - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - """Performs a single forward pass. + inputs_neg = inputs[neg_mask] + inputs_pos = inputs[pos_mask] - Args: - inputs: Predictions from model (probabilities, logits or labels). - targets: Ground truth values. + beta = pos_mask.sum() / target_count - Returns: - Quantile loss (float) - """ - losses = [] - for i, q in enumerate(self.quantiles): - errors = targets - inputs[:, i] - losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(1)) - loss = torch.cat(losses, dim=1).sum(dim=1).mean() + assert 0 <= beta <= 1 + + neg_loss = torch.log( + torch.cosh( + torch.pow(inputs_neg - targets_neg.to(dtype=inputs.dtype), 2) + ) + ).mean() + + pos_loss = torch.log( + torch.cosh( + torch.pow(inputs_pos - targets_pos.to(dtype=inputs.dtype), 2) + ) + ).mean() + + if torch.isnan(neg_loss): + neg_loss = 0.0 + + if torch.isnan(pos_loss): + pos_loss = 0.0 + + loss = beta * neg_loss + (1.0 - beta) * pos_loss return loss -class WeightedL1Loss(torch.nn.Module): - """Weighted L1Loss loss.""" +class BoundaryLoss(nn.Module): + """Boundary loss. + + License: + MIT License + Copyright (c) 2023 Hoel Kervadec + + Reference: + @inproceedings{kervadec_etal_2019, + title={Boundary loss for highly unbalanced segmentation}, + author={Kervadec, Hoel and Bouchtiba, Jihene and Desrosiers, Christian and Granger, Eric and Dolz, Jose and Ayed, Ismail Ben}, + booktitle={International conference on medical imaging with deep learning}, + pages={285--296}, + year={2019}, + organization={PMLR}, + } + + Source: + https://github.com/LIVIAETS/boundary-loss/tree/108bd9892adca476e6cdf424124bc6268707498e + """ def __init__(self): - super(WeightedL1Loss, self).__init__() + super().__init__() def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predictions from model. - targets: Ground truth values. - - Returns: - Loss (float) + Parameters + ========== + inputs + Predictions from model (probabilities), shaped (B, 1, H, W). + targets + Target distance map, shaped (B, H, W) or (B, 1, H, W). + mask + Values to mask (0) or keep (1), shaped (B, H, W) or (B, 1, H, W). + + Returns + ======= + Boundary loss (float) """ - inputs = inputs.contiguous().view(-1) - targets = targets.contiguous().view(-1) - mae = torch.abs(inputs - targets) - weight = inputs + targets - loss = (mae * weight).mean() + if len(targets.shape) == 3: + targets = einops.rearrange(targets, 'b h w -> b 1 h w') - return loss + if mask is not None: + if len(mask.shape) == 3: + mask = einops.rearrange(mask, 'b h w -> b 1 h w') + # Apply a mask to zero-out weight + inputs = inputs * mask + targets = targets * mask -class MSELoss(torch.nn.Module): - """MSE loss.""" + hadamard_product = torch.einsum('bchw, bchw -> bchw', inputs, targets) + + if mask is not None: + hadamard_mean = hadamard_product.sum() / mask.sum() + else: + hadamard_mean = hadamard_product.mean() + + return 1.0 - hadamard_mean - def __init__(self): - super(MSELoss, self).__init__() - self.loss_func = torch.nn.MSELoss() +class SoftSkeleton(nn.Module): + """Soft skeleton. + + License: + MIT License + Copyright (c) 2021 Johannes C. Paetzold and Suprosanna Shit + + Reference: + @inproceedings{shit_etal_2021, + title={clDice-a novel topology-preserving loss function for tubular structure segmentation}, + author={Shit, Suprosanna and Paetzold, Johannes C and Sekuboyina, Anjany and Ezhov, Ivan and Unger, Alexander and Zhylka, Andrey and Pluim, Josien PW and Bauer, Ulrich and Menze, Bjoern H}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={16560--16569}, + year={2021}, + } + + Source: + https://github.com/jocpae/clDice/tree/master + """ + + def __init__(self, num_iter: int): + super().__init__() + + self.num_iter = num_iter + + def soft_erode(self, img: torch.Tensor) -> torch.Tensor: + if len(img.shape) == 4: + + p1 = -F.max_pool2d( + -img, kernel_size=(3, 1), stride=(1, 1), padding=(1, 0) + ) + p2 = -F.max_pool2d( + -img, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1) + ) + + eroded = torch.min(p1, p2) + + elif len(img.shape) == 5: + + p1 = -F.max_pool3d( + -img, + kernel_size=(3, 1, 1), + stride=(1, 1, 1), + padding=(1, 0, 0), + ) + p2 = -F.max_pool3d( + -img, + kernel_size=(1, 3, 1), + stride=(1, 1, 1), + padding=(0, 1, 0), + ) + p3 = -F.max_pool3d( + -img, + kernel_size=(1, 1, 3), + stride=(1, 1, 1), + padding=(0, 0, 1), + ) + + eroded = torch.min(torch.min(p1, p2), p3) + + return eroded + + def soft_dilate(self, img: torch.Tensor) -> torch.Tensor: + if len(img.shape) == 4: + dilated = F.max_pool2d( + img, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) + ) + elif len(img.shape) == 5: + dilated = F.max_pool3d( + img, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1) + ) + + return dilated + + def soft_open(self, img: torch.Tensor) -> torch.Tensor: + return self.soft_dilate(self.soft_erode(img)) + + def soft_skeleton(self, img: torch.Tensor) -> torch.Tensor: + img1 = self.soft_open(img) + skeleton = F.relu(img - img1) + + for j in range(self.num_iter): + img = self.soft_erode(img) + img1 = self.soft_open(img) + delta = F.relu(img - img1) + skeleton = skeleton + F.relu(delta - skeleton * delta) + + return skeleton + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.soft_skeleton(x) + + +class CLDiceLoss(nn.Module): + """Centerline Dice loss. + + License: + MIT License + Copyright (c) 2021 Johannes C. Paetzold and Suprosanna Shit + + Reference: + @inproceedings{shit_etal_2021, + title={clDice-a novel topology-preserving loss function for tubular structure segmentation}, + author={Shit, Suprosanna and Paetzold, Johannes C and Sekuboyina, Anjany and Ezhov, Ivan and Unger, Alexander and Zhylka, Andrey and Pluim, Josien PW and Bauer, Ulrich and Menze, Bjoern H}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={16560--16569}, + year={2021}, + } + + Source: + https://github.com/jocpae/clDice/tree/master + """ + + def __init__(self, smooth: float = 1.0, num_iter: int = 10): + super().__init__() + + self.smooth = smooth + + self.soft_skeleton = SoftSkeleton(num_iter=num_iter) + + def precision_recall( + self, skeleton: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + return ((skeleton * mask).sum() + self.smooth) / ( + skeleton.sum() + self.smooth + ) def forward( - self, inputs: torch.Tensor, targets: torch.Tensor + self, + inputs: torch.Tensor, + targets: torch.Tensor, + transform_logits: bool = True, + mask: T.Optional[torch.Tensor] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predictions from model. - targets: Ground truth values. - - Returns: - Loss (float) + Parameters + ========== + inputs + Predictions from model (probabilities), shaped (B, 1, H, W). + targets + Binary targets, where background is 0 and targets are 1, shaped (B, H, W). + mask + Values to mask (0) or keep (1), shaped (B, 1, H, W). + + Returns + ======= + Centerline Dice loss (float) """ - return self.loss_func( - inputs.contiguous().view(-1), targets.contiguous().view(-1) + + targets = einops.rearrange(targets, 'b h w -> b 1 h w') + + if transform_logits: + inputs = F.softmax(inputs, dim=1)[:, [1]] + + # Get the predicted label + y_pred = (inputs > 0.5).long() + + # Add background + # TODO: this could be optional + pred_background = (1 - y_pred).abs() + y_pred = torch.cat((pred_background, y_pred), dim=1) + + true_background = (1 - targets).abs() + y_true = torch.cat((true_background, targets), dim=1) + + if mask is not None: + y_true = y_true * mask + y_pred = y_pred * mask + + pred_skeleton = self.soft_skeleton(y_pred.to(dtype=inputs.dtype)) + true_skeleton = self.soft_skeleton(y_true.to(dtype=inputs.dtype)) + + topo_precision = self.precision_recall(pred_skeleton, y_true) + topo_recall = self.precision_recall(true_skeleton, y_pred) + + cl_dice = 1.0 - 2.0 * (topo_precision * topo_recall) / ( + topo_precision + topo_recall ) + return cl_dice -class BoundaryLoss(torch.nn.Module): - """Boundary (surface) loss. - Reference: - https://github.com/LIVIAETS/boundary-loss - """ +class TverskyLoss(nn.Module): + """Tversky loss.""" - def __init__(self): - super(BoundaryLoss, self).__init__() + def __init__( + self, + alpha: float = 0.4, + beta: float = 0.6, + smooth: float = 1.0, + transform_logits: bool = False, + one_hot_targets: bool = True, + ): + super().__init__() + + self.alpha = alpha + self.beta = beta + self.smooth = smooth - self.gc = model_utils.GraphToConv() + self.preprocessor = LossPreprocessing( + transform_logits=transform_logits, + one_hot_targets=one_hot_targets, + ) def forward( - self, inputs: torch.Tensor, targets: torch.Tensor, data: Data + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, + dim: T.Optional[T.Tuple[int, ...]] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predicted probabilities. - targets: Ground truth inverse distance transform, where distances - along edges are 1. - data: Data object used to extract dimensions. - - Returns: - Loss (float) + Parameters + ========== + inputs + Predictions from model (probabilities), shaped (B, H, W) or (B, 1, H, W). + targets + Target labels, shaped (B, H, W) or (B, 1, H, W). + mask + Values to mask (0) or keep (1), shaped (B, H, W) or (B, 1, H, W). + + Returns + ======= + Tversky loss (float) """ - height = ( - int(data.height) if data.batch is None else int(data.height[0]) + + if dim is None: + dim = (1, 2, 3) + + inputs, targets = self.preprocessor( + inputs=inputs, targets=targets, mask=mask ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) - inputs = self.gc(inputs.unsqueeze(1), batch_size, height, width) - targets = self.gc(targets.unsqueeze(1), batch_size, height, width) + if mask is not None: - return torch.einsum("bchw, bchw -> bchw", inputs, targets).mean() + if len(mask.shape) == 3: + mask = einops.rearrange(mask, 'b h w -> b 1 h w') + inputs = inputs * mask + targets = targets * mask -class MultiScaleSSIMLoss(torch.nn.Module): - """Multi-scale Structural Similarity Index Measure loss.""" + tp = (inputs * targets).sum(dim=dim) + fp = ((1 - targets) * inputs).sum(dim=dim) + fn = (targets * (1 - inputs)).sum(dim=dim) - def __init__(self): - super(MultiScaleSSIMLoss, self).__init__() - - self.gc = model_utils.GraphToConv() - self.msssim = torchmetrics.MultiScaleStructuralSimilarityIndexMeasure( - gaussian_kernel=False, - kernel_size=3, - data_range=1.0, - k1=1e-4, - k2=9e-4, + tversky = (tp + self.smooth) / ( + tp + self.alpha * fp + self.beta * fn + self.smooth + ) + + loss = 1.0 - tversky + + return loss.mean() + + +class FocalTverskyLoss(nn.Module): + """Focal Tversky loss.""" + + def __init__( + self, + alpha: float = 0.2, + beta: float = 0.8, + gamma: float = 2.0, + smooth: float = 1.0, + ): + super().__init__() + + self.gamma = gamma + + self.tversky_loss = TverskyLoss( + alpha=alpha, + beta=beta, + smooth=smooth, ) def forward( - self, inputs: torch.Tensor, targets: torch.Tensor, data: Data + self, + inputs: torch.Tensor, + targets: torch.Tensor, + mask: T.Optional[torch.Tensor] = None, + dim: T.Optional[T.Tuple[int, ...]] = None, ) -> torch.Tensor: """Performs a single forward pass. - Args: - inputs: Predicted probabilities. - targets: Ground truth inverse distance transform, where distances - along edges are 1. - data: Data object used to extract dimensions. - - Returns: - Loss (float) + Parameters + ========== + inputs + Predictions from model (probabilities), shaped (B, H, W) or (B, 1, H, W). + targets + Target labels, shaped (B, H, W) or (B, 1, H, W). + mask + Values to mask (0) or keep (1), shaped (B, H, W) or (B, 1, H, W). + + Returns + ======= + Focal Tversky loss (float) """ - height = ( - int(data.height) if data.batch is None else int(data.height[0]) - ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) - inputs = self.gc(inputs.unsqueeze(1), batch_size, height, width) - targets = self.gc(targets.unsqueeze(1), batch_size, height, width).to( - dtype=inputs.dtype + tversky_loss = self.tversky_loss( + inputs=inputs, + targets=targets, + mask=mask, + dim=dim, ) - loss = 1.0 - self.msssim(inputs, targets) + loss = torch.pow(tversky_loss, self.gamma) - return loss + return loss.mean() diff --git a/src/cultionet/losses/topological.py b/src/cultionet/losses/topological.py deleted file mode 100644 index 4ebe150b..00000000 --- a/src/cultionet/losses/topological.py +++ /dev/null @@ -1,304 +0,0 @@ -import typing as T - -import numpy as np -import torch -import gudhi - - -def critical_points( - x: torch.Tensor, -) -> T.Tuple[T.List[np.ndarray], T.List[np.ndarray], T.List[np.ndarray], bool]: - batch_size = x.shape[0] - lh_vector = 1.0 - x.flatten() - cubical_complex = gudhi.CubicalComplex( - dimensions=x.shape, top_dimensional_cells=lh_vector - ) - cubical_complex.persistence(homology_coeff_field=2, min_persistence=0) - cofaces = cubical_complex.cofaces_of_persistence_pairs() - cofaces_batch_size = len(cofaces[0]) - - if (cofaces_batch_size == 0) or (cofaces_batch_size != batch_size): - return None, None, None, False - - pd_lh = [ - np.c_[ - lh_vector[cofaces[0][batch][:, 0]], - lh_vector[cofaces[0][batch][:, 1]], - ] - for batch in range(0, batch_size) - ] - bcp_lh = [ - np.c_[ - cofaces[0][batch][:, 0] // x.shape[-1], - cofaces[0][batch][:, 0] % x.shape[-1], - ] - for batch in range(0, batch_size) - ] - dcp_lh = [ - np.c_[ - cofaces[0][batch][:, 1] // x.shape[-1], - cofaces[0][batch][:, 1] % x.shape[-1], - ] - for batch in range(0, batch_size) - ] - - return pd_lh, bcp_lh, dcp_lh, True - - -def compute_dgm_force( - lh_dgm: np.ndarray, - gt_dgm: np.ndarray, - pers_thresh: float = 0.03, - pers_thresh_perfect: float = 0.99, - do_return_perfect: bool = False, -) -> T.Tuple[np.ndarray, np.ndarray]: - """Compute the persistent diagram of the image. - - Args: - lh_dgm: likelihood persistent diagram. - gt_dgm: ground truth persistent diagram. - pers_thresh: Persistent threshold, which also called dynamic value, which measure the difference. - between the local maximum critical point value with its neighouboring minimum critical point value. - Values smaller than the persistent threshold should be filtered. Default is 0.03. - pers_thresh_perfect: The distance difference between two critical points that can be considered as - correct match. Default is 0.99. - do_return_perfect: Return the persistent point or not from the matching. Default is ``False``. - - Returns: - force_list: The matching between the likelihood and ground truth persistent diagram. - idx_holes_to_fix: The index of persistent points that requires to fix in the following training process. - idx_holes_to_remove: The index of persistent points that require to remove for the following training - process. - """ - lh_pers = abs(lh_dgm[:, 1] - lh_dgm[:, 0]) - if gt_dgm.shape[0] == 0: - gt_pers = None - gt_n_holes = 0 - else: - gt_pers = gt_dgm[:, 1] - gt_dgm[:, 0] - gt_n_holes = gt_pers.size # number of holes in gt - - if (gt_pers is None) or (gt_n_holes == 0): - idx_holes_to_fix = np.array([], dtype=int) - idx_holes_to_remove = np.array(list(set(range(lh_pers.size)))) - idx_holes_perfect = [] - else: - # check to ensure that all gt dots have persistence 1 - tmp = gt_pers > pers_thresh_perfect - - # get "perfect holes" - holes which do not need to be fixed, i.e., find top - # lh_n_holes_perfect indices - # check to ensure that at least one dot has persistence 1; it is the hole - # formed by the padded boundary - # if no hole is ~1 (ie >.999) then just take all holes with max values - tmp = lh_pers > pers_thresh_perfect # old: assert tmp.sum() >= 1 - lh_pers_sorted_indices = np.argsort(lh_pers)[::-1] - if np.sum(tmp) >= 1: - lh_n_holes_perfect = tmp.sum() - idx_holes_perfect = lh_pers_sorted_indices[:lh_n_holes_perfect] - else: - idx_holes_perfect = [] - - # find top gt_n_holes indices - idx_holes_to_fix_or_perfect = lh_pers_sorted_indices[:gt_n_holes] - - # the difference is holes to be fixed to perfect - idx_holes_to_fix = np.array( - list(set(idx_holes_to_fix_or_perfect) - set(idx_holes_perfect)) - ) - - # remaining holes are all to be removed - idx_holes_to_remove = lh_pers_sorted_indices[gt_n_holes:] - - # only select the ones whose persistence is large enough - # set a threshold to remove meaningless persistence dots - pers_thd = pers_thresh - idx_valid = np.where(lh_pers > pers_thd)[0] - idx_holes_to_remove = np.array( - list(set(idx_holes_to_remove).intersection(set(idx_valid))) - ) - - force_list = np.zeros(lh_dgm.shape) - - # push each hole-to-fix to (0,1) - if idx_holes_to_fix.shape[0] > 0: - force_list[idx_holes_to_fix, 0] = 0 - lh_dgm[idx_holes_to_fix, 0] - force_list[idx_holes_to_fix, 1] = 1 - lh_dgm[idx_holes_to_fix, 1] - - # push each hole-to-remove to (0,1) - if idx_holes_to_remove.shape[0] > 0: - force_list[idx_holes_to_remove, 0] = lh_pers[ - idx_holes_to_remove - ] / np.sqrt(2.0) - force_list[idx_holes_to_remove, 1] = -lh_pers[ - idx_holes_to_remove - ] / np.sqrt(2.0) - - if do_return_perfect: - return ( - force_list, - idx_holes_to_fix, - idx_holes_to_remove, - idx_holes_perfect, - ) - - return force_list, idx_holes_to_fix, idx_holes_to_remove - - -def adjust_holes_to_fix( - topo_cp_weight_map: np.ndarray, - topo_cp_ref_map: np.ndarray, - topo_mask: np.ndarray, - hole_indices: np.ndarray, - pairs: np.ndarray, - fill_weight: int, - fill_ref: int, - height: int, - width: int, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: - mask = ( - (pairs[hole_indices][:, 0] >= 0) - * (pairs[hole_indices][:, 0] < height) - * (pairs[hole_indices][:, 1] >= 0) - * (pairs[hole_indices][:, 1] < width) - ) - indices = ( - pairs[hole_indices][:, 0][mask], - pairs[hole_indices][:, 1][mask], - ) - topo_cp_weight_map[indices] = fill_weight - topo_cp_ref_map[indices] = fill_ref - topo_mask[indices] = 1 - - return topo_cp_weight_map, topo_cp_ref_map, topo_mask - - -def adjust_holes_to_remove( - likelihood: np.ndarray, - topo_cp_weight_map: np.ndarray, - topo_cp_ref_map: np.ndarray, - topo_mask: np.ndarray, - hole_indices: np.ndarray, - pairs_b: np.ndarray, - pairs_d: np.ndarray, - fill_weight: int, - fill_ref: int, - height: int, - width: int, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: - mask = ( - (pairs_b[hole_indices][:, 0] >= 0) - * (pairs_b[hole_indices][:, 0] < height) - * (pairs_b[hole_indices][:, 1] >= 0) - * (pairs_b[hole_indices][:, 1] < width) - ) - indices = ( - pairs_b[hole_indices][:, 0][mask], - pairs_b[hole_indices][:, 1][mask], - ) - topo_cp_weight_map[indices] = fill_weight - topo_mask[indices] = 1 - - nested_mask = ( - mask - * (pairs_d[hole_indices][:, 0] >= 0) - * (pairs_d[hole_indices][:, 0] < height) - * (pairs_d[hole_indices][:, 1] >= 0) - * (pairs_d[hole_indices][:, 1] < width) - ) - indices_b = ( - pairs_b[hole_indices][:, 0][nested_mask], - pairs_b[hole_indices][:, 1][nested_mask], - ) - indices_d = ( - pairs_d[hole_indices][:, 0][nested_mask], - pairs_d[hole_indices][:, 1][nested_mask], - ) - topo_cp_ref_map[indices_b] = likelihood[indices_d] - topo_mask[indices_b] = 1 - - indices_inv = ( - pairs_b[hole_indices][:, 0][mask], - pairs_b[hole_indices][:, 1][mask], - ) - topo_cp_ref_map[indices_inv] = fill_ref - topo_mask[indices_inv] = 1 - - return topo_cp_weight_map, topo_cp_ref_map, topo_mask - - -def set_topology_weights( - likelihood: np.ndarray, - topo_cp_weight_map: np.ndarray, - topo_cp_ref_map: np.ndarray, - topo_mask: np.ndarray, - bcp_lh: np.ndarray, - dcp_lh: np.ndarray, - idx_holes_to_fix: np.ndarray, - idx_holes_to_remove: np.ndarray, - height: int, - width: int, -) -> T.Tuple[np.ndarray, np.ndarray, np.ndarray]: - x = 0 - y = 0 - - if len(idx_holes_to_fix) > 0: - topo_cp_weight_map, topo_cp_ref_map, topo_mask = adjust_holes_to_fix( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_fix, - pairs=bcp_lh, - fill_weight=1, - fill_ref=0, - height=height, - width=width, - ) - topo_cp_weight_map, topo_cp_ref_map, topo_mask = adjust_holes_to_fix( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_fix, - pairs=dcp_lh, - fill_weight=1, - fill_ref=1, - height=height, - width=width, - ) - if len(idx_holes_to_remove) > 0: - ( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask, - ) = adjust_holes_to_remove( - likelihood, - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_remove, - pairs_b=bcp_lh, - pairs_d=dcp_lh, - fill_weight=1, - fill_ref=1, - height=height, - width=width, - ) - ( - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask, - ) = adjust_holes_to_remove( - likelihood, - topo_cp_weight_map, - topo_cp_ref_map, - topo_mask=topo_mask, - hole_indices=idx_holes_to_remove, - pairs_b=dcp_lh, - pairs_d=bcp_lh, - fill_weight=1, - fill_ref=0, - height=height, - width=width, - ) - - return topo_cp_weight_map, topo_cp_ref_map, topo_mask diff --git a/src/cultionet/model.py b/src/cultionet/model.py index 444d7047..17f5137c 100644 --- a/src/cultionet/model.py +++ b/src/cultionet/model.py @@ -1,37 +1,40 @@ +import json +import logging import typing as T from pathlib import Path -import logging -import json +import attr +import lightning as L import numpy as np -from scipy.stats import mode as sci_mode -from rasterio.windows import Window import torch -from torch_geometric.data import Data -import pytorch_lightning as pl -from pytorch_lightning.callbacks import ( - ModelCheckpoint, - LearningRateMonitor, - StochasticWeightAveraging, - ModelPruning, -) -from pytorch_lightning.callbacks.early_stopping import EarlyStopping +from lightning.pytorch.tuner import Tuner +from rasterio.windows import Window +from scipy.stats import mode as sci_mode from torchvision import transforms -from .callbacks import LightningGTiffWriter -from .data.const import SCALE_FACTOR -from .data.datasets import EdgeDataset, zscores +from .callbacks import ( + PROGRESS_BAR_CALLBACK, + LightningGTiffWriter, + setup_callbacks, +) +from .data import Data +from .data.constant import SCALE_FACTOR +from .data.datasets import EdgeDataset from .data.modules import EdgeDataModule -from .data.samplers import EpochRandomSampler -from .models.cultio import GeoRefinement -from .models.lightning import ( - CultioLitModel, - MaskRCNNLitModel, - RefineLitModel, + +# from .data.samplers import EpochRandomSampler +from .enums import ( + AttentionTypes, + LearningRateSchedulers, + LossTypes, + ModelNames, + ModelTypes, + ResBlockTypes, ) -from .utils.reshape import ModelOutputs +from .models.lightning import CultionetLitModel, CultionetLitTransferModel from .utils.logging import set_color_logger - +from .utils.normalize import NormValues +from .utils.reshape import ModelOutputs logging.getLogger("lightning").addHandler(logging.NullHandler()) logging.getLogger("lightning").propagate = False @@ -39,285 +42,166 @@ logger = set_color_logger(__name__) -def fit_maskrcnn( - dataset: EdgeDataset, - ckpt_file: T.Union[str, Path], - test_dataset: T.Optional[EdgeDataset] = None, - val_frac: T.Optional[float] = 0.2, - batch_size: T.Optional[int] = 4, - accumulate_grad_batches: T.Optional[int] = 1, - filters: T.Optional[int] = 64, - num_classes: T.Optional[int] = 2, - learning_rate: T.Optional[float] = 0.001, - epochs: T.Optional[int] = 30, - save_top_k: T.Optional[int] = 1, - early_stopping_patience: T.Optional[int] = 7, - early_stopping_min_delta: T.Optional[float] = 0.01, - gradient_clip_val: T.Optional[float] = 1.0, - reset_model: T.Optional[bool] = False, - auto_lr_find: T.Optional[bool] = False, - device: T.Optional[str] = "gpu", - devices: T.Optional[int] = 1, - weight_decay: T.Optional[float] = 1e-5, - precision: T.Optional[int] = 32, - stochastic_weight_averaging: T.Optional[bool] = False, - stochastic_weight_averaging_lr: T.Optional[float] = 0.05, - stochastic_weight_averaging_start: T.Optional[float] = 0.8, - model_pruning: T.Optional[bool] = False, - resize_height: T.Optional[int] = 201, - resize_width: T.Optional[int] = 201, - min_image_size: T.Optional[int] = 100, - max_image_size: T.Optional[int] = 600, - trainable_backbone_layers: T.Optional[int] = 3, -): - """Fits a Mask R-CNN instance model. - - Args: - dataset (EdgeDataset): The dataset to fit on. - ckpt_file (str | Path): The checkpoint file path. - test_dataset (Optional[EdgeDataset]): A test dataset to evaluate on. If given, early stopping - will switch from the validation dataset to the test dataset. - val_frac (Optional[float]): The fraction of data to use for model validation. - batch_size (Optional[int]): The data batch size. - filters (Optional[int]): The number of initial model filters. - learning_rate (Optional[float]): The model learning rate. - epochs (Optional[int]): The number of epochs. - save_top_k (Optional[int]): The number of top-k model checkpoints to save. - early_stopping_patience (Optional[int]): The patience (epochs) before early stopping. - early_stopping_min_delta (Optional[float]): The minimum change threshold before early stopping. - gradient_clip_val (Optional[float]): A gradient clip limit. - reset_model (Optional[bool]): Whether to reset an existing model. Otherwise, pick up from last epoch of - an existing model. - auto_lr_find (Optional[bool]): Whether to search for an optimized learning rate. - device (Optional[str]): The device to train on. Choices are ['cpu', 'gpu']. - devices (Optional[int]): The number of GPU devices to use. - weight_decay (Optional[float]): The weight decay passed to the optimizer. Default is 1e-5. - precision (Optional[int]): The data precision. Default is 32. - stochastic_weight_averaging (Optional[bool]): Whether to use stochastic weight averaging. - Default is False. - stochastic_weight_averaging_lr (Optional[float]): The stochastic weight averaging learning rate. - Default is 0.05. - stochastic_weight_averaging_start (Optional[float]): The stochastic weight averaging epoch start. - Default is 0.8. - model_pruning (Optional[bool]): Whether to prune the model. Default is False. - """ - ckpt_file = Path(ckpt_file) - - # Split the dataset into train/validation - train_ds, val_ds = dataset.split_train_val(val_frac=val_frac) - - # Setup the data module - data_module = EdgeDataModule( - train_ds=train_ds, - val_ds=val_ds, - test_ds=test_dataset, - batch_size=batch_size, - num_workers=0, - shuffle=True, +@attr.s +class CultionetParams: + ckpt_file: T.Union[str, Path] = attr.ib(converter=Path, default=None) + spatial_partitions: str = attr.ib(default=None) + dataset: EdgeDataset = attr.ib(default=None) + test_dataset: T.Optional[EdgeDataset] = attr.ib(default=None) + val_frac: float = attr.ib(converter=float, default=0.2) + batch_size: int = attr.ib(converter=int, default=4) + load_batch_workers: int = attr.ib(converter=int, default=0) + edge_class: int = attr.ib(converter=int, default=None) + class_counts: torch.Tensor = attr.ib(default=None) + hidden_channels: int = attr.ib(converter=int, default=64) + model_type: str = attr.ib(converter=str, default=ModelTypes.TOWERUNET) + activation_type: str = attr.ib(converter=str, default="SiLU") + dropout: float = attr.ib(converter=float, default=0.1) + dilations: T.Union[int, T.Sequence[int]] = attr.ib( + converter=list, default=None ) - lit_model = MaskRCNNLitModel( - cultionet_model_file=ckpt_file.parent / "cultionet.pt", - cultionet_num_features=train_ds.num_features, - cultionet_num_time_features=train_ds.num_time_features, - cultionet_filters=filters, - cultionet_num_classes=num_classes, - learning_rate=learning_rate, - weight_decay=weight_decay, - resize_height=resize_height, - resize_width=resize_width, - min_image_size=min_image_size, - max_image_size=max_image_size, - trainable_backbone_layers=trainable_backbone_layers, + res_block_type: str = attr.ib(converter=str, default=ResBlockTypes.RESA) + attention_weights: str = attr.ib(default=None) + optimizer: str = attr.ib(converter=str, default="AdamW") + loss_name: str = attr.ib( + converter=str, default=LossTypes.TANIMOTO_COMPLEMENT ) - - if reset_model: - if ckpt_file.is_file(): - ckpt_file.unlink() - model_file = ckpt_file.parent / "maskrcnn.pt" - if model_file.is_file(): - model_file.unlink() - - # Checkpoint - cb_train_loss = ModelCheckpoint( - dirpath=ckpt_file.parent, - filename=ckpt_file.stem, - save_last=True, - save_top_k=save_top_k, - mode="min", - monitor="loss", - every_n_train_steps=0, - every_n_epochs=1, + learning_rate: float = attr.ib(converter=float, default=0.01) + lr_scheduler: str = attr.ib( + converter=str, default=LearningRateSchedulers.ONE_CYCLE_LR ) - # Validation and test loss - cb_val_loss = ModelCheckpoint(monitor="val_loss") - # Early stopping - early_stop_callback = EarlyStopping( - monitor="val_loss", - min_delta=early_stopping_min_delta, - patience=early_stopping_patience, - mode="min", - check_on_train_epoch_end=False, + steplr_step_size: int = attr.ib(converter=int, default=5) + weight_decay: float = attr.ib(converter=float, default=1e-3) + eps: float = attr.ib(converter=float, default=1e-4) + ckpt_name: str = attr.ib(converter=str, default="last") + model_name: str = attr.ib(converter=str, default="cultionet") + pool_by_max: bool = attr.ib(default=False) + batchnorm_first: bool = attr.ib(default=False) + scale_pos_weight: bool = attr.ib(default=False) + save_batch_val_metrics: bool = attr.ib(default=False) + epochs: int = attr.ib(converter=int, default=100) + accumulate_grad_batches: int = attr.ib(converter=int, default=1) + gradient_clip_val: float = attr.ib(converter=float, default=1.0) + gradient_clip_algorithm: str = attr.ib(converter=str, default="norm") + precision: T.Union[int, str] = attr.ib(default="16-mixed") + device: str = attr.ib(converter=str, default="gpu") + devices: int = attr.ib(converter=int, default=1) + reset_model: bool = attr.ib(default=False) + auto_lr_find: bool = attr.ib(default=False) + stochastic_weight_averaging: bool = attr.ib(default=False) + stochastic_weight_averaging_lr: float = attr.ib( + converter=float, default=0.05 ) - # Learning rate - lr_monitor = LearningRateMonitor(logging_interval="step") - callbacks = [lr_monitor, cb_train_loss, cb_val_loss, early_stop_callback] - if stochastic_weight_averaging: - callbacks.append( - StochasticWeightAveraging( - swa_lrs=stochastic_weight_averaging_lr, - swa_epoch_start=stochastic_weight_averaging_start, - ) + stochastic_weight_averaging_start: float = attr.ib( + converter=float, default=0.8 + ) + model_pruning: bool = attr.ib(default=False) + skip_train: bool = attr.ib(default=False) + finetune: str = attr.ib(default=None) + strategy: str = attr.ib(converter=str, default="ddp") + profiler: str = attr.ib(default=None) + + def check_checkpoint(self) -> None: + if self.reset_model: + if self.ckpt_file.is_file(): + self.ckpt_file.unlink() + + model_file = self.ckpt_file.parent / f"{self.model_name}.pt" + if model_file.is_file(): + model_file.unlink() + + def update_channels( + self, data_module: EdgeDataModule + ) -> "CultionetParams": + self.in_channels = data_module.train_ds.num_channels + self.in_time = data_module.train_ds.num_time + + return self + + def get_callback_params(self) -> dict: + return dict( + ckpt_file=self.ckpt_file, + stochastic_weight_averaging=self.stochastic_weight_averaging, + stochastic_weight_averaging_lr=self.stochastic_weight_averaging_lr, + stochastic_weight_averaging_start=self.stochastic_weight_averaging_start, + model_pruning=self.model_pruning, ) - if 0 < model_pruning <= 1: - callbacks.append(ModelPruning("l1_unstructured", amount=model_pruning)) - trainer = pl.Trainer( - default_root_dir=str(ckpt_file.parent), - callbacks=callbacks, - enable_checkpointing=True, - auto_lr_find=auto_lr_find, - auto_scale_batch_size=False, - accumulate_grad_batches=accumulate_grad_batches, - gradient_clip_val=gradient_clip_val, - gradient_clip_algorithm="value", - check_val_every_n_epoch=1, - min_epochs=5 if epochs >= 5 else epochs, - max_epochs=epochs, - precision=precision, - devices=None if device == "cpu" else devices, - num_processes=0, - accelerator=device, - log_every_n_steps=50, - profiler=None, - deterministic=False, - benchmark=False, - ) + def get_datamodule_params(self) -> dict: + return dict( + dataset=self.dataset, + test_dataset=self.test_dataset, + val_frac=self.val_frac, + spatial_partitions=self.spatial_partitions, + batch_size=self.batch_size, + load_batch_workers=self.load_batch_workers, + ) - if auto_lr_find: - trainer.tune(model=lit_model, datamodule=data_module) - else: - trainer.fit( - model=lit_model, - datamodule=data_module, - ckpt_path=ckpt_file if ckpt_file.is_file() else None, + def get_lightning_params(self) -> dict: + return dict( + in_channels=self.in_channels, + in_time=self.in_time, + hidden_channels=self.hidden_channels, + model_type=self.model_type, + dropout=self.dropout, + activation_type=self.activation_type, + dilations=self.dilations, + res_block_type=self.res_block_type, + attention_weights=self.attention_weights, + optimizer=self.optimizer, + loss_name=self.loss_name, + learning_rate=self.learning_rate, + lr_scheduler=self.lr_scheduler, + steplr_step_size=self.steplr_step_size, + weight_decay=self.weight_decay, + eps=self.eps, + ckpt_name=self.ckpt_name, + model_name=self.model_name, + pool_by_max=self.pool_by_max, + batchnorm_first=self.batchnorm_first, + class_counts=self.class_counts, + edge_class=self.edge_class, + scale_pos_weight=self.scale_pos_weight, + save_batch_val_metrics=self.save_batch_val_metrics, ) - if test_dataset is not None: - trainer.test( - model=lit_model, - dataloaders=data_module.test_dataloader(), - ckpt_path="last", - ) + def get_trainer_params(self) -> dict: + return dict( + default_root_dir=str(self.ckpt_file.parent), + enable_checkpointing=True, + accumulate_grad_batches=self.accumulate_grad_batches, + gradient_clip_val=self.gradient_clip_val, + gradient_clip_algorithm=self.gradient_clip_algorithm, + check_val_every_n_epoch=1, + min_epochs=5 if self.epochs >= 5 else self.epochs, + max_epochs=self.epochs, + precision=self.precision, + devices=self.devices, + accelerator=self.device, + log_every_n_steps=50, + deterministic=False, + benchmark=False, + strategy=self.strategy, + profiler=self.profiler, + ) -def fit( + +def get_data_module( dataset: EdgeDataset, - ckpt_file: T.Union[str, Path], test_dataset: T.Optional[EdgeDataset] = None, val_frac: T.Optional[float] = 0.2, spatial_partitions: T.Optional[T.Union[str, Path]] = None, - partition_name: T.Optional[str] = None, - partition_column: T.Optional[str] = None, batch_size: T.Optional[int] = 4, load_batch_workers: T.Optional[int] = 2, - accumulate_grad_batches: T.Optional[int] = 1, - filters: T.Optional[int] = 32, - num_classes: T.Optional[int] = 2, - edge_class: T.Optional[int] = None, - class_counts: T.Sequence[float] = None, - model_type: str = "ResUNet3Psi", - activation_type: str = "SiLU", - dilations: T.Union[int, T.Sequence[int]] = None, - res_block_type: str = "resa", - attention_weights: str = "spatial_channel", - deep_sup_dist: bool = False, - deep_sup_edge: bool = False, - deep_sup_mask: bool = False, - optimizer: str = "AdamW", - learning_rate: T.Optional[float] = 1e-3, - lr_scheduler: str = "CosineAnnealingLR", - steplr_step_size: T.Optional[T.Sequence[int]] = None, - scale_pos_weight: T.Optional[bool] = True, - epochs: T.Optional[int] = 30, - save_top_k: T.Optional[int] = 1, - early_stopping_patience: T.Optional[int] = 7, - early_stopping_min_delta: T.Optional[float] = 0.01, - gradient_clip_val: T.Optional[float] = 1.0, - gradient_clip_algorithm: T.Optional[float] = "norm", - reset_model: T.Optional[bool] = False, - auto_lr_find: T.Optional[bool] = False, - device: T.Optional[str] = "gpu", - devices: T.Optional[int] = 1, - profiler: T.Optional[str] = None, - weight_decay: T.Optional[float] = 1e-5, - precision: T.Optional[int] = 32, - stochastic_weight_averaging: T.Optional[bool] = False, - stochastic_weight_averaging_lr: T.Optional[float] = 0.05, - stochastic_weight_averaging_start: T.Optional[float] = 0.8, - model_pruning: T.Optional[bool] = False, - save_batch_val_metrics: T.Optional[bool] = False, - skip_train: T.Optional[bool] = False, - refine_model: T.Optional[bool] = False, -): - """Fits a model. - - Args: - dataset (EdgeDataset): The dataset to fit on. - ckpt_file (str | Path): The checkpoint file path. - test_dataset (Optional[EdgeDataset]): A test dataset to evaluate on. If given, early stopping - will switch from the validation dataset to the test dataset. - val_frac (Optional[float]): The fraction of data to use for model validation. - spatial_partitions (Optional[str | Path]): A spatial partitions file. - partition_name (Optional[str]): The spatial partition file column query name. - partition_column (Optional[str]): The spatial partition file column name. - batch_size (Optional[int]): The data batch size. - load_batch_workers (Optional[int]): The number of parallel batches to load. - filters (Optional[int]): The number of initial model filters. - optimizer (Optional[str]): The optimizer. - model_type (Optional[str]): The model type. - activation_type (Optional[str]): The activation type. - dilations (Optional[list]): The dilation size or sizes. - res_block_type (Optional[str]): The residual block type. - attention_weights (Optional[str]): The attention weights. - deep_sup_dist (Optional[bool]): Whether to use deep supervision for distances. - deep_sup_edge (Optional[bool]): Whether to use deep supervision for edges. - deep_sup_mask (Optional[bool]): Whether to use deep supervision for masks. - learning_rate (Optional[float]): The model learning rate. - lr_scheduler (Optional[str]): The learning rate scheduler. - steplr_step_size (Optional[list]): The multiplicative step size factor. - scale_pos_weight (Optional[bool]): Whether to scale class weights (i.e., balance classes). - epochs (Optional[int]): The number of epochs. - save_top_k (Optional[int]): The number of top-k model checkpoints to save. - early_stopping_patience (Optional[int]): The patience (epochs) before early stopping. - early_stopping_min_delta (Optional[float]): The minimum change threshold before early stopping. - gradient_clip_val (Optional[float]): The gradient clip limit. - gradient_clip_algorithm (Optional[str]): The gradient clip algorithm. - reset_model (Optional[bool]): Whether to reset an existing model. Otherwise, pick up from last epoch of - an existing model. - auto_lr_find (Optional[bool]): Whether to search for an optimized learning rate. - device (Optional[str]): The device to train on. Choices are ['cpu', 'gpu']. - devices (Optional[int]): The number of GPU devices to use. - profiler (Optional[str]): A profiler level. Choices are [None, 'simple', 'advanced']. - weight_decay (Optional[float]): The weight decay passed to the optimizer. Default is 1e-5. - precision (Optional[int]): The data precision. Default is 32. - stochastic_weight_averaging (Optional[bool]): Whether to use stochastic weight averaging. - Default is False. - stochastic_weight_averaging_lr (Optional[float]): The stochastic weight averaging learning rate. - Default is 0.05. - stochastic_weight_averaging_start (Optional[float]): The stochastic weight averaging epoch start. - Default is 0.8. - model_pruning (Optional[bool]): Whether to prune the model. Default is False. - save_batch_val_metrics (Optional[bool]): Whether to save batch validation metrics to a parquet file. - skip_train (Optional[bool]): Whether to refine and calibrate a trained model. - refine_model (Optional[bool]): Whether to skip training. - """ - ckpt_file = Path(ckpt_file) - +) -> EdgeDataModule: # Split the dataset into train/validation if spatial_partitions is not None: # TODO: We removed `dataset.split_train_val_by_partition` but # could make it an option in future versions. train_ds, val_ds = dataset.split_train_val( - val_frac=val_frac, spatial_overlap_allowed=False + val_frac=val_frac, + spatial_overlap_allowed=False, + spatial_balance=True, ) else: train_ds, val_ds = dataset.split_train_val(val_frac=val_frac) @@ -332,177 +216,104 @@ def fit( shuffle=True, ) + return data_module + + +def fit_transfer(cultionet_params: CultionetParams) -> None: + """Fits a transfer model.""" + + # This file should already exist + pretrained_ckpt_file = ( + cultionet_params.ckpt_file.parent / ModelNames.CKPT_TRANSFER_NAME + ) + assert ( + pretrained_ckpt_file.exists() + ), "The pretrained checkpoint does not exist." + + # Remove the spatial data because there is no check upstream + if cultionet_params.dataset.grid_gpkg_path.exists(): + cultionet_params.dataset.grid_gpkg_path.unlink() + + # Split the dataset into train/validation + data_module: EdgeDataModule = get_data_module( + **cultionet_params.get_datamodule_params() + ) + + # Get the channel and time dimensions from the dataset + cultionet_params = cultionet_params.update_channels(data_module) + # Setup the Lightning model - lit_model = CultioLitModel( - num_features=train_ds.num_features, - num_time_features=train_ds.num_time_features, - num_classes=num_classes, - filters=filters, - model_type=model_type, - activation_type=activation_type, - dilations=dilations, - res_block_type=res_block_type, - attention_weights=attention_weights, - deep_sup_dist=deep_sup_dist, - deep_sup_edge=deep_sup_edge, - deep_sup_mask=deep_sup_mask, - optimizer=optimizer, - learning_rate=learning_rate, - lr_scheduler=lr_scheduler, - steplr_step_size=steplr_step_size, - weight_decay=weight_decay, - class_counts=class_counts, - edge_class=edge_class, - scale_pos_weight=scale_pos_weight, - save_batch_val_metrics=save_batch_val_metrics, + lit_model = CultionetLitTransferModel( + pretrained_ckpt_file=pretrained_ckpt_file, + finetune=cultionet_params.finetune, + **cultionet_params.get_lightning_params(), ) - if reset_model: - if ckpt_file.is_file(): - ckpt_file.unlink() - model_file = ckpt_file.parent / "cultionet.pt" - if model_file.is_file(): - model_file.unlink() - - # Checkpoint - cb_train_loss = ModelCheckpoint(monitor="loss") - # Validation and test loss - cb_val_loss = ModelCheckpoint( - dirpath=ckpt_file.parent, - filename=ckpt_file.stem, - save_last=True, - save_top_k=save_top_k, - mode="min", - monitor="val_score", - every_n_train_steps=0, - every_n_epochs=1, + # Remove the model file if requested + cultionet_params.check_checkpoint() + + _, callbacks = setup_callbacks(**cultionet_params.get_callback_params()) + callbacks.append(PROGRESS_BAR_CALLBACK) + + # Setup the trainer + trainer = L.Trainer( + callbacks=callbacks, + **cultionet_params.get_trainer_params(), ) - # Early stopping - early_stop_callback = EarlyStopping( - monitor="val_score", - min_delta=early_stopping_min_delta, - patience=early_stopping_patience, - mode="min", - check_on_train_epoch_end=False, + + trainer.fit( + model=lit_model, + datamodule=data_module, + ckpt_path=cultionet_params.ckpt_file + if cultionet_params.ckpt_file.exists() + else None, ) - # Learning rate - lr_monitor = LearningRateMonitor(logging_interval="epoch") - callbacks = [lr_monitor, cb_train_loss, cb_val_loss, early_stop_callback] - if stochastic_weight_averaging: - callbacks.append( - StochasticWeightAveraging( - swa_lrs=stochastic_weight_averaging_lr, - swa_epoch_start=stochastic_weight_averaging_start, - ) - ) - if 0 < model_pruning <= 1: - callbacks.append(ModelPruning("l1_unstructured", amount=model_pruning)) - trainer = pl.Trainer( - default_root_dir=str(ckpt_file.parent), + +def fit(cultionet_params: CultionetParams) -> None: + """Fits a model.""" + + # Split the dataset into train/validation + data_module: EdgeDataModule = get_data_module( + **cultionet_params.get_datamodule_params() + ) + + # Get the channel and time dimensions from the dataset + cultionet_params = cultionet_params.update_channels(data_module) + + # Setup the Lightning model + lit_model = CultionetLitModel(**cultionet_params.get_lightning_params()) + + # Remove the model file if requested + cultionet_params.check_checkpoint() + + lr_monitor, callbacks = setup_callbacks( + **cultionet_params.get_callback_params() + ) + callbacks.append(PROGRESS_BAR_CALLBACK) + + # Setup the trainer + trainer = L.Trainer( callbacks=callbacks, - enable_checkpointing=True, - auto_lr_find=auto_lr_find, - auto_scale_batch_size=False, - accumulate_grad_batches=accumulate_grad_batches, - gradient_clip_val=gradient_clip_val, - gradient_clip_algorithm=gradient_clip_algorithm, - check_val_every_n_epoch=1, - min_epochs=5 if epochs >= 5 else epochs, - max_epochs=epochs, - precision=precision, - devices=None if device == "cpu" else devices, - num_processes=0, - accelerator=device, - log_every_n_steps=50, - profiler=profiler, - deterministic=False, - benchmark=False, + **cultionet_params.get_trainer_params(), ) - if auto_lr_find: - trainer.tune(model=lit_model, datamodule=data_module) + if cultionet_params.auto_lr_find: + tuner = Tuner(trainer) + lr_finder = tuner.lr_find(model=lit_model, datamodule=data_module) + opt_lr = lr_finder.suggestion() + logger.info(f"The suggested learning rate is {opt_lr}") else: - if not skip_train: + if not cultionet_params.skip_train: trainer.fit( model=lit_model, datamodule=data_module, - ckpt_path=ckpt_file if ckpt_file.is_file() else None, - ) - if refine_model: - refine_data_module = EdgeDataModule( - train_ds=dataset, - batch_size=batch_size, - num_workers=load_batch_workers, - shuffle=True, - # For each epoch, train on a random - # subset of 50% of the data. - sampler=EpochRandomSampler( - dataset, num_samples=int(len(dataset) * 0.5) - ), - ) - refine_ckpt_file = ckpt_file.parent / "refine" / ckpt_file.name - refine_ckpt_file.parent.mkdir(parents=True, exist_ok=True) - # refine checkpoints - refine_cb_train_loss = ModelCheckpoint( - dirpath=refine_ckpt_file.parent, - filename=refine_ckpt_file.stem, - save_last=True, - save_top_k=save_top_k, - mode="min", - monitor="loss", - every_n_train_steps=0, - every_n_epochs=1, - ) - # Early stopping - refine_early_stop_callback = EarlyStopping( - monitor="loss", - min_delta=early_stopping_min_delta, - patience=5, - mode="min", - check_on_train_epoch_end=False, - ) - refine_callbacks = [ - lr_monitor, - refine_cb_train_loss, - refine_early_stop_callback, - ] - refine_trainer = pl.Trainer( - default_root_dir=str(refine_ckpt_file.parent), - callbacks=refine_callbacks, - enable_checkpointing=True, - auto_lr_find=auto_lr_find, - auto_scale_batch_size=False, - gradient_clip_val=gradient_clip_val, - gradient_clip_algorithm="value", - check_val_every_n_epoch=1, - min_epochs=1 if epochs >= 1 else epochs, - max_epochs=10, - precision=32, - devices=None if device == "cpu" else devices, - num_processes=0, - accelerator=device, - log_every_n_steps=50, - profiler=profiler, - deterministic=False, - benchmark=False, - ) - # Calibrate the logits - refine_model = RefineLitModel( - in_features=train_ds.num_features, - num_classes=num_classes, - edge_class=edge_class, - class_counts=class_counts, - cultionet_ckpt=ckpt_file, - ) - refine_trainer.fit( - model=refine_model, - datamodule=refine_data_module, - ckpt_path=refine_ckpt_file - if refine_ckpt_file.is_file() + ckpt_path=cultionet_params.ckpt_file + if cultionet_params.ckpt_file.exists() else None, ) - if test_dataset is not None: + + if cultionet_params.test_dataset is not None: trainer.test( model=lit_model, dataloaders=data_module.test_dataloader(), @@ -522,24 +333,29 @@ def load_model( model_file: T.Union[str, Path] = None, num_features: T.Optional[int] = None, num_time_features: T.Optional[int] = None, - num_classes: T.Optional[int] = None, filters: T.Optional[int] = None, device: T.Union[str, bytes] = "gpu", devices: T.Optional[int] = 1, - lit_model: T.Optional[CultioLitModel] = None, + lit_model: T.Optional[CultionetLitModel] = None, enable_progress_bar: T.Optional[bool] = True, return_trainer: T.Optional[bool] = False, -) -> T.Tuple[T.Union[None, pl.Trainer], CultioLitModel]: +) -> T.Tuple[T.Union[None, L.Trainer], CultionetLitModel]: """Loads a model from file. - Args: - ckpt_file (str | Path): The model checkpoint file. - model_file (str | Path): The model file. - device (str): The device to apply inference on. - lit_model (CultioLitModel): A model to predict with. If `None`, the model - is loaded from file. - enable_progress_bar (Optional[bool]): Whether to use the progress bar. - return_trainer (Optional[bool]): Whether to return the `pytorch_lightning` `Trainer`. + Parameters + ========== + ckpt_file + The model checkpoint file. + model_file + The model file. + device + The device to apply inference on. + lit_model + A model to predict with. If `None`, the model is loaded from file. + enable_progress_bar + Whether to use the progress bar. + return_trainer + Whether to return the `lightning` `Trainer`. """ if ckpt_file is not None: ckpt_file = Path(ckpt_file) @@ -551,19 +367,14 @@ def load_model( trainer_kwargs = dict( default_root_dir=str(ckpt_file.parent), precision=32, - devices=None if device == "cpu" else devices, - gpus=1 if device == "gpu" else None, + devices=devices, accelerator=device, - num_processes=0, log_every_n_steps=0, logger=False, enable_progress_bar=enable_progress_bar, ) - if trainer_kwargs["accelerator"] == "cpu": - del trainer_kwargs["devices"] - del trainer_kwargs["gpus"] - trainer = pl.Trainer(**trainer_kwargs) + trainer = L.Trainer(**trainer_kwargs) if lit_model is None: if model_file is not None: @@ -574,16 +385,15 @@ def load_model( raise TypeError( "The features must be given to load the model file." ) - lit_model = CultioLitModel( + lit_model = CultionetLitModel( num_features=num_features, num_time_features=num_time_features, filters=filters, - num_classes=num_classes, ) lit_model.load_state_dict(state_dict=torch.load(model_file)) else: assert ckpt_file.is_file(), "The checkpoint file does not exist." - lit_model = CultioLitModel.load_from_checkpoint( + lit_model = CultionetLitModel.load_from_checkpoint( checkpoint_path=str(ckpt_file) ) lit_model.eval() @@ -597,21 +407,20 @@ def predict_lightning( out_path: T.Union[str, Path], ckpt: Path, dataset: EdgeDataset, - batch_size: int, - load_batch_workers: int, - device: str, - devices: int, - precision: int, - num_classes: int, - resampling: str, - ref_res: float, - compression: str, - refine_pt: T.Optional[Path] = None, + device: str = "gpu", + devices: int = 1, + strategy: str = "ddp", + batch_size: int = 4, + load_batch_workers: int = 0, + precision: T.Union[int, str] = "16-mixed", + resampling: str = "nearest", + compression: str = "lzw", + is_transfer_model: bool = False, ): reference_image = Path(reference_image) out_path = Path(out_path) ckpt_file = Path(ckpt) - assert ckpt_file.is_file(), "The checkpoint file does not exist." + assert ckpt_file.exists(), "The checkpoint file does not exist." data_module = EdgeDataModule( predict_ds=dataset, @@ -622,37 +431,33 @@ def predict_lightning( pred_writer = LightningGTiffWriter( reference_image=reference_image, out_path=out_path, - num_classes=num_classes, - ref_res=ref_res, resampling=resampling, compression=compression, ) trainer_kwargs = dict( default_root_dir=str(ckpt_file.parent), - callbacks=[pred_writer], + callbacks=[pred_writer, PROGRESS_BAR_CALLBACK], precision=precision, - devices=None if device == "cpu" else devices, - gpus=1 if device == "gpu" else None, + devices=devices, accelerator=device, - num_processes=0, + strategy=strategy, log_every_n_steps=0, logger=False, ) - trainer = pl.Trainer(**trainer_kwargs) - cultionet_lit_model = CultioLitModel.load_from_checkpoint( - checkpoint_path=str(ckpt_file) - ) + trainer = L.Trainer(**trainer_kwargs) - geo_refine_model = None - if refine_pt is not None: - if refine_pt.is_file(): - geo_refine_model = GeoRefinement( - in_features=dataset.num_features, out_channels=num_classes - ) - geo_refine_model.load_state_dict(torch.load(refine_pt)) - geo_refine_model.eval() - setattr(cultionet_lit_model, "temperature_lit_model", geo_refine_model) + if is_transfer_model: + pretrained_ckpt_file = ckpt.parent / ModelNames.CKPT_TRANSFER_NAME + + cultionet_lit_model = CultionetLitTransferModel.load_from_checkpoint( + checkpoint_path=str(ckpt_file), + pretrained_ckpt_file=pretrained_ckpt_file, + ) + else: + cultionet_lit_model = CultionetLitModel.load_from_checkpoint( + checkpoint_path=str(ckpt_file) + ) # Make predictions trainer.predict( @@ -660,140 +465,3 @@ def predict_lightning( datamodule=data_module, return_predictions=False, ) - - -def predict( - lit_model: CultioLitModel, - data: Data, - written: np.ndarray, - data_values: torch.Tensor, - w: Window = None, - w_pad: Window = None, - device: str = "cpu", - include_maskrcnn: bool = False, -) -> np.ndarray: - """Applies a model to predict image labels|values. - - Args: - lit_model (CultioLitModel): A model to predict with. - data (Data): The data to predict on. - written (ndarray) - data_values (Tensor) - w (Optional[int]): The ``rasterio.windows.Window`` to write to. - w_pad (Optional[int]): The ``rasterio.windows.Window`` to predict on. - device (Optional[str]) - """ - norm_batch = zscores(data, data_values.mean, data_values.std) - if device == "gpu": - norm_batch = norm_batch.to("cuda") - lit_model = lit_model.to("cuda") - with torch.no_grad(): - distance, dist_1, dist_2, dist_3, dist_4, edge, crop = lit_model( - norm_batch - ) - crop_type = torch.zeros((crop.size(0), 2), dtype=crop.dtype) - - if include_maskrcnn: - # TODO: fix this -- separate Mask R-CNN model - predictions = lit_model.mask_forward( - distance=distance, - edge=edge, - height=norm_batch.height, - width=norm_batch.width, - batch=None, - ) - instances = None - if include_maskrcnn: - instances = np.zeros( - (norm_batch.height, norm_batch.width), dtype="float64" - ) - if include_maskrcnn: - scores = predictions[0]["scores"].squeeze() - masks = predictions[0]["masks"].squeeze() - resizer = transforms.Resize((norm_batch.height, norm_batch.width)) - masks = resizer(masks) - # Filter by box scores - masks = masks[scores > 0.5] - scores = scores[scores > 0.5] - # Filter by pixel scores - masks = torch.where(masks > 0.5, masks, 0) - masks = masks.detach().cpu().numpy() - if masks.shape[0] > 0: - distance_mask = ( - distance.detach() - .cpu() - .numpy() - .reshape(norm_batch.height, norm_batch.width) - ) - edge_mask = ( - edge[:, 1] - .detach() - .cpu() - .numpy() - .reshape(norm_batch.height, norm_batch.width) - ) - crop_mask = ( - crop[:, 1] - .detach() - .cpu() - .numpy() - .reshape(norm_batch.height, norm_batch.width) - ) - instances = np.zeros( - (norm_batch.height, norm_batch.width), dtype="float64" - ) - - uid = 1 if written.max() == 0 else written.max() + 1 - - def iou(reference, targets): - tp = ((reference > 0.5) & (targets > 0.5)).sum() - fp = ((reference <= 0.5) & (targets > 0.5)).sum() - fn = ((reference > 0.5) & (targets <= 0.5)).sum() - - return tp / (tp + fp + fn) - - for lyr_idx_ref, lyr_ref in enumerate(masks): - lyr = None - for lyr_idx_targ, lyr_targ in enumerate(masks): - if lyr_idx_targ != lyr_idx_ref: - if iou(lyr_ref, lyr_targ) > 0.5: - lyr = ( - lyr_ref - if scores[lyr_idx_ref] - > scores[lyr_idx_targ] - else lyr_targ - ) - if lyr is None: - lyr = lyr_ref - conditional = ( - (lyr > 0.5) - & (distance_mask > 0.1) - & (edge_mask < 0.5) - & (crop_mask > 0.5) - ) - if written[conditional].max() > 0: - uid = int(sci_mode(written[conditional]).mode) - instances = np.where( - ((instances == 0) & conditional), uid, instances - ) - uid = instances.max() + 1 - instances /= SCALE_FACTOR - else: - logger.warning("No fields were identified.") - - mo = ModelOutputs( - distance=distance, - edge=edge, - crop=crop, - crop_type=crop_type, - instances=instances, - apply_softmax=False, - ) - stack = mo.stack_outputs(w, w_pad) - if include_maskrcnn: - stack[:-1] = (stack[:-1] * SCALE_FACTOR).clip(0, SCALE_FACTOR) - stack[-1] *= SCALE_FACTOR - else: - stack = (stack * SCALE_FACTOR).clip(0, SCALE_FACTOR) - - return stack diff --git a/src/cultionet/models/base_layers.py b/src/cultionet/models/base_layers.py deleted file mode 100644 index f3c4a7fa..00000000 --- a/src/cultionet/models/base_layers.py +++ /dev/null @@ -1,1532 +0,0 @@ -import typing as T -import enum - -import torch -import torch.nn.functional as F -from torch_geometric import nn - -from . import model_utils -from .enums import ResBlockTypes - - -class Swish(torch.nn.Module): - def __init__(self, channels: int, dims: int): - super(Swish, self).__init__() - - self.sigmoid = torch.nn.Sigmoid() - self.beta = torch.nn.Parameter(torch.ones(1)) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x * self.sigmoid(self.beta * x) - - def reset_parameters(self): - torch.nn.init.ones_(self.beta) - - -class SetActivation(torch.nn.Module): - def __init__( - self, - activation_type: str, - channels: T.Optional[int] = None, - dims: T.Optional[int] = None, - ): - """ - Examples: - >>> act = SetActivation('ReLU') - >>> act(x) - >>> - >>> act = SetActivation('LeakyReLU') - >>> act(x) - >>> - >>> act = SetActivation('Swish', channels=32) - >>> act(x) - """ - super(SetActivation, self).__init__() - - if activation_type == "Swish": - assert isinstance( - channels, int - ), "Swish requires the input channels." - assert isinstance( - dims, int - ), "Swish requires the tensor dimension." - self.activation = Swish(channels=channels, dims=dims) - else: - self.activation = getattr(torch.nn, activation_type)(inplace=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.activation(x) - - -class LogSoftmax(torch.nn.Module): - def __init__(self, dim: int = 1): - super(LogSoftmax, self).__init__() - - self.dim = dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.log_softmax(x, dim=self.dim, dtype=x.dtype) - - -class Softmax(torch.nn.Module): - def __init__(self, dim: int = 1): - super(Softmax, self).__init__() - - self.dim = dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.softmax(x, dim=self.dim, dtype=x.dtype) - - -class Permute(torch.nn.Module): - def __init__(self, axis_order: T.Sequence[int]): - super(Permute, self).__init__() - self.axis_order = axis_order - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.permute(*self.axis_order) - - -class Add(torch.nn.Module): - def __init__(self): - super(Add, self).__init__() - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x + y - - -class Min(torch.nn.Module): - def __init__(self, dim: int, keepdim: bool = False): - super(Min, self).__init__() - - self.dim = dim - self.keepdim = keepdim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.min(dim=self.dim, keepdim=self.keepdim)[0] - - -class Max(torch.nn.Module): - def __init__(self, dim: int, keepdim: bool = False): - super(Max, self).__init__() - - self.dim = dim - self.keepdim = keepdim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.max(dim=self.dim, keepdim=self.keepdim)[0] - - -class Mean(torch.nn.Module): - def __init__(self, dim: int, keepdim: bool = False): - super(Mean, self).__init__() - - self.dim = dim - self.keepdim = keepdim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.mean(dim=self.dim, keepdim=self.keepdim) - - -class Var(torch.nn.Module): - def __init__( - self, dim: int, keepdim: bool = False, unbiased: bool = False - ): - super(Var, self).__init__() - - self.dim = dim - self.keepdim = keepdim - self.unbiased = unbiased - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.var( - dim=self.dim, keepdim=self.keepdim, unbiased=self.unbiased - ) - - -class Std(torch.nn.Module): - def __init__( - self, dim: int, keepdim: bool = False, unbiased: bool = False - ): - super(Std, self).__init__() - - self.dim = dim - self.keepdim = keepdim - self.unbiased = unbiased - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.std( - dim=self.dim, keepdim=self.keepdim, unbiased=self.unbiased - ) - - -class Squeeze(torch.nn.Module): - def __init__(self, dim: T.Optional[int] = None): - super(Squeeze, self).__init__() - - self.dim = dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.squeeze(dim=self.dim) - - -class Unsqueeze(torch.nn.Module): - def __init__(self, dim: int): - super(Unsqueeze, self).__init__() - - self.dim = dim - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return x.unsqueeze(self.dim) - - -class SigmoidCrisp(torch.nn.Module): - r"""Sigmoid crisp. - - Adapted from publication and source code below: - - CSIRO BSTD/MIT LICENSE - - Redistribution and use in source and binary forms, with or without modification, are permitted provided that - the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the - following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or other materials provided with the distribution. - 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or - promote products derived from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, - INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - Citation: - @article{diakogiannis_etal_2021, - title={Looking for change? Roll the dice and demand attention}, - author={Diakogiannis, Foivos I and Waldner, Fran{\c{c}}ois and Caccetta, Peter}, - journal={Remote Sensing}, - volume={13}, - number={18}, - pages={3707}, - year={2021}, - publisher={MDPI} - } - - Reference: - https://www.mdpi.com/2072-4292/13/18/3707 - https://arxiv.org/pdf/2009.02062.pdf - https://github.com/waldnerf/decode/blob/main/FracTAL_ResUNet/nn/activations/sigmoid_crisp.py - """ - - def __init__(self, smooth: float = 1e-2): - super(SigmoidCrisp, self).__init__() - - self.smooth = smooth - self.gamma = torch.nn.Parameter(torch.ones(1)) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out = self.smooth + self.sigmoid(self.gamma) - out = torch.reciprocal(out) - out = x * out - out = self.sigmoid(out) - - return out - - -class ConvBlock2d(torch.nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - padding: int = 0, - dilation: int = 1, - add_activation: bool = True, - activation_type: str = "LeakyReLU", - ): - super(ConvBlock2d, self).__init__() - - layers = [ - torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=padding, - dilation=dilation, - bias=False, - ), - torch.nn.BatchNorm2d(out_channels), - ] - if add_activation: - layers += [ - SetActivation(activation_type, channels=out_channels, dims=2) - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class ResBlock2d(torch.nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - padding: int = 0, - dilation: int = 1, - activation_type: str = "LeakyReLU", - ): - super(ResBlock2d, self).__init__() - - layers = [ - torch.nn.BatchNorm2d(in_channels), - SetActivation(activation_type, channels=in_channels, dims=2), - torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=padding, - dilation=dilation, - ), - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class ConvBlock3d(torch.nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - in_time: int = 0, - padding: int = 0, - dilation: int = 1, - add_activation: bool = True, - squeeze: bool = False, - activation_type: str = "LeakyReLU", - ): - super(ConvBlock3d, self).__init__() - - layers = [ - torch.nn.Conv3d( - in_channels, - out_channels, - kernel_size=kernel_size, - padding=padding, - dilation=dilation, - bias=False, - ) - ] - if squeeze: - layers += [Squeeze(), torch.nn.BatchNorm2d(in_time)] - dims = 2 - else: - layers += [torch.nn.BatchNorm3d(out_channels)] - dims = 3 - if add_activation: - layers += [ - SetActivation( - activation_type, channels=out_channels, dims=dims - ) - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class AttentionAdd(torch.nn.Module): - def __init__(self): - super(AttentionAdd, self).__init__() - - self.up = model_utils.UpSample() - - def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: - if x.shape[-2:] != g.shape[-2:]: - x = self.up(x, size=g.shape[-2:], mode="bilinear") - - return x + g - - -class AttentionGate3d(torch.nn.Module): - def __init__(self, high_channels: int, low_channels: int): - super(AttentionGate3d, self).__init__() - - conv_x = torch.nn.Conv3d( - high_channels, high_channels, kernel_size=1, padding=0 - ) - conv_g = torch.nn.Conv3d( - low_channels, - high_channels, - kernel_size=1, - padding=0, - ) - conv1d = torch.nn.Conv3d(high_channels, 1, kernel_size=1, padding=0) - self.up = model_utils.UpSample() - - self.seq = nn.Sequential( - "x, g", - [ - (conv_x, "x -> x"), - (conv_g, "g -> g"), - (AttentionAdd(), "x, g -> x"), - (torch.nn.LeakyReLU(inplace=False), "x -> x"), - (conv1d, "x -> x"), - (torch.nn.Sigmoid(), "x -> x"), - ], - ) - self.final = ConvBlock3d( - in_channels=high_channels, - out_channels=high_channels, - kernel_size=1, - add_activation=False, - ) - - def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: - """ - Args: - x: Higher dimension - g: Lower dimension - """ - h = self.seq(x, g) - if h.shape[-2:] != x.shape[-2:]: - h = self.up(h, size=x.shape[-2:], mode="bilinear") - - return self.final(x * h) - - -class AttentionGate(torch.nn.Module): - def __init__(self, high_channels: int, low_channels: int): - super(AttentionGate, self).__init__() - - conv_x = torch.nn.Conv2d( - high_channels, high_channels, kernel_size=1, padding=0 - ) - conv_g = torch.nn.Conv2d( - low_channels, - high_channels, - kernel_size=1, - padding=0, - ) - conv1d = torch.nn.Conv2d(high_channels, 1, kernel_size=1, padding=0) - self.up = model_utils.UpSample() - - self.seq = nn.Sequential( - "x, g", - [ - (conv_x, "x -> x"), - (conv_g, "g -> g"), - (AttentionAdd(), "x, g -> x"), - (torch.nn.LeakyReLU(inplace=False), "x -> x"), - (conv1d, "x -> x"), - (torch.nn.Sigmoid(), "x -> x"), - ], - ) - self.final = ConvBlock2d( - in_channels=high_channels, - out_channels=high_channels, - kernel_size=1, - add_activation=False, - ) - - def forward(self, x: torch.Tensor, g: torch.Tensor) -> torch.Tensor: - """ - Args: - x: Higher dimension - g: Lower dimension - """ - h = self.seq(x, g) - if h.shape[-2:] != x.shape[-2:]: - h = self.up(h, size=x.shape[-2:], mode="bilinear") - - return self.final(x * h) - - -class TanimotoComplement(torch.nn.Module): - """Tanimoto distance with complement. - - THIS IS NOT CURRENTLY USED ANYWHERE IN THIS REPOSITORY - - Adapted from publications and source code below: - - CSIRO BSTD/MIT LICENSE - - Redistribution and use in source and binary forms, with or without modification, are permitted provided that - the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the - following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or other materials provided with the distribution. - 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or - promote products derived from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, - INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - References: - https://www.mdpi.com/2072-4292/14/22/5738 - https://arxiv.org/abs/2009.02062 - https://github.com/waldnerf/decode/blob/main/FracTAL_ResUNet/nn/layers/ftnmt.py - """ - - def __init__( - self, - smooth: float = 1e-5, - depth: int = 5, - dim: T.Union[int, T.Sequence[int]] = 0, - targets_are_labels: bool = True, - ): - super(TanimotoComplement, self).__init__() - - self.smooth = smooth - self.depth = depth - self.dim = dim - self.targets_are_labels = targets_are_labels - - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - """Performs a single forward pass. - - Args: - inputs: Predictions from model (probabilities or labels). - targets: Ground truth values. - - Returns: - Tanimoto distance loss (float) - """ - if self.depth == 1: - scale = 1.0 - else: - scale = 1.0 / self.depth - - def tanimoto(y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor: - tpl = torch.sum(y * yhat, dim=self.dim, keepdim=True) - numerator = tpl + self.smooth - sq_sum = torch.sum(y**2 + yhat**2, dim=self.dim, keepdim=True) - denominator = torch.zeros(1, dtype=inputs.dtype).to( - device=inputs.device - ) - for d in range(0, self.depth): - a = 2**d - b = -(2.0 * a - 1.0) - denominator = denominator + torch.reciprocal( - (a * sq_sum) + (b * tpl) + self.smooth - ) - - return numerator * denominator * scale - - l1 = tanimoto(targets, inputs) - l2 = tanimoto(1.0 - targets, 1.0 - inputs) - score = (l1 + l2) * 0.5 - - return score - - -class TanimotoDist(torch.nn.Module): - r"""Tanimoto distance. - - Adapted from publication and source code below: - - CSIRO BSTD/MIT LICENSE - - Redistribution and use in source and binary forms, with or without modification, are permitted provided that - the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the - following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or other materials provided with the distribution. - 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or - promote products derived from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, - INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - Citation: - @article{diakogiannis_etal_2021, - title={Looking for change? Roll the dice and demand attention}, - author={Diakogiannis, Foivos I and Waldner, Fran{\c{c}}ois and Caccetta, Peter}, - journal={Remote Sensing}, - volume={13}, - number={18}, - pages={3707}, - year={2021}, - publisher={MDPI} - } - - References: - https://www.mdpi.com/2072-4292/13/18/3707 - https://arxiv.org/abs/2009.02062 - https://arxiv.org/pdf/2009.02062.pdf - https://github.com/waldnerf/decode/blob/9e922a2082e570e248eaee10f7a1f2f0bd852b42/FracTAL_ResUNet/nn/layers/ftnmt.py - - Adapted from source code below: - - MIT License - - Copyright (c) 2017-2020 Matej Aleksandrov, Matej Batič, Matic Lubej, Grega Milčinski (Sinergise) - Copyright (c) 2017-2020 Devis Peressutti, Jernej Puc, Anže Zupanc, Lojze Žust, Jovan Višnjić (Sinergise) - - Reference: - https://github.com/sentinel-hub/eo-flow/blob/master/eoflow/models/losses.py - """ - - def __init__( - self, - smooth: float = 1e-5, - weight: T.Optional[torch.Tensor] = None, - dim: T.Union[int, T.Sequence[int]] = 0, - ): - super(TanimotoDist, self).__init__() - - self.smooth = smooth - self.weight = weight - self.dim = dim - - def forward( - self, inputs: torch.Tensor, targets: torch.Tensor - ) -> torch.Tensor: - """Performs a single forward pass. - - Args: - inputs: Predictions from model (probabilities, logits or labels). - targets: Ground truth values. - - Returns: - Tanimoto distance loss (float) - """ - - def _tanimoto(yhat: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - tpl = torch.sum(yhat * y, dim=self.dim, keepdim=True) - sq_sum = torch.sum(yhat**2 + y**2, dim=self.dim, keepdim=True) - numerator = tpl + self.smooth - denominator = (sq_sum - tpl) + self.smooth - tanimoto_score = numerator / denominator - - return tanimoto_score - - score = _tanimoto(inputs, targets) - compl_score = _tanimoto(1.0 - inputs, 1.0 - targets) - score = (score + compl_score) * 0.5 - - return score - - -class FractalAttention(torch.nn.Module): - """Fractal Tanimoto Attention Layer (FracTAL) - - Adapted from publication and source code below: - - CSIRO BSTD/MIT LICENSE - - Redistribution and use in source and binary forms, with or without modification, are permitted provided that - the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the - following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or other materials provided with the distribution. - 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or - promote products derived from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, - INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - Reference: - https://www.mdpi.com/2072-4292/13/18/3707 - https://arxiv.org/pdf/2009.02062.pdf - https://github.com/waldnerf/decode/blob/9e922a2082e570e248eaee10f7a1f2f0bd852b42/FracTAL_ResUNet/nn/units/fractal_resnet.py - https://github.com/waldnerf/decode/blob/9e922a2082e570e248eaee10f7a1f2f0bd852b42/FracTAL_ResUNet/nn/layers/attention.py - """ - - def __init__(self, in_channels: int, out_channels: int): - super(FractalAttention, self).__init__() - - self.query = torch.nn.Sequential( - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - add_activation=False, - ), - torch.nn.Sigmoid(), - ) - self.key = torch.nn.Sequential( - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - add_activation=False, - ), - torch.nn.Sigmoid(), - ) - self.value = torch.nn.Sequential( - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - add_activation=False, - ), - torch.nn.Sigmoid(), - ) - - self.spatial_sim = TanimotoDist(dim=1) - self.channel_sim = TanimotoDist(dim=[2, 3]) - self.norm = torch.nn.BatchNorm2d(out_channels) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - q = self.query(x) - k = self.key(x) - v = self.value(x) - - attention_spatial = self.spatial_sim(q, k) - v_spatial = attention_spatial * v - - attention_channel = self.channel_sim(q, k) - v_channel = attention_channel * v - - attention = (v_spatial + v_channel) * 0.5 - attention = self.norm(attention) - - return attention - - -class ChannelAttention(torch.nn.Module): - def __init__(self, out_channels: int, activation_type: str): - super(ChannelAttention, self).__init__() - - # Channel attention - self.channel_adaptive_avg = torch.nn.AdaptiveAvgPool2d(1) - self.channel_adaptive_max = torch.nn.AdaptiveMaxPool2d(1) - self.sigmoid = torch.nn.Sigmoid() - self.seq = torch.nn.Sequential( - torch.nn.Conv2d( - in_channels=out_channels, - out_channels=int(out_channels / 2), - kernel_size=1, - padding=0, - bias=False, - ), - SetActivation(activation_type=activation_type), - torch.nn.Conv2d( - in_channels=int(out_channels / 2), - out_channels=out_channels, - kernel_size=1, - padding=0, - bias=False, - ), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - avg_attention = self.seq(self.channel_adaptive_avg(x)) - max_attention = self.seq(self.channel_adaptive_max(x)) - attention = avg_attention + max_attention - attention = self.sigmoid(attention) - - return attention.expand_as(x) - - -class SpatialAttention(torch.nn.Module): - def __init__(self): - super(SpatialAttention, self).__init__() - - self.conv = torch.nn.Conv2d( - in_channels=2, out_channels=1, kernel_size=3, padding=1, bias=False - ) - self.channel_mean = Mean(dim=1, keepdim=True) - self.channel_max = Max(dim=1, keepdim=True) - self.sigmoid = torch.nn.Sigmoid() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - avg_attention = self.channel_mean(x) - max_attention = self.channel_max(x) - attention = torch.cat([avg_attention, max_attention], dim=1) - attention = self.conv(attention) - attention = self.sigmoid(attention) - - return attention.expand_as(x) - - -class SpatialChannelAttention(torch.nn.Module): - """Spatial-Channel Attention Block. - - References: - https://arxiv.org/abs/1807.02758 - https://github.com/yjn870/RCAN-pytorch - https://www.mdpi.com/2072-4292/14/9/2253 - https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py - """ - - def __init__(self, out_channels: int, activation_type: str): - super(SpatialChannelAttention, self).__init__() - - self.channel_attention = ChannelAttention( - out_channels=out_channels, activation_type=activation_type - ) - self.spatial_attention = SpatialAttention() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - channel_attention = self.channel_attention(x) - spatial_attention = self.spatial_attention(x) - attention = (channel_attention + spatial_attention) * 0.5 - - return attention - - -class ResSpatioTemporalConv3d(torch.nn.Module): - """A spatio-temporal convolution layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - activation_type: str = "LeakyReLU", - ): - super(ResSpatioTemporalConv3d, self).__init__() - - layers = [ - # Conv -> Batchnorm -> Activation - ConvBlock3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type=activation_type, - ), - # Conv -> Batchnorm - ConvBlock3d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=2, - dilation=2, - add_activation=False, - ), - ] - - self.seq = torch.nn.Sequential(*layers) - # Conv -> Batchnorm - self.skip = ConvBlock3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ) - self.final_act = SetActivation(activation_type=activation_type) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.seq(x) + self.skip(x) - - return self.final_act(x) - - -class SpatioTemporalConv3d(torch.nn.Module): - """A spatio-temporal convolution layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - activation_type: str = "LeakyReLU", - ): - super(SpatioTemporalConv3d, self).__init__() - - layers = [ - # Conv -> Batchnorm -> Activation - ConvBlock3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type=activation_type, - ), - # Conv -> Batchnorm - ConvBlock3d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=2, - dilation=2, - activation_type=activation_type, - ), - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class DoubleConv(torch.nn.Module): - """A double convolution layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "LeakyReLU", - ): - super(DoubleConv, self).__init__() - - layers = [] - - init_channels = in_channels - if init_point_conv: - layers += [ - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - activation_type=activation_type, - ) - ] - init_channels = out_channels - - layers += [ - ConvBlock2d( - in_channels=init_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type=activation_type, - ), - ConvBlock2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=double_dilation, - dilation=double_dilation, - activation_type=activation_type, - ), - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class AtrousPyramidPooling(torch.nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - dilation_b: int = 2, - dilation_c: int = 3, - dilation_d: int = 4, - ): - super(AtrousPyramidPooling, self).__init__() - - self.up = model_utils.UpSample() - - self.pool_a = torch.nn.AdaptiveAvgPool2d((1, 1)) - self.pool_b = torch.nn.AdaptiveAvgPool2d((2, 2)) - self.pool_c = torch.nn.AdaptiveAvgPool2d((4, 4)) - self.pool_d = torch.nn.AdaptiveAvgPool2d((8, 8)) - - self.conv_a = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ) - self.conv_b = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=dilation_b, - dilation=dilation_b, - add_activation=False, - ) - self.conv_c = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=dilation_c, - dilation=dilation_c, - add_activation=False, - ) - self.conv_d = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=dilation_d, - dilation=dilation_d, - add_activation=False, - ) - self.final = ConvBlock2d( - in_channels=int(in_channels * 4) + int(out_channels * 4), - out_channels=out_channels, - kernel_size=3, - padding=1, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - out_pa = self.up(self.pool_a(x), size=x.shape[-2:], mode="bilinear") - out_pb = self.up(self.pool_b(x), size=x.shape[-2:], mode="bilinear") - out_pc = self.up(self.pool_c(x), size=x.shape[-2:], mode="bilinear") - out_pd = self.up(self.pool_d(x), size=x.shape[-2:], mode="bilinear") - out_ca = self.conv_a(x) - out_cb = self.conv_b(x) - out_cc = self.conv_c(x) - out_cd = self.conv_d(x) - out = torch.cat( - [out_pa, out_pb, out_pc, out_pd, out_ca, out_cb, out_cc, out_cd], - dim=1, - ) - out = self.final(out) - - return out - - -class PoolConvSingle(torch.nn.Module): - """Max pooling followed by convolution.""" - - def __init__( - self, in_channels: int, out_channels: int, pool_size: int = 2 - ): - super(PoolConvSingle, self).__init__() - - self.seq = torch.nn.Sequential( - torch.nn.MaxPool2d(pool_size), - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - ), - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class PoolConv(torch.nn.Module): - """Max pooling with (optional) dropout.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - pool_size: int = 2, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "LeakyReLU", - dropout: T.Optional[float] = None, - ): - super(PoolConv, self).__init__() - - layers = [torch.nn.MaxPool2d(pool_size)] - if dropout is not None: - layers += [torch.nn.Dropout(dropout)] - layers += [ - DoubleConv( - in_channels=in_channels, - out_channels=out_channels, - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - ] - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class ResidualConvInit(torch.nn.Module): - """A residual convolution layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - activation_type: str = "LeakyReLU", - ): - super(ResidualConvInit, self).__init__() - - self.seq = torch.nn.Sequential( - # Conv -> Batchnorm -> Activation - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type=activation_type, - ), - # Conv -> Batchnorm - ConvBlock2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=2, - dilation=2, - add_activation=False, - ), - ) - # Conv -> Batchnorm - self.skip = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ) - self.final_act = SetActivation(activation_type=activation_type) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.seq(x) + self.skip(x) - - return self.final_act(x) - - -class ResConvLayer(torch.nn.Module): - """Convolution layer designed for a residual activation. - - if num_blocks [Conv2d-BatchNorm-Activation -> Conv2dAtrous-BatchNorm] - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - dilation: int, - activation_type: str = "LeakyReLU", - num_blocks: int = 2, - ): - super(ResConvLayer, self).__init__() - - assert num_blocks > 0 - - if num_blocks == 1: - layers = [ - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=dilation, - dilation=dilation, - add_activation=False, - ) - ] - else: - # Block 1 - layers = [ - ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type=activation_type, - ) - ] - if num_blocks > 2: - # Blocks 2:N-1 - layers += [ - ConvBlock2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=dilation, - dilation=dilation, - activation_type=activation_type, - ) - for __ in range(num_blocks - 2) - ] - # Block N - layers += [ - ConvBlock2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=dilation, - dilation=dilation, - add_activation=False, - ) - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class ResidualConv(torch.nn.Module): - """A residual convolution layer with (optional) attention.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - dilation: int = 2, - attention_weights: str = None, - activation_type: str = "LeakyReLU", - ): - super(ResidualConv, self).__init__() - - self.attention_weights = attention_weights - - if self.attention_weights is not None: - assert self.attention_weights in [ - "fractal", - "spatial_channel", - ], "The attention method is not supported." - - self.gamma = torch.nn.Parameter(torch.ones(1)) - - if self.attention_weights == "fractal": - self.attention_conv = FractalAttention( - in_channels=in_channels, out_channels=out_channels - ) - elif self.attention_weights == "spatial_channel": - self.attention_conv = SpatialChannelAttention( - out_channels=out_channels, activation_type=activation_type - ) - - # Ends with Conv2d -> BatchNorm2d - self.seq = ResConvLayer( - in_channels=in_channels, - out_channels=out_channels, - dilation=dilation, - activation_type=activation_type, - num_blocks=2, - ) - self.skip = None - if in_channels != out_channels: - # Conv2d -> BatchNorm2d - self.skip = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ) - self.final_act = SetActivation(activation_type=activation_type) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - if self.skip is not None: - # Align channels - residual = self.skip(x) - residual = residual + self.seq(x) - - if self.attention_weights is not None: - # Get the attention weights - if self.attention_weights == "spatial_channel": - # Get weights from the residual - attention = self.attention_conv(residual) - elif self.attention_weights == "fractal": - # Get weights from the input - attention = self.attention_conv(x) - - # 1 + γA - attention = 1.0 + self.gamma * attention - residual = residual * attention - - out = self.final_act(residual) - - return out - - -class ResidualAConv(torch.nn.Module): - r"""Residual convolution with atrous/dilated convolutions. - - Adapted from publication below: - - CSIRO BSTD/MIT LICENSE - - Redistribution and use in source and binary forms, with or without modification, are permitted provided that - the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the - following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and - the following disclaimer in the documentation and/or other materials provided with the distribution. - 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or - promote products derived from this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, - INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, - WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE - USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - Citation: - @article{diakogiannis_etal_2020, - title={ResUNet-a: A deep learning framework for semantic segmentation of remotely sensed data}, - author={Diakogiannis, Foivos I and Waldner, Fran{\c{c}}ois and Caccetta, Peter and Wu, Chen}, - journal={ISPRS Journal of Photogrammetry and Remote Sensing}, - volume={162}, - pages={94--114}, - year={2020}, - publisher={Elsevier} - } - - References: - https://www.sciencedirect.com/science/article/abs/pii/S0924271620300149 - https://arxiv.org/abs/1904.00592 - https://arxiv.org/pdf/1904.00592.pdf - - Modules: - module1: [Conv2dAtrous-BatchNorm] - ... - moduleN: [Conv2dAtrous-BatchNorm] - - Dilation sum: - sum = [module1 + module2 + ... + moduleN] - out = sum + skip - - Attention: - out = out * attention - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - dilations: T.List[int] = None, - attention_weights: str = None, - activation_type: str = "LeakyReLU", - ): - super(ResidualAConv, self).__init__() - - self.attention_weights = attention_weights - - if self.attention_weights is not None: - assert self.attention_weights in [ - "fractal", - "spatial_channel", - ], "The attention method is not supported." - - self.gamma = torch.nn.Parameter(torch.ones(1)) - - if self.attention_weights == "fractal": - self.attention_conv = FractalAttention( - in_channels=in_channels, out_channels=out_channels - ) - elif self.attention_weights == "spatial_channel": - self.attention_conv = SpatialChannelAttention( - out_channels=out_channels, activation_type=activation_type - ) - - self.res_modules = torch.nn.ModuleList( - [ - # Conv2dAtrous -> Batchnorm - ResConvLayer( - in_channels=in_channels, - out_channels=out_channels, - dilation=dilation, - activation_type=activation_type, - num_blocks=1, - ) - for dilation in dilations - ] - ) - self.skip = None - if in_channels != out_channels: - # Conv2dAtrous -> BatchNorm2d - self.skip = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ) - self.final_act = SetActivation(activation_type=activation_type) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - residual = x - if self.skip is not None: - # Align channels - residual = self.skip(x) - - for seq in self.res_modules: - residual = residual + seq(x) - - if self.attention_weights is not None: - # Get the attention weights - if self.attention_weights == "spatial_channel": - # Get weights from the residual - attention = self.attention_conv(residual) - elif self.attention_weights == "fractal": - # Get weights from the input - attention = self.attention_conv(x) - - # 1 + γA - attention = 1.0 + self.gamma * attention - residual = residual * attention - - out = self.final_act(residual) - - return out - - -class PoolResidualConv(torch.nn.Module): - """Max pooling followed by a residual convolution.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - pool_size: int = 2, - dropout: T.Optional[float] = None, - dilations: T.List[int] = None, - attention_weights: str = None, - activation_type: str = "LeakyReLU", - res_block_type: enum = ResBlockTypes.RESA, - ): - super(PoolResidualConv, self).__init__() - - assert res_block_type in (ResBlockTypes.RES, ResBlockTypes.RESA) - - layers = [torch.nn.MaxPool2d(pool_size)] - - if dropout is not None: - assert isinstance( - dropout, float - ), "The dropout arg must be a float." - layers += [torch.nn.Dropout(dropout)] - - if res_block_type == ResBlockTypes.RES: - layers += [ - ResidualConv( - in_channels, - out_channels, - attention_weights=attention_weights, - dilation=dilations[0], - activation_type=activation_type, - ) - ] - else: - layers += [ - ResidualAConv( - in_channels, - out_channels, - attention_weights=attention_weights, - dilations=dilations, - activation_type=activation_type, - ) - ] - - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class SingleConv3d(torch.nn.Module): - """A single convolution layer.""" - - def __init__(self, in_channels: int, out_channels: int): - super(SingleConv3d, self).__init__() - - self.seq = ConvBlock3d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class SingleConv(torch.nn.Module): - """A single convolution layer.""" - - def __init__( - self, - in_channels: int, - out_channels: int, - activation_type: str = "LeakyReLU", - ): - super(SingleConv, self).__init__() - - self.seq = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type=activation_type, - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) - - -class TemporalConv(torch.nn.Module): - """A temporal convolution layer.""" - - def __init__( - self, in_channels: int, hidden_channels: int, out_channels: int - ): - super(TemporalConv, self).__init__() - - layers = [ - ConvBlock3d( - in_channels=in_channels, - in_time=0, - out_channels=hidden_channels, - kernel_size=3, - padding=1, - ), - ConvBlock3d( - in_channels=hidden_channels, - in_time=0, - out_channels=hidden_channels, - kernel_size=3, - padding=2, - dilation=2, - ), - ConvBlock3d( - in_channels=hidden_channels, - in_time=0, - out_channels=out_channels, - kernel_size=1, - padding=0, - ), - ] - self.seq = torch.nn.Sequential(*layers) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.seq(x) diff --git a/src/cultionet/models/convstar.py b/src/cultionet/models/convstar.py deleted file mode 100644 index 425d0ebe..00000000 --- a/src/cultionet/models/convstar.py +++ /dev/null @@ -1,253 +0,0 @@ -"""Sources: - https://www.sciencedirect.com/science/article/pii/S0034425721003230 - https://github.com/0zgur0/ms-convSTAR -""" -import typing as T - -import torch -from torch.autograd import Variable - -from .base_layers import Softmax, ResidualConv - - -class ConvSTARCell(torch.nn.Module): - """Generates a convolutional STAR cell.""" - - def __init__(self, input_size: int, hidden_size: int, kernel_size: int): - super(ConvSTARCell, self).__init__() - - padding = int(kernel_size / 2.0) - self.sigmoid = torch.nn.Sigmoid() - self.tanh = torch.nn.Tanh() - self.input_size = input_size - self.hidden_size = hidden_size - self.gate = torch.nn.Conv2d( - input_size + hidden_size, hidden_size, kernel_size, padding=padding - ) - self.update = torch.nn.Conv2d( - input_size, hidden_size, kernel_size, padding=padding - ) - - torch.nn.init.orthogonal(self.update.weight) - torch.nn.init.orthogonal(self.gate.weight) - torch.nn.init.constant(self.update.bias, 0.0) - torch.nn.init.constant(self.gate.bias, 1.0) - - def forward( - self, inputs: torch.Tensor, prev_state: T.Union[None, torch.Tensor] - ) -> torch.Tensor: - # get batch and spatial sizes - batch_size = inputs.data.size()[0] - spatial_size = inputs.data.size()[2:] - - # generate empty prev_state, if None is provided - if prev_state is None: - state_size = [batch_size, self.hidden_size] + list(spatial_size) - prev_state = Variable(torch.zeros(state_size)) - - # data size is [batch, channel, height, width] - stacked_inputs = torch.cat([inputs, prev_state], dim=1) - gain = self.sigmoid(self.gate(stacked_inputs)) - update = self.tanh(self.update(inputs)) - new_state = gain * prev_state + (1.0 - gain) * update - - return new_state - - -class ConvSTAR(torch.nn.Module): - def __init__( - self, - input_size: int, - hidden_sizes: int, - kernel_sizes: int, - n_layers: int, - ): - """Generates a multi-layer convolutional GRU. Preserves spatial - dimensions across cells, only altering depth. - - :param input_size: integer. depth dimension of input tensors. - :param hidden_sizes: integer or list. depth dimensions of hidden state. - if integer, the same hidden size is used for all cells. - :param kernel_sizes: integer or list. sizes of Conv2d gate kernels. - if integer, the same kernel size is used for all cells. - :param n_layers: integer. number of chained `ConvSTARCell`. - """ - super(ConvSTAR, self).__init__() - - self.input_size = input_size - - if type(hidden_sizes) != list: - self.hidden_sizes = [hidden_sizes] * n_layers - else: - assert ( - len(hidden_sizes) == n_layers - ), "`hidden_sizes` must have the same length as n_layers" - self.hidden_sizes = hidden_sizes - if type(kernel_sizes) != list: - self.kernel_sizes = [kernel_sizes] * n_layers - else: - assert ( - len(kernel_sizes) == n_layers - ), "`kernel_sizes` must have the same length as n_layers" - self.kernel_sizes = kernel_sizes - - self.n_layers = n_layers - - cells = [] - for i in range(self.n_layers): - if i == 0: - input_dim = self.input_size - else: - input_dim = self.hidden_sizes[i - 1] - - cell = ConvSTARCell( - input_dim, self.hidden_sizes[i], self.kernel_sizes[i] - ) - name = f"ConvSTARCell_{str(i).zfill(2)}" - - setattr(self, name, cell) - cells.append(getattr(self, name)) - - self.cells = cells - - def forward( - self, x: torch.Tensor, hidden: T.Union[None, T.List[torch.Tensor]] - ) -> T.List[torch.Tensor]: - """ - :param x: 4D input tensor. (batch, channels, height, width). - :param hidden: list of 4D hidden state representations. (batch, channels, height, width). - :returns upd_hidden: 5D hidden representation. (layer, batch, channels, height, width). - """ - if not hidden: - hidden = [None] * self.n_layers - - input_ = x - upd_hidden = [] - - for layer_idx in range(self.n_layers): - cell = self.cells[layer_idx] - cell_hidden = hidden[layer_idx] - - # pass through layer - upd_cell_hidden = cell(input_, cell_hidden) - upd_hidden.append(upd_cell_hidden) - # update input_ to the last updated hidden layer for next pass - input_ = upd_cell_hidden - - # retain tensors in list to allow different hidden sizes - return upd_hidden - - -class StarRNN(torch.nn.Module): - def __init__( - self, - input_dim: int = 3, - hidden_dim: int = 64, - num_classes_l2: int = 2, - num_classes_last: int = 3, - n_stage: int = 3, - kernel_size: int = 3, - n_layers: int = 6, - cell: str = "star", - crop_type_layer: bool = False, - activation_type: str = "LeakyReLU", - final_activation: str = Softmax(dim=1), - ): - super(StarRNN, self).__init__() - - self.n_layers = n_layers - self.hidden_dim = hidden_dim - self.n_stage = n_stage - self.cell = cell - self.crop_type_layer = crop_type_layer - - self.rnn = ConvSTAR( - input_size=input_dim, - hidden_sizes=hidden_dim, - kernel_sizes=kernel_size, - n_layers=n_layers, - ) - - # Level 2 level (non-crop; crop) - self.final_l2 = torch.nn.Sequential( - ResidualConv( - in_channels=int(hidden_dim * 2), - out_channels=hidden_dim, - dilation=2, - activation_type=activation_type, - ), - torch.nn.Dropout(0.1), - torch.nn.Conv2d( - in_channels=hidden_dim, - out_channels=num_classes_l2, - kernel_size=1, - padding=0, - ), - final_activation, - ) - # Last level (non-crop; crop; edges) - self.final_last = torch.nn.Sequential( - ResidualConv( - in_channels=int(hidden_dim * 3), - out_channels=hidden_dim, - dilation=2, - activation_type=activation_type, - ), - torch.nn.Dropout(0.1), - torch.nn.Conv2d( - in_channels=hidden_dim, - out_channels=num_classes_last, - kernel_size=1, - padding=0, - ), - Softmax(dim=1), - ) - - def forward( - self, x, hidden_s: T.Optional[torch.Tensor] = None - ) -> T.Sequence[torch.Tensor]: - # input shape = (B x C x T x H x W) - batch_size, __, time_size, height, width = x.shape - - # convRNN step - # hidden_s is a list (number of layer) of hidden states of size [B x C x H x W] - if hidden_s is None: - hidden_s = [ - torch.zeros( - (batch_size, self.hidden_dim, height, width), - dtype=x.dtype, - device=x.device, - ) - ] * self.n_layers - - for iter_ in range(0, time_size): - hidden_s = self.rnn(x[:, :, iter_, :, :], hidden_s) - - if self.n_layers == 3: - local_1 = hidden_s[0] - local_2 = hidden_s[1] - elif self.n_stage == 3: - local_1 = hidden_s[1] - local_2 = hidden_s[3] - elif self.n_stage == 2: - local_1 = hidden_s[1] - local_2 = hidden_s[2] - elif self.n_stage == 1: - local_1 = hidden_s[-1] - local_2 = hidden_s[-1] - - h_last = hidden_s[-1] - if self.crop_type_layer: - last_l2 = self.final_l2(local_2) - h = torch.cat([local_2, h_last], dim=1) - last = self.final_last(h) - - return h, last_l2, last - else: - h = torch.cat([local_1, local_2], dim=1) - last_l2 = self.final_l2(h) - h = torch.cat([h, h_last], dim=1) - last = self.final_last(h) - - # The output is (B x C x H x W) - return h, last_l2, last diff --git a/src/cultionet/models/cultio.py b/src/cultionet/models/cultio.py deleted file mode 100644 index 86f05ad6..00000000 --- a/src/cultionet/models/cultio.py +++ /dev/null @@ -1,416 +0,0 @@ -import typing as T -import warnings - -import torch -from torch_geometric.data import Data - -from . import model_utils -from .base_layers import ConvBlock2d, ResidualConv, Softmax -from .nunet import UNet3, UNet3Psi, ResUNet3Psi -from .convstar import StarRNN - - -def scale_min_max( - x: torch.Tensor, - min_in: float, - max_in: float, - min_out: float, - max_out: float, -) -> torch.Tensor: - return (((max_out - min_out) * (x - min_in)) / (max_in - min_in)) + min_out - - -class GeoRefinement(torch.nn.Module): - def __init__( - self, - in_features: int, - in_channels: int = 21, - n_hidden: int = 32, - out_channels: int = 2, - ): - super(GeoRefinement, self).__init__() - - # in_channels = - # StarRNN 3 + 2 - # Distance transform x4 - # Edge sigmoid x4 - # Crop softmax x4 - - self.gc = model_utils.GraphToConv() - self.cg = model_utils.ConvToGraph() - - self.gamma = torch.nn.Parameter(torch.ones((1, out_channels, 1, 1))) - self.geo_attention = torch.nn.Sequential( - ConvBlock2d( - in_channels=2, - out_channels=out_channels, - kernel_size=1, - padding=0, - add_activation=False, - ), - torch.nn.Sigmoid(), - ) - - self.x_res_modules = torch.nn.ModuleList( - [ - torch.nn.Sequential( - ResidualConv( - in_channels=in_features, - out_channels=n_hidden, - dilation=2, - activation_type='SiLU', - ), - torch.nn.Dropout(0.5), - ), - torch.nn.Sequential( - ResidualConv( - in_channels=in_features, - out_channels=n_hidden, - dilation=3, - activation_type='SiLU', - ), - torch.nn.Dropout(0.5), - ), - torch.nn.Sequential( - ResidualConv( - in_channels=in_features, - out_channels=n_hidden, - dilation=4, - activation_type='SiLU', - ), - torch.nn.Dropout(0.5), - ), - ] - ) - self.crop_res_modules = torch.nn.ModuleList( - [ - torch.nn.Sequential( - ResidualConv( - in_channels=in_channels, - out_channels=n_hidden, - dilation=2, - activation_type='SiLU', - ), - torch.nn.Dropout(0.5), - ), - torch.nn.Sequential( - ResidualConv( - in_channels=in_channels, - out_channels=n_hidden, - dilation=3, - activation_type='SiLU', - ), - torch.nn.Dropout(0.5), - ), - torch.nn.Sequential( - ResidualConv( - in_channels=in_channels, - out_channels=n_hidden, - dilation=4, - activation_type='SiLU', - ), - torch.nn.Dropout(0.5), - ), - ] - ) - - self.fc = torch.nn.Sequential( - ConvBlock2d( - in_channels=( - (n_hidden * len(self.x_res_modules)) - + (n_hidden * len(self.crop_res_modules)) - ), - out_channels=n_hidden, - kernel_size=1, - padding=0, - activation_type="SiLU", - ), - torch.nn.Conv2d( - in_channels=n_hidden, - out_channels=out_channels, - kernel_size=1, - padding=0, - ), - ) - self.softmax = Softmax(dim=1) - - def proba_to_logit(self, x: torch.Tensor) -> torch.Tensor: - return torch.log(x / (1.0 - x)) - - def forward( - self, predictions: T.Dict[str, torch.Tensor], data: Data - ) -> T.Dict[str, torch.Tensor]: - """A single forward pass. - - Edge and crop inputs should be probabilities - """ - height = ( - int(data.height) if data.batch is None else int(data.height[0]) - ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) - - latitude_norm = scale_min_max( - data.top - ((data.top - data.bottom) * 0.5), -90.0, 90.0, 0.0, 1.0 - ) - longitude_norm = scale_min_max( - data.left + ((data.right - data.left) * 0.5), - -180.0, - 180.0, - 0.0, - 1.0, - ) - lat_lon = torch.cat( - [ - latitude_norm.reshape(*latitude_norm.shape, 1, 1, 1), - longitude_norm.reshape(*longitude_norm.shape, 1, 1, 1), - ], - dim=1, - ) - geo_attention = self.geo_attention(lat_lon) - geo_attention = 1.0 + self.gamma * geo_attention - - crop_x = torch.cat( - [ - predictions["crop_star_l2"], - predictions["crop_star"], - predictions["dist"], - predictions["dist_3_1"], - predictions["dist_2_2"], - predictions["dist_1_3"], - predictions["edge"], - predictions["edge_3_1"], - predictions["edge_2_2"], - predictions["edge_1_3"], - predictions["crop"], - predictions["crop_3_1"], - predictions["crop_2_2"], - predictions["crop_1_3"], - ], - dim=1, - ) - x = self.gc(data.x, batch_size, height, width) - x = torch.cat([m(x) for m in self.x_res_modules], dim=1) - - crop_x = self.gc(crop_x, batch_size, height, width) - crop_x = torch.cat([m(crop_x) for m in self.crop_res_modules], dim=1) - - x = torch.cat([x, crop_x], dim=1) - x = self.softmax(self.fc(x) * geo_attention) - predictions["crop"] = self.cg(x) - - return predictions - - -class CropTypeFinal(torch.nn.Module): - def __init__(self, in_channels: int, out_channels: int, out_classes: int): - super(CropTypeFinal, self).__init__() - - self.in_channels = in_channels - self.out_channels = out_channels - self.out_classes = out_classes - - self.conv1 = ConvBlock2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - activation_type="ReLU", - ) - layers1 = [ - ConvBlock2d( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - padding=1, - activation_type="ReLU", - ), - torch.nn.Conv2d( - out_channels, - out_channels, - kernel_size=3, - padding=1, - bias=False, - ), - torch.nn.BatchNorm2d(out_channels), - ] - self.seq = torch.nn.Sequential(*layers1) - - layers_final = [ - torch.nn.ReLU(inplace=False), - torch.nn.Conv2d( - out_channels, out_classes, kernel_size=1, padding=0 - ), - ] - self.final = torch.nn.Sequential(*layers_final) - - def forward( - self, x: torch.Tensor, crop_type_star: torch.Tensor - ) -> torch.Tensor: - out1 = self.conv1(x) - out = self.seq(out1) - out = out + out1 - out = self.final(out) - out = out + crop_type_star - - return out - - -def check_batch_dims(batch: Data, attribute: str): - batch_var = getattr(batch, attribute) - if not (batch_var == batch_var[0]).all(): - invalid = batch.train_id[batch_var != torch.mode(batch_var)[0]] - warnings.warn("The following ids do not match the batch mode.") - warnings.warn(invalid) - raise ValueError(f"The {attribute} dimensions do not align.") - - -class CultioNet(torch.nn.Module): - """The cultionet model framework. - - Args: - ds_features (int): The total number of dataset features (bands x time). - ds_time_features (int): The number of dataset time features in each band/channel. - filters (int): The number of output filters for each stream. - num_classes (int): The number of output mask/crop classes. - model_type (str): The model architecture type. - activation_type (str): The nonlinear activation. - dilations (int | list): The convolution dilation or dilations. - res_block_type (str): The residual convolution block type. - attention_weights (str): The attention weight type. - deep_sup_dist (bool): Whether to use deep supervision on the distance layer. - deep_sup_edge (bool): Whether to use deep supervision on the edge layer. - deep_sup_mask (bool): Whether to use deep supervision on the mask layer. - """ - - def __init__( - self, - ds_features: int, - ds_time_features: int, - filters: int = 32, - num_classes: int = 2, - model_type: str = "ResUNet3Psi", - activation_type: str = "SiLU", - dilations: T.Union[int, T.Sequence[int]] = None, - res_block_type: str = "resa", - attention_weights: str = "spatial_channel", - deep_sup_dist: bool = False, - deep_sup_edge: bool = False, - deep_sup_mask: bool = False, - ): - super(CultioNet, self).__init__() - - # Total number of features (time x bands/indices/channels) - self.ds_num_features = ds_features - # Total number of time features - self.ds_num_time = ds_time_features - # Total number of bands - self.ds_num_bands = int(self.ds_num_features / self.ds_num_time) - self.filters = filters - self.num_classes = num_classes - - self.gc = model_utils.GraphToConv() - self.cg = model_utils.ConvToGraph() - self.ct = model_utils.ConvToTime() - - self.star_rnn = StarRNN( - input_dim=self.ds_num_bands, - hidden_dim=self.filters, - n_layers=3, - num_classes_l2=self.num_classes, - num_classes_last=self.num_classes + 1, - crop_type_layer=True if self.num_classes > 2 else False, - activation_type=activation_type, - final_activation=Softmax(dim=1), - ) - unet3_kwargs = { - "in_channels": self.ds_num_bands, - "in_time": self.ds_num_time, - "in_rnn_channels": int(self.filters * 3), - "init_filter": self.filters, - "num_classes": self.num_classes, - "activation_type": activation_type, - "deep_sup_dist": deep_sup_dist, - "deep_sup_edge": deep_sup_edge, - "deep_sup_mask": deep_sup_mask, - "mask_activation": Softmax(dim=1), - } - assert model_type in ( - "UNet3Psi", - "ResUNet3Psi", - ), "The model type is not supported." - if model_type == "UNet3Psi": - unet3_kwargs["dilation"] = 2 if dilations is None else dilations - assert isinstance( - unet3_kwargs["dilation"], int - ), "The dilation for UNet3Psi must be an integer." - self.mask_model = UNet3Psi(**unet3_kwargs) - elif model_type == "ResUNet3Psi": - # ResUNet3Psi - unet3_kwargs["attention_weights"] = ( - None if attention_weights == "none" else attention_weights - ) - unet3_kwargs["res_block_type"] = res_block_type - if res_block_type == "res": - unet3_kwargs["dilations"] = ( - [2] if dilations is None else dilations - ) - assert ( - len(unet3_kwargs["dilations"]) == 1 - ), "The dilations for ResUNet3Psi must be a length-1 integer sequence." - elif res_block_type == "resa": - unet3_kwargs["dilations"] = ( - [1, 2] if dilations is None else dilations - ) - assert isinstance( - unet3_kwargs["dilations"], list - ), "The dilations for ResUNet3Psi must be a sequence of integers." - self.mask_model = ResUNet3Psi(**unet3_kwargs) - - def forward(self, data: Data) -> T.Dict[str, torch.Tensor]: - height = ( - int(data.height) if data.batch is None else int(data.height[0]) - ) - width = int(data.width) if data.batch is None else int(data.width[0]) - batch_size = 1 if data.batch is None else data.batch.unique().size(0) - - for attribute in ("ntime", "nbands", "height", "width"): - check_batch_dims(data, attribute) - - # Reshape from ((H*W) x (C*T)) -> (B x C x H x W) - x = self.gc(data.x, batch_size, height, width) - # Reshape from (B x C x H x W) -> (B x C x T|D x H x W) - x = self.ct(x, nbands=self.ds_num_bands, ntime=self.ds_num_time) - # StarRNN - logits_star_hidden, logits_star_l2, logits_star_last = self.star_rnn(x) - logits_star_l2 = self.cg(logits_star_l2) - logits_star_last = self.cg(logits_star_last) - # Main stream - logits = self.mask_model(x, logits_star_hidden) - logits_distance = self.cg(logits["dist"]) - logits_edges = self.cg(logits["edge"]) - logits_crop = self.cg(logits["mask"]) - - out = { - "dist": logits_distance, - "edge": logits_edges, - "crop": logits_crop, - "crop_type": None, - "crop_star_l2": logits_star_l2, - "crop_star": logits_star_last, - } - - if logits["dist_3_1"] is not None: - out["dist_3_1"] = self.cg(logits["dist_3_1"]) - out["dist_2_2"] = self.cg(logits["dist_2_2"]) - out["dist_1_3"] = self.cg(logits["dist_1_3"]) - if logits["mask_3_1"] is not None: - out["crop_3_1"] = self.cg(logits["mask_3_1"]) - out["crop_2_2"] = self.cg(logits["mask_2_2"]) - out["crop_1_3"] = self.cg(logits["mask_1_3"]) - if logits["edge_3_1"] is not None: - out["edge_3_1"] = self.cg(logits["edge_3_1"]) - out["edge_2_2"] = self.cg(logits["edge_2_2"]) - out["edge_1_3"] = self.cg(logits["edge_1_3"]) - - return out diff --git a/src/cultionet/models/cultionet.py b/src/cultionet/models/cultionet.py new file mode 100644 index 00000000..467fd59b --- /dev/null +++ b/src/cultionet/models/cultionet.py @@ -0,0 +1,110 @@ +import typing as T + +import einops +import torch +import torch.nn as nn + +from ..data import Data +from ..enums import AttentionTypes, InferenceNames, ModelTypes, ResBlockTypes +from .nunet import TowerUNet + + +class CultioNet(nn.Module): + """The cultionet model framework. + + Parameters + ========== + in_channels + The total number of dataset features (bands x time). + in_time + The number of dataset time features in each band/channel. + hidden_channels + The number of hidden channels. + model_type + The model architecture type. + activation_type + The nonlinear activation. + dropout + The dropout fraction / probability. + dilations + The convolution dilation or dilations. + res_block_type + The residual convolution block type. + attention_weights + The attention weight type. + pool_by_max + Whether to apply max pooling before residual block. + batchnorm_first + Whether to apply BatchNorm2d -> Activation -> Convolution2d. Otherwise, + apply Convolution2d -> BatchNorm2d -> Activation. + """ + + def __init__( + self, + in_channels: int, + in_time: int, + hidden_channels: int = 32, + model_type: str = ModelTypes.TOWERUNET, + activation_type: str = "SiLU", + dropout: float = 0.1, + dilations: T.Union[int, T.Sequence[int]] = None, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + pool_by_max: bool = False, + batchnorm_first: bool = False, + use_latlon: bool = False, + ): + super().__init__() + + self.in_channels = in_channels + self.in_time = in_time + self.hidden_channels = hidden_channels + + mask_model_kwargs = { + "in_channels": self.in_channels, + "in_time": self.in_time, + "hidden_channels": self.hidden_channels, + "num_classes": 1, + "attention_weights": attention_weights, + "res_block_type": res_block_type, + "dropout": dropout, + "dilations": dilations, + "activation_type": activation_type, + "edge_activation": True, + "mask_activation": True, + "pool_by_max": pool_by_max, + "batchnorm_first": batchnorm_first, + "use_latlon": use_latlon, + } + + assert model_type in ( + ModelTypes.TOWERUNET + ), "The model type is not supported." + + self.mask_model = TowerUNet(**mask_model_kwargs) + + def forward(self, batch: Data) -> T.Dict[str, torch.Tensor]: + + latlon_coords = torch.cat( + ( + einops.rearrange(batch.lon, 'b -> b 1'), + einops.rearrange(batch.lat, 'b -> b 1'), + ), + dim=1, + ) + + # Main stream + out = self.mask_model( + batch.x, + latlon_coords=latlon_coords, + ) + + out.update( + { + InferenceNames.CROP_TYPE: None, + InferenceNames.CLASSES_L2: None, + InferenceNames.CLASSES_L3: None, + } + ) + + return out diff --git a/src/cultionet/models/enums.py b/src/cultionet/models/enums.py deleted file mode 100644 index dce7639e..00000000 --- a/src/cultionet/models/enums.py +++ /dev/null @@ -1,11 +0,0 @@ -import enum - - -class ModelTypes(enum.Enum): - UNET = enum.auto() - RESUNET = enum.auto() - - -class ResBlockTypes(enum.Enum): - RES = enum.auto() - RESA = enum.auto() diff --git a/src/cultionet/models/kernels.py b/src/cultionet/models/kernels.py deleted file mode 100644 index 5afa676d..00000000 --- a/src/cultionet/models/kernels.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Source: - @inproceedings{ismail-fawaz2022hccf, - author = {Ismail-Fawaz, Ali and Devanne, Maxime and Weber, Jonathan and Forestier, Germain}, - title = {Deep Learning For Time Series Classification Using New Hand-Crafted Convolution Filters}, - booktitle = {2022 IEEE International Conference on Big Data (IEEE BigData 2022)}, - city = {Osaka}, - country = {Japan}, - pages = {972-981}, - url = {doi.org/10.1109/BigData55660.2022.10020496}, - year = {2022}, - organization = {IEEE} - } - -Paper: - https://germain-forestier.info/publis/bigdata2022.pdf - -Code: - https://github.com/MSD-IRIMAS/CF-4-TSC -""" -import torch -import torch.nn.functional as F - - -class Trend(torch.nn.Module): - def __init__(self, kernel_size: int, direction: str = "positive"): - super(Trend, self).__init__() - - assert direction in ( - "positive", - "negative", - ), "The trend direction must be one of 'positive' or 'negative'." - - self.padding = int(kernel_size / 2) - self.weights = torch.ones(kernel_size) - indices_ = torch.arange(kernel_size) - if direction == "positive": - self.weights[indices_ % 2 == 0] *= -1 - elif direction == "negative": - self.weights[indices_ % 2 > 0] *= -1 - - self.weights = self.weights[(None,) * 2] - self.relu = torch.nn.ReLU(inplace=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x = (B x C x T) - x = F.conv1d( - x, - self.weights.to(dtype=x.dtype, device=x.device), - bias=None, - stride=1, - padding=self.padding, - dilation=1, - groups=1, - ) - x = self.relu(x) - - return x - - -class Peaks(torch.nn.Module): - def __init__(self, kernel_size: int, radius: int = 9, sigma: float = 1.5): - super(Peaks, self).__init__() - - self.padding = int(kernel_size / 2) - x = torch.linspace(-radius, radius + 1, kernel_size) - mu = 0.0 - gaussian = ( - 1.0 - / (torch.sqrt(torch.tensor([2.0 * torch.pi])) * sigma) - * torch.exp(-1.0 * (x - mu) ** 2 / (2.0 * sigma**2)) - ) - self.weights = gaussian * (x**2 / sigma**4 - 1.0) / sigma**2 - self.weights -= self.weights.mean() - self.weights /= torch.sum(self.weights * x**2) / 2.0 - self.weights *= -1.0 - - self.weights = self.weights[(None,) * 2] - self.relu = torch.nn.ReLU(inplace=False) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # x = (B x C x T) - x = F.conv1d( - x, - self.weights.to(dtype=x.dtype, device=x.device), - bias=None, - stride=1, - padding=self.padding, - dilation=1, - groups=1, - ) - x = self.relu(x) - - return x diff --git a/src/cultionet/models/lightning.py b/src/cultionet/models/lightning.py index 023141f8..a1143f49 100644 --- a/src/cultionet/models/lightning.py +++ b/src/cultionet/models/lightning.py @@ -1,527 +1,121 @@ +import logging import typing as T -from pathlib import Path -import json import warnings -import logging +from pathlib import Path +import einops import pandas as pd import torch import torch.nn.functional as F -from torch.optim import lr_scheduler as optim_lr_scheduler -from torch_geometric.data import Data -from pytorch_lightning import LightningModule -from torchvision.ops import box_iou -from torchvision import transforms import torchmetrics +from lightning import LightningModule +from torch.optim import lr_scheduler as optim_lr_scheduler -from . import model_utils -from .cultio import CultioNet, GeoRefinement -from .maskcrnn import BFasterRCNN -from .base_layers import Softmax -from ..losses import TanimotoDistLoss - +from .. import losses as cnetlosses +from .. import nn as cunn +from ..data import Data +from ..enums import ( + AttentionTypes, + InferenceNames, + LearningRateSchedulers, + LossTypes, + ModelNames, + ModelTypes, + ResBlockTypes, + ValidationNames, +) +from ..layers.weights import init_conv_weights +from .cultionet import CultioNet warnings.filterwarnings("ignore") logging.getLogger("lightning").addHandler(logging.NullHandler()) logging.getLogger("lightning").propagate = False logging.getLogger("lightning").setLevel(logging.ERROR) - -class MaskRCNNLitModel(LightningModule): - def __init__( - self, - cultionet_model_file: Path, - cultionet_num_features: int, - cultionet_num_time_features: int, - cultionet_filters: int, - cultionet_num_classes: int, - ckpt_name: str = "maskrcnn", - model_name: str = "maskrcnn", - learning_rate: float = 1e-3, - weight_decay: float = 1e-5, - resize_height: int = 201, - resize_width: int = 201, - min_image_size: int = 100, - max_image_size: int = 500, - trainable_backbone_layers: int = 3, - ): - """Lightning model. - - Args: - num_features - num_time_features - filters - learning_rate - weight_decay - """ - super(MaskRCNNLitModel, self).__init__() - self.save_hyperparameters() - - self.ckpt_name = ckpt_name - self.model_name = model_name - self.learning_rate = learning_rate - self.weight_decay = weight_decay - self.num_classes = 2 - self.resize_height = resize_height - self.resize_width = resize_width - - self.cultionet_model = CultioLitModel( - num_features=cultionet_num_features, - num_time_features=cultionet_num_time_features, - filters=cultionet_filters, - num_classes=cultionet_num_classes, - ) - self.cultionet_model.load_state_dict( - state_dict=torch.load(cultionet_model_file) - ) - self.cultionet_model.eval() - self.cultionet_model.freeze() - self.model = BFasterRCNN( - in_channels=4, - out_channels=256, - num_classes=self.num_classes, - sizes=(16, 32, 64, 128, 256), - aspect_ratios=(0.5, 1.0, 3.0), - trainable_backbone_layers=trainable_backbone_layers, - min_image_size=min_image_size, - max_image_size=max_image_size, - ) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def mask_forward( - self, - distance_ori: torch.Tensor, - distance: torch.Tensor, - edge: torch.Tensor, - crop_r: torch.Tensor, - height: T.Union[None, int, torch.Tensor], - width: T.Union[None, int, torch.Tensor], - batch: T.Union[None, int, torch.Tensor], - y: T.Union[None, torch.Tensor] = None, - ): - height = int(height) if batch is None else int(height[0]) - width = int(width) if batch is None else int(width[0]) - batch_size = 1 if batch is None else batch.unique().size(0) - x = torch.cat( - ( - distance_ori, - distance, - edge[:, 1][:, None], - crop_r[:, 1][:, None], - ), - dim=1, - ) - # in x = (H*W x C) - # new x = (B x C x H x W) - gc = model_utils.GraphToConv() - x = gc(x, batch_size, height, width) - resizer = transforms.Resize((self.resize_height, self.resize_width)) - x = [resizer(image) for image in x] - targets = None - if y is not None: - targets = [] - for bidx in y["image_id"].unique(): - batch_dict = {} - batch_slice = y["image_id"] == bidx - for k in y.keys(): - if k == "masks": - batch_dict[k] = resizer(y[k][batch_slice]) - elif k == "boxes": - # [xmin, ymin, xmax, ymax] - batch_dict[k] = self.scale_boxes( - y[k][batch_slice], batch, [height] - ) - else: - batch_dict[k] = y[k][batch_slice] - targets.append(batch_dict) - outputs = self.model(x, targets) - - return outputs - - def scale_boxes( - self, - boxes: torch.Tensor, - batch: torch.Tensor, - height: T.Union[None, int, T.List[int], torch.Tensor], - ): - height = int(height) if batch is None else int(height[0]) - scale = self.resize_height / height - - return boxes * scale - - def forward( - self, - batch: Data, - batch_idx: int = None, - y: T.Optional[torch.Tensor] = None, - ) -> T.Tuple[ - torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor - ]: - """Performs a single model forward pass.""" - with torch.no_grad(): - distance_ori, distance, edge, __, crop_r = self.cultionet_model( - batch - ) - estimates = self.mask_forward( - distance_ori, - distance, - edge, - crop_r, - height=batch.height, - width=batch.width, - batch=batch.batch, - y=y, - ) - - return estimates - - def on_save_checkpoint(self, checkpoint): - """Save the checkpoint.""" - ckpt_file = Path(self.logger.save_dir) / f"{self.ckpt_name}.ckpt" - if ckpt_file.is_file(): - ckpt_file.unlink() - torch.save(checkpoint, ckpt_file) - - def on_validation_epoch_end(self, *args, **kwargs): - """Save the model on validation end.""" - model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt" - if model_file.is_file(): - model_file.unlink() - torch.save(self.state_dict(), model_file) - - def calc_loss( - self, batch: T.Union[Data, T.List], y: T.Optional[torch.Tensor] = None - ): - """Calculates the loss for each layer. - - Returns: - Average loss - """ - losses = self(batch, y=y) - loss = sum(loss for loss in losses.values()) - - return loss - - def training_step(self, batch: Data, batch_idx: int = None): - """Executes one training step.""" - y = { - "boxes": batch.boxes, - "labels": batch.box_labels, - "masks": batch.box_masks, - "image_id": batch.image_id, - } - loss = self.calc_loss(batch, y=y) - self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True) - - return loss - - def _shared_eval_step(self, batch: Data) -> dict: - # Predictions - instances = self(batch) - # True boxes - true_boxes = self.scale_boxes(batch.boxes, batch, batch.height) - - predict_iou_score = torch.tensor(0.0, device=self.device) - iou_score = torch.tensor(0.0, device=self.device) - box_score = torch.tensor(0.0, device=self.device) - for bidx, batch_value in enumerate(batch.image_id.unique()): - # This should be low (i.e., low overlap of predicted boxes) - predict_iou_score += box_iou( - instances[bidx]["boxes"], instances[bidx]["boxes"] - ).mean() - # This should be high (i.e., high overlap of predictions and true boxes) - iou_score += box_iou( - true_boxes[batch.image_id == batch_value], - instances[bidx]["boxes"], - ).mean() - # This should be high (i.e., masks should be confident) - box_score += instances[bidx]["scores"].mean() - predict_iou_score /= batch.image_id.unique().size(0) - iou_score /= batch.image_id.unique().size(0) - box_score /= batch.image_id.unique().size(0) - - total_iou_score = (predict_iou_score + (1.0 - iou_score)) * 0.5 - box_score = 1.0 - box_score - # Minimize intersection-over-union and maximum score - total_score = (total_iou_score + box_score) * 0.5 - - metrics = { - "predict_iou_score": predict_iou_score, - "iou_score": iou_score, - "box_score": box_score, - "mean_score": total_score, - } - - return metrics - - def validation_step(self, batch: Data, batch_idx: int = None) -> dict: - """Executes one valuation step.""" - eval_metrics = self._shared_eval_step(batch) - - metrics = { - "val_loss": eval_metrics["mean_score"], - "val_piou": eval_metrics["predict_iou_score"], - "val_iou": eval_metrics["iou_score"], - "val_box": eval_metrics["box_score"], - } - self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) - - return metrics - - def test_step(self, batch: Data, batch_idx: int = None) -> dict: - """Executes one test step.""" - eval_metrics = self._shared_eval_step(batch) - - metrics = { - "test_loss": eval_metrics["mean_score"], - "test_piou": eval_metrics["predict_iou_score"], - "test_iou": eval_metrics["iou_score"], - "test_box": eval_metrics["box_score"], - } - self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) - - return metrics - - def configure_optimizers(self): - optimizer = torch.optim.AdamW( - list(self.model.parameters()), - lr=self.learning_rate, - weight_decay=self.weight_decay, - eps=1e-4, - ) - lr_scheduler = optim_lr_scheduler.ReduceLROnPlateau( - optimizer, factor=0.1, patience=5 - ) - - return { - "optimizer": optimizer, - "scheduler": lr_scheduler, - "monitor": "val_loss", - } - - -def scale_logits(x: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return x / t - - -class RefineLitModel(LightningModule): - def __init__( - self, - in_features: int, - num_classes: int = 2, - learning_rate: float = 1e-3, - weight_decay: float = 0.01, - eps: float = 1e-4, - edge_class: int = 2, - class_counts: T.Optional[torch.Tensor] = None, - cultionet_ckpt: T.Optional[T.Union[Path, str]] = None, - ): - super(RefineLitModel, self).__init__() - - self.save_hyperparameters() - - self.learning_rate = learning_rate - self.weight_decay = weight_decay - self.eps = eps - self.edge_class = edge_class - self.class_counts = class_counts - self.cultionet_ckpt = cultionet_ckpt - - self.cultionet_model = None - self.geo_refine_model = GeoRefinement( - in_features=in_features, out_channels=num_classes - ) - - self.configure_loss() - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def forward( - self, - predictions: T.Dict[str, torch.Tensor], - batch: Data, - batch_idx: int = None, - ) -> T.Dict[str, torch.Tensor]: - return self.geo_refine_model(predictions, data=batch) - - def set_true_labels(self, batch: Data) -> torch.Tensor: - # in case of multi-class, `true_crop` = 1, 2, etc. - true_crop = torch.where( - (batch.y > 0) & (batch.y != self.edge_class), 1, 0 - ).long() - - return true_crop - - def calc_loss( - self, - batch: T.Union[Data, T.List], - predictions: T.Dict[str, torch.Tensor], - ): - true_crop = self.set_true_labels(batch) - # Predicted crop values are probabilities - loss = self.crop_loss(predictions["crop"], true_crop) - - return loss - - def training_step( - self, batch: Data, batch_idx: int = None, optimizer_idx: int = None - ): - """Executes one training step.""" - # Apply inference with the main cultionet model - if (self.cultionet_ckpt is not None) and ( - self.cultionet_model is None - ): - self.cultionet_model = CultioLitModel.load_from_checkpoint( - checkpoint_path=str(self.cultionet_ckpt) - ) - self.cultionet_model.to(self.device) - self.cultionet_model.eval() - self.cultionet_model.freeze() - with torch.no_grad(): - predictions = self.cultionet_model(batch) - - predictions = self(predictions, batch) - loss = self.calc_loss(batch, predictions) - - metrics = {"loss": loss} - self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) - - return metrics - - def on_train_epoch_end(self, *args, **kwargs): - """Save the scaling parameters on training end.""" - if self.logger.save_dir is not None: - model_file = Path(self.logger.save_dir) / "refine.pt" - if model_file.is_file(): - model_file.unlink() - torch.save(self.geo_refine_model.state_dict(), model_file) - - def configure_loss(self): - self.crop_loss = TanimotoDistLoss(scale_pos_weight=True) - - def configure_optimizers(self): - optimizer = torch.optim.AdamW( - list(self.geo_refine_model.parameters()), - lr=self.learning_rate, - weight_decay=self.weight_decay, - eps=self.eps, - ) - lr_scheduler = optim_lr_scheduler.CosineAnnealingLR( - optimizer, T_max=20, eta_min=1e-5, last_epoch=-1 - ) - - return { - "optimizer": optimizer, - "lr_scheduler": { - "scheduler": lr_scheduler, - "monitor": "loss", - "interval": "epoch", - "frequency": 1, - }, - } - - -class CultioLitModel(LightningModule): - def __init__( - self, - num_features: int = None, - num_time_features: int = None, - num_classes: int = 2, - filters: int = 32, - model_type: str = "ResUNet3Psi", - activation_type: str = "SiLU", - dilations: T.Union[int, T.Sequence[int]] = None, - res_block_type: str = "resa", - attention_weights: str = "spatial_channel", - optimizer: str = "AdamW", - learning_rate: float = 1e-3, - lr_scheduler: str = "CosineAnnealingLR", - steplr_step_size: int = 5, - weight_decay: float = 0.01, - eps: float = 1e-4, - ckpt_name: str = "last", - model_name: str = "cultionet", - deep_sup_dist: bool = False, - deep_sup_edge: bool = False, - deep_sup_mask: bool = False, - class_counts: T.Optional[torch.Tensor] = None, - edge_class: T.Optional[int] = None, - temperature_lit_model: T.Optional[GeoRefinement] = None, - scale_pos_weight: T.Optional[bool] = True, - save_batch_val_metrics: T.Optional[bool] = False, - ): - """Lightning model.""" - super(CultioLitModel, self).__init__() - - self.save_hyperparameters() - - self.optimizer = optimizer - self.learning_rate = learning_rate - self.lr_scheduler = lr_scheduler - self.steplr_step_size = steplr_step_size - self.weight_decay = weight_decay - self.eps = eps - self.ckpt_name = ckpt_name - self.model_name = model_name - self.num_classes = num_classes - self.num_time_features = num_time_features - self.class_counts = class_counts - self.temperature_lit_model = temperature_lit_model - self.scale_pos_weight = scale_pos_weight - self.save_batch_val_metrics = save_batch_val_metrics - self.deep_sup_dist = deep_sup_dist - self.deep_sup_edge = deep_sup_edge - self.deep_sup_mask = deep_sup_mask - self.sigmoid = torch.nn.Sigmoid() - if edge_class is not None: - self.edge_class = edge_class - else: - self.edge_class = num_classes - - self.model_attr = f"{model_name}_{model_type}" - setattr( - self, - self.model_attr, - CultioNet( - ds_features=num_features, - ds_time_features=num_time_features, - filters=filters, - num_classes=self.num_classes, - model_type=model_type, - activation_type=activation_type, - dilations=dilations, - res_block_type=res_block_type, - attention_weights=attention_weights, - deep_sup_dist=deep_sup_dist, - deep_sup_edge=deep_sup_edge, - deep_sup_mask=deep_sup_mask, - ), - ) - self.configure_loss() - self.configure_scorer() +torch.set_float32_matmul_precision("high") + + +LOSS_DICT = { + LossTypes.BOUNDARY: { + "classification": cnetlosses.BoundaryLoss(), + }, + LossTypes.CLASS_BALANCED_MSE: { + "classification": cnetlosses.ClassBalancedMSELoss(), + }, + LossTypes.LOG_COSH: { + "regression": cnetlosses.LogCoshLoss(), + }, + LossTypes.TANIMOTO_COMPLEMENT: { + "classification": cnetlosses.TanimotoComplementLoss(), + "regression": cnetlosses.TanimotoComplementLoss( + transform_logits=False, + one_hot_targets=False, + ), + }, + LossTypes.TANIMOTO: { + "classification": cnetlosses.TanimotoDistLoss(), + "regression": cnetlosses.TanimotoDistLoss( + transform_logits=False, + one_hot_targets=False, + ), + }, + LossTypes.TANIMOTO_COMBINED: { + "classification": cnetlosses.CombinedLoss( + losses=[ + cnetlosses.TanimotoDistLoss(), + cnetlosses.TanimotoComplementLoss(), + ], + ), + "regression": cnetlosses.CombinedLoss( + losses=[ + cnetlosses.TanimotoDistLoss( + transform_logits=False, + one_hot_targets=False, + ), + cnetlosses.TanimotoComplementLoss( + transform_logits=False, + one_hot_targets=False, + ), + ], + ), + }, + LossTypes.TVERSKY: { + "classification": cnetlosses.TverskyLoss(), + }, + LossTypes.FOCAL_TVERSKY: { + "classification": cnetlosses.FocalTverskyLoss(), + }, +} + + +class LightningModuleMixin(LightningModule): + def __init__(self): + super().__init__() def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - @property - def cultionet_model(self) -> CultioNet: - return getattr(self, self.model_attr) - def forward( self, batch: Data, batch_idx: int = None ) -> T.Dict[str, torch.Tensor]: """Performs a single model forward pass. - Returns: - distance: Normalized distance transform (from boundaries), [0,1]. - edge: Probabilities of edge|non-edge, [0,1]. - crop: Logits of crop|non-crop. + Returns + ======= + distance + Normalized distance transform (from boundaries), [0,1]. Shaped (B, 1, H, W). + edge + Edge|non-edge predictions, logits or probabilities. Shaped (B, 1, H, W). + crop + Logits of crop|non-crop. Shaped (B, C, H, W). """ return self.cultionet_model(batch) + @property + def cultionet_model(self) -> CultioNet: + """Get the network model name.""" + return getattr(self, self.model_attr) + @staticmethod def get_cuda_memory(): t = torch.cuda.get_device_properties(0).total_memory @@ -529,20 +123,47 @@ def get_cuda_memory(): a = torch.cuda.memory_allocated(0) print(f"{t * 1e-6:.02f}MB", f"{r * 1e-6:.02f}MB", f"{a * 1e-6:.02f}MB") + def probas_to_labels( + self, x: torch.Tensor, thresh: float = 0.5 + ) -> torch.Tensor: + """Converts probabilities to class labels.""" + + if x.shape[1] == 1: + labels = x.gt(thresh).squeeze(dim=1).long() + else: + labels = x.argmax(dim=1).long() + + return labels + + def logits_to_probas(self, x: torch.Tensor) -> T.Union[None, torch.Tensor]: + """Transforms logits to probabilities.""" + + if x is not None: + if x.shape[1] > 1: + x = F.softmax(x, dim=1, dtype=x.dtype) + else: + # Single-dimension inputs are sigmoid probabilities + x = F.sigmoid(x) + + x = x.clip(0, 1) + + return x + def predict_step( self, batch: Data, batch_idx: int = None ) -> T.Dict[str, torch.Tensor]: """A prediction step for Lightning.""" - predictions = self.forward(batch, batch_idx) - if self.temperature_lit_model is not None: - predictions = self.temperature_lit_model(predictions, batch) + + predictions = self.forward(batch, batch_idx=batch_idx) return predictions + @torch.no_grad def get_true_labels( self, batch: Data, crop_type: torch.Tensor = None ) -> T.Dict[str, T.Union[None, torch.Tensor]]: """Gets true labels from the data batch.""" + true_edge = torch.where(batch.y == self.edge_class, 1, 0).long() # Recode all crop classes to 1, otherwise 0 true_crop = torch.where( @@ -568,54 +189,23 @@ def get_true_labels( batch.y == self.edge_class, 0, batch.y ).long() + # Weak supervision mask + mask = None + if batch.y.min() == -1: + mask = torch.where(batch.y == -1, 0, 1).to( + dtype=torch.long, device=batch.y.device + ) + mask = einops.rearrange(mask, 'b h w -> b 1 h w') + return { - "true_edge": true_edge, - "true_crop": true_crop, - "true_crop_and_edge": true_crop_and_edge, - "true_crop_or_edge": true_crop_or_edge, - "true_crop_type": true_crop_type, + ValidationNames.TRUE_EDGE: true_edge, + ValidationNames.TRUE_CROP: true_crop, + ValidationNames.TRUE_CROP_AND_EDGE: true_crop_and_edge, + ValidationNames.TRUE_CROP_OR_EDGE: true_crop_or_edge, + ValidationNames.TRUE_CROP_TYPE: true_crop_type, + ValidationNames.MASK: mask, } - def softmax(self, x: torch.Tensor, dim: int = 1) -> torch.Tensor: - return F.softmax(x, dim=dim, dtype=x.dtype) - - def probas_to_labels( - self, x: torch.Tensor, thresh: float = 0.5 - ) -> torch.Tensor: - if x.shape[1] == 1: - labels = x.gt(thresh).long() - else: - labels = x.argmax(dim=1).long() - - return labels - - def logits_to_probas(self, x: torch.Tensor) -> T.Union[None, torch.Tensor]: - if x is not None: - # Single-dimension inputs are sigmoid probabilities - if x.shape[1] > 1: - # Transform logits to probabilities - x = self.softmax(x) - else: - x = self.sigmoid(x) - x = x.clip(0, 1) - - return x - - # def on_train_epoch_start(self): - # # Get the current learning rate from the optimizer - # eps = self.optimizers().optimizer.param_groups[0]['eps'] - # weight_decay = self.optimizers().optimizer.param_groups[0]['weight_decay'] - # if (weight_decay != self.weight_decay) or (eps != self.eps): - # self.configure_optimizers() - - def on_validation_epoch_end(self, *args, **kwargs): - """Save the model on validation end.""" - if self.logger.save_dir is not None: - model_file = Path(self.logger.save_dir) / f"{self.model_name}.pt" - if model_file.is_file(): - model_file.unlink() - torch.save(self.state_dict(), model_file) - def calc_loss( self, batch: T.Union[Data, T.List], @@ -623,206 +213,296 @@ def calc_loss( ): """Calculates the loss. - Returns: - Total loss + Returns + ======= + Total loss """ - true_labels_dict = self.get_true_labels( - batch, crop_type=predictions["crop_type"] - ) - # RNN level 2 loss (non-crop=0; crop|edge=1) - crop_star_l2_loss = self.crop_star_l2_loss( - predictions["crop_star_l2"], true_labels_dict["true_crop_and_edge"] - ) - # RNN final loss (non-crop=0; crop=1; edge=2) - crop_star_loss = self.crop_star_loss( - predictions["crop_star"], true_labels_dict["true_crop_or_edge"] - ) - # Main loss - loss = ( - # RNN losses - 0.25 * crop_star_l2_loss - + 0.5 * crop_star_loss - ) - # Edge losses - if self.deep_sup_dist: - dist_loss_3_1 = self.dist_loss_3_1( - predictions["dist_3_1"], batch.bdist + weights = { + InferenceNames.DISTANCE: 1.0, + InferenceNames.EDGE: 1.0, + InferenceNames.CROP: 1.0, + } + + with torch.no_grad(): + true_labels_dict = self.get_true_labels( + batch, crop_type=predictions.get(InferenceNames.CROP_TYPE) ) - dist_loss_2_2 = self.dist_loss_2_2( - predictions["dist_2_2"], batch.bdist + + # true_edge_distance = torch.where( + # true_labels_dict[ValidationNames.TRUE_EDGE] == 1, + # 1, + # torch.where( + # true_labels_dict[ValidationNames.TRUE_CROP] == 1, + # (1.0 - batch.bdist) ** 20.0, + # 0, + # ), + # ) + # true_crop_distance = torch.where( + # true_labels_dict[ValidationNames.TRUE_CROP] != 1, + # 0, + # 1.0 - true_edge_distance, + # ) + + # true_edge_distance = einops.rearrange( + # true_edge_distance, 'b h w -> b 1 h w' + # ) + # true_crop_distance = einops.rearrange( + # true_crop_distance, 'b h w -> b 1 h w' + # ) + + loss = 0.0 + + ########################## + # Temporal encoding losses + ########################## + + if predictions[InferenceNames.CLASSES_L2] is not None: + # Temporal encoding level 2 loss (non-crop=0; crop|edge=1) + classes_l2_loss = F.cross_entropy( + predictions[InferenceNames.CLASSES_L2], + true_labels_dict[ValidationNames.TRUE_CROP_AND_EDGE], + weight=self.crop_and_edge_weights, + reduction='none' + if true_labels_dict[ValidationNames.MASK] is not None + else 'mean', ) - dist_loss_1_3 = self.dist_loss_1_3( - predictions["dist_1_3"], batch.bdist + + if true_labels_dict[ValidationNames.MASK] is not None: + classes_l2_loss = classes_l2_loss * einops.rearrange( + true_labels_dict[ValidationNames.MASK], 'b 1 h w -> b h w' + ) + masked_weights = self.crop_and_edge_weights[ + true_labels_dict[ValidationNames.TRUE_CROP_AND_EDGE] + ] * einops.rearrange( + true_labels_dict[ValidationNames.MASK], 'b 1 h w -> b h w' + ) + classes_l2_loss = classes_l2_loss.sum() / masked_weights.sum() + + weights[InferenceNames.CLASSES_L2] = 0.01 + loss = loss + classes_l2_loss * weights[InferenceNames.CLASSES_L2] + + if predictions[InferenceNames.CLASSES_L3] is not None: + # Temporal encoding final loss (non-crop=0; crop=1; edge=2) + classes_last_loss = F.cross_entropy( + predictions[InferenceNames.CLASSES_L3], + true_labels_dict[ValidationNames.TRUE_CROP_OR_EDGE], + weight=self.crop_or_edge_weights, + reduction='none' + if true_labels_dict[ValidationNames.MASK] is not None + else 'mean', ) - # Main loss + + if true_labels_dict[ValidationNames.MASK] is not None: + classes_last_loss = classes_last_loss * einops.rearrange( + true_labels_dict[ValidationNames.MASK], 'b 1 h w -> b h w' + ) + masked_weights = self.crop_or_edge_weights[ + true_labels_dict[ValidationNames.TRUE_CROP_OR_EDGE] + ] * einops.rearrange( + true_labels_dict[ValidationNames.MASK], 'b 1 h w -> b h w' + ) + classes_last_loss = ( + classes_last_loss.sum() / masked_weights.sum() + ) + + weights[InferenceNames.CLASSES_L3] = 0.1 loss = ( - loss - + 0.1 * dist_loss_3_1 - + 0.25 * dist_loss_2_2 - + 0.5 * dist_loss_1_3 + loss + classes_last_loss * weights[InferenceNames.CLASSES_L3] ) + + ############# + # Main losses + ############# + # Distance transform loss - dist_loss = self.dist_loss(predictions["dist"], batch.bdist) - # Main loss - loss = loss + dist_loss - # Distance transform losses - if self.deep_sup_edge: - edge_loss_3_1 = self.edge_loss_3_1( - predictions["edge_3_1"], true_labels_dict["true_edge"] - ) - edge_loss_2_2 = self.edge_loss_2_2( - predictions["edge_2_2"], true_labels_dict["true_edge"] - ) - edge_loss_1_3 = self.edge_loss_1_3( - predictions["edge_1_3"], true_labels_dict["true_edge"] - ) - # Main loss - loss = ( - loss - + 0.1 * edge_loss_3_1 - + 0.25 * edge_loss_2_2 - + 0.5 * edge_loss_1_3 - ) + dist_loss = self.reg_loss( + # Inputs are 0-1 continuous + inputs=predictions[InferenceNames.DISTANCE], + # True data are 0-1 continuous + targets=batch.bdist, + mask=true_labels_dict[ValidationNames.MASK], + ) + loss = loss + dist_loss * weights[InferenceNames.DISTANCE] + # Edge loss - edge_loss = self.edge_loss( - predictions["edge"], true_labels_dict["true_edge"] + edge_loss = self.cls_loss( + # Inputs are single-layer logits or probabilities + inputs=predictions[InferenceNames.EDGE], + # True data are 0|1 + targets=true_labels_dict[ValidationNames.TRUE_EDGE], + mask=true_labels_dict[ValidationNames.MASK], ) - # Main loss - loss = loss + edge_loss - # Crop mask losses - if self.deep_sup_mask: - crop_loss_3_1 = self.crop_loss_3_1( - predictions["crop_3_1"], true_labels_dict["true_crop"] - ) - crop_loss_2_2 = self.crop_loss_2_2( - predictions["crop_2_2"], true_labels_dict["true_crop"] - ) - crop_loss_1_3 = self.crop_loss_1_3( - predictions["crop_1_3"], true_labels_dict["true_crop"] - ) - # Main loss - loss = ( - loss - + 0.1 * crop_loss_3_1 - + 0.25 * crop_loss_2_2 - + 0.5 * crop_loss_1_3 - ) + loss = loss + edge_loss * weights[InferenceNames.EDGE] + # Crop mask loss - crop_loss = self.crop_loss( - predictions["crop"], true_labels_dict["true_crop"] + crop_loss = self.cls_loss( + # Inputs are 2-layer logits or probabilities + inputs=predictions[InferenceNames.CROP], + # True data are 0|1 + targets=true_labels_dict[ValidationNames.TRUE_CROP], + mask=true_labels_dict[ValidationNames.MASK], ) - # Main loss - loss = loss + crop_loss + loss = loss + crop_loss * weights[InferenceNames.CROP] - if predictions["crop_type"] is not None: - # Upstream (deep) loss on crop-type - crop_type_star_loss = self.crop_type_star_loss( - predictions["crop_type_star"], - true_labels_dict["true_crop_type"], - ) - loss = loss + crop_type_star_loss - # Loss on crop-type - crop_type_loss = self.crop_type_loss( - predictions["crop_type"], true_labels_dict["true_crop_type"] - ) - loss = loss + crop_type_loss + loss_report = { + "dloss": dist_loss, + "eloss": edge_loss, + "closs": crop_loss, + } - return loss + return loss / sum(weights.values()), loss_report def training_step(self, batch: Data, batch_idx: int = None): """Executes one training step and logs training step metrics.""" + predictions = self(batch) - loss = self.calc_loss(batch, predictions) - self.log("loss", loss, on_step=False, on_epoch=True, prog_bar=True) + + loss, _ = self.calc_loss(batch, predictions) + + self.log( + "loss", + loss, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=batch.num_samples, + ) return loss def _shared_eval_step(self, batch: Data, batch_idx: int = None) -> dict: + """Evaluation step shared between validation and testing.""" + + # Forward pass to get predictions predictions = self(batch) - loss = self.calc_loss(batch, predictions) - dist_mae = self.dist_mae( - predictions["dist"].contiguous().view(-1), - batch.bdist.contiguous().view(-1), - ) - dist_mse = self.dist_mse( - predictions["dist"].contiguous().view(-1), - batch.bdist.contiguous().view(-1), - ) - # Get the class labels - edge_ypred = self.probas_to_labels(predictions["edge"]) - crop_ypred = self.probas_to_labels(predictions["crop"]) + # Calculate the loss + loss, loss_report = self.calc_loss(batch, predictions) + + # Convert probabilities to class labels + edge_ypred = self.probas_to_labels(predictions[InferenceNames.EDGE]) + crop_ypred = self.probas_to_labels(predictions[InferenceNames.CROP]) + # Get the true edge and crop labels true_labels_dict = self.get_true_labels( - batch, crop_type=predictions["crop_type"] + batch, crop_type=predictions.get(InferenceNames.CROP_TYPE) ) - # F1-score - edge_score = self.edge_f1(edge_ypred, true_labels_dict["true_edge"]) - crop_score = self.crop_f1(crop_ypred, true_labels_dict["true_crop"]) - # MCC - edge_mcc = self.edge_mcc(edge_ypred, true_labels_dict["true_edge"]) - crop_mcc = self.crop_mcc(crop_ypred, true_labels_dict["true_crop"]) - # Dice - edge_dice = self.edge_dice(edge_ypred, true_labels_dict["true_edge"]) - crop_dice = self.crop_dice(crop_ypred, true_labels_dict["true_crop"]) - # Jaccard/IoU - edge_jaccard = self.edge_jaccard( - edge_ypred, true_labels_dict["true_edge"] + + if true_labels_dict[ValidationNames.MASK] is not None: + # Valid sample = True; Invalid sample = False + labels_bool_mask = true_labels_dict[ValidationNames.MASK].to( + dtype=torch.bool + ) + predictions[InferenceNames.DISTANCE] = torch.masked_select( + predictions[InferenceNames.DISTANCE], labels_bool_mask + ) + bdist = torch.masked_select( + batch.bdist, labels_bool_mask.squeeze(dim=1) + ) + + else: + predictions[InferenceNames.DISTANCE] = einops.rearrange( + predictions[InferenceNames.DISTANCE], 'b 1 h w -> (b h w)' + ) + bdist = einops.rearrange(batch.bdist, 'b h w -> (b h w)') + + dist_score_args = (predictions[InferenceNames.DISTANCE], bdist) + + dist_mae = self.mae_scorer(*dist_score_args) + dist_mse = self.mse_scorer(*dist_score_args) + + if true_labels_dict[ValidationNames.MASK] is not None: + edge_ypred = torch.masked_select( + edge_ypred, labels_bool_mask.squeeze(dim=1) + ) + crop_ypred = torch.masked_select( + crop_ypred, labels_bool_mask.squeeze(dim=1) + ) + true_labels_dict[ValidationNames.TRUE_EDGE] = torch.masked_select( + true_labels_dict[ValidationNames.TRUE_EDGE], + labels_bool_mask.squeeze(dim=1), + ) + true_labels_dict[ValidationNames.TRUE_CROP] = torch.masked_select( + true_labels_dict[ValidationNames.TRUE_CROP], + labels_bool_mask.squeeze(dim=1), + ) + + else: + edge_ypred = einops.rearrange(edge_ypred, 'b h w -> (b h w)') + crop_ypred = einops.rearrange(crop_ypred, 'b h w -> (b h w)') + true_labels_dict[ValidationNames.TRUE_EDGE] = einops.rearrange( + true_labels_dict[ValidationNames.TRUE_EDGE], 'b h w -> (b h w)' + ) + true_labels_dict[ValidationNames.TRUE_CROP] = einops.rearrange( + true_labels_dict[ValidationNames.TRUE_CROP], 'b h w -> (b h w)' + ) + + # Scorer input args + edge_score_args = ( + edge_ypred, + true_labels_dict[ValidationNames.TRUE_EDGE], ) - crop_jaccard = self.crop_jaccard( - crop_ypred, true_labels_dict["true_crop"] + crop_score_args = ( + crop_ypred, + true_labels_dict[ValidationNames.TRUE_CROP], ) + # Fβ-score + edge_fscore = self.f_beta_scorer(*edge_score_args) + crop_fscore = self.f_beta_scorer(*crop_score_args) + + # MCC + edge_mcc = self.mcc_scorer(*edge_score_args) + crop_mcc = self.mcc_scorer(*crop_score_args) + total_score = ( loss - + (1.0 - edge_score) - + (1.0 - crop_score) + + (1.0 - edge_fscore) + + (1.0 - crop_fscore) + dist_mae - + (1.0 - edge_mcc) - + (1.0 - crop_mcc) + + (1.0 - edge_mcc.clamp_min(0)) + + (1.0 - crop_mcc.clamp_min(0)) ) metrics = { "loss": loss, "dist_mae": dist_mae, "dist_mse": dist_mse, - "edge_f1": edge_score, - "crop_f1": crop_score, + "edge_f1": edge_fscore, + "crop_f1": crop_fscore, "edge_mcc": edge_mcc, "crop_mcc": crop_mcc, - "edge_dice": edge_dice, - "crop_dice": crop_dice, - "edge_jaccard": edge_jaccard, - "crop_jaccard": crop_jaccard, "score": total_score, } - if predictions["crop_type"] is not None: - crop_type_ypred = self.probas_to_labels( - self.logits_to_probas(predictions["crop_type"]) - ) - crop_type_score = self.crop_type_f1( - crop_type_ypred, true_labels_dict["true_crop_type"] - ) - metrics["crop_type_f1"] = crop_type_score + + metrics.update(loss_report) return metrics def validation_step(self, batch: Data, batch_idx: int = None) -> dict: """Executes one valuation step.""" + eval_metrics = self._shared_eval_step(batch, batch_idx) metrics = { - "val_loss": eval_metrics["loss"], "vef1": eval_metrics["edge_f1"], "vcf1": eval_metrics["crop_f1"], "vmae": eval_metrics["dist_mae"], "val_score": eval_metrics["score"], + "val_loss": eval_metrics["loss"], + "val_dloss": eval_metrics["dloss"], + "val_eloss": eval_metrics["eloss"], + "val_closs": eval_metrics["closs"], } - if "crop_type_f1" in eval_metrics: - metrics["vctf1"] = eval_metrics["crop_type_f1"] - self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) + self.log_dict( + metrics, + on_step=False, + on_epoch=True, + prog_bar=True, + batch_size=batch.num_samples, + ) if self.save_batch_val_metrics: self._save_batch_metrics(metrics, self.current_epoch, batch) @@ -855,6 +535,7 @@ def _save_batch_metrics( def test_step(self, batch: Data, batch_idx: int = None) -> dict: """Executes one test step.""" + eval_metrics = self._shared_eval_step(batch, batch_idx) metrics = { @@ -879,73 +560,81 @@ def test_step(self, batch: Data, batch_idx: int = None) -> dict: return metrics def configure_scorer(self): - self.dist_mae = torchmetrics.MeanAbsoluteError() - self.dist_mse = torchmetrics.MeanSquaredError() - self.edge_f1 = torchmetrics.F1Score(num_classes=2, average="micro") - self.crop_f1 = torchmetrics.F1Score(num_classes=2, average="micro") - self.edge_mcc = torchmetrics.MatthewsCorrCoef(num_classes=2) - self.crop_mcc = torchmetrics.MatthewsCorrCoef(num_classes=2) - self.edge_dice = torchmetrics.Dice(num_classes=2, average="micro") - self.crop_dice = torchmetrics.Dice(num_classes=2, average="micro") - self.edge_jaccard = torchmetrics.JaccardIndex( - average="micro", num_classes=2 + """The fβ value. + + To put equal weight on precision and recall, set to 1. To emphasize + minimizing false positives, set to <1. To emphasize minimizing false + negatives, set to >1. + """ + + self.mae_scorer = torchmetrics.MeanAbsoluteError() + self.mse_scorer = torchmetrics.MeanSquaredError() + self.f_beta_scorer = torchmetrics.FBetaScore( + task="multiclass", num_classes=2, beta=2.0 ) - self.crop_jaccard = torchmetrics.JaccardIndex( - average="micro", num_classes=2 + self.mcc_scorer = torchmetrics.MatthewsCorrCoef( + task="multiclass", num_classes=2 ) - if self.num_classes > 2: - self.crop_type_f1 = torchmetrics.F1Score( - num_classes=self.num_classes, - average="weighted", - ignore_index=0, - ) + + def calc_weights(self, counts: torch.Tensor) -> torch.Tensor: + """Calculates class weights.""" + + num_samples = counts.sum() + num_classes = len(counts) + class_weights = num_samples / (num_classes * counts) + weights = torch.nan_to_num(class_weights, nan=0, neginf=0, posinf=0) + + return weights def configure_loss(self): - self.dist_loss = TanimotoDistLoss() - if self.deep_sup_dist: - self.dist_loss_3_1 = TanimotoDistLoss() - self.dist_loss_2_2 = TanimotoDistLoss() - self.dist_loss_1_3 = TanimotoDistLoss() - # Edge losses - self.edge_loss = TanimotoDistLoss() - if self.deep_sup_edge: - self.edge_loss_3_1 = TanimotoDistLoss() - self.edge_loss_2_2 = TanimotoDistLoss() - self.edge_loss_1_3 = TanimotoDistLoss() - # Crop mask losses - self.crop_loss = TanimotoDistLoss( - scale_pos_weight=self.scale_pos_weight - ) - if self.deep_sup_mask: - self.crop_loss_3_1 = TanimotoDistLoss( - scale_pos_weight=self.scale_pos_weight - ) - self.crop_loss_2_2 = TanimotoDistLoss( - scale_pos_weight=self.scale_pos_weight - ) - self.crop_loss_1_3 = TanimotoDistLoss( - scale_pos_weight=self.scale_pos_weight - ) - # Crop RNN losses - self.crop_star_l2_loss = TanimotoDistLoss() - self.crop_star_loss = TanimotoDistLoss() - # FIXME: - if self.num_classes > 2: - self.crop_type_star_loss = TanimotoDistLoss( - scale_pos_weight=self.scale_pos_weight - ) - self.crop_type_loss = TanimotoDistLoss( - scale_pos_weight=self.scale_pos_weight - ) + """Configures loss methods.""" + + # # Weights for crop AND edge + # crop_and_edge_counts = torch.zeros(2, device=self.class_counts.device) + # crop_and_edge_counts[0] = self.class_counts[0] + # crop_and_edge_counts[1] = self.class_counts[1:].sum() + # self.crop_and_edge_weights = self.calc_weights(crop_and_edge_counts) + + # # Weights for crop OR edge + # self.crop_or_edge_weights = self.calc_weights(self.class_counts) + + # # Weights for crop + # crop_counts = torch.zeros(2, device=self.class_counts.device) + # crop_counts[0] = self.class_counts[0] + # crop_counts[1] = self.class_counts[1] + # self.crop_weights = self.calc_weights(crop_counts) + + # Main loss + self.reg_loss = LOSS_DICT[self.loss_name].get("regression") + self.cls_loss = LOSS_DICT[self.loss_name].get("classification") def configure_optimizers(self): + """Configures optimizers.""" + params_list = list(self.cultionet_model.parameters()) - if self.optimizer == "AdamW": + interval = 'epoch' + if self.optimizer == "Adam": + optimizer = torch.optim.Adam( + params_list, + lr=self.learning_rate, + eps=self.eps, + ) + elif self.optimizer == "AdamW": optimizer = torch.optim.AdamW( params_list, lr=self.learning_rate, weight_decay=self.weight_decay, eps=self.eps, + betas=(0.9, 0.98), + ) + elif self.optimizer == "RAdam": + optimizer = torch.optim.RAdam( + params_list, + lr=self.learning_rate, + weight_decay=self.weight_decay, + decoupled_weight_decay=True, + eps=self.eps, + betas=(0.9, 0.99), ) elif self.optimizer == "SGD": optimizer = torch.optim.SGD( @@ -957,15 +646,23 @@ def configure_optimizers(self): else: raise NameError("Choose either 'AdamW' or 'SGD'.") - if self.lr_scheduler == "ExponentialLR": + if self.lr_scheduler == LearningRateSchedulers.COSINE_ANNEALING_LR: + model_lr_scheduler = optim_lr_scheduler.CosineAnnealingLR( + optimizer, T_max=20, eta_min=1e-5, last_epoch=-1 + ) + elif self.lr_scheduler == LearningRateSchedulers.EXPONENTIAL_LR: model_lr_scheduler = optim_lr_scheduler.ExponentialLR( optimizer, gamma=0.5 ) - elif self.lr_scheduler == "CosineAnnealingLR": - model_lr_scheduler = optim_lr_scheduler.CosineAnnealingLR( - optimizer, T_max=20, eta_min=1e-5, last_epoch=-1 + elif self.lr_scheduler == LearningRateSchedulers.ONE_CYCLE_LR: + model_lr_scheduler = optim_lr_scheduler.OneCycleLR( + optimizer, + max_lr=self.learning_rate, + epochs=self.trainer.max_epochs, + steps_per_epoch=self.trainer.estimated_stepping_batches, ) - elif self.lr_scheduler == "StepLR": + interval = 'step' + elif self.lr_scheduler == LearningRateSchedulers.STEP_LR: model_lr_scheduler = optim_lr_scheduler.StepLR( optimizer, step_size=self.steplr_step_size, gamma=0.5 ) @@ -980,7 +677,228 @@ def configure_optimizers(self): "scheduler": model_lr_scheduler, "name": "lr_sch", "monitor": "val_score", - "interval": "epoch", + "interval": interval, "frequency": 1, }, } + + +class CultionetLitTransferModel(LightningModuleMixin): + """Transfer learning module for Cultionet.""" + + def __init__( + self, + pretrained_ckpt_file: T.Union[Path, str], + in_channels: int, + in_time: int, + hidden_channels: int = 64, + model_type: str = ModelTypes.TOWERUNET, + dropout: float = 0.2, + activation_type: str = "SiLU", + dilations: T.Union[int, T.Sequence[int]] = None, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + optimizer: str = "AdamW", + loss_name: str = LossTypes.TANIMOTO_COMPLEMENT, + learning_rate: float = 0.01, + lr_scheduler: str = LearningRateSchedulers.ONE_CYCLE_LR, + steplr_step_size: int = 5, + weight_decay: float = 1e-3, + eps: float = 1e-4, + ckpt_name: str = ModelNames.CKPT_TRANSFER_NAME.replace(".ckpt", ""), + model_name: str = "cultionet_transfer", + pool_by_max: bool = False, + batchnorm_first: bool = False, + class_counts: T.Optional[torch.Tensor] = None, + edge_class: T.Optional[int] = None, + scale_pos_weight: bool = False, + save_batch_val_metrics: bool = False, + finetune: T.Optional[str] = None, + ): + super().__init__() + + self.save_hyperparameters() + + self.optimizer = optimizer + self.loss_name = loss_name + self.learning_rate = learning_rate + self.lr_scheduler = lr_scheduler + self.steplr_step_size = steplr_step_size + self.weight_decay = weight_decay + self.eps = eps + self.ckpt_name = ckpt_name + self.model_name = model_name + self.in_time = in_time + self.class_counts = class_counts + self.scale_pos_weight = scale_pos_weight + self.save_batch_val_metrics = save_batch_val_metrics + self.finetune = finetune + + if edge_class is not None: + self.edge_class = edge_class + else: + self.edge_class = 2 + + self.cultionet_model = CultionetLitModel.load_from_checkpoint( + checkpoint_path=str(pretrained_ckpt_file) + ).cultionet_model + + if self.finetune != "all": + + # Freeze all parameters if not finetuning the full model + self.freeze(self.cultionet_model) + + if self.finetune == "fc": + # Unfreeze fully connected layers + for name, param in self.cultionet_model.named_parameters(): + if name.startswith("mask_model.final_"): + param.requires_grad = True + + else: + + # Update the post-UNet layer with trainable parameters + mask_model_final_a = cunn.TowerUNetFinal( + in_channels=self.cultionet_model.mask_model.final_a.in_channels, + num_classes=self.cultionet_model.mask_model.final_a.num_classes, + activation_type=activation_type, + ) + mask_model_final_a.apply(init_conv_weights) + self.cultionet_model.mask_model.final_a = mask_model_final_a + + mask_model_final_b = cunn.TowerUNetFinal( + in_channels=self.cultionet_model.mask_model.final_b.in_channels, + num_classes=self.cultionet_model.mask_model.final_b.num_classes, + activation_type=activation_type, + resample_factor=2, + ) + mask_model_final_b.apply(init_conv_weights) + self.cultionet_model.mask_model.final_b = mask_model_final_b + + mask_model_final_c = cunn.TowerUNetFinal( + in_channels=self.cultionet_model.mask_model.final_c.in_channels, + num_classes=self.cultionet_model.mask_model.final_c.num_classes, + activation_type=activation_type, + resample_factor=4, + ) + mask_model_final_c.apply(init_conv_weights) + self.cultionet_model.mask_model.final_c = mask_model_final_c + + mask_model_final_combine = cunn.TowerUNetFinalCombine( + num_classes=self.cultionet_model.mask_model.final_combine.num_classes, + edge_activation=self.cultionet_model.mask_model.final_combine.edge_activation, + mask_activation=self.cultionet_model.mask_model.final_combine.mask_activation, + ) + mask_model_final_combine.apply(init_conv_weights) + self.cultionet_model.mask_model.final_combine = ( + mask_model_final_combine + ) + + self.model_attr = f"{model_name}_{model_type}" + setattr( + self, + self.model_attr, + self.cultionet_model, + ) + + self.configure_loss() + self.configure_scorer() + + @property + def is_transfer_model(self) -> bool: + return True + + def freeze(self, layer): + for param in layer.parameters(): + param.requires_grad = False + + def unfreeze(self, layer): + for param in layer.parameters(): + param.requires_grad = True + + return layer + + +class CultionetLitModel(LightningModuleMixin): + def __init__( + self, + in_channels: int, + in_time: int, + hidden_channels: int = 64, + model_type: str = ModelTypes.TOWERUNET, + dropout: float = 0.2, + activation_type: str = "SiLU", + dilations: T.Union[int, T.Sequence[int]] = None, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + optimizer: str = "AdamW", + loss_name: str = LossTypes.TANIMOTO_COMPLEMENT, + learning_rate: float = 0.01, + lr_scheduler: str = LearningRateSchedulers.ONE_CYCLE_LR, + steplr_step_size: int = 5, + weight_decay: float = 1e-3, + eps: float = 1e-4, + ckpt_name: str = "last", + model_name: str = "cultionet", + pool_by_max: bool = False, + batchnorm_first: bool = False, + class_counts: T.Optional[torch.Tensor] = None, + edge_class: T.Optional[int] = None, + scale_pos_weight: bool = False, + save_batch_val_metrics: bool = False, + ): + """Lightning model.""" + + super().__init__() + + self.save_hyperparameters() + + self.optimizer = optimizer + self.loss_name = loss_name + self.learning_rate = learning_rate + self.lr_scheduler = lr_scheduler + self.steplr_step_size = steplr_step_size + self.weight_decay = weight_decay + self.eps = eps + self.ckpt_name = ckpt_name + self.model_name = model_name + self.in_time = in_time + self.class_counts = class_counts + self.scale_pos_weight = scale_pos_weight + self.save_batch_val_metrics = save_batch_val_metrics + + if edge_class is not None: + self.edge_class = edge_class + else: + self.edge_class = 2 + + self.model_attr = f"{model_name}_{model_type}" + setattr( + self, + self.model_attr, + CultioNet( + in_channels=in_channels, + in_time=in_time, + hidden_channels=hidden_channels, + model_type=model_type, + dropout=dropout, + activation_type=activation_type, + dilations=dilations, + res_block_type=res_block_type, + attention_weights=attention_weights, + pool_by_max=pool_by_max, + batchnorm_first=batchnorm_first, + ), + ) + + self.configure_loss() + self.configure_scorer() + + @property + def is_transfer_model(self) -> bool: + return False + + # def on_train_epoch_start(self): + # # Get the current learning rate from the optimizer + # weight_decay = self.optimizers().optimizer.param_groups[0]['weight_decay'] + # if (weight_decay != self.weight_decay) or (eps != self.eps): + # self.configure_optimizers() diff --git a/src/cultionet/models/maskcrnn.py b/src/cultionet/models/maskcrnn.py deleted file mode 100644 index 2cb801a3..00000000 --- a/src/cultionet/models/maskcrnn.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Backbone source: https://github.com/VSainteuf/utae- -paps/blob/main/src/backbones/utae.py.""" -import typing as T - -import torch -from torchvision.models.detection.rpn import AnchorGenerator -from torchvision.models.detection.faster_rcnn import FastRCNNPredictor -from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor -from torchvision.models.detection.transform import GeneralizedRCNNTransform -from torchvision.models.detection import maskrcnn_resnet50_fpn_v2 - - -class BFasterRCNN(torch.nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - num_classes: int, - sizes: T.Optional[T.Sequence[int]] = None, - aspect_ratios: T.Optional[T.Sequence[int]] = None, - trainable_backbone_layers: T.Optional[int] = 3, - min_image_size: int = 800, - max_image_size: int = 1333, - ) -> None: - super(BFasterRCNN, self).__init__() - - if sizes is None: - sizes = (32, 64, 128, 256, 512) - if not isinstance(sizes, tuple): - try: - sizes = tuple(sizes) - except TypeError as e: - raise TypeError(e) - - if aspect_ratios is None: - aspect_ratios = (0.5, 1.0, 2.0) - if not isinstance(aspect_ratios, tuple): - try: - aspect_ratios = tuple(aspect_ratios) - except TypeError as e: - raise TypeError(e) - - # Load a pretrained model - self.model = maskrcnn_resnet50_fpn_v2( - weights="DEFAULT", - trainable_backbone_layers=trainable_backbone_layers, - ) - # Remove image normalization and add custom resizing - self.model.transform = GeneralizedRCNNTransform( - image_mean=(0.0,) * in_channels, - image_std=(1.0,) * in_channels, - min_size=min_image_size, - max_size=max_image_size, - ) - # Replace the first convolution - out_channels = self.model.backbone.body.conv1.out_channels - self.model.backbone.body.conv1 = torch.nn.Conv2d( - in_channels, - out_channels, - kernel_size=7, - stride=2, - padding=3, - bias=False, - ) - self.model.rpn.anchor_generator = AnchorGenerator( - sizes=tuple((size,) for size in sizes), - aspect_ratios=(aspect_ratios,) * len(sizes), - ) - # Update the output classes in the predictor heads - in_features = self.model.roi_heads.box_predictor.cls_score.in_features - self.model.roi_heads.box_predictor = FastRCNNPredictor( - in_features, num_classes - ) - in_features_mask = ( - self.model.roi_heads.mask_predictor.conv5_mask.in_channels - ) - self.model.roi_heads.mask_predictor = MaskRCNNPredictor( - in_features_mask, out_channels, num_classes - ) - - def __call__(self, *args, **kwargs): - return self.forward(*args, **kwargs) - - def forward( - self, x: torch.Tensor, y: T.Optional[torch.Tensor] = None - ) -> torch.Tensor: - return self.model(x, y) diff --git a/src/cultionet/models/model_utils.py b/src/cultionet/models/model_utils.py deleted file mode 100644 index 7a87b47e..00000000 --- a/src/cultionet/models/model_utils.py +++ /dev/null @@ -1,74 +0,0 @@ -import typing as T - -import torch -from torch_geometric import nn -from torch_geometric.data import Data - - -def get_batch_count(batch: torch.Tensor) -> int: - return batch.unique().size(0) - - -class UpSample(torch.nn.Module): - """Up-samples a tensor.""" - - def __init__(self): - super(UpSample, self).__init__() - - def forward( - self, x: torch.Tensor, size: T.Sequence[int], mode: str = "bilinear" - ) -> torch.Tensor: - upsampler = torch.nn.Upsample(size=size, mode=mode, align_corners=True) - - return upsampler(x) - - -class GraphToConv(torch.nn.Module): - """Reshapes a 2d tensor to a 4d tensor.""" - - def __init__(self): - super(GraphToConv, self).__init__() - - def forward( - self, x: torch.Tensor, nbatch: int, nrows: int, ncols: int - ) -> torch.Tensor: - n_channels = x.shape[1] - return x.reshape(nbatch, nrows, ncols, n_channels).permute(0, 3, 1, 2) - - -class ConvToGraph(torch.nn.Module): - """Reshapes a 4d tensor to a 2d tensor.""" - - def __init__(self): - super(ConvToGraph, self).__init__() - - def forward(self, x: torch.Tensor) -> torch.Tensor: - nbatch, n_channels, nrows, ncols = x.shape - - return x.permute(0, 2, 3, 1).reshape( - nbatch * nrows * ncols, n_channels - ) - - -class ConvToTime(torch.nn.Module): - """Reshapes a 4d tensor to a 5d tensor.""" - - def __init__(self): - super(ConvToTime, self).__init__() - - def forward( - self, x: torch.Tensor, nbands: int, ntime: int - ) -> torch.Tensor: - nbatch, __, height, width = x.shape - - return x.reshape(nbatch, nbands, ntime, height, width) - - -def max_pool_neighbor_x( - x: torch.Tensor, edge_index: torch.Tensor -) -> torch.Tensor: - return nn.max_pool_neighbor_x(Data(x=x, edge_index=edge_index)).x - - -def global_max_pool(x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor: - return nn.global_max_pool(x=x, batch=batch, size=x.shape[0]) diff --git a/src/cultionet/models/nunet.py b/src/cultionet/models/nunet.py index 122684b4..a40fa732 100644 --- a/src/cultionet/models/nunet.py +++ b/src/cultionet/models/nunet.py @@ -7,1152 +7,259 @@ import typing as T import torch +import torch.nn as nn +from einops.layers.torch import Rearrange -from . import model_utils -from . import kernels -from .base_layers import ( - AttentionGate, - DoubleConv, - SpatioTemporalConv3d, - Min, - Max, - Mean, - Std, - Permute, - PoolConv, - PoolResidualConv, - ResidualConv, - ResidualAConv, - SingleConv, - Softmax, - SigmoidCrisp, - Squeeze, - SetActivation, -) -from .enums import ResBlockTypes -from .unet_parts import ( - UNet3P_3_1, - UNet3P_2_2, - UNet3P_1_3, - UNet3P_0_4, - UNet3_3_1, - UNet3_2_2, - UNet3_1_3, - UNet3_0_4, - ResUNet3_3_1, - ResUNet3_2_2, - ResUNet3_1_3, - ResUNet3_0_4, -) +from .. import nn as cunn +from ..enums import AttentionTypes, ResBlockTypes +from ..layers.weights import init_conv_weights -def weights_init_kaiming(m): - """ - Source: - https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/init_weights.py - """ - classname = m.__class__.__name__ - if classname.find("Conv") != -1: - torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") - elif classname.find("Linear") != -1: - torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") - elif classname.find("BatchNorm") != -1: - torch.nn.init.normal_(m.weight.data, 1.0, 0.02) - torch.nn.init.constant_(m.bias.data, 0.0) - - -class UNet2(torch.nn.Module): - """UNet++ - - References: - https://arxiv.org/pdf/1807.10165.pdf - https://arxiv.org/pdf/1804.03999.pdf - https://github.com/4uiiurz1/pytorch-nested-unet/blob/master/archs.py - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - init_filter: int = 64, - boundary_layer: bool = False, - out_side_channels: int = 2, - linear_fc: bool = False, - deep_supervision: bool = False, - ): - super(UNet2, self).__init__() - - self.linear_fc = linear_fc - self.boundary_layer = boundary_layer - self.deep_supervision = deep_supervision - - init_filter = int(init_filter) - channels = [ - init_filter, - init_filter * 2, - init_filter * 4, - init_filter * 8, - init_filter * 16, - ] - - self.up = model_utils.UpSample() - - self.attention_0 = AttentionGate( - high_channels=channels[3], low_channels=channels[4] - ) - self.attention_1 = AttentionGate( - high_channels=channels[2], low_channels=channels[3] - ) - self.attention_2 = AttentionGate( - high_channels=channels[1], low_channels=channels[2] - ) - self.attention_3 = AttentionGate( - high_channels=channels[0], low_channels=channels[1] - ) - - if boundary_layer: - # Right stream - self.bound4_1 = DoubleConv(channels[4] + channels[4], channels[0]) - self.bound3_1 = DoubleConv( - channels[0] + channels[3] * 2, channels[0] - ) - self.bound2_1 = DoubleConv( - channels[0] + channels[2] * 2, channels[0] - ) - self.bound1_1 = DoubleConv( - channels[0] + channels[1] * 2, channels[0] - ) - self.bound0_1 = DoubleConv( - channels[0] + channels[0] * 2, channels[0] - ) - # Left stream - self.bound0_0 = ResidualConv(channels[0], channels[0]) - self.bound0_0_pool = PoolConv(channels[0], channels[1]) - self.bound1_0 = DoubleConv(channels[1] * 2, channels[1]) - self.bound1_0_pool = PoolConv(channels[1], channels[2]) - self.bound2_0 = DoubleConv(channels[2] * 2, channels[2]) - self.bound2_0_pool = PoolConv(channels[2], channels[3]) - self.bound3_0 = DoubleConv(channels[3] * 2, channels[3]) - self.bound3_0_pool = PoolConv(channels[3], channels[4]) - self.bound4_0 = DoubleConv(channels[4] * 2, channels[4]) - - self.bound_final = torch.nn.Conv2d( - channels[0], out_side_channels, kernel_size=1, padding=0 - ) - - self.conv0_0 = ResidualConv(in_channels, channels[0]) - self.conv1_0 = PoolConv(channels[0], channels[1], dropout=0.25) - self.conv2_0 = PoolConv(channels[1], channels[2], dropout=0.5) - self.conv3_0 = PoolConv(channels[2], channels[3], dropout=0.5) - self.conv4_0 = PoolConv(channels[3], channels[4], dropout=0.5) - - self.conv0_1 = ResidualConv(channels[0] + channels[1], channels[0]) - self.conv1_1 = DoubleConv(channels[1] + channels[2], channels[1]) - self.conv2_1 = DoubleConv(channels[2] + channels[3], channels[2]) - self.conv3_1 = DoubleConv(channels[3] + channels[4], channels[3]) - - self.conv0_2 = ResidualConv(channels[0] * 2 + channels[1], channels[0]) - self.conv1_2 = DoubleConv(channels[1] * 2 + channels[2], channels[1]) - self.conv2_2 = DoubleConv(channels[2] * 2 + channels[3], channels[2]) - - self.conv0_3 = ResidualConv(channels[0] * 3 + channels[1], channels[0]) - self.conv1_3 = DoubleConv(channels[1] * 3 + channels[2], channels[1]) - - self.conv0_4 = ResidualConv(channels[0] * 4 + channels[1], channels[0]) - - if self.linear_fc: - self.net_final = torch.nn.Sequential( - torch.nn.LeakyReLU(inplace=False), - Permute((0, 2, 3, 1)), - torch.nn.Linear(channels[0], out_channels), - Permute((0, 3, 1, 2)), - ) - else: - if self.deep_supervision: - in_final_layers = out_channels - - self.final_1 = torch.nn.Conv2d( - channels[0], out_channels, kernel_size=1, padding=0 - ) - self.final_2 = torch.nn.Conv2d( - channels[0], out_channels, kernel_size=1, padding=0 - ) - self.final_3 = torch.nn.Conv2d( - channels[0], out_channels, kernel_size=1, padding=0 - ) - self.final_4 = torch.nn.Conv2d( - channels[0], out_channels, kernel_size=1, padding=0 - ) - else: - in_final_layers = channels[0] - - if boundary_layer: - in_final_layers += out_side_channels - - self.net_final = torch.nn.Conv2d( - in_final_layers, out_channels, kernel_size=1, padding=0 - ) - - # Initialise weights - for m in self.modules(): - if isinstance(m, (torch.nn.Conv2d, torch.nn.BatchNorm2d)): - m.apply(weights_init_kaiming) - - def forward( - self, x: torch.Tensor - ) -> T.Dict[str, T.Union[None, torch.Tensor]]: - mask = None - boundary = None - - x0_0 = self.conv0_0(x) - x1_0 = self.conv1_0(x0_0) - - # 1/2 - x1_0 = self.conv1_0(x0_0) - # 1/1 - x0_1 = self.conv0_1( - torch.cat([x0_0, self.up(x1_0, size=x0_0.shape[-2:])], dim=1) - ) - - # 1/4 - x2_0 = self.conv2_0(x1_0) - # 1/2 - x1_1 = self.conv1_1( - torch.cat([x1_0, self.up(x2_0, size=x1_0.shape[-2:])], dim=1) - ) - # 1/1 - x0_2 = self.conv0_2( - torch.cat([x0_0, x0_1, self.up(x1_1, size=x0_1.shape[-2:])], dim=1) - ) - - # 1/8 - x3_0 = self.conv3_0(x2_0) - # 1/4 - x2_1 = self.conv2_1( - torch.cat([x2_0, self.up(x3_0, size=x2_0.shape[-2:])], dim=1) - ) - # 1/2 - x1_2 = self.conv1_2( - torch.cat([x1_0, x1_1, self.up(x2_1, size=x1_1.shape[-2:])], dim=1) - ) - # 1/1 - x0_3 = self.conv0_3( - torch.cat( - [x0_0, x0_1, x0_2, self.up(x1_2, size=x0_2.shape[-2:])], dim=1 - ) - ) - - # 1/16 - x4_0 = self.conv4_0(x3_0) - x3_0 = self.attention_0(x3_0, x4_0) - # 1/8 - x3_1 = self.conv3_1( - torch.cat([x3_0, self.up(x4_0, size=x3_0.shape[-2:])], dim=1) - ) - x2_1 = self.attention_1(x2_1, x3_1) - # 1/4 - x2_2 = self.conv2_2( - torch.cat([x2_0, x2_1, self.up(x3_1, size=x2_1.shape[-2:])], dim=1) - ) - x1_2 = self.attention_2(x1_2, x2_2) - # 1/2 - x1_3 = self.conv1_3( - torch.cat( - [x1_0, x1_1, x1_2, self.up(x2_2, size=x1_2.shape[-2:])], dim=1 - ) - ) - x0_3 = self.attention_3(x0_3, x1_3) - # 1/1 - x0_4 = self.conv0_4( - torch.cat( - [x0_0, x0_1, x0_2, x0_3, self.up(x1_3, size=x0_3.shape[-2:])], - dim=1, - ) - ) - - if self.boundary_layer: - # Left stream - b0_0 = self.bound0_0(x0_0) - b1_0 = self.bound1_0( - torch.cat([x1_0, self.bound0_0_pool(b0_0)], dim=1) - ) - b2_0 = self.bound2_0( - torch.cat([x2_0, self.bound1_0_pool(b1_0)], dim=1) - ) - b3_0 = self.bound3_0( - torch.cat([x3_0, self.bound2_0_pool(b2_0)], dim=1) - ) - b4_0 = self.bound4_0( - torch.cat([x4_0, self.bound3_0_pool(b3_0)], dim=1) - ) - # Right stream - b4_1 = self.bound4_1(torch.cat([b4_0, x4_0], dim=1)) - b3_1 = self.bound3_1( - torch.cat( - [x3_1, b3_0, self.up(b4_1, size=x3_1.shape[-2:])], dim=1 - ) - ) - b2_1 = self.bound2_1( - torch.cat( - [x2_2, b2_0, self.up(b3_1, size=x2_2.shape[-2:])], dim=1 - ) - ) - b1_1 = self.bound1_1( - torch.cat( - [x1_3, b1_0, self.up(b2_1, size=x1_3.shape[-2:])], dim=1 - ) - ) - boundary = self.bound0_1( - torch.cat( - [x0_4, b0_0, self.up(b1_1, size=x0_4.shape[-2:])], dim=1 - ) - ) - - if self.linear_fc: - mask = self.net_final(x0_4) - else: - if self.deep_supervision: - # Average over skip connections - x0_1 = self.final_1(x0_1) - x0_2 = self.final_2(x0_2) - x0_3 = self.final_3(x0_3) - x0_4 = self.final_4(x0_4) - x0_4 = (x0_1 + x0_2 + x0_3 + x0_4) / 4.0 - if self.boundary_layer: - boundary = self.bound_final(boundary) - mask = self.net_final(torch.cat([x0_4, boundary], dim=1)) - else: - mask = self.net_final(x0_4) - - return {"mask": mask, "boundary": boundary} - - -class UNet3(torch.nn.Module): - """UNet+++ - - References: - https://arxiv.org/ftp/arxiv/papers/2004/2004.08790.pdf - """ - +class Conv3d(nn.Module): def __init__( self, in_channels: int, + in_time: int, out_channels: int, - init_filter: int = 64, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "SiLU", - ): - super(UNet3, self).__init__() - - init_filter = int(init_filter) - channels = [ - init_filter, - init_filter * 2, - init_filter * 4, - init_filter * 8, - init_filter * 16, - ] - up_channels = int(channels[0] * 5) - - self.up = model_utils.UpSample() - - self.conv0_0 = SingleConv( - in_channels, channels[0], activation_type=activation_type - ) - self.conv1_0 = PoolConv( - in_channels=channels[0], - out_channels=channels[1], - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - self.conv2_0 = PoolConv( - in_channels=channels[1], - out_channels=channels[2], - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - self.conv3_0 = PoolConv( - in_channels=channels[2], - out_channels=channels[3], - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - self.conv4_0 = PoolConv( - in_channels=channels[3], - out_channels=channels[4], - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - - # Connect 3 - self.convs_3_1 = UNet3P_3_1( - channels=channels, - up_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - self.convs_2_2 = UNet3P_2_2( - channels=channels, - up_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - self.convs_1_3 = UNet3P_1_3( - channels=channels, - up_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - self.convs_0_4 = UNet3P_0_4( - channels=channels, - up_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=double_dilation, - activation_type=activation_type, - ) - - self.final = torch.nn.Conv2d( - in_channels=up_channels, - out_channels=out_channels, - kernel_size=1, - padding=0, - ) - - # Initialise weights - for m in self.modules(): - if isinstance(m, (torch.nn.Conv2d, torch.nn.BatchNorm2d)): - m.apply(weights_init_kaiming) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - # Backbone - # 1/1 - x0_0 = self.conv0_0(x) - # 1/2 - x1_0 = self.conv1_0(x0_0) - # 1/4 - x2_0 = self.conv2_0(x1_0) - # 1/8 - x3_0 = self.conv3_0(x2_0) - # 1/16 - x4_0 = self.conv4_0(x3_0) - - # 1/8 connection - out_3_1 = self.convs_3_1( - x0_0=x0_0, x1_0=x1_0, x2_0=x2_0, x3_0=x3_0, x4_0=x4_0 - ) - # 1/4 connection - out_2_2 = self.convs_2_2( - x0_0=x0_0, x1_0=x1_0, x2_0=x2_0, h3_1=out_3_1, x4_0=x4_0 - ) - # 1/2 connection - out_1_3 = self.convs_1_3( - x0_0=x0_0, x1_0=x1_0, h2_2=out_2_2, h3_1=out_3_1, x4_0=x4_0 - ) - # 1/1 connection - out_0_4 = self.convs_0_4( - x0_0=x0_0, h1_3=out_1_3, h2_2=out_2_2, h3_1=out_3_1, x4_0=x4_0 - ) - - out = self.final(out_0_4) - - return out - - -class PreUnet3Psi(torch.nn.Module): - def __init__( - self, - in_channels: int, - channels: T.Sequence[int], + kernel_size: int, activation_type: str, - trend_kernel_size: int = 5, ): - super(PreUnet3Psi, self).__init__() - - self.cg = model_utils.ConvToGraph() - self.gc = model_utils.GraphToConv() - - self.peak_kernel = kernels.Peaks(kernel_size=trend_kernel_size) - self.pos_trend_kernel = kernels.Trend( - kernel_size=trend_kernel_size, direction="positive" - ) - self.neg_trend_kernel = kernels.Trend( - kernel_size=trend_kernel_size, direction="negative" - ) - self.reduce_trend_to_time = torch.nn.Sequential( - SpatioTemporalConv3d( - in_channels=int(in_channels * 3), - out_channels=1, - activation_type=activation_type, - ), - Squeeze(dim=1), - ) - - self.time_conv0 = SpatioTemporalConv3d( - in_channels=in_channels, - out_channels=channels[0], - activation_type=activation_type, - ) - self.reduce_to_time = torch.nn.Sequential( - SpatioTemporalConv3d( - in_channels=channels[0], - out_channels=1, - activation_type=activation_type, + super().__init__() + + remaining_time = in_time - kernel_size + 1 + + self.seq = nn.Sequential( + # Reduce time + nn.Conv3d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=(kernel_size, 1, 1), + padding=0, + bias=False, ), - Squeeze(dim=1), - ) - # (B x C x T|D x H x W) - # Temporal reductions - # Reduce to 2d (B x C x H x W) - self.reduce_to_channels_min = torch.nn.Sequential( - Min(dim=2), - torch.nn.BatchNorm2d(channels[0]), - SetActivation(activation_type=activation_type), - ) - self.reduce_to_channels_max = torch.nn.Sequential( - Max(dim=2), - torch.nn.BatchNorm2d(channels[0]), - SetActivation(activation_type=activation_type), - ) - self.reduce_to_channels_mean = torch.nn.Sequential( - Mean(dim=2), - torch.nn.BatchNorm2d(channels[0]), - SetActivation(activation_type=activation_type), - ) - self.reduce_to_channels_std = torch.nn.Sequential( - Std(dim=2), - torch.nn.BatchNorm2d(channels[0]), - SetActivation(activation_type=activation_type), - ) - - def forward(self, x: torch.Tensor, rnn_h: torch.Tensor) -> torch.Tensor: - peak_kernels = [] - pos_trend_kernels = [] - neg_trend_kernels = [] - for bidx in range(0, x.shape[1]): - # (B x C x T x H x W) -> (B x T x H x W) - band_input = x[:, bidx] - # (B x T x H x W) -> (B*H*W x T) -> (B*H*W x 1(C) x T) - band_input = self.cg(band_input).unsqueeze(1) - peak_res = self.peak_kernel(band_input) - pos_trend_res = self.pos_trend_kernel(band_input) - neg_trend_res = self.neg_trend_kernel(band_input) - # Reshape (B*H*W x 1(C) x T) -> (B x C X T x H x W) - peak_kernels += [ - self.gc( - # (B*H*W x T) - peak_res.squeeze(), - nbatch=x.shape[0], - nrows=x.shape[-2], - ncols=x.shape[-1], - ).unsqueeze(1) - ] - pos_trend_kernels += [ - self.gc( - # (B*H*W x T) - pos_trend_res.squeeze(), - nbatch=x.shape[0], - nrows=x.shape[-2], - ncols=x.shape[-1], - ).unsqueeze(1) - ] - neg_trend_kernels += [ - self.gc( - # (B*H*W x T) - neg_trend_res.squeeze(), - nbatch=x.shape[0], - nrows=x.shape[-2], - ncols=x.shape[-1], - ).unsqueeze(1) - ] - # Concatentate along the channels - trend_kernels = torch.cat( - peak_kernels + pos_trend_kernels + neg_trend_kernels, dim=1 - ) - - # Inputs shape is (B x C X T|D x H x W) - h = self.time_conv0(x) - h = torch.cat( - [ - self.reduce_to_time(h), - self.reduce_to_channels_min(h), - self.reduce_to_channels_max(h), - self.reduce_to_channels_mean(h), - self.reduce_to_channels_std(h), - rnn_h, - self.reduce_trend_to_time(trend_kernels), - ], - dim=1, - ) - - return h - - -class PostUNet3Psi(torch.nn.Module): - def __init__( - self, - up_channels: int, - num_classes: int, - mask_activation: T.Callable, - deep_sup_dist: T.Optional[bool] = False, - deep_sup_edge: T.Optional[bool] = False, - deep_sup_mask: T.Optional[bool] = False, - ): - super(PostUNet3Psi, self).__init__() - - self.deep_sup_dist = deep_sup_dist - self.deep_sup_edge = deep_sup_edge - self.deep_sup_mask = deep_sup_mask - - self.up = model_utils.UpSample() - - self.final_dist = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - torch.nn.Sigmoid(), - ) - self.final_edge = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - SigmoidCrisp(), - ) - self.final_mask = torch.nn.Sequential( - torch.nn.Conv2d( - up_channels, num_classes, kernel_size=1, padding=0 + nn.BatchNorm3d(in_channels), + cunn.SetActivation(activation_type=activation_type), + # Reduce time to 1 + nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(remaining_time, 1, 1), + padding=0, + bias=False, ), - mask_activation, + # c = channels; t = 1 + Rearrange('b c 1 h w -> b c h w'), + nn.BatchNorm2d(out_channels), + cunn.SetActivation(activation_type=activation_type), ) - if self.deep_sup_dist: - self.final_dist_3_1 = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - torch.nn.Sigmoid(), - ) - self.final_dist_2_2 = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - torch.nn.Sigmoid(), - ) - self.final_dist_1_3 = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - torch.nn.Sigmoid(), - ) - if self.deep_sup_edge: - self.final_edge_3_1 = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - SigmoidCrisp(), - ) - self.final_edge_2_2 = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - SigmoidCrisp(), - ) - self.final_edge_1_3 = torch.nn.Sequential( - torch.nn.Conv2d(up_channels, 1, kernel_size=1, padding=0), - SigmoidCrisp(), - ) - if self.deep_sup_mask: - self.final_mask_3_1 = torch.nn.Sequential( - torch.nn.Conv2d( - up_channels, num_classes, kernel_size=1, padding=0 - ), - mask_activation, - ) - self.final_mask_2_2 = torch.nn.Sequential( - torch.nn.Conv2d( - up_channels, num_classes, kernel_size=1, padding=0 - ), - mask_activation, - ) - self.final_mask_1_3 = torch.nn.Sequential( - torch.nn.Conv2d( - up_channels, num_classes, kernel_size=1, padding=0 - ), - mask_activation, - ) - - def forward( - self, - out_0_4: T.Dict[str, torch.Tensor], - out_3_1: T.Dict[str, torch.Tensor], - out_2_2: T.Dict[str, torch.Tensor], - out_1_3: T.Dict[str, torch.Tensor], - ) -> T.Dict[str, torch.Tensor]: - dist = self.final_dist(out_0_4["dist"]) - edge = self.final_edge(out_0_4["edge"]) - mask = self.final_mask(out_0_4["mask"]) - - out = { - "dist": dist, - "edge": edge, - "mask": mask, - "dist_3_1": None, - "dist_2_2": None, - "dist_1_3": None, - "edge_3_1": None, - "edge_2_2": None, - "edge_1_3": None, - "mask_3_1": None, - "mask_2_2": None, - "mask_1_3": None, - } - if self.deep_sup_dist: - out["dist_3_1"] = self.final_dist_3_1( - self.up(out_3_1["dist"], size=dist.shape[-2:], mode="bilinear") - ) - out["dist_2_2"] = self.final_dist_2_2( - self.up(out_2_2["dist"], size=dist.shape[-2:], mode="bilinear") - ) - out["dist_1_3"] = self.final_dist_1_3( - self.up(out_1_3["dist"], size=dist.shape[-2:], mode="bilinear") - ) - if self.deep_sup_edge: - out["edge_3_1"] = self.final_edge_3_1( - self.up(out_3_1["edge"], size=edge.shape[-2:], mode="bilinear") - ) - out["edge_2_2"] = self.final_edge_2_2( - self.up(out_2_2["edge"], size=edge.shape[-2:], mode="bilinear") - ) - out["edge_1_3"] = self.final_edge_1_3( - self.up(out_1_3["edge"], size=edge.shape[-2:], mode="bilinear") - ) - if self.deep_sup_mask: - out["mask_3_1"] = self.final_mask_3_1( - self.up(out_3_1["mask"], size=mask.shape[-2:], mode="bilinear") - ) - out["mask_2_2"] = self.final_mask_2_2( - self.up(out_2_2["mask"], size=mask.shape[-2:], mode="bilinear") - ) - out["mask_1_3"] = self.final_mask_1_3( - self.up(out_1_3["mask"], size=mask.shape[-2:], mode="bilinear") - ) - - return out - - -class UNet3Psi(torch.nn.Module): - """UNet+++ with Psi-Net. + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.seq(x) - References: - https://arxiv.org/ftp/arxiv/papers/2004/2004.08790.pdf - https://arxiv.org/abs/1902.04099 - https://github.com/Bala93/Multi-task-deep-network - """ +class PreTimeReduction(nn.Module): def __init__( self, in_channels: int, in_time: int, - in_rnn_channels: int, - init_filter: int = 32, - num_classes: int = 2, - dilation: int = 2, - activation_type: str = "SiLU", - deep_sup_dist: T.Optional[bool] = False, - deep_sup_edge: T.Optional[bool] = False, - deep_sup_mask: T.Optional[bool] = False, - mask_activation: T.Union[Softmax, torch.nn.Sigmoid] = Softmax(dim=1), + out_channels: int, + activation_type: str, ): - super(UNet3Psi, self).__init__() - - init_filter = int(init_filter) - channels = [ - init_filter, - init_filter * 2, - init_filter * 4, - init_filter * 8, - init_filter * 16, - ] - up_channels = int(channels[0] * 5) + super().__init__() - self.pre_unet = PreUnet3Psi( + self.conv3 = Conv3d( in_channels=in_channels, - channels=channels, - activation_type=activation_type, - ) - - # Inputs = - # Reduced time dimensions - # Reduced channels (x2) for mean and max - # Input filters for RNN hidden logits - self.conv0_0 = SingleConv( - in_channels=( - in_time - + int(channels[0] * 4) - + in_rnn_channels - # Peak kernels and Trend kernels - + in_time - ), - out_channels=channels[0], - activation_type=activation_type, - ) - self.conv1_0 = PoolConv( - channels[0], - channels[1], - double_dilation=dilation, - activation_type=activation_type, - ) - self.conv2_0 = PoolConv( - channels[1], - channels[2], - double_dilation=dilation, - activation_type=activation_type, - ) - self.conv3_0 = PoolConv( - channels[2], - channels[3], - double_dilation=dilation, - activation_type=activation_type, - ) - self.conv4_0 = PoolConv( - channels[3], - channels[4], - double_dilation=dilation, + in_time=in_time, + out_channels=out_channels, + kernel_size=3, activation_type=activation_type, ) - # Connect 3 - self.convs_3_1 = UNet3_3_1( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - self.convs_2_2 = UNet3_2_2( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - self.convs_1_3 = UNet3_1_3( - channels=channels, - up_channels=up_channels, - dilations=[dilation], - activation_type=activation_type, - ) - self.convs_0_4 = UNet3_0_4( - channels=channels, - up_channels=up_channels, - dilations=[dilation], + self.conv5 = Conv3d( + in_channels=in_channels, + in_time=in_time, + out_channels=out_channels, + kernel_size=5, activation_type=activation_type, ) - self.post_unet = PostUNet3Psi( - up_channels=up_channels, - num_classes=num_classes, - mask_activation=mask_activation, - deep_sup_dist=deep_sup_dist, - deep_sup_edge=deep_sup_edge, - deep_sup_mask=deep_sup_mask, + self.layer_norm = nn.Sequential( + Rearrange('b c h w -> b h w c'), + nn.LayerNorm(out_channels), + Rearrange('b h w c -> b c h w'), ) - # Initialise weights - for m in self.modules(): - if isinstance( - m, - ( - torch.nn.Conv2d, - torch.nn.BatchNorm2d, - torch.nn.Conv3d, - torch.nn.BatchNorm3d, - ), - ): - m.apply(weights_init_kaiming) - - def forward( - self, x: torch.Tensor, rnn_h: torch.Tensor - ) -> T.Dict[str, T.Union[None, torch.Tensor]]: - # Inputs shape is (B x C X T|D x H x W) - h = self.pre_unet(x, rnn_h) - # h shape is (B x C x H x W) - # Backbone - # 1/1 - x0_0 = self.conv0_0(h) - # 1/2 - x1_0 = self.conv1_0(x0_0) - # 1/4 - x2_0 = self.conv2_0(x1_0) - # 1/8 - x3_0 = self.conv3_0(x2_0) - # 1/16 - x4_0 = self.conv4_0(x3_0) - - # 1/8 connection - out_3_1 = self.convs_3_1( - x0_0=x0_0, x1_0=x1_0, x2_0=x2_0, x3_0=x3_0, x4_0=x4_0 - ) - # 1/4 connection - out_2_2 = self.convs_2_2( - x0_0=x0_0, - x1_0=x1_0, - x2_0=x2_0, - h3_1_dist=out_3_1["dist"], - h3_1_edge=out_3_1["edge"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, - ) - # 1/2 connection - out_1_3 = self.convs_1_3( - x0_0=x0_0, - x1_0=x1_0, - h2_2_dist=out_2_2["dist"], - h3_1_dist=out_3_1["dist"], - h2_2_edge=out_2_2["edge"], - h3_1_edge=out_3_1["edge"], - h2_2_mask=out_2_2["mask"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, - ) - # 1/1 connection - out_0_4 = self.convs_0_4( - x0_0=x0_0, - h1_3_dist=out_1_3["dist"], - h2_2_dist=out_2_2["dist"], - h3_1_dist=out_3_1["dist"], - h1_3_edge=out_1_3["edge"], - h2_2_edge=out_2_2["edge"], - h3_1_edge=out_3_1["edge"], - h1_3_mask=out_1_3["mask"], - h2_2_mask=out_2_2["mask"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, - ) + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Parameters + ========== + x + Input, shaped (B, C, T, H, W). + """ - out = self.post_unet( - out_0_4=out_0_4, out_3_1=out_3_1, out_2_2=out_2_2, out_1_3=out_1_3 - ) + x3 = self.conv3(x) + x5 = self.conv5(x) - return out + encoded = self.layer_norm(x3 + x5) + return encoded -class ResUNet3Psi(torch.nn.Module): - """Residual UNet+++ with Psi-Net (Multi-head streams) and Attention. - References: - https://arxiv.org/ftp/arxiv/papers/2004/2004.08790.pdf - https://arxiv.org/abs/1902.04099 - https://github.com/Bala93/Multi-task-deep-network - https://github.com/hamidriasat/UNet-3-Plus - """ +class TowerUNet(nn.Module): + """Tower U-Net.""" def __init__( self, in_channels: int, in_time: int, - in_rnn_channels: int, - init_filter: int = 32, - num_classes: int = 2, - dilations: T.Sequence[int] = None, - activation_type: str = "LeakyReLU", - res_block_type: str = "resa", - attention_weights: T.Optional[str] = None, - deep_sup_dist: T.Optional[bool] = False, - deep_sup_edge: T.Optional[bool] = False, - deep_sup_mask: T.Optional[bool] = False, - mask_activation: T.Union[Softmax, torch.nn.Sigmoid] = Softmax(dim=1), + hidden_channels: int = 64, + num_classes: int = 1, + dilations: T.Optional[T.Sequence[int]] = None, + activation_type: str = "SiLU", + dropout: float = 0.0, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + pool_by_max: bool = False, + batchnorm_first: bool = False, + edge_activation: bool = True, + mask_activation: bool = True, + use_latlon: bool = False, ): - super(ResUNet3Psi, self).__init__() + super().__init__() + + if dilations is None: + dilations = [1, 2] - init_filter = int(init_filter) channels = [ - init_filter, - init_filter * 2, - init_filter * 4, - init_filter * 8, - init_filter * 16, + hidden_channels, # a + hidden_channels * 2, # b + hidden_channels * 4, # c + hidden_channels * 8, # d ] - up_channels = int(channels[0] * 5) + up_channels = int(hidden_channels * len(channels)) - self.pre_unet = PreUnet3Psi( - in_channels=in_channels, - channels=channels, - activation_type=activation_type, - ) - - # Inputs = - # Reduced time dimensions - # Reduced channels (x2) for mean and max - # Input filters for RNN hidden logits - if res_block_type.lower() == "res": - self.conv0_0 = ResidualConv( - in_channels=( - in_time - + int(channels[0] * 4) - + in_rnn_channels - # Peak kernels and Trend kernels - + in_time - ), - out_channels=channels[0], - dilation=dilations[0], - activation_type=activation_type, - attention_weights=attention_weights, - ) - else: - self.conv0_0 = ResidualAConv( - in_channels=( - in_time - + int(channels[0] * 4) - + in_rnn_channels - # Peak kernels and Trend kernels - + in_time - ), + self.pre_unet = torch.compile( + PreTimeReduction( + in_channels=in_channels, + in_time=in_time, out_channels=channels[0], - dilations=dilations, activation_type=activation_type, - attention_weights=attention_weights, ) - self.conv1_0 = PoolResidualConv( - channels[0], - channels[1], - dilations=dilations, - attention_weights=attention_weights, - res_block_type=ResBlockTypes[res_block_type.upper()], ) - self.conv2_0 = PoolResidualConv( - channels[1], - channels[2], - dilations=dilations, - activation_type=activation_type, - attention_weights=attention_weights, - res_block_type=ResBlockTypes[res_block_type.upper()], - ) - self.conv3_0 = PoolResidualConv( - channels[2], - channels[3], + + self.encoder = cunn.TowerUNetEncoder( + channels=channels, dilations=dilations, activation_type=activation_type, - attention_weights=attention_weights, - res_block_type=ResBlockTypes[res_block_type.upper()], + dropout=dropout, + res_block_type=res_block_type, + attention_weights=None, + pool_by_max=pool_by_max, + batchnorm_first=batchnorm_first, ) - self.conv4_0 = PoolResidualConv( - channels[3], - channels[4], + + self.decoder = cunn.TowerUNetDecoder( + channels=channels, + up_channels=up_channels, dilations=dilations, activation_type=activation_type, + dropout=dropout, + res_block_type=res_block_type, attention_weights=attention_weights, - res_block_type=ResBlockTypes[res_block_type.upper()], + batchnorm_first=batchnorm_first, ) - # Connect 3 - self.convs_3_1 = ResUNet3_3_1( + self.tower_fusion = cunn.TowerUNetFusion( channels=channels, up_channels=up_channels, dilations=dilations, - attention_weights=attention_weights, activation_type=activation_type, - res_block_type=ResBlockTypes[res_block_type.upper()], + dropout=dropout, + res_block_type=res_block_type, + attention_weights=None, + batchnorm_first=batchnorm_first, + use_latlon=use_latlon, ) - self.convs_2_2 = ResUNet3_2_2( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, + + self.final_a = cunn.TowerUNetFinal( + in_channels=up_channels, + num_classes=num_classes, activation_type=activation_type, - res_block_type=ResBlockTypes[res_block_type.upper()], ) - self.convs_1_3 = ResUNet3_1_3( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, + + self.final_b = cunn.TowerUNetFinal( + in_channels=up_channels, + num_classes=num_classes, activation_type=activation_type, - res_block_type=ResBlockTypes[res_block_type.upper()], + resample_factor=2, ) - self.convs_0_4 = ResUNet3_0_4( - channels=channels, - up_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, + + self.final_c = cunn.TowerUNetFinal( + in_channels=up_channels, + num_classes=num_classes, activation_type=activation_type, - res_block_type=ResBlockTypes[res_block_type.upper()], + resample_factor=4, ) - self.post_unet = PostUNet3Psi( - up_channels=up_channels, + self.final_combine = cunn.TowerUNetFinalCombine( num_classes=num_classes, + edge_activation=edge_activation, mask_activation=mask_activation, - deep_sup_dist=deep_sup_dist, - deep_sup_edge=deep_sup_edge, - deep_sup_mask=deep_sup_mask, ) - # Initialise weights - for m in self.modules(): - if isinstance( - m, - ( - torch.nn.Conv2d, - torch.nn.BatchNorm2d, - torch.nn.Conv3d, - torch.nn.BatchNorm3d, - ), - ): - m.apply(weights_init_kaiming) + # Initialize weights + self.apply(init_conv_weights) def forward( - self, x: torch.Tensor, rnn_h: torch.Tensor - ) -> T.Dict[str, T.Union[None, torch.Tensor]]: - # Inputs shape is (B x C X T|D x H x W) - h = self.pre_unet(x, rnn_h) - # h shape is (B x C x H x W) - # Backbone - # 1/1 - x0_0 = self.conv0_0(h) - # 1/2 - x1_0 = self.conv1_0(x0_0) - # 1/4 - x2_0 = self.conv2_0(x1_0) - # 1/8 - x3_0 = self.conv3_0(x2_0) - # 1/16 - x4_0 = self.conv4_0(x3_0) + self, + x: torch.Tensor, + latlon_coords: T.Optional[torch.Tensor] = None, + ) -> T.Dict[str, torch.Tensor]: + + """Forward pass. + + Parameters + ========== + x + The input image time series, shaped (B, C, T, H, W). + """ + + # Initial temporal reduction and convolutions to + # hidden dimensions + embeddings = self.pre_unet(x) - # 1/8 connection - out_3_1 = self.convs_3_1( - x0_0=x0_0, x1_0=x1_0, x2_0=x2_0, x3_0=x3_0, x4_0=x4_0 + encoded = self.encoder(embeddings) + decoded = self.decoder(encoded) + towers_fused = self.tower_fusion( + encoded=encoded, + decoded=decoded, + latlon_coords=latlon_coords, ) - # 1/4 connection - out_2_2 = self.convs_2_2( - x0_0=x0_0, - x1_0=x1_0, - x2_0=x2_0, - h3_1_dist=out_3_1["dist"], - h3_1_edge=out_3_1["edge"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, + + # Final outputs + + # -> {InferenceNames.DISTANCE_a, InferenceNames.EDGE_a, InferenceNames.CROP_a} + out_a = self.final_a( + towers_fused["x_tower_a"], + suffix="_a", ) - # 1/2 connection - out_1_3 = self.convs_1_3( - x0_0=x0_0, - x1_0=x1_0, - h2_2_dist=out_2_2["dist"], - h3_1_dist=out_3_1["dist"], - h2_2_edge=out_2_2["edge"], - h3_1_edge=out_3_1["edge"], - h2_2_mask=out_2_2["mask"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, + + # -> {InferenceNames.DISTANCE_b, InferenceNames.EDGE_b, InferenceNames.CROP_b} + out_b = self.final_b( + towers_fused["x_tower_b"], + size=towers_fused["x_tower_a"].shape[-2:], + suffix="_b", ) - # 1/1 connection - out_0_4 = self.convs_0_4( - x0_0=x0_0, - h1_3_dist=out_1_3["dist"], - h2_2_dist=out_2_2["dist"], - h3_1_dist=out_3_1["dist"], - h1_3_edge=out_1_3["edge"], - h2_2_edge=out_2_2["edge"], - h3_1_edge=out_3_1["edge"], - h1_3_mask=out_1_3["mask"], - h2_2_mask=out_2_2["mask"], - h3_1_mask=out_3_1["mask"], - x4_0=x4_0, + + # -> {InferenceNames.DISTANCE_c, InferenceNames.EDGE_c, InferenceNames.CROP_c} + out_c = self.final_c( + towers_fused["x_tower_c"], + size=towers_fused["x_tower_a"].shape[-2:], + suffix="_c", ) - out = self.post_unet( - out_0_4=out_0_4, out_3_1=out_3_1, out_2_2=out_2_2, out_1_3=out_1_3 + out = self.final_combine( + out_a, out_b, out_c, suffixes=["_a", "_b", "_c"] ) return out diff --git a/src/cultionet/models/unet_parts.py b/src/cultionet/models/unet_parts.py deleted file mode 100644 index fff24395..00000000 --- a/src/cultionet/models/unet_parts.py +++ /dev/null @@ -1,1268 +0,0 @@ -import typing as T -import enum - -import torch - -from . import model_utils -from .base_layers import ( - AttentionGate, - AtrousPyramidPooling, - DoubleConv, - PoolConv, - PoolResidualConv, - ResidualAConv, - ResidualConv, -) -from .enums import ModelTypes, ResBlockTypes - - -class UNet3Connector(torch.nn.Module): - """Connects layers in a UNet 3+ architecture.""" - - def __init__( - self, - channels: T.List[int], - up_channels: int, - prev_backbone_channel_index: int, - use_backbone: bool = True, - is_side_stream: bool = True, - n_pools: int = 0, - n_prev_down: int = 0, - n_stream_down: int = 0, - attention_weights: str = "spatial_channel", - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - model_type: enum = ModelTypes.UNET, - res_block_type: enum = ResBlockTypes.RESA, - activation_type: str = "LeakyReLU", - ): - super(UNet3Connector, self).__init__() - - assert attention_weights in [ - "gate", - "fractal", - "spatial_channel", - ], "Choose from 'gate', 'fractal', or 'spatial_channel' attention weights." - - assert model_type in (ModelTypes.UNET, ModelTypes.RESUNET) - assert res_block_type in (ResBlockTypes.RES, ResBlockTypes.RESA) - - self.n_pools = n_pools - self.n_prev_down = n_prev_down - self.n_stream_down = n_stream_down - self.attention_weights = attention_weights - self.use_backbone = use_backbone - self.is_side_stream = is_side_stream - self.cat_channels = 0 - self.pool4_0 = None - - self.up = model_utils.UpSample() - - if dilations is None: - dilations = [2] - - # Pool layers - if n_pools > 0: - if n_pools == 3: - pool_size = 8 - elif n_pools == 2: - pool_size = 4 - else: - pool_size = 2 - - for n in range(0, n_pools): - if model_type == ModelTypes.UNET: - setattr( - self, - f"pool_{n}", - PoolConv( - in_channels=channels[n], - out_channels=channels[0], - pool_size=pool_size, - double_dilation=dilations[0], - activation_type=activation_type, - ), - ) - else: - setattr( - self, - f"pool_{n}", - PoolResidualConv( - in_channels=channels[n], - out_channels=channels[0], - pool_size=pool_size, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - res_block_type=res_block_type, - ), - ) - pool_size = int(pool_size / 2) - self.cat_channels += channels[0] - if self.use_backbone: - if model_type == ModelTypes.UNET: - self.prev_backbone = DoubleConv( - in_channels=channels[prev_backbone_channel_index], - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ) - else: - if res_block_type == ResBlockTypes.RES: - self.prev_backbone = ResidualConv( - in_channels=channels[prev_backbone_channel_index], - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - else: - self.prev_backbone = ResidualAConv( - in_channels=channels[prev_backbone_channel_index], - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.cat_channels += up_channels - if self.is_side_stream: - if model_type == ModelTypes.UNET: - # Backbone, same level - self.prev = DoubleConv( - in_channels=up_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ) - else: - if res_block_type == ResBlockTypes.RES: - self.prev = ResidualConv( - in_channels=up_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - else: - self.prev = ResidualAConv( - in_channels=up_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.cat_channels += up_channels - # Previous output, downstream - if self.n_prev_down > 0: - for n in range(0, self.n_prev_down): - if model_type == ModelTypes.UNET: - setattr( - self, - f"prev_{n}", - DoubleConv( - in_channels=up_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ), - ) - else: - if res_block_type == ResBlockTypes.RES: - setattr( - self, - f"prev_{n}", - ResidualConv( - in_channels=up_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - else: - setattr( - self, - f"prev_{n}", - ResidualAConv( - in_channels=up_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - self.cat_channels += up_channels - # Previous output, (same) downstream - if self.n_stream_down > 0: - for n in range(0, self.n_stream_down): - in_stream_channels = up_channels - if self.attention_weights is not None and ( - self.attention_weights == "gate" - ): - attention_module = AttentionGate(up_channels, up_channels) - setattr(self, f"attn_stream_{n}", attention_module) - in_stream_channels = up_channels * 2 - if model_type == ModelTypes.UNET: - setattr( - self, - f"stream_{n}", - DoubleConv( - in_channels=in_stream_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ), - ) - else: - if res_block_type == ResBlockTypes.RES: - setattr( - self, - f"stream_{n}", - ResidualConv( - in_channels=in_stream_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - else: - setattr( - self, - f"stream_{n}", - ResidualAConv( - in_channels=in_stream_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ), - ) - self.cat_channels += up_channels - - self.cat_channels += channels[0] - if model_type == ModelTypes.UNET: - self.conv4_0 = DoubleConv( - in_channels=channels[4], - out_channels=channels[0], - init_point_conv=init_point_conv, - activation_type=activation_type, - ) - self.final = DoubleConv( - in_channels=self.cat_channels, - out_channels=up_channels, - init_point_conv=init_point_conv, - double_dilation=dilations[0], - activation_type=activation_type, - ) - else: - if res_block_type == ResBlockTypes.RES: - self.conv4_0 = ResidualConv( - in_channels=channels[4], - out_channels=channels[0], - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.final = ResidualConv( - in_channels=self.cat_channels, - out_channels=up_channels, - dilation=dilations[0], - attention_weights=attention_weights, - activation_type=activation_type, - ) - else: - self.conv4_0 = ResidualAConv( - in_channels=channels[4], - out_channels=channels[0], - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - self.final = ResidualAConv( - in_channels=self.cat_channels, - out_channels=up_channels, - dilations=dilations, - attention_weights=attention_weights, - activation_type=activation_type, - ) - # self.pool4_0 = AtrousPyramidPooling( - # in_channels=channels[0], - # out_channels=channels[0] - # ) - - def forward( - self, - prev_same: T.List[T.Tuple[str, torch.Tensor]], - x4_0: torch.Tensor, - pools: T.List[torch.Tensor] = None, - prev_down: T.List[torch.Tensor] = None, - stream_down: T.List[torch.Tensor] = None, - ): - h = [] - # Pooling layer of the backbone - if pools is not None: - assert self.n_pools == len( - pools - ), "There are no convolutions available for the pool layers." - for n, x in zip(range(self.n_pools), pools): - c = getattr(self, f"pool_{n}") - h += [c(x)] - # Up down layers from the previous head - if prev_down is not None: - assert self.n_prev_down == len( - prev_down - ), "There are no convolutions available for the previous downstream layers." - for n, x in zip(range(self.n_prev_down), prev_down): - c = getattr(self, f"prev_{n}") - h += [ - c( - self.up( - x, size=prev_same[0][1].shape[-2:], mode="bilinear" - ) - ) - ] - assert len(prev_same) == sum( - [self.use_backbone, self.is_side_stream] - ), "The previous same layers do not match the setup." - # Previous same layers from the previous head - for conv_name, prev_inputs in prev_same: - c = getattr(self, conv_name) - h += [c(prev_inputs)] - if self.attention_weights is not None and ( - self.attention_weights == "gate" - ): - prev_same_hidden = h[-1].clone() - # Previous down layers from the same head - if stream_down is not None: - assert self.n_stream_down == len( - stream_down - ), "There are no convolutions available for the downstream layers." - for n, x in zip(range(self.n_stream_down), stream_down): - if self.attention_weights is not None and ( - self.attention_weights == "gate" - ): - # Gate - g = self.up( - x, size=prev_same[0][1].shape[-2:], mode="bilinear" - ) - c_attn = getattr(self, f"attn_stream_{n}") - # Attention gate - attn_out = c_attn(g, prev_same_hidden) - c = getattr(self, f"stream_{n}") - # Concatenate attention weights - h += [c(torch.cat([attn_out, g], dim=1))] - else: - c = getattr(self, f"stream_{n}") - h += [ - c( - self.up( - x, - size=prev_same[0][1].shape[-2:], - mode="bilinear", - ) - ) - ] - - # Lowest level - x4_0_up = self.conv4_0( - self.up(x4_0, size=prev_same[0][1].shape[-2:], mode="bilinear") - ) - if self.pool4_0 is not None: - h += [self.pool4_0(x4_0_up)] - else: - h += [x4_0_up] - h = torch.cat(h, dim=1) - h = self.final(h) - - return h - - -class UNet3P_3_1(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 3,1.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "LeakyReLU", - ): - super(UNet3P_3_1, self).__init__() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - x3_0: torch.Tensor, - x4_0: torch.Tensor, - ) -> torch.Tensor: - h = self.conv( - prev_same=[("prev_backbone", x3_0)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - - return h - - -class UNet3P_2_2(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 2,2.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "LeakyReLU", - ): - super(UNet3P_2_2, self).__init__() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - h3_1: torch.Tensor, - x4_0: torch.Tensor, - ) -> torch.Tensor: - h = self.conv( - prev_same=[("prev_backbone", x2_0)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1], - ) - - return h - - -class UNet3P_1_3(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 1,3.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "LeakyReLU", - ): - super(UNet3P_1_3, self).__init__() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - h2_2: torch.Tensor, - h3_1: torch.Tensor, - x4_0: torch.Tensor, - ) -> torch.Tensor: - h = self.conv( - prev_same=[("prev_backbone", x1_0)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1, h2_2], - ) - - return h - - -class UNet3P_0_4(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 0,4.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - double_dilation: int = 1, - activation_type: str = "LeakyReLU", - ): - super(UNet3P_0_4, self).__init__() - - self.up = model_utils.UpSample() - - self.conv = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=[double_dilation], - model_type=ModelTypes.UNET, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - h1_3: torch.Tensor, - h2_2: torch.Tensor, - h3_1: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h = self.conv( - prev_same=[("prev_backbone", x0_0)], - x4_0=x4_0, - stream_down=[h3_1, h2_2, h1_3], - ) - - return h - - -class UNet3_3_1(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 3,1.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "LeakyReLU", - ): - super(UNet3_3_1, self).__init__() - - self.up = model_utils.UpSample() - - # Distance stream connection - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - # Edge stream connection - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - # Mask stream connection - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=3, - n_pools=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - x3_0: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - # Distance logits - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x3_0)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - # Output distance logits pass to edge layer - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x3_0), ("prev", h_dist)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - # Output edge logits pass to mask layer - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x3_0), ("prev", h_edge)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class UNet3_2_2(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 2,2.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "LeakyReLU", - ): - super(UNet3_2_2, self).__init__() - - self.up = model_utils.UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - h3_1_dist: torch.Tensor, - h3_1_edge: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x2_0)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x2_0), ("prev", h_dist)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x2_0), ("prev", h_edge)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class UNet3_1_3(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 1,3.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "LeakyReLU", - ): - super(UNet3_1_3, self).__init__() - - self.up = model_utils.UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - h2_2_dist: torch.Tensor, - h3_1_dist: torch.Tensor, - h2_2_edge: torch.Tensor, - h3_1_edge: torch.Tensor, - h2_2_mask: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x1_0)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_dist, h2_2_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x1_0), ("prev", h_dist)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_edge, h2_2_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x1_0), ("prev", h_edge)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_mask, h2_2_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class UNet3_0_4(torch.nn.Module): - """UNet 3+ connection from backbone to upstream 0,4.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - init_point_conv: bool = False, - dilations: T.Sequence[int] = None, - activation_type: str = "LeakyReLU", - ): - super(UNet3_0_4, self).__init__() - - self.up = model_utils.UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - is_side_stream=False, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - prev_backbone_channel_index=0, - n_stream_down=3, - init_point_conv=init_point_conv, - dilations=dilations, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - h1_3_dist: torch.Tensor, - h2_2_dist: torch.Tensor, - h3_1_dist: torch.Tensor, - h1_3_edge: torch.Tensor, - h2_2_edge: torch.Tensor, - h3_1_edge: torch.Tensor, - h1_3_mask: torch.Tensor, - h2_2_mask: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x0_0)], - x4_0=x4_0, - stream_down=[h3_1_dist, h2_2_dist, h1_3_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x0_0), ("prev", h_dist)], - x4_0=x4_0, - stream_down=[h3_1_edge, h2_2_edge, h1_3_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x0_0), ("prev", h_edge)], - x4_0=x4_0, - stream_down=[h3_1_mask, h2_2_mask, h1_3_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_3_1(torch.nn.Module): - """Residual UNet 3+ connection from backbone to upstream 3,1.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - dilations: T.Sequence[int] = None, - attention_weights: str = "spatial_channel", - activation_type: str = "LeakyReLU", - res_block_type: enum = ResBlockTypes.RESA, - ): - super(ResUNet3_3_1, self).__init__() - - self.up = model_utils.UpSample() - - # Distance stream connection - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=3, - n_pools=3, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - # Edge stream connection - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=3, - n_pools=3, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - # Mask stream connection - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=3, - n_pools=3, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - x3_0: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - # Distance logits - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x3_0)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - # Output distance logits pass to edge layer - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x3_0), ("prev", h_dist)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - # Output edge logits pass to mask layer - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x3_0), ("prev", h_edge)], - pools=[x0_0, x1_0, x2_0], - x4_0=x4_0, - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_2_2(torch.nn.Module): - """Residual UNet 3+ connection from backbone to upstream 2,2.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - dilations: T.Sequence[int] = None, - attention_weights: str = "spatial_channel", - activation_type: str = "LeakyReLU", - res_block_type: enum = ResBlockTypes.RESA, - ): - super(ResUNet3_2_2, self).__init__() - - self.up = model_utils.UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=2, - n_pools=2, - n_stream_down=1, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - x2_0: torch.Tensor, - h3_1_dist: torch.Tensor, - h3_1_edge: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x2_0)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x2_0), ("prev", h_dist)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x2_0), ("prev", h_edge)], - pools=[x0_0, x1_0], - x4_0=x4_0, - stream_down=[h3_1_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_1_3(torch.nn.Module): - """Residual UNet 3+ connection from backbone to upstream 1,3.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - dilations: T.Sequence[int] = None, - attention_weights: str = "spatial_channel", - activation_type: str = "LeakyReLU", - res_block_type: enum = ResBlockTypes.RESA, - ): - super(ResUNet3_1_3, self).__init__() - - self.up = model_utils.UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=1, - n_pools=1, - n_stream_down=2, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - x1_0: torch.Tensor, - h2_2_dist: torch.Tensor, - h3_1_dist: torch.Tensor, - h2_2_edge: torch.Tensor, - h3_1_edge: torch.Tensor, - h2_2_mask: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x1_0)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_dist, h2_2_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x1_0), ("prev", h_dist)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_edge, h2_2_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x1_0), ("prev", h_edge)], - pools=[x0_0], - x4_0=x4_0, - stream_down=[h3_1_mask, h2_2_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } - - -class ResUNet3_0_4(torch.nn.Module): - """Residual UNet 3+ connection from backbone to upstream 0,4.""" - - def __init__( - self, - channels: T.Sequence[int], - up_channels: int, - dilations: T.Sequence[int] = None, - attention_weights: str = "spatial_channel", - activation_type: str = "LeakyReLU", - res_block_type: enum = ResBlockTypes.RESA, - ): - super(ResUNet3_0_4, self).__init__() - - self.up = model_utils.UpSample() - - self.conv_dist = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=False, - prev_backbone_channel_index=0, - n_stream_down=3, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_edge = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=0, - n_stream_down=3, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - self.conv_mask = UNet3Connector( - channels=channels, - up_channels=up_channels, - use_backbone=True, - is_side_stream=True, - prev_backbone_channel_index=0, - n_stream_down=3, - dilations=dilations, - attention_weights=attention_weights, - model_type=ModelTypes.RESUNET, - res_block_type=res_block_type, - activation_type=activation_type, - ) - - def forward( - self, - x0_0: torch.Tensor, - h1_3_dist: torch.Tensor, - h2_2_dist: torch.Tensor, - h3_1_dist: torch.Tensor, - h1_3_edge: torch.Tensor, - h2_2_edge: torch.Tensor, - h3_1_edge: torch.Tensor, - h1_3_mask: torch.Tensor, - h2_2_mask: torch.Tensor, - h3_1_mask: torch.Tensor, - x4_0: torch.Tensor, - ) -> T.Dict[str, torch.Tensor]: - h_dist = self.conv_dist( - prev_same=[("prev_backbone", x0_0)], - x4_0=x4_0, - stream_down=[h3_1_dist, h2_2_dist, h1_3_dist], - ) - h_edge = self.conv_edge( - prev_same=[("prev_backbone", x0_0), ("prev", h_dist)], - x4_0=x4_0, - stream_down=[h3_1_edge, h2_2_edge, h1_3_edge], - ) - h_mask = self.conv_mask( - prev_same=[("prev_backbone", x0_0), ("prev", h_edge)], - x4_0=x4_0, - stream_down=[h3_1_mask, h2_2_mask, h1_3_mask], - ) - - return { - "dist": h_dist, - "edge": h_edge, - "mask": h_mask, - } diff --git a/src/cultionet/networks/__init__.py b/src/cultionet/networks/__init__.py deleted file mode 100644 index da5abcf8..00000000 --- a/src/cultionet/networks/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._build_network import SingleSensorNetwork, MultiSensorNetwork - -__all__ = ["SingleSensorNetwork", "MultiSensorNetwork"] diff --git a/src/cultionet/networks/_build_network.pyx b/src/cultionet/networks/_build_network.pyx deleted file mode 100644 index 5540724a..00000000 --- a/src/cultionet/networks/_build_network.pyx +++ /dev/null @@ -1,916 +0,0 @@ -# distutils: language=c++ -# cython: language_level=3 -# cython: profile=False -# cython: cdivision=True -# cython: boundscheck=False -# cython: wraparound=False -# cython: nonecheck=False - -import cython -cimport cython - -import numpy as np -cimport numpy as np - -from libc.stdint cimport int64_t - - -cdef extern from 'stdlib.h' nogil: - double fabs(double value) - - -cdef extern from 'stdlib.h' nogil: - int abs(int value) - - -cdef extern from 'numpy/npy_math.h' nogil: - bint npy_isnan(double x) - - -# nogil vector -cdef extern from "" namespace "std": - cdef cppclass vector[T]: - void push_back(T&) nogil - size_t size() nogil - T& operator[](size_t) nogil - void clear() nogil - - -cdef inline double _euclidean_distance(double xloc, double yloc, double xh, double yh) nogil: - return ((xloc - xh)**2 + (yloc - yh)**2)**0.5 - - -cdef inline double _get_max(double v1, double v2) nogil: - return v1 if v1 >= v2 else v2 - - -cdef inline double _clip_high(double v, double high) nogil: - return high if v > high else v - - -cdef inline double _clip(double v, double low, double high) nogil: - return low if v < low else _clip_high(v, high) - - -cdef inline double _scale_min_max(double xv, double mni, double mxi, double mno, double mxo) nogil: - return (((mxo - mno) * (xv - mni)) / (mxi - mni)) + mno - - -cdef double _get_mean_3d( - double[:, :, ::1] data, - unsigned int nbands, - unsigned int ridx, - unsigned int cidx -) nogil: - """Returns the band-wise mean - """ - cdef: - Py_ssize_t n - double data_mean = 0.0 - double data_val - - for n in range(0, nbands): - data_val = data[n, ridx, cidx] - data_mean += data_val - - return data_mean / nbands - - -cdef double _get_max_3d( - double[:, :, ::1] data, - unsigned int nbands, - unsigned int ridx, - unsigned int cidx -) nogil: - """Returns the band-wise maximum - """ - cdef: - Py_ssize_t n - double data_max = -1e9 - double data_val - - for n in range(0, nbands): - data_val = data[n, ridx, cidx] - data_max = _get_max(data_val, data_max) - - return data_max - - -cdef double _get_max_4d( - double[:, :, :, ::1] data, - unsigned int ntime, - unsigned int nbands, - unsigned int ridx, - unsigned int cidx -) nogil: - """Returns the time- and band-wise maximum - """ - cdef: - Py_ssize_t m, n - double data_max = -1e9 - double data_val - - for m in range(0, ntime): - for n in range(0, nbands): - data_val = data[m, n, ridx, cidx] - data_max = _get_max(data_val, data_max) - - return data_max - - -cdef double _determinant_transform(vector[double] t) nogil: - """The determinant of the transform matrix. - This value is equal to the area scaling factor when the - transform is applied to a shape. - - Reference: - https://github.com/sgillies/affine/blob/master/affine/__init__.py - - Copyright (c) 2014, Sean C. Gillies - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of Sean C. Gillies nor the names of - its contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. - """ - cdef: - double sa, sb, sc, sd, se, sf - - sa, sb, sc, sd, se, sf = t[0], t[1], t[2], t[3], t[4], t[5] - - return sa * se - sb * sd - - -cdef vector[double] _invert_transform(vector[double] t) nogil: - """Returns the inverse transform - - Reference: - https://github.com/sgillies/affine/blob/master/affine/__init__.py - - Copyright (c) 2014, Sean C. Gillies - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of Sean C. Gillies nor the names of - its contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. - """ - cdef: - vector[double] t_ - double idet - double sa, sb, sc, sd, se, sf - double ra, rb, rd, re - - idet = 1.0 / _determinant_transform(t) - sa, sb, sc, sd, se, sf = t[0], t[1], t[2], t[3], t[4], t[5] - ra = se * idet - rb = -sb * idet - rd = -sd * idet - re = sa * idet - - t_.push_back(ra) - t_.push_back(rb) - t_.push_back(-sc * ra - sf * rb) - t_.push_back(rd) - t_.push_back(re) - t_.push_back(-sc * rd - sf * re) - - return t_ - - -cdef void _transform_coords_to_indices( - vector[double] t, - double vx, - double vy, - int64_t[::1] out_indices__ -) nogil: - """Transforms coordinates to indices - - Reference: - https://github.com/sgillies/affine/blob/master/affine/__init__.py - - Copyright (c) 2014, Sean C. Gillies - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - * Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - * Neither the name of Sean C. Gillies nor the names of - its contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - - THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE - ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE - LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR - CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF - SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS - INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN - CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) - ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE - POSSIBILITY OF SUCH DAMAGE. - """ - cdef: - double sa, sb, sc, sd, se, sf - - sa, sb, sc, sd, se, sf = t[0], t[1], t[2], t[3], t[4], t[5] - - out_indices__[0] = (vx * sa + vy * sb + sc) - out_indices__[1] = (vx * sd + vy * se + sf) - - -cdef void _coarse_transformer( - Py_ssize_t i, - Py_ssize_t j, - unsigned int kh, - vector[double] hr_transform_, - vector[double] cr_transform_, - int64_t[::1] out_indices_ -) nogil: - """Transforms coordinates to indices for a coarse-to-high resolution transformation - - Args: - i (int): The row index position for the high-resolution grid. - j (int): The column index position for the high-resolution grid. - kh (int): The center pixel offset for the high-resolution grid. - hr_transform (list): The high-resolution affine transform. - cr_transform (list): The coarse-resolution affine transform. - """ - cdef: - double x, y - Py_ssize_t row_index, col_index - - # Coordinates of the high-resolution center pixel - x = hr_transform_[2] + ((j+kh) * fabs(hr_transform_[0])) - y = hr_transform_[5] - ((i+kh) * fabs(hr_transform_[4])) - - # Invert the coarse resolution affine transform and - # get the indices at the x,y coordinates. - _transform_coords_to_indices(_invert_transform(cr_transform_), x, y, out_indices_) - - -cdef class SingleSensorNetwork(object): - """A network class for a single sensor - """ - cdef: - int64_t[:, ::1] grid - vector[int] edge_indices_a - vector[int] edge_indices_b - vector[double] edge_attrs_diffs, edge_attrs_dists - vector[double] xpos, ypos - unsigned int nbands, nrows, ncols - double[:, :, ::1] varray - int k_, kh - double cell_size_ - double max_dist, max_scaled, eps - - def __init__( - self, - double[:, :, ::1] value_array, - int k=3, - float cell_size=30.0 - ): - self.nbands = value_array.shape[0] - self.nrows = value_array.shape[1] - self.ncols = value_array.shape[2] - self.k_ = k - self.kh = (self.k_ * 0.5) - self.cell_size_ = cell_size - self.varray = value_array - - self.max_dist = _euclidean_distance(0.0, 0.0, self.kh, self.kh) - self.max_scaled = 1.0 - (_euclidean_distance(self.kh, self.kh-1, self.kh, self.kh) / self.max_dist) - self.eps = 1e-6 - - self.grid = np.arange(0, self.nrows*self.ncols).reshape(self.nrows, self.ncols).astype('int64') - - def create_network(self): - self._create_network() - - return self.edge_indices_a, self.edge_indices_b, self.edge_attrs_diffs, self.edge_attrs_dists, self.xpos, self.ypos - - cdef void _create_network(self) nogil: - cdef: - Py_ssize_t i, j, m, n - bint do_connect - unsigned int column_end = self.ncols - self.k_ - - for i in range(0, self.nrows-self.k_): - for j in range(0, self.ncols-self.k_): - # Connect to center node - for m in range(0, self.k_): - for n in range(0, self.k_): - if m+1 < self.k_: - do_connect = True - if (i > 0) and (j == 0): - if m < self.kh: - do_connect = False - - elif j > 0: - # Only the second column half of the window needs updated - if n <= self.kh: - do_connect = False - - if do_connect: - # Vertical connection - self._connect_window(m, n, m+1, n, i, j) - - if n+1 < self.k_: - do_connect = True - if (i > 0) and (j == 0): - if m <= self.kh: - do_connect = False - - elif j > 0: - if n < self.kh: - do_connect = False - - if do_connect: - # Horizontal connection - self._connect_window(m, n, m, n+1, i, j) - - if (j == 0) and (m == 0) and (n == self.kh): - self._connect_window(m, n, self.kh, 0, i, j) - - if (j == column_end) and (m == 0) and (n == self.kh): - self._connect_window(m, n, self.kh, self.k_-1, i, j) - - # Avoid already connected direct neighbors - # o - x - o - # | \ | / | - # x - O - x - # | / | \ | - # o - x - o - if abs(m - self.kh) + abs(n - self.kh) > self.kh: - # Diagonal edges - self._connect_window(m, n, self.kh, self.kh, i, j) - - cdef void _connect_window( - self, - Py_ssize_t isource, - Py_ssize_t jsource, - Py_ssize_t itarg, - Py_ssize_t jtarg, - Py_ssize_t idx, - Py_ssize_t jdx, - bint directed=False - ) nogil: - """ - Args: - isource (int): The source window row index. - jsource (int): The source window column index. - itarg (int): The target window row index. - jtarg (int): The target window column index. - idx (int): The array row index. - jdx (int): The array column index. - max_dist (float): The maximum window distance from the center. - eps (float): An offset value to avoid zero weights. - """ - cdef: - Py_ssize_t b - double w, val_diff - - # COO format: - # [[sources, ...] - # [targets, ...]] - - # Center to link - self.edge_indices_a.push_back(self.grid[idx+isource, jdx+jsource]) - self.edge_indices_b.push_back(self.grid[idx+itarg, jdx+jtarg]) - - if not directed: - # Link to center - self.edge_indices_a.push_back(self.grid[idx+itarg, jdx+jtarg]) - self.edge_indices_b.push_back(self.grid[idx+isource, jdx+jsource]) - - w = 1.0 - (_euclidean_distance(jsource, isource, jtarg, itarg) / self.max_dist) - - w = _scale_min_max(w, 0.0, self.max_scaled, 0.75, 1.0) - w = _clip(w, 0.75, 1.0) - - if npy_isnan(w): - w = self.eps - - val_diff = 0.0 - for b in range(0, self.nbands): - val_diff += self.varray[b, idx+isource, jdx+jsource] - self.varray[b, idx+itarg, jdx+jtarg] - - val_diff /= self.nbands - - val_diff = _clip(fabs(val_diff), 0.0, 1.0) - val_diff = _scale_min_max(val_diff, 0.0, 1.0, self.eps, 1.0) - val_diff = _clip(val_diff, self.eps, 1.0) - - if npy_isnan(val_diff): - val_diff = self.eps - - # Edge attributes - self.edge_attrs_diffs.push_back(val_diff) - self.edge_attrs_dists.push_back(w) - - # x, y coordinates - self.xpos.push_back((jdx+jtarg)*self.cell_size_) - self.ypos.push_back((idx+itarg)*self.cell_size_) - - if not directed: - self.edge_attrs_diffs.push_back(val_diff) - self.edge_attrs_dists.push_back(w) - self.xpos.push_back((jdx+jsource)*self.cell_size_) - self.ypos.push_back((idx+isource)*self.cell_size_) - - -cdef class MultiSensorNetwork(object): - """A class for a multi-sensor network - """ - cdef: - unsigned int ntime, nbands, nrows_, ncols_ - double[:, :, :, ::1] xarray - double[:, :, ::1] yarray - vector[vector[double]] transforms_ - unsigned int n_transforms_ - int64_t[:, ::1] grid_ - vector[int64_t[:, ::1]] grid_c_ - vector[double[:, :, :, ::1]] grid_c_resamp_ - unsigned int k_, kh - double nodata_ - double coarse_window_res_limit_ - double max_edist_hres_ - bint add_coarse_nodes_ - - vector[int] edge_indices_a - vector[int] edge_indices_b - vector[double] edge_attrs - """Creates graph edges and edge attributes - - Args: - xdata (4d array): [time x bands x rows x columns] - ydata (3d array): [band x rows x columns] - nrows (int) - ncols (int) - transforms (list) - direct_to_center (bool): Whether to direct edges connected to the center pixel (i.e., in one direction). - add_coarse_nodes (bool): Whether to add coarse resolution data as separate nodes. - k (int): The local window size. - nodata (float | int) - coarse_window_res_limit (float | int) - """ - - def __init__( - self, - double[:, :, :, ::1] xdata, - double[:, :, ::1] ydata, - vector[vector[double]] transforms, - unsigned int n_transforms, - int64_t[:, ::1] grid, - vector[int64_t[:, ::1]] grid_c, - vector[double[:, :, :, ::1]] grid_c_resamp, - bint direct_to_center=False, - bint add_coarse_nodes=False, - unsigned int k=7, - double nodata=0.0, - double coarse_window_res_limit=30.0, - double max_edist_hres=1.0 - ): - self.xarray = xdata - self.yarray = ydata - self.transforms_ = transforms - - self.n_transforms_ = n_transforms - - self.ntime = self.xarray.shape[0] - self.nbands = self.xarray.shape[1] - self.nrows_ = self.xarray.shape[2] - self.ncols_ = self.xarray.shape[3] - - # 1:1 grid for high-res y and high-res X variables - self.grid_ = grid - self.grid_c_ = grid_c - self.grid_c_resamp_ = grid_c_resamp - - self.add_coarse_nodes_ = add_coarse_nodes - self.k_ = k - self.kh = (self.k_ / 2.0) - self.nodata_ = nodata - self.coarse_window_res_limit_ = coarse_window_res_limit - self.max_edist_hres_ = max_edist_hres - - def create_network(self): - cdef: - Py_ssize_t i, j - int64_t[::1] out_indices = np.zeros(2, dtype='int64') - double[:, :, :, ::1] xarray_ = self.xarray - double[:, :, ::1] yarray_ = self.yarray - int64_t[:, ::1] grid_ = self.grid_ - - with nogil: - # Create node edges and edge weights - for i in range(0, self.nrows_-self.k_): - for j in range(0, self.ncols_-self.k_): - # Local window iteration for direct neighbors - self.create_hr_nodes(i, j, xarray_, yarray_, grid_) - if self.add_coarse_nodes_: - self.create_coarse_undirected_isolated(i, j, self.kh, out_indices) - self.create_coarse_center_edges(i, j, self.kh, out_indices, yarray_, grid_) - - return self.edge_indices_a, self.edge_indices_b, self.edge_attrs - - cdef void _connect_window( - self, - int64_t[:, ::1] grid_, - double[:, :, :, ::1] xarray, - double[:, :, ::1] yarray, - Py_ssize_t targ_i, - Py_ssize_t targ_j, - Py_ssize_t idx, - Py_ssize_t jdx, - Py_ssize_t source_i, - Py_ssize_t source_j, - bint center_weights, - double weight_gain - ) nogil: - """ - Args: - grid_ (2d array): The grid indices. - xarray (4d array) - yarray (3d array) - targ_i (int): The target window row index. - targ_j (int): The target window column index. - idx (int): The array row index. - jdx (int): The array column index. - source_i (int): The source window row index. - source_j (int): The source window column index. - """ - cdef: - double edge_weight - double edist, spdist - double mean_off, mean_center - - # COO format: - # [[sources, ...] - # [targets, ...]] - - # Source -> target - self.edge_indices_a.push_back(grid_[idx+source_i, jdx+source_j]) - self.edge_indices_b.push_back(grid_[idx+targ_i, jdx+targ_j]) - - # Target -> source - self.edge_indices_a.push_back(grid_[idx+targ_i, jdx+targ_j]) - self.edge_indices_b.push_back(grid_[idx+source_i, jdx+source_j]) - - # Both arrays must have data in the neighbors - if (_get_max_4d(xarray, self.ntime, self.nbands, idx+source_i, jdx+source_j) != self.nodata_) and \ - (_get_max_4d(xarray, self.ntime, self.nbands, idx+targ_i, jdx+targ_j) != self.nodata_) and \ - (_get_max_3d(yarray, self.nbands, idx+source_i, jdx+source_j) != self.nodata_) and \ - (_get_max_3d(yarray, self.nbands, idx+targ_i, jdx+targ_j) != self.nodata_): - - if center_weights: - # Inverse euclidean distance - edist = 1.0 - ((_euclidean_distance(self.kh, self.kh, source_i, source_j) * self.transforms_[0][0]) / self.max_edist_hres_) - - # Inverse spectral difference - mean_off = _get_mean_3d(yarray, self.nbands, idx+source_i, jdx+source_j) - mean_center = _get_mean_3d(yarray, self.nbands, idx+targ_i, jdx+targ_j) - - spdist = 1.0 - fabs(mean_off - mean_center) - - # max(edist, spdist) x 10 - edge_weight = _get_max(edist, spdist) * weight_gain - - else: - if (targ_i == self.kh) and (targ_j == self.kh): - edge_weight = 1.0 - else: - edge_weight = 0.5 - - self.edge_attrs.push_back(edge_weight) - self.edge_attrs.push_back(edge_weight) - - else: - self.edge_attrs.push_back(0.0) - self.edge_attrs.push_back(0.0) - - cdef void create_hr_nodes( - self, - Py_ssize_t i, - Py_ssize_t j, - double[:, :, :, ::1] xarray, - double[:, :, ::1] yarray, - int64_t[:, ::1] grid_, - double hr_weight=5.0 - ) nogil: - """Creates high-resolution nodes and edges - """ - cdef: - Py_ssize_t m, n - bint do_connect - - # Connect to center node - for m in range(0, self.k_): - for n in range(0, self.k_): - if m+1 < self.k_: - do_connect = True - if (i > 0) and (j == 0): - if m < self.kh: - do_connect = False - - elif j > 0: - # Only the second column half of the window needs updated - if n <= self.kh: - do_connect = False - - if do_connect: - # Vertical connection - # (grid, targ_i, targ_j, i, j, source_i, source_j) - self._connect_window(grid_, xarray, yarray, m+1, n, i, j, m, n, False, hr_weight) - - if n+1 < self.k_: - do_connect = True - if (i > 0) and (j == 0): - if m <= self.kh: - do_connect = False - - elif j > 0: - if n < self.kh: - do_connect = False - - if do_connect: - # Horizontal connection - self._connect_window(grid_, xarray, yarray, m, n+1, i, j, m, n, False, hr_weight) - - # Avoid already connected direct neighbors - # o - x - o - # | \ | / | - # x - O - x - # | / | \ | - # o - x - o - if abs(m - self.kh) + abs(n - self.kh) <= 1: - continue - - # Diagonal edges - self._connect_window(grid_, xarray, yarray, self.kh, self.kh, i, j, m, n, True, hr_weight) - - cdef void create_coarse_undirected_isolated( - self, - Py_ssize_t i, - Py_ssize_t j, - unsigned int kh, - int64_t[::1] out_indices - ) nogil: - """Creates undirected, isolated (from the center) edges for coarse grids - """ - cdef: - vector[double] hr_transform, cr_transform - Py_ssize_t pidx - int64_t[:, ::1] coarse_grid - double[:, :, :, ::1] coarse_xarray - unsigned int ntime_, nbands_, nr, nc - unsigned int row_off, col_off - unsigned int row_off_nbr, col_off_nbr - - # Static 3x3 window for coarse grids - hr_transform = self.transforms_[0] - - for pidx in range(0, self.n_transforms_-1): - cr_transform = self.transforms_[pidx+1] - # Do not connect extremely coarse grids - if fabs(cr_transform[0]) > self.coarse_window_res_limit_: - continue - - coarse_grid = self.grid_c_[pidx] - coarse_xarray = self.grid_c_resamp_[pidx] - - ntime_ = coarse_xarray.shape[0] - nbands_ = coarse_xarray.shape[1] - nr = coarse_xarray.shape[2] - nc = coarse_xarray.shape[3] - - # Get the row/column indices of the coarse resolution - # that intersect the high-resolution. - _coarse_transformer( - i, - j, - kh, - hr_transform, - cr_transform, - out_indices - ) - - col_off = out_indices[0] - row_off = out_indices[1] - - if row_off > nr - 1: - row_off = nr - 1 - - if col_off > nc - 1: - col_off = nc - 1 - - row_off_nbr = row_off + 1 - col_off_nbr = col_off + 1 - - if col_off < nc - 1: - # Edge 1 - # n1 --> n2 - self.edge_indices_a.push_back(coarse_grid[row_off, col_off]) - self.edge_indices_b.push_back(coarse_grid[row_off, col_off_nbr]) - - # Edge 2 - # n1 <-- n2 - self.edge_indices_a.push_back(coarse_grid[row_off, col_off_nbr]) - self.edge_indices_b.push_back(coarse_grid[row_off, col_off]) - - # Both arrays must have data in the neighbor and at the center - if (_get_max_4d(coarse_xarray, ntime_, nbands_, row_off, col_off) != self.nodata_) and \ - (_get_max_4d(coarse_xarray, ntime_, nbands_, row_off, col_off_nbr) != self.nodata_): - - self.edge_attrs.push_back(0.1) - self.edge_attrs.push_back(0.1) - - else: - self.edge_attrs.push_back(0.0) - self.edge_attrs.push_back(0.0) - - if row_off < nr - 1: - # Edge 1 - # n1 - # ^ - # | - # n2 - self.edge_indices_a.push_back(coarse_grid[row_off, col_off]) - self.edge_indices_b.push_back(coarse_grid[row_off_nbr, col_off]) - - # Edge 2 - # n1 - # | - # v - # n2 - self.edge_indices_a.push_back(coarse_grid[row_off_nbr, col_off]) - self.edge_indices_b.push_back(coarse_grid[row_off, col_off]) - - # Both arrays must have data in the neighbor and at the center - if (_get_max_4d(coarse_xarray, ntime_, nbands_, row_off, col_off) != self.nodata_) and \ - (_get_max_4d(coarse_xarray, ntime_, nbands_, row_off_nbr, col_off) != self.nodata_): - - self.edge_attrs.push_back(0.1) - self.edge_attrs.push_back(0.1) - - else: - self.edge_attrs.push_back(0.0) - self.edge_attrs.push_back(0.0) - - cdef void create_coarse_center_edges( - self, - Py_ssize_t i, - Py_ssize_t j, - unsigned int kh, - int64_t[::1] out_indices, - double[:, :, ::1] yarray, - int64_t[:, ::1] grid_ - ) nogil: - """Creates edges from the coarse resolution to high-resolution center - """ - cdef: - Py_ssize_t ii, jj - unsigned int kh_ - Py_ssize_t row_center, col_center - vector[double] hr_transform, cr_transform - Py_ssize_t pidx - int64_t[:, ::1] coarse_grid - int64_t[:, ::1] prev_grid - double[:, :, :, ::1] coarse_xarray - unsigned int nr, nc - unsigned int row_index, col_index - unsigned int row_off_nbr, col_off_nbr - double edge_weight, weight_step, baseline_weight - - # The first grid edge weights - edge_weight = 1.0 - weight_step = -0.5 - baseline_weight = 0.1 - - for pidx in range(0, self.n_transforms_-1): - # Get the transform vectors - hr_transform = self.transforms_[pidx] - cr_transform = self.transforms_[pidx+1] - - # Get the grid of the previous resolution - if pidx == 0: - prev_grid = grid_ - ii = i - jj = j - kh_ = kh - row_center = i + kh_ - col_center = j + kh_ - - else: - prev_grid = self.grid_c_[pidx-1] - ii = row_index - jj = col_index - kh_ = 0 - row_center = row_index - col_center = col_index - - if row_center >= prev_grid.shape[0] - 1: - row_center = prev_grid.shape[0] - 1 - - if col_center >= prev_grid.shape[1] - 1: - col_center = prev_grid.shape[1] - 1 - - # Get the current coarse(r) resolution grid - coarse_grid = self.grid_c_[pidx] - - # Get the current resampled coarse resolution data - coarse_xarray = self.grid_c_resamp_[pidx] - - ntime_ = coarse_xarray.shape[0] - nbands_ = coarse_xarray.shape[1] - nr = coarse_xarray.shape[2] - nc = coarse_xarray.shape[3] - - # Get the row/column indices of the coarse resolution - # that intersects the high-resolution. - _coarse_transformer( - ii, - jj, - kh_, - hr_transform, - cr_transform, - out_indices - ) - - # Row/column indices for the coarse(r) resolution center pixel - col_index = out_indices[0] - row_index = out_indices[1] - - if row_index > nr - 1: - row_index = nr - 1 - - if col_index > nc - 1: - col_index = nc - 1 - - # Coarse-res edge links to the center_y - self.edge_indices_a.push_back(coarse_grid[row_index, col_index]) - self.edge_indices_b.push_back(prev_grid[row_center, col_center]) - - self.edge_indices_a.push_back(prev_grid[row_center, col_center]) - self.edge_indices_b.push_back(coarse_grid[row_index, col_index]) - - # Both arrays must have data in the neighbor and at the center - if (_get_max_4d(coarse_xarray, ntime_, nbands_, row_index, col_index) != self.nodata_) and \ - (_get_max_3d(yarray, self.nbands, row_center, col_center) != self.nodata_): - - self.edge_attrs.push_back(edge_weight) - self.edge_attrs.push_back(edge_weight) - - else: - self.edge_attrs.push_back(0.0) - self.edge_attrs.push_back(0.0) - - edge_weight += weight_step - - if edge_weight < baseline_weight: - edge_weight = baseline_weight diff --git a/src/cultionet/nn/__init__.py b/src/cultionet/nn/__init__.py new file mode 100644 index 00000000..85fe48b1 --- /dev/null +++ b/src/cultionet/nn/__init__.py @@ -0,0 +1,38 @@ +from .modules.activations import SetActivation +from .modules.attention import NeighborhoodAttention2D, SpatialChannelAttention +from .modules.convolution import ( + ConvBlock2d, + ConvTranspose2d, + PoolResidualConv, + ResidualAConv, + ResidualConv, +) +from .modules.geo_encoding import GeoEmbeddings +from .modules.unet_parts import ( + TowerUNetBlock, + TowerUNetDecoder, + TowerUNetEncoder, + TowerUNetFinal, + TowerUNetFinalCombine, + TowerUNetFusion, + UNetUpBlock, +) + +__all__ = [ + 'ConvBlock2d', + 'ConvTranspose2d', + 'GeoEmbeddings', + 'NeighborhoodAttention2D', + 'PoolResidualConv', + 'ResidualAConv', + 'ResidualConv', + 'SetActivation', + 'SpatialChannelAttention', + 'TowerUNetFinal', + 'TowerUNetFinalCombine', + 'UNetUpBlock', + 'TowerUNetBlock', + 'TowerUNetEncoder', + 'TowerUNetDecoder', + 'TowerUNetFusion', +] diff --git a/src/cultionet/nn/functional.py b/src/cultionet/nn/functional.py new file mode 100644 index 00000000..8874cabe --- /dev/null +++ b/src/cultionet/nn/functional.py @@ -0,0 +1,81 @@ +import cv2 +import einops +import numpy as np +import torch +import torch.nn.functional as F + + +@torch.no_grad +def merge_distances( + foreground_distances: torch.Tensor, + crop_mask: torch.Tensor, + edge_mask: torch.Tensor, + inverse: bool = True, + beta: float = 10.0, +) -> torch.Tensor: + + if len(foreground_distances.shape) == 3: + foreground_distances = einops.rearrange( + foreground_distances, 'b h w -> b 1 h w' + ) + + if len(crop_mask.shape) == 3: + crop_mask = einops.rearrange(crop_mask, 'b h w -> b 1 h w') + + if len(edge_mask.shape) == 3: + edge_mask = einops.rearrange(edge_mask, 'b h w -> b 1 h w') + + background_mask = ( + ((crop_mask == 0) & (edge_mask == 0)).detach().cpu().numpy() + ) + background_dist = np.zeros(background_mask.shape, dtype='float32') + for midx in range(background_mask.shape[0]): + bdist = cv2.distanceTransform( + background_mask[midx].squeeze(axis=0).astype('uint8'), + cv2.DIST_L2, + 3, + ) + bdist /= bdist.max() + + if inverse: + bdist = 1.0 - bdist + + if beta != 1: + bdist = bdist**beta + bdist[np.isnan(bdist)] = 0 + + background_dist[midx] = bdist[None, None] + + if inverse: + foreground_distances = 1.0 - foreground_distances + + if beta != 1: + foreground_distances = foreground_distances**beta + foreground_distances[torch.isnan(foreground_distances)] = 0 + + distance = np.where( + background_mask, + background_dist, + foreground_distances.detach().cpu().numpy(), + ) + targets = torch.tensor( + distance, + dtype=foreground_distances.dtype, + device=foreground_distances.device, + ) + + targets[edge_mask == 1] = 1.0 if inverse else 0.0 + + return targets + + +def check_upsample(x: torch.Tensor, size: torch.Size) -> torch.Tensor: + if x.shape[-2:] != size: + x = F.interpolate( + x, + size=size, + mode="bilinear", + align_corners=True, + ) + + return x diff --git a/src/cultionet/nn/modules/__init__.py b/src/cultionet/nn/modules/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/cultionet/nn/modules/activations.py b/src/cultionet/nn/modules/activations.py new file mode 100644 index 00000000..2f2c3f70 --- /dev/null +++ b/src/cultionet/nn/modules/activations.py @@ -0,0 +1,24 @@ +import torch +import torch.nn as nn + + +class SetActivation(nn.Module): + """ + Examples: + >>> act = SetActivation('ReLU') + >>> act(x) + >>> + >>> act = SetActivation('SiLU') + >>> act(x) + """ + + def __init__(self, activation_type: str): + super().__init__() + + try: + self.activation = getattr(torch.nn, activation_type)(inplace=False) + except TypeError: + self.activation = getattr(torch.nn, activation_type)() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.activation(x) diff --git a/src/cultionet/nn/modules/attention.py b/src/cultionet/nn/modules/attention.py new file mode 100644 index 00000000..9126c0e4 --- /dev/null +++ b/src/cultionet/nn/modules/attention.py @@ -0,0 +1,176 @@ +import typing as T + +import einops +import torch +import torch.nn as nn +import torch.nn.functional as F +from natten.functional import na2d, na2d_av, na2d_qk + +from .activations import SetActivation + + +class ChannelAttention(nn.Module): + def __init__(self, in_channels: int, activation_type: str): + super().__init__() + + # Channel attention + self.channel_adaptive_avg = nn.AdaptiveAvgPool2d(1) + self.channel_adaptive_max = nn.AdaptiveMaxPool2d(1) + self.fc1 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels // 2, + kernel_size=1, + padding=0, + bias=False, + ), + SetActivation(activation_type=activation_type), + nn.Conv2d( + in_channels=in_channels // 2, + out_channels=in_channels, + kernel_size=1, + padding=0, + bias=False, + ), + ) + self.fc2 = nn.Sequential( + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels // 2, + kernel_size=1, + padding=0, + bias=False, + ), + SetActivation(activation_type=activation_type), + nn.Conv2d( + in_channels=in_channels // 2, + out_channels=in_channels, + kernel_size=1, + padding=0, + bias=False, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, height, width = x.shape + + avg_attention = self.fc1(self.channel_adaptive_avg(x)) + max_attention = self.fc2(self.channel_adaptive_max(x)) + attention = avg_attention + max_attention + attention = F.sigmoid(attention) + + return attention.expand(-1, -1, height, width) + + +class SpatialAttention(nn.Module): + def __init__(self): + super().__init__() + + self.conv = nn.Conv2d( + in_channels=2, + out_channels=1, + kernel_size=3, + padding=1, + bias=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, height, width = x.shape + + avg_attention = einops.reduce(x, 'b c h w -> b 1 h w', 'mean') + max_attention = einops.reduce(x, 'b c h w -> b 1 h w', 'max') + attention = torch.cat([avg_attention, max_attention], dim=1) + attention = self.conv(attention) + attention = F.sigmoid(attention) + + return attention.expand(-1, -1, height, width) + + +class SpatialChannelAttention(nn.Module): + """Spatial-Channel Attention Block. + + References: + @inproceedings{woo_etal_2018, + title={Cbam: Convolutional block attention module}, + author={Woo, Sanghyun and Park, Jongchan and Lee, Joon-Young and Kweon, In So}, + booktitle={Proceedings of the European conference on computer vision (ECCV)}, + pages={3--19}, + year={2018}, + url={https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf}, + } + + https://arxiv.org/abs/1807.02758 + https://github.com/yjn870/RCAN-pytorch + https://www.mdpi.com/2072-4292/14/9/2253 + https://github.com/luuuyi/CBAM.PyTorch/blob/master/model/resnet_cbam.py + """ + + def __init__(self, in_channels: int, activation_type: str): + super().__init__() + + self.channel_attention = ChannelAttention( + in_channels=in_channels, + activation_type=activation_type, + ) + self.spatial_attention = SpatialAttention() + self.gamma = nn.Parameter(torch.zeros(1, requires_grad=True)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + channel_attention = self.channel_attention(x) + spatial_attention = self.spatial_attention(x) + attention = (channel_attention + spatial_attention) * 0.5 + attention = 1.0 + self.gamma * attention + + return attention + + +class NeighborhoodAttention2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + dilation: int, + ): + super().__init__() + + self.kernel_size = kernel_size + self.dilation = dilation + + self.query = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + padding=0, + ) + self.key = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + padding=0, + ) + self.value = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + padding=0, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + q = self.query(x) + k = self.key(x) + v = self.value(x) + + q = einops.rearrange(q, 'b c h w -> b h w 1 c') + k = einops.rearrange(k, 'b c h w -> b h w 1 c') + v = einops.rearrange(v, 'b c h w -> b h w 1 c') + + output = na2d( + q, k, v, kernel_size=self.kernel_size, dilation=self.dilation + ) + + output = einops.rearrange(output, 'b h w 1 c -> b c h w') + + return output diff --git a/src/cultionet/nn/modules/convolution.py b/src/cultionet/nn/modules/convolution.py new file mode 100644 index 00000000..d48087b6 --- /dev/null +++ b/src/cultionet/nn/modules/convolution.py @@ -0,0 +1,513 @@ +import logging +import typing as T + +import natten +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops.layers.torch import Rearrange + +from cultionet.enums import AttentionTypes, ResBlockTypes + +from ..functional import check_upsample +from .activations import SetActivation +from .attention import SpatialChannelAttention + +# logging.getLogger("natten").setLevel(logging.ERROR) +# natten.use_fused_na(True) +# natten.use_kv_parallelism_in_fused_na(True) + + +class DepthwiseSeparableConv(nn.Module): + def __init__(self, in_channels: int, out_channels: int, kernel_size: int): + super().__init__() + + self.separable = nn.Sequential( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=in_channels, + ), + nn.Conv2d( + out_channels, + out_channels, + kernel_size=1, + padding=0, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.separable(x) + + +class ConvTranspose2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + stride: int = 2, + padding: int = 1, + ): + super().__init__() + + self.up_conv = nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor: + return check_upsample( + self.up_conv(x), + size=size, + ) + + +class ConvBlock2d(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + dilation: int = 1, + stride: int = 1, + add_activation: bool = True, + activation_type: str = "SiLU", + batchnorm_first: bool = False, + ): + super().__init__() + + layers = [] + + if batchnorm_first: + layers += [ + nn.BatchNorm2d(in_channels), + SetActivation(activation_type), + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + stride=stride, + ), + ] + else: + layers += [ + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + stride=stride, + bias=False, + ), + nn.BatchNorm2d(out_channels), + ] + if add_activation: + layers += [SetActivation(activation_type)] + + self.seq = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.seq(x) + + +class ResConvBlock2d(nn.Module): + """Convolution layer designed for a residual activation.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + dilation: int = 1, + activation_type: str = "SiLU", + num_blocks: int = 2, + batchnorm_first: bool = False, + ): + super().__init__() + + assert num_blocks > 0, "There must be at least one block." + + conv_layers = [] + + conv_layers.append( + ConvBlock2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=0 if kernel_size == 1 else kernel_size // 2, + dilation=1, + activation_type=activation_type, + add_activation=True, + batchnorm_first=batchnorm_first, + ) + ) + + for _ in range(num_blocks - 1): + conv_layers.append( + ConvBlock2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=0 if kernel_size == 1 else max(1, dilation - 1), + dilation=1 if kernel_size == 1 else max(1, dilation - 1), + activation_type=activation_type, + add_activation=True, + batchnorm_first=batchnorm_first, + ) + ) + + self.block = nn.ModuleList(conv_layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + for layer in self.block: + x = layer(x) + + return x + + +class ResidualConv(nn.Module): + """A residual convolution layer with (optional) attention.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_blocks: int = 2, + attention_weights: T.Optional[str] = None, + activation_type: str = "SiLU", + batchnorm_first: bool = False, + ): + super().__init__() + + self.attention_weights = attention_weights + + if self.attention_weights is not None: + assert self.attention_weights in [ + AttentionTypes.SPATIAL_CHANNEL, + ], "The attention method is not supported." + + self.gamma = nn.Parameter(torch.ones(1, requires_grad=True)) + + self.attention_conv = SpatialChannelAttention( + out_channels=out_channels, activation_type=activation_type + ) + + self.seq = ResConvBlock2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_blocks=num_blocks, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + ) + + self.skip = None + if in_channels != out_channels: + self.skip = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + padding=0, + ) + + if self.attention_weights is not None: + self.final_act = SetActivation(activation_type=activation_type) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.skip is not None: + # Align channels + out = self.skip(x) + else: + out = x + + out = out + self.seq(x) + + if self.attention_weights is not None: + # Get weights from the residual + attention = self.attention_conv(out) + + # 1 + γA + attention = 1.0 + self.gamma * attention + out = out * attention + + out = self.final_act(out) + + return out + + +class ResidualAConv(nn.Module): + r"""Residual convolution with atrous/dilated convolutions. + + Residual convolutions: + + CSIRO BSTD/MIT LICENSE + + Redistribution and use in source and binary forms, with or without modification, are permitted provided that + the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the + following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or other materials provided with the distribution. + 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or + promote products derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, + INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + Citation: + @article{diakogiannis_etal_2020, + title={ResUNet-a: A deep learning framework for semantic segmentation of remotely sensed data}, + author={Diakogiannis, Foivos I and Waldner, Fran{\c{c}}ois and Caccetta, Peter and Wu, Chen}, + journal={ISPRS Journal of Photogrammetry and Remote Sensing}, + volume={162}, + pages={94--114}, + year={2020}, + publisher={Elsevier} + } + + References: + https://www.sciencedirect.com/science/article/abs/pii/S0924271620300149 + https://arxiv.org/abs/1904.00592 + https://arxiv.org/pdf/1904.00592.pdf + + Attention with NATTEN: + MIT License + Copyright (c) 2022 - 2024 Ali Hassani. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_blocks: int = 2, + dilations: T.Optional[T.List[int]] = None, + attention_weights: T.Optional[str] = None, + activation_type: str = "SiLU", + batchnorm_first: bool = False, + natten_num_heads: int = 8, + natten_kernel_size: int = 3, + natten_dilation: int = 1, + natten_attn_drop: float = 0.0, + natten_proj_drop: float = 0.0, + ): + super().__init__() + + if dilations is None: + dilations = [1, 2] + + self.attention_weights = attention_weights + + if in_channels != out_channels: + self.skip = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + padding=0, + ) + else: + self.skip = nn.Identity() + + if self.attention_weights is not None: + + assert self.attention_weights in [ + AttentionTypes.NATTEN, + AttentionTypes.SPATIAL_CHANNEL, + ], "The attention method is not supported." + + if self.attention_weights == AttentionTypes.NATTEN: + + self.attention_conv = nn.Sequential( + Rearrange('b c h w -> b h w c'), + nn.LayerNorm(out_channels), + natten.NeighborhoodAttention2D( + dim=out_channels, + num_heads=natten_num_heads, + kernel_size=natten_kernel_size, + dilation=natten_dilation, + rel_pos_bias=False, + qkv_bias=True, + attn_drop=natten_attn_drop, + proj_drop=natten_proj_drop, + ), + nn.LayerNorm(out_channels), + Rearrange('b h w c -> b c h w'), + ) + + else: + + self.attention_conv = SpatialChannelAttention( + in_channels=out_channels, + activation_type=activation_type, + ) + + self.res_modules = nn.ModuleList( + [ + ResConvBlock2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + dilation=dilation, + activation_type=activation_type, + num_blocks=num_blocks, + batchnorm_first=batchnorm_first, + ) + for dilation in dilations + ] + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.skip(x) + + if self.attention_weights is not None: + skip = out + + # Resunet-a block takes the same input and + # sums multiple outputs with varying dilations. + for layer in self.res_modules: + out = out + layer(x) + + if self.attention_weights is not None: + attention_out = self.attention_conv(skip) + if self.attention_weights == AttentionTypes.NATTEN: + out = out + attention_out + else: + out = out * attention_out + + return out + + +class PoolResidualConv(nn.Module): + """Residual convolution with down-sampling. + + Default: + 1) Convolution block + 2) Down-sampling by adaptive max pooling + + If pool_first=True: + 1) Down-sampling by adaptive max pooling + 2) Convolution block + If dropout > 0 + 3) Dropout + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + kernel_size: int = 3, + num_blocks: int = 2, + attention_weights: T.Optional[str] = None, + activation_type: str = "SiLU", + res_block_type: str = ResBlockTypes.RESA, + dilations: T.Sequence[int] = None, + pool_first: bool = True, + pool_by_max: bool = False, + batchnorm_first: bool = False, + natten_num_heads: int = 8, + natten_kernel_size: int = 3, + natten_dilation: int = 1, + natten_attn_drop: float = 0.0, + natten_proj_drop: float = 0.0, + ): + super().__init__() + + assert res_block_type in ( + ResBlockTypes.RES, + ResBlockTypes.RESA, + ) + + self.pool_first = pool_first + self.pool_by_max = pool_by_max + if self.pool_first: + if not self.pool_by_max: + if batchnorm_first: + self.pool_conv = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=2, + ) + else: + self.pool_conv = ConvBlock2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + stride=2, + add_activation=False, + batchnorm_first=False, + ) + + in_channels = out_channels + + if res_block_type == ResBlockTypes.RES: + + self.res_conv = ResidualConv( + in_channels, + out_channels, + kernel_size=kernel_size, + attention_weights=attention_weights, + num_blocks=num_blocks, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + ) + + else: + + self.res_conv = ResidualAConv( + in_channels, + out_channels, + kernel_size=kernel_size, + dilations=dilations, + num_blocks=num_blocks, + attention_weights=attention_weights, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + natten_num_heads=natten_num_heads, + natten_kernel_size=natten_kernel_size, + natten_dilation=natten_dilation, + natten_attn_drop=natten_attn_drop, + natten_proj_drop=natten_proj_drop, + ) + + self.dropout_layer = nn.Dropout2d(p=dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + _, _, height, width = x.shape + + if self.pool_first: + if self.pool_by_max: + x = F.adaptive_max_pool2d( + x, output_size=(height // 2, width // 2) + ) + else: + x = self.pool_conv(x) + + # Residual convolution + x = self.res_conv(x) + + # Dropout + x = self.dropout_layer(x) + + return x diff --git a/src/cultionet/nn/modules/geo_encoding.py b/src/cultionet/nn/modules/geo_encoding.py new file mode 100644 index 00000000..27e8d939 --- /dev/null +++ b/src/cultionet/nn/modules/geo_encoding.py @@ -0,0 +1,26 @@ +import torch +from torch import nn + + +class GeoEmbeddings(nn.Module): + def __init__(self, channels: int): + super().__init__() + + self.coord_embedding = nn.Linear(3, channels) + + @torch.no_grad + def decimal_degrees_to_cartesian( + self, degrees: torch.Tensor + ) -> torch.Tensor: + radians = torch.deg2rad(degrees) + cosine = torch.cos(radians) + sine = torch.sin(radians) + x = cosine[:, 1] * cosine[:, 0] + y = cosine[:, 1] * sine[:, 0] + + return torch.stack([x, y, sine[:, 1]], dim=-1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + cartesian_coords = self.decimal_degrees_to_cartesian(x) + + return self.coord_embedding(cartesian_coords) diff --git a/src/cultionet/nn/modules/unet_parts.py b/src/cultionet/nn/modules/unet_parts.py new file mode 100644 index 00000000..7c98d689 --- /dev/null +++ b/src/cultionet/nn/modules/unet_parts.py @@ -0,0 +1,760 @@ +import typing as T + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from cultionet.enums import AttentionTypes, InferenceNames, ResBlockTypes + +from .convolution import ( + ConvBlock2d, + ConvTranspose2d, + PoolResidualConv, + ResidualAConv, + ResidualConv, +) +from .geo_encoding import GeoEmbeddings + +NATTEN_PARAMS = { + "a": { + "natten_num_heads": 4, + "natten_kernel_size": 3, + "natten_dilation": 2, + }, + "b": { + "natten_num_heads": 4, + "natten_kernel_size": 3, + "natten_dilation": 1, + }, + "c": { + "natten_num_heads": 8, + "natten_kernel_size": 3, + "natten_dilation": 1, + }, + "d": { + "natten_num_heads": 8, + "natten_kernel_size": 1, + "natten_dilation": 1, + }, +} + + +class SigmoidCrisp(nn.Module): + r"""Sigmoid crisp. + + Adapted from publication and source code below: + + CSIRO BSTD/MIT LICENSE + + Redistribution and use in source and binary forms, with or without modification, are permitted provided that + the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the + following disclaimer. + 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and + the following disclaimer in the documentation and/or other materials provided with the distribution. + 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or + promote products derived from this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, + INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, + SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, + WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE + USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + Citation: + @article{diakogiannis_etal_2021, + title={Looking for change? Roll the dice and demand attention}, + author={Diakogiannis, Foivos I and Waldner, Fran{\c{c}}ois and Caccetta, Peter}, + journal={Remote Sensing}, + volume={13}, + number={18}, + pages={3707}, + year={2021}, + publisher={MDPI} + } + + Reference: + https://www.mdpi.com/2072-4292/13/18/3707 + https://arxiv.org/pdf/2009.02062.pdf + https://github.com/waldnerf/decode/blob/main/FracTAL_ResUNet/nn/activations/sigmoid_crisp.py + """ + + def __init__(self, smooth: float = 1e-2): + super().__init__() + + self.smooth = smooth + self.gamma = nn.Parameter(torch.ones(1, requires_grad=True)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.smooth + F.sigmoid(self.gamma) + out = torch.reciprocal(out) + out = x * out + out = F.sigmoid(out) + + return out + + +class TowerUNetFinalCombine(nn.Module): + """Final output by fusing all tower outputs.""" + + def __init__( + self, + num_classes: int, + edge_activation: bool = True, + mask_activation: bool = True, + ): + super().__init__() + + edge_activation_layer = ( + SigmoidCrisp() if edge_activation else nn.Identity() + ) + mask_activation_layer = ( + nn.Sigmoid() if mask_activation else nn.Identity() + ) + + self.final_dist = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, padding=0), + nn.Sigmoid(), + ) + self.dist_gamma1 = nn.Parameter(torch.ones(1, requires_grad=True)) + self.dist_gamma2 = nn.Parameter(torch.ones(1, requires_grad=True)) + self.dist_gamma3 = nn.Parameter(torch.ones(1, requires_grad=True)) + + self.final_edge = nn.Sequential( + nn.Conv2d(in_channels=1, out_channels=1, kernel_size=1, padding=0), + edge_activation_layer, + ) + self.edge_gamma1 = nn.Parameter(torch.ones(1, requires_grad=True)) + self.edge_gamma2 = nn.Parameter(torch.ones(1, requires_grad=True)) + self.edge_gamma3 = nn.Parameter(torch.ones(1, requires_grad=True)) + + self.final_crop = nn.Sequential( + nn.Conv2d( + in_channels=num_classes, + out_channels=num_classes, + kernel_size=1, + padding=0, + ), + mask_activation_layer, + ) + self.crop_gamma1 = nn.Parameter(torch.ones(1, requires_grad=True)) + self.crop_gamma2 = nn.Parameter(torch.ones(1, requires_grad=True)) + self.crop_gamma3 = nn.Parameter(torch.ones(1, requires_grad=True)) + + def forward( + self, + out_a: T.Dict[str, torch.Tensor], + out_b: T.Dict[str, torch.Tensor], + out_c: T.Dict[str, torch.Tensor], + suffixes: T.Sequence[str], + ) -> T.Dict[str, torch.Tensor]: + + distance = self.final_dist( + ( + torch.reciprocal(self.dist_gamma1) + * out_a[f"{InferenceNames.DISTANCE}{suffixes[0]}"] + + torch.reciprocal(self.dist_gamma2) + * out_b[f"{InferenceNames.DISTANCE}{suffixes[1]}"] + + torch.reciprocal(self.dist_gamma3) + * out_c[f"{InferenceNames.DISTANCE}{suffixes[2]}"] + ) + ) + + edge = self.final_edge( + ( + torch.reciprocal(self.edge_gamma1) + * out_a[f"{InferenceNames.EDGE}{suffixes[0]}"] + + torch.reciprocal(self.edge_gamma2) + * out_b[f"{InferenceNames.EDGE}{suffixes[1]}"] + + torch.reciprocal(self.edge_gamma3) + * out_c[f"{InferenceNames.EDGE}{suffixes[2]}"] + ) + ) + + crop = self.final_crop( + ( + torch.reciprocal(self.crop_gamma1) + * out_a[f"{InferenceNames.CROP}{suffixes[0]}"] + + torch.reciprocal(self.crop_gamma2) + * out_b[f"{InferenceNames.CROP}{suffixes[1]}"] + + torch.reciprocal(self.crop_gamma3) + * out_c[f"{InferenceNames.CROP}{suffixes[2]}"] + ) + ) + + return { + InferenceNames.DISTANCE: distance, + InferenceNames.EDGE: edge, + InferenceNames.CROP: crop, + } + + +class StreamConv2d(nn.Module): + def __init__( + self, + in_channels: int, + hidden_channels: int, + out_channels: int, + activation_type: str, + ): + super().__init__() + + self.conv = nn.Sequential( + ConvBlock2d( + in_channels=in_channels, + out_channels=hidden_channels, + kernel_size=3, + padding=1, + add_activation=True, + activation_type=activation_type, + ), + nn.Conv2d( + in_channels=hidden_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class TowerUNetFinal(nn.Module): + """Output of an individual tower fusion.""" + + def __init__( + self, + in_channels: int, + num_classes: int, + activation_type: str = "SiLU", + resample_factor: int = 0, + ): + super().__init__() + + self.in_channels = in_channels + self.num_classes = num_classes + + if resample_factor > 1: + self.up_conv = ConvTranspose2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=3, + stride=resample_factor, + padding=1, + ) + + # Hidden -> 3 -> 3 -> 1 + self.dist_conv = StreamConv2d( + in_channels=in_channels, + hidden_channels=3, + out_channels=1, + activation_type=activation_type, + ) + self.edge_conv = StreamConv2d( + in_channels=in_channels, + hidden_channels=3, + out_channels=1, + activation_type=activation_type, + ) + self.crop_conv = StreamConv2d( + in_channels=in_channels, + hidden_channels=3, + out_channels=1, + activation_type=activation_type, + ) + + # 3 -> 3 + self.fuse_conv = ConvBlock2d( + in_channels=3, + out_channels=3, + kernel_size=3, + padding=1, + add_activation=True, + activation_type=activation_type, + ) + + def forward( + self, + x: torch.Tensor, + size: T.Optional[torch.Size] = None, + suffix: str = "", + ) -> T.Dict[str, torch.Tensor]: + if size is not None: + x = self.up_conv(x, size=size) + + # Separate hidden into task streams + # H -> 3 -> 1 + dist_h = self.dist_conv(x) + edge_h = self.edge_conv(x) + crop_h = self.crop_conv(x) + + # [1, 1, 1] -> 3 + h = torch.cat([dist_h, edge_h, crop_h], dim=1) + # 3 -> 3 + h = self.fuse_conv(h) + # -> [1, 1, 1] + dist_out, edge_out, mask_out = torch.chunk(h, 3, dim=1) + + # x --> H(3) --> H(1) --> Concat(3) --> Fuse(3) --> Chunk(1,1,1) + + return { + f"{InferenceNames.DISTANCE}{suffix}": dist_out, + f"{InferenceNames.EDGE}{suffix}": edge_out, + f"{InferenceNames.CROP}{suffix}": mask_out, + } + + +class UNetUpBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + num_blocks: int = 2, + attention_weights: T.Optional[str] = None, + activation_type: str = "SiLU", + res_block_type: str = ResBlockTypes.RESA, + dilations: T.Sequence[int] = None, + batchnorm_first: bool = False, + resample_up: bool = True, + natten_num_heads: int = 8, + natten_kernel_size: int = 3, + natten_dilation: int = 1, + natten_attn_drop: float = 0.0, + natten_proj_drop: float = 0.0, + ): + super().__init__() + + assert res_block_type in ( + ResBlockTypes.RES, + ResBlockTypes.RESA, + ) + + if resample_up: + self.up_conv = ConvTranspose2d(in_channels, in_channels) + + if res_block_type == ResBlockTypes.RES: + + self.res_conv = ResidualConv( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_blocks=num_blocks, + attention_weights=attention_weights, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + ) + + else: + + self.res_conv = ResidualAConv( + in_channels, + out_channels, + kernel_size=kernel_size, + dilations=dilations, + attention_weights=attention_weights, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + natten_num_heads=natten_num_heads, + natten_kernel_size=natten_kernel_size, + natten_dilation=natten_dilation, + natten_attn_drop=natten_attn_drop, + natten_proj_drop=natten_proj_drop, + ) + + def forward(self, x: torch.Tensor, size: torch.Size) -> torch.Tensor: + if x.shape[-2:] != size: + x = self.up_conv(x, size=size) + + return self.res_conv(x) + + +class TowerUNetEncoder(nn.Module): + def __init__( + self, + channels: T.Sequence[int], + dilations: T.Sequence[int] = None, + activation_type: str = "SiLU", + dropout: float = 0.0, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + pool_by_max: bool = False, + batchnorm_first: bool = False, + ): + super().__init__() + + # Backbone layers + backbone_kwargs = dict( + dropout=dropout, + activation_type=activation_type, + res_block_type=res_block_type, + batchnorm_first=batchnorm_first, + pool_by_max=pool_by_max, + natten_attn_drop=dropout, + natten_proj_drop=dropout, + ) + self.down_a = PoolResidualConv( + in_channels=channels[0], + out_channels=channels[0], + dilations=dilations, + pool_first=False, + # Attention applied at 1/1 spatial resolution + attention_weights=attention_weights, + **{**backbone_kwargs, **NATTEN_PARAMS["a"]}, + ) + self.down_b = PoolResidualConv( + in_channels=channels[0], + out_channels=channels[1], + dilations=dilations[:3], + # Attention applied at 1/2 spatial resolution + attention_weights=attention_weights, + **{**backbone_kwargs, **NATTEN_PARAMS["b"]}, + ) + self.down_c = PoolResidualConv( + channels[1], + channels[2], + dilations=dilations[:2], + # Attention applied at 1/4 spatial resolution + attention_weights=attention_weights, + **{**backbone_kwargs, **NATTEN_PARAMS["c"]}, + ) + self.down_d = PoolResidualConv( + channels[2], + channels[3], + kernel_size=1, + num_blocks=1, + dilations=[1], + # Attention applied at 1/8 spatial resolution + attention_weights=None, + **backbone_kwargs, + ) + + def forward(self, x: torch.Tensor) -> T.Dict[str, torch.Tensor]: + # Backbone + x_a = self.down_a(x) # 1/1 of input + x_b = self.down_b(x_a) # 1/2 of input + x_c = self.down_c(x_b) # 1/4 of input + x_d = self.down_d(x_c) # 1/8 of input + + return { + "x_a": x_a, + "x_b": x_b, + "x_c": x_c, + "x_d": x_d, + } + + +class TowerUNetDecoder(nn.Module): + def __init__( + self, + channels: T.Sequence[int], + up_channels: int, + dilations: T.Sequence[int] = None, + activation_type: str = "SiLU", + dropout: float = 0.0, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + batchnorm_first: bool = False, + ): + super().__init__() + + # Up layers + up_kwargs = dict( + activation_type=activation_type, + res_block_type=res_block_type, + batchnorm_first=batchnorm_first, + natten_attn_drop=dropout, + natten_proj_drop=dropout, + ) + self.over_d = UNetUpBlock( + in_channels=channels[3], + out_channels=up_channels, + kernel_size=1, + num_blocks=1, + dilations=[1], + resample_up=False, + # Attention applied at 1/8 spatial resolution + attention_weights=None, + **up_kwargs, + ) + self.up_cu = UNetUpBlock( + in_channels=up_channels, + out_channels=up_channels, + dilations=dilations[:2], + # Attention applied at 1/4 spatial resolution + attention_weights=attention_weights, + **{**up_kwargs, **NATTEN_PARAMS["c"]}, + ) + self.up_bu = UNetUpBlock( + in_channels=up_channels, + out_channels=up_channels, + dilations=dilations[:3], + # Attention applied at 1/2 spatial resolution + attention_weights=attention_weights, + **{**up_kwargs, **NATTEN_PARAMS["b"]}, + ) + self.up_au = UNetUpBlock( + in_channels=up_channels, + out_channels=up_channels, + dilations=dilations, + # Attention applied at 1/1 spatial resolution + attention_weights=attention_weights, + **{**up_kwargs, **NATTEN_PARAMS["a"]}, + ) + + def forward( + self, x: T.Dict[str, torch.Tensor] + ) -> T.Dict[str, torch.Tensor]: + x_du = self.over_d(x["x_d"], size=x["x_d"].shape[-2:]) + + # Up + x_cu = self.up_cu(x_du, size=x["x_c"].shape[-2:]) + x_bu = self.up_bu(x_cu, size=x["x_b"].shape[-2:]) + x_au = self.up_au(x_bu, size=x["x_a"].shape[-2:]) + + return { + "x_au": x_au, + "x_bu": x_bu, + "x_cu": x_cu, + "x_du": x_du, + } + + +class TowerUNetFusion(nn.Module): + def __init__( + self, + channels: T.Sequence[int], + up_channels: int, + dilations: T.Sequence[int] = None, + activation_type: str = "SiLU", + dropout: float = 0.0, + res_block_type: str = ResBlockTypes.RESA, + attention_weights: str = AttentionTypes.NATTEN, + batchnorm_first: bool = False, + use_latlon: bool = False, + ): + super().__init__() + + # Towers + tower_kwargs = dict( + up_channels=up_channels, + out_channels=up_channels, + activation_type=activation_type, + res_block_type=res_block_type, + batchnorm_first=batchnorm_first, + attention_weights=attention_weights, + natten_attn_drop=dropout, + natten_proj_drop=dropout, + use_latlon=use_latlon, + ) + self.tower_c = TowerUNetBlock( + backbone_side_channels=channels[2], + backbone_down_channels=channels[3], + dilations=dilations[:2], + **{**tower_kwargs, **NATTEN_PARAMS["c"]}, + ) + self.tower_b = TowerUNetBlock( + backbone_side_channels=channels[1], + backbone_down_channels=channels[2], + tower=True, + dilations=dilations, + **{**tower_kwargs, **NATTEN_PARAMS["b"]}, + ) + self.tower_a = TowerUNetBlock( + backbone_side_channels=channels[0], + backbone_down_channels=channels[1], + tower=True, + dilations=dilations, + **{**tower_kwargs, **NATTEN_PARAMS["a"]}, + ) + + def forward( + self, + encoded: T.Dict[str, torch.Tensor], + decoded: T.Dict[str, torch.Tensor], + latlon_coords: T.Optional[torch.Tensor] = None, + ) -> T.Dict[str, torch.Tensor]: + + # Central towers + x_tower_c = self.tower_c( + backbone_side=encoded["x_c"], + backbone_down=encoded["x_d"], + decode_side=decoded["x_cu"], + decode_down=decoded["x_du"], + latlon_coords=latlon_coords, + ) + x_tower_b = self.tower_b( + backbone_side=encoded["x_b"], + backbone_down=encoded["x_c"], + decode_side=decoded["x_bu"], + decode_down=decoded["x_cu"], + tower_down=x_tower_c, + latlon_coords=latlon_coords, + ) + x_tower_a = self.tower_a( + backbone_side=encoded["x_a"], + backbone_down=encoded["x_b"], + decode_side=decoded["x_au"], + decode_down=decoded["x_bu"], + tower_down=x_tower_b, + latlon_coords=latlon_coords, + ) + + return { + "x_tower_a": x_tower_a, + "x_tower_b": x_tower_b, + "x_tower_c": x_tower_c, + } + + +class TowerUNetBlock(nn.Module): + def __init__( + self, + backbone_side_channels: int, + backbone_down_channels: int, + up_channels: int, + out_channels: int, + tower: bool = False, + kernel_size: int = 3, + num_blocks: int = 2, + attention_weights: T.Optional[str] = None, + res_block_type: str = ResBlockTypes.RESA, + dilations: T.Sequence[int] = None, + activation_type: str = "SiLU", + batchnorm_first: bool = False, + natten_num_heads: int = 8, + natten_kernel_size: int = 3, + natten_dilation: int = 1, + natten_attn_drop: float = 0.0, + natten_proj_drop: float = 0.0, + use_latlon: bool = False, + ): + super().__init__() + + self.use_latlon = use_latlon + + assert res_block_type in ( + ResBlockTypes.RES, + ResBlockTypes.RESA, + ) + + in_channels = ( + backbone_side_channels + backbone_down_channels + up_channels * 2 + ) + + self.backbone_down_conv = ConvTranspose2d( + in_channels=backbone_down_channels, + out_channels=backbone_down_channels, + kernel_size=3, + stride=2, + padding=1, + ) + + self.decode_down_conv = ConvTranspose2d( + in_channels=up_channels, + out_channels=up_channels, + kernel_size=3, + stride=2, + padding=1, + ) + + if tower: + self.tower_conv = ConvTranspose2d( + in_channels=up_channels, + out_channels=up_channels, + kernel_size=3, + stride=2, + padding=1, + ) + in_channels += up_channels + + if self.use_latlon: + # TODO: make optional + self.geo_embeddings = torch.compile(GeoEmbeddings(up_channels)) + # self.geo_embeddings4 = SphericalHarmonics(out_channels=in_channels, legendre_polys=4) + # self.geo_embeddings8 = SphericalHarmonics(out_channels=in_channels, legendre_polys=8) + + in_channels += up_channels + + if res_block_type == ResBlockTypes.RES: + + self.res_conv = ResidualConv( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + num_blocks=num_blocks, + attention_weights=attention_weights, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + ) + + else: + + self.res_conv = ResidualAConv( + in_channels, + out_channels, + kernel_size=kernel_size, + num_blocks=num_blocks, + dilations=dilations, + attention_weights=attention_weights, + activation_type=activation_type, + batchnorm_first=batchnorm_first, + natten_num_heads=natten_num_heads, + natten_kernel_size=natten_kernel_size, + natten_dilation=natten_dilation, + natten_attn_drop=natten_attn_drop, + natten_proj_drop=natten_proj_drop, + ) + + def forward( + self, + backbone_side: torch.Tensor, + backbone_down: torch.Tensor, + decode_side: torch.Tensor, + decode_down: torch.Tensor, + tower_down: T.Optional[torch.Tensor] = None, + latlon_coords: T.Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + backbone_down = self.backbone_down_conv( + backbone_down, + size=decode_side.shape[-2:], + ) + decode_down = self.decode_down_conv( + decode_down, + size=decode_side.shape[-2:], + ) + + x = torch.cat( + (backbone_side, backbone_down, decode_side, decode_down), + dim=1, + ) + + # Embed coordinates + if self.use_latlon: + assert latlon_coords is not None, "No lat/lon coordinates given." + + latlon_coords = rearrange( + self.geo_embeddings(latlon_coords.to(dtype=x.dtype)), + 'b c -> b c 1 1', + ) + + _, _, height, width = x.shape + x = torch.cat( + (x, latlon_coords.expand(-1, -1, height, width)), dim=1 + ) + + if tower_down is not None: + tower_down = self.tower_conv( + tower_down, + size=decode_side.shape[-2:], + ) + + x = torch.cat((x, tower_down), dim=1) + + return self.res_conv(x) diff --git a/src/cultionet/scripts/args.yml b/src/cultionet/scripts/args.yml index c4e6029c..e256b343 100644 --- a/src/cultionet/scripts/args.yml +++ b/src/cultionet/scripts/args.yml @@ -1,35 +1,17 @@ -epilog: | - ######## - Examples - ######## - - # Create training data - cultionet create --project-path /projects/data - - # Spatial k-fold cross-validation using quadrant partitions on the dataset - cultionet skfoldcv -p . --splits 2 --val-frac 0.5 --processes 8 --epochs 1 --batch-size 4 --precision 16 - - # Train a model - cultionet train --project-path /projects/data - - # Apply inference over an image - cultionet predict --project-path /projects/data -o estimates.tif - dates: start_date: short: sd long: start-date - help: The predict start date (mm-dd) + help: The start date (mm-dd or yyyy-mm-dd for predictions) kwargs: default: 01-01 end_date: short: ed long: end-date - help: The predict start date (mm-dd) + help: The end date (mm-dd or yyyy-mm-dd for predictions) kwargs: default: 01-01 - shared_image: gain: short: '' @@ -85,7 +67,7 @@ shared_create: long: res help: The cell resolution kwargs: - default: !!null + default: 10.0 type: '&float' feature_pattern: short: '' @@ -106,44 +88,15 @@ shared_create: help: Whether to create a dataset for instance segmentation kwargs: action: store_true + add_year: + short: '' + long: add-year + help: The number of years to add to the year column to derive the end year + kwargs: + default: 0 + type: '&int' create: - transforms: - short: tr - long: transforms - help: Augmentation transforms to apply - kwargs: - default: - - none - - fliplr - - flipud - - rot90 - - rot180 - - rot270 - - tswarp - - tsnoise - - tsdrift - - tspeaks - - gaussian - - saltpepper - - speckle - choices: - - none - - fliplr - - flipud - - flipfb - - rot90 - - rot180 - - rot270 - - tswarp - - tsnoise - - tsdrift - - tspeaks - - roll - - gaussian - - saltpepper - - speckle - nargs: '+' grid_size: short: gs long: grid-size @@ -160,7 +113,7 @@ create: help: The data destination kwargs: default: train - choices: ['train', 'test'] + choices: ['train', 'test', 'predict'] crop_column: short: '' long: crop-column @@ -182,15 +135,25 @@ create: replace_dict: short: '' long: replace-dict - help: A dictionary of crop class remappings + help: Crop class recodings (e.g., "61:0 141:0") + bbox_offsets: + short: '' + long: bbox_offsets + help: Additional grid bounding box offsets (e.g., 0,0 1000,0) + nonag_is_unknown: + short: '' + long: nonag-is-unknown + help: Whether the non-agricultural background is unknown + kwargs: + action: store_true + all_touched: + short: '' + long: all-touched + help: Whether to 'burn in' all pixels touched by geometries or only pixels whose center is within the polygon + kwargs: + action: store_true create_predict: - predict_year: - short: 'y' - long: year - help: The predict end year (yyyy) - kwargs: - type: '&int' window_size: short: w long: window-size @@ -209,13 +172,6 @@ create_predict: short: '' long: ts-path help: A path with time series data (overrides the config regions) - chunksize: - short: '' - long: chunksize - help: The window chunksize for processing batches - kwargs: - default: 100 - type: '&int' train_predict: model_type: @@ -223,10 +179,9 @@ train_predict: long: model-type help: The model type kwargs: - default: 'ResUNet3Psi' + default: 'TowerUNet' choices: - - 'UNet3Psi' - - 'ResUNet3Psi' + - 'TowerUNet' activation_type: short: '' long: activation-type @@ -234,18 +189,25 @@ train_predict: kwargs: default: 'SiLU' res_block_type: - short: '' + short: rb long: res-block-type - help: The residual block type (only relevant when --model-type=ResUNet3Psi) + help: The residual block type) kwargs: - default: 'res' + default: 'resa' choices: ['res', 'resa'] + dropout: + short: '' + long: dropout + help: The dropout probability + kwargs: + default: 0.2 + type: '&float' dilations: short: '' long: dilations help: The dilations to use kwargs: - default: [2] + default: [1, 2] nargs: '+' type: '&int' attention_weights: @@ -253,14 +215,14 @@ train_predict: long: attention-weights help: The attention weights kwargs: - default: 'spatial_channel' - choices: ['spatial_channel', 'fractal', 'none'] - filters: + default: 'natten' + choices: ['natten', 'spatial_channel'] + hidden_channels: short: '' - long: filters - help: The number of model input filters + long: hidden-channels + help: The number of input hidden channels kwargs: - default: 32 + default: 64 type: '&int' device: short: '' @@ -288,72 +250,48 @@ train_predict: long: batch-size help: The batch size kwargs: - default: 8 + default: 4 type: '&int' load_batch_workers: short: '' long: load-batch-workers help: The number of parallel batches to load kwargs: - default: 2 + default: 0 type: '&int' precision: short: '' long: precision help: The model data precision kwargs: - default: 16 - type: '&int' - num_classes: - short: '' - long: num-classes - help: The number of classes (overrides file info) - kwargs: - default: !!null - type: '&int' - -maskrcnn: - resize_height: - short: '' - long: resize-height - help: The image resize height - kwargs: - default: 201 - type: '&int' - resize_width: + default: '16-mixed' + strategy: short: '' - long: resize-width - help: The image resize width + long: strategy + help: The model distribution strategy kwargs: - default: 201 - type: '&int' - min_image_size: - short: '' - long: min-image-size - help: The minimum image size - kwargs: - default: 100 - type: '&int' - max_image_size: + default: 'ddp' + choices: ['ddp', 'ddp_spawn', 'fsdp', 'ddp_find_unused_parameters_true'] + data_pattern: short: '' - long: max-image-size - help: The maximum image size + long: data-pattern + help: A glob pattern for data train files kwargs: - default: 600 - type: '&int' - trainable_backbone_layers: + default: 'data*.pt' + log_transform: short: '' - long: trainable-layers - help: The number of trainable backbone layers + long: log-transform + help: Whether to log-transform the data kwargs: - default: 3 - type: '&int' + action: store_true shared_partitions: spatial_partitions: short: '' long: spatial-partitions help: The spatial partitions for spatial k-fold cross-validation or regional training + kwargs: + default: 'yes' partition_column: short: '' long: partition-column @@ -376,10 +314,17 @@ train: val_frac: short: '' long: val-frac - help: the validation fraction + help: The validation fraction kwargs: default: 0.2 type: '&float' + augment_prob: + short: '' + long: augment-prob + help: The augmentation probability + kwargs: + default: 0.5 + type: '&float' random_seed: short: '' long: random-seed @@ -392,14 +337,7 @@ train: long: epochs help: The number of training epochs kwargs: - default: 30 - type: '&int' - save_top_k: - short: '' - long: save-top-k - help: The number of model checkpoints to save (in addition to the last/best) - kwargs: - default: 1 + default: 100 type: '&int' threads: short: t @@ -414,10 +352,10 @@ train: help: Whether to reset the model kwargs: action: store_true - expected_dim: + expected_time: short: '' - long: expected-dim - help: The expected X dimension (time x bands) of the training data + long: expected-time + help: The expected time dimension of the training data kwargs: default: !!null type: '&int' @@ -453,18 +391,6 @@ train: help: The progress bar color for dimension checks kwargs: default: '#ffffff' - mean_color: - short: '' - long: mean-color - help: The progress bar color for means - kwargs: - default: '#ffffff' - sse_color: - short: '' - long: sse-color - help: The progress bar color for sum of squared errors - kwargs: - default: '#ffffff' auto_lr_find: short: '' long: lr-find @@ -484,13 +410,6 @@ train: help: The gradient clip algorithm kwargs: default: 'norm' - patience: - short: '' - long: patience - help: The early stopping patience - kwargs: - default: 20 - type: '&int' optimizer: short: '' long: optimizer @@ -498,42 +417,46 @@ train: kwargs: default: 'AdamW' choices: + - 'Adam' - 'AdamW' + - 'RAdam' - 'SGD' - deep_sup_dist: + pool_by_max: short: '' - long: deep-sup-dist - help: Whether to use deep supervision for distances + long: pool-by-max + help: Whether to apply max pooling before convolution (otherwise, use strided convolution) kwargs: action: store_true - deep_sup_edge: + batchnorm_first: short: '' - long: deep-sup-edge - help: Whether to use deep supervision for edges + long: batchnorm-first + help: Whether to apply BN->Act->Conv, otherwise Conv->BN->Act kwargs: action: store_true - deep_sup_mask: - short: '' - long: deep-sup-mask - help: Whether to use deep supervision for masks + loss_name: + short: l + long: loss-name + help: The loss method name kwargs: - action: store_true + default: 'TanimotoComplementLoss' + choices: ['TanimotoDistLoss', 'TanimotoComplementLoss', 'TanimotoCombined'] learning_rate: short: lr long: learning-rate help: The learning rate kwargs: - default: 1e-3 + default: 0.01 type: '&float' lr_scheduler: short: lrs long: lr-scheduler help: The learning rate scheduler kwargs: - default: 'CosineAnnealingLR' + default: 'OneCycleLR' choices: - - 'ExponentialLR' - 'CosineAnnealingLR' + - 'ExponentialLR' + - 'OneCycleLR' - 'StepLR' steplr_step_size: short: '' @@ -553,7 +476,7 @@ train: long: weight-decay help: Sets the weight decay for Adam optimizer\'s regularization kwargs: - default: 0.01 + default: 1e-3 type: '&float' accumulate_grad_batches: short: agb @@ -612,26 +535,23 @@ train: help: Whether to save batch validation metrics kwargs: action: store_true - refine_model: - short: '' - long: refine-model - help: Whether to refine a trained model - kwargs: - action: store_true skip_train: short: '' long: skip-train help: Whether to skip training kwargs: action: store_true + finetune: + short: '' + long: finetune + help: Layers to finetune (if None, do feature extraction) + kwargs: + default: !!null + choices: + - all + - fc predict: - predict_year: - short: 'y' - long: year - help: The predict end year (yyyy) - kwargs: - type: '&int' out_path: short: 'o' long: out-path @@ -666,7 +586,7 @@ predict: long: padding help: The read padding around the window (padding is sliced off before writing) kwargs: - default: 101 + default: 20 type: '&int' mode: short: '' @@ -687,15 +607,21 @@ predict: help: The compression algorithm to use kwargs: default: lzw - include_maskrcnn: - short: '' - long: include-maskrcnn - help: Whether to include Mask R-CNN - kwargs: - action: store_true delete_dataset: short: '' long: delete-dataset help: Whether to delete the prediction dataset kwargs: action: store_true + +train_transfer: + placeholder: + short: '' + long: placeholder + help: Help for placeholder + +predict_transfer: + placeholder: + short: '' + long: placeholder + help: Help for placeholder \ No newline at end of file diff --git a/src/cultionet/scripts/config.yml b/src/cultionet/scripts/config.yml index 6f9fc3ee..82212f28 100644 --- a/src/cultionet/scripts/config.yml +++ b/src/cultionet/scripts/config.yml @@ -3,20 +3,15 @@ image_vis: - gcvi - kndvi -# The regions to process (start, end) -regions: - - 1 - - 1 - # The region file path region_id_file: !!null +polygon_file: !!null + +# Each year in `region_id_file` should correspond to the year of harvest +# For US harvest year 2019, an end date of 12-31 would mean 2019-01-01 to 2020-01-01 +# For Argentina harvest year 2019, an end date of 07-01 would mean 2018-07-01 to 2019-07-01 +start_mmdd: '01-01' +end_mmdd: '12-31' -# End years (i.e., 2020 = 2019 planting/harvest year) -# 2019 = 2018 CDL -# 2020 = 2019 CDL -# 2021 = 2020 CDL -# 2022 = 2021 CDL -years: - - 2020 - - 2021 - - 2022 +# The length of the time series +num_months: 12 diff --git a/src/cultionet/scripts/cultionet.py b/src/cultionet/scripts/cultionet.py index 979b980d..c20316b4 100644 --- a/src/cultionet/scripts/cultionet.py +++ b/src/cultionet/scripts/cultionet.py @@ -1,44 +1,41 @@ #!/usr/bin/env python -from abc import abstractmethod import argparse -import typing as T -import logging -from pathlib import Path -from datetime import datetime -import asyncio -import filelock import builtins import json -import ast -import itertools +import logging +import typing as T +from collections import namedtuple +from datetime import datetime +from functools import partial +from pathlib import Path -import geowombat as gw -from geowombat.core.windows import get_window_offsets import geopandas as gpd +import numpy as np import pandas as pd -import yaml -import rasterio as rio -from rasterio.windows import Window import torch -import xarray as xr -import ray -from ray.actor import ActorHandle -from tqdm import tqdm -from tqdm.dask import TqdmCallback -from pytorch_lightning import seed_everything +import yaml +from fiona.errors import DriverError +from geowombat.core import sort_images_by_date +from joblib import delayed, parallel_config +from lightning import seed_everything +from rich.markdown import Markdown +from rich_argparse import RichHelpFormatter +from shapely.errors import GEOSException +from shapely.geometry import box import cultionet -from cultionet.data.const import SCALE_FACTOR +from cultionet.data.create import create_predict_dataset, create_train_batch from cultionet.data.datasets import EdgeDataset -from cultionet.utils.project_paths import setup_paths, ProjectPaths +from cultionet.data.utils import split_multipolygons +from cultionet.enums import CLISteps, DataColumns, ModelNames from cultionet.errors import TensorShapeError -from cultionet.utils.normalize import get_norm_values -from cultionet.data.create import create_dataset, create_predict_dataset -from cultionet.data.utils import get_image_list_dims, create_network_data +from cultionet.model import CultionetParams from cultionet.utils import model_preprocessing from cultionet.utils.logging import set_color_logger - +from cultionet.utils.model_preprocessing import ParallelProgress +from cultionet.utils.normalize import NormValues +from cultionet.utils.project_paths import ProjectPaths, setup_paths logger = set_color_logger(__name__) @@ -65,48 +62,42 @@ def get_centroid_coords_from_image( def get_start_end_dates( feature_path: Path, - start_year: int, - start_date: str, - end_date: str, + end_year: T.Union[int, str], + start_mmdd: str, + end_mmdd: str, + num_months: int, date_format: str = "%Y%j", lat: T.Optional[float] = None, ) -> T.Tuple[str, str]: """Gets the start and end dates from user args or from the filenames. - Returns: - str (mm-dd), str (mm-dd) + Returns + ======= + str (mm-dd), str (mm-dd) """ - # Get the first file for the start year - filename = list(feature_path.glob(f"{start_year}*.tif"))[0] - # Get the date from the file name - file_dt = datetime.strptime(filename.stem, date_format) - if start_date is not None: - start_date = start_date - else: - start_date = file_dt.strftime("%m-%d") - if end_date is not None: - end_date = end_date - else: - end_date = file_dt.strftime("%m-%d") + image_dict = sort_images_by_date( + feature_path, + '*.tif', + date_pos=0, + date_start=0, + date_end=8, + date_format=date_format, + ) + image_df = pd.DataFrame( + data=list(image_dict.keys()), + columns=['filename'], + index=list(image_dict.values()), + ) - month = int(start_date.split("-")[0]) + end_date_stamp = pd.Timestamp(f"{end_year}-{end_mmdd}") + start_year = (end_date_stamp - pd.DateOffset(months=num_months - 1)).year + start_date_stamp = pd.Timestamp(f"{start_year}-{start_mmdd}") + image_df = image_df.loc[start_date_stamp:end_date_stamp] - if lat is not None: - if lat > 0: - # Expected time series start in northern hemisphere winter - if 2 < month < 11: - logger.warning( - f"The time series start date is {start_date} but the time series is in the Northern hemisphere." - ) - else: - # Expected time series start in northern southern winter - if (month < 5) or (month > 9): - logger.warning( - f"The time series start date is {start_date} but the time series is in the Southern hemisphere." - ) - - return start_date, end_date + return image_df.index[0].strftime("%Y-%m-%d"), image_df.index[-1].strftime( + "%Y-%m-%d" + ) def get_image_list( @@ -164,840 +155,379 @@ def get_image_list( return image_list -@ray.remote -class ProgressBarActor: - """ - Reference: - https://docs.ray.io/en/releases-1.11.1/ray-core/examples/progress_bar.html - """ +def predict_image(args): + logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) - counter: int - delta: int - event: asyncio.Event + # This is a helper function to manage paths + ppaths = setup_paths( + args.project_path, append_ts=True if args.append_ts == "y" else False + ) - def __init__(self) -> None: - self.counter = 0 - self.delta = 0 - self.event = asyncio.Event() + # Load the z-score norm values + norm_values = NormValues.from_file(ppaths.norm_file) - def update(self, num_items_completed: int) -> None: - """Updates the ProgressBar with the incremental number of items that - were just completed.""" - self.counter += num_items_completed - self.delta += num_items_completed - self.event.set() + ds = EdgeDataset( + root=ppaths.predict_path, + log_transform=args.log_transform, + norm_values=norm_values, + pattern=f"{args.region}_{args.start_date.replace('-', '')}_{args.end_date.replace('-', '')}*.pt", + ) + + cultionet.predict_lightning( + reference_image=args.reference_image, + out_path=args.out_path, + ckpt=ppaths.ckpt_path / ModelNames.CKPT_NAME, + dataset=ds, + device=args.device, + devices=args.devices, + strategy=args.strategy, + batch_size=args.batch_size, + load_batch_workers=args.load_batch_workers, + precision=args.precision, + resampling=ds[0].resampling[0] + if hasattr(ds[0], "resampling") + else "nearest", + compression=args.compression, + is_transfer_model=args.process == CLISteps.PREDICT_TRANSFER, + ) - async def wait_for_update(self) -> T.Tuple[int, int]: - """Blocking call. + if args.delete_dataset: + ds.cleanup() - Waits until somebody calls `update`, then returns a tuple of the number - of updates since the last call to `wait_for_update`, and the total - number of completed items. - """ - await self.event.wait() - self.event.clear() - saved_delta = self.delta - self.delta = 0 - return saved_delta, self.counter +def create_one_id( + args: namedtuple, + config: dict, + ppaths: ProjectPaths, + region_df: gpd.GeoDataFrame, + polygon_df: gpd.GeoDataFrame, + processed_path: Path, + bbox_offsets: T.Optional[T.List[T.Tuple[int, int]]] = None, +) -> None: + """Creates a single dataset. + + Parameters + ========== + args + An ``argparse`` ``namedtuple`` of CLI arguments. + config + The configuration. + ppaths + The project path object. + region_df + The region grid ``geopandas.GeoDataFrame``. + polygon_df + The region polygon ``geopandas.GeoDataFrame``. + processed_path + The time series path. + bbox_offsets + Bounding box (x, y) offsets as [(x, y)]. E.g., shifts of + [(-1000, 0), (0, 1000)] would shift the grid left by 1,000 meters and + then right by 1,000 meters. + + Note that the ``polygon_df`` should support the shifts outside of the grid. + """ - def get_counter(self) -> int: - """Returns the total number of complete items.""" - return self.counter + row_id = processed_path.name + bbox_offset_list = [(0, 0)] + if bbox_offsets is not None: + bbox_offset_list.extend(bbox_offsets) -class ProgressBar: - """ - Reference: - https://docs.ray.io/en/releases-1.11.1/ray-core/examples/progress_bar.html - """ + for grid_offset in bbox_offset_list: - progress_actor: ActorHandle - total: int - desc: str - position: int - leave: bool - pbar: tqdm + if args.destination != "predict": + # Get the grid + row_region_df = region_df.query( + f"{DataColumns.GEOID} == '{row_id}'" + ) - def __init__( - self, total: int, desc: str = "", position: int = 0, leave: bool = True - ): - # Ray actors don't seem to play nice with mypy, generating - # a spurious warning for the following line, - # which we need to suppress. The code is fine. - self.progress_actor = ProgressBarActor.remote() # type: ignore - self.total = total - self.desc = desc - self.position = position - self.leave = leave - - @property - def actor(self) -> ActorHandle: - """Returns a reference to the remote `ProgressBarActor`. - - When you complete tasks, call `update` on the actor. - """ - return self.progress_actor - - def print_until_done(self) -> None: - """Blocking call. - - Do this after starting a series of remote Ray tasks, to which you've - passed the actor handle. Each of them calls `update` on the actor. When - the progress meter reaches 100%, this method returns. - """ - pbar = tqdm( - desc=self.desc, - position=self.position, - total=self.total, - leave=self.leave, - ) - while True: - delta, counter = ray.get(self.actor.wait_for_update.remote()) - pbar.update(delta) - if counter >= self.total: - pbar.close() + if row_region_df.empty: return + left, bottom, right, top = row_region_df.total_bounds + + if grid_offset != (0, 0): + # Create a new, shifted grid + row_region_df = gpd.GeoDataFrame( + geometry=[ + box( + left + grid_offset[1], + bottom + grid_offset[0], + right + grid_offset[1], + top + grid_offset[0], + ), + ], + crs=row_region_df.crs, + ) + left, bottom, right, top = row_region_df.total_bounds -class BlockWriter(object): - def _build_slice(self, window: Window) -> tuple: - return ( - slice(0, None), - slice(window.row_off, window.row_off + window.height), - slice(window.col_off, window.col_off + window.width), - ) + # Clip the polygons to the current grid + # NOTE: .cx gets all intersecting polygons and reduces the problem size for clip() + polygon_df_intersection = polygon_df.cx[left:right, bottom:top] - def predict_write_block(self, w: Window, w_pad: Window): - slc = self._build_slice(w_pad) - # Create the data for the chunk - data = create_network_data( - self.ts[slc].gw.compute(num_workers=1), - ntime=self.ntime, - nbands=self.nbands, - ) - # Apply inference on the chunk - stack = cultionet.predict( - lit_model=self.lit_model, - data=data, - written=None, # self.dst.read(self.bands[-1], window=w_pad), - data_values=self.data_values, - w=w, - w_pad=w_pad, - device=self.device, - include_maskrcnn=self.include_maskrcnn, - ) - # Write the prediction stack to file - with filelock.FileLock("./dst.lock"): - self.dst.write( - stack, - indexes=range(1, self.dst.profile["count"] + 1), - window=w, + # Clip the polygons to the grid edges + try: + row_polygon_df = gpd.clip( + polygon_df_intersection, + row_region_df, + ) + except GEOSException: + try: + # Try clipping with any MultiPolygon split + row_polygon_df = gpd.clip( + split_multipolygons(polygon_df_intersection), + row_region_df, + ) + except GEOSException: + try: + # Try clipping with a ghost buffer + row_polygon_df = gpd.clip( + split_multipolygons( + polygon_df_intersection + ).assign( + geometry=polygon_df_intersection.geometry.buffer( + 0 + ) + ), + row_region_df, + ) + except GEOSException: + logger.warning( + f"Could not create a dataset file for {row_id}." + ) + return + + # Check for multi-polygons + row_polygon_df = split_multipolygons(row_polygon_df) + # Rather than check for a None CRS, just set it + row_polygon_df = row_polygon_df.set_crs( + polygon_df_intersection.crs, allow_override=True ) + end_year = int(row_region_df[DataColumns.YEAR]) -class WriterModule(BlockWriter): - def __init__( - self, - out_path: T.Union[str, Path], - mode: str, - profile: dict, - ntime: int, - nbands: int, - filters: int, - num_classes: int, - ts: xr.DataArray, - data_values: torch.Tensor, - ppaths: ProjectPaths, - device: str, - scale_factor: float, - include_maskrcnn: bool, - ) -> None: - self.out_path = out_path - # Create the output file - if mode == "w": - with rio.open(self.out_path, mode=mode, **profile): - pass - - self.dst = rio.open(self.out_path, mode="r+") - - self.ntime = ntime - self.nbands = nbands - self.ts = ts - self.data_values = data_values - self.ppaths = ppaths - self.device = device - self.scale_factor = scale_factor - self.include_maskrcnn = include_maskrcnn - # self.bands = [1, 2, 3] #+ list(range(4, 4+num_classes-1)) - # if self.include_maskrcnn: - # self.bands.append(self.bands[-1] + 1) - - self.lit_model = cultionet.load_model( - ckpt_file=self.ppaths.ckpt_file, - model_file=self.ppaths.ckpt_file.parent / "cultionet.pt", - num_features=ntime * nbands, - num_time_features=ntime, - filters=filters, - num_classes=num_classes, - device=self.device, - enable_progress_bar=False, - )[1] - - def close_open(self): - self.close() - self.dst = rio.open(self.out_path, mode="r+") - - def close(self): - self.dst.close() - - @abstractmethod - def write( - self, - w: Window, - w_pad: Window, - pba: T.Optional[T.Union[ActorHandle, int]] = None, - ): - raise NotImplementedError - - -@ray.remote -class RemoteWriter(WriterModule): - """A concurrent writer with Ray.""" - - def __init__( - self, - out_path: T.Union[str, Path], - mode: str, - profile: dict, - ntime: int, - nbands: int, - filters: int, - num_classes: int, - ts: xr.DataArray, - data_values: torch.Tensor, - ppaths: ProjectPaths, - device: str, - scale_factor: float, - include_maskrcnn: bool, - ) -> None: - super().__init__( - out_path=out_path, - mode=mode, - profile=profile, - ntime=ntime, - nbands=nbands, - filters=filters, - num_classes=num_classes, - ts=ts, - data_values=data_values, - ppaths=ppaths, - device=device, - scale_factor=scale_factor, - include_maskrcnn=include_maskrcnn, - ) + if args.add_year > 0: + end_year += args.add_year - def write(self, w: Window, w_pad: Window, pba: ActorHandle = None): - self.predict_write_block(w, w_pad) - if pba is not None: - pba.update.remote(1) - - -class SerialWriter(WriterModule): - """A serial writer.""" - - def __init__( - self, - out_path: T.Union[str, Path], - mode: str, - profile: dict, - ntime: int, - nbands: int, - filters: int, - num_classes: int, - ts: xr.DataArray, - data_values: torch.Tensor, - ppaths: ProjectPaths, - device: str, - scale_factor: float, - include_maskrcnn: bool, - ) -> None: - super().__init__( - out_path=out_path, - mode=mode, - profile=profile, - ntime=ntime, - nbands=nbands, - filters=filters, - num_classes=num_classes, - ts=ts, - data_values=data_values, - ppaths=ppaths, - device=device, - scale_factor=scale_factor, - include_maskrcnn=include_maskrcnn, - ) + image_list = [] + for image_vi in config["image_vis"]: + # Set the full path to the images + vi_path = ppaths.image_path.resolve().joinpath( + args.feature_pattern.format(region=row_id, image_vi=image_vi) + ) - def write(self, w: Window, w_pad: Window, pba: int = None): - self.predict_write_block(w, w_pad) - self.close_open() - if pba is not None: - pba.update(1) + if not vi_path.exists(): + logger.warning( + f"The {image_vi} path is missing for {str(vi_path)}." + ) + return + # Get the requested time slice + ts_list = model_preprocessing.get_time_series_list( + vi_path, + date_format=args.date_format, + start_date=pd.to_datetime(args.start_date) + if args.destination == "predict" + else None, + end_date=pd.to_datetime(args.end_date) + if args.destination == "predict" + else None, + end_year=end_year if args.destination != "predict" else None, + start_mmdd=config["start_mmdd"], + end_mmdd=config["end_mmdd"], + num_months=config["num_months"], + ) -def predict_image(args): - logging.getLogger("pytorch_lightning").setLevel(logging.WARNING) + if args.skip_index > 0: + ts_list = ts_list[:: args.skip_index] - config = open_config(args.config_file) - - # This is a helper function to manage paths - ppaths = setup_paths( - args.project_path, append_ts=True if args.append_ts == "y" else False - ) - # Load the z-score norm values - data_values = torch.load(ppaths.norm_file) - with open(ppaths.classes_info_path, mode="r") as f: - class_info = json.load(f) + image_list += ts_list - num_classes = ( - args.num_classes - if args.num_classes is not None - else class_info["max_crop_class"] + 1 - ) + if image_list: + if args.destination == "predict": + create_predict_dataset( + image_list=image_list, + region=row_id, + process_path=ppaths.get_process_path(args.destination), + date_format=args.date_format, + gain=args.gain, + offset=args.offset, + ref_res=args.ref_res, + resampling=args.resampling, + window_size=args.window_size, + padding=args.padding, + num_workers=args.num_workers, + ) + else: + class_info = { + "max_crop_class": args.max_crop_class, + "edge_class": args.max_crop_class + 1, + } + with open(ppaths.classes_info_path, mode="w") as f: + f.write(json.dumps(class_info)) - if args.data_path is not None: - ds = EdgeDataset( - ppaths.predict_path, - data_means=data_values.mean, - data_stds=data_values.std, - pattern=f"data_{args.region}_{args.predict_year}*.pt", - ) - ckpt_file = ppaths.ckpt_path / "last.ckpt" + create_train_batch( + image_list=image_list, + df_grid=row_region_df, + df_polygons=row_polygon_df, + max_crop_class=args.max_crop_class, + region=row_id, + process_path=ppaths.get_process_path(args.destination), + date_format=args.date_format, + gain=args.gain, + offset=args.offset, + ref_res=args.ref_res, + resampling=args.resampling, + grid_size=args.grid_size, + crop_column=args.crop_column, + keep_crop_classes=args.keep_crop_classes, + replace_dict=args.replace_dict, + nonag_is_unknown=args.nonag_is_unknown, + all_touched=args.all_touched, + ) - cultionet.predict_lightning( - reference_image=args.reference_image, - out_path=args.out_path, - ckpt=ckpt_file, - dataset=ds, - batch_size=args.batch_size, - load_batch_workers=args.load_batch_workers, - device=args.device, - devices=args.devices, - precision=args.precision, - num_classes=num_classes, - ref_res=ds[0].res, - resampling=ds[0].resampling, - compression=args.compression, - refine_pt=ckpt_file.parent / "refine" / "refine.pt", - ) - if args.delete_dataset: - ds.cleanup() - else: +def read_training( + filename: T.Union[list, tuple, str, Path], columns: list +) -> gpd.GeoDataFrame: + if isinstance(filename, (list, tuple)): try: - tmp = int(args.grid_id) - region = f"{tmp:06d}" - except ValueError: - region = args.grid_id - - # Get the image list - image_list = get_image_list( - ppaths, - region=region, - predict_year=args.predict_year, - start_date=args.start_date, - end_date=args.end_date, - config=config, - date_format=args.date_format, - skip_index=args.skip_index, - ) - - with gw.open( - image_list, - stack_dim="band", - band_names=list(range(1, len(image_list) + 1)), - ) as src_ts: - time_series = ( - (src_ts * args.gain + args.offset).astype("float64").clip(0, 1) - ) - if args.preload_data: - with TqdmCallback(desc="Loading data"): - time_series.load(num_workers=args.processes) - # Get the image dimensions - nvars = model_preprocessing.VegetationIndices( - image_vis=config["image_vis"] - ).n_vis - nfeas, height, width = time_series.shape - ntime = int(nfeas / nvars) - windows = get_window_offsets( - height, - width, - args.window_size, - args.window_size, - padding=( - args.padding, - args.padding, - args.padding, - args.padding, - ), - ) + df = pd.concat( + [ + gpd.read_file( + fn, + columns=columns, + engine="pyogrio", + ) + for fn in filename + ] + ).reset_index(drop=True) - profile = { - "crs": src_ts.crs, - "transform": src_ts.gw.transform, - "height": height, - "width": width, - # Orientation (+1) + distance (+1) + edge (+1) + crop (+1) crop types (+N) - # `num_classes` includes background - "count": 3 + num_classes - 1, - "dtype": "uint16", - "blockxsize": 64 if 64 < width else width, - "blockysize": 64 if 64 < height else height, - "driver": "GTiff", - "sharing": False, - "compress": args.compression, - } - profile["tiled"] = ( - True - if max(profile["blockxsize"], profile["blockysize"]) >= 16 - else False - ) + except DriverError: + raise IOError("The id file does not exist") - # Get the time and band count - ntime, nbands = get_image_list_dims(image_list, time_series) - - if args.processes == 1: - serial_writer = SerialWriter( - out_path=args.out_path, - mode=args.mode, - profile=profile, - ntime=ntime, - nbands=nbands, - filters=args.filters, - num_classes=num_classes, - ts=time_series, - data_values=data_values, - ppaths=ppaths, - device=args.device, - scale_factor=SCALE_FACTOR, - include_maskrcnn=args.include_maskrcnn, - ) - try: - with tqdm( - total=len(windows), - desc="Predicting windows", - position=0, - ) as pbar: - results = [ - serial_writer.write(w, w_pad, pba=pbar) - for w, w_pad in windows - ] - serial_writer.close() - except Exception as e: - serial_writer.close() - logger.exception(f"The predictions failed because {e}.") - else: - if ray.is_initialized(): - logger.warning("The Ray cluster is already running.") - else: - if args.device == "gpu": - # TODO: support multiple GPUs through CLI - try: - ray.init(num_cpus=args.processes, num_gpus=1) - except KeyError as e: - logger.exception( - f"Ray could not be instantiated with a GPU because {e}." - ) - else: - ray.init(num_cpus=args.processes) - assert ray.is_initialized(), "The Ray cluster is not running." - # Setup the remote ray writer - remote_writer = RemoteWriter.options( - max_concurrency=args.processes - ).remote( - out_path=args.out_path, - mode=args.mode, - profile=profile, - ntime=ntime, - nbands=nbands, - filters=args.filters, - num_classes=num_classes, - ts=ray.put(time_series), - data_values=data_values, - ppaths=ppaths, - device=args.device, - devices=args.devices, - scale_factor=SCALE_FACTOR, - include_maskrcnn=args.include_maskrcnn, - ) - actor_chunksize = args.processes * 8 - try: - with tqdm( - total=len(windows), - desc="Predicting windows", - position=0, - ) as pbar: - for wchunk in range( - 0, len(windows) + actor_chunksize, actor_chunksize - ): - chunk_windows = windows[ - wchunk : wchunk + actor_chunksize - ] - pbar.set_description( - f"Windows {wchunk:,d}--{wchunk+len(chunk_windows):,d}" - ) - # pb = ProgressBar( - # total=len(chunk_windows), - # desc=f'Chunks {wchunk}-{wchunk+len(chunk_windows)}', - # position=1, - # leave=False - # ) - # tqdm_actor = pb.actor - # Write each window concurrently - results = [ - remote_writer.write.remote(w, w_pad) - for w, w_pad in chunk_windows - ] - # Initiate the processing - # pb.print_until_done() - ray.get(results) - # Close the file - ray.get(remote_writer.close_open.remote()) - pbar.update(len(chunk_windows)) - ray.get(remote_writer.close.remote()) - ray.shutdown() - except Exception as e: - ray.get(remote_writer.close.remote()) - ray.shutdown() - logger.exception(f"The predictions failed because {e}.") - - -def cycle_data( - year_lists: list, - regions_lists: list, - project_path_lists: list, - ref_res_lists: list, -): - for years, regions, project_path, ref_res in zip( - year_lists, regions_lists, project_path_lists, ref_res_lists - ): - for region in regions: - for image_year in years: - yield region, image_year, project_path, ref_res + else: + filename = Path(filename) + if not filename.exists(): + raise IOError("The id file does not exist") + df = gpd.read_file(filename) -def get_centroid_coords( - df: gpd.GeoDataFrame, dst_crs: T.Optional[str] = None -) -> T.Tuple[float, float]: - """Gets the lon/lat or x/y coordinates of a centroid.""" - centroid = df.to_crs(dst_crs).centroid + return df - return float(centroid.x), float(centroid.y) +def create_dataset(args): + """Creates a train or predict dataset.""" -def create_datasets(args): config = open_config(args.config_file) - project_path_lists = [args.project_path] - ref_res_lists = [args.ref_res] + + ppaths: ProjectPaths = setup_paths( + args.project_path, + append_ts=True if args.append_ts == "y" else False, + ) if hasattr(args, "max_crop_class"): assert isinstance( args.max_crop_class, int ), "The maximum crop class value must be given." - region_as_list = config["regions"] is not None - region_as_file = config["region_id_file"] is not None + region_df = None + polygon_df = None + if args.destination == "train": + region_id_file = config.get("region_id_file") + polygon_file = config.get("polygon_file") - assert ( - region_as_list or region_as_file - ), "Only submit region as a list or as a given file" + if region_id_file is None: + raise NameError("A region file or file list must be given.") - if hasattr(args, "time_series_path") and ( - args.time_series_path is not None - ): - inputs = model_preprocessing.TrainInputs( - regions=[Path(args.time_series_path).name], - years=[args.predict_year], - ) - else: - if region_as_file: - file_path = config["region_id_file"] - if not Path(file_path).is_file(): - raise IOError("The id file does not exist") - id_data = pd.read_csv(file_path) - assert ( - "id" in id_data.columns - ), f"id column not found in {file_path}." - regions = id_data["id"].unique().tolist() - else: - regions = list( - range(config["regions"][0], config["regions"][1] + 1) - ) + if polygon_file is None: + raise NameError("A polygon file or file list must be given.") - inputs = model_preprocessing.TrainInputs( - regions=regions, years=config["years"] + # Read the training grids + region_df = read_training( + region_id_file, + columns=[DataColumns.GEOID, DataColumns.YEAR, "geometry"], ) - total_iters = len( - list( - itertools.product( - list(itertools.chain.from_iterable(inputs.year_lists)), - list( - itertools.chain.from_iterable(inputs.regions_lists) - ), - ) + # Read the training polygons + polygon_df = read_training( + polygon_file, + columns=[args.crop_column, "geometry"], ) - ) - with tqdm(total=total_iters, position=0, leave=True) as pbar: - for region, end_year, project_path, ref_res in cycle_data( - inputs.year_lists, - inputs.regions_lists, - project_path_lists, - ref_res_lists, - ): - ppaths = setup_paths( - project_path, - append_ts=True if args.append_ts == "y" else False, - ) + polygon_df[args.crop_column] + polygon_df = polygon_df.astype({args.crop_column: int}) - try: - tmp = int(region) - region = f"{tmp:06d}" - except ValueError: - pass - - if args.destination == "predict": - df_grids = None - df_edges = None - else: - # Read the training data - grids = ( - ppaths.edge_training_path - / f"{region}_grid_{end_year}.gpkg" - ) - edges = ( - ppaths.edge_training_path - / f"{region}_edges_{end_year}.gpkg" - ) - if not grids.is_file(): - pbar.update(1) - pbar.set_description("File not exist") - continue - - df_grids = gpd.read_file(grids) - if not {"region", "grid"}.intersection(df_grids.columns.tolist()): - df_grids["region"] = region - - if not edges.is_file(): - edges = ( - ppaths.edge_training_path - / f"{region}_poly_{end_year}.gpkg" - ) - if not edges.is_file(): - # No training polygons - df_edges = gpd.GeoDataFrame( - data=[], geometry=[], crs=df_grids.crs - ) - else: - df_edges = gpd.read_file(edges) - - image_list = [] - for image_vi in model_preprocessing.VegetationIndices( - image_vis=config["image_vis"] - ).image_vis: - # Set the full path to the images - vi_path = ppaths.image_path / args.feature_pattern.format( - region=region, image_vi=image_vi - ) - - if not vi_path.is_dir(): - pbar.update(1) - pbar.set_description("No directory") - continue - - # Get the centroid coordinates of the grid - lat = None - if args.destination != "predict": - lat = get_centroid_coords( - df_grids.centroid, dst_crs="epsg:4326" - )[1] - # Get the start and end dates - start_date, end_date = get_start_end_dates( - vi_path, - start_year=end_year - 1, - start_date=args.start_date, - end_date=args.end_date, - date_format=args.date_format, - lat=lat, - ) - # Get the requested time slice - ts_list = model_preprocessing.get_time_series_list( - vi_path, - end_year - 1, - start_date, - end_date, - date_format=args.date_format, - ) - if len(ts_list) <= 1: - pbar.update(1) - pbar.set_description("TS too short") - continue - - if args.skip_index > 0: - ts_list = ts_list[:: args.skip_index] - image_list += ts_list - - if args.destination != "predict": - class_info = { - "max_crop_class": args.max_crop_class, - "edge_class": args.max_crop_class + 1, - } - with open(ppaths.classes_info_path, mode="w") as f: - f.write(json.dumps(class_info)) - - if image_list: - if args.destination == "predict": - create_predict_dataset( - image_list=image_list, - region=region, - year=end_year, - process_path=ppaths.get_process_path(args.destination), - gain=args.gain, - offset=args.offset, - ref_res=ref_res, - resampling=args.resampling, - window_size=args.window_size, - padding=args.padding, - num_workers=args.num_workers, - chunksize=args.chunksize, - ) - else: - pbar = create_dataset( - image_list=image_list, - df_grids=df_grids, - df_edges=df_edges, - max_crop_class=args.max_crop_class, - group_id=f"{region}_{end_year}", - process_path=ppaths.get_process_path(args.destination), - transforms=args.transforms, - gain=args.gain, - offset=args.offset, - ref_res=ref_res, - resampling=args.resampling, - num_workers=args.num_workers, - grid_size=args.grid_size, - instance_seg=args.instance_seg, - zero_padding=args.zero_padding, - crop_column=args.crop_column, - keep_crop_classes=args.keep_crop_classes, - replace_dict=args.replace_dict, - pbar=pbar, - ) - - pbar.update(1) + assert ( + region_df.crs == polygon_df.crs + ), "The region id CRS does not match the polygon CRS." + assert ( + DataColumns.GEOID in region_df.columns + ), "The geo_id column was not found in the grid region file." -def train_maskrcnn(args): - seed_everything(args.random_seed, workers=True) + assert ( + DataColumns.YEAR in region_df.columns + ), "The year column was not found in the grid region file." - # This is a helper function to manage paths - ppaths = setup_paths(args.project_path, ckpt_name="maskrcnn.ckpt") + if 0 in polygon_df[args.crop_column].unique(): + raise ValueError("The field crop values should not have zeros.") - if ( - (args.expected_dim is not None) - or not ppaths.norm_file.is_file() - or (ppaths.norm_file.is_file() and args.recalc_zscores) + # Get processed ids + if hasattr(args, 'time_series_path') and ( + args.time_series_path is not None ): - ds = EdgeDataset( - ppaths.train_path, - processes=args.processes, - threads_per_worker=args.threads, - random_seed=args.random_seed, - ) - # Check dimensions - if args.expected_dim is not None: - try: - ds.check_dims( - args.expected_dim, args.delete_mismatches, args.dim_color - ) - except TensorShapeError as e: - raise ValueError(e) - # Get the normalization means and std. deviations on the train data - # Calculate the values needed to transform to z-scores, using - # the training data - if ppaths.norm_file.is_file(): - if args.recalc_zscores: - ppaths.norm_file.unlink() - if not ppaths.norm_file.is_file(): - train_ds = ds.split_train_val(val_frac=args.val_frac)[0] - data_values = get_norm_values( - dataset=train_ds, - batch_size=args.batch_size, - mean_color=args.mean_color, - sse_color=args.sse_color, - ) - torch.save(data_values, str(ppaths.norm_file)) + processed_ids = [Path(args.time_series_path)] else: - data_values = torch.load(str(ppaths.norm_file)) + if 'data_pattern' in config: + processed_ids = list( + ppaths.image_path.resolve().glob(config['data_pattern']) + ) + else: + processed_ids = list(ppaths.image_path.resolve().glob('*')) - # Create the train data object again, this time passing - # the means and standard deviation tensors - ds = EdgeDataset( - ppaths.train_path, - data_means=data_values.mean, - data_stds=data_values.std, - random_seed=args.random_seed, - ) - # Check for a test dataset - test_ds = None - if list((ppaths.test_process_path).glob("*.pt")): - test_ds = EdgeDataset( - ppaths.test_path, - data_means=data_values.mean, - data_stds=data_values.std, - random_seed=args.random_seed, + if args.destination == "train": + # Filter ids to those that have been processed + processed_mask = np.isin( + np.array([fn.name for fn in processed_ids]), + region_df[DataColumns.GEOID].values, ) - if args.expected_dim is not None: - try: - test_ds.check_dims( - args.expected_dim, args.delete_mismatches, args.dim_color - ) - except TensorShapeError as e: - raise ValueError(e) - - # Fit the model - cultionet.fit_maskrcnn( - dataset=ds, - ckpt_file=ppaths.ckpt_file, - test_dataset=test_ds, - val_frac=args.val_frac, - batch_size=args.batch_size, - epochs=args.epochs, - save_top_k=args.save_top_k, - accumulate_grad_batches=args.accumulate_grad_batches, - learning_rate=args.learning_rate, - filters=args.filters, - num_classes=args.num_classes, - reset_model=args.reset_model, - auto_lr_find=args.auto_lr_find, - device=args.device, - devices=args.devices, - gradient_clip_val=args.gradient_clip_val, - gradient_clip_algorithm=args.gradient_clip_algorithm, - early_stopping_patience=args.patience, - weight_decay=args.weight_decay, - precision=args.precision, - stochastic_weight_averaging=args.stochastic_weight_averaging, - stochastic_weight_averaging_lr=args.stochastic_weight_averaging_lr, - stochastic_weight_averaging_start=args.stochastic_weight_averaging_start, - model_pruning=args.model_pruning, - resize_height=args.resize_height, - resize_width=args.resize_width, - min_image_size=args.min_image_size, - max_image_size=args.max_image_size, - trainable_backbone_layers=args.trainable_backbone_layers, + processed_ids = np.array(processed_ids)[processed_mask] + + partial_create_one_id = partial( + create_one_id, + args=args, + config=config, + ppaths=ppaths, + region_df=region_df, + polygon_df=polygon_df, + bbox_offsets=args.bbox_offsets + if args.destination == "train" + else None, ) + if args.destination == "predict": + partial_create_one_id(processed_path=processed_ids[0]) + else: + with parallel_config( + backend="loky", + n_jobs=args.num_workers, + ): + with ParallelProgress( + tqdm_params={ + "total": len(processed_ids), + "desc": f"Creating {args.destination} files", + "colour": "green", + "ascii": "\u2015\u25E4\u25E5\u25E2\u25E3\u25AA", + }, + ) as parallel_pool: + parallel_pool( + delayed(partial_create_one_id)( + processed_path=processed_path + ) + for processed_path in processed_ids + ) + def spatial_kfoldcv(args): ppaths = setup_paths(args.project_path) @@ -1006,9 +536,9 @@ def spatial_kfoldcv(args): class_info = json.load(f) ds = EdgeDataset( - ppaths.train_path, + root=ppaths.train_path, + log_transform=args.log_transform, processes=args.processes, - threads_per_worker=args.threads, random_seed=args.random_seed, ) # Read or create the spatial partitions (folds) @@ -1023,17 +553,13 @@ def spatial_kfoldcv(args): ) # Normalize the partition temp_ds = train_ds.split_train_val(val_frac=args.val_frac)[0] - data_values = get_norm_values( + norm_values = NormValues.from_dataset( dataset=temp_ds, class_info=class_info, batch_size=args.batch_size, - mean_color=args.mean_color, - sse_color=args.sse_color, ) - train_ds.data_means = data_values.mean - train_ds.data_stds = data_values.std - test_ds.data_means = data_values.mean - test_ds.data_stds = data_values.std + train_ds.norm_values = norm_values + test_ds.norm_values = norm_values # Get balanced class weights # Reference: https://github.com/scikit-learn/scikit-learn/blob/f3f51f9b6/sklearn/utils/class_weight.py#L10 @@ -1041,9 +567,9 @@ def spatial_kfoldcv(args): # class_weights = recip_freq[torch.arange(0, len(data_values.crop_counts)-1)] # class_weights = torch.tensor([0] + list(class_weights), dtype=torch.float) if torch.cuda.is_available(): - class_counts = data_values.crop_counts.to("cuda") + class_counts = norm_values.crop_counts.to("cuda") else: - class_counts = data_values.crop_counts + class_counts = norm_values.crop_counts # Fit the model cultionet.fit( @@ -1058,7 +584,7 @@ def spatial_kfoldcv(args): accumulate_grad_batches=args.accumulate_grad_batches, optimizer=args.optimizer, learning_rate=args.learning_rate, - filters=args.filters, + hidden_channels=args.hidden_channels, num_classes=args.num_classes if args.num_classes is not None else class_info["max_crop_class"] + 1, @@ -1085,44 +611,6 @@ def spatial_kfoldcv(args): ) -def generate_model_graph(args): - from cultionet.models.convstar import StarRNN - from cultionet.models.nunet import ResUNet3Psi - - ppaths = setup_paths(args.project_path) - data_values = torch.load(str(ppaths.norm_file)) - ds = EdgeDataset( - ppaths.train_path, - data_means=data_values.mean, - data_stds=data_values.std, - crop_counts=data_values.crop_counts, - edge_counts=data_values.edge_counts, - ) - - data = ds[0] - xrnn = data.x.reshape(1, data.nbands, data.ntime, data.height, data.width) - filters = 32 - star_rnn_model = StarRNN( - input_dim=data.nbands, - hidden_dim=filters, - n_layers=6, - num_classes_last=2, - ) - x, __ = star_rnn_model(xrnn) - torch.onnx.export( - star_rnn_model, xrnn, ppaths.ckpt_path / "cultionet_starrnn.onnx" - ) - resunet_model = ResUNet3Psi( - in_channels=int(filters * 3), - init_filter=filters, - num_classes=2, - double_dilation=2, - ) - torch.onnx.export( - resunet_model, x, ppaths.ckpt_path / "cultionet_resunet.onnx" - ) - - def train_model(args): seed_everything(args.random_seed, workers=True) @@ -1133,21 +621,23 @@ def train_model(args): class_info = json.load(f) if ( - (args.expected_dim is not None) + (args.expected_time is not None) or not ppaths.norm_file.is_file() or (ppaths.norm_file.is_file() and args.recalc_zscores) ): ds = EdgeDataset( - ppaths.train_path, + root=ppaths.train_path, + log_transform=args.log_transform, + pattern=args.data_pattern, processes=args.processes, - threads_per_worker=args.threads, random_seed=args.random_seed, ) + # Check dimensions - if args.expected_dim is not None: + if args.expected_time is not None: try: ds.check_dims( - args.expected_dim, + args.expected_time, args.expected_height, args.expected_width, args.delete_mismatches, @@ -1155,51 +645,55 @@ def train_model(args): ) except TensorShapeError as e: raise ValueError(e) + ds = EdgeDataset( - ppaths.train_path, + root=ppaths.train_path, + log_transform=args.log_transform, + pattern=args.data_pattern, processes=args.processes, - threads_per_worker=args.threads, random_seed=args.random_seed, ) + # Get the normalization means and std. deviations on the train data # Calculate the values needed to transform to z-scores, using # the training data - if ppaths.norm_file.is_file(): + if ppaths.norm_file.exists(): if args.recalc_zscores: ppaths.norm_file.unlink() - if not ppaths.norm_file.is_file(): + + if not ppaths.norm_file.exists(): + if ds.grid_gpkg_path.exists(): + ds.grid_gpkg_path.unlink() + if args.spatial_partitions is not None: - # train_ds = ds.split_train_val_by_partition( - # spatial_partitions=args.spatial_partitions, - # partition_column=args.partition_column, - # val_frac=args.val_frac, - # partition_name=args.partition_name - # )[0] train_ds = ds.split_train_val( - val_frac=args.val_frac, spatial_overlap_allowed=False + val_frac=args.val_frac, + spatial_overlap_allowed=False, + spatial_balance=True, )[0] else: train_ds = ds.split_train_val(val_frac=args.val_frac)[0] + # Get means and standard deviations from the training dataset - data_values = get_norm_values( + norm_values: NormValues = NormValues.from_dataset( dataset=train_ds, class_info=class_info, - batch_size=args.batch_size, - mean_color=args.mean_color, - sse_color=args.sse_color, + batch_size=args.batch_size * 4, + num_workers=args.load_batch_workers, ) - torch.save(data_values, str(ppaths.norm_file)) + + norm_values.to_file(ppaths.norm_file) else: - data_values = torch.load(str(ppaths.norm_file)) + norm_values = NormValues.from_file(ppaths.norm_file) # Create the train data object again, this time passing # the means and standard deviation tensors ds = EdgeDataset( - ppaths.train_path, - data_means=data_values.mean, - data_stds=data_values.std, - crop_counts=data_values.crop_counts, - edge_counts=data_values.edge_counts, + root=ppaths.train_path, + log_transform=args.log_transform, + pattern=args.data_pattern, + norm_values=norm_values, + augment_prob=args.augment_prob, random_seed=args.random_seed, ) @@ -1207,125 +701,153 @@ def train_model(args): test_ds = None if list((ppaths.test_process_path).glob("*.pt")): test_ds = EdgeDataset( - ppaths.test_path, - data_means=data_values.mean, - data_stds=data_values.std, - crop_counts=data_values.crop_counts, - edge_counts=data_values.edge_counts, + root=ppaths.test_path, + log_transform=args.log_transform, + norm_values=norm_values, random_seed=args.random_seed, ) - if args.expected_dim is not None: + if args.expected_time is not None: try: test_ds.check_dims( - args.expected_dim, args.delete_mismatches, args.dim_color + args.expected_time, args.delete_mismatches, args.dim_color ) except TensorShapeError as e: raise ValueError(e) + test_ds = EdgeDataset( - ppaths.test_path, - data_means=data_values.mean, - data_stds=data_values.std, - crop_counts=data_values.crop_counts, - edge_counts=data_values.edge_counts, + root=ppaths.test_path, + log_transform=args.log_transform, + norm_values=norm_values, random_seed=args.random_seed, ) - # Get balanced class weights - # Reference: https://github.com/scikit-learn/scikit-learn/blob/f3f51f9b6/sklearn/utils/class_weight.py#L10 - # def get_class_weights(counts: torch.Tensor) -> torch.Tensor: - # recip_freq = counts.sum() / (len(counts) * counts) - # weights = recip_freq[torch.arange(0, len(counts))] - - # if torch.cuda.is_available(): - # return weights.to('cuda') - # else: - # return weights + # Combine edge counts with crop counts + class_counts = torch.zeros(len(norm_values.dataset_crop_counts) + 1) + class_counts[1:-1] = norm_values.dataset_crop_counts[1:] + class_counts[-1] = norm_values.dataset_edge_counts[1] + class_counts[0] = ( + norm_values.dataset_edge_counts[0] + - norm_values.dataset_crop_counts[1:].sum() + ) - # class_weights = get_class_weights(data_values.crop_counts) - # edge_weights = get_class_weights(data_values.edge_counts) if torch.cuda.is_available(): - class_counts = data_values.crop_counts.to("cuda") - else: - class_counts = data_values.crop_counts + class_counts = class_counts.to(device="cuda") - # Fit the model - cultionet.fit( - dataset=ds, + cultionet_params = CultionetParams( ckpt_file=ppaths.ckpt_file, + model_name="cultionet_transfer" + if args.process == CLISteps.TRAIN_TRANSFER + else "cultionet", + dataset=ds, test_dataset=test_ds, val_frac=args.val_frac, spatial_partitions=args.spatial_partitions, - partition_name=args.partition_name, - partition_column=args.partition_column, batch_size=args.batch_size, - epochs=args.epochs, - save_top_k=args.save_top_k, - accumulate_grad_batches=args.accumulate_grad_batches, + load_batch_workers=args.load_batch_workers, + edge_class=args.edge_class + if args.edge_class is not None + else class_info["edge_class"], + class_counts=class_counts, + hidden_channels=args.hidden_channels, model_type=args.model_type, + activation_type=args.activation_type, + dropout=args.dropout, dilations=args.dilations, res_block_type=args.res_block_type, attention_weights=args.attention_weights, - activation_type=args.activation_type, - deep_sup_dist=args.deep_sup_dist, - deep_sup_edge=args.deep_sup_edge, - deep_sup_mask=args.deep_sup_mask, optimizer=args.optimizer, + loss_name=args.loss_name, learning_rate=args.learning_rate, lr_scheduler=args.lr_scheduler, steplr_step_size=args.steplr_step_size, + weight_decay=args.weight_decay, + pool_by_max=args.pool_by_max, + batchnorm_first=args.batchnorm_first, scale_pos_weight=args.scale_pos_weight, - filters=args.filters, - num_classes=args.num_classes - if args.num_classes is not None - else class_info["max_crop_class"] + 1, - edge_class=args.edge_class - if args.edge_class is not None - else class_info["edge_class"], - class_counts=class_counts, - reset_model=args.reset_model, - auto_lr_find=args.auto_lr_find, - device=args.device, - devices=args.devices, - profiler=args.profiler, + save_batch_val_metrics=args.save_batch_val_metrics, + epochs=args.epochs, + accumulate_grad_batches=args.accumulate_grad_batches, gradient_clip_val=args.gradient_clip_val, gradient_clip_algorithm=args.gradient_clip_algorithm, - early_stopping_patience=args.patience, - weight_decay=args.weight_decay, precision=args.precision, + device=args.device, + devices=args.devices, + reset_model=args.reset_model, + auto_lr_find=args.auto_lr_find, stochastic_weight_averaging=args.stochastic_weight_averaging, stochastic_weight_averaging_lr=args.stochastic_weight_averaging_lr, stochastic_weight_averaging_start=args.stochastic_weight_averaging_start, - model_pruning=args.model_pruning, - save_batch_val_metrics=args.save_batch_val_metrics, skip_train=args.skip_train, - refine_model=args.refine_model, + finetune=args.finetune, + strategy=args.strategy, + profiler=args.profiler, ) + # Fit the model + if args.process == CLISteps.TRAIN_TRANSFER: + cultionet.fit_transfer(cultionet_params) + else: + cultionet.fit(cultionet_params) + def main(): args_config = open_config((Path(__file__).parent / "args.yml").absolute()) + RichHelpFormatter.styles["argparse.groups"] = "#ACFCD6" + RichHelpFormatter.styles["argparse.args"] = "#FCADED" + RichHelpFormatter.styles["argparse.prog"] = "#AA9439" + RichHelpFormatter.styles["argparse.help"] = "#cacaca" + + description = "# Cultionet: deep learning network for agricultural field boundary detection" + + epilog = """ +# Examples +--- + +## Create training data +```commandline +cultionet create --project-path /projects/data -gs 100 100 -r 10.0 --max-crop-class 1 --crop-column crop_col --num-workers 8 --config-file config.yml +``` + +## View training help +```commandline +cultionet train --help +``` + +## Train a model +```commandline +cultionet train -p . --val-frac 0.1 --epochs 100 --processes 8 --load-batch-workers 8 --batch-size 4 --accumulate-grad-batches 4 --deep-sup +``` + +## Apply inference over an image +```commandline +cultionet predict --project-path /projects/data -o estimates.tif --region imageid --ref-image time_series_vars/imageid/brdf_ts/ms/evi2/20200101.tif --batch-size 4 --load-batch-workers 8 --start-date 2020-01-01 --end-date 2021-01-01 --config-file config.yml +``` + """ + parser = argparse.ArgumentParser( - description="Cultionet models", - formatter_class=argparse.RawTextHelpFormatter, - epilog=args_config["epilog"], + description=Markdown(description, style="argparse.text"), + formatter_class=RichHelpFormatter, + epilog=Markdown(epilog, style="argparse.text"), ) subparsers = parser.add_subparsers(dest="process") available_processes = [ - "create", - "create-predict", - "skfoldcv", - "train", - "maskrcnn", - "predict", - "graph", - "version", + CLISteps.CREATE, + CLISteps.CREATE_PREDICT, + CLISteps.SKFOLDCV, + CLISteps.TRAIN, + CLISteps.PREDICT, + CLISteps.TRAIN_TRANSFER, + CLISteps.PREDICT_TRANSFER, + CLISteps.VERSION, ] for process in available_processes: - subparser = subparsers.add_parser(process) + subparser = subparsers.add_parser( + process, formatter_class=parser.formatter_class + ) - if process == "version": + if process == CLISteps.VERSION: continue subparser.add_argument( @@ -1334,18 +856,32 @@ def main(): dest="project_path", help="The project path (the directory that contains the grid ids)", ) - if process == "graph": - break process_dict = args_config[process.replace("-", "_")] - if process in ("skfoldcv", "maskrcnn"): + # Processes that use train args in addition to 'train' + if process in (CLISteps.SKFOLDCV, CLISteps.TRAIN_TRANSFER): process_dict.update(args_config["train"]) - if process in ("train", "maskrcnn", "predict", "skfoldcv"): + # Processes that use the predict args in addition to 'predict' + if process in (CLISteps.PREDICT_TRANSFER,): + process_dict.update(args_config["predict"]) + # Processes that use args shared between train and predict + if process in ( + CLISteps.TRAIN, + CLISteps.TRAIN_TRANSFER, + CLISteps.PREDICT, + CLISteps.PREDICT_TRANSFER, + CLISteps.SKFOLDCV, + ): process_dict.update(args_config["train_predict"]) process_dict.update(args_config["shared_partitions"]) - if process in ("create", "create-predict"): + if process in (CLISteps.CREATE, CLISteps.CREATE_PREDICT): process_dict.update(args_config["shared_create"]) - if process in ("create", "create-predict", "predict"): + if process in ( + CLISteps.CREATE, + CLISteps.CREATE_PREDICT, + CLISteps.PREDICT, + CLISteps.PREDICT_TRANSFER, + ): process_dict.update(args_config["shared_image"]) process_dict.update(args_config["dates"]) for process_key, process_values in process_dict.items(): @@ -1369,48 +905,79 @@ def main(): **process_values["kwargs"], ) - if process in ("create", "create-predict", "predict"): - subparser.add_argument( - "--config-file", - dest="config_file", - help="The configuration YAML file (default: %(default)s)", - default=(Path(__file__).parent / "config.yml").absolute(), - ) + # if process in ( + # CLISteps.CREATE, + # CLISteps.CREATE_PREDICT, + # CLISteps.PREDICT, + # CLISteps.PREDICT_TRANSFER, + # ): + subparser.add_argument( + "--config-file", + dest="config_file", + help="The configuration YAML file (default: %(default)s)", + default=(Path(__file__).parent / "config.yml").absolute(), + ) args = parser.parse_args() - if args.process == "create-predict": + + if hasattr(args, "config_file") and (args.config_file is not None): + args.config_file = str(args.config_file) + + if args.process == CLISteps.CREATE_PREDICT: setattr(args, "destination", "predict") - if args.process == "version": + if args.process == CLISteps.VERSION: print(cultionet.__version__) return if hasattr(args, "replace_dict"): if args.replace_dict is not None: - setattr(args, "replace_dict", ast.literal_eval(args.replace_dict)) + replace_dict = dict( + list( + map( + lambda x: list(map(int, x.split(":"))), + args.replace_dict.split(" "), + ) + ) + ) + setattr(args, "replace_dict", replace_dict) + + # config = open_config(args.config_file) + # for k, v in config["train"].get("trainer").items(): + # setattr(args, k, v) + # for k, v in config["train"].get("model").items(): + # setattr(args, k, v) project_path = Path(args.project_path) / "ckpt" project_path.mkdir(parents=True, exist_ok=True) + command_path = Path(args.project_path) / "commands" + command_path.mkdir(parents=True, exist_ok=True) now = datetime.now() + with open( - project_path + command_path / f"{args.process}_command_{now.strftime('%Y%m%d-%H%M')}.json", mode="w", ) as f: f.write(json.dumps(vars(args), indent=4)) - if args.process in ("create", "create-predict"): - create_datasets(args) - elif args.process == "skfoldcv": + if args.process in ( + CLISteps.CREATE, + CLISteps.CREATE_PREDICT, + ): + create_dataset(args) + elif args.process == CLISteps.SKFOLDCV: spatial_kfoldcv(args) - elif args.process == "train": + elif args.process in ( + CLISteps.TRAIN, + CLISteps.TRAIN_TRANSFER, + ): train_model(args) - elif args.process == "maskrcnn": - train_maskrcnn(args) - elif args.process == "predict": + elif args.process in ( + CLISteps.PREDICT, + CLISteps.PREDICT_TRANSFER, + ): predict_image(args) - elif args.process == "graph": - generate_model_graph(args) if __name__ == "__main__": diff --git a/src/cultionet/utils/geometry.py b/src/cultionet/utils/geometry.py deleted file mode 100644 index a49b1538..00000000 --- a/src/cultionet/utils/geometry.py +++ /dev/null @@ -1,29 +0,0 @@ -import typing as T -from pathlib import Path - -import geopandas as gpd -import rasterio as rio -from shapely.geometry import Polygon - - -def bounds_to_frame( - left: float, bottom: float, right: float, top: float, crs: T.Optional[str] = 'epsg:4326' -) -> gpd.GeoDataFrame: - """Converts a bounding box to a GeoDataFrame - """ - geom = Polygon([(left, bottom), (left, top), (right, top), (right, bottom), (left, bottom)]) - df = gpd.GeoDataFrame(data=[0], geometry=[geom], crs=crs) - - return df - - -def warp_by_image( - df: gpd.GeoDataFrame, image_path: T.Union[str, Path] -) -> T.Tuple[gpd.GeoDataFrame, str]: - """Warps a GeoDataFrame CRS by a reference image - """ - with rio.open(image_path) as src: - df = df.to_crs(src.crs.to_epsg()) - ref_crs = f'epsg:{df.crs.to_epsg()}' - - return df, ref_crs diff --git a/src/cultionet/utils/logging.py b/src/cultionet/utils/logging.py index 16579042..7042a383 100644 --- a/src/cultionet/utils/logging.py +++ b/src/cultionet/utils/logging.py @@ -1,10 +1,14 @@ import logging +from joblib import Parallel +from tqdm import tqdm + class ColorFormatter(logging.Formatter): """Reference: - https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output + https://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output """ + grey = "\x1b[38;20m" yellow = "\x1b[33;20m" red = "\x1b[31;20m" @@ -17,7 +21,7 @@ class ColorFormatter(logging.Formatter): logging.INFO: grey + format + reset, logging.WARNING: yellow + format + reset, logging.ERROR: red + format + reset, - logging.CRITICAL: bold_red + format + reset + logging.CRITICAL: bold_red + format + reset, } def format(self, record): @@ -29,10 +33,39 @@ def format(self, record): def set_color_logger(logger_name): logger = logging.getLogger(logger_name) - logger.setLevel(logging.DEBUG) - ch = logging.StreamHandler() - ch.setLevel(logging.DEBUG) - ch.setFormatter(ColorFormatter()) - logger.addHandler(ch) + logger.setLevel(logging.INFO) + formatter = ColorFormatter() + + file_handler = logging.FileHandler( + "cultionet.log", mode="w", encoding="utf-8" + ) + file_handler.setLevel(logging.WARNING) + file_handler.setFormatter(formatter) + logger.addHandler(file_handler) + + stream_handler = logging.StreamHandler() + stream_handler.setLevel(logging.INFO) + stream_handler.setFormatter(formatter) + logger.addHandler(stream_handler) return logger + + +class ParallelProgress(Parallel): + """ + Source: + https://stackoverflow.com/questions/37804279/how-can-we-use-tqdm-in-a-parallel-execution-with-joblib + """ + + def __init__(self, tqdm_params: dict, **kwargs): + self.tqdm_params = tqdm_params + + super().__init__(**kwargs) + + def __call__(self, *args, **kwargs): + with tqdm(**self.tqdm_params) as self._pbar: + return Parallel.__call__(self, *args, **kwargs) + + def print_progress(self): + self._pbar.n = self.n_completed_tasks + self._pbar.refresh() diff --git a/src/cultionet/utils/model_preprocessing.py b/src/cultionet/utils/model_preprocessing.py index b101d198..cfa78f30 100644 --- a/src/cultionet/utils/model_preprocessing.py +++ b/src/cultionet/utils/model_preprocessing.py @@ -1,20 +1,20 @@ import typing as T from pathlib import Path -from geowombat.core.util import sort_images_by_date - -import pandas as pd import attr -from tqdm.auto import tqdm +import pandas as pd +from geowombat.core.util import sort_images_by_date from joblib import Parallel +from tqdm.auto import tqdm -class TqdmParallel(Parallel): - """A tqdm progress bar for joblib Parallel tasks +class ParallelProgress(Parallel): + """A tqdm progress bar for joblib Parallel tasks. Reference: https://stackoverflow.com/questions/37804279/how-can-we-use-tqdm-in-a-parallel-execution-with-joblib """ + def __init__(self, tqdm_kwargs: dict, **joblib_kwargs): self.tqdm_kwargs = tqdm_kwargs super().__init__(**joblib_kwargs) @@ -30,7 +30,9 @@ def print_progress(self): @attr.s class VegetationIndices(object): - image_vis: T.List[str] = attr.ib(default=None, validator=attr.validators.instance_of(list)) + image_vis: T.List[str] = attr.ib( + default=None, validator=attr.validators.instance_of(list) + ) @property def n_vis(self): @@ -39,8 +41,12 @@ def n_vis(self): @attr.s class TrainInputs(object): - regions: T.List[str] = attr.ib(default=None, validator=attr.validators.instance_of(list)) - years: T.List[int] = attr.ib(default=None, validator=attr.validators.instance_of(list)) + regions: T.List[str] = attr.ib( + default=None, validator=attr.validators.instance_of(list) + ) + years: T.List[int] = attr.ib( + default=None, validator=attr.validators.instance_of(list) + ) def __attrs_post_init__(self): region_list = self.regions @@ -50,13 +56,15 @@ def __attrs_post_init__(self): def get_time_series_list( feature_path: Path, - start_year: int, - start_date: str, - end_date: str, - date_format: str = '%Y%j' + date_format: str = '%Y%j', + start_date: T.Optional[pd.Timestamp] = None, + end_date: T.Optional[pd.Timestamp] = None, + end_year: T.Optional[T.Union[int, str]] = None, + start_mmdd: T.Optional[str] = None, + end_mmdd: T.Optional[str] = None, + num_months: T.Optional[int] = None, ) -> T.List[str]: - """Gets a list of time series paths - """ + """Gets a list of time series paths.""" # Get the requested time slice image_dict = sort_images_by_date( feature_path, @@ -64,17 +72,34 @@ def get_time_series_list( date_pos=0, date_start=0, date_end=7 if date_format == '%Y%j' else 8, - date_format=date_format + date_format=date_format, ) + # Create a DataFrame with paths and dates df = pd.DataFrame( data=list(image_dict.keys()), columns=['name'], - index=list(image_dict.values()) + index=list(image_dict.values()), ) + + if (start_date is not None) and (end_date is not None): + start_date_stamp = start_date + end_date_stamp = end_date + else: + end_date_stamp = pd.Timestamp( + f"{end_year}-{end_mmdd}" + ) + pd.DateOffset(days=1) + start_year = (end_date_stamp - pd.DateOffset(months=num_months)).year + start_date_stamp = pd.Timestamp(f"{start_year}-{start_mmdd}") + + image_df = df.loc[start_date_stamp:end_date_stamp] + + if num_months is not None: + assert ( + num_months <= len(image_df.index) <= num_months + 1 + ), "The image list is not the correct length." + # Slice the requested time series from the dataFrame - ts_list = df.loc[ - f'{start_year}-{start_date}':f'{start_year+1}-{end_date}' - ].name.values.tolist() + ts_list = image_df.name.values.tolist() return ts_list diff --git a/src/cultionet/utils/normalize.py b/src/cultionet/utils/normalize.py index 00189f85..06a23a40 100644 --- a/src/cultionet/utils/normalize.py +++ b/src/cultionet/utils/normalize.py @@ -1,208 +1,213 @@ import typing as T -from dataclasses import dataclass -from functools import partial from pathlib import Path -from ..data.datasets import EdgeDataset -from ..data.modules import EdgeDataModule -from ..utils.model_preprocessing import TqdmParallel -from ..utils.stats import ( - tally_stats, - cache_load_enabled, - Quantile, - Variance -) - -from tqdm import tqdm +import joblib import torch -from joblib import delayed, parallel_backend - +from einops import rearrange +from rich.progress import ( + BarColumn, + Progress, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) +from rich.style import Style +from torch.utils.data import DataLoader, Dataset -@dataclass -class NormValues: - mean: torch.Tensor - std: torch.Tensor - max: torch.Tensor - crop_counts: torch.Tensor - edge_counts: torch.Tensor +from ..data import Data +from ..data.utils import collate_fn +from .stats import Quantile, Variance, cache_load_enabled, tally_stats def add_dim(d: torch.Tensor) -> torch.Tensor: return d.unsqueeze(0) -def inverse_transform(x: torch.Tensor, data_values: NormValues) -> torch.Tensor: - """Transforms the inverse of the z-scores""" - return data_values.std*x + data_values.mean - - -def get_norm_values( - dataset: T.Union[EdgeDataset, torch.utils.data.Dataset], - batch_size: int, - class_info: T.Dict[str, int], - num_workers: int = 0, - processes: int = 1, - threads_per_worker: int = 1, - centering: str = 'mean', - mean_color: str = '#ffffff', - sse_color: str = '#ffffff' -) -> NormValues: - """Normalizes a dataset to z-scores - """ - if not isinstance(dataset, EdgeDataset): - data_loader = torch.utils.data.DataLoader( - dataset, - batch_size=batch_size, - shuffle=True, - num_workers=0 +class NormValues: + """Normalization values.""" + + def __init__( + self, + dataset_mean: torch.Tensor, + dataset_std: torch.Tensor, + dataset_crop_counts: torch.Tensor, + dataset_edge_counts: torch.Tensor, + num_channels: int, + lower_bound: T.Optional[torch.Tensor] = None, + upper_bound: T.Optional[torch.Tensor] = None, + ): + self.dataset_mean = dataset_mean + self.dataset_std = dataset_std + self.dataset_crop_counts = dataset_crop_counts + self.dataset_edge_counts = dataset_edge_counts + self.num_channels = num_channels + self.lower_bound = lower_bound + self.upper_bound = upper_bound + + def __repr__(self): + return ( + "NormValues(" + f" dataset_mean={self.dataset_mean}," + f" dataset_std={self.dataset_std}," + f" dataset_crop_counts={self.dataset_crop_counts}," + f" dataset_edge_counts={self.dataset_edge_counts}," + f" num_channels={self.num_channels}," + f" lower_bound={self.lower_bound}," + f" upper_bound={self.upper_bound}," + ")" + ) + + def __call__(self, batch: Data) -> Data: + return self.transform(batch) + + def transform(self, batch: Data) -> Data: + r"""Normalizes data by the Dynamic World log method or by z-scores. + + Parameters + ========== + batch + A tensor data object. + data_means + The data feature-wise means. + data_stds + The data feature-wise standard deviations. + + z = (x - μ) / σ + """ + batch_copy = batch.copy() + batch_copy.x = ( + batch_copy.x - self.dataset_mean.to(device=batch_copy.x.device) + ) / self.dataset_std.to(device=batch_copy.x.device) + + return batch_copy + + def inverse_transform(self, batch: Data) -> Data: + """Transforms the inverse of the z-scores.""" + batch_copy = batch.copy() + batch_copy.x = self.dataset_std.to( + device=batch_copy.x.device + ) * batch_copy.x + self.dataset_mean.to(device=batch_copy.x.device) + + return batch_copy + + @property + def data_dict(self) -> dict: + return { + 'dataset_mean': self.dataset_mean, + 'dataset_std': self.dataset_std, + 'dataset_crop_counts': self.dataset_crop_counts, + 'dataset_edge_counts': self.dataset_edge_counts, + 'num_channels': self.num_channels, + 'lower_bound': self.lower_bound, + 'upper_bound': self.upper_bound, + } + + def to_file( + self, filename: T.Union[Path, str], compress: str = 'zlib' + ) -> None: + joblib.dump( + self.data_dict, + filename, + compress=compress, ) - data_maxs = torch.zeros(3, dtype=torch.float) - data_sums = torch.zeros(3, dtype=torch.float) - sse = torch.zeros(3, dtype=torch.float) - pix_count = 0.0 - with tqdm( - total=int(len(dataset)/batch_size), - desc='Calculating means', - colour=mean_color - ) as pbar: - for x, y in data_loader: - channel_maxs = torch.tensor([x[0, c, ...].max() for c in range(0, x.shape[1])], dtype=torch.float) - data_maxs = torch.where(channel_maxs > data_maxs, channel_maxs, data_maxs) - # Sum over all data - data_sums += x.sum(dim=(0, 2, 3)) - pix_count += (x.shape[2] * x.shape[3]) - - pbar.update(1) - - data_means = data_sums / float(pix_count) - with tqdm( - total=int(len(dataset)/batch_size), - desc='Calculating SSEs', - colour=sse_color - ) as pbar: - for x, y in data_loader: - sse += ((x - data_means.unsqueeze(0)[..., None, None]).pow(2)).sum(dim=(0, 2, 3)) - - pbar.update(1) - - data_stds = torch.sqrt(sse / pix_count) - - else: - data_module = EdgeDataModule( - train_ds=dataset, + @classmethod + def from_file(cls, filename: T.Union[Path, str]) -> "NormValues": + return cls(**joblib.load(filename)) + + @classmethod + def from_dataset( + cls, + dataset: Dataset, + batch_size: int, + class_info: T.Dict[str, int], + num_workers: int = 0, + centering: str = 'median', + lower_quantile: float = 0.05, + upper_quantile: float = 0.95, + ) -> "NormValues": + """Calculates dataset statistics.""" + + lower_bound = None + upper_bound = None + + data_loader = DataLoader( + dataset, batch_size=batch_size, - num_workers=num_workers + num_workers=num_workers, + shuffle=False, + collate_fn=collate_fn, ) - if centering == 'median': - stat_var = Variance(method='median') - stat_q = Quantile(r=1024*6) - tmp_cache_path = Path.home().absolute() / '.cultionet' - tmp_cache_path.mkdir(parents=True, exist_ok=True) - var_data_cache = tmp_cache_path / '_var.npz' - q_data_cache = tmp_cache_path / '_q.npz' - crop_counts = torch.zeros(class_info['max_crop_class']+1).long() - edge_counts = torch.zeros(2).long() - with cache_load_enabled(True): - with tqdm( - total=int(len(dataset) / batch_size), - desc='Calculating dataset statistics' - ) as pbar: - for batch in tally_stats( + stat_var = Variance(method='median') + stat_q = Quantile(r=1024 * 6) + tmp_cache_path = Path.home().absolute() / '.cultionet' + tmp_cache_path.mkdir(parents=True, exist_ok=True) + var_data_cache = tmp_cache_path / '_var.npz' + q_data_cache = tmp_cache_path / '_q.npz' + crop_counts = torch.zeros(class_info['max_crop_class'] + 1).long() + edge_counts = torch.zeros(2).long() + with cache_load_enabled(True): + with Progress( + TextColumn("Calculating stats", style=Style(color="#cacaca")), + TextColumn("•", style=Style(color="#cacaca")), + BarColumn( + style="#ACFCD6", + complete_style="#AA9439", + finished_style="#ACFCD6", + pulse_style="#FCADED", + ), + TaskProgressColumn(), + TextColumn("•", style=Style(color="#cacaca")), + TimeElapsedColumn(), + ) as pbar: + for batch in pbar.track( + tally_stats( stats=(stat_var, stat_q), - loader=data_module.train_dataloader(), - caches=(var_data_cache, q_data_cache) - ): - stat_var.add(batch.x) - stat_q.add(batch.x) - - crop_counts[0] += ((batch.y == 0) | (batch.y == class_info['edge_class'])).sum() - for i in range(1, class_info['edge_class']): - crop_counts[i] += (batch.y == i).sum() - edge_counts[0] += (batch.y != class_info['edge_class']).sum() - edge_counts[1] += (batch.y == class_info['edge_class']).sum() - - pbar.update(1) - - data_stds = stat_var.std() - data_means = stat_q.median() - - var_data_cache.unlink() - q_data_cache.unlink() - tmp_cache_path.rmdir() + loader=data_loader, + caches=(var_data_cache, q_data_cache), + ), + total=len(data_loader), + ): + # Stack samples + x = rearrange(batch.x, 'b c t h w -> (b t h w) c') + + # Update the stats + stat_var.add(x) + stat_q.add(x) + + # Update counts + crop_counts[0] += ( + (batch.y == 0) | (batch.y == class_info['edge_class']) + ).sum() + for i in range(1, class_info['edge_class']): + crop_counts[i] += (batch.y == i).sum() + + edge_counts[0] += ( + (batch.y >= 0) & (batch.y != class_info['edge_class']) + ).sum() + edge_counts[1] += ( + batch.y == class_info['edge_class'] + ).sum() + + data_stds = stat_var.std() + if centering == 'mean': + data_means = stat_q.mean() else: - def get_info( - x: torch.Tensor, y: torch.Tensor - ) -> T.Tuple[torch.Tensor, int, torch.Tensor, torch.Tensor]: - crop_counts = torch.zeros(class_info['max_crop_class']+1) - edge_counts = torch.zeros(2) - crop_counts[0] = ((y == 0) | (y == class_info['edge_class'])).sum() - for i in range(1, class_info['edge_class']): - crop_counts[i] = (y == i).sum() - edge_counts[0] = (y != class_info['edge_class']).sum() - edge_counts[1] = (y == class_info['edge_class']).sum() - - return x.sum(dim=0), x.shape[0], crop_counts, edge_counts - - with parallel_backend( - backend='loky', - n_jobs=processes, - inner_max_num_threads=threads_per_worker - ): - with TqdmParallel( - tqdm_kwargs={ - 'total': int(len(dataset) / batch_size), - 'desc': 'Calculating means', - 'colour': mean_color - } - ) as pool: - results = pool( - delayed(get_info)( - batch.x, batch.y - ) for batch in data_module.train_dataloader() - ) - data_sums, pix_count, crop_counts, edge_counts = list(map(list, zip(*results))) - - data_sums = torch.stack(data_sums).sum(dim=0) - pix_count = torch.tensor(pix_count).sum() - crop_counts = torch.stack(crop_counts).sum(dim=0) - edge_counts = torch.stack(edge_counts).sum(dim=0) - data_means = data_sums / float(pix_count) - - def get_sse(x_mu: torch.Tensor, x: torch.Tensor) -> torch.Tensor: - return ((x - x_mu).pow(2)).sum(dim=0) - - sse_partial = partial(get_sse, add_dim(data_means)) - - with parallel_backend( - backend='loky', - n_jobs=processes, - inner_max_num_threads=threads_per_worker - ): - with TqdmParallel( - tqdm_kwargs={ - 'total': int(len(dataset) / batch_size), - 'desc': 'Calculating SSEs', - 'colour': sse_color - } - ) as pool: - sses = pool( - delayed(sse_partial)( - batch.x - ) for batch in data_module.train_dataloader() - ) - - sses = torch.stack(sses).sum(dim=0) - data_stds = torch.sqrt(sses / float(pix_count)) - data_maxs = torch.zeros_like(data_means) - - norm_values = NormValues( - mean=data_means, - std=data_stds, - max=data_maxs, - crop_counts=crop_counts, - edge_counts=edge_counts - ) - - return norm_values + data_means = stat_q.median() + lower_bound = stat_q.quantiles(lower_quantile) + upper_bound = stat_q.quantiles(upper_quantile) + + var_data_cache.unlink() + q_data_cache.unlink() + tmp_cache_path.rmdir() + + return cls( + dataset_mean=rearrange(data_means, 'c -> 1 c 1 1 1'), + dataset_std=rearrange(data_stds, 'c -> 1 c 1 1 1'), + lower_bound=rearrange(lower_bound, 'c -> 1 c 1 1 1'), + upper_bound=rearrange(upper_bound, 'c -> 1 c 1 1 1'), + dataset_crop_counts=crop_counts, + dataset_edge_counts=edge_counts, + num_channels=len(data_means), + ) diff --git a/src/cultionet/utils/project_paths.py b/src/cultionet/utils/project_paths.py index 15a9e980..047f6788 100644 --- a/src/cultionet/utils/project_paths.py +++ b/src/cultionet/utils/project_paths.py @@ -2,12 +2,8 @@ from dataclasses import dataclass import shutil import typing as T -import enum - -class Destinations(enum.Enum): - train = 'train' - test = 'test' +from ..enums import Destinations, ModelNames @dataclass @@ -31,6 +27,14 @@ class ProjectPaths: loss_file: Path norm_file: Path + @property + def grid_format(self) -> str: + return "{region}_grid_{end_year}.gpkg" + + @property + def polygon_format(self) -> str: + return "{region}_poly_{end_year}.gpkg" + def remove_train_path(self): if self.process_path.is_dir(): for fn in self.process_path.glob('*.pt'): @@ -45,26 +49,30 @@ def get_process_path(self, destination: str) -> Path: def setup_paths( project_path: T.Union[str, Path, bytes], append_ts: T.Optional[bool] = True, - ckpt_name: T.Optional[str] = 'last.ckpt' + ckpt_name: T.Optional[str] = ModelNames.CKPT_NAME, ) -> ProjectPaths: project_path = Path(project_path) - image_path = project_path / 'time_series_vars' if append_ts else project_path + image_path = ( + project_path / Destinations.TIME_SERIES_VARS + if append_ts + else project_path + ) composite_path = project_path.parent / 'composites' proba_path = project_path.parent / 'composites_probas' - figure_path = project_path / 'figures' - data_path = project_path / 'data' - ckpt_path = project_path / 'ckpt' - classes_info_path = data_path / 'classes.info' - train_path = data_path / 'train' - test_path = data_path / 'test' - predict_path = data_path / 'predict' - process_path = train_path / 'processed' - test_process_path = test_path / 'processed' - predict_process_path = predict_path / 'processed' - edge_training_path = project_path / 'user_train' + figure_path = project_path / Destinations.FIGURES + data_path = project_path / Destinations.DATA + ckpt_path = project_path / Destinations.CKPT + classes_info_path = data_path / ModelNames.CLASS_INFO + train_path = data_path / Destinations.TRAIN + test_path = data_path / Destinations.TEST + predict_path = data_path / Destinations.PREDICT + process_path = train_path / Destinations.PROCESSED + test_process_path = test_path / Destinations.PROCESSED + predict_process_path = predict_path / Destinations.PROCESSED + edge_training_path = project_path / Destinations.USER_TRAIN ckpt_file = ckpt_path / ckpt_name loss_file = ckpt_path / 'losses.npy' - norm_file = ckpt_path / 'last.norm' + norm_file = ckpt_path / ModelNames.NORM for p in [ proba_path, @@ -73,7 +81,7 @@ def setup_paths( process_path, test_process_path, predict_process_path, - ckpt_path + ckpt_path, ]: p.mkdir(exist_ok=True, parents=True) @@ -95,5 +103,5 @@ def setup_paths( edge_training_path=edge_training_path, ckpt_file=ckpt_file, loss_file=loss_file, - norm_file=norm_file + norm_file=norm_file, ) diff --git a/src/cultionet/utils/reshape.py b/src/cultionet/utils/reshape.py index 6c944d42..30b7b632 100644 --- a/src/cultionet/utils/reshape.py +++ b/src/cultionet/utils/reshape.py @@ -1,83 +1,57 @@ import typing as T -import numpy as np -from rasterio.windows import Window import attr +import numpy as np import torch import torch.nn.functional as F - - -def nd_to_columns(data, layers, rows, columns): - """Reshapes an array from nd layout to [samples (rows*columns) x dimensions] - """ - if layers == 1: - return np.ascontiguousarray(data.flatten()[:, np.newaxis]) - else: - return np.ascontiguousarray(data.transpose(1, 2, 0).reshape(rows*columns, layers)) - - -def columns_to_nd(data, layers, rows, columns): - """Reshapes an array from columns layout to [layers x rows x columns] - """ - if layers == 1: - return np.ascontiguousarray(data.reshape(columns, rows).T) - else: - return np.ascontiguousarray(data.T.reshape(layers, rows, columns)) +from rasterio.windows import Window @attr.s class ModelOutputs(object): - """A class for reshaping of the model output estimates - """ - distance: torch.Tensor = attr.ib(validator=attr.validators.instance_of(torch.Tensor)) - edge: torch.Tensor = attr.ib(validator=attr.validators.instance_of(torch.Tensor)) - crop: torch.Tensor = attr.ib(validator=attr.validators.instance_of(torch.Tensor)) + """A class for reshaping of the model output estimates.""" + + distance: torch.Tensor = attr.ib( + validator=attr.validators.instance_of(torch.Tensor) + ) + edge: torch.Tensor = attr.ib( + validator=attr.validators.instance_of(torch.Tensor) + ) + crop: torch.Tensor = attr.ib( + validator=attr.validators.instance_of(torch.Tensor) + ) crop_type: T.Union[torch.Tensor, None] = attr.ib( - validator=attr.validators.optional(attr.validators.instance_of(torch.Tensor)) + validator=attr.validators.optional( + attr.validators.instance_of(torch.Tensor) + ) ) instances: T.Optional[T.Union[None, np.ndarray]] = attr.ib( default=None, - validator=attr.validators.optional(attr.validators.instance_of(np.ndarray)) + validator=attr.validators.optional( + attr.validators.instance_of(np.ndarray) + ), ) apply_softmax: T.Optional[bool] = attr.ib( - default=False, - validator=attr.validators.instance_of(bool) + default=False, validator=attr.validators.instance_of(bool) ) - def stack_outputs(self, w: Window, w_pad: Window) -> np.ndarray: - self.reshape(w, w_pad) - self.nan_to_num() - if (self.crop_type_probas is not None) and len(self.crop_type_probas.shape) == 3: - stack_items = ( - self.edge_dist[None], - self.edge_probas[None], - self.crop_probas[None] - ) - if self.crop_type_probas is not None: - stack_items += (self.crop_type_probas,) - if self.instances is not None: - stack_items += (self.instances[None],) - - return np.vstack(stack_items) - else: - stack_items = ( - self.edge_dist, - self.edge_probas, - self.crop_probas - ) - if self.crop_type_probas is not None: - stack_items += (self.crop_type_probas,) - if self.instances is not None: - stack_items += (self.instances,) - - return np.stack(stack_items) + def stack_outputs(self) -> np.ndarray: + return ( + torch.cat((self.distance, self.edge, self.crop), dim=0) + .detach() + .cpu() + .numpy() + ) @staticmethod - def _clip_and_reshape(tarray: torch.Tensor, window_obj: Window) -> np.ndarray: - if (len(tarray.shape) == 1) or ((len(tarray.shape) > 1) and (tarray.shape[1] == 1)): + def _clip_and_reshape( + tarray: torch.Tensor, window_obj: Window + ) -> np.ndarray: + if (len(tarray.shape) == 1) or ( + (len(tarray.shape) > 1) and (tarray.shape[1] == 1) + ): return ( - tarray - .contiguous() + tarray.contiguous() .view(-1) .detach() .cpu() @@ -89,8 +63,7 @@ def _clip_and_reshape(tarray: torch.Tensor, window_obj: Window) -> np.ndarray: n_layers = tarray.shape[1] return ( - tarray - .contiguous() + tarray.contiguous() .t() .detach() .cpu() @@ -99,11 +72,11 @@ def _clip_and_reshape(tarray: torch.Tensor, window_obj: Window) -> np.ndarray: .reshape(n_layers, window_obj.height, window_obj.width) ) - def inputs_to_probas(self, inputs: np.ndarray, w_pad: Window) -> np.ndarray: + def inputs_to_probas( + self, inputs: np.ndarray, w_pad: Window + ) -> np.ndarray: if self.apply_softmax: - inputs = F.softmax( - inputs, dim=1, dtype=inputs.dtype - )[:, 1] + inputs = F.softmax(inputs, dim=1, dtype=inputs.dtype)[:, 1] else: if len(inputs.shape) > 1: if inputs.shape[1] > 1: @@ -125,19 +98,18 @@ def reshape(self, w: Window, w_pad: Window) -> None: # Get the crop-type probabilities self.crop_type_probas = None if self.crop_type is not None: - self.crop_type_probas = self.inputs_to_probas(self.crop_type, w_pad) + self.crop_type_probas = self.inputs_to_probas( + self.crop_type, w_pad + ) # Reshape the window chunk and slice off padding i = abs(w.row_off - w_pad.row_off) j = abs(w.col_off - w_pad.col_off) - slicer = ( - slice(i, i+w.height), - slice(j, j+w.width) - ) + slicer = (slice(i, i + w.height), slice(j, j + w.width)) slicer3d = ( slice(0, None), - slice(i, i+w.height), - slice(j, j+w.width) + slice(i, i + w.height), + slice(j, j + w.width), ) self.edge_dist = self.edge_dist[slicer] self.edge_probas = self.edge_probas[slicer] @@ -156,31 +128,19 @@ def reshape(self, w: Window, w_pad: Window) -> None: def nan_to_num(self): # Convert the data type to integer and set 'no data' values - self.edge_dist =np.nan_to_num( - self.edge_dist, - nan=-1.0, - neginf=-1.0, - posinf=-1.0 + self.edge_dist = np.nan_to_num( + self.edge_dist, nan=-1.0, neginf=-1.0, posinf=-1.0 ).astype('float32') self.edge_probas = np.nan_to_num( - self.edge_probas, - nan=-1.0, - neginf=-1.0, - posinf=-1.0 + self.edge_probas, nan=-1.0, neginf=-1.0, posinf=-1.0 ).astype('float32') self.crop_probas = np.nan_to_num( - self.crop_probas, - nan=-1.0, - neginf=-1.0, - posinf=-1.0 + self.crop_probas, nan=-1.0, neginf=-1.0, posinf=-1.0 ).astype('float32') if self.crop_type_probas is not None: self.crop_type_probas = np.nan_to_num( - self.crop_type_probas, - nan=-1.0, - neginf=-1.0, - posinf=-1.0 + self.crop_type_probas, nan=-1.0, neginf=-1.0, posinf=-1.0 ).astype('float32') diff --git a/src/cultionet/utils/stats.py b/src/cultionet/utils/stats.py index 3b9bd5e6..b6cd3ef0 100644 --- a/src/cultionet/utils/stats.py +++ b/src/cultionet/utils/stats.py @@ -2,28 +2,29 @@ Source: https://gist.github.com/davidbau/00a9b6763a260be8274f6ba22df9a145 """ -import os import math +import os import struct import typing as T from pathlib import Path import numpy as np import torch -from torch_geometric.data import DataLoader - +from torch.utils.data import DataLoader null_numpy_value = np.array( - struct.unpack('>d', struct.pack('>Q', 0xfff8000000000002))[0], - dtype=np.float64 + struct.unpack('>d', struct.pack('>Q', 0xFFF8000000000002))[0], + dtype=np.float64, ) def is_null_numpy_value(v) -> bool: return ( - isinstance(v, np.ndarray) and np.ndim(v) == 0 - and v.dtype == np.float64 and np.isnan(v) - and 0xfff8000000000002 == struct.unpack('>Q', struct.pack('>d', v))[0] + isinstance(v, np.ndarray) + and np.ndim(v) == 0 + and v.dtype == np.float64 + and np.isnan(v) + and 0xFFF8000000000002 == struct.unpack('>Q', struct.pack('>d', v))[0] ) @@ -42,8 +43,7 @@ def unbox_numpy_null(d): def resolve_state_dict(s): - """Resolves a state, which can be a filename or a dict-like object. - """ + """Resolves a state, which can be a filename or a dict-like object.""" if isinstance(s, str): return unbox_numpy_null(np.load(s)) return s @@ -55,7 +55,7 @@ def save_cached_state(cachefile, obj, args): dat = obj.state_dict() for a, v in args.items(): if a in dat: - assert (dat[a] == v) + assert dat[a] == v dat[a] = v if isinstance(cachefile, dict): cachefile.clear() @@ -66,14 +66,15 @@ def save_cached_state(cachefile, obj, args): global_load_cache_enabled = True + + def load_cached_state( cachefile: T.Union[Path, str], args: T.Optional[dict] = None, quiet: T.Optional[bool] = False, - throw: T.Optional[bool] = False + throw: T.Optional[bool] = False, ): - """Resolves a state, which can be a filename or a dict-like object. - """ + """Resolves a state, which can be a filename or a dict-like object.""" if args is None: args = {} if not global_load_cache_enabled or cachefile is None: @@ -81,7 +82,7 @@ def load_cached_state( try: if isinstance(cachefile, dict): dat = cachefile - cachefile = 'state' # for printed messages + cachefile = 'state' # for printed messages else: dat = unbox_numpy_null(np.load(cachefile)) for a, v in args.items(): @@ -101,64 +102,55 @@ def load_cached_state( class Stat(object): - """Abstract base class for a running pytorch statistic. - """ + """Abstract base class for a running pytorch statistic.""" + def __init__(self, state): """By convention, all Stat subclasses can be initialized by passing - state=; and then they will initialize by calling load_state_dict. - """ + state=; and then they will initialize by calling load_state_dict.""" self.load_state_dict(resolve_state_dict(state)) def add(self, x, *args, **kwargs): """Observes a batch of samples to be incorporated into the statistic. - Dimension 0 should be the batch dimension, and dimension 1 should - be the feature dimension of the pytorch tensor x. + + Dimension 0 should be the batch dimension, and dimension 1 should be + the feature dimension of the pytorch tensor x. """ pass def load_state_dict(self, d): - """Loads this Stat from a dictionary of numpy arrays as saved - by state_dict. - """ + """Loads this Stat from a dictionary of numpy arrays as saved by + state_dict.""" pass def state_dict(self): - """Saves this Stat as a dictionary of numpy arrays that can be - stored in an npz or reloaded later using load_state_dict. - """ + """Saves this Stat as a dictionary of numpy arrays that can be stored + in an npz or reloaded later using load_state_dict.""" return {} def save(self, filename): - """Saves this stat as an npz file containing the state_dict. - """ + """Saves this stat as an npz file containing the state_dict.""" save_cached_state(filename, self, {}) def load(self, filename): - """ - Loads this stat from an npz file containing a saved state_dict. - """ + """Loads this stat from an npz file containing a saved state_dict.""" self.load_state_dict( load_cached_state(filename, {}, quiet=True, throw=True) ) def to_(self, device): - """Moves this Stat to the given device. - """ + """Moves this Stat to the given device.""" pass def cpu_(self): - """Moves this Stat to the cpu device. - """ + """Moves this Stat to the cpu device.""" self.to_('cpu') def cuda_(self): - """Moves this Stat to the default cuda device. - """ + """Moves this Stat to the default cuda device.""" self.to_('cuda') def _normalize_add_shape(self, x, attr='data_shape'): - """Flattens input data to 2d. - """ + """Flattens input data to 2d.""" if not torch.is_tensor(x): x = torch.tensor(x) if len(x.shape) < 1: @@ -173,8 +165,7 @@ def _normalize_add_shape(self, x, attr='data_shape'): return x.view(x.shape[0], int(np.prod(data_shape))) def _restore_result_shape(self, x, attr='data_shape'): - """Restores output data to input data shape. - """ + """Restores output data to input data shape.""" data_shape = getattr(self, attr, None) if data_shape is None: return x @@ -183,8 +174,8 @@ def _restore_result_shape(self, x, attr='data_shape'): class Mean(Stat): - """Running mean - """ + """Running mean.""" + def __init__(self, state=None): if state is not None: return super().__init__(state) @@ -228,7 +219,9 @@ def load_state_dict(self, state): self.count = state['count'] self.batchcount = state['batchcount'] self._mean = torch.from_numpy(state['mean']) - self.data_shape = None if state['data_shape'] is None else tuple(state['data_shape']) + self.data_shape = ( + None if state['data_shape'] is None else tuple(state['data_shape']) + ) def state_dict(self): return dict( @@ -236,7 +229,7 @@ def state_dict(self): count=self.count, data_shape=self.data_shape and tuple(self.data_shape), batchcount=self.batchcount, - mean=self._mean.cpu().numpy() + mean=self._mean.cpu().numpy(), ) @@ -256,7 +249,10 @@ class Quantile(Stat): Based on the optimal KLL quantile algorithm by Karnin, Lang, and Liberty from FOCS 2016. http://ieee-focs.org/FOCS-2016-Papers/3933a071.pdf """ - def __init__(self, r: int = 3072, buffersize: int = None, seed=None, state=None): + + def __init__( + self, r: int = 3072, buffersize: int = None, seed=None, state=None + ): if state is not None: return super().__init__(state) @@ -290,7 +286,7 @@ def _lazy_init(self, incoming): self.depth, self.resolution, dtype=self.dtype, - device=self.device + device=self.device, ) ] self.extremes = torch.zeros( @@ -300,8 +296,7 @@ def _lazy_init(self, incoming): self.extremes[:, -1] = -float('inf') def to_(self, device): - """Switches internal storage to specified device. - """ + """Switches internal storage to specified device.""" if device != self.device: old_data = self.data old_extremes = self.extremes @@ -326,7 +321,7 @@ def add(self, incoming): self._scan_extremes(incoming) chunksize = int(math.ceil(self.buffersize / self.samplerate)) for index in range(0, len(incoming), chunksize): - batch = incoming[index:index + chunksize] + batch = incoming[index : index + chunksize] sample = sample_portion(batch, self.samplerate) if len(sample): self._add_every(sample) @@ -350,8 +345,8 @@ def _add_every(self, incoming): ff = self.firstfree[0] available = self.data[0].shape[1] - ff copycount = min(available, supplied - index) - self.data[0][:, ff:ff + copycount] = torch.t( - incoming[index:index + copycount, :] + self.data[0][:, ff : ff + copycount] = torch.t( + incoming[index : index + copycount, :] ) self.firstfree[0] += copycount index += copycount @@ -373,7 +368,9 @@ def _shift(self): offset = self._randbit() position = self.firstfree[index + 1] subset = data[:, offset::2] - self.data[index + 1][:, position : position + subset.shape[1]] = subset + self.data[index + 1][ + :, position : position + subset.shape[1] + ] = subset self.firstfree[index] = 0 self.firstfree[index + 1] += subset.shape[1] index += 1 @@ -460,9 +457,9 @@ def mean(self): def var(self, unbiased=True): mean = self.mean()[:, None] - return self.integrate( - lambda x: (x - mean).pow(2) - ) / (self.count - (1 if unbiased else 0)) + return self.integrate(lambda x: (x - mean).pow(2)) / ( + self.count - (1 if unbiased else 0) + ) def std(self, unbiased=True): return self.var(unbiased=unbiased).sqrt() @@ -472,7 +469,10 @@ def _expand(self): if cap > 0: # First, make a new layer of the proper capacity. self.data.insert( - 0, torch.zeros(self.depth, cap, dtype=self.dtype, device=self.device) + 0, + torch.zeros( + self.depth, cap, dtype=self.dtype, device=self.device + ), ) self.firstfree.insert(0, 0) else: @@ -491,9 +491,9 @@ def _expand(self): if self.data[index - 1].shape[1] - (amount + position) >= ( -(-self.data[index - 2].shape[1] // 2) if (index - 1) else 1 ): - self.data[index - 1][:, position : position + amount] = self.data[ - index - ][:, :amount] + self.data[index - 1][ + :, position : position + amount + ] = self.data[index][:, :amount] self.firstfree[index - 1] += amount self.firstfree[index] = 0 else: @@ -521,13 +521,15 @@ def _weighted_summary(self, sort=True): self._scan_extremes(self.data[0][:, : self.firstfree[0]].t()) size = sum(self.firstfree) weights = torch.FloatTensor(size) # Floating point - summary = torch.zeros(self.depth, size, dtype=self.dtype, device=self.device) + summary = torch.zeros( + self.depth, size, dtype=self.dtype, device=self.device + ) index = 0 for level, ff in enumerate(self.firstfree): if ff == 0: continue summary[:, index : index + ff] = self.data[level][:, :ff] - weights[index : index + ff] = 2.0 ** level + weights[index : index + ff] = 2.0**level index += ff assert index == summary.shape[1] if sort: @@ -565,7 +567,9 @@ def quantiles(self, quantiles): nsm = summary.cpu().detach().numpy() for d in range(self.depth): result[d] = torch.tensor( - np.interp(nq, ncw[d], nsm[d]), dtype=self.dtype, device=self.device + np.interp(nq, ncw[d], nsm[d]), + dtype=self.dtype, + device=self.device, ) return result.view((self.depth,) + qshape) @@ -575,7 +579,11 @@ def integrate(self, fun): for level, ff in enumerate(self.firstfree): if ff == 0: continue - result.append(torch.sum(fun(self.data[level][:, :ff]) * (2.0 ** level), dim=-1)) + result.append( + torch.sum( + fun(self.data[level][:, :ff]) * (2.0**level), dim=-1 + ) + ) if len(result) == 0: return None @@ -585,9 +593,11 @@ def readout(self, count=1001): return self.quantiles(torch.linspace(0.0, 1.0, count)) def normalize(self, data): - """Given input data as taken from the training distirbution, - normalizes every channel to reflect quantile values, - uniformly distributed, within [0, 1]. + """Given input data as taken from the training distirbution, normalizes + every channel to reflect quantile values, uniformly distributed, + within. + + [0, 1]. """ assert self.count > 0 assert data.shape[0] == self.depth @@ -613,9 +623,11 @@ def normalize(self, data): class Variance(Stat): - """Running computation of mean|median and variance. Use this when you just need - basic stats without covariance. + """Running computation of mean|median and variance. + + Use this when you just need basic stats without covariance. """ + def __init__(self, method: str = 'mean', state=None): if state is not None: return super().__init__(state) @@ -681,7 +693,9 @@ def load_state_dict(self, state): self.batchcount = state['batchcount'] self._mean = torch.from_numpy(state['mean']) self.v_cmom2 = torch.from_numpy(state['cmom2']) - self.data_shape = None if state['data_shape'] is None else tuple(state['data_shape']) + self.data_shape = ( + None if state['data_shape'] is None else tuple(state['data_shape']) + ) def state_dict(self): return dict( @@ -690,7 +704,7 @@ def state_dict(self): data_shape=self.data_shape and tuple(self.data_shape), batchcount=self.batchcount, mean=self._mean.cpu().numpy(), - cmom2=self.v_cmom2.cpu().numpy() + cmom2=self.v_cmom2.cpu().numpy(), ) @@ -698,15 +712,14 @@ def tally_stats( stats: T.Sequence[T.Union[Mean, Variance, Quantile]], loader: DataLoader, caches: T.Sequence[T.Union[Path, str]], - quiet: bool = True + quiet: bool = True, ): - """Tally stats + """Tally stats. To use tally_stats, write code like the following. ds = EdgeDataset( ppaths.train_path, processes=4, - threads_per_worker=2, random_seed=100 ) train_ds, val_ds = ds.split_train_val( @@ -763,6 +776,7 @@ def tally_stats( cached_state = load_cached_state(cache, args, quiet=quiet) if cached_state is not None: stat.load_state_dict(cached_state) + def empty_loader(): return yield @@ -780,10 +794,10 @@ def wrapped_loader(): return wrapped_loader() -class cache_load_enabled(): +class cache_load_enabled: """When used as a context manager, cache_load_enabled(False) will prevent - tally from loading cached statsitics, forcing them to be recomputed. - """ + tally from loading cached statsitics, forcing them to be recomputed.""" + def __init__(self, enabled=True): self.prev = False self.enabled = enabled diff --git a/tests/test_create_dataset.py b/tests/_test_create_dataset.py similarity index 84% rename from tests/test_create_dataset.py rename to tests/_test_create_dataset.py index abcd373f..00495792 100644 --- a/tests/test_create_dataset.py +++ b/tests/_test_create_dataset.py @@ -1,11 +1,11 @@ import shutil -from .data import p -from cultionet.scripts.cultionet import open_config from cultionet.data.create import create_predict_dataset +from cultionet.scripts.cultionet import open_config from cultionet.utils import model_preprocessing from cultionet.utils.project_paths import setup_paths +from .data import p CONFIG = open_config(p / 'config.yml') END_YEAR = CONFIG['years'][-1] @@ -17,7 +17,11 @@ def get_image_list(): for image_vi in CONFIG['image_vis']: vi_path = p / 'time_series_vars' / REGION / image_vi ts_list = model_preprocessing.get_time_series_list( - vi_path, END_YEAR-1, CONFIG['start_date'], CONFIG['end_date'], date_format='%Y%j' + vi_path, + END_YEAR - 1, + CONFIG['start_date'], + CONFIG['end_date'], + date_format='%Y%j', ) image_list += ts_list @@ -40,7 +44,7 @@ def test_predict_dataset(): window_size=50, padding=5, num_workers=2, - chunksize=100 + chunksize=100, ) pt_list = list(ppaths.get_process_path('predict').glob('*.pt')) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..29876aba --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,83 @@ +from pathlib import Path +from typing import Optional + +import numpy as np +import pytest +import torch + +from cultionet.data import Data +from cultionet.data.datasets import EdgeDataset + +RNG = np.random.default_rng(100) + + +@pytest.fixture +def class_info() -> dict: + return {'max_crop_class': 1, 'edge_class': 2} + + +def create_batch( + num_channels: int = 3, + num_time: int = 12, + height: int = 20, + width: int = 20, + rng: Optional[np.random.Generator] = None, +) -> Data: + x = torch.rand(1, num_channels, num_time, height, width) + y = torch.randint(low=-1, high=3, size=(1, height, width)) + bdist = torch.rand(1, height, width) + + if rng is None: + rng = RNG + + idx = rng.integers(low=0, high=99_999) + year = rng.choice([2020, 2021, 2022, 2023]) + + top = rng.uniform(-90, 90, size=1) + bottom = rng.uniform(-90, 90, size=1) + if top < bottom: + top, bottom = bottom, top + + left = rng.uniform(-180, 180, size=1) + right = rng.uniform(-180, 180, size=1) + if right < left: + left, right = right, left + + return Data( + x=x, + y=y, + bdist=bdist, + batch_id=[f"data_{idx:06d}_{year}_none.pt"], + left=torch.from_numpy(left), + bottom=torch.from_numpy(bottom), + right=torch.from_numpy(right), + top=torch.from_numpy(top), + ) + + +@pytest.fixture +def data_batch() -> Data: + return create_batch() + + +def temporary_dataset( + temp_dir: str, + num_samples: int, + rng: Optional[np.random.Generator] = None, + batch_kwargs: Optional[dict] = None, + **kwargs, +) -> EdgeDataset: + if batch_kwargs is None: + batch_kwargs = {} + + train_path = Path(temp_dir) + processed_path = train_path / 'processed' + + if rng is None: + rng = np.random.default_rng(100) + + for _ in range(num_samples): + batch = create_batch(rng=rng, **batch_kwargs) + batch.to_file(processed_path / batch.batch_id[0]) + + return EdgeDataset(train_path, **kwargs) diff --git a/tests/data/train/processed/data_000002_2022_0_none.pt b/tests/data/train/processed/data_000002_2022_0_none.pt new file mode 100644 index 00000000..2528776a Binary files /dev/null and b/tests/data/train/processed/data_000002_2022_0_none.pt differ diff --git a/tests/test_augmentation.py b/tests/test_augmentation.py index 591353e4..18c24337 100644 --- a/tests/test_augmentation.py +++ b/tests/test_augmentation.py @@ -1,11 +1,15 @@ -from cultionet.augment.augmenters import Augmenters +import numpy as np +import torch +from scipy.ndimage import label as nd_label +from skimage.measure import regionprops + from cultionet.augment.augmenter_utils import ( feature_stack_to_tsaug, tsaug_to_feature_stack, ) +from cultionet.augment.augmenters import Augmenters -import numpy as np - +from .conftest import create_batch NTIME = 12 NBANDS = 3 @@ -14,33 +18,194 @@ RNG = np.random.default_rng(200) -def test_feature_stack_to_tsaug(): +def test_tensor_reshape(): """Test array reshaping.""" - x = RNG.random((NTIME * NBANDS, NROWS, NCOLS)) - nfeas = x.shape[0] - # Reshape from (T*C x H x W) -> (H*W x T X C) - x_t = feature_stack_to_tsaug(x, NTIME, NBANDS, NROWS, NCOLS) + + x = torch.rand(1, NBANDS, NTIME, NROWS, NCOLS) + # Reshape to -> (H*W x T X C) + x_t = feature_stack_to_tsaug(x) + assert x_t.shape == ( NROWS * NCOLS, NTIME, NBANDS, ), 'The feature stack was incorrectly reshaped.' + + # First sample, first band, all time + assert torch.allclose(x_t[0, :, 0], x[0, 0, :, 0, 0]) + # First sample, second band, all time + assert torch.allclose(x_t[0, :, 1], x[0, 1, :, 0, 0]) + # First sample, last band, all time + assert torch.allclose(x_t[0, :, -1], x[0, -1, :, 0, 0]) + # Last sample, first band, all time + assert torch.allclose(x_t[-1, :, 0], x[0, 0, :, -1, -1]) + # Reshape from (H*W x T X C) -> (T*C x H x W) - x_tr = tsaug_to_feature_stack(x_t, nfeas, NROWS, NCOLS) - assert np.allclose( + x_tr = tsaug_to_feature_stack(x_t, NROWS, NCOLS) + + assert torch.allclose( x, x_tr ), 'The re-transformed data do not match the original.' def test_augmenter_loading(): augmentations = [ + 'roll', 'tswarp', 'tsnoise', 'tsdrift', 'tspeaks', + 'tsdrift', + 'gaussian', + 'saltpepper', + 'perlin', ] - aug = Augmenters( - augmentations=augmentations, ntime=13, nbands=5, max_crop_class=1 - ) - for i, method in enumerate(aug): - assert method.name_ == augmentations[i] + + for aug_name in augmentations: + aug_modules = Augmenters(augmentations=[aug_name], rng=RNG) + + batch = create_batch( + num_channels=3, + num_time=12, + height=50, + width=50, + ) + + assert batch.x.min() >= 0 + assert batch.x.max() <= 1 + assert batch.y.min() == -1 + + batch.segments = np.uint8(nd_label(batch.y.squeeze().numpy() == 1)[0]) + batch.props = regionprops(batch.segments) + aug_batch = aug_modules(batch.copy()) + + assert not torch.allclose(aug_batch.x, batch.x) + assert torch.allclose(aug_batch.y, batch.y) + assert torch.allclose(aug_batch.bdist, batch.bdist) + + augmentations = [ + 'rot90', + 'rot180', + 'rot270', + 'fliplr', + 'flipud', + 'cropresize', + ] + for aug_name in augmentations: + aug_modules = Augmenters(augmentations=[aug_name], rng=RNG) + + batch = create_batch( + num_channels=3, + num_time=12, + height=50, + width=50, + ) + + assert batch.x.min() >= 0 + assert batch.x.max() <= 1 + assert batch.y.min() == -1 + + aug_batch = aug_modules(batch.copy()) + + if aug_name == 'rotate-90': + assert torch.allclose( + batch.x[0, 0, :, 0, 0], + aug_batch.x[0, 0, :, -1, 0], + rtol=1e-4, + ) + assert torch.allclose( + batch.x[0, 0, :, 0, -1], + aug_batch.x[0, 0, :, 0, 0], + rtol=1e-4, + ) + assert torch.allclose( + batch.y[0, 0, 0], + aug_batch.y[0, -1, 0], + ) + assert torch.allclose( + batch.y[0, 0, -1], + aug_batch.y[0, 0, 0], + ) + assert torch.allclose( + batch.bdist[0, 0, 0], + aug_batch.bdist[0, -1, 0], + ) + assert torch.allclose( + batch.bdist[0, 0, -1], + aug_batch.bdist[0, 0, 0], + ) + elif aug_name == 'fliplr': + assert torch.allclose( + batch.x[0, 0, :, 0, 0], + aug_batch.x[0, 0, :, 0, -1], + rtol=1e-4, + ) + assert torch.allclose( + batch.x[0, 0, :, -1, 0], + aug_batch.x[0, 0, :, -1, -1], + rtol=1e-4, + ) + assert torch.allclose( + batch.y[0, 0, 0], + aug_batch.y[0, 0, -1], + ) + assert torch.allclose( + batch.y[0, -1, 0], + aug_batch.y[0, -1, -1], + ) + assert torch.allclose( + batch.bdist[0, 0, 0], + aug_batch.bdist[0, 0, -1], + ) + assert torch.allclose( + batch.bdist[0, -1, 0], + aug_batch.bdist[0, -1, -1], + ) + elif aug_name == 'flipud': + assert torch.allclose( + batch.x[0, 0, :, 0, 0], + aug_batch.x[0, 0, :, -1, 0], + rtol=1e-4, + ) + assert torch.allclose( + batch.x[0, 0, :, 0, -1], + aug_batch.x[0, 0, :, -1, -1], + rtol=1e-4, + ) + assert torch.allclose( + batch.y[0, 0, 0], + aug_batch.y[0, -1, 0], + ) + assert torch.allclose( + batch.y[0, 0, -1], + aug_batch.y[0, -1, -1], + ) + assert torch.allclose( + batch.bdist[0, 0, 0], + aug_batch.bdist[0, -1, 0], + ) + assert torch.allclose( + batch.bdist[0, 0, -1], + aug_batch.bdist[0, -1, -1], + ) + + assert not torch.allclose(aug_batch.x, batch.x) + assert not torch.allclose(aug_batch.y, batch.y) + assert not torch.allclose(aug_batch.bdist, batch.bdist) + + augmentations = ['none'] + for aug_name in augmentations: + aug_modules = Augmenters(augmentations=[aug_name], rng=RNG) + + batch = create_batch( + num_channels=3, + num_time=12, + height=50, + width=50, + ) + + aug_batch = aug_modules(batch.copy()) + + assert torch.allclose(aug_batch.x, batch.x) + assert torch.allclose(aug_batch.y, batch.y) + assert torch.allclose(aug_batch.bdist, batch.bdist) diff --git a/tests/test_cultionet.py b/tests/test_cultionet.py new file mode 100644 index 00000000..906e860e --- /dev/null +++ b/tests/test_cultionet.py @@ -0,0 +1,120 @@ +import tempfile + +from cultionet.data.modules import EdgeDataModule +from cultionet.enums import ( + AttentionTypes, + InferenceNames, + ModelTypes, + ResBlockTypes, +) +from cultionet.models.cultionet import CultioNet +from cultionet.utils.normalize import NormValues + +from .conftest import temporary_dataset + + +def get_train_dataset( + class_nums: dict, + temp_dir: str, + batch_kwargs: dict, + batch_size: int, + num_samples: int, + val_frac: float, +) -> EdgeDataModule: + + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + batch_kwargs=batch_kwargs, + processes=1, + ) + norm_values = NormValues.from_dataset( + ds, + batch_size=batch_size, + class_info=class_nums, + num_workers=0, + ) + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + batch_kwargs=batch_kwargs, + processes=1, + norm_values=norm_values, + augment_prob=0.1, + ) + train_ds, val_ds = ds.split_train_val( + val_frac=val_frac, + spatial_overlap_allowed=False, + spatial_balance=True, + ) + + return EdgeDataModule( + train_ds=train_ds, + val_ds=val_ds, + batch_size=batch_size, + ) + + +def test_cultionet(class_info: dict): + num_channels = 5 + in_time = 13 + height = 100 + width = 100 + batch_size = 2 + num_samples = 12 + val_frac = 0.2 + + kwargs = dict( + in_channels=num_channels, + in_time=in_time, + hidden_channels=32, + model_type=ModelTypes.TOWERUNET, + activation_type="SiLU", + dilations=[1, 2], + dropout=0.2, + res_block_type=ResBlockTypes.RESA, + attention_weights=AttentionTypes.SPATIAL_CHANNEL, + pool_by_max=True, + ) + + model = CultioNet(**kwargs) + + with tempfile.TemporaryDirectory() as temp_dir: + data_module = get_train_dataset( + class_nums=class_info, + temp_dir=temp_dir, + batch_kwargs=dict( + num_channels=num_channels, + num_time=in_time, + height=height, + width=width, + ), + batch_size=batch_size, + num_samples=num_samples, + val_frac=val_frac, + ) + + assert data_module.train_ds.augment_prob == 0.1 + assert data_module.val_ds.augment_prob == 0.0 + + for batch in data_module.train_dataloader(): + output = model(batch) + + assert output[InferenceNames.DISTANCE].shape == ( + batch_size, + 1, + height, + width, + ) + assert output[InferenceNames.EDGE].shape == ( + batch_size, + 1, + height, + width, + ) + assert output[InferenceNames.CROP].shape == ( + batch_size, + 1, + height, + width, + ) diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 00000000..d975915f --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,149 @@ +import tempfile +from pathlib import Path + +import numpy as np +import torch + +from cultionet.data import Data +from cultionet.data.modules import EdgeDataModule + +from .conftest import temporary_dataset + + +def test_assign_x(): + num_channels = 3 + num_time = 10 + height = 5 + width = 5 + x = torch.rand(1, num_channels, num_time, height, width) + batch = Data(x=x) + + assert batch.x.shape == (1, num_channels, num_time, height, width) + assert batch.y is None + assert torch.allclose(x, batch.x) + assert batch.num_samples == 1 + assert batch.num_channels == num_channels + assert batch.num_time == num_time + assert batch.height == height + assert batch.width == width + + +def test_assign_xy(): + num_channels = 3 + num_time = 10 + height = 5 + width = 5 + x = torch.rand(1, num_channels, num_time, height, width) + y = torch.randint(low=0, high=2, size=(1, height, width)) + batch = Data(x=x, y=y) + + assert batch.x.shape == (1, num_channels, num_time, height, width) + assert batch.y.shape == (1, height, width) + assert torch.allclose(x, batch.x) + assert torch.allclose(y, batch.y) + assert batch.num_samples == 1 + assert batch.num_channels == num_channels + assert batch.num_time == num_time + assert batch.height == height + assert batch.width == width + + +def test_assign_xy_kwargs(): + num_channels = 3 + num_time = 10 + height = 5 + width = 5 + x = torch.rand(1, num_channels, num_time, height, width) + y = torch.randint(low=0, high=2, size=(1, height, width)) + bdist = torch.rand(1, height, width) + batch = Data(x=x, y=y, bdist=bdist) + + assert batch.x.shape == (1, num_channels, num_time, height, width) + assert batch.y.shape == (1, height, width) + assert batch.bdist.shape == (1, height, width) + assert torch.allclose(x, batch.x) + assert torch.allclose(y, batch.y) + assert torch.allclose(bdist, batch.bdist) + assert batch.num_samples == 1 + assert batch.num_channels == num_channels + assert batch.num_time == num_time + assert batch.height == height + assert batch.width == width + + +def test_create_data(): + num_channels = 3 + num_time = 10 + height = 5 + width = 5 + + x = torch.rand(1, num_channels, num_time, height, width) + y = torch.randint(low=0, high=2, size=(1, height, width)) + bdist = torch.rand(1, height, width) + batch = Data(x=x, y=y, bdist=bdist) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_path = Path(temp_dir) / 'test_batch.pt' + + # Save and load a single batch + batch.to_file(temp_path) + loaded_batch = batch.from_file(temp_path) + + assert loaded_batch.x.shape == ( + 1, + num_channels, + num_time, + height, + width, + ) + assert loaded_batch.y.shape == (1, height, width) + assert loaded_batch.bdist.shape == (1, height, width) + assert torch.allclose(x, loaded_batch.x) + assert torch.allclose(y, loaded_batch.y) + assert torch.allclose(bdist, loaded_batch.bdist) + assert loaded_batch.num_samples == 1 + assert loaded_batch.num_channels == num_channels + assert loaded_batch.num_time == num_time + assert loaded_batch.height == height + assert loaded_batch.width == width + + +def test_copy_data(data_batch: Data): + x_clone = data_batch.x.clone() + + batch_copy = data_batch.copy() + + for key in batch_copy.to_dict().keys(): + assert key in data_batch.to_dict().keys() + + batch_copy.x *= 10 + + assert not torch.allclose(data_batch.x, batch_copy.x) + assert torch.allclose(data_batch.x, x_clone) + assert torch.allclose(data_batch.y, batch_copy.y) + + +def test_train_dataset(): + num_samples = 6 + batch_size = 2 + + with tempfile.TemporaryDirectory() as temp_dir: + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + ) + + assert len(ds) == num_samples + + data_module = EdgeDataModule( + train_ds=ds, + batch_size=batch_size, + num_workers=0, + ) + for batch in data_module.train_dataloader(): + assert batch.num_samples == batch_size + for key, value in batch.to_dict().items(): + if isinstance(value, (torch.Tensor, np.ndarray)): + assert value.shape[0] == batch_size + else: + assert len(value) == batch_size diff --git a/tests/test_data_splits.py b/tests/test_data_splits.py new file mode 100644 index 00000000..0f058cfa --- /dev/null +++ b/tests/test_data_splits.py @@ -0,0 +1,23 @@ +import tempfile + +from .conftest import temporary_dataset + + +def test_train_dataset(): + num_samples = 6 + val_frac = 0.2 + + with tempfile.TemporaryDirectory() as temp_dir: + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + processes=1, + ) + train_ds, val_ds = ds.split_train_val( + val_frac=val_frac, + spatial_overlap_allowed=False, + spatial_balance=True, + ) + + assert len(train_ds) == len(ds) - int(len(ds) * val_frac) + assert len(val_ds) == int(len(ds) * val_frac) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 2c71c344..f7bd7838 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,45 +1,110 @@ -from pathlib import Path +import tempfile -from .data import batch_file -from cultionet.data.datasets import EdgeDataset -from cultionet.utils.project_paths import setup_paths import torch - -project_path = Path(__file__).parent.absolute() -ppaths = setup_paths(project_path) -ds = EdgeDataset(ppaths.train_path) -DATA = next(iter(ds)) -LOADED_DATA = ds.load_file(batch_file) - - -def test_load(): - assert torch.allclose(DATA.x, LOADED_DATA.x) - assert torch.allclose(DATA.y, LOADED_DATA.y) - - -def test_ds_type(): - assert isinstance(ds, EdgeDataset) - - -def test_ds_len(): - assert len(ds) == 1 - - -def test_x_type(): - assert isinstance(DATA.x, torch.Tensor) - - -def test_x_shape(): - assert DATA.x.shape == (10000, 39) - - -def test_y_shape(): - assert DATA.y.shape == (10000,) - - -def test_dims_attr(): - assert DATA.nbands == 3 - assert DATA.ntime == 13 - assert DATA.height == 100 - assert DATA.width == 100 +from cultionet.data.modules import EdgeDataModule +from cultionet.utils.normalize import NormValues + +from .conftest import temporary_dataset + + +def test_dataset(class_info: dict) -> EdgeDataModule: + + batch_size = 2 + num_channels = 3 + in_time = 12 + height = 20 + width = 20 + num_samples = 20 + val_frac = 0.1 + + batch_kwargs = dict( + num_channels=num_channels, + num_time=in_time, + height=height, + width=width, + ) + + with tempfile.TemporaryDirectory() as temp_dir: + + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + batch_kwargs=batch_kwargs, + processes=1, + random_seed=100, + ) + norm_values = NormValues.from_dataset( + ds, + batch_size=batch_size, + class_info=class_info, + num_workers=0, + ) + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + batch_kwargs=batch_kwargs, + processes=1, + norm_values=norm_values, + augment_prob=0.1, + random_seed=100, + ) + train_ds, val_ds = ds.split_train_val( + val_frac=val_frac, + spatial_overlap_allowed=False, + spatial_balance=True, + ) + + generator = torch.Generator() + generator.manual_seed(100) + + data_module = EdgeDataModule( + train_ds=train_ds, + val_ds=val_ds, + batch_size=batch_size, + shuffle=False, + generator=generator, + ) + first_train_batch = next(iter(data_module.train_dataloader())) + first_val_batch = next(iter(data_module.val_dataloader())) + assert first_train_batch.batch_id == [ + 'data_002257_2022_none.pt', + 'data_012624_2023_none.pt', + ] + assert first_val_batch.batch_id == [ + 'data_051349_2022_none.pt', + 'data_094721_2022_none.pt', + ] + data_module = EdgeDataModule( + train_ds=train_ds, + val_ds=val_ds, + batch_size=batch_size, + shuffle=True, + generator=generator, + ) + first_train_batch = next(iter(data_module.train_dataloader())) + first_val_batch = next(iter(data_module.val_dataloader())) + assert first_train_batch.batch_id == [ + 'data_034049_2022_none.pt', + 'data_050552_2023_none.pt', + ] + assert first_val_batch.batch_id == [ + 'data_051349_2022_none.pt', + 'data_094721_2022_none.pt', + ] + + assert len(ds) == num_samples + assert len(val_ds) == int(val_frac * len(ds)) + assert len(train_ds) == len(ds) - int(val_frac * len(ds)) + assert ds.num_time == in_time + assert train_ds.num_time == in_time + assert val_ds.num_time == in_time + + assert ds.data_list[0].name == 'data_002257_2022_none.pt' + assert ds.data_list[-1].name == 'data_094721_2022_none.pt' + ds.shuffle() + assert ds.data_list[0].name == 'data_032192_2020_none.pt' + assert ds.data_list[-1].name == 'data_022792_2023_none.pt' + + ds.cleanup() + assert len(ds) == 0 diff --git a/tests/test_loss.py b/tests/test_loss.py index 1c486d1c..1ae5edf2 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -1,16 +1,145 @@ -from cultionet.losses import TanimotoDistLoss - +import numpy as np +import pytest import torch +import torch.nn.functional as F +from einops import rearrange + +from cultionet.losses import ( + CombinedLoss, + LossPreprocessing, + TanimotoComplementLoss, + TanimotoDistLoss, +) + +rng = np.random.default_rng(100) + +BATCH_SIZE = 2 +HEIGHT = 20 +WIDTH = 20 + +INPUTS_CROP_LOGIT = torch.from_numpy( + rng.uniform(low=-3, high=3, size=(BATCH_SIZE, 2, HEIGHT, WIDTH)) +).float() +INPUTS_CROP_PROB = rearrange( + torch.from_numpy( + rng.dirichlet((0.5, 0.5), size=(BATCH_SIZE * HEIGHT * WIDTH)) + ).float(), + '(b h w) c -> b c h w', + b=BATCH_SIZE, + c=2, + h=HEIGHT, + w=WIDTH, +) +INPUTS_EDGE_PROB = torch.from_numpy( + rng.random((BATCH_SIZE, 1, HEIGHT, WIDTH)) +).float() +INPUTS_DIST = torch.from_numpy( + rng.random((BATCH_SIZE, 1, HEIGHT, WIDTH)) +).float() +DISCRETE_TARGETS = torch.from_numpy( + rng.integers(low=0, high=2, size=(BATCH_SIZE, HEIGHT, WIDTH)) +).long() +DISCRETE_EDGE_TARGETS = torch.from_numpy( + rng.integers(low=0, high=1, size=(BATCH_SIZE, HEIGHT, WIDTH)) +).long() +DIST_TARGETS = torch.from_numpy( + rng.random((BATCH_SIZE, HEIGHT, WIDTH)) +).float() +MASK = torch.from_numpy( + rng.integers(low=0, high=2, size=(BATCH_SIZE, 1, HEIGHT, WIDTH)) +).long() + + +def test_loss_preprocessing(): + # Input logits + preprocessor = LossPreprocessing( + transform_logits=True, one_hot_targets=True + ) + inputs, targets = preprocessor(INPUTS_CROP_LOGIT, DISCRETE_TARGETS) + + assert inputs.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH) + assert targets.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH) + assert torch.allclose( + inputs.sum(dim=1), torch.ones(BATCH_SIZE, HEIGHT, WIDTH), rtol=0.1 + ) + assert torch.allclose( + inputs, + F.softmax(INPUTS_CROP_LOGIT, dim=1, dtype=INPUTS_CROP_LOGIT.dtype), + ) + + # Input probabilities + preprocessor = LossPreprocessing( + transform_logits=False, one_hot_targets=True + ) + inputs, targets = preprocessor(INPUTS_CROP_PROB, DISCRETE_TARGETS) + + assert inputs.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH) + assert targets.shape == (BATCH_SIZE, 2, HEIGHT, WIDTH) + assert torch.allclose( + inputs.sum(dim=1), torch.ones(BATCH_SIZE, HEIGHT, WIDTH), rtol=0.1 + ) + assert torch.allclose( + inputs, + INPUTS_CROP_PROB, + ) + + preprocessor = LossPreprocessing( + transform_logits=False, one_hot_targets=True + ) + inputs, targets = preprocessor(INPUTS_EDGE_PROB, DISCRETE_EDGE_TARGETS) + + assert inputs.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH) + assert targets.shape == (BATCH_SIZE, 1, HEIGHT, WIDTH) + assert torch.allclose( + inputs, + INPUTS_EDGE_PROB, + ) + + # Regression + preprocessor = LossPreprocessing( + transform_logits=False, one_hot_targets=False + ) + inputs, targets = preprocessor(INPUTS_DIST, DIST_TARGETS) + + # Preprocessing should not change the inputs other than the shape + assert torch.allclose(inputs, INPUTS_DIST) + assert torch.allclose(targets, rearrange(DIST_TARGETS, 'b h w -> b 1 h w')) + + +def test_tanimoto_classification_loss(): + loss_func = TanimotoDistLoss() + + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) + assert round(float(loss.item()), 3) == 0.611 + + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) + assert round(float(loss.item()), 3) == 0.431 + + loss_func = TanimotoComplementLoss() + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) + assert round(float(loss.item()), 3) == 0.824 + + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) + assert round(float(loss.item()), 3) == 0.692 + loss_func = CombinedLoss( + losses=[ + TanimotoDistLoss(), + TanimotoComplementLoss(), + ] + ) + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS) + assert round(float(loss.item()), 3) == 0.717 -torch.manual_seed(100) -n_samples = 100 -INPUTS = torch.randn((n_samples, 2)) -TARGETS = torch.randint(low=0, high=2, size=(n_samples,)) + loss = loss_func(INPUTS_CROP_PROB, DISCRETE_TARGETS, mask=MASK) + assert round(float(loss.item()), 3) == 0.561 -def test_tanimoto_loss(): - loss_func = TanimotoDistLoss(scale_pos_weight=False, transform_logits=True) - loss = loss_func(INPUTS, TARGETS) +def test_tanimoto_regression_loss(): + loss_func = TanimotoDistLoss(one_hot_targets=False) + loss = loss_func(INPUTS_DIST, DIST_TARGETS) + assert round(float(loss.item()), 3) == 0.417 - assert round(loss.mean().item(), 4) == 0.5903 + loss_func = TanimotoComplementLoss(one_hot_targets=False) + loss = loss_func(INPUTS_DIST, DIST_TARGETS) + assert round(float(loss.item()), 3) == 0.704 diff --git a/tests/test_norm.py b/tests/test_norm.py index 46776314..27e60af3 100644 --- a/tests/test_norm.py +++ b/tests/test_norm.py @@ -1,224 +1,132 @@ +import tempfile from pathlib import Path -from cultionet.data.datasets import zscores, EdgeDataset -from cultionet.utils.normalize import get_norm_values -# from cultionet.data.modules import EdgeDataModule -# from cultionet.utils.stats import ( -# tally_stats, -# cache_load_enabled, -# load_cached_state, -# Mean, -# Quantile, -# Variance -# ) - import torch -from torch_geometric.data import Data -import pytest - - -PROJECT_PATH = Path(__file__).parent.absolute() -CLASS_INFO = { - 'max_crop_class': 1, - 'edge_class': 2 -} - - -def create_small_chips(b: torch.Tensor, rc_slice: tuple) -> Data: - """Method used to create new data - - Example: - >>> import joblib - >>> - >>> batch = joblib.load('...') - >>> create_small_chips( - >>> batch, - >>> rc_slice=(slice(0, None), slice(45, 55), slice(45, 55)) - >>> ) - >>> - >>> # Create small data chips in the test dir - >>> out_path = Path('test_dir') - >>> for fn in Path('train/processed').glob('*.pt'): - >>> batch = joblib.load(fn) - >>> small_batch = create_create_small_chipstest_data( - >>> batch, - >>> (slice(0, None), slice(45, 55), slice(45, 55)) - >>> ) - >>> joblib.dump(small_batch, out_path / fn.name) - """ - exclusion = ('x', 'height', 'width') - # Reshape to (C x H x W) - x = b.x.t().reshape(b.ntime*b.nbands, b.height, b.width) - # Take a subset - x = x[rc_slice] - # Reshape back to (S x D) - height = rc_slice[1].stop - rc_slice[1].start - width = rc_slice[2].stop - rc_slice[2].start - x = x.permute(1, 2, 0).reshape(height*width, b.ntime*b.nbands) - - return Data( - x=x, - height=height, - width=width, - **{k: getattr(b, k) for k in b.keys if k not in exclusion} - ) - - -@pytest.fixture(scope='session') -def train_dataset() -> EdgeDataset: - train_path = PROJECT_PATH / 'data' / 'train' / 'small_chips' - - ds = EdgeDataset( - train_path, - processes=1, - threads_per_worker=1, - random_seed=100 - ) +from torch.utils.data import DataLoader - return ds +from cultionet.data import Data +from cultionet.data.utils import collate_fn +from cultionet.utils.normalize import NormValues +from .conftest import temporary_dataset -@pytest.fixture(scope='session') -def serial_ref_data(train_dataset: EdgeDataset) -> torch.Tensor: - ref_data = torch.cat([batch.x for batch in train_dataset], dim=0) - return ref_data - - -@pytest.fixture(scope='session') -def serial_norm_data(train_dataset: EdgeDataset) -> Data: - norm_values = get_norm_values( - dataset=train_dataset, - batch_size=1, - class_info=CLASS_INFO, - num_workers=1, - processes=1, - threads_per_worker=1, - mean_color='#3edf2b', - sse_color='#dfb92b' +def test_norm(): + num_channels = 3 + shape = (1, num_channels, 1, 1, 1) + norm_values = NormValues( + dataset_mean=torch.zeros(shape), + dataset_std=torch.ones(shape), + dataset_crop_counts=None, + dataset_edge_counts=None, + num_channels=num_channels, ) - return norm_values - - -def test_cumnorm_serial( - serial_ref_data: torch.Tensor, - serial_norm_data: Data -): - assert torch.allclose(serial_norm_data.mean, serial_ref_data.mean(dim=0), rtol=1e-4), \ - 'The mean values do not match the expected values.' - assert torch.allclose(serial_norm_data.std, serial_ref_data.std(dim=0, unbiased=False), rtol=1e-4), \ - 'The mean values do not match the expected values.' - - -def test_cumnorm_concurrent(train_dataset: EdgeDataset, serial_ref_data: torch.Tensor): - norm_values = get_norm_values( - dataset=train_dataset, - batch_size=1, - class_info=CLASS_INFO, - num_workers=1, - processes=4, - threads_per_worker=2, - mean_color='#df4a2b', - sse_color='#2ba0df' + batch = Data(x=torch.ones(shape)) + assert torch.allclose( + norm_values(batch).x, + torch.ones(shape), ) + assert torch.allclose(batch.x, torch.ones(shape)) - assert torch.allclose(norm_values.mean, serial_ref_data.mean(dim=0), rtol=1e-4), \ - 'The mean values do not match the expected values.' - assert torch.allclose(norm_values.std, serial_ref_data.std(dim=0, unbiased=False), rtol=1e-4), \ - 'The mean values do not match the expected values.' - - -def test_transform_data(train_dataset: EdgeDataset, serial_norm_data: Data): - ref_batch = train_dataset[0] - batch = zscores( - batch=ref_batch, - data_means=serial_norm_data.mean, - data_stds=serial_norm_data.std, + batch = Data(x=torch.zeros(shape)) + assert torch.allclose( + norm_values(batch).x, + torch.zeros(shape), + ) + assert torch.allclose(batch.x, torch.zeros(shape)) + + norm_values = NormValues( + dataset_mean=torch.zeros(shape) + 0.5, + dataset_std=torch.ones(shape) + 0.5, + dataset_crop_counts=None, + dataset_edge_counts=None, + num_channels=num_channels, ) - # z = (x - μ) / σ - ref_zscores = (ref_batch.x - serial_norm_data.mean) / serial_norm_data.std - - assert torch.allclose(batch.x, ref_zscores), 'The z-scores do not match the expected values.' - - -# NOTE: this module is not currently used, but we will -# keep the test here in case of future use -# def test_norm(): -# train_path = PROJECT_PATH / 'data' / 'train' / 'small_chips' -# mean_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'data_means.npz' -# var_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'data_vars.npz' -# var_median_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'data_vars_median.npz' -# q_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'data_quantiles.npz' -# ref_q_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'ref_data_quantiles.npz' -# ref_var_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'ref_data_vars.npz' -# ref_var_median_data_cache = PROJECT_PATH / 'data' / 'train' / 'small_chips' / 'ref_data_vars_median.npz' - -# ds = EdgeDataset( -# train_path, -# processes=1, -# threads_per_worker=1, -# random_seed=100 -# ) -# # TODO: test this -# # norm_values = get_norm_values( -# # dataset=ds, -# # batch_size=1, -# # class_info=CLASS_INFO, -# # num_workers=4, -# # centering='median' -# # ) - -# data_module = EdgeDataModule( -# train_ds=ds, -# batch_size=1, -# num_workers=0, -# shuffle=False -# ) - -# ref_data = [] -# stat_mean = Mean() -# stat_var = Variance() -# stat_var_median = Variance(method='median') -# stat_q = Quantile() -# with cache_load_enabled(False): -# for batch in tally_stats( -# stats=(stat_mean, stat_var, stat_var_median, stat_q), -# loader=data_module.train_dataloader(), -# caches=(mean_data_cache, var_data_cache, var_median_data_cache, q_data_cache) -# ): -# ref_data.append(batch.x) -# stat_mean.add(batch.x) -# stat_q.add(batch.x) -# stat_var.add(batch.x) -# stat_var_median.add(batch.x) -# ref_data = torch.cat(ref_data, dim=0) -# mean = stat_mean.mean() -# std = stat_var.std() -# std_median = stat_var_median.std() -# median = stat_q.median() - -# ref_stat_var = Variance() -# cached_state = load_cached_state(ref_var_data_cache) -# ref_stat_var.load_state_dict(cached_state) -# ref_std = ref_stat_var.std() - -# ref_stat_var_median = Variance(method='median') -# cached_state = load_cached_state(ref_var_median_data_cache) -# ref_stat_var_median.load_state_dict(cached_state) -# ref_std_median = ref_stat_var_median.std() - -# ref_stat_q = Quantile() -# cached_state = load_cached_state(ref_q_data_cache) -# ref_stat_q.load_state_dict(cached_state) -# ref_median = ref_stat_q.median() - -# assert torch.allclose(ref_data.mean(dim=0), mean, rtol=1e-4), \ -# 'The data means do not match the expected values.' -# assert torch.allclose(std, ref_std, rtol=1e-4), \ -# 'The data standard deviations do not match the cached values.' -# assert torch.allclose(std_median, ref_std_median, rtol=1e-4), \ -# 'The data median standard deviations do not match the cached values.' -# assert torch.allclose(median, ref_median, rtol=1e-4), \ -# 'The data medians do not match the cached values.' + batch = Data(x=torch.ones(shape)) + assert torch.allclose( + norm_values(batch).x, + torch.zeros(shape) + 0.3333, + rtol=0.01, + ) + assert torch.allclose(batch.x, torch.ones(shape)) + + +def test_train_dataset(class_info: dict): + num_samples = 6 + batch_size = 2 + + with tempfile.TemporaryDirectory() as temp_dir: + ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + log_transform=True, + ) + + norm_values = NormValues.from_dataset( + ds, + batch_size=batch_size, + class_info=class_info, + num_workers=0, + ) + + norm_path = Path(temp_dir) / 'data.norm' + norm_values.to_file(norm_path) + loaded_norm_values = NormValues.from_file(norm_path) + + assert torch.allclose( + norm_values.dataset_mean, loaded_norm_values.dataset_mean + ) + assert torch.allclose( + norm_values.dataset_std, loaded_norm_values.dataset_std + ) + assert torch.allclose( + norm_values.dataset_crop_counts, + loaded_norm_values.dataset_crop_counts, + ) + assert torch.allclose( + norm_values.dataset_edge_counts, + loaded_norm_values.dataset_edge_counts, + ) + + assert norm_values.dataset_mean.shape == ( + 1, + norm_values.num_channels, + 1, + 1, + 1, + ) + + # Apply normalization + norm_ds = temporary_dataset( + temp_dir=temp_dir, + num_samples=num_samples, + norm_values=norm_values, + log_transform=True, + ) + data_loader = DataLoader( + ds, + batch_size=batch_size, + num_workers=0, + shuffle=False, + collate_fn=collate_fn, + ) + norm_data_loader = DataLoader( + norm_ds, + batch_size=batch_size, + num_workers=0, + shuffle=False, + collate_fn=collate_fn, + ) + + # The normalization should be applied to each batch + for batch, norm_batch in zip(data_loader, norm_data_loader): + assert not torch.allclose( + batch.x, + norm_batch.x, + ) + assert torch.allclose( + norm_values(batch).x, + norm_batch.x, + ) diff --git a/tests/test_reshape.py b/tests/test_reshape.py deleted file mode 100644 index 906cea18..00000000 --- a/tests/test_reshape.py +++ /dev/null @@ -1,87 +0,0 @@ -from pathlib import Path - -from .data import batch_file -from cultionet.data.datasets import EdgeDataset -from cultionet.utils.project_paths import setup_paths -from cultionet.models import model_utils - -import torch - - -project_path = Path(__file__).parent.absolute() -ppaths = setup_paths(project_path) -ds = EdgeDataset(ppaths.train_path) -DATA = ds.load_file(batch_file) - - -def test_graph_to_conv(): - """Test reshaping from graph/column order to multi-dimensional/convolution order - """ - gc = model_utils.GraphToConv() - - x = gc(DATA.x, 1, DATA.height, DATA.width) - - assert x.shape == (1, DATA.x.shape[1], DATA.height, DATA.width) - assert torch.allclose(x[0, :, 0, 0], DATA.x[0]) - assert torch.allclose(x[0, :, 0, 1], DATA.x[1]) - assert torch.allclose(x[0, :, -1, -2], DATA.x[-2]) - assert torch.allclose(x[0, :, -1, -1], DATA.x[-1]) - - -def test_conv_to_graph(): - """Test reshaping from multi-dimensional/convolution order to graph/column order - """ - gc = model_utils.GraphToConv() - cg = model_utils.ConvToGraph() - - x = gc(DATA.x, 1, DATA.height, DATA.width) - y = cg(x) - - assert torch.allclose(y, DATA.x) - - -def test_conv_to_time(): - """Test reshaping from multi-dimensional/convolution order to time order - """ - gc = model_utils.GraphToConv() - ct = model_utils.ConvToTime() - - x = gc(DATA.x, 1, DATA.height, DATA.width) - t = ct(x, nbands=DATA.nbands, ntime=DATA.ntime) - - assert torch.allclose( - x[0, :DATA.ntime, 0, 0], t[0, 0, :, 0, 0] - ) - assert torch.allclose( - x[0, DATA.ntime:DATA.ntime*2, 0, 0], t[0, 1, :, 0, 0] - ) - assert torch.allclose( - x[0, DATA.ntime*2:, 0, 0], t[0, 2, :, 0, 0] - ) - assert torch.allclose( - x[0, :DATA.ntime, 0, 1], t[0, 0, :, 0, 1] - ) - assert torch.allclose( - x[0, DATA.ntime:DATA.ntime*2, 0, 1], t[0, 1, :, 0, 1] - ) - assert torch.allclose( - x[0, DATA.ntime*2:, 0, 1], t[0, 2, :, 0, 1] - ) - assert torch.allclose( - x[0, :DATA.ntime, 50, 50], t[0, 0, :, 50, 50] - ) - assert torch.allclose( - x[0, DATA.ntime:DATA.ntime*2, 50, 50], t[0, 1, :, 50, 50] - ) - assert torch.allclose( - x[0, DATA.ntime*2:, 50, 50], t[0, 2, :, 50, 50] - ) - assert torch.allclose( - x[0, :DATA.ntime, -1, -1], t[0, 0, :, -1, -1] - ) - assert torch.allclose( - x[0, DATA.ntime:DATA.ntime*2, -1, -1], t[0, 1, :, -1, -1] - ) - assert torch.allclose( - x[0, DATA.ntime*2:, -1, -1], t[0, 2, :, -1, -1] - ) diff --git a/tests/test_tower_unet.py b/tests/test_tower_unet.py new file mode 100644 index 00000000..e06721b9 --- /dev/null +++ b/tests/test_tower_unet.py @@ -0,0 +1,38 @@ +import torch + +from cultionet.enums import InferenceNames, ResBlockTypes +from cultionet.models.nunet import TowerUNet + + +def test_tower_unet(): + batch_size = 2 + num_channels = 3 + hidden_channels = 32 + num_time = 13 + height = 100 + width = 100 + + x = torch.rand( + (batch_size, num_channels, num_time, height, width), + dtype=torch.float32, + ) + + model = TowerUNet( + in_channels=num_channels, + in_time=num_time, + hidden_channels=hidden_channels, + dilations=[1, 2], + res_block_type=ResBlockTypes.RESA, + pool_by_max=False, + ) + + logits = model(x) + + assert logits[InferenceNames.DISTANCE].shape == ( + batch_size, + 1, + height, + width, + ) + assert logits[InferenceNames.EDGE].shape == (batch_size, 1, height, width) + assert logits[InferenceNames.CROP].shape == (batch_size, 1, height, width) diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 00000000..0c80575a --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,133 @@ +import json +import subprocess +import tempfile +from pathlib import Path + +import lightning as L +import numpy as np +import torch + +import cultionet +from cultionet.data import Data +from cultionet.data.datasets import EdgeDataset +from cultionet.enums import AttentionTypes, ModelTypes, ResBlockTypes +from cultionet.model import CultionetParams +from cultionet.utils.project_paths import setup_paths + +L.seed_everything(100) +RNG = np.random.default_rng(200) + + +def create_data(group: int) -> Data: + num_channels = 2 + num_time = 12 + height = 100 + width = 100 + + x = torch.rand( + (1, num_channels, num_time, height, width), + dtype=torch.float32, + ) + bdist = torch.rand((1, height, width), dtype=torch.float32) + y = torch.randint(low=0, high=3, size=(1, height, width)) + + lat_left = RNG.uniform(low=-180, high=180) + lat_bottom = RNG.uniform(low=-90, high=90) + lat_right = RNG.uniform(low=-180, high=180) + lat_top = RNG.uniform(low=-90, high=90) + + batch_data = Data( + x=x, + y=y, + bdist=bdist, + left=torch.tensor([lat_left], dtype=torch.float32), + bottom=torch.tensor([lat_bottom], dtype=torch.float32), + right=torch.tensor([lat_right], dtype=torch.float32), + top=torch.tensor([lat_top], dtype=torch.float32), + batch_id=[group], + ) + + return batch_data + + +def test_train(): + num_data = 10 + with tempfile.TemporaryDirectory() as tmp_path: + ppaths = setup_paths(tmp_path) + for i in range(num_data): + data_path = ( + ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' + ) + batch_data = create_data(i) + batch_data.to_file(data_path) + + dataset = EdgeDataset( + ppaths.train_path, + processes=0, + random_seed=100, + ) + + cultionet_params = CultionetParams( + ckpt_file=ppaths.ckpt_file, + model_name="cultionet", + dataset=dataset, + val_frac=0.2, + batch_size=2, + load_batch_workers=0, + hidden_channels=16, + edge_class=2, + model_type=ModelTypes.TOWERUNET, + res_block_type=ResBlockTypes.RESA, + activation_type="SiLU", + dilations=[1, 2], + dropout=0.2, + pool_by_max=True, + epochs=1, + device="cpu", + devices=1, + precision="32", + ) + + try: + cultionet.fit(cultionet_params) + except Exception as e: + raise RuntimeError(e) + + +# def test_train_cli(): +# num_data = 10 +# with tempfile.TemporaryDirectory() as tmp_dir: +# tmp_path = Path(tmp_dir) +# ppaths = setup_paths(tmp_path) +# for i in range(num_data): +# data_path = ( +# ppaths.process_path / f'data_{i:06d}_2021_{i:06d}_none.pt' +# ) +# batch_data = create_data(i) +# batch_data.to_file(data_path) + +# with open(tmp_path / "data/classes.info", "w") as f: +# json.dump({"max_crop_class": 1, "edge_class": 2}, f) + +# command = ( +# f"cultionet train -p {str(tmp_path.absolute())} " +# "--val-frac 0.2 --augment-prob 0.5 --epochs 1 --hidden-channels 16 " +# "--processes 1 --load-batch-workers 0 --batch-size 2 --dropout 0.2 " +# "--deep-sup --dilations 1 2 --pool-by-max --learning-rate 0.01 " +# "--weight-decay 1e-4 --attention-weights spatial_channel --device cpu" +# ) + +# try: +# subprocess.run( +# command, +# shell=True, +# check=True, +# capture_output=True, +# universal_newlines=True, +# ) +# except subprocess.CalledProcessError as e: +# raise NameError( +# "Exit code:\n{}\n\nstdout:\n{}\n\nstderr:\n{}".format( +# e.returncode, e.output, e.stderr +# ) +# )