Skip to content

Commit

Permalink
feat: add FourRooms env
Browse files Browse the repository at this point in the history
  • Loading branch information
epignatelli committed Mar 10, 2024
1 parent 1dffa5a commit 043618a
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 1 deletion.
1 change: 1 addition & 0 deletions navix/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
from .environment import Environment, Timestep
from .room import Room
from .keydoor import KeyDoor
from .four_rooms import FourRooms
116 changes: 116 additions & 0 deletions navix/environments/four_rooms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2023 The Navix Authors.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


from __future__ import annotations
from typing import Union

import jax
import jax.numpy as jnp
from jax import Array
from flax import struct

from ..components import EMPTY_POCKET_ID
from ..entities import Entities, Goal, Player, State, Wall
from ..grid import (
random_positions,
random_directions,
room,
horizontal_wall,
vertical_wall,
)
from ..rendering.cache import RenderingCache
from .environment import Environment, Timestep
from .registry import register_env


class FourRooms(Environment):

def reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Timestep:
assert self.height > 4, f"Insufficient height for room {self.height} < 4"
assert self.width > 4, f"Insufficient width for room {self.width} < 4"
key, k1, k2 = jax.random.split(key, 3)

# map
grid = room(height=self.height, width=self.width)

# vertical partition
opening_1 = jax.random.randint(k1, shape=(), minval=1, maxval=self.height // 2)
opening_2 = jax.random.randint(
k1, shape=(), minval=self.height // 2 + 2, maxval=self.height
)
openings = jnp.stack([opening_1, opening_2])
wall_pos_vert = vertical_wall(grid, 9, openings)

# horizontal partition
opening_1 = jax.random.randint(k2, shape=(), minval=1, maxval=self.width // 2)
opening_2 = jax.random.randint(
k1, shape=(), minval=self.width // 2 + 2, maxval=self.width
)
openings = jnp.stack([opening_1, opening_2])
wall_pos_hor = horizontal_wall(grid, 9, openings)

walls_pos = jnp.concatenate([wall_pos_vert, wall_pos_hor])
walls = Wall(position=walls_pos)

# player
player_pos, goal_pos = random_positions(k1, grid, n=2, exclude=walls_pos)
direction = random_directions(k2, n=1)
player = Player(
position=player_pos,
direction=direction,
pocket=EMPTY_POCKET_ID,
)
# goal
goal = Goal(position=goal_pos, probability=jnp.asarray(1.0))

entities = {
Entities.PLAYER: player[None],
Entities.GOAL: goal[None],
Entities.WALL: walls,
}

# systems
state = State(
key=key,
grid=grid,
cache=cache or RenderingCache.init(grid),
entities=entities,
)

return Timestep(
t=jnp.asarray(0, dtype=jnp.int32),
observation=self.observation(state),
action=jnp.asarray(0, dtype=jnp.int32),
reward=jnp.asarray(0.0, dtype=jnp.float32),
step_type=jnp.asarray(0, dtype=jnp.int32),
state=state,
)


register_env(
"MiniGrid-FourRooms-v0",
lambda *args, **kwargs: FourRooms(*args, **kwargs, height=19, width=19),
)


register_env(
"MiniGrid-FourRooms-19x19-v0",
lambda *args, **kwargs: FourRooms(*args, **kwargs, height=19, width=19),
)
24 changes: 23 additions & 1 deletion navix/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def positions_equal(a: Array, b: Array) -> Array:


def room(height: int, width: int):
"""A grid of ids of size `width` x `height`"""
"""A grid of ids of size `width` x `height`, including the sorrounding walls"""
grid = jnp.zeros((height - 2, width - 2), dtype=jnp.int32)
return jnp.pad(grid, 1, mode="constant", constant_values=-1)

Expand All @@ -152,6 +152,28 @@ def two_rooms(height: int, width: int, key: Array) -> Tuple[Array, Array]:
return grid, wall_at


def vertical_wall(grid: Array, row_idx: int, opening_col_idx: Array | None= None) -> Array:
rows = jnp.arange(1, grid.shape[0] - 1)
cols = jnp.asarray([row_idx] * (grid.shape[0] - 2))
positions = jnp.stack((rows, cols), axis=1)
if opening_col_idx is not None:
positions = jnp.delete(
positions, opening_col_idx - 1, axis=0, assume_unique_indices=True
)
return positions


def horizontal_wall(grid: Array, col_idx: int, opening_row_idx: Array | None= None) -> Array:
rows = jnp.asarray([col_idx] * (grid.shape[1] - 2))
cols = jnp.arange(1, grid.shape[1] - 1)
positions = jnp.stack((rows, cols), axis=1)
if opening_row_idx is not None:
positions = jnp.delete(
positions, opening_row_idx - 1, axis=0, assume_unique_indices=True
)
return positions


def crop(grid: Array, origin: Array, direction: Array, radius: int) -> Array:
input_shape = grid.shape

Expand Down

0 comments on commit 043618a

Please sign in to comment.