-
Notifications
You must be signed in to change notification settings - Fork 5
/
potsdam.py
66 lines (56 loc) · 2.42 KB
/
potsdam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# Copyright (c) OpenMMLab. All rights reserved.
from .builder import DATASETS
from .custom import CustomDataset
from mmseg.core import intersect_and_union
from .pipelines import LoadAnnotationsReduceIgnoreIndex
@DATASETS.register_module()
class PotsdamDataset(CustomDataset):
"""ISPRS Potsdam dataset.
In segmentation map annotation for Potsdam dataset, 0 is the ignore index.
``reduce_zero_label`` should be set to True. The ``img_suffix`` and
``seg_map_suffix`` are both fixed to '.png'.
"""
CLASSES = ('impervious_surface', 'building', 'low_vegetation', 'tree',
'car')
PALETTE = [[255, 255, 255], [0, 0, 255], [0, 255, 255], [0, 255, 0],
[255, 255, 0]]
def __init__(self, **kwargs):
super(PotsdamDataset, self).__init__(
img_suffix='.png',
seg_map_suffix='.png',
reduce_zero_label=True,
**kwargs)
self.gt_seg_map_loader = LoadAnnotationsReduceIgnoreIndex(ignore_index=6, reduce_zero_label=True)
def pre_eval(self, preds, indices):
"""Collect eval result from each iteration.
Args:
preds (list[torch.Tensor] | torch.Tensor): the segmentation logit
after argmax, shape (N, H, W).
indices (list[int] | int): the prediction related ground truth
indices.
Returns:
list[torch.Tensor]: (area_intersect, area_union, area_prediction,
area_ground_truth).
"""
# In order to compat with batch inference
if not isinstance(indices, list):
indices = [indices]
if not isinstance(preds, list):
preds = [preds]
pre_eval_results = []
for pred, index in zip(preds, indices):
seg_map = self.get_gt_seg_map_by_idx(index)
pre_eval_results.append(
intersect_and_union(
pred,
seg_map,
len(self.CLASSES),
self.ignore_index,
# as the labels has been converted when dataset initialized
# in `get_palette_for_custom_classes ` this `label_map`
# should be `dict()`, see
# https://github.com/open-mmlab/mmsegmentation/issues/1415
# for more ditails
label_map=dict(),
reduce_zero_label=False))
return pre_eval_results