Skip to content

Commit

Permalink
Implement global pooling layer (keras-team#74)
Browse files Browse the repository at this point in the history
* Add max and poolig layer

* fix tests

* handle TF transpose

* renaming

* initial

* add something

* rename tests

* more

* add docstring

* add mask for global average pooling 1d
  • Loading branch information
chenmoneygithub authored May 3, 2023
1 parent e8fdf09 commit 066d4dd
Show file tree
Hide file tree
Showing 10 changed files with 788 additions and 0 deletions.
12 changes: 12 additions & 0 deletions keras_core/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@
from keras_core.layers.pooling.average_pooling1d import AveragePooling1D
from keras_core.layers.pooling.average_pooling2d import AveragePooling2D
from keras_core.layers.pooling.average_pooling3d import AveragePooling3D
from keras_core.layers.pooling.global_average_pooling1d import (
GlobalAveragePooling1D,
)
from keras_core.layers.pooling.global_average_pooling2d import (
GlobalAveragePooling2D,
)
from keras_core.layers.pooling.global_average_pooling3d import (
GlobalAveragePooling3D,
)
from keras_core.layers.pooling.global_max_pooling1d import GlobalMaxPooling1D
from keras_core.layers.pooling.global_max_pooling2d import GlobalMaxPooling2D
from keras_core.layers.pooling.global_max_pooling3d import GlobalMaxPooling3D
from keras_core.layers.pooling.max_pooling1d import MaxPooling1D
from keras_core.layers.pooling.max_pooling2d import MaxPooling2D
from keras_core.layers.pooling.max_pooling3d import MaxPooling3D
Expand Down
50 changes: 50 additions & 0 deletions keras_core/layers/pooling/base_global_pooling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from keras_core.backend import image_data_format
from keras_core.layers.input_spec import InputSpec
from keras_core.layers.layer import Layer


class BaseGlobalPooling(Layer):
"""Base global pooling layer."""

def __init__(
self, pool_dimensions, data_format=None, keepdims=False, **kwargs
):
super().__init__(**kwargs)

self.data_format = (
image_data_format() if data_format is None else data_format
)
self.keepdims = keepdims
self.input_spec = InputSpec(ndim=pool_dimensions + 2)

def call(self, inputs):
raise NotImplementedError

def compute_output_shape(self, input_shape):
num_spatial_dims = len(input_shape) - 2
if self.data_format == "channels_last":
if self.keepdims:
return (
(input_shape[0],)
+ (1,) * num_spatial_dims
+ (input_shape[-1],)
)
else:
return (input_shape[0],) + (input_shape[-1],)
else:
if self.keepdims:
return (input_shape[0], input_shape[1]) + (
1,
) * num_spatial_dims
else:
return (input_shape[0], input_shape[1])

def get_config(self):
config = super().get_config()
config.update(
{
"data_format": self.data_format,
"keepdims": self.keepdims,
}
)
return config
84 changes: 84 additions & 0 deletions keras_core/layers/pooling/global_average_pooling1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from keras_core import operations as ops
from keras_core import backend
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_global_pooling import BaseGlobalPooling


@keras_core_export(
[
"keras_core.layers.GlobalAveragePooling1D",
"keras_core.layers.GlobalAvgPool1D",
]
)
class GlobalAveragePooling1D(BaseGlobalPooling):
"""Global average pooling operation for temporal data.
Args:
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, steps, features)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, steps)`. It defaults to the `image_data_format`
value found in your Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be `"channels_last"`.
keepdims: A boolean, whether to keep the temporal dimension or not.
If `keepdims` is `False` (default), the rank of the tensor is
reduced for spatial dimensions. If `keepdims` is `True`, the
temporal dimension are retained with length 1.
The behavior is the same as for `tf.reduce_mean` or `np.mean`.
Call arguments:
inputs: A 3D tensor.
mask: Binary tensor of shape `(batch_size, steps)` indicating whether
a given step should be masked (excluded from the average).
Input shape:
- If `data_format='channels_last'`:
3D tensor with shape:
`(batch_size, steps, features)`
- If `data_format='channels_first'`:
3D tensor with shape:
`(batch_size, features, steps)`
Output shape:
- If `keepdims`=False:
2D tensor with shape `(batch_size, features)`.
- If `keepdims`=True:
- If `data_format="channels_last"`:
3D tensor with shape `(batch_size, 1, features)`
- If `data_format="channels_first"`:
3D tensor with shape `(batch_size, features, 1)`
Examples:
>>> x = np.random.rand(2, 3, 4)
>>> y = keras_core.layers.GlobalAveragePooling1D()(x)
>>> print(y.shape)
(2, 4)
"""

def __init__(self, data_format=None, keepdims=False, **kwargs):
super().__init__(
pool_dimensions=1,
data_format=data_format,
keepdims=keepdims,
**kwargs,
)
self.supports_masking = True

def call(self, inputs, mask=None):
steps_axis = 1 if self.data_format == "channels_last" else 2
if mask is not None:
mask = backend.cast(mask, inputs[0].dtype)
mask = ops.expand_dims(
mask, 2 if self.data_format == "channels_last" else 1
)
inputs *= mask
return ops.sum(
inputs, axis=steps_axis, keepdims=self.keepdims
) / ops.sum(mask, axis=steps_axis, keepdims=self.keepdims)
else:
return ops.mean(inputs, axis=steps_axis, keepdims=self.keepdims)

def compute_mask(self, inputs, mask=None):
return None
66 changes: 66 additions & 0 deletions keras_core/layers/pooling/global_average_pooling2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_global_pooling import BaseGlobalPooling


@keras_core_export(
[
"keras_core.layers.GlobalAveragePooling2D",
"keras_core.layers.GlobalAvgPool2D",
]
)
class GlobalAveragePooling2D(BaseGlobalPooling):
"""Global average pooling operation for 2D data.
Args:
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape `(batch, height, width, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, features, height, weight)`. It defaults to the
`image_data_format` value found in your Keras config file at
`~/.keras/keras.json`. If you never set it, then it will be
`"channels_last"`.
keepdims: A boolean, whether to keep the temporal dimension or not.
If `keepdims` is `False` (default), the rank of the tensor is
reduced for spatial dimensions. If `keepdims` is `True`, the
spatial dimension are retained with length 1.
The behavior is the same as for `tf.reduce_mean` or `np.mean`.
Input shape:
- If `data_format='channels_last'`:
4D tensor with shape:
`(batch_size, height, width, channels)`
- If `data_format='channels_first'`:
4D tensor with shape:
`(batch_size, channels, height, width)`
Output shape:
- If `keepdims`=False:
2D tensor with shape `(batch_size, channels)`.
- If `keepdims`=True:
- If `data_format="channels_last"`:
4D tensor with shape `(batch_size, 1, 1, channels)`
- If `data_format="channels_first"`:
4D tensor with shape `(batch_size, channels, 1, 1)`
Examples:
>>> x = np.random.rand(2, 4, 5, 3)
>>> y = keras_core.layers.GlobalAveragePooling2D()(x)
>>> print(y.shape)
(2, 3)
"""

def __init__(self, data_format=None, keepdims=False, **kwargs):
super().__init__(
pool_dimensions=2,
data_format=data_format,
keepdims=keepdims,
**kwargs,
)

def call(self, inputs):
if self.data_format == "channels_last":
return ops.mean(inputs, axis=[1, 2], keepdims=self.keepdims)
return ops.mean(inputs, axis=[2, 3], keepdims=self.keepdims)
67 changes: 67 additions & 0 deletions keras_core/layers/pooling/global_average_pooling3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from keras_core import operations as ops
from keras_core.api_export import keras_core_export
from keras_core.layers.pooling.base_global_pooling import BaseGlobalPooling


@keras_core_export(
[
"keras_core.layers.GlobalAveragePooling3D",
"keras_core.layers.GlobalAvgPool3D",
]
)
class GlobalAveragePooling3D(BaseGlobalPooling):
"""Global average pooling operation for 3D data.
Args:
data_format: string, either `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs. `"channels_last"`
corresponds to inputs with shape
`(batch, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
while `"channels_first"` corresponds to inputs with shape
`(batch, channels, spatial_dim1, spatial_dim2, spatial_dim3)`.
It defaults to the `image_data_format` value found in your Keras
config file at `~/.keras/keras.json`. If you never set it, then it
will be `"channels_last"`.
keepdims: A boolean, whether to keep the temporal dimension or not.
If `keepdims` is `False` (default), the rank of the tensor is
reduced for spatial dimensions. If `keepdims` is `True`, the
spatial dimension are retained with length 1.
The behavior is the same as for `tf.reduce_mean` or `np.mean`.
Input shape:
- If `data_format='channels_last'`:
5D tensor with shape:
`(batch_size, spatial_dim1, spatial_dim2, spatial_dim3, channels)`
- If `data_format='channels_first'`:
5D tensor with shape:
`(batch_size, channels, spatial_dim1, spatial_dim2, spatial_dim3)`
Output shape:
- If `keepdims`=False:
2D tensor with shape `(batch_size, channels)`.
- If `keepdims`=True:
- If `data_format="channels_last"`:
5D tensor with shape `(batch_size, 1, 1, 1, channels)`
- If `data_format="channels_first"`:
5D tensor with shape `(batch_size, channels, 1, 1, 1)`
Examples:
>>> x = np.random.rand(2, 4, 5, 4, 3)
>>> y = keras_core.layers.GlobalAveragePooling3D()(x)
>>> print(y.shape)
(2, 3)
"""

def __init__(self, data_format=None, keepdims=False, **kwargs):
super().__init__(
pool_dimensions=3,
data_format=data_format,
keepdims=keepdims,
**kwargs,
)

def call(self, inputs):
if self.data_format == "channels_last":
return ops.mean(inputs, axis=[1, 2, 3], keepdims=self.keepdims)
return ops.mean(inputs, axis=[2, 3, 4], keepdims=self.keepdims)
Loading

0 comments on commit 066d4dd

Please sign in to comment.