Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YoloV5 in Inf1 #330

Closed
diazGT94 opened this issue Sep 23, 2021 · 4 comments
Closed

YoloV5 in Inf1 #330

diazGT94 opened this issue Sep 23, 2021 · 4 comments

Comments

@diazGT94
Copy link

Hi, I'm trying to replicate the steps indicated in #253 to convert YoloV5s to neuron in inf1.

I am using Ubuntu 18.04 DLAMI. Activate the aws_neuron_pytorch_p36 python env

Installed this: pip install -r https://github.com/raw/ultralytics/yolov5/master/requirements.txt
Then import from Pytorch-Hub: model = torch.hub.load('ultralytics/yolov5', 'yolov5s', force_reload=True)
Create a Fake Image: fake_image = torch.zeros([1, 3, 608, 608], dtype=torch.float32)
Model inspection: model_neuron_for_inspection = torch.neuron.trace(model, fake_image, skip_compiler=True)

But this gives me the following error:

Fusing layers... Model Summary: 224 layers, 7266973 parameters, 0 gradients Adding AutoShape... /home/ubuntu/.cache/torch/hub/ultralytics_yolov5_master/models/yolo.py:60: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic: /home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py:940: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. for list, use a tupleinstead. fordict, use a NamedTupleinstead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior. _force_outplace, /home/ubuntu/.cache/torch/hub/ultralytics_yolov5_master/models/yolo.py:60: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! if self.grid[i].shape[2:4] != x[i].shape[2:4] or self.onnx_dynamic: /home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py:940: TracerWarning: Encountering a list at the output of the tracer might cause the trace to be incorrect, this is only valid if the container structure does not change based on the module's inputs. Consider using a constant container instead (e.g. forlist, use a tupleinstead. fordict, use a NamedTupleinstead). If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior. _force_outplace, Traceback (most recent call last): File "neuron_converter.py", line 11, in model_neuron_for_inspection = torch.neuron.trace(model, fake_image, skip_compiler=True) File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 103, in trace neuron_graph, jit_trace = to_graph(func, example_inputs, return_trace=True, **kwargs) File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch_neuron/convert.py", line 182, in to_graph jit_trace = torch.jit.trace(func_or_mod, example_inputs, **kwargs) File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py", line 742, in trace _module_class, File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py", line 966, in trace_module _module_class, File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, **kwargs) File "/home/ubuntu/anaconda3/envs/aws_neuron_pytorch_p36/lib/python3.6/site-packages/torch/jit/_trace.py", line 519, in _check_trace raise TracingCheckError(*diag_info) torch.jit._trace.TracingCheckError: Tracing failed sanity checks! ERROR: Graphs differed across invocations!

Could you please guide me through how to perform the conversion for deploying it on Inf1? @Ownmarc

Thanks,

@diazGT94
Copy link
Author

The problem was solved whit this:

fake_image = torch.zeros([1, 3, 608, 608], dtype=torch.float32)
try:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])
except Exception:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])

model_neuron = torch.neuron.trace(model, 
                                example_inputs=[fake_image])

@josebenitezg
Copy link

Hi!

I was able to convert the model from yolov5 to neuron with the follow code:

import torch
import torch_neuron
from torchvision import models

model = torch.hub.load('yolo5',
        'custom',
        path='yolov5.pt',
        source='local',
        force_reload=True)  # local repo

fake_image = torch.zeros([1, 3, 640, 640], dtype=torch.float32)
#fake_image = (torch.rand(3), torch.rand(3))
try:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])
except Exception:
    torch.neuron.analyze_model(model, example_inputs=[fake_image])

model_neuron = torch.neuron.trace(model, 
                                example_inputs=[fake_image])

## Export to saved model
model_neuron.save("model_converted.pt")

Now that I am trying to test and compare I have the tensors outputs different from yolo as follow:

Neuron Yolov5 Model:

[tensor([[-0.0356,  0.1790,  0.7456,  0.6292,  0.9359, 13.0000],
        [ 0.5830,  0.1404,  1.1279,  0.6628,  0.9359, 13.0000],
        [ 0.0823,  0.6350,  0.6272,  1.1599,  0.9315, 13.0000],
        [-0.1443,  0.1416,  0.2542,  0.5107,  0.9224, 13.0000],
        [ 0.3516,  0.6426,  0.7500,  1.0137,  0.9188, 13.0000],
        [ 0.3555,  0.1436,  0.7539,  0.5127,  0.9147, 13.0000]])]

Yolov5:

[tensor([[334.57495, 176.98302, 407.46155, 213.81169,   0.93721,  13.00000]])]

Inference script:

im = cv2.imread('test_img.jpg')
img0 = im.copy()
im = cv2.resize(im, (640, 640), interpolation = cv2.INTER_AREA)
# Convert
im = im.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
im = np.ascontiguousarray(im)
# Convert into torch
im = torch.from_numpy(im)
im = im.float()  # uint8 to fp16/32
im /= 255  # 0 - 255 to 0.0 - 1.0
if len(im.shape) == 3:
    im = im[None]  # expand for batch dim

# Load the compiled model
model = torch.jit.load('model_converted.pt')

# Inference
pred = model(im)
pred = non_max_suppression(pred) #nms function used same as yolov5 detect.py

#Process predictions
for i, det in enumerate(pred):  # per image
    im0 = img0.copy()
    color=(30, 30, 30)
    txt_color=(255, 255, 255)
    h_size, w_size = im.shape[-2:]
    print(h_size, w_size)
    lw = max(round(sum(im.shape) / 2 * 0.003), 2) 

    if len(det):
        # Write results
        for *xyxy, conf, cls in reversed(det):
            c = int(cls)  # integer class
            label = f'{CLASSES[c]} {conf:.2f}'
            print(label)
            box = xyxy 
            p1, p2 = (int(box[0]* w_size), int(box[1]* h_size)), (int(box[2]* w_size), int(box[3]* h_size))
            cv2.rectangle(im0, p1, p2, color, thickness=lw, lineType=cv2.LINE_AA)
            tf = max(lw - 1, 1)  # font thickness
            w, h = cv2.getTextSize(label, 0, fontScale=lw / 3, thickness=tf)[0]  # text width, height
            outside = p1[1] - h - 3 >= 0  # label fits outside box
            p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
            cv2.rectangle(im0, p1, p2, color, -1, cv2.LINE_AA)  # filled
            cv2.putText(im0,
                        label, (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
                        0,
                        lw / 3,
                        txt_color,
                        thickness=tf,
                        lineType=cv2.LINE_AA)
    # Save results (image with detections)
    status = cv2.imwrite('out.jpg', im0)

Is there something wrong when converting the model or running inference? The label and also the acc seems to be same as the expected, but tensors not.

@RobinFrcd
Copy link

👋 @josebenitez Did you manage to solve your problem ?

@josebenitezg
Copy link

This is probably the latest reply! We were able to compile for yolov5 and yolov8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants