1111from keras .src .api_export import keras_export
1212from keras .src .layers .input_spec import InputSpec
1313from keras .src .layers .layer import Layer
14+ from keras .src .quantizers .quantization_config import QuantizationConfig
15+ from keras .src .quantizers .quantization_config import validate_and_resolve_config
1416from keras .src .quantizers .quantizers import dequantize_with_sz_map
1517
1618
@@ -378,9 +380,9 @@ def variable_serialization_spec(self):
378380
379381 def quantized_build (self , kernel_shape , mode , config = None ):
380382 if mode == "int8" :
381- self ._int8_build (kernel_shape )
383+ self ._int8_build (kernel_shape , config )
382384 elif mode == "int4" :
383- self ._int4_build (kernel_shape )
385+ self ._int4_build (kernel_shape , config )
384386 elif mode == "float8" :
385387 self ._float8_build ()
386388 elif mode == "gptq" :
@@ -389,8 +391,13 @@ def quantized_build(self, kernel_shape, mode, config=None):
389391 raise self ._quantization_mode_error (mode )
390392 self ._is_quantized = True
391393
392- def _int8_build (self , kernel_shape ):
393- self .inputs_quantizer = quantizers .AbsMaxQuantizer (axis = - 1 )
394+ def _int8_build (self , kernel_shape , config = None ):
395+ self .inputs_quantizer = (
396+ QuantizationConfig .activation_quantizer_or_default (
397+ config , quantizers .AbsMaxQuantizer (axis = - 1 )
398+ )
399+ )
400+
394401 self ._kernel = self .add_weight (
395402 name = "kernel" ,
396403 shape = kernel_shape ,
@@ -489,7 +496,7 @@ def _gptq_call(self, inputs, training=False):
489496 y = self .activation (y )
490497 return y
491498
492- def _int4_build (self , kernel_shape ):
499+ def _int4_build (self , kernel_shape , config = None ):
493500 """Build variables for int4 quantization.
494501
495502 `kernel_shape` is the *original* float32 kernel shape
@@ -498,8 +505,10 @@ def _int4_build(self, kernel_shape):
498505 int8 byte.
499506 """
500507 # Per-channel int8 quantizer for the last axis (features).
501- self .inputs_quantizer = quantizers .AbsMaxQuantizer (
502- axis = - 1 ,
508+ self .inputs_quantizer = (
509+ QuantizationConfig .activation_quantizer_or_default (
510+ config , quantizers .AbsMaxQuantizer (axis = - 1 )
511+ )
503512 )
504513 input_dim , output_dim = kernel_shape
505514 packed_rows = (input_dim + 1 ) // 2 # ceil for odd dims
@@ -588,11 +597,15 @@ def grad_fn(*args, upstream=None):
588597 inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
589598 return (inputs_grad , None , None )
590599
591- inputs , inputs_scale = self .inputs_quantizer (inputs )
600+ output_scale = kernel_scale
601+ if self .inputs_quantizer :
602+ inputs , inputs_scale = self .inputs_quantizer (inputs )
603+ output_scale = ops .multiply (output_scale , inputs_scale )
604+
592605 x = ops .matmul (inputs , kernel )
593606 # De-scale outputs
594607 x = ops .cast (x , self .compute_dtype )
595- x = ops .divide (x , ops . multiply ( inputs_scale , kernel_scale ) )
608+ x = ops .divide (x , output_scale )
596609 return x , grad_fn
597610
598611 x = matmul_with_inputs_gradient (
@@ -639,10 +652,15 @@ def grad_fn(*args, upstream=None):
639652 inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
640653 return (inputs_grad , None , None )
641654
642- inputs , inputs_scale = self .inputs_quantizer (inputs )
655+ output_scale = kernel_scale
656+
657+ if self .inputs_quantizer :
658+ inputs , inputs_scale = self .inputs_quantizer (inputs )
659+ output_scale = ops .multiply (output_scale , inputs_scale )
660+
643661 x = ops .matmul (inputs , unpacked_kernel )
644662 x = ops .cast (x , self .compute_dtype )
645- x = ops .divide (x , ops . multiply ( inputs_scale , kernel_scale ) )
663+ x = ops .divide (x , output_scale )
646664 return x , grad_fn
647665
648666 x = matmul_with_inputs_gradient (
@@ -754,38 +772,46 @@ def grad(*args, upstream=None, variables=None):
754772 x = self .activation (x )
755773 return x
756774
757- def quantize (self , mode , type_check = True , config = None ):
775+ def quantize (self , mode = None , type_check = True , config = None ):
758776 # Prevent quantization of the subclasses
759777 if type_check and (type (self ) is not Dense ):
760778 raise self ._not_implemented_error (self .quantize )
761779
780+ config = validate_and_resolve_config (mode , config )
781+ mode = config .mode
782+
762783 kernel_shape = self ._kernel .shape
763784 if mode == "int8" :
764- kernel_value , kernel_scale = quantizers .abs_max_quantize (
765- self ._kernel , axis = 0 , to_numpy = True
785+ weight_quantizer = QuantizationConfig .weight_quantizer_or_default (
786+ config , quantizers .AbsMaxQuantizer (axis = 0 )
787+ )
788+ kernel_value , kernel_scale = weight_quantizer (
789+ self ._kernel , to_numpy = True
766790 )
767791 kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
768792 del self ._kernel
769793 # Build variables for int8 mode
770- self .quantized_build (kernel_shape , mode )
794+ self .quantized_build (kernel_shape , mode , config )
771795 self ._kernel .assign (kernel_value )
772796 self .kernel_scale .assign (kernel_scale )
773797 elif mode == "int4" :
774798 # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
775- kernel_value_int4 , kernel_scale = quantizers .abs_max_quantize (
776- self ._kernel ,
777- axis = 0 ,
778- value_range = (- 8 , 7 ),
779- dtype = "int8" ,
780- to_numpy = True ,
799+ weight_quantizer = QuantizationConfig .weight_quantizer_or_default (
800+ config ,
801+ quantizers .AbsMaxQuantizer (
802+ axis = 0 , value_range = (- 8 , 7 ), output_dtype = "int8"
803+ ),
804+ )
805+ kernel_value_int4 , kernel_scale = weight_quantizer (
806+ self ._kernel , to_numpy = True
781807 )
782808 kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
783809 # 2. Pack two int4 values into a single int8 byte.
784810 packed_kernel_value , _ , _ = quantizers .pack_int4 (kernel_value_int4 )
785811 del self ._kernel
786812 # Build variables using the original kernel shape; _int4_build will
787813 # compute the packed shape internally.
788- self .quantized_build (kernel_shape , mode )
814+ self .quantized_build (kernel_shape , mode , config )
789815 # Assign packed values.
790816 self ._kernel .assign (packed_kernel_value )
791817 self .kernel_scale .assign (kernel_scale )
0 commit comments