-
Notifications
You must be signed in to change notification settings - Fork 6
/
synth_img.py
117 lines (92 loc) · 2.98 KB
/
synth_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import pathlib
import yaml
import torch
import torchvision.transforms.functional as TF
import matplotlib.colors
from PIL import Image
import numpy as np
import options.gan
import datasets.nrw
from utils import unwrap_state_dict
from IPython import embed
def invert_colormap(img, cmap, norm):
img_invert = np.zeros(img.shape[:2], dtype=np.int32)
for color, idx in zip(cmap.colors, range(int(norm.vmin), int(norm.vmax)+1)):
# conversion from hex to rgb and rescaling
color_rgb = matplotlib.colors.to_rgb(color)
red, green, blue = (255*x for x in color_rgb)
red_mask = img[:, :, 0] == red
green_mask = img[:, :, 1] == green
blue_mask = img[:, :, 2] == blue
mask = np.logical_and(red_mask, green_mask)
mask = np.logical_and(mask, blue_mask)
img_invert[mask] = idx
return img_invert
###################
# #
# Parse arguments #
# #
###################
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--seg", type=str, help="segmentation map"
)
parser.add_argument(
"--dem", type=str, help="digitial elevation model"
)
parser.add_argument("model", type=str)
parser.add_argument("output", type=str)
args = parser.parse_args()
########################
# #
# Get config and model #
# #
########################
OUT_DIR = pathlib.Path(args.model).absolute().parents[0]
# loading config
with open(OUT_DIR / "config.yml", "r") as stream:
CONFIG = yaml.load(stream)
print("config: {}".format(CONFIG))
if torch.cuda.device_count() >= 1:
device = torch.device("cuda")
else:
device = torch.device("cpu")
print("loading model {}".format(args.model))
model = options.gan.get_generator(CONFIG)
# remove distributed wrapping, i.e. module. from keynames
state_dict = unwrap_state_dict(torch.load(args.model))
model.load_state_dict(state_dict)
model.eval()
model.to(device)
##############
# #
# Load image #
# #
##############
def seg2tensor(seg):
seg = np.array(Image.open(seg))
seg_inv = invert_colormap(seg, datasets.nrw.lcov_cmap, datasets.nrw.lcov_norm)
seg_inv_one_hot = torch.nn.functional.one_hot(TF.to_tensor(seg_inv).long(), 11).squeeze().permute(2, 0, 1).float()
return seg_inv_one_hot.unsqueeze(0)
def dem2tensor(dem):
dem = np.array(Image.open(dem))
return TF.to_tensor(dem).unsqueeze(0)
sample = {}
if args.seg:
sample["seg"] = seg2tensor(args.seg)
if args.dem:
sample["dem"] = dem2tensor(args.dem)
with torch.no_grad():
fake_rgb = model({k: v.to(device) for k, v in sample.items()})
def sar2rgb(sar):
return np.squeeze(np.clip(255*sar, 0, 255).astype(np.uint8))
# for SAR
# fake_rgb = sar2rgb(fake_rgb.squeeze().cpu().numpy())
# for RGB
fake_rgb = (fake_rgb.squeeze().cpu().numpy() * 255).astype(np.uint8)
fake_rgb = np.moveaxis(fake_rgb, 0, 2)
result = Image.fromarray(fake_rgb)
result.save(args.output)