Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]: difference in output of model exported to onnx #1211

Closed
4 tasks done
Tuxliri opened this issue Dec 12, 2022 · 2 comments
Closed
4 tasks done

[Bug]: difference in output of model exported to onnx #1211

Tuxliri opened this issue Dec 12, 2022 · 2 comments
Labels
bug Something isn't working

Comments

@Tuxliri
Copy link

Tuxliri commented Dec 12, 2022

🐛 Bug

I'm trying to export a trained model to onnx to use in a Simulink model. When I test the output of the exported onnx model it is different than the one obtained by running the sb3 model. I cannot share my trained model, however I've reproduced the problem with the Pendulum-v1 environment. In the code below the onnx model output [[0.]] if the observation is [[0.,0.,0.]] however the sb3 model outputs a non-zero value.

To Reproduce

import torch as th

from stable_baselines3 import PPO

class OnnxablePolicy(th.nn.Module):
    def __init__(self, extractor, action_net, value_net):
        super().__init__()
        self.extractor = extractor
        self.action_net = action_net
        self.value_net = value_net

    def forward(self, observation):
        # NOTE: You may have to process (normalize) observation in the correct
        #       way before using this. See `common.preprocessing.preprocess_obs`
        action_hidden, value_hidden = self.extractor(observation)
        return self.action_net(action_hidden), self.value_net(value_hidden)


# Example: model = PPO("MlpPolicy", "Pendulum-v1")
model = PPO("MlpPolicy", "Pendulum-v1")
observation_size = model.observation_space.shape
dummy_input = th.randn(1, *observation_size)

model.predict(dummy_input)
onnxable_model = OnnxablePolicy(
    model.policy.mlp_extractor,
    model.policy.action_net,
    model.policy.value_net
)


th.onnx.export(
    onnxable_model,
    dummy_input,
    "my_ppo_pendulum_model.onnx",
    opset_version=9,
    input_names=["input"],
)

##### Load and test with onnx

import onnx
import onnxruntime as ort
import numpy as np

onnx_path = "my_ppo_pendulum_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)

observation = np.zeros((1, *observation_size)).astype(np.float32)
ort_sess = ort.InferenceSession(onnx_path)
action, value = ort_sess.run(None, {"input": observation})
action1, __ = model.predict(observation=observation)

np.testing.assert_allclose(action, action1, rtol=1e-03, atol=1e-05)

Relevant log output / Error message

Exception has occurred: AssertionError

Not equal to tolerance rtol=0.001, atol=1e-05

Mismatched elements: 1 / 1 (100%)
Max absolute difference: 2.
Max relative difference: 1.
 x: array([[0.]], dtype=float32)
 y: array([[2.]], dtype=float32)
  File "/home/diafrate/RL_rocket_6DOF/export_model_to_onnx.py", line 55, in <module>
    np.testing.assert_allclose(action, action1, rtol=1e-03, atol=1e-05)

System Info

  • 'OS': 'Linux-5.15.74.2-microsoft-standard-WSL2-x86_64-with-debian-bullseye-sid Roadmap to Stable-Baselines3 V1.0 #1 SMP Wed Nov 2 19:50:29 UTC 2022',
  • 'Python': '3.7.13'
  • 'Stable-Baselines3': '1.6.0'
  • 'PyTorch': '1.12.0+cu102'
  • 'GPU Enabled': 'False'
  • 'Numpy': '1.21.6',
  • 'Gym': '0.21.0'

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug
  • I've used the markdown code blocks for both code and stack traces.
@Tuxliri Tuxliri added the bug Something isn't working label Dec 12, 2022
@araffin
Copy link
Member

araffin commented Dec 12, 2022

Hello,
you are missing deterministic=True when calling predict, otherwise it will sample from the action distribution.

@Tuxliri
Copy link
Author

Tuxliri commented Dec 12, 2022

Thanks, that was the issue!

@Tuxliri Tuxliri closed this as completed Dec 12, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants