-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
nst.py
112 lines (90 loc) · 3.86 KB
/
nst.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.utils import save_image
class VGG(nn.Module):
def __init__(self):
super(VGG, self).__init__()
# The first number x in convx_y gets added by 1 after it has gone
# through a maxpool, and the second y if we have several conv layers
# in between a max pool. These strings (0, 5, 10, ..) then correspond
# to conv1_1, conv2_1, conv3_1, conv4_1, conv5_1 mentioned in NST paper
self.chosen_features = ["0", "5", "10", "19", "28"]
# We don't need to run anything further than conv5_1 (the 28th module in vgg)
# Since remember, we dont actually care about the output of VGG: the only thing
# that is modified is the generated image (i.e, the input).
self.model = models.vgg19(pretrained=True).features[:29]
def forward(self, x):
# Store relevant features
features = []
# Go through each layer in model, if the layer is in the chosen_features,
# store it in features. At the end we'll just return all the activations
# for the specific layers we have in chosen_features
for layer_num, layer in enumerate(self.model):
x = layer(x)
if str(layer_num) in self.chosen_features:
features.append(x)
return features
def load_image(image_name):
image = Image.open(image_name)
image = loader(image).unsqueeze(0)
return image.to(device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
imsize = 356
# Here we may want to use the Normalization constants used in the original
# VGG network (to get similar values net was originally trained on), but
# I found it didn't matter too much so I didn't end of using it. If you
# use it make sure to normalize back so the images don't look weird.
loader = transforms.Compose(
[
transforms.Resize((imsize, imsize)),
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
original_img = load_image("annahathaway.png")
style_img = load_image("style.jpg")
# initialized generated as white noise or clone of original image.
# Clone seemed to work better for me.
# generated = torch.randn(original_img.data.shape, device=device, requires_grad=True)
generated = original_img.clone().requires_grad_(True)
model = VGG().to(device).eval()
# Hyperparameters
total_steps = 6000
learning_rate = 0.001
alpha = 1
beta = 0.01
optimizer = optim.Adam([generated], lr=learning_rate)
for step in range(total_steps):
# Obtain the convolution features in specifically chosen layers
generated_features = model(generated)
original_img_features = model(original_img)
style_features = model(style_img)
# Loss is 0 initially
style_loss = original_loss = 0
# iterate through all the features for the chosen layers
for gen_feature, orig_feature, style_feature in zip(
generated_features, original_img_features, style_features
):
# batch_size will just be 1
batch_size, channel, height, width = gen_feature.shape
original_loss += torch.mean((gen_feature - orig_feature) ** 2)
# Compute Gram Matrix of generated
G = gen_feature.view(channel, height * width).mm(
gen_feature.view(channel, height * width).t()
)
# Compute Gram Matrix of Style
A = style_feature.view(channel, height * width).mm(
style_feature.view(channel, height * width).t()
)
style_loss += torch.mean((G - A) ** 2)
total_loss = alpha * original_loss + beta * style_loss
optimizer.zero_grad()
total_loss.backward()
optimizer.step()
if step % 200 == 0:
print(total_loss)
save_image(generated, "generated.png")