-
Notifications
You must be signed in to change notification settings - Fork 0
/
infer.py
55 lines (45 loc) · 1.84 KB
/
infer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import os
from tqdm import tqdm
import torchvision.transforms as transforms
from recon_model import Recon_Transformer
from dataset import get_loader
from PIL import Image
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2'
gpu_number = 0
def main():
dataset = "Mirflickr"
num_heads = 4
num_blocks = 6
embed_dim = 128 # dimension of embedding/hidden layer in Transformer
patch_size = 15 # 270 / 15 = 18
n_channels = 3
ffn_multiplier = 2
min_side_len = 270
dropout_rate = 0.1
num_workers = 4
save_path = '/home/ponoma/workspace/Basic_Transformer/checkpoint/'
infer_results = './infer_results/'
if not os.path.exists(infer_results):
os.makedirs(infer_results)
model = Recon_Transformer(min_side_len, patch_size, n_channels, num_heads, num_blocks, embed_dim, ffn_multiplier, dropout_rate)
if torch.cuda.is_available():
device = torch.device("cuda")
torch.cuda.set_device(gpu_number)
model.load_state_dict(torch.load(os.path.join(save_path, 'model.pth')))
model = model.to(device)
else:
device = torch.device("cpu")
model.load_state_dict(torch.load(os.path.join(save_path, 'model.pth'), map_location=torch.device('cpu')))
_, _, test_loader = get_loader(dataset, min_side_len, batch_size=1, num_workers=num_workers)
with torch.no_grad():
model.eval()
for step, batch in enumerate(tqdm(test_loader)):
input, target, img_name = batch
input, target = input.to(device), target.to(device)
output = model(input)
output = transforms.ToPILImage()(output.squeeze())
output.save(os.path.join(infer_results, img_name[0] + '.jpg'), format='JPEG') # Test if this works
if __name__ == "__main__":
main()