Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: deprecate Connect4 #67

Merged
merged 13 commits into from
Feb 16, 2023
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ default_language_version:

repos:
- repo: https://github.com/timothycrosley/isort
rev: 5.10.1
rev: 5.11.5
hooks:
- id: isort

Expand Down
2 changes: 1 addition & 1 deletion jumanji/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,4 +53,4 @@ def is_notebook() -> bool:
except ImportError as exc:
import warnings

warnings.warn(f"Error importing IPython: {exc}")
warnings.warn(f"Error importing IPython: {exc}", stacklevel=2)
21 changes: 15 additions & 6 deletions jumanji/environments/games/connect4/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Tuple

from chex import Array, PRNGKey
Expand All @@ -34,7 +35,8 @@


class Connect4(Environment[State]):
"""A JAX implementation of the 'Connect 4' game.
"""This environment is DEPRECATED and will be REMOVED in release 0.2.0.
A JAX implementation of the 'Connect 4' game.

- observation: Observation
- board: jax array (int8) of shape (6, 7):
Expand Down Expand Up @@ -67,7 +69,14 @@ class Connect4(Environment[State]):

"""

n_players: int = 2
def __init__(self, n_players: int = 2):
"""Throw a deprecation warning on initialization of Connect4."""
warnings.warn(
f"{self.__class__.__name__} is deprecated and will be removed in release 0.2.0.",
DeprecationWarning,
stacklevel=2,
)
self.n_players = n_players

def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""Resets the environment.
Expand All @@ -82,7 +91,7 @@ def reset(self, key: PRNGKey) -> Tuple[State, TimeStep[Observation]]:
"""
del key
board = jnp.zeros((BOARD_HEIGHT, BOARD_WIDTH), dtype=jnp.int8)
action_mask = jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)
action_mask = jnp.ones((BOARD_WIDTH,), dtype=bool)

obs = Observation(
board=board, action_mask=action_mask, current_player=jnp.int8(0)
Expand Down Expand Up @@ -177,9 +186,9 @@ def observation_spec(self) -> ObservationSpec:
action_mask=specs.BoundedArray(
shape=(7,),
dtype=bool,
minimum=0,
maximum=1,
name="invalid_mask",
minimum=False,
maximum=True,
name="action_mask",
),
current_player=specs.DiscreteArray(
num_values=self.n_players, dtype=jnp.int8, name="current_player"
Expand Down
12 changes: 9 additions & 3 deletions jumanji/environments/games/connect4/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def test_connect4__reset(connect4_env: Connect4, empty_board: Array) -> None:
assert timestep.observation.current_player == 0
assert jnp.array_equal(state.board, empty_board)
assert jnp.array_equal(
timestep.observation.action_mask, jnp.ones((BOARD_WIDTH,), dtype=jnp.int8)
timestep.observation.action_mask, jnp.ones((BOARD_WIDTH,), dtype=bool)
)
# Check that the state is made of DeviceArrays, this is false for the non-jitted
# reset function since unpacking random.split returns numpy arrays and not device arrays.
Expand Down Expand Up @@ -122,9 +122,9 @@ def test_connect4__invalid_action(connect4_env: Connect4) -> None:
state, timestep = connect4_env.step(state, action)

# check that the action is flagged as illegal
assert timestep.observation.action_mask[0] == 0
assert not timestep.observation.action_mask[0]
# check the other actions are still legal
assert jnp.all(timestep.observation.action_mask[1:] == 1)
assert jnp.all(timestep.observation.action_mask[1:])

bad_player = state.current_player
good_player = (bad_player + 1) % 2
Expand Down Expand Up @@ -183,3 +183,9 @@ def test_connect4__render(connect4_env: Connect4, empty_board: Array) -> None:
expected_board_render = str(empty_board)

assert expected_board_render in render


def test_connect4__deprecation() -> None:
"""Check that instantiating the environment triggers a deprecation warning."""
with pytest.deprecated_call():
Connect4()
2 changes: 1 addition & 1 deletion requirements/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mkdocs-material
mkdocs-mermaid2-plugin==0.6.0
mkdocstrings==0.18.0
mknotebooks==0.7.1
mypy==0.942
mypy==0.991
nbmake
optax>=0.0.9
pre-commit==2.17.0
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ jaxlib>=0.1.74
matplotlib>=3.3.4
numpy>=1.19.5
Pillow>=9.0.0
pygame>=2.0.0
pygame==2.0.2
typing-extensions>=4.0.0
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ ignore =
E266 # Do not require block comments to only have a single leading #.
E731 # Do not assign a lambda expression, use a def.
W503 # Line break before binary operator (not compatible with black).
B017 # assertRaises(Exception): or pytest.raises(Exception) should be considered evil.