Skip to content

Commit

Permalink
Fixing the state.timestamp.batch.value issue in loss v len callback (#…
Browse files Browse the repository at this point in the history
…1232)

* adding print statements

* testing fix

* fix

* removing print statements

* minor fix
  • Loading branch information
ShashankMosaicML authored May 23, 2024
1 parent c213ea8 commit 6fa6026
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions llmfoundry/callbacks/loss_perp_v_len_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,24 @@ def after_backward(self, state: State, logger: Logger) -> None:
)

def batch_end(self, state: State, logger: Logger) -> None:
if state.timestamp.batch.value % self.compute_batch_interval == 0:
if (
state.timestamp.batch.value - 1
) % self.compute_batch_interval == 0: # state.timestamp.batch.value - 1 because batch is incremented before batch_end (https://github.com/mosaicml/composer/blob/57c7b72b9df41b0c9777bad1c2bec17f3103c31f/composer/trainer/trainer.py#L2478C1-L2484C55)
current_metric_dict = self.loss_perp_v_len.compute()
if dist.get_global_rank() == 0:
for k, v in current_metric_dict.items():
v = v.tolist()
v.append(
state.timestamp.batch.value,
state.timestamp.batch.value -
1, # state.timestamp.batch.value - 1 because batch is incremented before batch_end (https://github.com/mosaicml/composer/blob/57c7b72b9df41b0c9777bad1c2bec17f3103c31f/composer/trainer/trainer.py#L2478C1-L2484C55)
) # Add the current batch index as the last column
if k not in self.metric_dict:
self.metric_dict[k] = []
self.metric_dict[k].append(v)
if state.timestamp.batch.value % self.log_batch_interval == 0 and dist.get_global_rank(
) == 0:
if (
state.timestamp.batch.value - 1
) % self.log_batch_interval == 0 and dist.get_global_rank(
) == 0: # state.timestamp.batch.value - 1 because batch is incremented before batch_end (https://github.com/mosaicml/composer/blob/57c7b72b9df41b0c9777bad1c2bec17f3103c31f/composer/trainer/trainer.py#L2478C1-L2484C55)
for k, v in self.metric_dict.items():
columns = []
columns = [
Expand Down

0 comments on commit 6fa6026

Please sign in to comment.