Skip to content

Commit

Permalink
Revert "AdaptivePooling Fixed"
Browse files Browse the repository at this point in the history
This reverts commit 720adf7.
  • Loading branch information
awsaf49 committed Oct 16, 2023
1 parent 2257ff4 commit 02a9181
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 126 deletions.
4 changes: 1 addition & 3 deletions gcvit/layers/feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ def __init__(self, oup=None, expansion=0.25, **kwargs):
def build(self, input_shape):
inp = input_shape[-1]
self.oup = self.oup or inp
self.avg_pool = AdaptiveAveragePooling2D(
output_size=(1, 1), name="avg_pool"
)
self.avg_pool = AdaptiveAveragePooling2D(1, name="avg_pool")
self.fc = [
tf.keras.layers.Dense(
int(inp * self.expansion), use_bias=False, name="fc/0"
Expand Down
253 changes: 130 additions & 123 deletions gcvit/layers/pooling.py
Original file line number Diff line number Diff line change
@@ -1,135 +1,142 @@
import math
from typing import Tuple
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import Callable
from typing import Iterable
from typing import Union

import tensorflow as tf

from ..utils import normalize_data_format
from ..utils import normalize_tuple


@tf.keras.utils.register_keras_serializable(package="gcvit")
class AdaptivePooling2D(tf.keras.layers.Layer):
"""Parent class for 2D pooling layers with adaptive kernel size.
Implementation is based on tensorflow-addons:
https://github.com/tensorflow/addons/blob/v0.17.0/tensorflow_addons/layers/adaptive_pooling.py#LL157C1-L234C41
This class only exists for code reuse. It will never be an exposed API.
Args:
reduce_function: The reduction method to apply, e.g. `tf.reduce_max`.
output_size: An integer or tuple/list of 2 integers specifying
(pooled_rows, pooled_cols). The new size of output channels.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape
`(batch, channels, height, width)`.
"""

class AdaptiveAveragePooling2D(tf.keras.layers.Layer):
def __init__(
self,
output_size: Tuple[int, int],
input_ordering: str = "NHWC",
**kwargs
reduce_function: Callable,
output_size: Union[int, Iterable[int]],
data_format=None,
**kwargs,
):
self.data_format = normalize_data_format(data_format)
self.reduce_function = reduce_function
self.output_size = normalize_tuple(output_size, 2, "output_size")
super().__init__(**kwargs)
self.output_size = output_size
self.input_ordering = input_ordering
if input_ordering not in ("NCHW", "NHWC"):
raise ValueError(
"Unrecognized input_ordering, should be 'NCHW' or 'NHWC'!"
)
self.h_axis = input_ordering.index("H")
self.w_axis = input_ordering.index("W")

def pseudo_1d_pool(self, inputs: tf.Tensor, h_pooling: bool):
# Figure out which axis we're pooling on
if h_pooling:
axis = self.h_axis
output_dim = self.output_size[0]
else:
axis = self.w_axis
output_dim = self.output_size[1]
input_dim = inputs.shape[axis]

# Figure out the potential pooling windows
# This is the key idea - the torch op will always use only two
# consecutive pooling window sizes, like 3 and 4. Therefore,
# if we pool with both possible sizes, we simply need to gather
# the 'correct' pool at each position to reimplement the torch op.
small_window = math.ceil(input_dim / output_dim)
big_window = small_window + 1
if h_pooling:
output_dim = self.output_size[0]
small_window_shape = (small_window, 1)
big_window_shape = (big_window, 1)
else:
output_dim = self.output_size[1]
small_window_shape = (1, small_window)
big_window_shape = (1, big_window)

# For integer resizes, we can take a very quick shortcut
if input_dim % output_dim == 0:
return tf.nn.avg_pool2d(
inputs,
ksize=small_window_shape,
strides=small_window_shape,
padding="VALID",
data_format=self.input_ordering,
)

# For non-integer resizes, we pool with both possible window sizes
# and concatenate them
small_pool = tf.nn.avg_pool2d(
inputs,
ksize=small_window_shape,
strides=1,
padding="VALID",
data_format=self.input_ordering,
)
big_pool = tf.nn.avg_pool2d(
inputs,
ksize=big_window_shape,
strides=1,
padding="VALID",
data_format=self.input_ordering,
)
both_pool = tf.concat([small_pool, big_pool], axis=axis)

# We compute vectors of the start and end positions
# for each pooling window
# Each (start, end) pair here corresponds to a single output position
window_starts = tf.math.floor(
(tf.range(output_dim, dtype=tf.float32) * input_dim) / output_dim
)
window_starts = tf.cast(window_starts, tf.int64)
window_ends = tf.math.ceil(
(tf.range(1, output_dim + 1, dtype=tf.float32) * input_dim)
/ output_dim
)
window_ends = tf.cast(window_ends, tf.int64)

# pool_selector is a boolean array of shape (output_dim,)
# where 1 indicates that output position
# has a big receptive field and 0 indicates that that output
# position has a small receptive field
pool_selector = tf.cast(
window_ends - window_starts - small_window, tf.bool
)

# Since we concatenated the small and big pools, we need to do a bit of
# pointer arithmetic to get the indices of the big pools
small_indices = window_starts
big_indices = window_starts + small_pool.shape[axis]

# Finally, we use the pool_selector to generate a list of indices,
# one per output position
gather_indices = tf.where(pool_selector, big_indices, small_indices)

# Gathering from those indices yields the final, correct pooling
return tf.gather(both_pool, gather_indices, axis=axis)

def call(self, inputs: tf.Tensor):
if self.input_ordering == "NHWC":
input_shape = inputs.shape[1:3]
def call(self, inputs, *args):
h_bins = self.output_size[0]
w_bins = self.output_size[1]
if self.data_format == "channels_last":
split_cols = tf.split(inputs, h_bins, axis=1)
split_cols = tf.stack(split_cols, axis=1)
split_rows = tf.split(split_cols, w_bins, axis=3)
split_rows = tf.stack(split_rows, axis=3)
out_vect = self.reduce_function(split_rows, axis=[2, 4])
else:
input_shape = inputs.shape[2:]

if (
input_shape[0] % self.output_size[0] == 0
and input_shape[1] % self.output_size[1] == 0
):
# If we're resizing by an integer factor on both dimensions,
# we can take a very quick shortcut.
h_resize = int(input_shape[0] // self.output_size[0])
w_resize = int(input_shape[1] // self.output_size[1])
return tf.nn.avg_pool2d(
inputs,
ksize=(h_resize, w_resize),
strides=(h_resize, w_resize),
padding="VALID",
data_format=self.input_ordering,
split_cols = tf.split(inputs, h_bins, axis=2)
split_cols = tf.stack(split_cols, axis=2)
split_rows = tf.split(split_cols, w_bins, axis=4)
split_rows = tf.stack(split_rows, axis=4)
out_vect = self.reduce_function(split_rows, axis=[3, 5])
return out_vect

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
if self.data_format == "channels_last":
shape = tf.TensorShape(
[
input_shape[0],
self.output_size[0],
self.output_size[1],
input_shape[3],
]
)
else:
# If we can't take the shortcut, we do a 1D pool on each axis
h_pooled = self.pseudo_1d_pool(inputs, h_pooling=True)
return self.pseudo_1d_pool(h_pooled, h_pooling=False)
shape = tf.TensorShape(
[
input_shape[0],
input_shape[1],
self.output_size[0],
self.output_size[1],
]
)

return shape

def get_config(self):
config = {
"output_size": self.output_size,
"data_format": self.data_format,
}
base_config = super().get_config()
return {**base_config, **config}


@tf.keras.utils.register_keras_serializable(package="gcvit")
class AdaptiveAveragePooling2D(AdaptivePooling2D):
"""Average Pooling with adaptive kernel size.
Class is borrowed from tensorflow-addons:
https://github.com/tensorflow/addons/blob/v0.17.0/tensorflow_addons/layers/adaptive_pooling.py#L238
Args:
output_size: Tuple of integers specifying (pooled_rows, pooled_cols).
The new size of output channels.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first`
corresponds to inputs with shape `(batch, channels, height, width)`.
Input shape:
- If `data_format='channels_last'`:
4D tensor with shape `(batch_size, height, width, channels)`.
- If `data_format='channels_first'`:
4D tensor with shape `(batch_size, channels, height, width)`.
Output shape:
- If `data_format='channels_last'`:
4D tensor with shape `(batch_size, pooled_rows, pooled_cols, channels)`.
- If `data_format='channels_first'`:
4D tensor with shape `(batch_size, channels, pooled_rows, pooled_cols)`.
"""

def __init__(
self, output_size: Union[int, Iterable[int]], data_format=None, **kwargs
):
super().__init__(tf.reduce_mean, output_size, data_format, **kwargs)

0 comments on commit 02a9181

Please sign in to comment.