Skip to content

Commit

Permalink
Unsets profile_duration_ms in Flax ResNet example to prioritize `nu…
Browse files Browse the repository at this point in the history
…m_profile_steps`.

PiperOrigin-RevId: 648828498
  • Loading branch information
allenwang28 authored and Flax Authors committed Jul 2, 2024
1 parent 37123d5 commit 3898afd
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion examples/imagenet/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,8 @@ def train_and_evaluate(
train_metrics = []
hooks = []
if jax.process_index() == 0 and config.profile:
hooks += [periodic_actions.Profile(num_profile_steps=3, logdir=workdir)]
hooks += [periodic_actions.Profile(
num_profile_steps=3, logdir=workdir, profile_duration_ms=None)]
train_metrics_last_t = time.time()
logging.info('Initial compilation, this might take some minutes...')
for step, batch in zip(range(step_offset, num_steps), train_iter):
Expand Down

0 comments on commit 3898afd

Please sign in to comment.