Skip to content

Latest commit

 

History

History
101 lines (65 loc) · 4.42 KB

tf2pytorch.md

File metadata and controls

101 lines (65 loc) · 4.42 KB

Tensorflow slim "Resnet V2" to PyTorch conversion example

Model: "ResNet V2 152" for Imagenet

Source: TensorFlow

Destination: PyTorch


Prepare the TensorFlow model

You need to prepare your pre-trained TensorFlow model firstly. And there is a pre-trained model extractor for frameworks to help you. You can refer it to extract your TensorFlow model checkpoint files.

$ mmdownload -f tensorflow -n resnet_v2_152

Downloading file [./resnet_v2_152_2017_04_14.tar.gz] from [http://download.tensorflow.org/models/resnet_v2_152_2017_04_14.tar.gz]
100% [......................................................................] 675629399 / 675629399

Model saved in file: ./imagenet_resnet_v2_152.ckpt

Then you got the TensorFlow checkpoint files for ResNet V2 152 model which is in your current working directory, including imagenet_resnet_v2_152.ckpt.meta for architecture , imagenet_resnet_v2_152.ckpt.data-00000-of-00001 and imagenet_resnet_v2_152.ckpt.index for weights.

Find the output node of the model

TensorFlow original checkpoint files contain many operators (if you tried tensorboard to visualize the graph) which is not used in our toolkits. we should prune them with specifying the output node of your model.

$ mmvismeta imagenet_resnet_v2_152.ckpt.meta ./log/
.
.
.
TensorBoard 1.5.1 at http://kit-station:6006 (Press CTRL+C to quit)

Then you can open URL above to find the output node of your model,

TensorBoard

like the squeeze node named MMdnn_Output we set up in our tensorflow model extractor. Detail information is in TensorFlow README

Convert TensorFlow Model to PyTorch

We provide two ways to convert models.

One-step Command

Above MMdnn@0.1.4, we provide one command to achieve the conversion

$ mmconvert -sf tensorflow -in imagenet_resnet_v2_152.ckpt.meta -iw imagenet_resnet_v2_152.ckpt --dstNode MMdnn_Output -df pytorch -om tf_to_pytorch_resnet_152.pth
.
.
.
PyTorch model file is saved as [tf_to_pytorch_resnet_152.pth], generated by [052eb72db9934edc90d8e1ffa48144d7.py] and [052eb72db9934edc90d8e1ffa48144d7.npy].

Then you get the PyTorch original model tf_to_pytorch_resnet_152.pth converted from TensorFlow. 052eb72db9934edc90d8e1ffa48144d7.py and 052eb72db9934edc90d8e1ffa48144d7.npy are temporal files which will be removed automatically.

Step-by-step Command for debugging

Convert the pre-trained model files to intermediate representation

$ mmtoir -f tensorflow -n imagenet_resnet_v2_152.ckpt.meta -w imagenet_resnet_v2_152.ckpt --dstNode MMdnn_Output -o converted

Parse file [imagenet_resnet_v2_152.ckpt.meta] with binary format successfully.
Tensorflow model file [imagenet_resnet_v2_152.ckpt.meta] loaded successfully.
Tensorflow checkpoint file [imagenet_resnet_v2_152.ckpt] loaded successfully. [816] variables loaded.
IR network structure is saved as [converted.json].
IR network structure is saved as [converted.pb].
IR weights are saved as [converted.npy].

Then you got the intermediate representation files converted.json for visualization, converted.proto and converted.npy for next steps.

Convert the IR files to PyTorch code

$ mmtocode -f pytorch -n converted.pb -w converted.npy -d converted_pytorch.py -dw converted_pytorch.npy

Parse file [converted.pb] with binary format successfully.
Target network code snippet is saved as [converted_pytorch.py].
Target weights are saved as [converted_pytorch.npy].

And you will get a filename converted_pytorch.py, which contains the original PyTorch codes to build the ResNet V2 152 network and converted_pytorch.npy which is used to set weights in the network building process.

With the three steps, you have already converted the pre-trained TensorFlow ResNet V2 152 models to PyTorch network building file converted_pytorch.py and weights file converted_pytorch.npy. You can use these two files to fine-tune training or inference.

Dump the original PyTorch model

$ mmtomodel -f pytorch -in converted_pytorch.py -iw converted_pytorch.npy -o converted_pytorch.pth

PyTorch model file is saved as [converted_pytorch.pth], generated by [converted_pytorch.py] and [converted_pytorch.npy].

The file converted_pytorch.pth can be loaded by PyTorch directly.