Skip to content

Commit

Permalink
refactor: remove brax (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-bonnet authored Feb 22, 2023
1 parent 119a4ed commit 07aae63
Show file tree
Hide file tree
Showing 6 changed files with 2 additions and 178 deletions.
1 change: 0 additions & 1 deletion docs/api/wrappers.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
- JumanjiToDMEnvWrapper
- MultiToSingleWrapper
- VmapWrapper
- BraxToJumanjiWrapper
- AutoResetWrapper
- JumanjiToGymWrapper
- JumanjiToGymWrapper
Expand Down
21 changes: 1 addition & 20 deletions docs/guides/wrappers.md
Original file line number Diff line number Diff line change
@@ -1,26 +1,7 @@
# Wrappers

The `Wrapper` interface is used for extending Jumanji `Environment` to add features like auto reset and vectorised environments.
Jumanji provides wrappers to convert a Jumanji `Environment` to a DeepMind or Gym environment, and a Brax environment into a Jumanji `Environment`.

## Brax to Jumanji
Below is an example of how to convert a [Brax](https://github.com/google/brax) environment into a Jumanji environment. In this example Walker2d
terminates when 1000 steps are reached.

```python
import brax.envs
import jax.random
import jumanji.wrappers

brax_env = brax.envs.create("walker2d")
env = jumanji.wrappers.BraxToJumanjiWrapper(brax_env)

key = jax.random.PRNGKey(0)
state, timestep = env.reset(key)
action = jax.random.normal(key, [brax_env.action_size])
state, timestep = env.step(state, action)
...
```
Jumanji provides wrappers to convert a Jumanji `Environment` to a DeepMind or Gym environment.

## Jumanji to DeepMind Environment
We can also convert our Jumanji environments to a DeepMind environment:
Expand Down
86 changes: 1 addition & 85 deletions jumanji/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,12 @@
import jax
import jax.numpy as jnp
import numpy as np
from brax.envs import Env as BraxEnv
from brax.envs import State as BraxState
from chex import Array, PRNGKey
from jax import jit, random

from jumanji import specs, tree_utils
from jumanji.env import Environment, State
from jumanji.types import Action, TimeStep, restart, termination, transition
from jumanji.types import Action, TimeStep

Observation = TypeVar("Observation")

Expand Down Expand Up @@ -349,88 +347,6 @@ def render(self, state: State) -> Any:
return super().render(state_0)


class BraxToJumanjiWrapper(Environment):
"""A wrapper that converts a Brax environment to an `Environment` for standardisation and to
augment the API (add timesteps, metrics...).
"""

def __init__(self, brax_env: BraxEnv):
"""Creates the Environment wrapper for Brax environments.
Args:
brax_env: Brax Env object that is not wrapped by a ResetWrapper
"""
self._env = brax_env

def reset(self, key: PRNGKey) -> Tuple[BraxState, TimeStep]:
"""Resets the environment to an initial state.
Args:
key: random key used to reset the environment.
Returns:
state: Brax State object corresponding to the new state of the environment,
timestep: TimeStep object corresponding the first timestep returned by the environment,
"""
state = self._env.reset(key)
timestep = restart(observation=state.obs, extras=state.metrics)
return state, timestep

def step(self, state: BraxState, action: Action) -> Tuple[State, TimeStep]:
"""Run one timestep of the environment's dynamics.
Args:
state: Brax State object containing the dynamics of the environment.
action: Array containing the action to take.
Returns:
state: Brax State object corresponding to the next state of the environment,
timestep: TimeStep object corresponding the timestep returned by the environment,
"""
state = self._env.step(state, action)
timestep = jax.lax.cond(
state.done,
lambda _state: termination(
reward=_state.reward, observation=_state.obs, extras=_state.metrics
),
lambda _state: transition(
reward=_state.reward, observation=_state.obs, extras=_state.metrics
),
state,
)
return state, timestep

def observation_spec(self) -> specs.Array:
"""Returns the observation spec.
Returns:
observation_spec: a `specs.Array` spec.
"""
return specs.Array(
shape=(self._env.observation_size,),
dtype=float,
name="observation",
)

def action_spec(self) -> specs.BoundedArray:
"""Returns the action spec.
Returns:
action_spec: a `specs.BoundedArray` spec.
"""
return specs.BoundedArray(
shape=(self._env.action_size,),
dtype=float,
minimum=-1.0,
maximum=1.0,
name="action",
)

@property
def unwrapped(self) -> BraxEnv:
return self._env


class AutoResetWrapper(Wrapper):
"""Automatically resets environments that are done. Once the terminal state is reached,
the state, observation, and step_type are reset. The observation and step_type of the
Expand Down
70 changes: 0 additions & 70 deletions jumanji/wrappers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from collections import namedtuple
from typing import Tuple, Type, TypeVar

import brax
import dm_env.specs
import gym
import jax
Expand All @@ -24,8 +23,6 @@
import numpy as np
import pytest
import pytest_mock
from brax.envs import Env as BraxEnv
from brax.envs import State as BraxState
from chex import assert_trees_all_equal, dataclass

from jumanji import specs
Expand All @@ -37,7 +34,6 @@
from jumanji.types import StepType, TimeStep
from jumanji.wrappers import (
AutoResetWrapper,
BraxToJumanjiWrapper,
JumanjiToDMEnvWrapper,
JumanjiToGymWrapper,
MultiToSingleWrapper,
Expand Down Expand Up @@ -503,72 +499,6 @@ def test_vmap_env__unwrapped(
assert fake_vmap_environment._env == fake_environment


class TestBraxEnvToJumanjiEnvironment:
"""
Test the BraxEnvToJumanjiEnvironment wrapper that transforms a Brax Env into an Environment
format.
"""

@pytest.fixture
def fake_brax_env(self, time_limit: int = 10) -> BraxEnv:
"""Creates a trivial Brax Env meant for unit testing."""
return brax.envs.create("fast", auto_reset=False, episode_length=time_limit)

@pytest.fixture
def jumanji_environment_from_brax(self, fake_brax_env: BraxEnv) -> Environment:
"""Instantiates an Environment wrapped from a Brax env."""
return BraxToJumanjiWrapper(fake_brax_env)

def test_brax_env_to_jumanji_environment__init(
self, fake_brax_env: BraxEnv
) -> None:
"""Validates initialization of the wrapper."""
environment = BraxToJumanjiWrapper(fake_brax_env)
assert isinstance(environment, Environment)

def test_brax_env_to_jumanji_environment__reset(
self, jumanji_environment_from_brax: Environment
) -> None:
"""Validates (jitted) reset function and timestep type of the wrapped environment."""
state, timestep = jax.jit(jumanji_environment_from_brax.reset)(
jax.random.PRNGKey(0)
)
assert isinstance(state, BraxState)
assert isinstance(timestep, TimeStep)
assert timestep.step_type == StepType.FIRST
assert timestep.extras == {}

def test_brax_env_to_jumanji_environment__step(
self, jumanji_environment_from_brax: Environment
) -> None:
"""Validates (jitted) step function of the wrapped environment."""
state, timestep = jumanji_environment_from_brax.reset(jax.random.PRNGKey(0))
action = jumanji_environment_from_brax.action_spec().generate_value()
next_state, next_timestep = jax.jit(jumanji_environment_from_brax.step)(
state, action
)
assert_trees_are_different(timestep, next_timestep)
assert_trees_are_different(state, next_state)

def test_brax_env_to_jumanji_environment__observation_spec(
self, jumanji_environment_from_brax: Environment
) -> None:
"""Validates observation_spec property of the wrapped environment."""
assert isinstance(jumanji_environment_from_brax.observation_spec(), specs.Array)

def test_brax_env_to_jumanji_environment__action_spec(
self, jumanji_environment_from_brax: Environment
) -> None:
"""Validates action_spec property of the wrapped environment."""
assert isinstance(jumanji_environment_from_brax.action_spec(), specs.Array)

def test_brax_env_to_jumanji_environment__unwrapped(
self, jumanji_environment_from_brax: Environment
) -> None:
"""Validates unwrapped property of the wrapped environment."""
assert isinstance(jumanji_environment_from_brax.unwrapped, BraxEnv)


class TestAutoResetWrapper:
@pytest.fixture
def fake_auto_reset_environment(
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ follow_imports = "skip"
module = [
"matplotlib.*",
"mpl_toolkits.*",
"brax.*",
"gym.*",
"pytest_mock.*",
"numpy.*",
Expand Down
1 change: 0 additions & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
brax @ git+https://github.com/google/brax
chex>=0.1.3
dm-env>=1.5
gym>=0.22.0
Expand Down

0 comments on commit 07aae63

Please sign in to comment.