@@ -50,6 +50,21 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te
5050
5151
5252class CFGDenoisedParams :
53+ def __init__ (self , x , sampling_step , total_sampling_steps , inner_model ):
54+ self .x = x
55+ """Latent image representation in the process of being denoised"""
56+
57+ self .sampling_step = sampling_step
58+ """Current Sampling step number"""
59+
60+ self .total_sampling_steps = total_sampling_steps
61+ """Total number of sampling steps planned"""
62+
63+ self .inner_model = inner_model
64+ """Inner model reference used for denoising"""
65+
66+
67+ class AfterCFGCallbackParams :
5368 def __init__ (self , x , sampling_step , total_sampling_steps ):
5469 self .x = x
5570 """Latent image representation in the process of being denoised"""
@@ -60,6 +75,10 @@ def __init__(self, x, sampling_step, total_sampling_steps):
6075 self .total_sampling_steps = total_sampling_steps
6176 """Total number of sampling steps planned"""
6277
78+ self .output_altered = False
79+ """A flag for CFGDenoiser indicating whether the output has been altered by the callback"""
80+
81+
6382
6483class UiTrainTabParams :
6584 def __init__ (self , txt2img_preview_params ):
@@ -84,6 +103,7 @@ def __init__(self, imgs, cols, rows):
84103 callbacks_image_saved = [],
85104 callbacks_cfg_denoiser = [],
86105 callbacks_cfg_denoised = [],
106+ callbacks_cfg_after_cfg = [],
87107 callbacks_before_component = [],
88108 callbacks_after_component = [],
89109 callbacks_image_grid = [],
@@ -174,6 +194,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
174194 report_exception (e , c , 'cfg_denoised_callback' )
175195
176196
197+ def cfg_after_cfg_callback (params : AfterCFGCallbackParams ):
198+ for c in callback_map ['callbacks_cfg_after_cfg' ]:
199+ try :
200+ c .callback (params )
201+ except Exception as e :
202+ report_exception (e , c , 'cfg_after_cfg_callback' )
203+
204+
177205def before_component_callback (component , ** kwargs ):
178206 for c in callback_map ['callbacks_before_component' ]:
179207 try :
@@ -315,6 +343,14 @@ def on_cfg_denoised(callback):
315343 add_callback (callback_map ['callbacks_cfg_denoised' ], callback )
316344
317345
346+ def on_cfg_after_cfg (callback ):
347+ """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
348+ The callback is called with one argument:
349+ - params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
350+ """
351+ add_callback (callback_map ['callbacks_cfg_after_cfg' ], callback )
352+
353+
318354def on_before_component (callback ):
319355 """register a function to be called before a component is created.
320356 The callback is called with arguments:
0 commit comments