forked from nyuolab/NYUTron
-
Notifications
You must be signed in to change notification settings - Fork 0
/
basic_triton_client.py
73 lines (54 loc) · 2.51 KB
/
basic_triton_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
import argparse
import tritonclient.grpc.model_config_pb2 as mc
import tritonclient.http as httpclient
from tritonclient.utils import triton_to_np_dtype
from tritonclient.utils import InferenceServerException
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
def postprocessing(results, labels):
return [labels[str(r)] for r in results]
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-u',
'--url',
type=str,
required=False,
default='localhost:8000',
help='Inference server URL. Default is localhost:8000.')
parser.add_argument('-f',
'--filename',
type=str,
required=False,
default='queries.txt',
help='Text file containing inputs for model')
args = parser.parse_args()
# Get Model MetaData
triton_client = httpclient.InferenceServerClient(url=args.url) # select IP for TRITON
# Set up input/output MetaData
model_name = 'nemo-tokenizer'
model_mode = 'model_trt'
input_names = ['rawtext',]
output_names = ['textout',]
model_dict = {'model_onnx': 0, # ONNX as is
'model_trt': 1, # TensorRT plan
}
with open(args.filename, 'r') as f:
for input_data in f:
input0 = np.array([[input_data] for i in range(1)], dtype=object)
inputs = []
inputs.append(httpclient.InferInput(input_names[0], input0.shape, "BYTES"))
## TODO defaulted everything to TRT Model for now
# inputs.append(httpclient.InferInput('model_mode', [1, 1], "INT64"))
outputs = [httpclient.InferRequestedOutput('textout'),]
# Initialize the data
inputs[0].set_data_from_numpy(input0)
# inputs[1].set_data_from_numpy(np.array([[model_dict[model_mode]]], dtype=np.int64))
results = triton_client.infer(model_name,
inputs,
outputs=outputs)
out_strings = results.as_numpy(output_names[0])
output = out_strings.item().decode()
print(f'Query: {input_data}')
print(f'Predicted label: {output}')
print('------------------------------')