Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
csuhan committed Apr 13, 2021
1 parent c8b0846 commit 88f8170
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tools/publish_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import argparse
import subprocess
from collections import OrderedDict

import torch


def parse_args():
parser = argparse.ArgumentParser(
description='Process a checkpoint to be published')
Expand All @@ -18,6 +18,13 @@ def process_checkpoint(in_file, out_file):
# remove optimizer for smaller file size
if 'optimizer' in checkpoint:
del checkpoint['optimizer']
if 'state_dict' in checkpoint:
in_state_dict = checkpoint.pop('state_dict')
out_state_dict = OrderedDict()
for key, val in in_state_dict.items():
key = key.replace('backbone.','')
out_state_dict[key] = val
checkpoint['state_dict'] = out_state_dict
# if it is necessary to remove some sensitive data in checkpoint['meta'],
# add the code here.
torch.save(checkpoint, out_file)
Expand Down

0 comments on commit 88f8170

Please sign in to comment.