Skip to content

Commit

Permalink
fix: flatpack was training with ints (#234)
Browse files Browse the repository at this point in the history
  • Loading branch information
sash-a authored Mar 18, 2024
1 parent d462d54 commit f025e4f
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jumanji/training/networks/flat_pack/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __call__(self, grid_observation: chex.Array) -> chex.Array:
# Grid observation is of shape (B, num_rows, num_cols)

# Add a channel dimension
grid_observation = grid_observation[..., jnp.newaxis]
grid_observation = grid_observation[..., jnp.newaxis].astype(float)

# Down colvolve with strided convolutions
down_1 = hk.Conv2D(32, kernel_shape=3, stride=2, padding="SAME")(
Expand Down Expand Up @@ -155,9 +155,10 @@ def __call__(self, observation: Observation) -> Tuple[chex.Array, chex.Array]:
# observation.grid (B, num_rows, num_cols)

# Flatten the blocks
# (B, num_blocks, 9)
flattened_blocks = jnp.reshape(
observation.blocks, (-1, self.num_blocks, 9)
) # (B, num_blocks, 9)
).astype(float)

# Encode the blocks with an MLP
block_encoder = hk.nets.MLP(output_sizes=[self.model_size])
Expand Down

0 comments on commit f025e4f

Please sign in to comment.