@@ -523,10 +523,10 @@ def get_qdq_module(
523523 "export_training_ir_rollout_check" ,
524524 return_value = False ,
525525 ):
526- m = torch .export .export (
526+ m_with_patch = torch .export .export (
527527 module , inputs , dynamic_shapes = dynamic_shapes , strict = False
528528 ).module ()
529- draw_graph ("export_with_patch" , "." , m )
529+ draw_graph ("export_with_patch" , "." , m_with_patch )
530530 m = torch .export .export (
531531 module , inputs , dynamic_shapes = dynamic_shapes , strict = False
532532 ).module ()
@@ -541,6 +541,9 @@ def get_qdq_module(
541541 )
542542 if block_size_map is not None :
543543 quantizer .set_block_size_map (block_size_map )
544+ prepared_with_patch = prepare_pt2e (m_with_patch , quantizer )
545+ prepared_with_patch (* inputs )
546+ quantized_module_with_patch = convert_pt2e (prepared_with_patch )
544547 prepared = prepare_pt2e (m , quantizer )
545548 prepared (* inputs )
546549 quantized_module = convert_pt2e (prepared )
@@ -555,6 +558,8 @@ def get_qdq_module(
555558 }
556559 if not bypass_check :
557560 self .assertTrue (nodes .intersection (q_and_dq ))
561+ draw_graph ("convert_pt2e_with_patch" , "." , quantized_module_with_patch )
562+ draw_graph ("convert_pt2e" , "." , quantized_module )
558563 return quantized_module
559564
560565 def get_prepared_qat_module (
0 commit comments