-
Notifications
You must be signed in to change notification settings - Fork 3
/
dataset.py
86 lines (66 loc) · 2.08 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""
Define a Dataset model that inherits from Pytorch's base class.
"""
import os
import random
import numpy as np
import torch
from torchvision import transforms
from torch.utils.data import Dataset
from scipy.io import loadmat
class Data(Dataset):
def __init__(self, filename_x='data_25', filename_y='data_125',
directory="Data/", transform=transforms.ToTensor()):
# Loading data.
x = loadmat(os.path.join(directory, filename_x))[filename_x]
y = loadmat(os.path.join(directory, filename_y))[filename_y]
'''
# Transform makes sure that type is torch and that the
# dimensions are (NxHxW).
x_transformed = transforms(x)
y_transformed = transforms(y)
'''
self.transform = transform
x = x.transpose(2, 0, 1)
y = y.transpose(2, 0, 1)
self.data = {
'X': x,
'Y': y
}
# Save data shapes for creating models.
self.input_dim = x.shape[-2:]
self.output_dim = y.shape[-2:]
self.output_dim_fk = list(self.output_dim)
self.output_dim_fk[-1] = self.output_dim_fk[-1] // 2 + 1
def __len__(self):
return self.data['X'].shape[0]
def __getitem__(self, idx):
sample = {
'x': self.data['X'][idx],
'y': self.data['Y'][idx]
}
if self.transform:
sample = self.transform(sample)
return sample
class ToTensor(object):
def __call__(self, sample):
x, y = sample['x'], sample['y']
return {
'x': torch.from_numpy(x.copy()).unsqueeze(0),
'y': torch.from_numpy(y.copy()).unsqueeze(0)
}
class RandomHorizontalFlip(object):
def __init__(self, flip_p=0.5):
"""
Randomly flip a sample horizontally.
"""
self.flip_p = flip_p
def __call__(self, sample):
x, y = sample['x'], sample['y']
if random.random() < self.flip_p:
x = np.fliplr(x)
y = np.fliplr(y)
return {
'x': x,
'y': y
}