Skip to content

Commit

Permalink
Proper relative paths handling
Browse files Browse the repository at this point in the history
  • Loading branch information
CristiFati committed May 28, 2021
1 parent 7981f16 commit 39d26ae
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions examples/rknn_convert/rknn_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ def parse_model_config(yaml_config_file):
return model_configs


def convert_model(model_path, out_path, pre_compile):
if os.path.isfile(model_path):
yaml_config_file = model_path
model_path = os.path.dirname(yaml_config_file)
def convert_model(config_path, out_path, pre_compile):
if os.path.isfile(config_path):
config_file = os.path.abspath(config_path)
config_path = os.path.dirname(config_file)
else:
yaml_config_file = os.path.join(model_path, 'model_config.yml')
if not os.path.exists(yaml_config_file):
print('model config {} not exist!'.format(yaml_config_file))
config_file = os.path.join(config_path, 'model_config.yml')
if not os.path.exists(config_file):
print('model config {} not exist!'.format(config_file))
exit(-1)

model_configs = parse_model_config(yaml_config_file)
model_configs = parse_model_config(config_file)

exported_rknn_model_path_list = []

Expand All @@ -39,7 +39,7 @@ def convert_model(model_path, out_path, pre_compile):

print('--> Loading model...')
if model['platform'] == 'tensorflow':
model_file_path = os.path.join(model_path, model['model_file_path'])
model_file_path = os.path.join(config_path, model['model_file_path'])
input_size_list = []
for input_size_str in model['subgraphs']['input-size-list']:
input_size = list(map(int, input_size_str.split(',')))
Expand All @@ -50,22 +50,22 @@ def convert_model(model_path, out_path, pre_compile):
outputs=model['subgraphs']['outputs'],
input_size_list=input_size_list)
elif model['platform'] == 'tflite':
model_file_path = os.path.join(model_path, model['model_file_path'])
model_file_path = os.path.join(config_path, model['model_file_path'])
rknn.load_tflite(model=model_file_path)
elif model['platform'] == 'caffe':
prototxt_file_path = os.path.join(model_path,model['prototxt_file_path'])
caffemodel_file_path = os.path.join(model_path,model['caffemodel_file_path'])
prototxt_file_path = os.path.join(config_path,model['prototxt_file_path'])
caffemodel_file_path = os.path.join(config_path,model['caffemodel_file_path'])
rknn.load_caffe(model=prototxt_file_path, proto='caffe', blobs=caffemodel_file_path)
elif model['platform'] == 'onnx':
model_file_path = os.path.join(model_path, model['model_file_path'])
model_file_path = os.path.join(config_path, model['model_file_path'])
rknn.load_onnx(model=model_file_path)
else:
print("Platform {:} is not supported! Moving on.".format(model['platform']))
continue
print('done')

if model['quantize']:
dataset_path = os.path.join(model_path, model['dataset'])
dataset_path = os.path.join(config_path, model['dataset'])
else:
dataset_path = './dataset'

Expand All @@ -83,8 +83,8 @@ def convert_model(model_path, out_path, pre_compile):


if __name__ == '__main__':
model_path = sys.argv[1]
config_path = sys.argv[1]
out_path = sys.argv[2]
pre_compile = sys.argv[3] in ['true', '1', 'True']

convert_model(model_path, out_path, pre_compile)
convert_model(config_path, out_path, pre_compile)

0 comments on commit 39d26ae

Please sign in to comment.