diff --git a/core/testcasecontroller/metrics/metrics.py b/core/testcasecontroller/metrics/metrics.py index 112dc691..b4f72983 100644 --- a/core/testcasecontroller/metrics/metrics.py +++ b/core/testcasecontroller/metrics/metrics.py @@ -49,41 +49,69 @@ def samples_transfer_ratio_func(system_metric_info: dict): def compute(key, matrix): """ - compute BWT and FWT + Compute BWT and FWT scores for a given matrix. """ # pylint: disable=C0103 + # pylint: disable=C0301 + # pylint: disable=C0303 + # pylint: disable=R0912 + + print(f"compute function: key={key}, matrix={matrix}, type(matrix)={type(matrix)}") + length = len(matrix) accuracy = 0.0 BWT_score = 0.0 FWT_score = 0.0 flag = True - for i in range(length): - if len(matrix[i]) != length-1: + + if key == 'all': + for i in range(length-1, 0, -1): + sum_before_i = sum(item['accuracy'] for item in matrix[i][:i]) + sum_after_i = sum(item['accuracy'] for item in matrix[i][-(length - i - 1):]) + if i == 0: + seen_class_accuracy = 0.0 + else: + seen_class_accuracy = sum_before_i / i + if length - 1 - i == 0: + unseen_class_accuracy = 0.0 + else: + unseen_class_accuracy = sum_after_i / (length - 1 - i) + print(f"round {i} : unseen class accuracy is {unseen_class_accuracy}, seen class accuracy is {seen_class_accuracy}") + + for row in matrix: + if not isinstance(row, list) or len(row) != length-1: flag = False break - if flag is False: + + if not flag: BWT_score = np.nan FWT_score = np.nan return BWT_score, FWT_score for i in range(length-1): - accuracy += matrix[length-1][i]['accuracy'] - BWT_score += matrix[length-1][i]['accuracy'] - matrix[i+1][i]['accuracy'] - for i in range(0,length-1): - FWT_score += matrix[i][i]['accuracy'] - matrix[0][i]['accuracy'] - accuracy = accuracy/(length) - BWT_score = BWT_score/(length-1) - FWT_score = FWT_score/(length-1) - #print(f"{key} accuracy: ", accuracy) - print(f"{key} BWT_score: ", BWT_score) - print(f"{key} FWT_score: ", FWT_score) + for j in range(length-1): + if 'accuracy' in matrix[i+1][j] and 'accuracy' in matrix[i][j]: + accuracy += matrix[i+1][j]['accuracy'] + BWT_score += matrix[i+1][j]['accuracy'] - matrix[i][j]['accuracy'] + + for i in range(0, length-1): + if 'accuracy' in matrix[i][i] and 'accuracy' in matrix[0][i]: + FWT_score += matrix[i][i]['accuracy'] - matrix[0][i]['accuracy'] + + accuracy = accuracy / ((length-1) * (length-1)) + BWT_score = BWT_score / ((length-1) * (length-1)) + FWT_score = FWT_score / (length-1) + + print(f"{key} BWT_score: {BWT_score}") + print(f"{key} FWT_score: {FWT_score}") + my_matrix = [] for i in range(length-1): my_matrix.append([]) - for i in range(length-1): for j in range(length-1): - my_matrix[i].append(matrix[i+1][j]['accuracy']) - #self.draw_picture(key,my_matrix) + if 'accuracy' in matrix[i+1][j]: + my_matrix[i].append(matrix[i+1][j]['accuracy']) + return my_matrix, BWT_score, FWT_score def bwt_func(system_metric_info: dict): diff --git a/docs/proposals/algorithms/lifelong-learning/Implementation of a Class Incremental Learning Algorithm Evaluation System based on Ianvs.md b/docs/proposals/algorithms/lifelong-learning/Implementation of a Class Incremental Learning Algorithm Evaluation System based on Ianvs.md new file mode 100644 index 00000000..b516515a --- /dev/null +++ b/docs/proposals/algorithms/lifelong-learning/Implementation of a Class Incremental Learning Algorithm Evaluation System based on Ianvs.md @@ -0,0 +1,270 @@ + + +- [Implementation of a Class Incremental Learning Algorithm Evaluation System based on Ianvs](#implementation-of-a-class-incremental-learning-algorithm-evaluation-system-based-on-ianvs) + - [Motivation](#motivation) + - [Background](#background) + - [Goals](#goals) + - [Proposal](#proposal) + - [Design Details](#design-details) + - [Overall Design](#overall-design) + - [Datasets](#datasets) + - [File-level Design](#file-level-design) + - [Test Environment](#test-environment) + - [Test Algorithm](#test-algorithm) + - [Test Report](#test-report) + - [Roadmap](#roadmap) + - [Phase 1 July 1st - August 15th](#phase-1-july-1st---august-15th) + - [Phase 2 August 16th - September 30th](#phase-2-august-16th---september-30th) + + + +# Implementation of a Class Incremental Learning Algorithm Evaluation System based on Ianvs + +## 1 Motivation + +### 1.1 Background +Currently, lifelong learning is facing a challenge: new classes may appear when models are trained on a new data domain ( for example, in the figure below, three classes in red are new classes in `Domain 2` ), which makes it difficult for models to maintain generalization ability and results in a severe performance drop. + +
+MDIL-SS +
+ +Many algorithms have been proposed to solve the class increment problem in domain shift scenario. However, such algorithms lack a unified testing environment, which is not conducive to comparing algorithms. In some cases, new algorithms are only tested on certain datasets, which is not rigorous. + +In this context, it is necessary to develop an algorithm evaluation system that provides standardized testing for class-incremental learning algorithms, which is increasingly widely used in the industry, and evaluates the effectiveness of these algorithms. + +[KubeEdge-Ianvs](https://github.com/kubeedge/ianvs) is a distributed collaborative AI benchmarking project which can perform benchmarks with respect to several types of paradigms (e.g. single-task learning, incremental learning, etc.). This project aims to leverage the benchmarking capabilities of Ianvs to develop an evaluation system for class-incremental learning algorithms, in order to fulfill the benchmarking requirements specific to this type of algorithm. + +### 1.2 Goals + +This project aims to build a benchmarking for class-incremental learning in domain shift scenario on KubeEdge-Ianvs, which includes: + - Reproduce the Multi-Domain Incremental Learning for Semantic Segmentation (MDIL-SS) algorithm proposed in the [WACV2022 paper](https://github.com/prachigarg23/MDIL-SS). + - Use three datasets (including Cityscapes, SYNTHIA, and the Cloud-Robotic dataset provided by KubeEdge SIG AI) to conduct benchmarking tests and generate a comprehensive test report (including rankings, time, algorithm name, dataset, and test metrics, among other details). + +## 2 Proposal + +`Implementation of a Class Incremental Learning Algorithm Evaluation System based on Ianvs` taking MDIL-SS algorithm as an example, aims to test the performance of class-incremental learning models following benchmarking standards, to make the development more efficient and productive. + +The scope of the system includes + +- A test case for class-incremental learning semantic segmentation algorithms, in which a test report can be successfully generated following instructions. +- Easy to expand, allowing users to seamlessly integrate existing algorithms into the system for testing. + +Targeting users include + +- Beginners: Familiarize with distributed synergy AI and lifelong learning, among other concepts. +- Developers: Quickly integrate class-increment algorithms into Ianvs and test the performance for further optimization. + + +## 3 Design Details + +### 3.1 Overall Design + + +First, let's introduce the training process for lifelong learning: +- When the model enters a data domain (e.g., the Cloud-Robotic dataset) for training, the first step is `Unseen Task Detection`. This is done to assess the value of the domain data for updating the model. We consider that samples of unknown tasks have learning value, while samples from known tasks are not worth annotating and using for model training, meaning they provide no value for model updating. +- Once samples from unknown tasks are detected, in order to utilize these samples for model updating, we need to label them (through manual labelling or assisted labelling algorithms). +- Finally, we utilize samples with `Unseen Task Processing`, which means updating the model using labelled samples from unseen tasks. +- In summary, the workflow is `Unseen Task Detection -> Labeling -> Unseen Task Processing`. + +Next, let's outline the class-incremental learning algorithm evaluating process for this project, as shown in the figure below: + +![MDIL-SS](images/OSPP_MDIL-SS_11.png) +- In the first round, an ERFNet model initially enters the Synthia data domain for learning. As this is the model's first encounter with this data domain, all samples are considered as unseen task samples. We manually label all of these samples and utilize them for the model's first update, meaning that the model is trained using the Synthia dataset. +- During the testing phase of the first round, some of the testing samples are seen by the model, which is Synthia, while the other samples are unseen. +- Before the start of the second round of training, we manually detect and label all Cityscapes samples. Since unseen task detection and labeling are crucial for training effectiveness, we employ the most reliable method, which is manual way. +- In the second round, the model is updated using the already labeled Cityscapes samples, meaning it is trained with the Cityscapes dataset. +- Similarly, before commencing the final round of training, we manually detect and label all Cloud-Robotic samples. These samples are then used to update the model in the third round, where the model is trained on the Cloud-Robotic dataset. +- In each round, we conduct testing using the same three datasets for lifelong learning metric, and the final test report will demonstrate the model's lifelong learning capabilities. +- Please note that the lifelong learning model will undergo training and testing across three successive data domains, namely Cityscapes, SYNTHIA, and Cloud-Robotics, comprising a total of three rounds of training and testing. As the model shifts among data domains, class also changes, signifying class increments. The following diagram illustrates the differences of classes among these three data domains. + +
+MDIL-SS +
+ + +In this project, we have maintained a relatively default setup for Unseen Task Detection and Labelling. Our primary focus lies on the `Unseen Task Processing`, which corresponds to the red-boxed section in the ianvs lifelong learning architecture: + +![MDIL-SS](images/OSPP_MDIL-SS_9.png ) + +The architecture diagram for the project is as follows: + +![MDIL-SS](images/OSPP_MDIL-SS_8.png) + +All in all, we use the three class-different datasets to conduct training (i.e., Unseen Task Processing) and testing, and the core concern is to test the ability of the algorithm to update the model using labelled unseen samples (i.e., evaluating `Unseen Task Processing` ability). + +### 3.2 Datasets + +This project will use three datasets, namely **Cityscapes**, **SYNTHIA**, and KubeEdge SIG AI's **Cloud-Robotics** dataset (**CS**, **SYN**, **CR**). + +Ianvs has already provides [Cityscapes and SYNTHIA datasets](https://github.com/kubeedge/ianvs/blob/main/docs/proposals/algorithms/lifelong-learning/Additional-documentation/curb_detetion_datasets.md). The following two images are examples from them respectively. + +| CS Example | SYN Example | +| :----------------------------------------------------------: | :----------------------------------------------------------: | +| ![MDIL-SS](images/OSPP_MDIL-SS_1.png) |![MDIL-SS](images/OSPP_MDIL-SS_2.png) | + +In addition, this project utilizes the CR dataset from KubeEdge. + +| CR Example | +| :----------------------------------------------------------: | +| ![MDIL-SS](images/OSPP_MDIL-SS_3.png) | + +The following code is an excerpt from the `train-index-mix.txt` file. The first column represents the path to the original image, and the second column represents the corresponding label image path. + +```txt +rgb/train/20220420_garden/00480.png gtFine/train/20220420_garden/00480_TrainIds.png +rgb/train/20220420_garden/00481.png gtFine/train/20220420_garden/00481_TrainIds.png +rgb/train/20220420_garden/00483.png gtFine/train/20220420_garden/00483_TrainIds.png +``` + +The following code snippet is an excerpt from the `test-index.txt` file, which follows a similar format to the training set. + +```txt +rgb/test/20220420_garden/01357.png gtFine/test/20220420_garden/01357_TrainIds.png +rgb/test/20220420_garden/01362.png gtFine/test/20220420_garden/01362_TrainIds.png +rgb/test/20220420_garden/01386.png gtFine/test/20220420_garden/01386_TrainIds.png +rgb/test/20220420_garden/01387.png gtFine/test/20220420_garden/01387_TrainIds.png +``` + +As shown in the table below, this dataset contains 7 groups and 30 classes. + +| Group | Classes | +| :----------: | :----------------------------------------------------------: | +| flat | road · sidewalk · ramp · runway | +| human | person · rider | +| vehicle | car · truck · bus · train · motorcycle · bicycle | +| construction | building · wall · fence · stair · curb · flowerbed · door | +| object | pole · traffic sign · traffic light · CCTV camera · Manhole · hydrant · belt · dustbin | +| nature | vegetation · terrain | +| sky | sky | + +More detail about CR dataset please refer to [this link](https://github.com/kubeedge/ianvs/blob/main/docs/proposals/scenarios/Cloud-Robotics/Cloud-Robotics_zh.md). + +### 3.3 File-level Design + +The development consists of two main parts, which are **test environment (test env)** and **test algorithms**. + +Test environment can be understood as an exam paper, which specifies the dataset, evaluation metrics, and the number of increments used for testing. It is used to evaluate the performance of the "students". And test algorithms can be seen as the students who will take the exam. + +
+MDIL-SS +
+ +In addition, `benchmarkingjob.yaml` is used for integrating the configuration of test env and test algorithms, and is a necessary ianvs configuration file. + +For test env, the development work mainly focuses on the implementation of `mIoU.py`. And for test algorithms, development is concentrated on `basemodel.py`, as shown in the picture below. + +![MDIL-SS](images/OSPP_MDIL-SS_5.png) + +#### 3.3.1 Test Environment + +The following code is the `testenv.yaml` file designed for this project. + +As a configuration file for test env, it contains the 3 aspects, which are the dataset and the number of increments, model validation logic, and model evaluation metrics. + +```yaml +# testenv.yaml + +testenv: + + # 1 + dataset: + train_url: "/home/QXY/ianvs/dataset/mdil-ss-dataset/train_data/index.txt" + test_url: "/home/QXY/ianvs/dataset/mdil-ss-dataset/test_data/index.txt" + using: "CS SYN CR" + incremental_rounds: 3 + + # 2 + model_eval: + model_metric: + name: "mIoU" + url: "/home/QXY/ianvs/examples/mdil-ss/testenv/mIoU.py" + threshold: 0 + operator: ">=" + + # 3 + metrics: + - name: "mIoU" + url: "/home/QXY/ianvs/examples/mdil-ss/testenv/mIoU.py" + - name: "BWT" + - name: "FWT" +``` + +After each round of lifelong learning, the model will be evaluated on the validation set. In this project, **mIoU** (mean Intersection over Union) is used as the evaluation metric. If the model achieves an mIoU greater than the specified threshold on the validation set, the model will be updated. + +**BWT** (Backward Transfer) and **FWT** (Forward Transfer) are two important concepts in the field of lifelong learning. BWT refers to the impact of previously learned knowledge on the learning of the current task, while FWT refers to the impact of the current task on the learning of future tasks. Along with mIoU, they serve as testing metrics to assess the lifelong learning capability of the model in semantic segmentation. Functions related to BWT and FWT have already been implemented in [Ianvs repository](https://github.com/kubeedge/ianvs/blob/main/core/testcasecontroller/metrics/metrics.py). + +#### 3.3.2 Test Algorithm + +The following code is the `mdil-ss_algorithm.yaml` file designed for this project. + +```yaml +# mdil-ss_algorithm.yaml + +algorithm: + paradigm_type: "incrementallearning" + + incremental_learning_data_setting: + train_ratio: 0.8 + splitting_method: "default" + + modules: + - type: "basemodel" + + # 1 + name: "ERFNet" + url: "/home/QXY/ianvs/examples/mdil-ss/testalgorithms/mdil-ss/basemodel.py" + + # 2 + hyperparameters: + - learning_rate: + values: + - 0.01 + - 0.0001 + - epochs: + values: + - 5 + - 10 + - batch_size: + values: + - 10 + - 20 +``` + +First, `basemodel.py`, which involves encapsulating various functional components of the model, including its architecture, layers, and operations, which is the focus of development. + +Second, **hyperparameters** setting for the model is also defined in this yaml file. In addition, the evaluation system can perform tests with multiple combinations of hyperparameters at once by configuring multiple hyperparameters in `mdil-ss_algorithm.yaml`. + +#### 3.3.3 Test Report + +The test report is designed as follows, which contains the ranking, algorithm name, three metrics, dataset name, base model, three hyperparameters, and time. + +| Rank | Algorithm | mIoU_Overall | BWT | FWT | Paradigm | Round | Dataset | Basemodel | Learning_rate | Epoch | Batch_size | Time | +| :-------: | :-------: | :------: | :-----: | :-----: | :----------------: | :-----: | :--------------: | :---------: | :-------------: | :-----: | :----------: | :-------------------: | +| 1 | MDIL-SS | 0.8734 | 0.075 | 0.021 | Lifelonglearning | 3 | CS SYN CR | ERFNet | 0.0001 | 1 | 10 | 2023-05-28 17:05:15 | + +## 4 Roadmap + +### 4.1 Phase 1 (July 1st - August 15th) + +- Engage in discussions with the project mentor and the community to finalize the development details. + +- Further refine the workflow of the MDIL-SS testing task, including the relationships between different components and modules. + +- Develop the test environment, including datasets and model metrics. + +- Begin the development of the base model encapsulation for the test algorithms. + +### 4.2 Phase 2 (August 16th - September 30th) + +- Summarize the progress of Phase 1 and generate relevant documentation. + +- Complete the remaining development tasks, including models, test reports, etc. + +- Generate initial algorithm evaluation reports. + +- Engage in discussions with the project mentor and the community to further supplement and improve the project. + +- Organize the project code and related documentation, and merge them into the Ianvs repository. + +- Upon merging into the repository, explore new research areas and produce additional outcomes based on this project. diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_1.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_1.png new file mode 100644 index 00000000..26945fcb Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_1.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_10.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_10.png new file mode 100644 index 00000000..2e88f609 Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_10.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_11.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_11.png new file mode 100644 index 00000000..3a5fab3c Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_11.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_2.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_2.png new file mode 100644 index 00000000..02caae49 Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_2.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_3.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_3.png new file mode 100644 index 00000000..fac3280c Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_3.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_4.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_4.png new file mode 100644 index 00000000..62458079 Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_4.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_5.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_5.png new file mode 100644 index 00000000..a4364771 Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_5.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_7.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_7.png new file mode 100644 index 00000000..624fab5e Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_7.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_8.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_8.png new file mode 100644 index 00000000..ec0c2db1 Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_8.png differ diff --git a/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_9.png b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_9.png new file mode 100644 index 00000000..8b3ea379 Binary files /dev/null and b/docs/proposals/algorithms/lifelong-learning/images/OSPP_MDIL-SS_9.png differ diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/README.md b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/README.md new file mode 100644 index 00000000..5eb1181f --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/README.md @@ -0,0 +1,112 @@ +# Quick Start about Class Incremental Semantic Segmentation + +Welcome to Ianvs! Ianvs aims to test the performance of distributed synergy AI solutions following recognized standards, +in order to facilitate more efficient and effective development. This semantic segmentation scenario quick start guides you how to test your class incremental algorithm on Ianvs. You can reduce manual procedures to just a few steps so that you can +build and start your distributed synergy AI solution development within minutes. + +Before using Ianvs, you might want to have the device ready: +- One machine is all you need, i.e., a laptop or a virtual machine is sufficient and a cluster is not necessary +- 2 CPUs or more +- 4GB+ free memory, depends on algorithm and simulation setting +- 10GB+ free disk space +- Internet connection for GitHub and pip, etc +- Python 3.6+ installed + + +In this example, we are using the Linux platform with Python 3.8. If you are using Windows, most steps should still apply but a few like commands and package requirements might be different. + +## Step 1. Ianvs Preparation + +First, we download the code of Ianvs. Assuming that we are using `/ianvs` as workspace, Ianvs can be cloned with `Git` +as: + +``` shell +mkdir /ianvs +cd /ianvs # One might use another path preferred + +mkdir project +cd project +git clone https://github.com/kubeedge/ianvs.git +``` + + +Then, we install third-party dependencies for ianvs. +``` shell +sudo apt-get update +sudo apt-get install libgl1-mesa-glx -y +python -m pip install --upgrade pip + +cd ianvs +python -m pip install ./examples/resources/third_party/* +python -m pip install -r requirements.txt +``` + +We are now ready to install Ianvs. +``` shell +python setup.py install +``` + +## Step 2. Dataset Preparation + +Datasets and models can be large. To avoid over-size projects in the Github repository of Ianvs, the Ianvs code base does +not include origin datasets. Then developers do not need to download non-necessary datasets for a quick start. + +``` shell +mkdir dataset +cd dataset +unzip mdil-ss.zip +``` + +The URL address of this dataset then should be filled in the configuration file ``testenv.yaml``. In this quick start, +we have done that for you and the interested readers can refer to [testenv.yaml](https://ianvs.readthedocs.io/en/latest/guides/how-to-test-algorithms.html#step-1-test-environment-preparation) for more details. + + +Related algorithm is also ready in this quick start. + +``` shell +export PYTHONPATH=$PYTHONPATH:/ianvs/project/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet +``` + +The URL address of this algorithm then should be filled in the configuration file ``algorithm.yaml``. In this quick +start, we have done that for you and the interested readers can refer to [algorithm.yaml](https://ianvs.readthedocs.io/en/latest/guides/how-to-test-algorithms.html#step-1-test-environment-preparation) for more details. + + +## Step 3. Ianvs Execution and Presentation + +We are now ready to run the ianvs for benchmarking. + +``` shell +cd /ianvs/project +ianvs -f examples/class_increment_semantic_segmentation/lifelong_learning_bench/benchmarkingjob.yaml +``` + +Finally, the user can check the result of benchmarking on the console and also in the output path( +e.g. `/ianvs/project/ianvs-workspace/mdil-ss/lifelong_learning_bench`) defined in the benchmarking config file ( +e.g. `benchmarkingjob.yaml`). In this quick start, we have done all configurations for you and the interested readers +can refer to [benchmarkingJob.yaml](https://ianvs.readthedocs.io/en/latest/guides/how-to-test-algorithms.html#step-1-test-environment-preparation) for more details. + +The final output might look like this: + +| rank | algorithm | Task_Avg_Acc | BWT | FWT | paradigm | basemodel | task_definition | task_allocation | basemodel-learning_rate | basemodel-epochs | task_definition-origins | task_allocation-origins | time | url | +|:----:|:------------------------:|:--------------------:|:--------------------:|:--------------------:|:----------------:|:---------:|:----------------------:|:----------------------:|:-----------------------:|:----------------:|:-----------------------------------------:|:-----------------------------------------:|:-------------------:|:-------------------------------------------------------------------------------------------------------------------------------:| +| 1 | erfnet_lifelong_learning | 0.027414088670437726 | 0.010395591126145793 | 0.002835451693721201 | lifelonglearning | BaseModel | TaskDefinitionByDomain | TaskAllocationByDomain | 0.0001 | 1 | ['Cityscapes', 'Synthia', 'Cloud-Robotics'] | ['Cityscapes', 'Synthia', 'Cloud-Robotics'] | 2023-09-26 20:13:21 | ./ianvs-workspace/mdil-ss/lifelong_learning_bench/benchmarkingjob/erfnet_lifelong_learning/3a8c73ba-5c64-11ee-8ebd-b07b25dd6922 | + + +In addition, in the log displayed at the end of the test, you can see the accuracy of known and unknown tasks in each round, as shown in the table below (in the testing phase of round 3, all classes are seen). + + +| Round | Seen Class Accuracy | Unseen Class Accuracy | +|:-----:|:---------------------:|:-------------------:| +| 1 | 0.176 | 0.0293 | +| 2 | 0.203 | 0.0265 | +| 3 | 0.311 | 0.0000 | + + + +This ends the quick start experiment. + +# What is next + +If any problems happen, the user can refer to [the issue page on Github](https://github.com/kubeedge/ianvs/issues) for help and are also welcome to raise any new issue. + +Enjoy your journey on Ianvs! \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/benchmarkingjob.yaml b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/benchmarkingjob.yaml new file mode 100644 index 00000000..4eaa1cfe --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/benchmarkingjob.yaml @@ -0,0 +1,72 @@ +benchmarkingjob: + # job name of bechmarking; string type; + name: "benchmarkingjob" + # the url address of job workspace that will reserve the output of tests; string type; + workspace: "./ianvs-workspace/mdil-ss/lifelong_learning_bench" + + # the url address of test environment configuration file; string type; + # the file format supports yaml/yml; + testenv: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/testenv.yaml" + + # the configuration of test object + test_object: + # test type; string type; + # currently the option of value is "algorithms",the others will be added in succession. + type: "algorithms" + # test algorithm configuration files; list type; + algorithms: + # algorithm name; string type; + - name: "erfnet_lifelong_learning" + # the url address of test algorithm configuration file; string type; + # the file format supports yaml/yml + url: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/test_algorithm.yaml" + + # the configuration of ranking leaderboard + rank: + # rank leaderboard with metric of test case's evaluation and order ; list type; + # the sorting priority is based on the sequence of metrics in the list from front to back; + sort_by: [ { "accuracy": "descend" }, { "BWT": "descend" } ] + + # visualization configuration + visualization: + # mode of visualization in the leaderboard; string type; + # There are quite a few possible dataitems in the leaderboard. Not all of them can be shown simultaneously on the screen. + # In the leaderboard, we provide the "selected_only" mode for the user to configure what is shown or is not shown. + mode: "selected_only" + # method of visualization for selected dataitems; string type; + # currently the options of value are as follows: + # 1> "print_table": print selected dataitems; + method: "print_table" + + # selected dataitem configuration + # The user can add his/her interested dataitems in terms of "paradigms", "modules", "hyperparameters" and "metrics", + # so that the selected columns will be shown. + selected_dataitem: + # currently the options of value are as follows: + # 1> "all": select all paradigms in the leaderboard; + # 2> paradigms in the leaderboard, e.g., "singletasklearning" + paradigms: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all modules in the leaderboard; + # 2> modules in the leaderboard, e.g., "basemodel" + modules: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all hyperparameters in the leaderboard; + # 2> hyperparameters in the leaderboard, e.g., "momentum" + hyperparameters: [ "all" ] + # currently the options of value are as follows: + # 1> "all": select all metrics in the leaderboard; + # 2> metrics in the leaderboard, e.g., "F1_SCORE" + metrics: [ "accuracy", "BWT", "FWT"] + + # model of save selected and all dataitems in workspace `./rank` ; string type; + # currently the options of value are as follows: + # 1> "selected_and_all": save selected and all dataitems; + # 2> "selected_only": save selected dataitems; + save_mode: "selected_and_all" + + + + + + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/accuracy.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/accuracy.py new file mode 100644 index 00000000..51663185 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/accuracy.py @@ -0,0 +1,38 @@ +from basemodel import val_args +from utils.metrics import Evaluator +from tqdm import tqdm +from dataloaders import make_data_loader +from sedna.common.class_factory import ClassType, ClassFactory + +__all__ = ('accuracy') + +@ClassFactory.register(ClassType.GENERAL) +def accuracy(y_true, y_pred, **kwargs): + args = val_args() + _, _, test_loader, num_class = make_data_loader(args, test_data=y_true) + evaluator = Evaluator(num_class) + + tbar = tqdm(test_loader, desc='\r') + for i, (sample, img_path) in enumerate(tbar): + if args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if args.cuda: + image, target = image.cuda(args.gpu_ids), target.cuda(args.gpu_ids) + if args.depth: + depth = depth.cuda(args.gpu_ids) + + target[target > evaluator.num_class-1] = 255 + target = target.cpu().numpy() + # Add batch sample into evaluator + evaluator.add_batch(target, y_pred[i]) + + # Test during the training + # Acc = evaluator.Pixel_Accuracy() + CPA = evaluator.Pixel_Accuracy_Class() + mIoU = evaluator.Mean_Intersection_over_Union() + FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() + + print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU)) + return CPA \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/basemodel.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/basemodel.py new file mode 100644 index 00000000..665a0855 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/basemodel.py @@ -0,0 +1,309 @@ +import os +import numpy as np +import torch +from PIL import Image +import argparse +from train import Trainer +from eval import Validator +from tqdm import tqdm +from eval import load_my_state_dict +from utils.metrics import Evaluator +from dataloaders import make_data_loader +from dataloaders import custom_transforms as tr +from torchvision import transforms +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.common.config import Context +from sedna.datasources import TxtDataParse +from torch.utils.data import DataLoader +from sedna.common.file_ops import FileOps +from utils.lr_scheduler import LR_Scheduler + +def preprocess(image_urls): + transformed_images = [] + for paths in image_urls: + if len(paths) == 2: + img_path, depth_path = paths + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(depth_path) + else: + img_path = paths[0] + _img = Image.open(img_path).convert('RGB') + _depth = _img + + sample = {'image': _img, 'depth': _depth, 'label': _img} + composed_transforms = transforms.Compose([ + # tr.CropBlackArea(), + # tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + transformed_images.append((composed_transforms(sample), img_path)) + + return transformed_images + +class Model: + def __init__(self, **kwargs): + self.val_args = val_args() + self.train_args = train_args() + + self.train_args.lr = kwargs.get("learning_rate", 1e-4) + self.train_args.epochs = kwargs.get("epochs", 2) + self.train_args.eval_interval = kwargs.get("eval_interval", 2) + self.train_args.no_val = kwargs.get("no_val", True) + self.trainer = None + + label_save_dir = Context.get_parameters("INFERENCE_RESULT_DIR", "./inference_results") + self.val_args.color_label_save_path = os.path.join(label_save_dir, "color") + self.val_args.merge_label_save_path = os.path.join(label_save_dir, "merge") + self.val_args.label_save_path = os.path.join(label_save_dir, "label") + self.validator = Validator(self.val_args) + + def train(self, train_data, valid_data=None, **kwargs): + self.trainer = Trainer(self.train_args, train_data=train_data) + print("Total epoches:", self.trainer.args.epochs) + for epoch in range(self.trainer.args.start_epoch, self.trainer.args.epochs): + if epoch == 0 and self.trainer.val_loader: + self.trainer.validation(epoch) + self.trainer.training(epoch) + + if self.trainer.args.no_val and \ + (epoch % self.trainer.args.eval_interval == (self.trainer.args.eval_interval - 1) + or epoch == self.trainer.args.epochs - 1): + # save checkpoint when it meets eval_interval or the training finished + is_best = False + checkpoint_path = self.trainer.saver.save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': self.trainer.model.state_dict(), + 'optimizer': self.trainer.optimizer.state_dict(), + 'best_pred': self.trainer.best_pred, + }, is_best) + + self.trainer.writer.close() + + return checkpoint_path + + def predict(self, data, **kwargs): + if not isinstance(data[0][0], dict): + data = preprocess(data) + + if type(data) is np.ndarray: + data = data.tolist() + + self.validator.test_loader = DataLoader(data, batch_size=self.val_args.test_batch_size, shuffle=False, + pin_memory=True) + return self.validator.validate() + + def evaluate(self, data, **kwargs): + self.val_args.save_predicted_image = kwargs.get("save_predicted_image", True) + samples = preprocess(data.x) + predictions = self.predict(samples) + return accuracy(data.y, predictions) + + def load(self, model_url, **kwargs): + if model_url: + self.validator.new_state_dict = torch.load(model_url, map_location=torch.device("cpu")) + self.train_args.resume = model_url + else: + raise Exception("model url does not exist.") + self.validator.model = load_my_state_dict(self.validator.model, self.validator.new_state_dict['state_dict']) + + def save(self, model_path=None): + # TODO: how to save unstructured data model + pass + +def train_args(): + parser = argparse.ArgumentParser(description="PyTorch ERFNet Training") + parser.add_argument('--depth', action="store_true", default=False, + help='training with depth image or not (default: False)') + parser.add_argument('--dataset', type=str, default='cityscapes', + choices=['citylostfound', 'cityscapes', 'cityrand', 'target', 'xrlab', 'e1', 'mapillary'], + help='dataset name (default: cityscapes)') + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='dataloader threads') + parser.add_argument('--base-size', type=int, default=1024, + help='base image size') + parser.add_argument('--crop-size', type=int, default=768, + help='crop image size') + parser.add_argument('--loss-type', type=str, default='ce', + choices=['ce', 'focal'], + help='loss func type (default: ce)') + # training hyper params + # parser.add_argument('--epochs', type=int, default=None, metavar='N', + # help='number of epochs to train (default: auto)') + parser.add_argument('--epochs', type=int, default=None, metavar='N', + help='number of epochs to train (default: auto)') + parser.add_argument('--start_epoch', type=int, default=0, + metavar='N', help='start epochs (default:0)') + parser.add_argument('--batch-size', type=int, default=None, + metavar='N', help='input batch size for \ + training (default: auto)') + parser.add_argument('--val-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--test-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--use-balanced-weights', action='store_true', default=False, + help='whether to use balanced weights (default: True)') + parser.add_argument('--num-class', type=int, default=24, + help='number of training classes (default: 24') + # optimizer params + parser.add_argument('--lr', type=float, default=1e-4, metavar='LR', + help='learning rate (default: auto)') + parser.add_argument('--lr-scheduler', type=str, default='cos', + choices=['poly', 'step', 'cos', 'inv'], + help='lr scheduler mode: (default: cos)') + parser.add_argument('--momentum', type=float, default=0.9, + metavar='M', help='momentum (default: 0.9)') + parser.add_argument('--weight-decay', type=float, default=2.5e-5, + metavar='M', help='w-decay (default: 5e-4)') + # cuda, seed and logging + parser.add_argument('--no-cuda', action='store_true', default= + False, help='disables CUDA training') + parser.add_argument('--gpu-ids', type=str, default='0', + help='use which gpu to train, must be a \ + comma-separated list of integers only (default=0)') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + # checking point + parser.add_argument('--resume', type=str, + default=None, + help='put the path to resuming file if needed') + parser.add_argument('--checkname', type=str, default=None, + help='set the checkpoint name') + # finetuning pre-trained models + parser.add_argument('--ft', action='store_true', default=True, + help='finetuning on a different dataset') + # evaluation option + parser.add_argument('--eval-interval', type=int, default=1, + help='evaluation interval (default: 1)') + parser.add_argument('--no-val', action='store_true', default=False, + help='skip validation during training') + + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + print(torch.cuda.is_available()) + if args.cuda: + try: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] + except ValueError: + raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') + + if args.epochs is None: + epoches = { + 'cityscapes': 200, + 'citylostfound': 200, + } + args.epochs = epoches[args.dataset.lower()] + + if args.batch_size is None: + args.batch_size = 4 * len(args.gpu_ids) + + if args.test_batch_size is None: + args.test_batch_size = args.batch_size + + if args.lr is None: + lrs = { + 'cityscapes': 0.0001, + 'citylostfound': 0.0001, + 'cityrand': 0.0001 + } + args.lr = lrs[args.dataset.lower()] / (4 * len(args.gpu_ids)) * args.batch_size + + if args.checkname is None: + args.checkname = 'ERFNet' + print(args) + torch.manual_seed(args.seed) + + return args + +def val_args(): + parser = argparse.ArgumentParser(description="PyTorch RFNet validation") + parser.add_argument('--dataset', type=str, default='cityscapes', + choices=['citylostfound', 'cityscapes', 'xrlab', 'mapillary'], + help='dataset name (default: cityscapes)') + parser.add_argument('--workers', type=int, default=4, + metavar='N', help='dataloader threads') + parser.add_argument('--base-size', type=int, default=1024, + help='base image size') + parser.add_argument('--crop-size', type=int, default=768, + help='crop image size') + parser.add_argument('--batch-size', type=int, default=6, + help='batch size for training') + parser.add_argument('--val-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + validating (default: auto)') + parser.add_argument('--test-batch-size', type=int, default=1, + metavar='N', help='input batch size for \ + testing (default: auto)') + parser.add_argument('--num-class', type=int, default=24, + help='number of training classes (default: 24') + parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training') + parser.add_argument('--gpu-ids', type=str, default='0', + help='use which gpu to train, must be a \ + comma-separated list of integers only (default=0)') + parser.add_argument('--checkname', type=str, default=None, + help='set the checkpoint name') + parser.add_argument('--weight-path', type=str, default="./models/530_exp3_2.pth", + help='enter your path of the weight') + parser.add_argument('--save-predicted-image', action='store_true', default=False, + help='save predicted images') + parser.add_argument('--color-label-save-path', type=str, + default='./test/color/', + help='path to save label') + parser.add_argument('--merge-label-save-path', type=str, + default='./test/merge/', + help='path to save merged label') + parser.add_argument('--label-save-path', type=str, default='./test/label/', + help='path to save merged label') + parser.add_argument('--merge', action='store_true', default=True, help='merge image and label') + parser.add_argument('--depth', action='store_true', default=False, help='add depth image or not') + + args = parser.parse_args() + args.cuda = not args.no_cuda and torch.cuda.is_available() + if args.cuda: + try: + args.gpu_ids = [int(s) for s in args.gpu_ids.split(',')] + except ValueError: + raise ValueError('Argument --gpu_ids must be a comma-separated list of integers only') + + return args + +def accuracy(y_true, y_pred, **kwargs): + args = val_args() + _, _, test_loader, num_class = make_data_loader(args, test_data=y_true) + evaluator = Evaluator(num_class) + + tbar = tqdm(test_loader, desc='\r') + for i, (sample, img_path) in enumerate(tbar): + if args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if args.cuda: + image, target = image.cuda(args.gpu_ids), target.cuda(args.gpu_ids) + if args.depth: + depth = depth.cuda(args.gpu_ids) + + target[target > evaluator.num_class-1] = 255 + target = target.cpu().numpy() + # Add batch sample into evaluator + evaluator.add_batch(target, y_pred[i]) + + # Test during the training + # Acc = evaluator.Pixel_Accuracy() + CPA = evaluator.Pixel_Accuracy_Class() + mIoU = evaluator.Mean_Intersection_over_Union() + FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() + + print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU)) + return CPA + +if __name__ == '__main__': + model_path = "/tmp/RFNet/" + if not os.path.exists(model_path): + os.makedirs(model_path) + + p1 = Process(target=exp_train, args=(10,)) + p1.start() + p1.join() diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/__init__.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/__init__.py new file mode 100644 index 00000000..ec1f25bf --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/__init__.py @@ -0,0 +1,116 @@ +from dataloaders.datasets import cityscapes, citylostfound, cityrand, target, xrlab, e1, mapillary +from torch.utils.data import DataLoader + +def make_data_loader(args, train_data=None, valid_data=None, test_data=None, **kwargs): + + if args.dataset == 'cityscapes': + if train_data is not None: + train_set = cityscapes.CityscapesSegmentation(args, data=train_data, split='train') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + else: + train_loader, num_class = None, cityscapes.CityscapesSegmentation.NUM_CLASSES + + if valid_data is not None: + val_set = cityscapes.CityscapesSegmentation(args, data=valid_data, split='val') + num_class = val_set.NUM_CLASSES + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + else: + val_loader, num_class = None, cityscapes.CityscapesSegmentation.NUM_CLASSES + + if test_data is not None: + test_set = cityscapes.CityscapesSegmentation(args, data=test_data, split='test') + num_class = test_set.NUM_CLASSES + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + else: + test_loader, num_class = None, cityscapes.CityscapesSegmentation.NUM_CLASSES + + return train_loader, val_loader, test_loader, num_class + + if args.dataset == 'citylostfound': + if args.depth: + train_set = citylostfound.CitylostfoundSegmentation(args, split='train') + val_set = citylostfound.CitylostfoundSegmentation(args, split='val') + test_set = citylostfound.CitylostfoundSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + else: + train_set = citylostfound.CitylostfoundSegmentation_rgb(args, split='train') + val_set = citylostfound.CitylostfoundSegmentation_rgb(args, split='val') + test_set = citylostfound.CitylostfoundSegmentation_rgb(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, num_class + if args.dataset == 'cityrand': + train_set = cityrand.CityscapesSegmentation(args, split='train') + val_set = cityrand.CityscapesSegmentation(args, split='val') + test_set = cityrand.CityscapesSegmentation(args, split='test') + custom_set = cityrand.CityscapesSegmentation(args, split='custom_resize') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'target': + train_set = target.CityscapesSegmentation(args, split='train') + val_set = target.CityscapesSegmentation(args, split='val') + test_set = target.CityscapesSegmentation(args, split='test') + custom_set = target.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'xrlab': + train_set = xrlab.CityscapesSegmentation(args, split='train') + val_set = xrlab.CityscapesSegmentation(args, split='val') + test_set = xrlab.CityscapesSegmentation(args, split='test') + custom_set = xrlab.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'e1': + train_set = e1.CityscapesSegmentation(args, split='train') + val_set = e1.CityscapesSegmentation(args, split='val') + test_set = e1.CityscapesSegmentation(args, split='test') + custom_set = e1.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + if args.dataset == 'mapillary': + train_set = mapillary.CityscapesSegmentation(args, split='train') + val_set = mapillary.CityscapesSegmentation(args, split='val') + test_set = mapillary.CityscapesSegmentation(args, split='test') + custom_set = mapillary.CityscapesSegmentation(args, split='test') + num_class = train_set.NUM_CLASSES + train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) + val_loader = DataLoader(val_set, batch_size=args.val_batch_size, shuffle=False, **kwargs) + test_loader = DataLoader(test_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + custom_loader = DataLoader(custom_set, batch_size=args.test_batch_size, shuffle=False, **kwargs) + + return train_loader, val_loader, test_loader, custom_loader, num_class + + else: + raise NotImplementedError + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/custom_transforms.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/custom_transforms.py new file mode 100644 index 00000000..ab61821b --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/custom_transforms.py @@ -0,0 +1,237 @@ +import torch +import random +import numpy as np + +from PIL import Image, ImageOps, ImageFilter + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + img = np.array(img).astype(np.float32) + depth = np.array(depth).astype(np.float32) + mask = np.array(mask).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + # mean and std for original depth images, indicate the mean and standard deviation values for original depth images. + mean_depth = 0.12176 + std_depth = 0.09752 + + depth /= 255.0 + depth -= mean_depth + depth /= std_depth + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class ToTensor(object): + """Convert Image object in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + img = np.array(img).astype(np.float32).transpose((2, 0, 1)) + depth = np.array(depth).astype(np.float32) + mask = np.array(mask).astype(np.float32) + + img = torch.from_numpy(img).float() + depth = torch.from_numpy(depth).float() + mask = torch.from_numpy(mask).float() + + return {'image': img, + 'depth': depth, + 'label': mask} + +class CropBlackArea(object): + """ + crop black area for depth image + """ + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + width, height = img.size + # coordinate of the left, right, top, bottom boundary of the cropping region. + left = 140 + top = 30 + right = 2030 + bottom = 900 + # crop + img = img.crop((left, top, right, bottom)) + depth = depth.crop((left, top, right, bottom)) + mask = mask.crop((left, top, right, bottom)) + # resize + img = img.resize((width,height), Image.BILINEAR) + depth = depth.resize((width,height), Image.BILINEAR) + mask = mask.resize((width,height), Image.NEAREST) + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomHorizontalFlip(object): + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + depth = depth.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomRotate(object): + def __init__(self, degree): + self.degree = degree + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + rotate_degree = random.uniform(-1*self.degree, self.degree) + img = img.rotate(rotate_degree, Image.BILINEAR) + depth = depth.rotate(rotate_degree, Image.BILINEAR) + mask = mask.rotate(rotate_degree, Image.NEAREST) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomGaussianBlur(object): + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur( + radius=random.random())) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class RandomScaleCrop(object): + def __init__(self, base_size, crop_size, fill=0): + self.base_size = base_size + self.crop_size = crop_size + self.fill = fill + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + # random scale (short edge) + short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + depth = depth.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < self.crop_size: + padh = self.crop_size - oh if oh < self.crop_size else 0 + padw = self.crop_size - ow if ow < self.crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + depth = ImageOps.expand(depth, border=(0, 0, padw, padh), fill=0) # depth多余的部分填0 + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - self.crop_size) + y1 = random.randint(0, h - self.crop_size) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + depth = depth.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'depth': depth, + 'label': mask} + + +class FixScaleCrop(object): + def __init__(self, crop_size): + self.crop_size = crop_size + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + w, h = img.size + if w > h: + oh = self.crop_size + ow = int(1.0 * w * oh / h) + else: + ow = self.crop_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + depth = depth.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - self.crop_size) / 2.)) + y1 = int(round((h - self.crop_size) / 2.)) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + depth = depth.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'depth': depth, + 'label': mask} + +class FixedResize(object): + def __init__(self, size): + self.size = (size, size) # size: (h, w) + + def __call__(self, sample): + img = sample['image'] + depth = sample['depth'] + mask = sample['label'] + + assert img.size == depth.size == mask.size + + img = img.resize(self.size, Image.BILINEAR) + depth = depth.resize(self.size, Image.BILINEAR) + mask = mask.resize(self.size, Image.NEAREST) + + return {'image': img, + 'depth': depth, + 'label': mask} + +class Relabel(object): + def __init__(self, olabel, nlabel): # change trainid label from olabel to nlabel + self.olabel = olabel + self.nlabel = nlabel + + def __call__(self, tensor): + # assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, + # torch.ByteTensor)), 'tensor needs to be LongTensor' + tensor[tensor == self.olabel] = self.nlabel + return tensor \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/custom_transforms_rgb.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/custom_transforms_rgb.py new file mode 100644 index 00000000..e04ef5a3 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/custom_transforms_rgb.py @@ -0,0 +1,230 @@ +import torch +import random +import numpy as np + +from PIL import Image, ImageOps, ImageFilter + +class Normalize(object): + """Normalize a tensor image with mean and standard deviation. + Args: + mean (tuple): means for each channel. + std (tuple): standard deviations for each channel. + """ + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + img = np.array(img).astype(np.float32) + mask = np.array(mask).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + return {'image': img, + 'label': mask} + + +class Normalize_test(object): + def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)): + self.mean = mean + self.std = std + + def __call__(self, sample): + img = sample + img = np.array(img).astype(np.float32) + img /= 255.0 + img -= self.mean + img /= self.std + + return img + + +class ToTensor(object): + """Convert Image object in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample['image'] + mask = sample['label'] + img = np.array(img).astype(np.float32).transpose((2, 0, 1)) + mask = np.array(mask).astype(np.float32) + + img = torch.from_numpy(img).float() + mask = torch.from_numpy(mask).float() + + return {'image': img, + 'label': mask} + +class CropBlackArea(object): + """ + crop black area for depth image + """ + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + width, height = img.size + left = 140 + top = 30 + right = 2030 + bottom = 900 + # crop + img = img.crop((left, top, right, bottom)) + mask = mask.crop((left, top, right, bottom)) + # resize + img = img.resize((width,height), Image.BILINEAR) + mask = mask.resize((width,height), Image.NEAREST) + # img = img.resize((512,1024), Image.BILINEAR) + # mask = mask.resize((512,1024), Image.NEAREST) + print(img.size) + + return {'image': img, + 'label': mask} + +class ToTensor_test(object): + """Convert Image object in sample to Tensors.""" + + def __call__(self, sample): + # swap color axis because + # numpy image: H x W x C + # torch image: C X H X W + img = sample + img = np.array(img).astype(np.float32).transpose((2, 0, 1)) + + img = torch.from_numpy(img).float() + + return img + + +class RandomHorizontalFlip(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + if random.random() < 0.5: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + mask = mask.transpose(Image.FLIP_LEFT_RIGHT) + + return {'image': img, + 'label': mask} + + +class RandomRotate(object): + def __init__(self, degree): + self.degree = degree + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + rotate_degree = random.uniform(-1*self.degree, self.degree) + img = img.rotate(rotate_degree, Image.BILINEAR) + mask = mask.rotate(rotate_degree, Image.NEAREST) + + return {'image': img, + 'label': mask} + + +class RandomGaussianBlur(object): + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + if random.random() < 0.5: + img = img.filter(ImageFilter.GaussianBlur( + radius=random.random())) + + return {'image': img, + 'label': mask} + + +class RandomScaleCrop(object): + def __init__(self, base_size, crop_size, fill=0): + self.base_size = base_size + self.crop_size = crop_size + self.fill = fill + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + # random scale (short edge) + short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0)) + w, h = img.size + if h > w: + ow = short_size + oh = int(1.0 * h * ow / w) + else: + oh = short_size + ow = int(1.0 * w * oh / h) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # pad crop + if short_size < self.crop_size: + padh = self.crop_size - oh if oh < self.crop_size else 0 + padw = self.crop_size - ow if ow < self.crop_size else 0 + img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) + mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=self.fill) + # random crop crop_size + w, h = img.size + x1 = random.randint(0, w - self.crop_size) + y1 = random.randint(0, h - self.crop_size) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'label': mask} + + +class FixScaleCrop(object): + def __init__(self, crop_size): + self.crop_size = crop_size + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + w, h = img.size + if w > h: + oh = self.crop_size + ow = int(1.0 * w * oh / h) + else: + ow = self.crop_size + oh = int(1.0 * h * ow / w) + img = img.resize((ow, oh), Image.BILINEAR) + mask = mask.resize((ow, oh), Image.NEAREST) + # center crop + w, h = img.size + x1 = int(round((w - self.crop_size) / 2.)) + y1 = int(round((h - self.crop_size) / 2.)) + img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size)) + + return {'image': img, + 'label': mask} + +class FixedResize(object): + def __init__(self, size): + self.size = (size, size) # size: (h, w) + + def __call__(self, sample): + img = sample['image'] + mask = sample['label'] + + assert img.size == mask.size + + img = img.resize(self.size, Image.BILINEAR) + mask = mask.resize(self.size, Image.NEAREST) + + return {'image': img, + 'label': mask} + +class Relabel(object): + def __init__(self, olabel, nlabel): # change trainid label from olabel to nlabel + self.olabel = olabel + self.nlabel = nlabel + + def __call__(self, tensor): + # assert (isinstance(tensor, torch.LongTensor) or isinstance(tensor, + # torch.ByteTensor)), 'tensor needs to be LongTensor' + tensor[tensor == self.olabel] = self.nlabel + return tensor \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/__init__.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/citylostfound.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/citylostfound.py new file mode 100644 index 00000000..ff46a6d9 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/citylostfound.py @@ -0,0 +1,273 @@ +import os +import numpy as np +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr +from dataloaders import custom_transforms_rgb as tr_rgb + +class CitylostfoundSegmentation(data.Dataset): + NUM_CLASSES = 20 + + def __init__(self, args, root=Path.db_root_dir('citylostfound'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root,'disparity',self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix= '.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix= '.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, + suffix='labelTrainIds.png') + self.labels[split].sort() + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + lbl_path = self.labels[self.split][index].rstrip() + + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) + if self.split == 'train': + if index < 1036: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + elif self.split == 'val': + if index < 1203: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + _target = Image.fromarray(_tmp) + + sample = {'image': _img, 'depth': _depth, 'label': _target} + + # data augment + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample) + + + def relabel_lostandfound(self, input): + input = tr.Relabel(0, self.ignore_index)(input) # background->255 ignore + input = tr.Relabel(1, 0)(input) # road 1->0 + input = tr.Relabel(2, 19)(input) # obstacle 19 + return input + + def recursive_glob(self, rootdir='.', suffix=None): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + if isinstance(suffix, str): + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + elif isinstance(suffix, list): + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for x in suffix for filename in filenames if filename.startswith(x)] + + + def transform_tr(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # help standardize the pixel values to have a mean of (0, 0, 0) and a standard deviation of (1, 1, 1). + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + +class CitylostfoundSegmentation_rgb(data.Dataset): + NUM_CLASSES = 19 + + def __init__(self, args, root=Path.db_root_dir('citylostfound'), split="train"): + + self.root = root + self.split = split + self.args = args + self.files = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.files[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.files[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='labelTrainIds.png') + self.labels[split].sort() + + self.ignore_index = 255 + + if not self.files[split]: + raise Exception("No files for split=[%s] found in %s" % (split, self.images_base)) + + print("Found %d %s images" % (len(self.files[split]), split)) + + def __len__(self): + return len(self.files[self.split]) + + def __getitem__(self, index): + + img_path = self.files[self.split][index].rstrip() + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _tmp = np.array(Image.open(lbl_path), dtype=np.uint8) + if self.split == 'train': + if index < 1036: # threshold for lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + elif self.split == 'val': + if index < 1203: # lostandfound + _tmp = self.relabel_lostandfound(_tmp) + else: # cityscapes + pass + _target = Image.fromarray(_tmp) + + sample = {'image': _img, 'label': _target} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample) + + + def relabel_lostandfound(self, input): + input = tr.Relabel(0, self.ignore_index)(input) + input = tr.Relabel(1, 0)(input) # road 1->0 + input = tr.Relabel(2, 19)(input) # obstacle 19 + return input + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr_rgb.CropBlackArea(), + tr_rgb.RandomHorizontalFlip(), + tr_rgb.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + tr_rgb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr_rgb.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr_rgb.CropBlackArea(), + tr_rgb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr_rgb.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr_rgb.FixedResize(size=self.args.crop_size), + tr_rgb.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr_rgb.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CitylostfoundSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/cityrand.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/cityrand.py new file mode 100644 index 00000000..74eddb67 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/cityrand.py @@ -0,0 +1,151 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 19 + + def __init__(self, args, root=Path.db_root_dir('cityrand'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='TrainIds.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/cityscapes.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/cityscapes.py new file mode 100644 index 00000000..19b9f51a --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/cityscapes.py @@ -0,0 +1,156 @@ +import os +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" +import numpy as np +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 30 # 25 + + def __init__(self, args, root=Path.db_root_dir('cityscapes'), data=None, split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.disparities_base = os.path.join(self.root, self.split, "depth", "cityscapes_real") + self.images[split] = [img[0] for img in data.x] if hasattr(data, "x") else data + + + if hasattr(data, "x") and len(data.x[0]) == 1: + # TODO: fit the case that depth images don't exist. + self.disparities[split] = self.images[split] + elif hasattr(data, "x") and len(data.x[0]) == 2: + self.disparities[split] = [img[1] for img in data.x] + else: + if len(data[0]) == 2: + self.images[split] = [img[0] for img in data] + self.disparities[split] = [img[1] for img in data] + elif len(data[0]) == 1: + self.images[split] = [img[0] for img in data] + self.disparities[split] = [img[0] for img in data] + else: + self.images[split] = data + self.disparities[split] = data + + self.labels[split] = data.y if hasattr(data, "y") else data + + self.ignore_index = 255 + + if len(self.images[split]) == 0: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if len(self.disparities[split]) == 0: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/e1.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/e1.py new file mode 100644 index 00000000..40e06e98 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/e1.py @@ -0,0 +1,151 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 + + def __init__(self, args, root=Path.db_root_dir('e1'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='.png') + self.labels[split].sort() + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + #tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/mapillary.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/mapillary.py new file mode 100644 index 00000000..d665649b --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/mapillary.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 + + def __init__(self, args, root=Path.db_root_dir('mapillary'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 768 + args.crop_size = 768 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/target.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/target.py new file mode 100644 index 00000000..739e85f8 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/target.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 24 + + def __init__(self, args, root=Path.db_root_dir('target'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='TrainIds.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/xrlab.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/xrlab.py new file mode 100644 index 00000000..4b261fcd --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/datasets/xrlab.py @@ -0,0 +1,152 @@ +import os +import numpy as np +import scipy.misc as m +from PIL import Image +from torch.utils import data +from mypath import Path +from torchvision import transforms +from dataloaders import custom_transforms as tr + +class CityscapesSegmentation(data.Dataset): + NUM_CLASSES = 25 + + def __init__(self, args, root=Path.db_root_dir('xrlab'), split="train"): + + self.root = root + self.split = split + self.args = args + self.images = {} + self.disparities = {} + self.labels = {} + + self.images_base = os.path.join(self.root, 'leftImg8bit', self.split) + self.disparities_base = os.path.join(self.root, 'disparity', self.split) + self.annotations_base = os.path.join(self.root, 'gtFine', self.split) + + self.images[split] = self.recursive_glob(rootdir=self.images_base, suffix='.png') + self.images[split].sort() + + self.disparities[split] = self.recursive_glob(rootdir=self.disparities_base, suffix='.png') + self.disparities[split].sort() + + self.labels[split] = self.recursive_glob(rootdir=self.annotations_base, suffix='.png') + self.labels[split].sort() + + + self.ignore_index = 255 + + if not self.images[split]: + raise Exception("No RGB images for split=[%s] found in %s" % (split, self.images_base)) + if not self.disparities[split]: + raise Exception("No depth images for split=[%s] found in %s" % (split, self.disparities_base)) + + print("Found %d %s RGB images" % (len(self.images[split]), split)) + print("Found %d %s disparity images" % (len(self.disparities[split]), split)) + + + def __len__(self): + return len(self.images[self.split]) + + def __getitem__(self, index): + + img_path = self.images[self.split][index].rstrip() + disp_path = self.disparities[self.split][index].rstrip() + #print(index) + try: + lbl_path = self.labels[self.split][index].rstrip() + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + _target = Image.open(lbl_path) + sample = {'image': _img,'depth':_depth, 'label': _target} + except: + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(disp_path) + sample = {'image': _img,'depth':_depth, 'label': _img} + + if self.split == 'train': + return self.transform_tr(sample) + elif self.split == 'val': + return self.transform_val(sample), img_path + elif self.split == 'test': + return self.transform_ts(sample), img_path + elif self.split == 'custom_resize': + return self.transform_ts(sample), img_path + + + def recursive_glob(self, rootdir='.', suffix=''): + """Performs recursive glob with given suffix and rootdir + :param rootdir is the root directory + :param suffix is the suffix to be searched + """ + return [os.path.join(looproot, filename) + for looproot, _, filenames in os.walk(rootdir) + for filename in filenames if filename.endswith(suffix)] + + def transform_tr(self, sample): + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.RandomHorizontalFlip(), + tr.RandomScaleCrop(base_size=self.args.base_size, crop_size=self.args.crop_size, fill=255), + # tr.RandomGaussianBlur(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_val(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + + def transform_ts(self, sample): + + composed_transforms = transforms.Compose([ + tr.CropBlackArea(), + #tr.FixedResize(size=self.args.crop_size), + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + return composed_transforms(sample) + +if __name__ == '__main__': + from dataloaders.utils import decode_segmap + from torch.utils.data import DataLoader + import matplotlib.pyplot as plt + import argparse + + parser = argparse.ArgumentParser() + args = parser.parse_args() + args.base_size = 513 + args.crop_size = 513 + + cityscapes_train = CityscapesSegmentation(args, split='train') + + dataloader = DataLoader(cityscapes_train, batch_size=2, shuffle=True, num_workers=2) + + for ii, sample in enumerate(dataloader): + for jj in range(sample["image"].size()[0]): + img = sample['image'].numpy() + gt = sample['label'].numpy() + tmp = np.array(gt[jj]).astype(np.uint8) + segmap = decode_segmap(tmp, dataset='cityscapes') + img_tmp = np.transpose(img[jj], axes=[1, 2, 0]) + img_tmp *= (0.229, 0.224, 0.225) + img_tmp += (0.485, 0.456, 0.406) + img_tmp *= 255.0 + img_tmp = img_tmp.astype(np.uint8) + plt.figure() + plt.title('display') + plt.subplot(211) + plt.imshow(img_tmp) + plt.subplot(212) + plt.imshow(segmap) + + if ii == 1: + break + + plt.show(block=True) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/utils.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/utils.py new file mode 100644 index 00000000..ef572332 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/dataloaders/utils.py @@ -0,0 +1,244 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + +def decode_seg_map_sequence(label_masks, dataset='pascal'): + rgb_masks = [] + for label_mask in label_masks: + rgb_mask = decode_segmap(label_mask, dataset) + rgb_masks.append(rgb_mask) + rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) # change for val + return rgb_masks + + +def decode_segmap(label_mask, dataset, plot=False): + """Decode segmentation class labels into a color image + Args: + label_mask (np.ndarray): an (M,N) array of integer values denoting + the class label at each spatial location. + plot (bool, optional): whether to show the resulting color image + in a figure. + Returns: + (np.ndarray, optional): the resulting decoded color image. + """ + if dataset == 'pascal' or dataset == 'coco': + n_classes = 21 + label_colours = get_pascal_labels() + elif dataset == 'cityscapes': + n_classes = 19 + label_colours = get_cityscapes_labels() + elif dataset == 'target': + n_classes = 24 + label_colours = get_cityscapes_labels() + elif dataset == 'cityrand': + n_classes = 19 + label_colours = get_cityscapes_labels() + elif dataset == 'citylostfound': + n_classes = 20 + label_colours = get_citylostfound_labels() + elif dataset == 'xrlab': + n_classes = 25 + label_colours = get_cityscapes_labels() + elif dataset == 'e1': + n_classes = 24 + label_colours = get_cityscapes_labels() + elif dataset == 'mapillary': + n_classes = 24 + label_colours = get_cityscapes_labels() + else: + raise NotImplementedError + + r = label_mask.copy() + g = label_mask.copy() + b = label_mask.copy() + for ll in range(0, n_classes): + r[label_mask == ll] = label_colours[ll, 0] + g[label_mask == ll] = label_colours[ll, 1] + b[label_mask == ll] = label_colours[ll, 2] + rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # change for val + # rgb = torch.ByteTensor(3, label_mask.shape[0], label_mask.shape[1]).fill_(0) + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + # r = torch.from_numpy(r) + # g = torch.from_numpy(g) + # b = torch.from_numpy(b) + + rgb[:, :, 0] = r / 255.0 + rgb[:, :, 1] = g / 255.0 + rgb[:, :, 2] = b / 255.0 + if plot: + plt.imshow(rgb) + plt.show() + else: + return rgb + + +def encode_segmap(mask): + """Encode segmentation label images as pascal classes + Args: + mask (np.ndarray): raw segmentation label image of dimension + (M, N, 3), in which the Pascal classes are encoded as colours. + Returns: + (np.ndarray): class map with dimensions (M,N), where the value at + a given location is the integer denoting the class index. + """ + mask = mask.astype(int) + label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16) + for ii, label in enumerate(get_pascal_labels()): + label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii + label_mask = label_mask.astype(int) + return label_mask + + +def get_cityscapes_labels(): + return np.array([ + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [0, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + [119, 11, 119], + [128, 64, 64], + [102, 10, 156], + [102, 102, 15], + [10, 102, 156], + [10, 102, 156], + [10, 102, 156], + [10, 102, 156]]) + +def get_citylostfound_labels(): + return np.array([ + [128, 64, 128], + [244, 35, 232], + [70, 70, 70], + [102, 102, 156], + [190, 153, 153], + [153, 153, 153], + [250, 170, 30], + [220, 220, 0], + [107, 142, 35], + [152, 251, 152], + [0, 130, 180], + [220, 20, 60], + [255, 0, 0], + [0, 0, 142], + [0, 0, 70], + [0, 60, 100], + [0, 80, 100], + [0, 0, 230], + [119, 11, 32], + [111, 74, 0]]) + + +def get_pascal_labels(): + """Load the mapping that associates pascal classes with label colors + Returns: + np.ndarray with dimensions (21, 3) + """ + return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0], + [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128], + [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0], + [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128], + [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0], + [0, 64, 128]]) + + +def colormap_bdd(n): + cmap=np.zeros([n, 3]).astype(np.uint8) + cmap[0,:] = np.array([128, 64, 128]) + cmap[1,:] = np.array([244, 35, 232]) + cmap[2,:] = np.array([ 70, 70, 70]) + cmap[3,:] = np.array([102, 102, 156]) + cmap[4,:] = np.array([190, 153, 153]) + cmap[5,:] = np.array([153, 153, 153]) + + cmap[6,:] = np.array([250, 170, 30]) + cmap[7,:] = np.array([220, 220, 0]) + cmap[8,:] = np.array([107, 142, 35]) + cmap[9,:] = np.array([152, 251, 152]) + cmap[10,:]= np.array([70, 130, 180]) + + cmap[11,:]= np.array([220, 20, 60]) + cmap[12,:]= np.array([255, 0, 0]) + cmap[13,:]= np.array([0, 0, 142]) + cmap[14,:]= np.array([0, 0, 70]) + cmap[15,:]= np.array([0, 60, 100]) + + cmap[16,:]= np.array([0, 80, 100]) + cmap[17,:]= np.array([0, 0, 230]) + cmap[18,:]= np.array([119, 11, 32]) + cmap[19,:]= np.array([111, 74, 0]) #多加了一类small obstacle + + return cmap + +def colormap_bdd0(n): + cmap=np.zeros([n, 3]).astype(np.uint8) + cmap[0,:] = np.array([0, 0, 0]) + cmap[1,:] = np.array([70, 130, 180]) + cmap[2,:] = np.array([70, 70, 70]) + cmap[3,:] = np.array([128, 64, 128]) + cmap[4,:] = np.array([244, 35, 232]) + cmap[5,:] = np.array([64, 64, 128]) + + cmap[6,:] = np.array([107, 142, 35]) + cmap[7,:] = np.array([153, 153, 153]) + cmap[8,:] = np.array([0, 0, 142]) + cmap[9,:] = np.array([220, 220, 0]) + cmap[10,:]= np.array([220, 20, 60]) + + cmap[11,:]= np.array([119, 11, 32]) + cmap[12,:]= np.array([0, 0, 230]) + cmap[13,:]= np.array([250, 170, 160]) + cmap[14,:]= np.array([128, 64, 64]) + cmap[15,:]= np.array([250, 170, 30]) + + cmap[16,:]= np.array([152, 251, 152]) + cmap[17,:]= np.array([255, 0, 0]) + cmap[18,:]= np.array([0, 0, 70]) + cmap[19,:]= np.array([0, 60, 100]) #small obstacle + cmap[20,:]= np.array([0, 80, 100]) + cmap[21,:]= np.array([102, 102, 156]) + cmap[22,:]= np.array([102, 102, 156]) + + return cmap + +class Colorize: + + def __init__(self, n=24): # n = nClasses + # self.cmap = colormap(256) + self.cmap = colormap_bdd(256) + self.cmap[n] = self.cmap[-1] + self.cmap = torch.from_numpy(self.cmap[:n]) + + def __call__(self, gray_image): + size = gray_image.size() + # print(size) + color_images = torch.ByteTensor(size[0], 3, size[1], size[2]).fill_(0) + # color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0) + + # for label in range(1, len(self.cmap)): + for i in range(color_images.shape[0]): + for label in range(0, len(self.cmap)): + mask = gray_image[0] == label + # mask = gray_image == label + + color_images[i][0][mask] = self.cmap[label][0] + color_images[i][1][mask] = self.cmap[label][1] + color_images[i][2][mask] = self.cmap[label][2] + + return color_images diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/eval.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/eval.py new file mode 100644 index 00000000..1db7c73e --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/eval.py @@ -0,0 +1,194 @@ +import argparse +import os +import numpy as np +from tqdm import tqdm +import time +import torch +from torchvision.transforms import ToPILImage +from PIL import Image + +from dataloaders import make_data_loader +from dataloaders.utils import decode_seg_map_sequence, Colorize +from utils.metrics import Evaluator +from models.erfnet_RA_parallel import Net as Net_RAP +import torch.backends.cudnn as cudnn + +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" + +class Validator(object): + def __init__(self, args, data=None, unseen_detection=False): + self.args = args + self.time_train = [] + self.num_class = args.num_class # [13, 30, 30] + self.current_domain = args.current_domain # 0 when start + self.next_domain = args.next_domain # 1 when start + + if self.current_domain <= 0: + self.current_class = [self.num_class[0]] + elif self.current_domain == 1: + self.current_class = self.num_class[:2] + elif self.current_domain >= 2: + self.current_class = self.num_class + else: + pass + + # Define Dataloader + kwargs = {'num_workers': args.workers, 'pin_memory': False} + _, _, self.test_loader, _ = make_data_loader(args, test_data=data, **kwargs) + + # Define evaluator + self.evaluator = Evaluator(self.num_class[self.current_domain]) + + # Define network + self.model = Net_RAP(num_classes=self.current_class, nb_tasks=self.current_domain + 1, cur_task=self.current_domain) + + args.current_domain = self.next_domain + args.next_domain += 1 + if args.cuda: + #self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) + self.model = self.model.cuda(args.gpu_ids) + cudnn.benchmark = True # accelarate speed + print('Model loaded successfully!') + + def validate(self): + self.model.eval() + self.evaluator.reset() + tbar = tqdm(self.test_loader, desc='\r') + predictions = [] + for i, (sample, image_name) in enumerate(tbar): + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + # spec = time.time() + image, target = sample['image'], sample['label'] + #print(self.args.cuda, self.args.gpu_ids) + if self.args.cuda: + image = image.cuda(self.args.gpu_ids) + if self.args.depth: + depth = depth.cuda(self.args.gpu_ids) + + with torch.no_grad(): + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image,self.current_domain) + + if self.args.cuda: + torch.cuda.synchronize() + + pred = output.data.cpu().numpy() + # todo + pred = np.argmax(pred, axis=1) + predictions.append(pred) + + if not self.args.save_predicted_image: + continue + + pre_colors = Colorize()(torch.max(output, 1)[1].detach().cpu().byte()) + pre_labels = torch.max(output, 1)[1].detach().cpu().byte() + print(pre_labels.shape) + # save + for i in range(pre_colors.shape[0]): + print(image_name[0]) + + if not image_name[0]: + img_name = "test.png" + else: + img_name = os.path.basename(image_name[0]) + + color_label_name = os.path.join(self.args.color_label_save_path, img_name) + label_name = os.path.join(self.args.label_save_path, img_name) + merge_label_name = os.path.join(self.args.merge_label_save_path, img_name) + + os.makedirs(os.path.dirname(color_label_name), exist_ok=True) + os.makedirs(os.path.dirname(merge_label_name), exist_ok=True) + os.makedirs(os.path.dirname(label_name), exist_ok=True) + + pre_color_image = ToPILImage()(pre_colors[i]) # pre_colors.dtype = float64 + pre_color_image.save(color_label_name) + + pre_label_image = ToPILImage()(pre_labels[i]) + pre_label_image.save(label_name) + + if (self.args.merge): + image_merge(image[i], pre_color_image, merge_label_name) + print('save image: {}'.format(merge_label_name)) + #print("start validating 120") + return predictions + + def task_divide(self): + seen_task_samples, unseen_task_samples = [], [] + self.model.eval() + self.evaluator.reset() + tbar = tqdm(self.test_loader, desc='\r') + for i, (sample, image_name) in enumerate(tbar): + + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if self.args.cuda: + image = image.cuda(self.args.gpu_ids) + if self.args.depth: + depth = depth.cuda(self.args.gpu_ids) + start_time = time.time() + with torch.no_grad(): + if self.args.depth: + output_, output, _ = self.model(image, depth) + else: + output_, output, _ = self.model(image) + if self.args.cuda: + torch.cuda.synchronize() + if i != 0: + fwt = time.time() - start_time + self.time_train.append(fwt) + print("Forward time per img (bath size=%d): %.3f (Mean: %.3f)" % ( + self.args.val_batch_size, fwt / self.args.val_batch_size, + sum(self.time_train) / len(self.time_train) / self.args.val_batch_size)) + time.sleep(0.1) # to avoid overheating the GPU too much + + # pred colorize + pre_colors = Colorize()(torch.max(output, 1)[1].detach().cpu().byte()) + pre_labels = torch.max(output, 1)[1].detach().cpu().byte() + for i in range(pre_colors.shape[0]): + task_sample = dict() + task_sample.update(image=sample["image"][i]) + task_sample.update(label=sample["label"][i]) + if self.args.depth: + task_sample.update(depth=sample["depth"][i]) + + if torch.max(pre_labels) == output.shape[1] - 1: + unseen_task_samples.append((task_sample, image_name[i])) + else: + seen_task_samples.append((task_sample, image_name[i])) + + return seen_task_samples, unseen_task_samples + +def image_merge(image, label, save_name): + image = ToPILImage()(image.detach().cpu().byte()) + # width, height = image.size + left = 140 + top = 30 + right = 2030 + bottom = 900 + # crop + image = image.crop((left, top, right, bottom)) + # resize + image = image.resize(label.size, Image.BILINEAR) + + image = image.convert('RGBA') + label = label.convert('RGBA') + image = Image.blend(image, label, 0.6) + image.save(save_name) + +def load_my_state_dict(model, state_dict): # custom function to load model when not all dict elements + own_state = model.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + # print('{} not in model_state'.format(name)) + continue + else: + own_state[name].copy_(param) + + return model diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/erfnet.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/erfnet.py new file mode 100644 index 00000000..fe9b0f7d --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/erfnet.py @@ -0,0 +1,210 @@ +# ERFNet full model definition for Pytorch +# Sept 2017 +# Eduardo Romera +####################### + +import torch +import torch.nn as nn +import torch.nn.init as init +import torch.nn.functional as F + +class DownsamplerBlock (nn.Module): + def __init__(self, ninput, noutput, nb_tasks=1): + super().__init__() + + self.conv = nn.Conv2d(ninput, noutput-ninput, (3, 3), stride=2, padding=1, bias=True) + self.pool = nn.MaxPool2d(2, stride=2) + self.bn_ini = nn.ModuleList([nn.BatchNorm2d(noutput, eps=1e-3) for i in range(nb_tasks)]) + + def forward(self, input): + task = current_task + output = torch.cat([self.conv(input), self.pool(input)], 1) + output = self.bn_ini[task](output) + return F.relu(output) + + +class non_bottleneck_1d (nn.Module): + def __init__(self, chann, dropprob, dilated): + super().__init__() + + self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) + + self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) + + self.bn1 = nn.BatchNorm2d(chann, eps=1e-03) + + self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=( + 1*dilated, 0), bias=True, dilation=(dilated, 1)) + + self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=( + 0, 1*dilated), bias=True, dilation=(1, dilated)) + + self.bn2 = nn.BatchNorm2d(chann, eps=1e-03) + + self.dropout = nn.Dropout2d(dropprob) + + def forward(self, input): + + output = self.conv3x1_1(input) + output = F.relu(output) + output = self.conv1x3_1(output) + output = self.bn1(output) + output = F.relu(output) + + output = self.conv3x1_2(output) + output = F.relu(output) + output = self.conv1x3_2(output) + output = self.bn2(output) + + if (self.dropout.p != 0): + output = self.dropout(output) + + return F.relu(output+input) # +input = identity (residual connection) + + +class non_bottleneck_1d_RAP (nn.Module): + def __init__(self, chann, dropprob, dilated, nb_tasks=1): + #chann = #channels, dropprob=dropout probability, dilated=dilation rate + super().__init__() + + self.conv3x1_1 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=(1, 0), bias=True) + self.conv1x3_1 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=(0, 1), bias=True) + + # domain-specific 1x1conv + self.parallel_conv_1 = nn.ModuleList([nn.Conv2d(chann, chann, kernel_size=1, stride=1, padding=0, bias=True) for i in range(nb_tasks)]) #nb_tasks=1 for 1st time, its only on CS + self.bns_1 = nn.ModuleList([nn.BatchNorm2d(chann, eps=1e-03) for i in range(nb_tasks)]) + + self.conv3x1_2 = nn.Conv2d(chann, chann, (3, 1), stride=1, padding=( + 1*dilated, 0), bias=True, dilation=(dilated, 1)) + + self.conv1x3_2 = nn.Conv2d(chann, chann, (1, 3), stride=1, padding=( + 0, 1*dilated), bias=True, dilation=(1, dilated)) + + self.parallel_conv_2 = nn.ModuleList([nn.Conv2d(chann, chann, kernel_size=1, stride=1, padding=0, bias=True) for i in range(nb_tasks)]) + self.bns_2 = nn.ModuleList([nn.BatchNorm2d(chann, eps=1e-03) for i in range(nb_tasks)]) + + self.dropout = nn.Dropout2d(dropprob) + + def forward(self, input): + task = current_task + # print('input: ', input.size()) + output = self.conv3x1_1(input) + output = F.relu(output) + output = self.conv1x3_1(output) + # print('output 2nd 1x3: ', output.size()) + + output = output + self.parallel_conv_1[task](input) # RAP skip connection for conv2 + output = self.bns_1[task](output) + + output_ = F.relu(output) + + output = self.conv3x1_2(output_) + output = F.relu(output) + output = self.conv1x3_2(output) + + output = output + self.parallel_conv_2[task](output_) # RAP skip connection for conv2 + output = self.bns_2[task](output) + + if (self.dropout.p != 0): + output = self.dropout(output) + + return F.relu(output+input) # +input = identity (residual connection) + +''' +ENCODER will use the non_bottleneck_1d_RAP modules as they have the parallel residual adapters. +DECODER will use non_bottleneck_1d modules as they don't have RAPs and we need RAPs only in the encoder. + +only encoder has shared and domain-specific RAPs. decoder is domain specific +it'll be like decoder.0, decoder.1 +for domain-specific RAPs and bns, it'll be like parallel_conv_2.0.weight, parallel_conv_2.1.weight +''' +class Encoder(nn.Module): + def __init__(self, nb_tasks=1): + super().__init__() + self.initial_block = DownsamplerBlock(3, 16, nb_tasks) + + self.layers = nn.ModuleList() + + self.layers.append(DownsamplerBlock(16, 64, nb_tasks)) + + for x in range(0, 5): # 5 times + self.layers.append(non_bottleneck_1d_RAP(64, 0.03, 1, nb_tasks)) + + self.layers.append(DownsamplerBlock(64, 128, nb_tasks)) + + for x in range(0, 2): # 2 times + self.layers.append(non_bottleneck_1d_RAP(128, 0.3, 2, nb_tasks)) # dropprob for imagenet pretrained encoder is 0.1 not 0.3, here using 0.3 for imagenet pretrained encoder + self.layers.append(non_bottleneck_1d_RAP(128, 0.3, 4, nb_tasks)) + self.layers.append(non_bottleneck_1d_RAP(128, 0.3, 8, nb_tasks)) + self.layers.append(non_bottleneck_1d_RAP(128, 0.3, 16, nb_tasks)) + + def forward(self, input, predict=False): + output = self.initial_block(input) + + for layer in self.layers: + output = layer(output) + + return output + + +class UpsamplerBlock (nn.Module): + def __init__(self, ninput, noutput): + super().__init__() + self.conv = nn.ConvTranspose2d(ninput, noutput, 3, stride=2, + padding=1, output_padding=1, bias=True) + self.bn = nn.BatchNorm2d(noutput, eps=1e-3) + + def forward(self, input): + output = self.conv(input) + output = self.bn(output) + return F.relu(output) + + +class Decoder (nn.Module): + def __init__(self, num_classes): + super().__init__() + + self.layers = nn.ModuleList() + + self.layers.append(UpsamplerBlock(128, 64)) + self.layers.append(non_bottleneck_1d(64, 0, 1)) + self.layers.append(non_bottleneck_1d(64, 0, 1)) + + self.layers.append(UpsamplerBlock(64, 16)) + self.layers.append(non_bottleneck_1d(16, 0, 1)) + self.layers.append(non_bottleneck_1d(16, 0, 1)) + + self.output_conv = nn.ConvTranspose2d( + 16, num_classes, 2, stride=2, padding=0, output_padding=0, bias=True) + + def forward(self, input): + output = input + + for layer in self.layers: + output = layer(output) + + output = self.output_conv(output) + + return output + +# ERFNet + +class Net(nn.Module): + def __init__(self, num_classes = [13], nb_tasks=1, cur_task=0): # use encoder to pass pretrained encoder + # the encoder has been passed here in this manner because we needed an Imagenet pretrained encoder. so used erfnet_imagenet to initialize encoder and read it from saved pretrained model. we want to attach that encoder to our decoder + # encoder is not being passed. figure out another way of initialising with the imagenet pretrained encoder weights, on this encoder. init to be handled. + super().__init__() + + global current_task + current_task = cur_task + + self.encoder = Encoder(nb_tasks) + + self.decoder = nn.ModuleList([Decoder(num_classes[i]) for i in range(nb_tasks)]) + + def forward(self, input, task): + global current_task + current_task = task # chose which branch of forward pass you need based on if training on current dataset or validating on a previous dataset. + output = self.encoder(input) + output = self.decoder[task].forward(output) + return output diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/replicate.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/replicate.py new file mode 100644 index 00000000..3734266e --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/replicate.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# File : replicate.py +# Author : Jiayuan Mao +# Email : maojiayuan@gmail.com +# Date : 27/01/2018 +# +# This file is part of Synchronized-BatchNorm-PyTorch. +# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch +# Distributed under MIT License. + +import functools + +from torch.nn.parallel.data_parallel import DataParallel + +__all__ = [ + 'CallbackContext', + 'execute_replication_callbacks', + 'DataParallelWithCallback', + 'patch_replication_callback' +] + + +class CallbackContext(object): + pass + + +def execute_replication_callbacks(modules): + """ + Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + Note that, as all modules are isomorphism, we assign each sub-module with a context + (shared among multiple copies of this module on different devices). + Through this context, different copies can share some information. + We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback + of any slave copies. + """ + master_copy = modules[0] + nr_modules = len(list(master_copy.modules())) + ctxs = [CallbackContext() for _ in range(nr_modules)] + + for i, module in enumerate(modules): + for j, m in enumerate(module.modules()): + if hasattr(m, '__data_parallel_replicate__'): + m.__data_parallel_replicate__(ctxs[j], i) + + +class DataParallelWithCallback(DataParallel): + """ + Data Parallel with a replication callback. + An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by + original `replicate` function. + The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + # sync_bn.__data_parallel_replicate__ will be invoked. + """ + + def replicate(self, module, device_ids): + modules = super(DataParallelWithCallback, self).replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + +def patch_replication_callback(data_parallel): + """ + Monkey-patch an existing `DataParallel` object. Add the replication callback. + Useful when you have customized `DataParallel` implementation. + Examples: + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) + > patch_replication_callback(sync_bn) + # this is equivalent to + > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) + > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) + """ + + assert isinstance(data_parallel, DataParallel) + + old_replicate = data_parallel.replicate + + @functools.wraps(old_replicate) + def new_replicate(module, device_ids): + modules = old_replicate(module, device_ids) + execute_replication_callbacks(modules) + return modules + + data_parallel.replicate = new_replicate \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/util.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/util.py new file mode 100644 index 00000000..5c86e759 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/models/util.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +upsample = lambda x, size: F.interpolate(x, size, mode='bilinear', align_corners=False) +batchnorm_momentum = 0.01 / 2 + + +def get_n_params(parameters): + pp = 0 + for p in parameters: + nn = 1 + for s in list(p.size()): + nn = nn * s + pp += nn + return pp + + +class _BNReluConv(nn.Sequential): + def __init__(self, num_maps_in, num_maps_out, k=3, batch_norm=True, bn_momentum=0.1, bias=False, dilation=1): + super(_BNReluConv, self).__init__() + if batch_norm: + self.add_module('norm', nn.BatchNorm2d(num_maps_in, momentum=bn_momentum)) + self.add_module('relu', nn.ReLU(inplace=batch_norm is True)) + padding = k // 2 # same conv + self.add_module('conv', nn.Conv2d(num_maps_in, num_maps_out, + kernel_size=k, padding=padding, bias=bias, dilation=dilation)) + + +class _Upsample(nn.Module): + def __init__(self, num_maps_in, skip_maps_in, num_maps_out, use_bn=True, k=3): + super(_Upsample, self).__init__() + print(f'Upsample layer: in = {num_maps_in}, skip = {skip_maps_in}, out = {num_maps_out}') + self.bottleneck = _BNReluConv(skip_maps_in, num_maps_in, k=1, batch_norm=use_bn) + self.blend_conv = _BNReluConv(num_maps_in, num_maps_out, k=k, batch_norm=use_bn) + + def forward(self, x, skip): + skip = self.bottleneck.forward(skip) + skip_size = skip.size()[2:4] + x = upsample(x, skip_size) + x = x + skip + x = self.blend_conv.forward(x) + return x + + +class SpatialPyramidPooling(nn.Module): + def __init__(self, num_maps_in, num_levels, bt_size=512, level_size=128, out_size=128, + grids=(6, 3, 2, 1), square_grid=False, bn_momentum=0.1, use_bn=True): + super(SpatialPyramidPooling, self).__init__() + self.grids = grids + self.square_grid = square_grid + self.spp = nn.Sequential() + self.spp.add_module('spp_bn', + _BNReluConv(num_maps_in, bt_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) + num_features = bt_size + final_size = num_features + for i in range(num_levels): + final_size += level_size + self.spp.add_module('spp' + str(i), + _BNReluConv(num_features, level_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) + self.spp.add_module('spp_fuse', + _BNReluConv(final_size, out_size, k=1, bn_momentum=bn_momentum, batch_norm=use_bn)) + + def forward(self, x): + levels = [] + target_size = x.size()[2:4] + + ar = target_size[1] / target_size[0] + + x = self.spp[0].forward(x) + levels.append(x) + num = len(self.spp) - 1 + + for i in range(1, num): + if not self.square_grid: + grid_size = (self.grids[i - 1], max(1, round(ar * self.grids[i - 1]))) + x_pooled = F.adaptive_avg_pool2d(x, grid_size) + else: + x_pooled = F.adaptive_avg_pool2d(x, self.grids[i - 1]) + level = self.spp[i].forward(x_pooled) + + level = upsample(level, target_size) + levels.append(level) + x = torch.cat(levels, 1) + x = self.spp[-1].forward(x) + return x + + +class _UpsampleBlend(nn.Module): + def __init__(self, num_features, use_bn=True): + super(_UpsampleBlend, self).__init__() + self.blend_conv = _BNReluConv(num_features, num_features, k=3, batch_norm=use_bn) + + def forward(self, x, skip): + skip_size = skip.size()[2:4] + x = upsample(x, skip_size) + x = x + skip + x = self.blend_conv.forward(x) + return x diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/mypath.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/mypath.py new file mode 100644 index 00000000..fa667bef --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/mypath.py @@ -0,0 +1,20 @@ +class Path(object): + @staticmethod + def db_root_dir(dataset): + if dataset == 'cityscapes': + return './ianvs/project/RFNet-master/Data/cityscapes/' # folder that contains leftImg8bit/ + elif dataset == 'citylostfound': + return './ianvs/project/RFNet-master/Data/cityscapesandlostandfound/' # folder that mixes Cityscapes and Lost and Found + elif dataset == 'cityrand': + return './ianvs/project/RFNet-master/Data/cityrand/' + elif dataset == 'target': + return './ianvs/project/RFNet-master/Data/target/' + elif dataset == 'xrlab': + return './ianvs/project/RFNet-master/Data/xrlab/' + elif dataset == 'e1': + return './ianvs/project/RFNet-master/Data/e1/' + elif dataset == 'mapillary': + return './ianvs/project/RFNet-master/Data/mapillary/' + else: + print('Dataset {} not available.'.format(dataset)) + raise NotImplementedError diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/predict.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/predict.py new file mode 100644 index 00000000..ed56fffd --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/predict.py @@ -0,0 +1,100 @@ +import os +os.environ['BACKEND_TYPE'] = 'PYTORCH' +# set at yaml +# os.environ["PREDICT_RESULT_DIR"] = "./inference_results" +# os.environ["EDGE_OUTPUT_URL"] = "./edge_kb" +# os.environ["video_url"] = "./video/radio.mp4" +# os.environ["MODEL_URLS"] = "./cloud_next_kb/index.pkl" + + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) +import time +import torch +import numpy as np +from PIL import Image +import base64 +import tempfile +import warnings +from io import BytesIO + +from sedna.datasources import BaseDataSource +from sedna.core.lifelong_learning import LifelongLearning +from sedna.common.config import Context + +from dataloaders import custom_transforms as tr +from torchvision import transforms + +from accuracy import accuracy +from basemodel import preprocess, val_args, Model + +def preprocess(samples): + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + data = BaseDataSource(data_type="test") + data.x = [(composed_transforms(samples), "")] + return data + +def init_ll_job(): + estimator = Model() + + task_allocation = { + "method": "TaskAllocationByOrigin", + "param": { + "origins": ["real", "sim"], + "default": "real" + } + } + unseen_task_allocation = { + "method": "UnseenTaskAllocationDefault" + } + + ll_job = LifelongLearning( + estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=task_allocation, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=unseen_task_allocation, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None) + + return ll_job + +def predict(): + ll_job = init_ll_job() + + camera_address = Context.get_parameters('video_url') + # use video streams for testing + camera = cv2.VideoCapture(camera_address) + fps = 10 + nframe = 0 + while 1: + ret, input_yuv = camera.read() + if not ret: + time.sleep(5) + camera = cv2.VideoCapture(camera_address) + continue + + if nframe % fps: + nframe += 1 + continue + + img_rgb = cv2.cvtColor(input_yuv, cv2.COLOR_BGR2RGB) + nframe += 1 + if nframe % 1000 == 1: # logs every 1000 frames + warnings.warn(f"camera is open, current frame index is {nframe}") + + img_rgb = cv2.resize(np.array(img_rgb), (2048, 1024), interpolation=cv2.INTER_CUBIC) + img_rgb = Image.fromarray(img_rgb) + sample = {'image': img_rgb, "depth": img_rgb, "label": img_rgb} + data = preprocess(sample) + print("Inference results:", ll_job.inference(data=data)) + +if __name__ == '__main__': + predict() diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/run_server.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/run_server.py new file mode 100644 index 00000000..0f0b2c88 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/run_server.py @@ -0,0 +1,254 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from io import BytesIO +from typing import Optional, Any + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) +import numpy as np +from PIL import Image +import uvicorn +import time +from pydantic import BaseModel +from fastapi import FastAPI, UploadFile, File +from fastapi.routing import APIRoute +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse +import sedna_predict +from sedna.common.utils import get_host_ip +from dataloaders.datasets.cityscapes import CityscapesSegmentation + + +class ImagePayload(BaseModel): + image: UploadFile = File(...) + depth: Optional[UploadFile] = None + + +class ResultModel(BaseModel): + type: int = 0 + box: Any = None + curr: str = None + future: str = None + img: str = None + + +class ResultResponse(BaseModel): + msg: str = "" + result: Optional[ResultModel] = None + code: int + + +class BaseServer: + # pylint: disable=too-many-instance-attributes,too-many-arguments + DEBUG = True + WAIT_TIME = 15 + + def __init__( + self, + servername: str, + host: str, + http_port: int = 8080, + grpc_port: int = 8081, + workers: int = 1, + ws_size: int = 16 * 1024 * 1024, + ssl_key=None, + ssl_cert=None, + timeout=300): + self.server_name = servername + self.app = None + self.host = host or '0.0.0.0' + self.http_port = http_port or 80 + self.grpc_port = grpc_port + self.workers = workers + self.keyfile = ssl_key + self.certfile = ssl_cert + self.ws_size = int(ws_size) + self.timeout = int(timeout) + protocal = "https" if self.certfile else "http" + self.url = f"{protocal}://{self.host}:{self.http_port}" + + def run(self, app, **kwargs): + if hasattr(app, "add_middleware"): + app.add_middleware( + CORSMiddleware, allow_origins=["*"], allow_credentials=True, + allow_methods=["*"], allow_headers=["*"], + ) + + uvicorn.run( + app, + host=self.host, + port=self.http_port, + ssl_keyfile=self.keyfile, + ssl_certfile=self.certfile, + workers=self.workers, + timeout_keep_alive=self.timeout, + **kwargs) + + def get_all_urls(self): + url_list = [{"path": route.path, "name": route.name} + for route in getattr(self.app, 'routes', [])] + return url_list + + +class InferenceServer(BaseServer): # pylint: disable=too-many-arguments + """ + rest api server for inference + """ + + def __init__( + self, + servername, + host: str, + http_port: int = 5000, + max_buffer_size: int = 104857600, + workers: int = 1): + super( + InferenceServer, + self).__init__( + servername=servername, + host=host, + http_port=http_port, + workers=workers) + + self.job, self.detection_validator = sedna_predict.init_ll_job() + + self.max_buffer_size = max_buffer_size + self.app = FastAPI( + routes=[ + APIRoute( + f"/{servername}", + self.model_info, + methods=["GET"], + ), + APIRoute( + f"/{servername}/predict", + self.predict, + methods=["POST"], + response_model=ResultResponse + ), + ], + log_level="trace", + timeout=600, + ) + self.index_frame = 0 + + def start(self): + return self.run(self.app) + + @staticmethod + def model_info(): + return HTMLResponse( + """

Welcome to the RestNet API!

+

To use this service, send a POST HTTP request to {this-url}/predict

+

The JSON payload has the following format: {"image": "BASE64_STRING_OF_IMAGE", + "depth": "BASE64_STRING_OF_DEPTH"}

+ """) + + async def predict(self, image: UploadFile = File(...), depth: Optional[UploadFile] = None) -> ResultResponse: + contents = await image.read() + recieve_img_time = time.time() + print("Recieve image from the robo:", recieve_img_time) + + image = Image.open(BytesIO(contents)).convert('RGB') + + img_dep = None + self.index_frame = self.index_frame + 1 + + if depth: + depth_contents = await depth.read() + depth = Image.open(BytesIO(depth_contents)).convert('RGB') + img_dep = cv2.resize(np.array(depth), (2048, 1024), interpolation=cv2.INTER_CUBIC) + img_dep = Image.fromarray(img_dep) + + img_rgb = cv2.resize(np.array(image), (2048, 1024), interpolation=cv2.INTER_CUBIC) + img_rgb = Image.fromarray(img_rgb) + + sample = {'image': img_rgb, "depth": img_dep, "label": img_rgb} + results = sedna_predict.predict(self.job, data=sample, validator=self.detection_validator) + + predict_finish_time = time.time() + print(f"Prediction costs {predict_finish_time - recieve_img_time} seconds") + + post_process = True + if results["result"]["box"] is None: + results["result"]["curr"] = None + results["result"]["future"] = None + elif post_process: + curr, future = get_curb(results["result"]["box"]) + results["result"]["curr"] = curr + results["result"]["future"] = future + results["result"]["box"] = None + print("Post process cost at worker:", (time.time()-predict_finish_time)) + else: + results["result"]["curr"] = None + results["result"]["future"] = None + + print("Result transmit to robo time:", time.time()) + return results + +def parse_result(label, count): + label_map = ['road', 'sidewalk', ] + count_d = dict(zip(label, count)) + curb_count = count_d.get(19, 0) + if curb_count / np.sum(count) > 0.3: + return "curb" + r = sorted(label, key=count_d.get, reverse=True)[0] + try: + c = label_map[r] + except: + c = "other" + + return c + +def get_curb(results): + results = np.array(results[0]) + input_height, input_width = results.shape + + closest = np.array([ + [0, int(input_height)], + [int(input_width), + int(input_height)], + [int(0.118 * input_width + .5), + int(.8 * input_height + .5)], + [int(0.882 * input_width + .5), + int(.8 * input_height + .5)], + ]) + + future = np.array([ + [int(0.118 * input_width + .5), + int(.8 * input_height + .5)], + [int(0.882 * input_width + .5), + int(.8 * input_height + .5)], + [int(.765 * input_width + .5), + int(.66 * input_height + .5)], + [int(.235 * input_width + .5), + int(.66 * input_height + .5)] + ]) + + mask = np.zeros((input_height, input_width), dtype=np.uint8) + mask = cv2.fillPoly(mask, [closest], 1) + mask = cv2.fillPoly(mask, [future], 2) + d1, c1 = np.unique(results[mask == 1], return_counts=True) + d2, c2 = np.unique(results[mask == 2], return_counts=True) + c = parse_result(d1, c1) + f = parse_result(d2, c2) + + return c, f + +if __name__ == '__main__': + web_app = InferenceServer("lifelong-learning-robo", host=get_host_ip()) + web_app.start() diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_evaluate.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_evaluate.py new file mode 100644 index 00000000..b596a06d --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_evaluate.py @@ -0,0 +1,45 @@ +import os +os.environ['BACKEND_TYPE'] = 'PYTORCH' + +from sedna.core.lifelong_learning import LifelongLearning +from sedna.datasources import IndexDataParse +from sedna.common.config import Context + +from accuracy import accuracy +from basemodel import Model + +def _load_txt_dataset(dataset_url): + # use original dataset url + original_dataset_url = Context.get_parameters('original_dataset_url') + return os.path.join(os.path.dirname(original_dataset_url), dataset_url) + +def eval(): + estimator = Model() + eval_dataset_url = Context.get_parameters("test_dataset_url") + eval_data = IndexDataParse(data_type="eval", func=_load_txt_dataset) + eval_data.parse(eval_dataset_url, use_raw=False) + + task_allocation = { + "method": "TaskAllocationByOrigin", + "param": { + "origins": ["real", "sim"] + } + } + + ll_job = LifelongLearning(estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=task_allocation, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ) + + ll_job.evaluate(eval_data, metrics=accuracy) + + +if __name__ == '__main__': + print(eval()) diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_predict.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_predict.py new file mode 100644 index 00000000..cfcc6048 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_predict.py @@ -0,0 +1,129 @@ +import os + +os.environ['BACKEND_TYPE'] = 'PYTORCH' +os.environ["TEST_DATASET_URL"] = "./data_txt/door_test.txt" +os.environ["EDGE_OUTPUT_URL"] = "./edge_kb" +os.environ["ORIGINAL_DATASET_URL"] = "/tmp" + +import torch +import numpy as np +from PIL import Image +import base64 +import tempfile +from io import BytesIO +from torchvision.transforms import ToPILImage +from torchvision import transforms +from torch.utils.data import DataLoader + +from sedna.datasources import IndexDataParse +from sedna.core.lifelong_learning import LifelongLearning +from sedna.common.config import Context + +from eval import Validator +from accuracy import accuracy +from basemodel import preprocess, val_args, Model +from dataloaders.utils import Colorize +from dataloaders import custom_transforms as tr +from dataloaders.datasets.cityscapes import CityscapesSegmentation + +def _load_txt_dataset(dataset_url): + # use original dataset url, + # see https://github.com/kubeedge/sedna/issues/35 + original_dataset_url = Context.get_parameters('original_dataset_url') + return os.path.join(os.path.dirname(original_dataset_url), dataset_url) + +def fetch_data(): + test_dataset_url = Context.get_parameters("test_dataset_url") + test_data = IndexDataParse(data_type="test", func=_load_txt_dataset) + test_data.parse(test_dataset_url, use_raw=False) + return test_data + +def pre_data_process(samples): + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + data = BaseDataSource(data_type="test") + data.x = [(composed_transforms(samples), "")] + return data + +def post_process(res, is_unseen_task): + if is_unseen_task: + res, base64_string = None, None + else: + res = res[0].tolist() + + type = 0 if not is_unseen_task else 1 + mesg = { + "msg": "", + "result": { + "type": type, + "box": res + }, + "code": 0 + } + return mesg + +def image_merge(raw_img, result): + raw_img = ToPILImage()(raw_img) + + pre_colors = Colorize()(torch.from_numpy(result)) + pre_color_image = ToPILImage()(pre_colors[0]) # pre_colors.dtype = float64 + + image = raw_img.resize(pre_color_image.size, Image.BILINEAR) + image = image.convert('RGBA') + label = pre_color_image.convert('RGBA') + image = Image.blend(image, label, 0.6) + with tempfile.NamedTemporaryFile(suffix='.png') as f: + image.save(f.name) + + with open(f.name, 'rb') as open_file: + byte_content = open_file.read() + base64_bytes = base64.b64encode(byte_content) + base64_string = base64_bytes.decode('utf-8') + return base64_string + +def init_ll_job(): + estimator = Model() + inference_integrate = { + "method": "BBoxInferenceIntegrate" + } + unseen_task_allocation = { + "method": "UnseenTaskAllocationDefault" + } + unseen_sample_recognition = { + "method": "SampleRegonitionByRFNet" + } + + ll_job = LifelongLearning( + estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=None, + task_remodeling=None, + inference_integrate=inference_integrate, + task_update_decision=None, + unseen_task_allocation=unseen_task_allocation, + unseen_sample_recognition=unseen_sample_recognition, + unseen_sample_re_recognition=None) + + args = val_args() + args.weight_path = "./models/detection_model.pth" + args.num_class = 31 + + return ll_job, Validator(args, unseen_detection=True) + +def predict(ll_job, data=None, validator=None): + if data: + data = pre_data_process(data) + else: + data = fetch_data() + data.x = preprocess(data.x) + + res, is_unseen_task, _ = ll_job.inference( + data, validator=validator, initial=False) + return post_process(res, is_unseen_task) + +if __name__ == '__main__': + ll_job, validator = init_ll_job() + print("Inference result:", predict(ll_job, validator=validator)) diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_train.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_train.py new file mode 100644 index 00000000..1c99361a --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/sedna_train.py @@ -0,0 +1,78 @@ +import os +os.environ['BACKEND_TYPE'] = 'PYTORCH' +os.environ["OUTPUT_URL"] = "./cloud_kb/" +# os.environ['CLOUD_KB_INDEX'] = "./cloud_kb/index.pkl" +os.environ["TRAIN_DATASET_URL"] = "./data_txt/sedna_data.txt" +os.environ["KB_SERVER"] = "http://0.0.0.0:9020" +os.environ["HAS_COMPLETED_INITIAL_TRAINING"] = "false" + +from sedna.common.file_ops import FileOps +from sedna.datasources import IndexDataParse +from sedna.common.config import Context, BaseConfig +from sedna.core.lifelong_learning import LifelongLearning + +from basemodel import Model + +def _load_txt_dataset(dataset_url): + # use original dataset url + original_dataset_url = Context.get_parameters('original_dataset_url') + return os.path.join(os.path.dirname(original_dataset_url), dataset_url) + +def train(estimator, train_data): + task_definition = { + "method": "TaskDefinitionByOrigin", + "param": { + "origins": ["real", "sim"] + } + } + + task_allocation = { + "method": "TaskAllocationByOrigin", + "param": { + "origins": ["real", "sim"] + } + } + + ll_job = LifelongLearning(estimator, + task_definition=task_definition, + task_relationship_discovery=None, + task_allocation=task_allocation, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ) + + ll_job.train(train_data) + +def update(estimator, train_data): + ll_job = LifelongLearning(estimator, + task_definition=None, + task_relationship_discovery=None, + task_allocation=None, + task_remodeling=None, + inference_integrate=None, + task_update_decision=None, + unseen_task_allocation=None, + unseen_sample_recognition=None, + unseen_sample_re_recognition=None + ) + + ll_job.update(train_data) + +def run(): + estimator = Model() + train_dataset_url = BaseConfig.train_dataset_url + train_data = IndexDataParse(data_type="train") + train_data.parse(train_dataset_url, use_raw=False) + + is_completed_initilization = str(Context.get_parameters("HAS_COMPLETED_INITIAL_TRAINING", "false")).lower() + if is_completed_initilization == "false": + train(estimator, train_data) + else: + update(estimator, train_data) + +if __name__ == '__main__': + run() diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/test.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/test.py new file mode 100644 index 00000000..7350632b --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/test.py @@ -0,0 +1,25 @@ +import numpy as np +import seaborn as sns +import pandas as pd +import matplotlib.pyplot as plt + +CPA_results = np.load("./cpa_results.npy").T +ratios = [0.3, 0.5, 0.6, 0.7, 0.8, 0.9] +ratio_counts = np.zeros((len(CPA_results), len(ratios)), dtype=float) + +for i in range(len(CPA_results)): + for j in range(len(ratios)): + result = CPA_results[i] + result = result[result <= ratios[j]] + + ratio_counts[i][j] = len(result) / 275 + +plt.figure(figsize=(45, 10)) +ratio_counts = pd.DataFrame(data=ratio_counts.T, index=ratios) +sns.heatmap(data=ratio_counts, annot=True, cmap="YlGnBu", annot_kws={'fontsize': 15}) +plt.xticks(fontsize=20) +plt.yticks(fontsize=25) +plt.xlabel("Test images", fontsize=25) +plt.ylabel("Ratio of PA ranges", fontsize=25) +plt.savefig("./figs/ratio_count.png") +plt.show() diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/train.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/train.py new file mode 100644 index 00000000..ecd87553 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/train.py @@ -0,0 +1,361 @@ +import argparse +import os +import numpy as np +from tqdm import tqdm +import torch +import copy + +from mypath import Path +from dataloaders import make_data_loader + +from models.erfnet_RA_parallel import Net as Net_RAP + +from utils.loss import SegmentationLosses +from models.replicate import patch_replication_callback +from utils.calculate_weights import calculate_weigths_labels +from utils.lr_scheduler import LR_Scheduler +from utils.saver import Saver +from utils.summaries import TensorboardSummary +from utils.metrics import Evaluator +from sedna.datasources import BaseDataSource + +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" + +class Trainer(object): + def __init__(self, args, train_data=None, valid_data=None): + self.args = args + # Define Saver + self.saver = Saver(args) + self.saver.save_experiment_config() + # Define Tensorboard Summary + self.summary = TensorboardSummary(self.saver.experiment_dir) + self.writer = self.summary.create_summary() + # denormalize for detph image + self.mean_depth = torch.as_tensor(0.12176, dtype=torch.float32, device='cpu') + self.std_depth = torch.as_tensor(0.09752, dtype=torch.float32, device='cpu') + + self.nclass = args.num_class # [13, 30, 30] + self.current_domain = min(args.current_domain, 2) # current domain start from 0 and maximum is 2 + self.next_domain = args.next_domain # next_domain start from 1 + + if self.current_domain <= 0: + self.current_class = [self.nclass[0]] + elif self.current_domain == 1: + self.current_class = self.nclass[:2] + elif self.current_domain >= 2: + self.current_class = self.nclass + else: + pass + + # Define Dataloader + kwargs = {'num_workers': args.workers, 'pin_memory': True} + self.train_loader, self.val_loader, self.test_loader, _ = make_data_loader(args, train_data=train_data, + valid_data=valid_data, **kwargs) + + self.print_domain_info() + + # Define network + model = Net_RAP(num_classes=self.current_class, nb_tasks=self.current_domain + 1, cur_task=self.current_domain) + model_old = Net_RAP(num_classes=self.current_class, nb_tasks=self.current_domain, cur_task=max(self.current_domain-1, 0)) + args.current_domain = self.next_domain + args.next_domain += 1 + + optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, + weight_decay=args.weight_decay) + # Define Criterion + # whether to use class balanced weights + if args.use_balanced_weights: + classes_weights_path = os.path.join(Path.db_root_dir(args.dataset), args.dataset+'_classes_weights.npy') + if os.path.isfile(classes_weights_path): + weight = np.load(classes_weights_path) + else: + weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) + weight = torch.from_numpy(weight.astype(np.float32)) + else: + weight = None + # Define loss function + self.criterion = SegmentationLosses(weight=weight, cuda=args.cuda, gpu_ids=args.gpu_ids).build_loss(mode=args.loss_type) + self.model, self.model_old, self.optimizer = model, model_old, optimizer + # Define Evaluator + self.evaluator = Evaluator(self.nclass[self.current_domain]) + # # Define lr scheduler + self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) + # Using cuda + if args.cuda: + self.model = torch.nn.DataParallel(self.model) + # patch_replication_callback(self.model) + self.model = self.model.cuda(args.gpu_ids) + self.gpu_ids = args.gpu_ids + # Resuming checkpoint + self.best_pred = 0.0 + if args.resume is not None: + if not os.path.isfile(args.resume): + raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) + print(f"Training: load model from {args.resume}") + checkpoint = torch.load(args.resume, map_location=torch.device('cuda:0')) + args.start_epoch = checkpoint['epoch'] + + self.model.load_state_dict(checkpoint['state_dict'], False) + + if not args.ft: + self.optimizer.load_state_dict(checkpoint['optimizer']) + self.best_pred = checkpoint['best_pred'] + print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) + + # Clear start epoch if fine-tuning + if args.ft: + args.start_epoch = 0 + + def get_weight(self): + print("get weight") + current_model = copy.deepcopy(self.model) + return current_model.parameters() + + def set_weight(self, weights): + length = len(weights) + print("set weight", length) + print("model:", self.args.resume) + tau = 0.2 + if length == 1: + for param, target_param in zip(weights[0], self.model.parameters()): + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + elif length == 2: + for param1, param2, target_param in zip(weights[0], weights[1], self.model.parameters()): + target_param.data.copy_(0.5 * tau * param1.data + 0.5 * tau * param2.data + (1 - tau) * target_param.data) + + def my_training(self, epoch): + train_loss = 0.0 + print(self.optimizer.state_dict()['param_groups'][0]['lr']) + current_model = copy.deepcopy(self.model) + self.model.train() + tbar = tqdm(self.train_loader) + num_img_tr = len(self.train_loader) + + for i, sample in enumerate(tbar): + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + #print(target.shape) + else: + image, target = sample['image'], sample['label'] + print(image.shape) + if self.args.cuda: + image, target = image.cuda(self.args.gpu_ids), target.cuda(self.args.gpu_ids) + if self.args.depth: + depth = depth.cuda(self.args.gpu_ids) + self.scheduler(self.optimizer, i, epoch, self.best_pred) + self.optimizer.zero_grad() + + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image) + target[target > self.nclass[2]-1] = 255 + loss = self.criterion(output, target) + loss.backward() + self.optimizer.step() + #print(self.optimizer.state_dict()['param_groups'][0]['lr']) + train_loss += loss.item() + tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) + self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) + # Show 10 * 3 inference results each epoch + if i % (num_img_tr // 10 + 1) == 0: + global_step = i + num_img_tr * epoch + if self.args.depth: + self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) + + depth_display = depth[0].cpu().unsqueeze(0) + depth_display = depth_display.mul_(self.std_depth).add_(self.mean_depth) + depth_display = depth_display.numpy() + depth_display = depth_display*255 + depth_display = depth_display.astype(np.uint8) + self.writer.add_image('Depth', depth_display, global_step) + + else: + self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) + + self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) + print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) + print('Loss: %.3f' % train_loss) + tau = 0.3 + flag = True + for param, target_param in zip(current_model.parameters(), self.model.parameters()): + if flag: + flag = False + target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data) + + del current_model + return train_loss + + def training(self, epoch): + train_loss = 0.0 + print(self.optimizer.state_dict()['param_groups'][0]['lr']) + + self.model.train() + self.model_old.eval() + + for name, m in self.model_old.named_parameters(): + m.requires_grad = False + + for name, m in self.model.named_parameters(): + if 'decoder' in name: + if 'decoder.{}'.format(self.current_domain) in name: + m.requires_grad = True + else: + m.requires_grad = False + + elif 'encoder' in name: + if 'bn' in name or 'parallel_conv' in name: + if '.{}.weight'.format(self.current_domain) in name or '.{}.bias'.format(self.current_domain) in name: + m.requires_grad = True + else: + m.requires_grad = False + + tbar = tqdm(self.train_loader) + num_img_tr = len(self.train_loader) + + for i, sample in enumerate(tbar): + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + #print(target.shape) + else: + image, target = sample['image'], sample['label'] + # print(image.shape) + if self.args.cuda: + image, target = image.cuda(self.args.gpu_ids), target.cuda(self.args.gpu_ids) + if self.args.depth: + depth = depth.cuda(self.args.gpu_ids) + self.scheduler(self.optimizer, i, epoch, self.best_pred) + self.optimizer.zero_grad() + + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image, self.current_domain) + + output = torch.tensor(output, dtype=torch.float32) + target[target > self.nclass[self.current_domain]-1] = 255 + + target = self.my_to_label(target) + target = self.my_relabel(target, 255, self.nclass[self.current_domain] - 1) + + target = target.squeeze(0) + target = target.cuda(self.gpu_ids) + + outputs_prev_task = self.model(image, max(self.current_domain-1, 0)) + loss = self.criterion(output, target) + + loss.requires_grad_(True) + loss.backward() + self.optimizer.step() + + train_loss += loss.item() + tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) + self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) + + # Show 10 * 3 inference results each epoch + if i % (num_img_tr // 10 + 1) == 0: + global_step = i + num_img_tr * epoch + if self.args.depth: + self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) + + depth_display = depth[0].cpu().unsqueeze(0) + depth_display = depth_display.mul_(self.std_depth).add_(self.mean_depth) + depth_display = depth_display.numpy() + depth_display = depth_display*255 + depth_display = depth_display.astype(np.uint8) + self.writer.add_image('Depth', depth_display, global_step) + + else: + self.summary.visualize_image(self.writer, self.args.dataset, image, target, output, global_step) + + self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) + print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) + print('Loss: %.3f' % train_loss) + + # save checkpoint every epoch + checkpoint_path = self.saver.save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'best_pred': self.best_pred, + }, True) + return train_loss + + def validation(self, epoch): + self.model.eval() + self.evaluator.reset() + tbar = tqdm(self.val_loader, desc='\r') + test_loss = 0.0 + for i, (sample, img_path) in enumerate(tbar): + if self.args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + # print(f"val image is {image}") + if self.args.cuda: + image, target = image.cuda(self.args.gpu_ids), target.cuda(self.args.gpu_ids) + if self.args.depth: + depth = depth.cuda(self.args.gpu_ids) + with torch.no_grad(): + if self.args.depth: + output = self.model(image, depth) + else: + output = self.model(image) + target[target > self.nclass[2]-1] = 255 + loss = self.criterion(output, target) + test_loss += loss.item() + tbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) + pred = output.data.cpu().numpy() + target = target.cpu().numpy() + pred = np.argmax(pred, axis=1) + # Add batch sample into evaluator + self.evaluator.add_batch(target, pred) + + # Fast test during the training + Acc = self.evaluator.Pixel_Accuracy() + Acc_class = self.evaluator.Pixel_Accuracy_Class() + mIoU = self.evaluator.Mean_Intersection_over_Union() + FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() + self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) + self.writer.add_scalar('val/mIoU', mIoU, epoch) + self.writer.add_scalar('val/Acc', Acc, epoch) + self.writer.add_scalar('val/Acc_class', Acc_class, epoch) + self.writer.add_scalar('val/fwIoU', FWIoU, epoch) + print('Validation:') + print('[Epoch: %d, numImages: %5d]' % (epoch, i * self.args.batch_size + image.data.shape[0])) + print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) + print('Loss: %.3f' % test_loss) + + new_pred = mIoU + if new_pred > self.best_pred: + is_best = True + self.best_pred = new_pred + self.saver.save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': self.model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'best_pred': self.best_pred, + }, is_best) + + + def print_domain_info(self): + + domain_map = { + 0: "Synthia", + 1: "CityScapes", + 2: "Cloud-Robotics" + } + + domain_name = domain_map.get(self.current_domain, "Unknown Domain") + + print("We are in domain", self.current_domain, "which is", domain_name) + + def my_relabel(self, tensor, olabel, nlabel): + tensor[tensor == olabel] = nlabel + return tensor + + def my_to_label(self, image): + image = image.cpu() + return torch.from_numpy(np.array(image)).long().unsqueeze(0) + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/__init__.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/args.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/args.py new file mode 100644 index 00000000..edbecd20 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/args.py @@ -0,0 +1,61 @@ +class TrainArgs: + def __init__(self, **kwargs): + self.depth = False + self.dataset = 'cityscapes' + self.workers = 4 + self.base_size = 1024 + self.crop_size = 768 + self.loss_type = 'ce' + self.epochs = kwargs.get("epochs", 1) + self.start_epoch = 0 + + self.num_class = [13, 30, 30] + self.current_domain = 0 + self.next_domain = 1 + + self.state = None + self.batch_size = 2 + self.val_batch_size = 1 + self.use_balanced_weights = False + + self.lr = kwargs.get("learning_rate", 1e-4) + self.lr_scheduler = 'cos' + self.momentum = 0.9 + self.weight_decay = 2.5e-5 + self.no_cuda = False + self.gpu_ids = 0 + + self.seed = 1 + self.resume = None + self.checkname = 'erfnet_RA_parallel' + self.ft = True + self.eval_interval = kwargs.get("eval_interval", 50) + self.no_val = kwargs.get("no_val", True) + self.cuda = True + self.savedir = './dataset/mdil-ss/save' + +class ValArgs: + def __init__(self, **kwargs): + self.dataset = 'cityscapes' + self.workers = 0 + self.base_size = 1024 + self.crop_size = 768 + self.batch_size = 6 + self.val_batch_size = 1 + self.test_batch_size = 1 + + self.num_class = [13, 30, 30] + self.current_domain = 0 + self.next_domain = 1 + + self.no_cuda = False + self.gpu_ids = 0 + self.checkname = None + self.weight_path = "./models/530_exp3_2.pth" + self.save_predicted_image = False + self.color_label_save_path = './test/color' + self.merge_label_save_path = './test/merge' + self.label_save_path = './test/label' + self.merge = True + self.depth = False + self.cuda = True diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/calculate_weights.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/calculate_weights.py new file mode 100644 index 00000000..2c2c9821 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/calculate_weights.py @@ -0,0 +1,29 @@ +import os +from tqdm import tqdm +import numpy as np +from mypath import Path + +def calculate_weigths_labels(dataset, dataloader, num_classes): + # Create an instance from the data loader + z = np.zeros((num_classes,)) + # Initialize tqdm + tqdm_batch = tqdm(dataloader) + print('Calculating classes weights') + for sample in tqdm_batch: + y = sample['label'] + y = y.detach().cpu().numpy() + mask = (y >= 0) & (y < num_classes) + labels = y[mask].astype(np.uint8) + count_l = np.bincount(labels, minlength=num_classes) + z += count_l + tqdm_batch.close() + total_frequency = np.sum(z) + class_weights = [] + for frequency in z: + class_weight = 1 / (np.log(1.02 + (frequency / total_frequency))) + class_weights.append(class_weight) + ret = np.array(class_weights) + classes_weights_path = os.path.join(Path.db_root_dir(dataset), dataset+'_classes_weights.npy') + np.save(classes_weights_path, ret) + + return ret \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/iouEval.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/iouEval.py new file mode 100644 index 00000000..93f029b6 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/iouEval.py @@ -0,0 +1,129 @@ +import torch + +class iouEval: + + def __init__(self, nClasses, ignoreIndex=20): + + self.nClasses = nClasses + self.ignoreIndex = ignoreIndex if nClasses > ignoreIndex else -1 # if ignoreIndex is larger than nClasses, consider no ignoreIndex + self.reset() + + def reset(self): + classes = self.nClasses if self.ignoreIndex == -1 else self.nClasses - 1 + self.tp = torch.zeros(classes).double() + self.fp = torch.zeros(classes).double() + self.fn = torch.zeros(classes).double() + self.cdp_obstacle = torch.zeros(1).double() + self.tp_obstacle = torch.zeros(1).double() + self.idp_obstacle = torch.zeros(1).double() + self.tp_nonobstacle = torch.zeros(1).double() + + def addBatch(self, x, y): # x=preds, y=targets + # sizes should be "batch_size x nClasses x H x W" + + if (x.is_cuda or y.is_cuda): + x = x.cuda() + y = y.cuda() + + # if size is "batch_size x 1 x H x W" scatter to onehot + if (x.size(1) == 1): + x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3)) + if x.is_cuda: + x_onehot = x_onehot.cuda() + x_onehot.scatter_(1, x, 1).float() # dim index src 按照列用1替换0,索引为x + else: + x_onehot = x.float() + + if (y.size(1) == 1): + y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3)) + if y.is_cuda: + y_onehot = y_onehot.cuda() + y_onehot.scatter_(1, y, 1).float() + else: + y_onehot = y.float() + + if (self.ignoreIndex != -1): + ignores = y_onehot[:, self.ignoreIndex].unsqueeze(1) # 加一维 + x_onehot = x_onehot[:, :self.ignoreIndex] # ignoreIndex后的都不要 + y_onehot = y_onehot[:, :self.ignoreIndex] + else: + ignores = 0 + + + tpmult = x_onehot * y_onehot # times prediction and gt coincide is 1 + tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, + keepdim=True).squeeze() + fpmult = x_onehot * ( + 1 - y_onehot - ignores) # times prediction says its that class and gt says its not (subtracting cases when its ignore label!) + fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, + keepdim=True).squeeze() + fnmult = (1 - x_onehot) * (y_onehot) # times prediction says its not that class and gt says it is + fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3, + keepdim=True).squeeze() + + self.tp += tp.double().cpu() + self.fp += fp.double().cpu() + self.fn += fn.double().cpu() + + cdp_obstacle = tpmult[:, 19].sum() # obstacle index 19 + tp_obstacle = y_onehot[:, 19].sum() + + idp_obstacle = (x_onehot[:, 19] - tpmult[:, 19]).sum() + tp_nonobstacle = (-1*y_onehot+1).sum() + + self.cdp_obstacle += cdp_obstacle.double().cpu() + self.tp_obstacle += tp_obstacle.double().cpu() + self.idp_obstacle += idp_obstacle.double().cpu() + self.tp_nonobstacle += tp_nonobstacle.double().cpu() + + + + def getIoU(self): + num = self.tp + den = self.tp + self.fp + self.fn + 1e-15 + iou = num / den + iou_not_zero = list(filter(lambda x: x != 0, iou)) + iou_mean = sum(iou_not_zero) / len(iou_not_zero) + tfp = self.tp + self.fp + 1e-15 + acc = num / tfp + acc_not_zero = list(filter(lambda x: x != 0, acc)) + acc_mean = sum(acc_not_zero) / len(acc_not_zero) + + return iou_mean, iou, acc_mean, acc # returns "iou mean", "iou per class" + + def getObstacleEval(self): + + pdr_obstacle = self.cdp_obstacle / (self.tp_obstacle+1e-15) + pfp_obstacle = self.idp_obstacle / (self.tp_nonobstacle+1e-15) + + return pdr_obstacle, pfp_obstacle + + +# Class for colors +class colors: + RED = '\033[31;1m' + GREEN = '\033[32;1m' + YELLOW = '\033[33;1m' + BLUE = '\033[34;1m' + MAGENTA = '\033[35;1m' + CYAN = '\033[36;1m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + ENDC = '\033[0m' + + +# Colored value output if colorized flag is activated. +def getColorEntry(val): + if not isinstance(val, float): + return colors.ENDC + if (val < .20): + return colors.RED + elif (val < .40): + return colors.YELLOW + elif (val < .60): + return colors.BLUE + elif (val < .80): + return colors.CYAN + else: + return colors.GREEN + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/loss.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/loss.py new file mode 100644 index 00000000..e4a85a38 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/loss.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn + +class SegmentationLosses(object): + def __init__(self, weight=None, size_average=True, batch_average=True, ignore_index=255, cuda=False, gpu_ids=0): # ignore_index=255 + self.ignore_index = ignore_index + self.weight = weight + self.size_average = size_average + self.batch_average = batch_average + self.cuda = cuda + self.gpu_ids = gpu_ids + + def build_loss(self, mode='ce'): + """Choices: ['ce' or 'focal']""" + if mode == 'ce': + return self.CrossEntropyLoss + elif mode == 'focal': + return self.FocalLoss + else: + raise NotImplementedError + + def CrossEntropyLoss(self, logit, target): + criterion = nn.CrossEntropyLoss() + if self.cuda: + criterion = criterion.cuda(self.gpu_ids) + + loss = criterion(logit, target.long()) + + return loss + + def FocalLoss(self, logit, target, gamma=2, alpha=0.5): + n, c, h, w = logit.size() + criterion = nn.CrossEntropyLoss(weight=self.weight, ignore_index=self.ignore_index, + size_average=self.size_average) + if self.cuda: + criterion = criterion.cuda(self.gpu_ids) + + logpt = -criterion(logit, target.long()) + pt = torch.exp(logpt) + if alpha is not None: + logpt *= alpha + loss = -((1 - pt) ** gamma) * logpt + + if self.batch_average: + loss /= n + + return loss + +if __name__ == "__main__": + loss = SegmentationLosses(cuda=True) diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/lr_scheduler.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/lr_scheduler.py new file mode 100644 index 00000000..47124028 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/lr_scheduler.py @@ -0,0 +1,70 @@ +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +## Created by: Hang Zhang +## ECE Department, Rutgers University +## Email: zhang.hang@rutgers.edu +## Copyright (c) 2017 +## +## This source code is licensed under the MIT-style license found in the +## LICENSE file in the root directory of this source tree +##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ + +import math + +class LR_Scheduler(object): + """Learning Rate Scheduler + + Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` + + Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` + + Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` + + Args: + args: + :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), + :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, + :attr:`args.lr_step` + + iters_per_epoch: number of iterations per epoch + """ + def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, + lr_step=0, warmup_epochs=0): + self.mode = mode + print('Using {} LR Scheduler!'.format(self.mode)) + self.lr = base_lr + if mode == 'step': + assert lr_step + self.lr_step = lr_step + self.iters_per_epoch = iters_per_epoch + self.N = num_epochs * iters_per_epoch + self.epoch = -1 + self.warmup_iters = warmup_epochs * iters_per_epoch + + def __call__(self, optimizer, i, epoch, best_pred): + T = epoch * self.iters_per_epoch + i + if self.mode == 'cos': + lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) + elif self.mode == 'poly': + lr = self.lr * pow((1 - 1.0 * T / self.N), 2) + elif self.mode == 'step': + lr = self.lr * (0.1 ** (epoch // self.lr_step)) + else: + raise NotImplemented + # warm up lr schedule + if self.warmup_iters > 0 and T < self.warmup_iters: + lr = lr * 1.0 * T / self.warmup_iters + if epoch > self.epoch: + print('\n=>Epoches %i, learning rate = %.4f, \ + previous best = %.4f' % (epoch, lr, best_pred)) + self.epoch = epoch + assert lr >= 0 + self._adjust_learning_rate(optimizer, lr) + + def _adjust_learning_rate(self, optimizer, lr): + if len(optimizer.param_groups) == 1: + optimizer.param_groups[0]['lr'] = lr * 4 + else: + # enlarge the lr at the head + optimizer.param_groups[0]['lr'] = lr * 4 + for i in range(1, len(optimizer.param_groups)): + optimizer.param_groups[i]['lr'] = lr diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/metrics.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/metrics.py new file mode 100644 index 00000000..2900ec2b --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/metrics.py @@ -0,0 +1,78 @@ +import numpy as np + + +class Evaluator(object): + def __init__(self, num_class): + self.num_class = num_class + self.confusion_matrix = np.zeros((self.num_class,)*2) # shape:(num_class, num_class) + + def Pixel_Accuracy(self): + Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() + return Acc + + def Pixel_Accuracy_Class_Curb(self): + Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) + Acc = np.nanmean(Acc[:2]) + return Acc + + + def Pixel_Accuracy_Class(self): + Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) + Acc = np.nanmean(Acc) + return Acc + + def Mean_Intersection_over_Union(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + MIoU = np.nanmean(MIoU) + return MIoU + + def Mean_Intersection_over_Union_Curb(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + MIoU = np.nanmean(MIoU[:2]) + return MIoU + + def Frequency_Weighted_Intersection_over_Union(self): + freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) + iu = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() + CFWIoU = freq[freq > 0] * iu[freq > 0] + return FWIoU + + def Frequency_Weighted_Intersection_over_Union_Curb(self): + freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) + iu = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + CFWIoU = freq[freq > 0] * iu[freq > 0] + + return np.nanmean(CFWIoU[:2]) + + def _generate_matrix(self, gt_image, pre_image): + mask = (gt_image >= 0) & (gt_image < self.num_class) + label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] + count = np.bincount(label, minlength=self.num_class**2) + confusion_matrix = count.reshape(self.num_class, self.num_class) + return confusion_matrix + + def add_batch(self, gt_image, pre_image): + gt_image = np.array(gt_image) + pre_image = np.array(pre_image) + print(gt_image.shape, pre_image.shape) + if gt_image.shape != pre_image.shape: + pre_image = pre_image[0] + assert gt_image.shape == pre_image.shape + self.confusion_matrix += self._generate_matrix(gt_image, pre_image) + + def reset(self): + self.confusion_matrix = np.zeros((self.num_class,) * 2) + + + + diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/saver.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/saver.py new file mode 100644 index 00000000..03866432 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/saver.py @@ -0,0 +1,68 @@ +import os +import time +import shutil +import tempfile +import torch +from collections import OrderedDict +import glob + +class Saver(object): + + def __init__(self, args): + self.args = args + self.directory = os.path.join('/tmp', args.dataset, args.checkname) + self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) + run_id = int(self.runs[-1].split('_')[-1]) + 1 if self.runs else 0 + + self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) + if not os.path.exists(self.experiment_dir): + os.makedirs(self.experiment_dir) + + def save_checkpoint(self, state, is_best): # filename from .pth.tar change to .pth? + """Saves checkpoint to disk""" + filename = f'checkpoint_{time.time()}.pth' + checkpoint_path = os.path.join(self.experiment_dir, filename) + torch.save(state, checkpoint_path) + if is_best: + best_pred = state['best_pred'] + with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: + f.write(str(best_pred)) + if self.runs: + previous_miou = [0.0] + for run in self.runs: + run_id = run.split('_')[-1] + path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') + if os.path.exists(path): + with open(path, 'r') as f: + miou = float(f.readline()) + previous_miou.append(miou) + else: + continue + max_miou = max(previous_miou) + if best_pred > max_miou: + checkpoint_path_best = os.path.join(self.directory, 'model_best.pth') + shutil.copyfile(checkpoint_path, checkpoint_path_best) + checkpoint_path = checkpoint_path_best + else: + checkpoint_path_best = os.path.join(self.directory, 'model_best.pth') + shutil.copyfile(checkpoint_path, checkpoint_path_best) + checkpoint_path = checkpoint_path_best + + return checkpoint_path + + def save_experiment_config(self): + logfile = os.path.join(self.experiment_dir, 'parameters.txt') + log_file = open(logfile, 'w') + p = OrderedDict() + p['datset'] = self.args.dataset + # p['out_stride'] = self.args.out_stride + p['lr'] = self.args.lr + p['lr_scheduler'] = self.args.lr_scheduler + p['loss_type'] = self.args.loss_type + p['epoch'] = self.args.epochs + p['base_size'] = self.args.base_size + p['crop_size'] = self.args.crop_size + + for key, val in p.items(): + log_file.write(key + ':' + str(val) + '\n') + log_file.close() \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/summaries.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/summaries.py new file mode 100644 index 00000000..04bcdb82 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/ERFNet/utils/summaries.py @@ -0,0 +1,39 @@ +import os +import torch +from torchvision.utils import make_grid +# from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter +from dataloaders.utils import decode_seg_map_sequence + +class TensorboardSummary(object): + def __init__(self, directory): + self.directory = directory + + def create_summary(self): + writer = SummaryWriter(log_dir=os.path.join(self.directory)) + return writer + + def visualize_image(self, writer, dataset, image, target, output, global_step, depth=None): + if depth is None: + grid_image = make_grid(image[:3].clone().cpu().data, 3, normalize=True) + writer.add_image('Image', grid_image, global_step) + + grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), + dataset=dataset), 3, normalize=False, range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), + dataset=dataset), 3, normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) + else: + grid_image = make_grid(image[:3].clone().cpu().data, 4, normalize=True) + writer.add_image('Image', grid_image, global_step) + + grid_image = make_grid(depth[:3].clone().cpu().data, 4, normalize=True) # normalize=False? + writer.add_image('Depth', grid_image, global_step) + + grid_image = make_grid(decode_seg_map_sequence(torch.max(output[:3], 1)[1].detach().cpu().numpy(), + dataset=dataset), 4, normalize=False, range=(0, 255)) + writer.add_image('Predicted label', grid_image, global_step) + grid_image = make_grid(decode_seg_map_sequence(torch.squeeze(target[:3], 1).detach().cpu().numpy(), + dataset=dataset), 4, normalize=False, range=(0, 255)) + writer.add_image('Groundtruth label', grid_image, global_step) \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/basemodel.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/basemodel.py new file mode 100644 index 00000000..97aca3de --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/basemodel.py @@ -0,0 +1,142 @@ +import os +import gc +import numpy as np +import torch +from torch.utils.data import DataLoader +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.common.config import Context +from sedna.common.file_ops import FileOps +from sedna.common.log import LOGGER +from PIL import Image +from torchvision import transforms + +from ERFNet.train import Trainer +from ERFNet.eval import Validator, load_my_state_dict +from ERFNet.dataloaders import custom_transforms as tr +from ERFNet.dataloaders import make_data_loader +from ERFNet.utils.args import TrainArgs, ValArgs + +# set backend +os.environ['BACKEND_TYPE'] = 'PYTORCH' + +os.environ["OMP_NUM_THREADS"] = "1" +os.environ["MKL_NUM_THREADS"] = "1" + +@ClassFactory.register(ClassType.GENERAL, alias="BaseModel") +class BaseModel: + def __init__(self, **kwargs): + self.train_args = TrainArgs(**kwargs) + self.trainer = None + + self.val_args = ValArgs(**kwargs) + label_save_dir = Context.get_parameters("INFERENCE_RESULT_DIR", "./inference_results") + self.val_args.color_label_save_path = os.path.join(label_save_dir, "color") + self.val_args.merge_label_save_path = os.path.join(label_save_dir, "merge") + self.val_args.label_save_path = os.path.join(label_save_dir, "label") + self.validator = Validator(self.val_args) + + def get_weights(self): + return self.trainer.get_weight() + + def set_weights(self, weights): + self.trainer.set_weight(weights) + + epoch_num = 0 + print("Total epoch: ", epoch_num) + loss_all = [] + for epoch in range(epoch_num): + train_loss = self.trainer.my_training(epoch) + loss_all.append(train_loss) + + def train(self, train_data, valid_data=None, **kwargs): + self.trainer = Trainer(self.train_args, train_data=train_data) + print("Total epoches:", self.trainer.args.epochs) + loss_all = [] + for epoch in range( + self.trainer.args.start_epoch, + self.trainer.args.epochs): + if epoch == 0 and self.trainer.val_loader: + self.trainer.validation(epoch) + loss = self.trainer.training(epoch) + loss_all.append(loss) + if self.trainer.args.no_val and ( + epoch % + self.trainer.args.eval_interval == ( + self.trainer.args.eval_interval - + 1) or epoch == self.trainer.args.epochs - + 1): + is_best = False + self.train_model_url = self.trainer.saver.save_checkpoint({ + 'epoch': epoch + 1, + 'state_dict': self.trainer.model.state_dict(), + 'optimizer': self.trainer.optimizer.state_dict(), + 'best_pred': self.trainer.best_pred, + }, is_best) + + self.trainer.writer.close() + return self.train_model_url + + def predict(self, data, **kwargs): + if len(data) > 10: + print("predict start for big data") + my_kwargs = {'num_workers': self.val_args.workers, 'pin_memory': True} + _, _, self.validator.test_loader, _ = make_data_loader(self.val_args, test_data=data, **my_kwargs) + else: + print("predict start for small data") + if not isinstance(data[0][0], dict): + data = self._preprocess(data) + if type(data) is np.ndarray: + data = data.tolist() + self.validator.test_loader = DataLoader(data, batch_size=self.val_args.test_batch_size, shuffle=False, + pin_memory=True) + + return self.validator.validate() + + def evaluate(self, data, **kwargs): + self.val_args.save_predicted_image = kwargs.get("save_predicted_image", True) + samples = self._preprocess(data.x) + predictions = self.predict(samples) + metric_name, metric_func = kwargs.get("metric") + if callable(metric_func): + return metric_func(data.y, predictions) + else: + raise Exception(f"not found model metric func(name={metric_name}) in model eval phase") + + def load(self, model_url, **kwargs): + if model_url: + print("load model url: ",model_url) + self.validator.new_state_dict = torch.load(model_url, map_location=torch.device("cpu")) + self.train_args.resume = model_url + else: + raise Exception("model url does not exist.") + self.validator.model = load_my_state_dict(self.validator.model, self.validator.new_state_dict['state_dict']) + + def save(self, model_path=None): + if not model_path: + LOGGER.warning(f"Not specify model path.") + return self.train_model_url + + return FileOps.upload(self.train_model_url, model_path) + + def _preprocess(self, image_urls): + transformed_images = [] + for paths in image_urls: + if len(paths) == 2: + img_path, depth_path = paths + _img = Image.open(img_path).convert('RGB') + _depth = Image.open(depth_path) + else: + img_path = paths[0] + _img = Image.open(img_path).convert('RGB') + _depth = _img + + sample = {'image': _img, 'depth': _depth, 'label': _img} + del _img + gc.collect() + composed_transforms = transforms.Compose([ + tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), + tr.ToTensor()]) + + transformed_images.append((composed_transforms(sample), img_path)) + + return transformed_images diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_allocation_by_domain.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_allocation_by_domain.py new file mode 100644 index 00000000..dc4d9522 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_allocation_by_domain.py @@ -0,0 +1,46 @@ +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('TaskAllocationByDomain',) + + +@ClassFactory.register(ClassType.STP, alias="TaskAllocationByDomain") +class TaskAllocationByOrigin: + """ + Corresponding to `TaskDefinitionByOrigin` + + Parameters + ---------- + task_extractor : Dict + used to match target tasks + origins: List[Metadata] + metadata is usually a class feature + label with finite values. + """ + + def __init__(self, **kwargs): + self.default_origin = kwargs.get("default", None) + + def __call__(self, task_extractor, samples: BaseDataSource): + self.task_extractor = {"Synthia": 0, "Cityscapes": 1, "Cloud-Robotics": 2} # Mapping of origins to task indices + + if self.default_origin: + return samples, [int(self.task_extractor.get(self.default_origin))] * len(samples.x) + + categories = ["Cityscapes", "Synthia", "Cloud-Robotics"] # List of all possible origins + + sample_origins = [] + for _x in samples.x: + sample_origin = None + for category in categories: + if category in _x[0]: + sample_origin = category + break + if sample_origin is None: + # If none of the categories match, assign a default origin + sample_origin = self.default_origin if self.default_origin else categories[0] + sample_origins.append(sample_origin) + + allocations = [int(self.task_extractor.get(sample_origin)) for sample_origin in sample_origins] + + return samples, allocations diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_definition_by_domain.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_definition_by_domain.py new file mode 100644 index 00000000..eba0f74f --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_definition_by_domain.py @@ -0,0 +1,56 @@ +from typing import List, Any, Tuple + +from sedna.datasources import BaseDataSource +from sedna.common.class_factory import ClassType, ClassFactory +from sedna.algorithms.seen_task_learning.artifact import Task + +__all__ = ('TaskDefinitionByDomain',) + + +@ClassFactory.register(ClassType.STP, alias="TaskDefinitionByDomain") +class TaskDefinitionByOrigin: + """ + Dividing datasets based on their origins. + + Parameters + ---------- + origins: List[Metadata] + metadata is usually a class feature label with finite values. + """ + + def __init__(self, **kwargs): + self.origins = kwargs.get("origins", ["Cityscapes", "Synthia", "Cloud-Robotics"]) + + def __call__(self, + samples: BaseDataSource, **kwargs) -> Tuple[List[Task], + Any, + BaseDataSource]: + categories = self.origins + + tasks = [] + d_type = samples.data_type + x_data = samples.x + y_data = samples.y + + task_index = dict(zip(categories, range(len(categories)))) + + data_sources = {category: BaseDataSource(data_type=d_type) for category in categories} + print(data_sources) + + for category in data_sources.values(): + category.x = [] + category.y = [] + + for i in range(samples.num_examples()): + for category in categories: + if category in x_data[i]: + data_sources[category].x.append(x_data[i]) + data_sources[category].y.append(y_data[i]) + break + + for category, data_source in data_sources.items(): + task_name = f"{category}_semantic_segmentation_model" + task_obj = Task(entry=task_name, samples=data_source, meta_attr=category) + tasks.append(task_obj) + + return tasks, task_index, samples \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/test_algorithm.yaml b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/test_algorithm.yaml new file mode 100644 index 00000000..d39634c3 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/test_algorithm.yaml @@ -0,0 +1,63 @@ +algorithm: + # paradigm type; string type; + # currently the options of value are as follows: + # 1> "singletasklearning" + # 2> "incrementallearning" + # 3> "lifelonglearning" + paradigm_type: "lifelonglearning" + lifelong_learning_data_setting: + # ratio of training dataset; float type; + # the default value is 0.8. + train_ratio: 0.9 + # the method of splitting dataset; string type; optional; + # currently the options of value are as follows: + # 1> "default": the dataset is evenly divided based train_ratio; + # splitting_method: "default" + splitting_method: "fwt_splitting" + + # algorithm module configuration in the paradigm; list type; + modules: + # type of algorithm module; string type; + # currently the options of value are as follows: + # 1> "basemodel": contains important interfaces such as train、 eval、 predict and more; required module; + - type: "basemodel" + # name of python module; string type; + # example: basemodel.py has BaseModel module that the alias is "FPN" for this benchmarking; + name: "BaseModel" + # the url address of python module; string type; + url: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/basemodel.py" + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + - learning_rate: + values: + - 0.0001 + - epochs: + values: + - 1 + # 2> "task_definition": define lifelong task ; optional module; + - type: "task_definition" + # name of python module; string type; + name: "TaskDefinitionByDomain" + # the url address of python module; string type; + url: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_definition_by_domain.py" + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + # origins of data; value is ["real", "sim"], this means that data from real camera and simulator. + - origins: + values: + - ["Cityscapes", "Synthia", "Cloud-Robotics"] + # 3> "task_allocation": allocate lifelong task ; optional module; + - type: "task_allocation" + # name of python module; string type; + name: "TaskAllocationByDomain" + # the url address of python module; string type; + url: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testalgorithms/erfnet/task_allocation_by_domain.py" + # hyperparameters configuration for the python module; list type; + hyperparameters: + # name of the hyperparameter; string type; + # origins of data; value is ["real", "sim"], this means that data from real camera and simulator. + - origins: + values: + - ["Cityscapes", "Synthia", "Cloud-Robotics"] \ No newline at end of file diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/accuracy.py b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/accuracy.py new file mode 100644 index 00000000..4dd63613 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/accuracy.py @@ -0,0 +1,54 @@ +# Copyright 2022 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tqdm import tqdm + +from sedna.common.class_factory import ClassType, ClassFactory + +from ERFNet.dataloaders import make_data_loader +from ERFNet.utils.metrics import Evaluator +from ERFNet.utils.args import ValArgs + +__all__ = ('accuracy') + + +@ClassFactory.register(ClassType.GENERAL, alias="accuracy") +def accuracy(y_true, y_pred, **kwargs): + args = ValArgs() + _, _, test_loader, num_class = make_data_loader(args, test_data=y_true) + evaluator = Evaluator(num_class) + #print(y_true) + tbar = tqdm(test_loader, desc='\r') + for i, (sample, img_path) in enumerate(tbar): + if args.depth: + image, depth, target = sample['image'], sample['depth'], sample['label'] + else: + image, target = sample['image'], sample['label'] + if args.cuda: + image, target = image.cuda(args.gpu_ids), target.cuda(args.gpu_ids) + if args.depth: + depth = depth.cuda(args.gpu_ids) + + target[target > evaluator.num_class-1] = 255 + target = target.cpu().numpy() + evaluator.add_batch(target, y_pred[i]) + + # Test during the training + # Acc = evaluator.Pixel_Accuracy() + CPA = evaluator.Pixel_Accuracy_Class() + mIoU = evaluator.Mean_Intersection_over_Union() + FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union() + + print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU)) + return mIoU diff --git a/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/testenv.yaml b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/testenv.yaml new file mode 100644 index 00000000..d00b4aa6 --- /dev/null +++ b/examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/testenv.yaml @@ -0,0 +1,38 @@ +testenv: + # dataset configuration + dataset: + # the url address of train dataset index; string type; + train_url: "/home/QXY/dataset/mdil-ss/train/mdil-ss-train-index-small.txt" + # the url address of test dataset index; string type; + test_url: "/home/QXY/dataset/mdil-ss/test/mdil-ss-test-index-small.txt" + + # model eval configuration of incremental learning; + model_eval: + # metric used for model evaluation + model_metric: + # metric name; string type; + name: "accuracy" + # the url address of python file + url: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/accuracy.py" + mode: "no-inference" + + # condition of triggering inference model to update + # threshold of the condition; types are float/int + threshold: 0 + # operator of the condition; string type; + # values are ">=", ">", "<=", "<" and "="; + operator: "<" + + # metrics configuration for test case's evaluation; list type; + metrics: + # metric name; string type; + - name: "accuracy" + # the url address of python file + url: "./examples/class_increment_semantic_segmentation/lifelong_learning_bench/testenv/accuracy.py" + - name: "samples_transfer_ratio" + - name: "BWT" + - name: "FWT" + - name: "Matrix" + + # incremental rounds setting; int type; default value is 2; + incremental_rounds: 3 \ No newline at end of file