Skip to content

Commit 71f4997

Browse files
Introduces QuantizationConfig for fine-grained quantization control (#21896)
* Introduces customizable quantization API * mixed precision einsum fix for torch + fixed tf/jax tests * fixed minor errors + api export * Removed redundant matmuls + added docs * minor cleanup + docstring improvements * address comments * refactor validation * make mode optional
1 parent 3989d64 commit 71f4997

File tree

18 files changed

+940
-137
lines changed

18 files changed

+940
-137
lines changed

keras/api/_tf_keras/keras/quantizers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
1010
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11+
from keras.src.quantizers.quantization_config import (
12+
Float8QuantizationConfig as Float8QuantizationConfig,
13+
)
14+
from keras.src.quantizers.quantization_config import (
15+
Int4QuantizationConfig as Int4QuantizationConfig,
16+
)
17+
from keras.src.quantizers.quantization_config import (
18+
Int8QuantizationConfig as Int8QuantizationConfig,
19+
)
20+
from keras.src.quantizers.quantization_config import (
21+
QuantizationConfig as QuantizationConfig,
22+
)
1123
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1224
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1325
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/api/quantizers/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@
88
from keras.src.quantizers import get as get
99
from keras.src.quantizers import serialize as serialize
1010
from keras.src.quantizers.gptq_config import GPTQConfig as GPTQConfig
11+
from keras.src.quantizers.quantization_config import (
12+
Float8QuantizationConfig as Float8QuantizationConfig,
13+
)
14+
from keras.src.quantizers.quantization_config import (
15+
Int4QuantizationConfig as Int4QuantizationConfig,
16+
)
17+
from keras.src.quantizers.quantization_config import (
18+
Int8QuantizationConfig as Int8QuantizationConfig,
19+
)
20+
from keras.src.quantizers.quantization_config import (
21+
QuantizationConfig as QuantizationConfig,
22+
)
1123
from keras.src.quantizers.quantizers import AbsMaxQuantizer as AbsMaxQuantizer
1224
from keras.src.quantizers.quantizers import Quantizer as Quantizer
1325
from keras.src.quantizers.quantizers import abs_max_quantize as abs_max_quantize

keras/src/layers/core/dense.py

Lines changed: 48 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from keras.src.api_export import keras_export
1212
from keras.src.layers.input_spec import InputSpec
1313
from keras.src.layers.layer import Layer
14+
from keras.src.quantizers.quantization_config import QuantizationConfig
15+
from keras.src.quantizers.quantization_config import validate_and_resolve_config
1416
from keras.src.quantizers.quantizers import dequantize_with_sz_map
1517

1618

@@ -378,9 +380,9 @@ def variable_serialization_spec(self):
378380

379381
def quantized_build(self, kernel_shape, mode, config=None):
380382
if mode == "int8":
381-
self._int8_build(kernel_shape)
383+
self._int8_build(kernel_shape, config)
382384
elif mode == "int4":
383-
self._int4_build(kernel_shape)
385+
self._int4_build(kernel_shape, config)
384386
elif mode == "float8":
385387
self._float8_build()
386388
elif mode == "gptq":
@@ -389,8 +391,13 @@ def quantized_build(self, kernel_shape, mode, config=None):
389391
raise self._quantization_mode_error(mode)
390392
self._is_quantized = True
391393

392-
def _int8_build(self, kernel_shape):
393-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(axis=-1)
394+
def _int8_build(self, kernel_shape, config=None):
395+
self.inputs_quantizer = (
396+
QuantizationConfig.activation_quantizer_or_default(
397+
config, quantizers.AbsMaxQuantizer(axis=-1)
398+
)
399+
)
400+
394401
self._kernel = self.add_weight(
395402
name="kernel",
396403
shape=kernel_shape,
@@ -489,7 +496,7 @@ def _gptq_call(self, inputs, training=False):
489496
y = self.activation(y)
490497
return y
491498

492-
def _int4_build(self, kernel_shape):
499+
def _int4_build(self, kernel_shape, config=None):
493500
"""Build variables for int4 quantization.
494501
495502
`kernel_shape` is the *original* float32 kernel shape
@@ -498,8 +505,10 @@ def _int4_build(self, kernel_shape):
498505
int8 byte.
499506
"""
500507
# Per-channel int8 quantizer for the last axis (features).
501-
self.inputs_quantizer = quantizers.AbsMaxQuantizer(
502-
axis=-1,
508+
self.inputs_quantizer = (
509+
QuantizationConfig.activation_quantizer_or_default(
510+
config, quantizers.AbsMaxQuantizer(axis=-1)
511+
)
503512
)
504513
input_dim, output_dim = kernel_shape
505514
packed_rows = (input_dim + 1) // 2 # ceil for odd dims
@@ -588,11 +597,15 @@ def grad_fn(*args, upstream=None):
588597
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
589598
return (inputs_grad, None, None)
590599

591-
inputs, inputs_scale = self.inputs_quantizer(inputs)
600+
output_scale = kernel_scale
601+
if self.inputs_quantizer:
602+
inputs, inputs_scale = self.inputs_quantizer(inputs)
603+
output_scale = ops.multiply(output_scale, inputs_scale)
604+
592605
x = ops.matmul(inputs, kernel)
593606
# De-scale outputs
594607
x = ops.cast(x, self.compute_dtype)
595-
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
608+
x = ops.divide(x, output_scale)
596609
return x, grad_fn
597610

598611
x = matmul_with_inputs_gradient(
@@ -639,10 +652,15 @@ def grad_fn(*args, upstream=None):
639652
inputs_grad = ops.matmul(upstream, ops.transpose(float_kernel))
640653
return (inputs_grad, None, None)
641654

642-
inputs, inputs_scale = self.inputs_quantizer(inputs)
655+
output_scale = kernel_scale
656+
657+
if self.inputs_quantizer:
658+
inputs, inputs_scale = self.inputs_quantizer(inputs)
659+
output_scale = ops.multiply(output_scale, inputs_scale)
660+
643661
x = ops.matmul(inputs, unpacked_kernel)
644662
x = ops.cast(x, self.compute_dtype)
645-
x = ops.divide(x, ops.multiply(inputs_scale, kernel_scale))
663+
x = ops.divide(x, output_scale)
646664
return x, grad_fn
647665

648666
x = matmul_with_inputs_gradient(
@@ -754,38 +772,46 @@ def grad(*args, upstream=None, variables=None):
754772
x = self.activation(x)
755773
return x
756774

757-
def quantize(self, mode, type_check=True, config=None):
775+
def quantize(self, mode=None, type_check=True, config=None):
758776
# Prevent quantization of the subclasses
759777
if type_check and (type(self) is not Dense):
760778
raise self._not_implemented_error(self.quantize)
761779

780+
config = validate_and_resolve_config(mode, config)
781+
mode = config.mode
782+
762783
kernel_shape = self._kernel.shape
763784
if mode == "int8":
764-
kernel_value, kernel_scale = quantizers.abs_max_quantize(
765-
self._kernel, axis=0, to_numpy=True
785+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
786+
config, quantizers.AbsMaxQuantizer(axis=0)
787+
)
788+
kernel_value, kernel_scale = weight_quantizer(
789+
self._kernel, to_numpy=True
766790
)
767791
kernel_scale = ops.squeeze(kernel_scale, axis=0)
768792
del self._kernel
769793
# Build variables for int8 mode
770-
self.quantized_build(kernel_shape, mode)
794+
self.quantized_build(kernel_shape, mode, config)
771795
self._kernel.assign(kernel_value)
772796
self.kernel_scale.assign(kernel_scale)
773797
elif mode == "int4":
774798
# 1. Quantize to int4 values (still int8 dtype, range [-8,7])
775-
kernel_value_int4, kernel_scale = quantizers.abs_max_quantize(
776-
self._kernel,
777-
axis=0,
778-
value_range=(-8, 7),
779-
dtype="int8",
780-
to_numpy=True,
799+
weight_quantizer = QuantizationConfig.weight_quantizer_or_default(
800+
config,
801+
quantizers.AbsMaxQuantizer(
802+
axis=0, value_range=(-8, 7), output_dtype="int8"
803+
),
804+
)
805+
kernel_value_int4, kernel_scale = weight_quantizer(
806+
self._kernel, to_numpy=True
781807
)
782808
kernel_scale = ops.squeeze(kernel_scale, axis=0)
783809
# 2. Pack two int4 values into a single int8 byte.
784810
packed_kernel_value, _, _ = quantizers.pack_int4(kernel_value_int4)
785811
del self._kernel
786812
# Build variables using the original kernel shape; _int4_build will
787813
# compute the packed shape internally.
788-
self.quantized_build(kernel_shape, mode)
814+
self.quantized_build(kernel_shape, mode, config)
789815
# Assign packed values.
790816
self._kernel.assign(packed_kernel_value)
791817
self.kernel_scale.assign(kernel_scale)

keras/src/layers/core/dense_test.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,67 @@
1717
from keras.src import testing
1818
from keras.src.backend.common import keras_tensor
1919
from keras.src.quantizers.gptq_config import GPTQConfig
20+
from keras.src.quantizers.quantization_config import Int4QuantizationConfig
21+
from keras.src.quantizers.quantization_config import Int8QuantizationConfig
22+
from keras.src.quantizers.quantizers import AbsMaxQuantizer
2023

2124

2225
class DenseTest(testing.TestCase):
26+
@parameterized.named_parameters(
27+
("int8", "int8", {"axis": 0}, {"axis": -1}),
28+
(
29+
"int4",
30+
"int4",
31+
{"axis": 0, "value_range": (-8, 7), "output_dtype": "int8"},
32+
{"axis": -1},
33+
),
34+
("int8_weight_only", "int8", {"axis": 0}, None),
35+
)
36+
def test_dense_quantize_config(
37+
self, mode, weight_quantizer_args, activation_quantizer_args
38+
):
39+
"""Test Dense quantization with QuantizationConfig."""
40+
layer = layers.Dense(units=32)
41+
layer.build((None, 8))
42+
43+
weight_quantizer = AbsMaxQuantizer(**weight_quantizer_args)
44+
if activation_quantizer_args is not None:
45+
activation_quantizer = AbsMaxQuantizer(**activation_quantizer_args)
46+
else:
47+
activation_quantizer = None
48+
49+
if mode == "int8":
50+
config = Int8QuantizationConfig(
51+
weight_quantizer=weight_quantizer,
52+
activation_quantizer=activation_quantizer,
53+
)
54+
elif mode == "int4":
55+
config = Int4QuantizationConfig(
56+
weight_quantizer=weight_quantizer,
57+
activation_quantizer=activation_quantizer,
58+
)
59+
60+
layer.quantize(mode, config=config)
61+
62+
if activation_quantizer_args is not None:
63+
# Verify inputs_quantizer is set correctly
64+
self.assertIsInstance(layer.inputs_quantizer, AbsMaxQuantizer)
65+
self.assertEqual(layer.inputs_quantizer.axis, (-1,))
66+
else:
67+
# Verify inputs_quantizer is None
68+
self.assertIsNone(layer.inputs_quantizer)
69+
70+
# Verify call works
71+
x = np.random.random((2, 8)).astype("float32")
72+
y = layer(x)
73+
self.assertEqual(y.shape, (2, 32))
74+
75+
if mode == "int4":
76+
# Verify kernel is int8 (packed int4)
77+
self.assertEqual(
78+
backend.standardize_dtype(layer._kernel.dtype), "int8"
79+
)
80+
2381
@pytest.mark.requires_trainable_backend
2482
def test_dense_basics(self):
2583
# 2D case, no bias.

0 commit comments

Comments
 (0)