Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

different image size w/ torchscript windows c++ #1920

Closed
alexyuisme opened this issue Jan 13, 2021 · 7 comments
Closed

different image size w/ torchscript windows c++ #1920

alexyuisme opened this issue Jan 13, 2021 · 7 comments

Comments

@alexyuisme
Copy link

Hi, all:

I am running into an issue while testing different images under c++ windows environment. The following is my windows program: (windows 10, visual studio 2017 w/ torchscript) :

`
int main()
{
torch::jit::script::Module module;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
module = torch::jit::load("F:/yolov5s.torchscript.pt");

	string img_path = "F:/image.jpg";
	Mat img = imread(img_path);
	cvtColor(img, img, CV_BGR2RGB);// bgr->rgb
	img.convertTo(img, CV_32FC3, 1.0f / 255.0f);// 1/255
	auto tensor_img = torch::from_blob(img.data, { img.rows, img.cols, img.channels() });
	tensor_img = tensor_img.permute({ 2, 0, 1 });
	tensor_img = tensor_img.unsqueeze(0);

	std::vector<torch::jit::IValue> inputs;
	inputs.push_back(tensor_img);
	torch::jit::IValue output = module.forward(inputs);
	auto op = output.toList().get(0).toTensor();
}
catch (const c10::Error& e) {
	
	//std::cerr << "error loading the model\n";
	std::cerr << e.what() << std::endl;
	return -1;
}

}`

testing image.jpg with size 384x640 is ok; but testing image.jpg with 640x360 gives me the following error:

Unhandled exception at 0x00007FFEDA1FF218 in test_libtorch.exe: Microsoft C++ exception: std::runtime_error at memory location 0x000000E5255E9C58.

Any ideas? Thanks!

@github-actions
Copy link
Contributor

github-actions bot commented Jan 13, 2021

👋 Hello @alexyuisme, thank you for your interest in 🚀 YOLOv5! Please visit our ⭐️ Tutorials to get started, where you can find quickstart guides for simple tasks like Custom Data Training all the way to advanced concepts like Hyperparameter Evolution.

If this is a 🐛 Bug Report, please provide screenshots and minimum viable code to reproduce your issue, otherwise we can not help you.

If this is a custom training ❓ Question, please provide as much information as possible, including dataset images, training logs, screenshots, and a public link to online W&B logging if available.

For business inquiries or professional support requests please visit https://www.ultralytics.com or email Glenn Jocher at glenn.jocher@ultralytics.com.

Requirements

Python 3.8 or later with all requirements.txt dependencies installed, including torch>=1.7. To install run:

$ pip install -r requirements.txt

Environments

YOLOv5 may be run in any of the following up-to-date verified environments (with all dependencies including CUDA/CUDNN, Python and PyTorch preinstalled):

Status

CI CPU testing

If this badge is green, all YOLOv5 GitHub Actions Continuous Integration (CI) tests are currently passing. CI tests verify correct operation of YOLOv5 training (train.py), testing (test.py), inference (detect.py) and export (export.py) on MacOS, Windows, and Ubuntu every 24 hours and on every commit.

@zhiqwang
Copy link
Contributor

zhiqwang commented Jan 13, 2021

Hi,

It seems that torch.jit.trace mechanism doesn't support inferring with dynamic variable input sizes. In other words, your image size must equal to the sizes when you generate the torchscript model.

I think that one way to address the dynamic size inference issue is through the torch.jit.script mechanism.

@alexyuisme
Copy link
Author

Hi,

It seems that torch.jit.trace mechanism doesn't support inferring with dynamic variable input sizes. In other words, your image size must equal to the sizes when you generate the torchscript model.

I think that one way to address the dynamic size inference issue is through the torch.jit.script mechanism.

Hi, zhiqwang:

Thanks for your reply. I am very new to Pytorch and still figuring out what is the diff btw torch.jit.trace and torch.jit.script. I am wondering how do you know I am using torch.jit.trace mechanism? The way that I generated "yolov5s.torchscript.pt" is by running:

python models/export.py --weights yolov5s.pt --img 640 --batch 1

Thanks.

@zhiqwang
Copy link
Contributor

zhiqwang commented Jan 14, 2021

Hi @alexyuisme

Because the author use torch.jit.trace as default as following:

yolov5/models/export.py

Lines 58 to 62 in 051e9e8

print('\nStarting TorchScript export with torch %s...' % torch.__version__)
f = opt.weights.replace('.pt', '.torchscript.pt') # filename
ts = torch.jit.trace(model, img)
ts.save(f)
print('TorchScript export success, saved as %s' % f)

Besides, the dynamic variable batch inference limitations is very common with torch.jit.trace, so I guess that you are using torch.jit.trace.

@alexyuisme
Copy link
Author

Hi, @zhiqwang:

Haha, I see. Is that possible to specify different width and height rather than simply using --img 640?

Thanks,

@zhiqwang
Copy link
Contributor

zhiqwang commented Jan 14, 2021

Hi, @alexyuisme

The torch.jit.trace mechanism only support determined image sizes, if the sizes of your dataset is determined, you can just set (suppose the size is (640, 360))

python models/export.py --weights yolov5s.pt --img 640 360 --batch 1

If the size is dynamic, the author also provide a letterbox function to auto-pad the images to the sizes you are determined (when the image is smaller than (640, 360)), because CNN is local sensitive, it doesn't influence the inference results, you can refer the function as following:

yolov5/utils/datasets.py

Lines 795 to 825 in dd03b20

def letterbox(img, new_shape=(640, 640), color=(114, 114, 114), auto=True, scaleFill=False, scaleup=True):
# Resize image to a 32-pixel-multiple rectangle https://github.com/ultralytics/yolov3/issues/232
shape = img.shape[:2] # current shape [height, width]
if isinstance(new_shape, int):
new_shape = (new_shape, new_shape)
# Scale ratio (new / old)
r = min(new_shape[0] / shape[0], new_shape[1] / shape[1])
if not scaleup: # only scale down, do not scale up (for better test mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding
if auto: # minimum rectangle
dw, dh = np.mod(dw, 32), np.mod(dh, 32) # wh padding
elif scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (new_shape[1], new_shape[0])
ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
img = cv2.resize(img, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
img = cv2.copyMakeBorder(img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border
return img, ratio, (dw, dh)

Another more flexible method is through torch.jit.script as I mentioned, you can refer to my repo.

@alexyuisme
Copy link
Author

You saved my life, bro! I'll definitely check your yolo5rt stack.

Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants