Skip to content

Commit

Permalink
Merge pull request #231 from Mingosnake/pl_version_update
Browse files Browse the repository at this point in the history
feat: update pl version supports
  • Loading branch information
gwkrsrch committed Jul 31, 2023
2 parents 681d9aa + 1e66b65 commit 4cfcf97
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 22 deletions.
1 change: 1 addition & 0 deletions donut/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def __init__(
num_heads=[4, 8, 16, 32],
num_classes=0,
)
self.model.norm = None

# weight init with swin
if not name_or_path:
Expand Down
35 changes: 18 additions & 17 deletions lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(self, config):
# encoder_layer=[2,2,14,2], decoder_layer=4, ...
)
)
self.pytorch_lightning_version_is_1 = int(pl.__version__[0]) < 2
self.num_of_loaders = len(self.config.dataset_name_or_paths)

def training_step(self, batch, batch_idx):
image_tensors, decoder_input_ids, decoder_labels = list(), list(), list()
Expand All @@ -56,9 +58,16 @@ def training_step(self, batch, batch_idx):
decoder_labels = torch.cat(decoder_labels)
loss = self.model(image_tensors, decoder_input_ids, decoder_labels)[0]
self.log_dict({"train_loss": loss}, sync_dist=True)
if not self.pytorch_lightning_version_is_1:
self.log('loss', loss, prog_bar=True)
return loss

def validation_step(self, batch, batch_idx, dataset_idx=0):
def on_validation_epoch_start(self) -> None:
super().on_validation_epoch_start()
self.validation_step_outputs = [[] for _ in range(self.num_of_loaders)]
return

def validation_step(self, batch, batch_idx, dataloader_idx=0):
image_tensors, decoder_input_ids, prompt_end_idxs, answers = batch
decoder_prompts = pad_sequence(
[input_id[: end_idx + 1] for input_id, end_idx in zip(decoder_input_ids, prompt_end_idxs)],
Expand All @@ -84,17 +93,16 @@ def validation_step(self, batch, batch_idx, dataset_idx=0):
self.print(f" Answer: {answer}")
self.print(f" Normed ED: {scores[0]}")

self.validation_step_outputs[dataloader_idx].append(scores)

return scores

def validation_epoch_end(self, validation_step_outputs):
num_of_loaders = len(self.config.dataset_name_or_paths)
if num_of_loaders == 1:
validation_step_outputs = [validation_step_outputs]
assert len(validation_step_outputs) == num_of_loaders
cnt = [0] * num_of_loaders
total_metric = [0] * num_of_loaders
val_metric = [0] * num_of_loaders
for i, results in enumerate(validation_step_outputs):
def on_validation_epoch_end(self):
assert len(self.validation_step_outputs) == self.num_of_loaders
cnt = [0] * self.num_of_loaders
total_metric = [0] * self.num_of_loaders
val_metric = [0] * self.num_of_loaders
for i, results in enumerate(self.validation_step_outputs):
for scores in results:
cnt[i] += len(scores)
total_metric[i] += np.sum(scores)
Expand Down Expand Up @@ -136,13 +144,6 @@ def lr_lambda(current_step):

return LambdaLR(optimizer, lr_lambda)

def get_progress_bar_dict(self):
items = super().get_progress_bar_dict()
items.pop("v_num", None)
items["exp_name"] = f"{self.config.get('exp_name', '')}"
items["exp_version"] = f"{self.config.get('exp_version', '')}"
return items

@rank_zero_only
def on_save_checkpoint(self, checkpoint):
save_path = Path(self.config.result_path) / self.config.exp_name / self.config.exp_version
Expand Down
37 changes: 32 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,34 @@ def save_config_file(config, path):
print(f"Config is saved at {save_path}")


class ProgressBar(pl.callbacks.TQDMProgressBar):
def __init__(self, config):
super().__init__()
self.enable = True
self.config = config

def disable(self):
self.enable = False

def get_metrics(self, trainer, model):
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
items["exp_name"] = f"{self.config.get('exp_name', '')}"
items["exp_version"] = f"{self.config.get('exp_version', '')}"
return items


def set_seed(seed):
pytorch_lightning_version = int(pl.__version__[0])
if pytorch_lightning_version < 2:
pl.utilities.seed.seed_everything(seed, workers=True)
else:
import lightning_fabric
lightning_fabric.utilities.seed.seed_everything(seed, workers=True)


def train(config):
pl.utilities.seed.seed_everything(config.get("seed", 42), workers=True)
set_seed(config.get("seed", 42))

model_module = DonutModelPLModule(config)
data_module = DonutDataPLModule(config)
Expand Down Expand Up @@ -111,11 +137,12 @@ def train(config):
mode="min",
)

bar = ProgressBar(config)

custom_ckpt = CustomCheckpointIO()
trainer = pl.Trainer(
resume_from_checkpoint=config.get("resume_from_checkpoint_path", None),
num_nodes=config.get("num_nodes", 1),
gpus=torch.cuda.device_count(),
devices=torch.cuda.device_count(),
strategy="ddp",
accelerator="gpu",
plugins=custom_ckpt,
Expand All @@ -127,10 +154,10 @@ def train(config):
precision=16,
num_sanity_val_steps=0,
logger=logger,
callbacks=[lr_callback, checkpoint_callback],
callbacks=[lr_callback, checkpoint_callback, bar],
)

trainer.fit(model_module, data_module)
trainer.fit(model_module, data_module, ckpt_path=config.get("resume_from_checkpoint_path", None))


if __name__ == "__main__":
Expand Down

0 comments on commit 4cfcf97

Please sign in to comment.