-
Notifications
You must be signed in to change notification settings - Fork 11
/
myimgfolder.py
49 lines (40 loc) · 1.57 KB
/
myimgfolder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from torchvision import datasets, transforms
from skimage.color import rgb2lab, rgb2gray
import torch
import numpy as np
#import matplotlib.pyplot as plt
scale_transform = transforms.Compose([
transforms.Scale(256),
transforms.RandomCrop(224),
#transforms.ToTensor()
])
class TrainImageFolder(datasets.ImageFolder):
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
if self.transform is not None:
img_original = self.transform(img)
img_original = np.asarray(img_original)
img_lab = rgb2lab(img_original)
img_lab = (img_lab + 128) / 255
img_ab = img_lab[:, :, 1:3]
img_ab = torch.from_numpy(img_ab.transpose((2, 0, 1)))
img_original = rgb2gray(img_original)
img_original = torch.from_numpy(img_original)
if self.target_transform is not None:
target = self.target_transform(target)
return (img_original, img_ab), target
class ValImageFolder(datasets.ImageFolder):
def __getitem__(self, index):
path, target = self.imgs[index]
img = self.loader(path)
img_scale = img.copy()
img_original = img
img_scale = scale_transform(img_scale)
img_scale = np.asarray(img_scale)
img_original = np.asarray(img_original)
img_scale = rgb2gray(img_scale)
img_scale = torch.from_numpy(img_scale)
img_original = rgb2gray(img_original)
img_original = torch.from_numpy(img_original)
return (img_original, img_scale), target