-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
59 lines (46 loc) · 1.78 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import matplotlib.pyplot as plt
import numpy as np
from pytorch_toolbelt.utils import fs
from sklearn.utils import compute_sample_weight
def get_balanced_weights(dataset):
"""Compute sample weights for each mask in the dataset. """
labels=[]
for mask in dataset.masks_fps:
mask = fs.read_image_as_is(mask)
unique_labels = np.unique(mask)
labels.append(''.join([str(int(i)) for i in unique_labels]))
weights = compute_sample_weight('balanced', labels)
return weights
def visualize(**images):
"""PLot images in one row."""
n = len(images)
plt.figure(figsize=(16, 5))
for i, (name, image) in enumerate(images.items()):
plt.subplot(1, n, i + 1)
plt.xticks([])
plt.yticks([])
plt.title(' '.join(name.split('_')).title())
plt.imshow(image)
plt.show()
palette ={ 0 : (0,0,0), # NoData
1 : (77, 255, 0), #Ausgebaute Allwetterstrasse
2 : (204, 0, 0), #Eisenbahn
3 : (230, 128, 0), #Fußweg
4 : (255, 0, 0), #Karawanenweg
5 : (0, 204, 242), #Saumweg
}
invert_palette = {v: k for k, v in palette.items()}
def convert_to_color(arr_2d, palette=palette):
""" Numeric labels to RGB-color encoding """
arr_3d = np.zeros((arr_2d.shape[0], arr_2d.shape[1], 3), dtype=np.uint8)
for c, i in palette.items():
m = arr_2d == c
arr_3d[m] = i
return arr_3d
def convert_from_color(arr_3d, palette=invert_palette):
""" RGB-color encoding to grayscale labels """
arr_2d = np.zeros((arr_3d.shape[0], arr_3d.shape[1]), dtype=np.uint8)
for c, i in palette.items():
m = np.all(arr_3d == np.array(c).reshape(1, 1, 3), axis=2)
arr_2d[m] = i
return arr_2d