Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding "count_include_pad" argument to flax.linen.pooling.avg_pool #2451

Merged
merged 5 commits into from
Sep 23, 2022

Conversation

dslisleedh
Copy link
Contributor

@dslisleedh dslisleedh commented Sep 9, 2022

What does this PR do?

Now version's flax.linen.pooling.avg_pool average window_sum result include padded tokens. I add argument whether to include padded tokens or not

Checklist

  • This change is discussed in a Github issue/
    issues
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

@codecov-commenter
Copy link

codecov-commenter commented Sep 9, 2022

Codecov Report

Merging #2451 (d4fa16d) into main (8687673) will decrease coverage by 0.83%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##             main    #2451      +/-   ##
==========================================
- Coverage   79.66%   78.83%   -0.84%     
==========================================
  Files          49       49              
  Lines        4982     5070      +88     
==========================================
+ Hits         3969     3997      +28     
- Misses       1013     1073      +60     
Impacted Files Coverage Δ
flax/linen/pooling.py 86.48% <100.00%> (+2.11%) ⬆️
flax/training/checkpoints.py 61.35% <0.00%> (-11.86%) ⬇️
flax/errors.py 87.20% <0.00%> (-0.75%) ⬇️
flax/traverse_util.py 98.52% <0.00%> (-0.50%) ⬇️
flax/linen/linear.py 97.50% <0.00%> (ø)
flax/core/lift.py 95.81% <0.00%> (+<0.01%) ⬆️
flax/linen/transforms.py 94.06% <0.00%> (+0.02%) ⬆️
flax/linen/module.py 92.75% <0.00%> (+0.03%) ⬆️
flax/serialization.py 69.27% <0.00%> (+1.39%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

@cgarciae
Copy link
Collaborator

Hey @dslisleedh, thanks for creating this PR!
I think this looks good, however, before merging we should probably add a test using this flag.

@cgarciae cgarciae self-assigned this Sep 13, 2022
@dslisleedh
Copy link
Contributor Author

To @cgarciae,

Sorry, I tested my self but didn't share result in PR. Here's result.

for i, shape in enumerate([(1, 5, 3), (1, 5, 5, 3), (1, 5, 5, 5, 3)]):
    inputs = jnp.ones(shape=shape)
    for kernel_size in [(1,), (3,), (5,)]:
        k_s = kernel_size * (i + 1)
        for strides in [(1,), (2,), (3,)]:
            s = strides * (i + 1)
            for padding in ['VALID','SAME']:
                res = avg_pool(inputs, k_s, strides=s, padding=padding, count_include_pad=False)

                print(jnp.min(res), jnp.max(res))

1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0
1.0 1.0

This is first time for me to create PR, so tell me anything if I missed something. Thank you.

@cgarciae
Copy link
Collaborator

Just noticed there aren't any tests for pooling.py, I'll create an issue, we can add a test for this in the future.

@dslisleedh
Copy link
Contributor Author

dslisleedh commented Sep 14, 2022

I found PoolTest class from ./tests/linen/test_linen.py

class PoolTest(absltest.TestCase):

  def test_pool_custom_reduce(self):
    x = jnp.full((1, 3, 3, 1), 2.)
    mul_reduce = lambda x, y: x * y
    y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
    np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))

  def test_avg_pool(self):
    x = jnp.full((1, 3, 3, 1), 2.)
    pool = lambda x: nn.avg_pool(x, (2, 2))
    y = pool(x)
    np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
    y_grad = jax.grad(lambda x: pool(x).sum())(x)
    expected_grad = jnp.array([
        [0.25, 0.5, 0.25],
        [0.5, 1., 0.5],
        [0.25, 0.5, 0.25],
    ]).reshape((1, 3, 3, 1))
    np.testing.assert_allclose(y_grad, expected_grad)

  def test_avg_pool_no_batch(self):
    x = jnp.full((3, 3, 1), 2.)
    pool = lambda x: nn.avg_pool(x, (2, 2))
    y = pool(x)
    np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
    y_grad = jax.grad(lambda x: pool(x).sum())(x)
    expected_grad = jnp.array([
        [0.25, 0.5, 0.25],
        [0.5, 1., 0.5],
        [0.25, 0.5, 0.25],
    ]).reshape((3, 3, 1))
    np.testing.assert_allclose(y_grad, expected_grad)

  def test_max_pool(self):
    ...

When I ran this code with the edits from this PR there were no problems.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py 
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] IdsTest.test_hashable
[       OK ] IdsTest.test_hashable
[ RUN      ] NormalizationTest.test_batch_norm
I0915 01:20:42.378859 4314512704 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0915 01:20:42.379297 4314512704 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0915 01:20:42.379443 4314512704 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0915 01:20:42.379549 4314512704 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0915 01:20:42.380259 4314512704 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] NormalizationTest.test_batch_norm
[ RUN      ] NormalizationTest.test_batch_norm_complex
[       OK ] NormalizationTest.test_batch_norm_complex
[ RUN      ] NormalizationTest.test_batch_norm_multi_init
[       OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN      ] NormalizationTest.test_group_norm
[       OK ] NormalizationTest.test_group_norm
[ RUN      ] NormalizationTest.test_group_norm_raises
[       OK ] NormalizationTest.test_group_norm_raises
[ RUN      ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[       OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN      ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[       OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN      ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[       OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN      ] PoolTest.test_avg_pool
[       OK ] PoolTest.test_avg_pool
[ RUN      ] PoolTest.test_avg_pool_no_batch
[       OK ] PoolTest.test_avg_pool_no_batch
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
[ RUN      ] RecurrentTest.test_complex_input_gru
[       OK ] RecurrentTest.test_complex_input_gru
[ RUN      ] RecurrentTest.test_convlstm
[       OK ] RecurrentTest.test_convlstm
[ RUN      ] RecurrentTest.test_gru
[       OK ] RecurrentTest.test_gru
[ RUN      ] RecurrentTest.test_lstm
[       OK ] RecurrentTest.test_lstm
[ RUN      ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
  warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[       OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN      ] StochasticTest.test_dropout
[       OK ] StochasticTest.test_dropout
[ RUN      ] StochasticTest.test_dropout_rate_limits
[       OK ] StochasticTest.test_dropout_rate_limits
[ RUN      ] StochasticTest.test_dropout_rate_stats
[       OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 21 tests in 13.907s

OK

@cgarciae
Copy link
Collaborator

@dslisleedh can you create one of more tests under PoolTest using this new argument? You can copy an existing test, rename it, and change it to use the new count_include_pad argument. This would help immensely as these will be checked automatically before merging new code.

…iv_shape for it raises error when there's no batch dimension
@dslisleedh
Copy link
Contributor Author

To @cgarciae

I add some codes to TestPool and here's code I tested.

from absl.testing import absltest, parameterized

from flax import ids
from flax import linen as nn

import jax
from jax import random
from jax import test_util as jtu
from jax.nn import initializers
import jax.numpy as jnp

import numpy as np

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()


class PoolTest(parameterized.TestCase):

    def test_pool_custom_reduce(self):
        x = jnp.full((1, 3, 3, 1), 2.)
        mul_reduce = lambda x, y: x * y
        y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
        np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2. ** 4))

    @parameterized.parameters(
        {'count_include_pad': True},
        {'count_include_pad': False})
    def test_avg_pool(self, count_include_pad):
        x = jnp.full((1, 3, 3, 1), 2.)
        pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
        y = pool(x)
        np.testing.assert_allclose(y, np.full((1, 2, 2, 1), 2.))
        y_grad = jax.grad(lambda x: pool(x).sum())(x)
        expected_grad = jnp.array([
            [0.25, 0.5, 0.25],
            [0.5, 1., 0.5],
            [0.25, 0.5, 0.25],
        ]).reshape((1, 3, 3, 1))
        np.testing.assert_allclose(y_grad, expected_grad)

    @parameterized.parameters(
        {'count_include_pad': True},
        {'count_include_pad': False})
    def test_avg_pool_no_batch(self, count_include_pad):
        x = jnp.full((3, 3, 1), 2.)
        pool = lambda x: nn.avg_pool(x, (2, 2), count_include_pad=count_include_pad)
        y = pool(x)
        np.testing.assert_allclose(y, np.full((2, 2, 1), 2.))
        y_grad = jax.grad(lambda x: pool(x).sum())(x)
        expected_grad = jnp.array([
            [0.25, 0.5, 0.25],
            [0.5, 1., 0.5],
            [0.25, 0.5, 0.25],
        ]).reshape((3, 3, 1))
        np.testing.assert_allclose(y_grad, expected_grad)

    def test_max_pool(self):
        x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32)
        pool = lambda x: nn.max_pool(x, (2, 2))
        expected_y = jnp.array([
            [4., 5.],
            [7., 8.],
        ]).reshape((1, 2, 2, 1))
        y = pool(x)
        np.testing.assert_allclose(y, expected_y)
        y_grad = jax.grad(lambda x: pool(x).sum())(x)
        expected_grad = jnp.array([
            [0., 0., 0.],
            [0., 1., 1.],
            [0., 1., 1.],
        ]).reshape((1, 3, 3, 1))
        np.testing.assert_allclose(y_grad, expected_grad)


if __name__ == '__main__':
  absltest.main()

When I tested with the previous code, an error occurred in the non-batch avg_pool, so I corrected the PR. Thanks for telling me to test it for sure.

Below is the test result with the modified code.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python pool_test.py 
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] PoolTest.test_avg_pool0 (count_include_pad=True)
I0920 20:09:02.610634 4371807552 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0920 20:09:02.610964 4371807552 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0920 20:09:02.611104 4371807552 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0920 20:09:02.611212 4371807552 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0920 20:09:02.611675 4371807552 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
----------------------------------------------------------------------
Ran 6 tests in 1.586s

OK

Thank you.

@cgarciae
Copy link
Collaborator

Awesome @dslisleedh! Can you commit changes to the tests?

…iv_shape for it raises error when there's no batch dimension
@dslisleedh
Copy link
Contributor Author

To @cgarciae

Sure :)

@cgarciae
Copy link
Collaborator

@dslisleedh can you add this test? None of the other tests used padding="SAME" which is the more interesting case.

    @parameterized.parameters(
        {'count_include_pad': True},
        {'count_include_pad': False})
    def test_avg_pool_padding_same(self, count_include_pad):
        x = jnp.array([1.0, 2.0, 3.0, 4.0]).reshape((1, 2, 2, 1))
        pool = lambda x: nn.avg_pool(x, (2, 2), padding="SAME", count_include_pad=count_include_pad)
        y = pool(x)
        if count_include_pad:
          expected_y = jnp.array([10.0 / 4, 6.0 / 4, 7.0 / 4, 4.0 / 4]).reshape((1, 2, 2, 1))
        else:
          expected_y = jnp.array([10.0 / 4, 6.0 / 2, 7.0 / 2, 4.0 / 1]).reshape((1, 2, 2, 1))
        np.testing.assert_allclose(y, expected_y)

@dslisleedh
Copy link
Contributor Author

@cgarciae Oh, I forgot that. Thank you.

and here is result of your code.

(flax_test) idongheon@idongheon-ui-MacBookAir flax % python ./tests/linen/linen_test.py 
Running tests under Python 3.10.5: /Users/idongheon/miniforge3/envs/flax_test/bin/python
[ RUN      ] IdsTest.test_hashable
[       OK ] IdsTest.test_hashable
[ RUN      ] NormalizationTest.test_batch_norm
I0921 22:23:30.726801 4309613888 xla_bridge.py:169] Remote TPU is not linked into jax; skipping remote TPU.
I0921 22:23:30.727137 4309613888 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': Could not initialize backend 'tpu_driver'
I0921 22:23:30.727280 4309613888 xla_bridge.py:345] Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0921 22:23:30.727386 4309613888 xla_bridge.py:345] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I0921 22:23:30.727836 4309613888 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'
[       OK ] NormalizationTest.test_batch_norm
[ RUN      ] NormalizationTest.test_batch_norm_complex
[       OK ] NormalizationTest.test_batch_norm_complex
[ RUN      ] NormalizationTest.test_batch_norm_multi_init
[       OK ] NormalizationTest.test_batch_norm_multi_init
[ RUN      ] NormalizationTest.test_group_norm
[       OK ] NormalizationTest.test_group_norm
[ RUN      ] NormalizationTest.test_group_norm_raises
[       OK ] NormalizationTest.test_group_norm_raises
[ RUN      ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[       OK ] NormalizationTest.test_layer_norm0 (reduction_axes=-1)
[ RUN      ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[       OK ] NormalizationTest.test_layer_norm1 (reduction_axes=1)
[ RUN      ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[       OK ] NormalizationTest.test_layer_norm2 (reduction_axes=(1, 2))
[ RUN      ] PoolTest.test_avg_pool0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_no_batch0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_no_batch1 (count_include_pad=False)
[ RUN      ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[       OK ] PoolTest.test_avg_pool_padding_same0 (count_include_pad=True)
[ RUN      ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[       OK ] PoolTest.test_avg_pool_padding_same1 (count_include_pad=False)
[ RUN      ] PoolTest.test_max_pool
[       OK ] PoolTest.test_max_pool
[ RUN      ] PoolTest.test_pool_custom_reduce
[       OK ] PoolTest.test_pool_custom_reduce
[ RUN      ] RecurrentTest.test_complex_input_gru
[       OK ] RecurrentTest.test_complex_input_gru
[ RUN      ] RecurrentTest.test_convlstm
[       OK ] RecurrentTest.test_convlstm
[ RUN      ] RecurrentTest.test_gru
[       OK ] RecurrentTest.test_gru
[ RUN      ] RecurrentTest.test_lstm
[       OK ] RecurrentTest.test_lstm
[ RUN      ] RecurrentTest.test_optimized_lstm_cell_matches_regular
/Users/idongheon/miniforge3/envs/flax_test/lib/python3.10/site-packages/jax/test_util.py:44: FutureWarning: jax.test_util.check_eq is deprecated and will soon be removed.
  warnings.warn(f"jax.test_util.{attr} is deprecated and will soon be removed.", FutureWarning)
[       OK ] RecurrentTest.test_optimized_lstm_cell_matches_regular
[ RUN      ] StochasticTest.test_dropout
[       OK ] StochasticTest.test_dropout
[ RUN      ] StochasticTest.test_dropout_rate_limits
[       OK ] StochasticTest.test_dropout_rate_limits
[ RUN      ] StochasticTest.test_dropout_rate_stats
[       OK ] StochasticTest.test_dropout_rate_stats
----------------------------------------------------------------------
Ran 25 tests in 14.023s

OK

Copy link
Collaborator

@cgarciae cgarciae left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great @dslisleedh, thanks for going through with this!

@copybara-service copybara-service bot merged commit b344022 into google:main Sep 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants