Skip to content

Commit

Permalink
🎨 cleanup dm
Browse files Browse the repository at this point in the history
  • Loading branch information
nateraw committed Jul 29, 2020
1 parent e01975d commit 55146c4
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 13 deletions.
5 changes: 4 additions & 1 deletion pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ def __init__(
# link default data
if datamodule is None:
datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers)

self.datamodule = datamodule
self.datamodule.prepare_data()

self.img_dim = self.datamodule.size()

Expand Down Expand Up @@ -121,6 +121,9 @@ def test_epoch_end(self, outputs):
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

def prepare_data(self):
self.datamodule.prepare_data()

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand Down
9 changes: 0 additions & 9 deletions pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,6 @@ def configure_optimizers(self):
def prepare_data(self):
self.datamodule.prepare_data()

def train_dataloader(self):
return self.datamodule.train_dataloader(self.hparams.batch_size)

def val_dataloader(self):
return self.datamodule.val_dataloader(self.hparams.batch_size)

def test_dataloader(self):
return self.datamodule.test_dataloader(self.hparams.batch_size)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand Down
3 changes: 0 additions & 3 deletions pl_bolts/models/gans/basic/basic_gan_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,9 +184,6 @@ def configure_optimizers(self):
def prepare_data(self):
self.datamodule.prepare_data()

def train_dataloader(self):
return self.datamodule.train_dataloader(self.hparams.batch_size)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
Expand Down

0 comments on commit 55146c4

Please sign in to comment.