Skip to content

Commit

Permalink
feat: deprecate Connect4 (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
clement-bonnet authored Feb 16, 2023
1 parent 1095886 commit 1898254
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
rev: 5.11.5
hooks:
- id: isort

Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def is_notebook() -> bool:
except ImportError as exc:
import warnings

warnings.warn(f"Error importing IPython: {exc}")
warnings.warn(f"Error importing IPython: {exc}", stacklevel=2)
21 changes: 15 additions & 6 deletions jumanji/environments/games/connect4/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Tuple

from chex import Array, PRNGKey
Expand All @@ -34,7 +35,8 @@


class Connect4(Environment[State]):
"""A JAX implementation of the 'Connect 4' game.
"""This environment is DEPRECATED and will be REMOVED in release 0.2.0.
A JAX implementation of the 'Connect 4' game.
- observation: Observation
- board: jax array (int8) of shape (6, 7):
Expand Down Expand Up @@ -67,7 +69,14 @@ class Connect4(Environment[State]):
"""

n_players: int = 2
def __init__(self, n_players: int = 2):
"""Throw a deprecation warning on initialization of Connect4."""
warnings.warn(
f"{self.__class__.__name__} is deprecated and will be removed in release 0.2.0.",
DeprecationWarning,
stacklevel=2,
)
self.n_players = n_players

def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""Resets the environment.
Expand All @@ -82,7 +91,7 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""
del key
board = jnp.zeros((BOARD_HEIGHT, BOARD_WIDTH), dtype=jnp.int8)
action_mask = jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)
action_mask = jnp.ones((BOARD_WIDTH,), dtype=bool)

obs = Observation(
board=board, action_mask=action_mask, current_player=jnp.int8(0)
Expand Down Expand Up @@ -177,9 +186,9 @@ def observation_spec(self) -> ObservationSpec:
action_mask=specs.BoundedArray(
shape=(7,),
dtype=bool,
minimum=0,
maximum=1,
name="invalid_mask",
minimum=False,
maximum=True,
name="action_mask",
),
current_player=specs.DiscreteArray(
num_values=self.n_players, dtype=jnp.int8, name="current_player"
Expand Down
12 changes: 9 additions & 3 deletions jumanji/environments/games/connect4/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_connect4__reset(connect4_env: Connect4, empty_board: Array) -> None:
assert timestep.observation.current_player == 0
assert jnp.array_equal(state.board, empty_board)
assert jnp.array_equal(
timestep.observation.action_mask, jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)
timestep.observation.action_mask, jnp.ones((BOARD_WIDTH,), dtype=bool)
)
# Check that the state is made of DeviceArrays, this is false for the non-jitted
# reset function since unpacking random.split returns numpy arrays and not device arrays.
Expand Down Expand Up @@ -122,9 +122,9 @@ def test_connect4__invalid_action(connect4_env: Connect4) -> None:
state, timestep = connect4_env.step(state, action)

# check that the action is flagged as illegal
assert timestep.observation.action_mask[0] == 0
assert not timestep.observation.action_mask[0]
# check the other actions are still legal
assert jnp.all(timestep.observation.action_mask[1:] == 1)
assert jnp.all(timestep.observation.action_mask[1:])

bad_player = state.current_player
good_player = (bad_player + 1) % 2
Expand Down Expand Up @@ -183,3 +183,9 @@ def test_connect4__render(connect4_env: Connect4, empty_board: Array) -> None:
expected_board_render = str(empty_board)

assert expected_board_render in render


def test_connect4__deprecation() -> None:
"""Check that instantiating the environment triggers a deprecation warning."""
with pytest.deprecated_call():
Connect4()
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mkdocs-material
mkdocs-mermaid2-plugin==0.6.0
mkdocstrings==0.18.0
mknotebooks==0.7.1
mypy==0.942
mypy==0.991
nbmake
optax>=0.0.9
pre-commit==2.17.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ jaxlib>=0.1.74
matplotlib>=3.3.4
numpy>=1.19.5
Pillow>=9.0.0
pygame>=2.0.0
pygame==2.0.2
typing-extensions>=4.0.0
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ ignore =
E266 # Do not require block comments to only have a single leading #.
E731 # Do not assign a lambda expression, use a def.
W503 # Line break before binary operator (not compatible with black).
B017 # assertRaises(Exception): or pytest.raises(Exception) should be considered evil.

0 comments on commit 1898254

Please sign in to comment.