-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
75 lines (65 loc) · 2.73 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from clip import clip
def get_model(model_name,device):
return clip.load(model_name,device=device)
import torch
from dataset import VisaDataset, MVTecDataset
import torch.nn.functional as F
import kornia as K
import torch
import numpy as np
def get_rot_mat(theta):
theta = torch.tensor(theta)
return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]])
def get_translation_mat(a, b):
return torch.tensor([[1, 0, a],
[0, 1, b]])
def rot_img(x, theta):
dtype = torch.FloatTensor
rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size(),align_corners=True).type(dtype)
x = F.grid_sample(x, grid, padding_mode="reflection",align_corners=True)
return x
def translation_img(x, a, b):
dtype = torch.FloatTensor
rot_mat = get_translation_mat(a, b)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size(),align_corners=True).type(dtype)
x = F.grid_sample(x, grid, padding_mode="reflection",align_corners=True)
return x
def hflip_img(x):
x = K.geometry.transform.hflip(x)
return x
def rot90_img(x,k):
# k is 0,1,2,3
degreesarr = [0., 90., 180., 270., 360]
degrees = torch.tensor(degreesarr[k])
x = K.geometry.transform.rotate(x, angle = degrees, padding_mode='reflection')
return x
def grey_img(x):
x = K.color.rgb_to_grayscale(x)
x = x.repeat(1, 3, 1,1)
return x
def aug(support_img):
augment_support_img = support_img
# rotate img with small angle
for angle in [-np.pi / 4, -3 * np.pi / 16, -np.pi / 8, -np.pi / 16, np.pi / 16, np.pi / 8, 3 * np.pi / 16,
np.pi / 4]:
rotate_img = rot_img(support_img, angle)
augment_support_img = torch.cat([augment_support_img, rotate_img], dim=0)
# translate img
for a, b in [(0.2, 0.2), (-0.2, 0.2), (-0.2, -0.2), (0.2, -0.2), (0.1, 0.1), (-0.1, 0.1), (-0.1, -0.1),
(0.1, -0.1)]:
trans_img = translation_img(support_img, a, b)
augment_support_img = torch.cat([augment_support_img, trans_img], dim=0)
# hflip img
flipped_img = hflip_img(support_img)
augment_support_img = torch.cat([augment_support_img, flipped_img], dim=0)
# rgb to grey img
greyed_img = grey_img(support_img)
augment_support_img = torch.cat([augment_support_img, greyed_img], dim=0)
# rotate img in 90 degree
for angle in [1, 2, 3]:
rotate90_img = rot90_img(support_img, angle)
augment_support_img = torch.cat([augment_support_img, rotate90_img], dim=0)
augment_support_img = augment_support_img[torch.randperm(augment_support_img.size(0))]
return augment_support_img