-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
35 lines (25 loc) · 1020 Bytes
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from utils import save_checkpoint, load_checkpoint, inference
import torch.nn as nn
import torch.optim as optim
import config
from dataset import MapDataset
from generator import Generator
from discriminator import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image
import numpy as np
def inference_image():
gen = Generator(in_channels=3, features=64).to(config.DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.5, 0.999))
load_checkpoint(
config.CHECKPOINT_GEN, gen, opt_gen, config.LEARNING_RATE,
)
input_image = np.array(Image.open(config.IMAGE_INFERENCE_PATH))
augmentations = config.both_transform(image=input_image)
input_image = augmentations["image"]
input_image = config.transform_only_input(image=input_image)["image"]
inference(gen, input_image, config.SAVE_INFER)
if __name__ == "__main__":
inference_image()