Skip to content

Commit

Permalink
added option for user to define data transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacCorley committed Jan 26, 2020
1 parent e581893 commit 2365ce7
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 20 deletions.
13 changes: 10 additions & 3 deletions torch_enhance/datasets/bsds300.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,17 @@ def __init__(
image_size=256,
color_space='RGB',
train=True,
data_dir=None
data_dir=None,
lr_transforms=None,
hr_transforms=None
):
super(BSDS300, self).__init__()

self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space
self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms

if data_dir is None:
data_dir = os.path.join(os.getcwd(), self.base_dir)
Expand All @@ -30,8 +35,10 @@ def __init__(
self.set_dir = os.path.join(self.root_dir, 'train' if train else 'test')
self.file_names = self.get_files(self.set_dir)

self.lr_transform = self.get_lr_transforms()
self.hr_transform = self.get_hr_transforms()
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()

def download(self, data_dir):

Expand Down
13 changes: 10 additions & 3 deletions torch_enhance/datasets/bsds500.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ def __init__(
image_size=256,
color_space='RGB',
set_type='train',
data_dir=None
data_dir=None,
lr_transforms=None,
hr_transforms=None
):
super(BSDS500, self).__init__()

self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space
self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms

if data_dir is None:
data_dir = os.path.join(os.getcwd(), self.base_dir)
Expand All @@ -31,8 +36,10 @@ def __init__(
self.set_dir = os.path.join(self.root_dir, set_type)
self.file_names = self.get_files(self.set_dir)

self.lr_transform = self.get_lr_transforms()
self.hr_transform = self.get_hr_transforms()
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()

def download(self, data_dir):

Expand Down
12 changes: 9 additions & 3 deletions torch_enhance/datasets/historical.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ def __init__(
scale_factor=2,
image_size=256,
color_space='L',
data_dir=None
data_dir=None,
lr_transforms=None,
hr_transforms=None
):
super(Historical, self).__init__()

self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space
self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms

if data_dir is None:
data_dir = os.path.join(os.getcwd(), self.base_dir)
Expand All @@ -29,8 +33,10 @@ def __init__(
self.download(data_dir)
self.file_names = self.get_files(self.root_dir)

self.lr_transform = self.get_lr_transforms()
self.hr_transform = self.get_hr_transforms()
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()

def download(self, data_dir):

Expand Down
14 changes: 10 additions & 4 deletions torch_enhance/datasets/set14.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,29 @@ def __init__(
scale_factor=2,
image_size=256,
color_space='RGB',
data_dir=None
data_dir=None,
lr_transforms=None,
hr_transforms=None
):
super(Set14, self).__init__()

self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space

self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms

if data_dir is None:
data_dir = os.path.join(os.getcwd(), self.base_dir)

self.root_dir = os.path.join(data_dir, 'Set14')
self.download(data_dir)
self.file_names = self.get_files(self.root_dir)

self.lr_transform = self.get_lr_transforms()
self.hr_transform = self.get_hr_transforms()
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()

def download(self, data_dir):

Expand Down
12 changes: 9 additions & 3 deletions torch_enhance/datasets/set5.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ def __init__(
scale_factor=2,
image_size=256,
color_space='RGB',
data_dir=None
data_dir=None,
lr_transforms=None,
hr_transforms=None
):
super(Set5, self).__init__()

self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space
self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms

if data_dir is None:
data_dir = os.path.join(os.getcwd(), self.base_dir)
Expand All @@ -29,8 +33,10 @@ def __init__(
self.download(data_dir)
self.file_names = self.get_files(self.root_dir)

self.lr_transform = self.get_lr_transforms()
self.hr_transform = self.get_hr_transforms()
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()

def download(self, data_dir):

Expand Down
15 changes: 11 additions & 4 deletions torch_enhance/datasets/t91.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,29 @@ def __init__(
scale_factor=2,
image_size=256,
color_space='RGB',
data_dir=None
data_dir=None,
lr_transforms=None,
hr_transforms=None
):
super(T91, self).__init__()

self.scale_factor = scale_factor
self.image_size = image_size
self.color_space = color_space

self.lr_transforms = lr_transforms
self.hr_transforms = hr_transforms

if data_dir is None:
data_dir = os.path.join(os.getcwd(), self.base_dir)

self.root_dir = os.path.join(data_dir, 'T91')
self.download(data_dir)
self.file_names = self.get_files(self.root_dir)

self.lr_transform = self.get_lr_transforms()
self.hr_transform = self.get_hr_transforms()
if self.lr_transforms is None:
self.lr_transform = self.get_lr_transforms()
if self.hr_transforms is None:
self.hr_transform = self.get_hr_transforms()

def download(self, data_dir):

Expand Down

0 comments on commit 2365ce7

Please sign in to comment.