Skip to content

Commit cb465b1

Browse files
authored
Merge pull request #352 from ashen-sensored/master
CFGDenoiser and script_callbacks modification for SAG
2 parents d3fbd79 + 0e39aa7 commit cb465b1

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

modules/script_callbacks.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps, te
5050

5151

5252
class CFGDenoisedParams:
53+
def __init__(self, x, sampling_step, total_sampling_steps, inner_model):
54+
self.x = x
55+
"""Latent image representation in the process of being denoised"""
56+
57+
self.sampling_step = sampling_step
58+
"""Current Sampling step number"""
59+
60+
self.total_sampling_steps = total_sampling_steps
61+
"""Total number of sampling steps planned"""
62+
63+
self.inner_model = inner_model
64+
"""Inner model reference used for denoising"""
65+
66+
67+
class AfterCFGCallbackParams:
5368
def __init__(self, x, sampling_step, total_sampling_steps):
5469
self.x = x
5570
"""Latent image representation in the process of being denoised"""
@@ -60,6 +75,10 @@ def __init__(self, x, sampling_step, total_sampling_steps):
6075
self.total_sampling_steps = total_sampling_steps
6176
"""Total number of sampling steps planned"""
6277

78+
self.output_altered = False
79+
"""A flag for CFGDenoiser indicating whether the output has been altered by the callback"""
80+
81+
6382

6483
class UiTrainTabParams:
6584
def __init__(self, txt2img_preview_params):
@@ -84,6 +103,7 @@ def __init__(self, imgs, cols, rows):
84103
callbacks_image_saved=[],
85104
callbacks_cfg_denoiser=[],
86105
callbacks_cfg_denoised=[],
106+
callbacks_cfg_after_cfg=[],
87107
callbacks_before_component=[],
88108
callbacks_after_component=[],
89109
callbacks_image_grid=[],
@@ -174,6 +194,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams):
174194
report_exception(e, c, 'cfg_denoised_callback')
175195

176196

197+
def cfg_after_cfg_callback(params: AfterCFGCallbackParams):
198+
for c in callback_map['callbacks_cfg_after_cfg']:
199+
try:
200+
c.callback(params)
201+
except Exception as e:
202+
report_exception(e, c, 'cfg_after_cfg_callback')
203+
204+
177205
def before_component_callback(component, **kwargs):
178206
for c in callback_map['callbacks_before_component']:
179207
try:
@@ -315,6 +343,14 @@ def on_cfg_denoised(callback):
315343
add_callback(callback_map['callbacks_cfg_denoised'], callback)
316344

317345

346+
def on_cfg_after_cfg(callback):
347+
"""register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations are completed.
348+
The callback is called with one argument:
349+
- params: AfterCFGCallbackParams - parameters to be passed to the script for post-processing after cfg calculation.
350+
"""
351+
add_callback(callback_map['callbacks_cfg_after_cfg'], callback)
352+
353+
318354
def on_before_component(callback):
319355
"""register a function to be called before a component is created.
320356
The callback is called with arguments:

modules/sd_samplers_kdiffusion.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import modules.shared as shared
99
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
1010
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
11+
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
1112

1213
samplers_k_diffusion = [
1314
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -145,7 +146,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
145146

146147
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
147148

148-
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
149+
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps, self.inner_model)
149150
cfg_denoised_callback(denoised_params)
150151

151152
devices.test_for_nans(x_out, "unet")
@@ -165,6 +166,11 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
165166
if self.mask is not None:
166167
denoised = self.init_latent * self.mask + self.nmask * denoised
167168

169+
after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
170+
cfg_after_cfg_callback(after_cfg_callback_params)
171+
if after_cfg_callback_params.output_altered:
172+
denoised = after_cfg_callback_params.x
173+
168174
self.step += 1
169175

170176
return denoised

0 commit comments

Comments
 (0)