Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for inclined planes and heightmap visualization (from #122) #125

Merged
merged 9 commits into from
Apr 2, 2024

Conversation

flferretti
Copy link
Collaborator

@flferretti flferretti commented Mar 29, 2024

This pull request adds support for inclined planes and heightmap visualization using the MuJoCo visualizer. In particular, in the case of an inclined plane, the zaxis attribute of the terrain element in the XML gets changed, while for the heightmap, a callable should be passed when creating the MujocoModelHelper. This, will populate at runtime the hfield attribute which gets allocated when generating the XML string if the flag heightmap is passed to the converter. (See MuJoCo XML Reference)

Default behavior:

video.mp4

To run the examples:

MWE Inclined Plane
import pathlib

import jax.numpy as jnp
import numpy as np
import resolve_robotics_uri_py
import rod

import jaxsim
import jaxsim.api as js
from jaxsim import VelRepr, integrators

# Find the urdf file.
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(uri="file://stickbot.urdf")
# Build the ROD model.
rod_sdf = rod.Sdf.load(sdf=urdf_path)

# Create a terrain.
plane_normal = jnp.array([0.0, -0.1, 0.2])
terrain = jaxsim.terrain.PlaneTerrain(
    plane_normal=plane_normal,
)

# Build the model.
model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_sdf.model,
    terrain=terrain,
)


# Build the model's data.
# Set already here the initial base position.
data0 = js.data.JaxSimModelData.build(
    model=model,
    base_position=jnp.array([0, 0, 0.85]),
    velocity_representation=VelRepr.Inertial,
)

# Update the soft-contact parameters.
# By default, only 1 support point is used as worst-case scenario.
# Feel free to tune this with more points to get a less stiff system.
data0 = data0.replace(
    soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(
        model, number_of_active_collidable_points_steady_state=2
    )
)

# =====================
# Create the integrator
# =====================

# Create a RK4 integrator integrating the quaternion on SO(3).
integrator = integrators.fixed_step.RungeKutta4SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model,
        data=data0,
        system_dynamics=js.ode.system_dynamics,
    ),
)

# =========================================
# Visualization in Mujoco viewer / renderer
# =========================================

from jaxsim.mujoco import MujocoModelHelper, MujocoVideoRecorder, RodModelToMjcf

# Convert the ROD model to a Mujoco model.
mjcf_string, assets = RodModelToMjcf.convert(
    rod_model=rod_sdf.models()[0],
    considered_joints=list(model.joint_names()),
    plane_normal=plane_normal,
)

# Build the Mujoco model helper.
mj_model_helper = self = MujocoModelHelper.build_from_xml(
    mjcf_description=mjcf_string, assets=assets
)

# Create the video recorder.
recorder = MujocoVideoRecorder(
    model=mj_model_helper.model,
    data=mj_model_helper.data,
    fps=int(1 / 0.010),
    width=320 * 4,
    height=240 * 4,
)

# ==============
# Recording loop
# ==============

# Initialize the integrator.
t0 = 0.0
tf = 5.0
dt = 0.001_000
integrator_state = integrator.init(x0=data0.state, t0=t0, dt=dt)

# Initialize the loop.
data = data0.copy()
joint_names = list(model.joint_names())

while data.time_ns < tf * 1e9:

    # Integrate the dynamics.
    data, integrator_state = js.model.step(
        dt=dt,
        model=model,
        data=data,
        integrator=integrator,
        integrator_state=integrator_state,
        # Optional inputs
        joint_forces=None,
        external_forces=None,
    )

    # Extract the generalized position.
    s = data.state.physics_model.joint_positions
    W_p_B = data.state.physics_model.base_position
    W_Q_B = data.state.physics_model.base_quaternion

    # Update the data object stored in the helper, which is shared with the recorder.
    mj_model_helper.set_base_position(position=np.array(W_p_B))
    mj_model_helper.set_base_orientation(orientation=np.array(W_Q_B), dcm=False)
    mj_model_helper.set_joint_positions(positions=np.array(s), joint_names=joint_names)

    # Record the frame if the time is right to get the desired fps.
    if data.time_ns % jnp.array(1e9 / recorder.fps).astype(jnp.uint64) == 0:
        recorder.record_frame(camera_name=None)

# Store the video.
video_path = pathlib.Path("~/video.mp4").expanduser()
recorder.write_video(path=video_path, exist_ok=True)

# Clean up the recorder.
recorder.frames = []
recorder.renderer.close()
video.mp4
MWE Heightmap with Ridged Multifractal
import pathlib

import jax.numpy as jnp
import numpy as np
import resolve_robotics_uri_py
import rod
import numpy as np

import jaxsim
import jaxsim.api as js
from jaxsim import VelRepr, integrators

# Find the urdf file.
urdf_path = resolve_robotics_uri_py.resolve_robotics_uri(uri="file://stickbot.urdf")

# Build the ROD model.
rod_sdf = rod.Sdf.load(sdf=urdf_path)

# Create a terrain.
ridged_multifractal = lambda x, y: 1 + np.sum(
    [
        0.25**i
        * (
            0.25
            - np.abs(x * 2**i - control_points[i][0]) ** 0.5
            - np.abs(y * 2**i - control_points[i][1]) ** 0.5
        )
        for i in range(num_octaves)
    ],
    axis=0,
)
num_octaves = 5
control_points = np.random.uniform(0, 1, size=(num_octaves, 2))

hfield_fun = lambda x, y: ridged_multifractal(x, y)

# Build the model.
model = js.model.JaxSimModel.build_from_model_description(
    model_description=rod_sdf.model,
)


# Build the model's data.
# Set already here the initial base position.
data0 = js.data.JaxSimModelData.build(
    model=model,
    base_position=jnp.array([0, 0, 0.85]),
    velocity_representation=VelRepr.Inertial,
)

# Update the soft-contact parameters.
# By default, only 1 support point is used as worst-case scenario.
# Feel free to tune this with more points to get a less stiff system.
data0 = data0.replace(
    soft_contacts_params=js.contact.estimate_good_soft_contacts_parameters(
        model, number_of_active_collidable_points_steady_state=2
    )
)

# =====================
# Create the integrator
# =====================

# Create a RK4 integrator integrating the quaternion on SO(3).
integrator = integrators.fixed_step.RungeKutta4SO3.build(
    dynamics=js.ode.wrap_system_dynamics_for_integration(
        model=model,
        data=data0,
        system_dynamics=js.ode.system_dynamics,
    ),
)

# =========================================
# Visualization in Mujoco viewer / renderer
# =========================================

from jaxsim.mujoco import (
    MujocoModelHelper,
    MujocoVideoRecorder,
    RodModelToMjcf,
    MujocoVisualizer,
)

# Convert the ROD model to a Mujoco model.
mjcf_string, assets = RodModelToMjcf.convert(
    rod_model=rod_sdf.models()[0],
    considered_joints=list(model.joint_names()),
    heightmap=True,
)

# Build the Mujoco model helper.
mj_model_helper = self = MujocoModelHelper.build_from_xml(
    mjcf_description=mjcf_string,
    assets=assets,
    heightmap=hfield_fun,
)

# Create the video recorder.
recorder = MujocoVideoRecorder(
    model=mj_model_helper.model,
    data=mj_model_helper.data,
    fps=int(1 / 0.010),
    width=320 * 4,
    height=240 * 4,
)

viz = MujocoVisualizer(model=mj_model_helper.model, data=mj_model_helper.data)

handle = viz.open_viewer()
viz.sync(handle)

input()
# ==============
# Recording loop
# ==============

# Initialize the integrator.
t0 = 0.0
tf = 5.0
dt = 0.001_000
integrator_state = integrator.init(x0=data0.state, t0=t0, dt=dt)

# Initialize the loop.
data = data0.copy()
joint_names = list(model.joint_names())

while data.time_ns < tf * 1e9:

    # Integrate the dynamics.
    data, integrator_state = js.model.step(
        dt=dt,
        model=model,
        data=data,
        integrator=integrator,
        integrator_state=integrator_state,
        # Optional inputs
        joint_forces=None,
        external_forces=None,
    )

    # Extract the generalized position.
    s = data.state.physics_model.joint_positions
    W_p_B = data.state.physics_model.base_position
    W_Q_B = data.state.physics_model.base_quaternion

    # Update the data object stored in the helper, which is shared with the recorder.
    mj_model_helper.set_base_position(position=np.array(W_p_B))
    mj_model_helper.set_base_orientation(orientation=np.array(W_Q_B), dcm=False)
    mj_model_helper.set_joint_positions(positions=np.array(s), joint_names=joint_names)

    # Record the frame if the time is right to get the desired fps.
    if data.time_ns % jnp.array(1e9 / recorder.fps).astype(jnp.uint64) == 0:
        recorder.record_frame(camera_name=None)

# Store the video.
video_path = pathlib.Path("~/video.mp4").expanduser()
recorder.write_video(path=video_path, exist_ok=True)

# Clean up the recorder.
recorder.frames = []
recorder.renderer.close()

image


📚 Documentation preview 📚: https://jaxsim--125.org.readthedocs.build//125/

@DanielePucci
Copy link
Member

This PR is relevant for @lorycontixd and the work he will be conducting in https://github.com/ami-iit/element_lower-leg-morphology-optimization

@flferretti flferretti marked this pull request as ready for review April 2, 2024 07:53
Copy link
Member

@diegoferigo diegoferigo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments

src/jaxsim/mujoco/loaders.py Outdated Show resolved Hide resolved
src/jaxsim/mujoco/loaders.py Show resolved Hide resolved
Co-authored-by: Diego Ferigo <diego.ferigo@iit.it>
src/jaxsim/mujoco/model.py Outdated Show resolved Hide resolved
flferretti and others added 2 commits April 2, 2024 11:44
Co-authored-by: Diego Ferigo <dgferigo@gmail.com>
@flferretti flferretti merged commit a2e0ca9 into ami-iit:main Apr 2, 2024
11 checks passed
@flferretti flferretti deleted the feature/plane_visualization branch April 2, 2024 10:24
This pull request was closed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants