Skip to content

Commit

Permalink
Fix: issue in dtype of grid in cleaner env (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelavalos authored Jan 16, 2024
1 parent d21d23b commit 8168c5c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/environments/cleaner.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ always start in the top left corner of the maze.
## Observation
The **observation** seen by the agent is a `NamedTuple` containing the following:

- `grid`: jax array (int) of shape `(num_rows, num_cols)`, array representing the grid, each tile is
- `grid`: jax array (int8) of shape `(num_rows, num_cols)`, array representing the grid, each tile is
either dirty (0), clean (1), or a wall (2).

- `agents_locations`: jax array (int) of shape `(num_agents, 2)`, array specifying the x and y
Expand Down
8 changes: 4 additions & 4 deletions jumanji/environments/routing/cleaner/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Cleaner(Environment[State]):
a maze.
- observation: `Observation`
- grid: jax array (int32) of shape (num_rows, num_cols)
- grid: jax array (int8) of shape (num_rows, num_cols)
contains the state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
- agents_locations: jax array (int32) of shape (num_agents, 2)
contains the location of each agent on the board.
Expand All @@ -57,7 +57,7 @@ class Cleaner(Environment[State]):
- An invalid action is selected for any of the agents.
- state: `State`
- grid: jax array (int32) of shape (num_rows, num_cols)
- grid: jax array (int8) of shape (num_rows, num_cols)
contains the current state of the board: 0 for dirty tile, 1 for clean tile, 2 for wall.
- agents_locations: jax array (int32) of shape (num_agents, 2)
contains the location of each agent on the board.
Expand Down Expand Up @@ -127,15 +127,15 @@ def observation_spec(self) -> specs.Spec[Observation]:
Returns:
Spec for the `Observation`, consisting of the fields:
- grid: BoundedArray (int32) of shape (num_rows, num_cols). Values
- grid: BoundedArray (int8) of shape (num_rows, num_cols). Values
are between 0 and 2 (inclusive).
- agent_locations_spec: BoundedArray (int32) of shape (num_agents, 2).
Maximum value for the first column is num_rows, and maximum value
for the second is num_cols.
- action_mask: BoundedArray (bool) of shape (num_agent, 4).
- step_count: BoundedArray (int32) of shape ().
"""
grid = specs.BoundedArray(self.grid_shape, jnp.int32, 0, 2, "grid")
grid = specs.BoundedArray(self.grid_shape, jnp.int8, 0, 2, "grid")
agents_locations = specs.BoundedArray(
(self.num_agents, 2), jnp.int32, [0, 0], self.grid_shape, "agents_locations"
)
Expand Down
3 changes: 2 additions & 1 deletion jumanji/environments/routing/cleaner/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
[DIRTY, DIRTY, DIRTY, DIRTY, WALL],
[DIRTY, WALL, WALL, DIRTY, WALL],
[DIRTY, WALL, DIRTY, DIRTY, DIRTY],
]
],
dtype=jnp.int8,
)


Expand Down

0 comments on commit 8168c5c

Please sign in to comment.