Skip to content

Commit

Permalink
allow val dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicki Skafte authored and Borda committed May 4, 2020
1 parent b6c271c commit 24817e9
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions pytorch_lightning/trainer/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def _run_lr_finder_internally(self, model: LightningModule):
def lr_find(self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[DataLoader] = None,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
Expand Down Expand Up @@ -133,7 +134,7 @@ def lr_find(self,
self.checkpoint_callback = False
self.early_stop_callback = None
self.enable_early_stop = False

# Required for saving the model
self.optimizers, self.schedulers = [], [],
self.model = model
Expand All @@ -152,7 +153,9 @@ def lr_find(self,
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])

# Fit, lr & loss logged in callback
self.fit(model, train_dataloader=train_dataloader)
self.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)

# Prompt if we stopped early
if self.global_step != num_training:
Expand Down

0 comments on commit 24817e9

Please sign in to comment.