Skip to content

Commit f48f480

Browse files
Add adaptive pooling (1D, 2D, 3D) support across JAX, NumPy, TensorFlow, and PyTorch backends (#21820)
* Add AdaptiveAveragePooling2D and AdaptiveMaxPooling2D layers * Add adaptive pooling (adaptive_avg_pool and adaptive_max_pool) for JAX, NumPy, PyTorch, and TensorFlow backends * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Fix adaptive pooling implementation * Refactor adaptive pooling with shared utils and base classes * Update adaptive pooling implementation per review feedback * Update adaptive pooling implementation per review feedback * Fix config imports and regenerate API. * Fix exports. * Fix tests, in particular with Torch on GPU. --------- Co-authored-by: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com>
1 parent 0771c80 commit f48f480

24 files changed

+2377
-0
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,24 @@
113113
from keras.src.layers.normalization.unit_normalization import (
114114
UnitNormalization as UnitNormalization,
115115
)
116+
from keras.src.layers.pooling.adaptive_average_pooling1d import (
117+
AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,
118+
)
119+
from keras.src.layers.pooling.adaptive_average_pooling2d import (
120+
AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,
121+
)
122+
from keras.src.layers.pooling.adaptive_average_pooling3d import (
123+
AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,
124+
)
125+
from keras.src.layers.pooling.adaptive_max_pooling1d import (
126+
AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,
127+
)
128+
from keras.src.layers.pooling.adaptive_max_pooling2d import (
129+
AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,
130+
)
131+
from keras.src.layers.pooling.adaptive_max_pooling3d import (
132+
AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,
133+
)
116134
from keras.src.layers.pooling.average_pooling1d import (
117135
AveragePooling1D as AveragePooling1D,
118136
)

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
from keras.src.ops.math import top_k as top_k
6565
from keras.src.ops.math import view_as_complex as view_as_complex
6666
from keras.src.ops.math import view_as_real as view_as_real
67+
from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
68+
from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
6769
from keras.src.ops.nn import average_pool as average_pool
6870
from keras.src.ops.nn import batch_normalization as batch_normalization
6971
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy

keras/api/_tf_keras/keras/ops/nn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
8+
from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
79
from keras.src.ops.nn import average_pool as average_pool
810
from keras.src.ops.nn import batch_normalization as batch_normalization
911
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy

keras/api/layers/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,24 @@
113113
from keras.src.layers.normalization.unit_normalization import (
114114
UnitNormalization as UnitNormalization,
115115
)
116+
from keras.src.layers.pooling.adaptive_average_pooling1d import (
117+
AdaptiveAveragePooling1D as AdaptiveAveragePooling1D,
118+
)
119+
from keras.src.layers.pooling.adaptive_average_pooling2d import (
120+
AdaptiveAveragePooling2D as AdaptiveAveragePooling2D,
121+
)
122+
from keras.src.layers.pooling.adaptive_average_pooling3d import (
123+
AdaptiveAveragePooling3D as AdaptiveAveragePooling3D,
124+
)
125+
from keras.src.layers.pooling.adaptive_max_pooling1d import (
126+
AdaptiveMaxPooling1D as AdaptiveMaxPooling1D,
127+
)
128+
from keras.src.layers.pooling.adaptive_max_pooling2d import (
129+
AdaptiveMaxPooling2D as AdaptiveMaxPooling2D,
130+
)
131+
from keras.src.layers.pooling.adaptive_max_pooling3d import (
132+
AdaptiveMaxPooling3D as AdaptiveMaxPooling3D,
133+
)
116134
from keras.src.layers.pooling.average_pooling1d import (
117135
AveragePooling1D as AveragePooling1D,
118136
)

keras/api/ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@
6464
from keras.src.ops.math import top_k as top_k
6565
from keras.src.ops.math import view_as_complex as view_as_complex
6666
from keras.src.ops.math import view_as_real as view_as_real
67+
from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
68+
from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
6769
from keras.src.ops.nn import average_pool as average_pool
6870
from keras.src.ops.nn import batch_normalization as batch_normalization
6971
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy

keras/api/ops/nn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
since your modifications would be overwritten.
55
"""
66

7+
from keras.src.ops.nn import adaptive_average_pool as adaptive_average_pool
8+
from keras.src.ops.nn import adaptive_max_pool as adaptive_max_pool
79
from keras.src.ops.nn import average_pool as average_pool
810
from keras.src.ops.nn import batch_normalization as batch_normalization
911
from keras.src.ops.nn import binary_crossentropy as binary_crossentropy

keras/src/backend/common/backend_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import math
23
import operator
34
import re
45
import warnings
@@ -539,3 +540,10 @@ def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
539540
-1 - axis
540541
)
541542
return x[tuple(slices)]
543+
544+
545+
def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
546+
"""Compute small and big window sizes for adaptive pooling."""
547+
small = math.ceil(input_dim / output_dim)
548+
big = small + 1
549+
return small, big

0 commit comments

Comments
 (0)