Skip to content

Commit

Permalink
temporarily disable validation
Browse files Browse the repository at this point in the history
  • Loading branch information
LoicGrobol committed May 29, 2020
1 parent 0f68169 commit 85ee87c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
54 changes: 27 additions & 27 deletions zeldarose/mlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,33 +171,33 @@ def training_epoch_end(self, outputs):
results = {"avg_train_loss": avg_loss, "train_perplexity": perplexity}
return results

def validation_step(self, batch: zeldarose.data.TextBatch, batch_idx: int):
tokens, attention_mask, internal_tokens_mask, token_type_ids = batch
with torch.no_grad():
masked = mask_tokens(
inputs=tokens,
change_ratio=self.task_config.change_ratio,
keep_mask=internal_tokens_mask,
mask_ratio=self.task_config.mask_ratio,
input_mask_index=self.mask_token_index,
switch_ratio=self.task_config.switch_ratio,
vocabulary_size=self.vocabulary_size,
)

outputs = self.forward(
tokens=masked.inputs,
attention_mask=attention_mask,
mlm_labels=masked.labels,
token_type_ids=token_type_ids,
)

loss = outputs[0]

preds = torch.argmax(outputs[1], dim=-1)
correct_preds = preds.eq(masked.labels) & masked.labels.ne(-100)
accuracy = correct_preds.float().mean()

return {"val_loss": loss, "val_acc": accuracy}
# def validation_step(self, batch: zeldarose.data.TextBatch, batch_idx: int):
# tokens, attention_mask, internal_tokens_mask, token_type_ids = batch
# with torch.no_grad():
# masked = mask_tokens(
# inputs=tokens,
# change_ratio=self.task_config.change_ratio,
# keep_mask=internal_tokens_mask,
# mask_ratio=self.task_config.mask_ratio,
# input_mask_index=self.mask_token_index,
# switch_ratio=self.task_config.switch_ratio,
# vocabulary_size=self.vocabulary_size,
# )

# outputs = self.forward(
# tokens=masked.inputs,
# attention_mask=attention_mask,
# mlm_labels=masked.labels,
# token_type_ids=token_type_ids,
# )

# loss = outputs[0]

# preds = torch.argmax(outputs[1], dim=-1)
# correct_preds = preds.eq(masked.labels) & masked.labels.ne(-100)
# accuracy = correct_preds.float().mean()

# return {"val_loss": loss, "val_acc": accuracy}

def validation_end(self, outputs):
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()
Expand Down
23 changes: 16 additions & 7 deletions zeldarose/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import tempfile

from typing import Optional, Type, Union
from typing import List, Optional, Type, Union

import click
import click_pathlib
Expand Down Expand Up @@ -371,6 +371,7 @@ def main(
)
val_set: Optional[data.TextDataset]
if val_path is not None:
raise NotImplementedError("Epoch validation is not implemented yet")
val_set = dataset_type(
tokenizer=tokenizer,
text_path=val_path,
Expand Down Expand Up @@ -428,13 +429,18 @@ def main(
train_loader = data.TextLoader(
train_set, batch_size=loader_batch_size, num_workers=n_workers, shuffle=True,
)
val_loader: Optional[data.TextLoader]
val_loaders: Optional[List[data.TextLoader]]
if val_set is not None:
val_loader = data.TextLoader(
val_set, batch_size=loader_batch_size, num_workers=n_workers, shuffle=False,
)
val_loaders = [
data.TextLoader(
val_set,
batch_size=loader_batch_size,
num_workers=n_workers,
shuffle=False,
)
]
else:
val_loader = None
val_loaders = None

logger.info(f"Creating trainer")
if profile:
Expand All @@ -455,6 +461,7 @@ def main(
max_epochs=max_epochs,
max_steps=max_steps,
track_grad_norm=2,
val_percent_check=1.0 if val_loaders is not None else 0.0,
**profile_kwargs,
)

Expand All @@ -463,7 +470,9 @@ def main(
else:
logger.info(f"Training the model on CPU")

trainer.fit(finetuning_model, train_dataloader=train_loader, val_dataloaders=[val_loader])
trainer.fit(
finetuning_model, train_dataloader=train_loader, val_dataloaders=val_loaders
)

# TODO: only on rank 0
# TODO: this saves the model with the LM head but we might want the base model
Expand Down

0 comments on commit 85ee87c

Please sign in to comment.