Skip to content

Commit

Permalink
Merge pull request #84 from epignatelli/fix-walkable
Browse files Browse the repository at this point in the history
Fix walkable
  • Loading branch information
epignatelli committed Jul 8, 2024
2 parents e38b4a0 + 8a09379 commit d926b98
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 11 deletions.
2 changes: 1 addition & 1 deletion navix/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@
# under the License.


__version__ = "0.6.15"
__version__ = "0.6.16"
__version_info__ = tuple(int(i) for i in __version__.split(".") if i.isdigit())
2 changes: 1 addition & 1 deletion navix/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _can_walk_there(state: State, position: Array) -> Tuple[Array, EventsManager
obstructs = jnp.logical_and(
jnp.logical_not(state.entities[k].walkable), same_position
)
walkable = jnp.logical_and(walkable, jnp.any(jnp.logical_not(obstructs)))
walkable = jnp.logical_and(walkable, jnp.all(jnp.logical_not(obstructs)))
return jnp.asarray(walkable, dtype=jnp.bool_), events


Expand Down
36 changes: 27 additions & 9 deletions navix/environments/lava_gap.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,36 +96,54 @@ def _reset(self, key: Array, cache: Union[RenderingCache, None] = None) -> Times
register_env(
"Navix-LavaGapS5-v0",
lambda *args, **kwargs: LavaGap.create(
*args,
**kwargs,
height=5,
width=5,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
termination_fn=kwargs.pop(
"termination_fn",
terminations.compose(
terminations.on_goal_reached,
terminations.on_lava_fall,
),
),
*args,
**kwargs,
),
)
register_env(
"Navix-LavaGapS6-v0",
lambda *args, **kwargs: LavaGap.create(
*args,
**kwargs,
height=6,
width=6,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
termination_fn=kwargs.pop(
"termination_fn",
terminations.compose(
terminations.on_goal_reached,
terminations.on_lava_fall,
),
),
*args,
**kwargs,
),
)
register_env(
"Navix-LavaGapS7-v0",
lambda *args, **kwargs: LavaGap.create(
*args,
**kwargs,
height=7,
width=7,
observation_fn=kwargs.pop("observation_fn", observations.symbolic),
reward_fn=kwargs.pop("reward_fn", rewards.on_goal_reached),
termination_fn=kwargs.pop("termination_fn", terminations.on_goal_reached),
termination_fn=kwargs.pop(
"termination_fn",
terminations.compose(
terminations.on_goal_reached,
terminations.on_lava_fall,
),
),
*args,
**kwargs,
),
)
53 changes: 53 additions & 0 deletions tests/test_issues.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright 2023 The Navix Authors.

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import jax
import jax.numpy as jnp

import navix as nx
from navix import observations


def test_82():
env = nx.make(
"Navix-DoorKey-5x5-v0",
max_steps=100,
observation_fn=observations.rgb,
)
key = jax.random.PRNGKey(5)
timestep = env.reset(key)
# Seed 5 is:
# # # # #
# P # . #
# . # . #
# K D G #
# # # # #

# start agent direction = EAST
prev_pos = timestep.state.entities["player"].position
# action 2 is forward
timestep = env.step(timestep, 2) # should not walk into wall
pos = timestep.state.entities["player"].position
assert jnp.array_equal(prev_pos, pos)


if __name__ == "__main__":
test_82()

0 comments on commit d926b98

Please sign in to comment.