Skip to content

Commit

Permalink
fix static prune bug (PaddlePaddle#2933)
Browse files Browse the repository at this point in the history
  • Loading branch information
yghstill authored May 11, 2021
1 parent d9f8d3b commit d5da5e6
Showing 1 changed file with 13 additions and 14 deletions.
27 changes: 13 additions & 14 deletions static/slim/prune/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,7 @@ def main():
assert FLAGS.prune_criterion in ['l1_norm', 'geometry_median'], \
"unsupported prune criterion {}".format(FLAGS.prune_criterion)
pruner = Pruner(criterion=FLAGS.prune_criterion)
train_prog = pruner.prune(
train_prog,
fluid.global_scope(),
params=pruned_params,
ratios=pruned_ratios,
place=place,
only_graph=False)[0]

compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)

if FLAGS.eval:

base_flops = flops(eval_prog)
eval_prog = pruner.prune(
eval_prog,
Expand All @@ -232,6 +218,19 @@ def main():
pruned_flops))
compiled_eval_prog = fluid.CompiledProgram(eval_prog)

train_prog = pruner.prune(
train_prog,
fluid.global_scope(),
params=pruned_params,
ratios=pruned_ratios,
place=place,
only_graph=False)[0]

compiled_train_prog = fluid.CompiledProgram(train_prog).with_data_parallel(
loss_name=loss.name,
build_strategy=build_strategy,
exec_strategy=exec_strategy)

if FLAGS.resume_checkpoint:
checkpoint.load_checkpoint(exe, train_prog, FLAGS.resume_checkpoint)
start_iter = checkpoint.global_step()
Expand Down

0 comments on commit d5da5e6

Please sign in to comment.