Skip to content

Getting started: using the new features of MIGraphX 0.4

mvermeulen edited this page Sep 3, 2019 · 26 revisions

New Features in MIGraphX 0.4

MIGraphX 0.4 supports the following new features:

  • Quantization support for fp16 and int8
  • Support for NLP models, particularly BERT with both Tensorflow and ONNX examples

This page provides examples and pointers of how to use these new features.

Quantization

Release 0.4 adds support for INT8 quantization as well as FP16 previously introduced in release 0.3. One aspect that int8 quantization differs from fp16 is that MIGraphX needs to determine "scale factors" to convert between fp32 and int8 values. There are two methods of determining such scale factors:

  • MIGraphX has built-in heuristics to pick factors or
  • MIGraphX quantization int8 quantization functions can accept as input a set of "calibration data". The model is run with this calibration data and scale factors are determined by measuring intermediate inputs. The format of the quantization data is the same as data later used for evaluation

The APIs MIGraphX provides for quantization have been updated to the following:

...to be added...

BERT, natural language processing (NLP) model

Release 0.4 includes improvements so that MIGraphX can optimize the BERT NLP model. Examples are included for both ONNX and Tensorflow frozen graphs. These examples are based on the following repositories:

These models are somewhat complex in requiring multiple steps to prepare input, create a frozen model and then run the resultant combination. In the sections below, we describe several aspects.

  • Background information: BERT model, GLUE input and how input data can be constructed
  • ONNX: How to create a PyTorch frozen ONNX model from existing BERT repository
  • Tensorflow: How to create a frozen Tensorflow model from an existing BERT repository
  • Test examples: Program script examples that demonstrate using frozen BERT models (PyTorch ONNX or Tensorflow) to run a sample input

These descriptions are intended to provide an example for getting started with MIGraphX and BERT language models. More complete development of programs are not covered.

BERT model, GLUE benchmark basics

The General Language Understanding Evaluation (GLUE benchmark) is a set of test resources that can be used for evaluating general Natural Language Processing (NLP) systems. For more general understanding of GLUE, please see the link.

For our demonstration purposes, we use the Microsoft Research Paraphase Corpus (MRPC) subtest from GLUE. These tests compare two sentence and see if they are paraphrases of each other. For example, the first two sentence from the "dev" list are:

He said the foodservice pie business doesn 't fit the company 's long-term growth strategy . 

and

The foodservice pie business does not fit our long-term growth strategy .

which MRPC labels as being paraphrases of the same topic.

BERT models that process MRPC take four inputs for training or evaluation:

  • input_ids - this is a sequence of "tokens" where words from sentences above have been turned into numeric indices from a list of tokens. BERT model we use takes as inputs a starting token [CLS], tokens for the first sentence, a separator token [SEP], tokens for the second sentence and a final separator token [SEP]. This tokenization is not unique to MIGraphX and hence we refer to BERT repositories for details. A tokenized input for our two sentences is as follows. The trailing zeros are unused tokens.
101 1124 1163 1103 11785 1200 14301 16288 1671 2144 112 189 4218 1103 1419 112 188 1263 118 1858 3213 5564 119 102 107 1109 11785 1200 14301 16288 1671 1674 1136 4218 1412 1263 118 1858 3213 5564 119 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  • input_mask - this is a sequence of either 1 or 0 to indicate whether a token from input_ids are valid. For our two sentences, it is thus
1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  • sequence_ids - this is a sequence of either 1 or 0 to indicate whether the corresponding (valid) token is part of the first or second sentence. For our two sentences, it is thus
0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
  • label - for training data, this is the label for the two sentences. For example, in MRPC a 1 indicates two sentences are paraphrases and a 0 indicates they are not.

Part of the scripts that run BERT on benchmarks such as GLUE do the pre-processing to turn GLUE benchmark data inputs into the arguments represented above. This pre-processing is not unique to MIGraphX so we refer to BERT repositories for additional details and sample codes. Below, we assume this preprocessing has been done in creating a test driver to illustrate running BERT using MIGraphX.

ONNX: How to create a PyTorch frozen ONNX model from existing BERT repository

Start by creating an ONNX file saved from pytorch-pretrained-transformers repository The first part here is to get sources of the pytorch repository

prompt% git clone https://github.com/huggingface/pytorch-transformers

The next step is to modify the pytorch/transformers/examples/run_glue.py script to dump an ONNX file after the training step completes. We do this by by finding the following code

if args.output_mode == "classification":
   preds = np.argmax(preds, axis=1)
elif args.output_mode == "regression":
   preds = np.squeeze(preds)

and adding the following immediately after

with torch.no_grad():
   model.eval()
   torch.onnx.export(model(batch[0],batch[1],batch[2]),
                     'bert_'+args.task_name.lower()+str(args.eval_batch_size)+'.onnx',verbose=True)

The torch.onnx.export call creates an ONNX model that expects the following three inputs of length sequence_mask:

  • input_ids - sequence of tokens (batch[0])
  • input_mask - 1 when means the corresponding input_id is valid, 0 means not valid (batch[1])
  • segment_ids - 0 means part of the first segment and 1 means part of the second segement (batch[2])

Once these modifications are made to the run_glue.py script, we execute the examples script with the following options to fine-tune a MRPC model and then dump an ONNX file

#!/bin/bash
#
# Script to fine-tune glue tasks and export to ONNX
#
#    git clone https://github.com/huggingface/pytorch-transformers
#
# Then run this script in the pytorch-transformers/examples directory
#GLUE_TASK=${GLUE_TASK:="MRPC"}
GLUE_TASK=${GLUE_TASK:="SST-2"}
BERT_MODEL=${BERT_MODEL:="bert-base-cased"}
GLUE_DATADIR=${GLUE_DATADIR:="/home/mev/source/migraphx_sample/glue/glue_data/${GLUE_TASK}"}
OUTPUT_DIR=${OUTPUT_DIR:="./checkpoint/${GLUE_TASK}"}

# run model to create checkpoints
python3 run_glue.py \
	--model_type bert \
	--model_name_or_path ${BERT_MODEL} \
	--per_gpu_eval_batch_size 1 \
	--task_name ${GLUE_TASK} \
	--do_eval \
	--do_train \
	--output_dir ${OUTPUT_DIR} \
	--data_dir ${GLUE_DATADIR}
	--data_dir ${GLUE_DATADIR}

By default, this results in "bert_mrpc1.onnx" but ONNX files for other GLUE benchmark tasks or batch sizes can also be created.

Tensorflow: How to create a frozen Tensorflow model from an existing BERT repository

Test examples: Program script examples that demonstrate using frozen BERT models

Below is an example code driver that uses the hard-coded GLUE data for MRPC from above as well as either an ONNX file or a frozen Tensorflow protobuf to run the first data point. Turning this into a more real application mostly involves replacing the hard-coded tokenized input with actual code from BERT.

#include <iostream>
#include <string>
#include <vector>
#include <migraphx/program.hpp>
#include <migraphx/onnx.hpp>
#include <migraphx/tf.hpp>
#include <migraphx/gpu/target.hpp>
#include <migraphx/gpu/hip.hpp>

// hard coded values of tokenized inputs for 1st GLUE MRPC entry                                                
std::vector<int64_t> input_ids{ 101, 1124, 1163, 1103, 11785, 1200, 14301, 16288, 1671, 2144, 112, 189, 4218, 1\
103, 1419, 112, 188, 1263, 118, 1858, 3213, 5564, 119, 102, 107, 1109, 11785, 1200, 14301, 16288, 1671, 1674, 1\
136, 4218, 1412, 1263, 118, 1858, 3213, 5564, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0\
, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0\
, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
std::vector<int64_t> input_mask{ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, \
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, \
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
std::vector<int64_t> sequence_ids{ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1\
, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0\
, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0\
, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };

void mrpc_test_onnx(std::string filename){
  // load ONNX file                                                                                             
  auto prog = migraphx::parse_onnx(filename);
  std::cout << prog << std::endl;

  // compile                                                                                                    
  prog.compile(migraphx::gpu::target{});

  // pass in arguments                                                                                          
  for (auto&& x: prog.get_parameter_shapes()){
    std::cout << "parameter: " << x.first << " shape: " << x.second << std::endl;
  }

  migraphx::program::parameter_map pmap;
  pmap["scratch"] = migraphx::gpu::allocate_gpu(prog.get_parameter_shape("scratch"));
  pmap["output"] = migraphx::gpu::allocate_gpu(prog.get_parameter_shape("output"));

  migraphx::argument arg{};
  arg = migraphx::argument(prog.get_parameter_shape("input.1"),input_ids.data());
  pmap["input.1"] = migraphx::gpu::to_gpu(arg);
  arg = migraphx::argument(prog.get_parameter_shape("2"),input_mask.data());
  pmap["2"] = migraphx::gpu::to_gpu(arg);
  arg = migraphx::argument(prog.get_parameter_shape("input.3"),sequence_ids.data());
  pmap["input.3"] = migraphx::gpu::to_gpu(arg);

  // evaluate                                                                                                   
  auto result = migraphx::gpu::from_gpu(prog.eval(pmap));
  std::vector<float> vec_output;
  result.visit([&](auto output){ vec_output.assign(output.begin(),output.end()); });
  std::cout << "result = " << vec_output[0] << ", " << vec_output[1] << std::endl;
}

void mrpc_test_tf(std::string filename){
  // load TF file                                                                                               
  auto prog = migraphx::parse_tf(filename,true);
  std::cout << prog << std::endl;

  // compile                                                                                                    
  prog.compile(migraphx::gpu::target{});

  // pass in arguments                                                                                          
  for (auto&& x: prog.get_parameter_shapes()){
    std::cout << "parameter: " << x.first << " shape: " << x.second << std::endl;
  }
  // TF model uses int32_t for tokens while ONNX uses int64_t, do a quick conversion
  std::vector<int32_t> input_ids32(input_ids.begin(),input_ids.end());
  std::vector<int32_t> input_mask32(input_mask.begin(),input_mask.end());
  std::vector<int32_t> sequence_ids32(sequence_ids.begin(),sequence_ids.end());

  migraphx::program::parameter_map pmap;
  pmap["scratch"] = migraphx::gpu::allocate_gpu(prog.get_parameter_shape("scratch"));
  pmap["output"] = migraphx::gpu::allocate_gpu(prog.get_parameter_shape("output"));

  migraphx::argument arg{};
  arg = migraphx::argument(prog.get_parameter_shape("input_ids_1"),input_ids32.data());
  pmap["input_ids_1"] = migraphx::gpu::to_gpu(arg);
  arg = migraphx::argument(prog.get_parameter_shape("input_mask_1"),input_mask32.data());
  pmap["input_mask_1"] = migraphx::gpu::to_gpu(arg);
  arg = migraphx::argument(prog.get_parameter_shape("segment_ids_1"),sequence_ids32.data());
  pmap["segment_ids_1"] = migraphx::gpu::to_gpu(arg);

  // evaluate                                                                                                   
  auto result = migraphx::gpu::from_gpu(prog.eval(pmap));
  std::vector<float> vec_output;
  result.visit([&](auto output){ vec_output.assign(output.begin(),output.end()); });
  std::cout << "result = " << vec_output[0] << ", " << vec_output[1] << std::endl;
}

int main(int argc, char **argv){
  if (argc != 3){
    std::cout << "Usage: " << argv[0] << " onnx|tf filename" << std::endl;
    return 0;
  }
  if (std::string(argv[1]) == "onnx")
    mrpc_test_onnx(argv[2]);
  else if (std::string(argv[1]) == "tf")
    mrpc_test_tf(argv[2]);
  return 0;
}