-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
42 lines (33 loc) · 1.18 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import matplotlib.pyplot as plt
import torch
from model_design import AutoEncoder
from data import get_dataset
if __name__ == '__main__':
train_loader, train_data_y, test_loader, test_data_y = get_dataset()
auto_encoder_model = AutoEncoder()
checkpoint = torch.load("./checkpoint/auto_encoder.pth")
auto_encoder_model.load_state_dict(checkpoint["state_dict"])
auto_encoder_model.eval()
print("Load model done")
for i, test_data in enumerate(test_loader):
if i > 0:
break
_, test_decoder = auto_encoder_model(test_data)
plt.figure(figsize=(10, 10))
for i in range(test_decoder.shape[0]):
plt.subplot(10, 10, i + 1)
img = test_data[i, :]
img = img.data.numpy().reshape(28, 28)
plt.imshow(img, 'gray')
plt.axis('off')
plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.show()
plt.figure(figsize=(10, 10))
for i in range(test_decoder.shape[0]):
plt.subplot(10, 10, i + 1)
img = test_decoder[i, :]
img = img.data.numpy().reshape(28, 28)
plt.imshow(img, 'gray')
plt.axis('off')
plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.show()