From 2bff3c520fa2dd5dd0fdb311594ae11764b72f12 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 10 Dec 2025 14:06:16 +0000 Subject: [PATCH 1/4] Add nabla support --- comfy/ldm/kandinsky5/model.py | 101 ++++++++++++++++--- comfy/ldm/kandinsky5/utils_nabla.py | 147 ++++++++++++++++++++++++++++ 2 files changed, 234 insertions(+), 14 deletions(-) create mode 100644 comfy/ldm/kandinsky5/utils_nabla.py diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index 1509de2f83d7..df10e4496675 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -6,6 +6,12 @@ from comfy.ldm.modules.attention import optimized_attention from comfy.ldm.flux.math import apply_rope1 from comfy.ldm.flux.layers import EmbedND +from comfy.ldm.kandinsky5.utils_nabla import ( + fractal_flatten, + fractal_unflatten, + fast_sta_nabla, + nabla, +) def attention(q, k, v, heads, transformer_options={}): return optimized_attention( @@ -116,14 +122,17 @@ def _compute_qk(self, x, freqs, proj_fn, norm_fn): result = proj_fn(x).view(*x.shape[:-1], self.num_heads, -1) return apply_rope1(norm_fn(result), freqs) - def _forward(self, x, freqs, transformer_options={}): + def _forward(self, x, freqs, sparse_params=None, transformer_options={}): q = self._compute_qk(x, freqs, self.to_query, self.query_norm) k = self._compute_qk(x, freqs, self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + if sparse_params is None: + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + else: + out = nabla(q, k, v, sparse_params) return self.out_layer(out) - def _forward_chunked(self, x, freqs, transformer_options={}): + def _forward_chunked(self, x, freqs, sparse_params=None, transformer_options={}): def process_chunks(proj_fn, norm_fn): x_chunks = torch.chunk(x, self.num_chunks, dim=1) freqs_chunks = torch.chunk(freqs, self.num_chunks, dim=1) @@ -135,14 +144,17 @@ def process_chunks(proj_fn, norm_fn): q = process_chunks(self.to_query, self.query_norm) k = process_chunks(self.to_key, self.key_norm) v = self.to_value(x).view(*x.shape[:-1], self.num_heads, -1) - out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + if sparse_params is None: + out = attention(q, k, v, self.num_heads, transformer_options=transformer_options) + else: + out = nabla(q, k, v, sparse_params) return self.out_layer(out) - def forward(self, x, freqs, transformer_options={}): + def forward(self, x, freqs, sparse_params=None, transformer_options={}): if x.shape[1] > 8192: - return self._forward_chunked(x, freqs, transformer_options=transformer_options) + return self._forward_chunked(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options) else: - return self._forward(x, freqs, transformer_options=transformer_options) + return self._forward(x, freqs, sparse_params=sparse_params, transformer_options=transformer_options) class CrossAttention(SelfAttention): @@ -251,12 +263,12 @@ def __init__(self, model_dim, time_dim, ff_dim, head_dim, operation_settings=Non self.feed_forward_norm = operations.LayerNorm(model_dim, elementwise_affine=False, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")) self.feed_forward = FeedForward(model_dim, ff_dim, operation_settings=operation_settings) - def forward(self, visual_embed, text_embed, time_embed, freqs, transformer_options={}): + def forward(self, visual_embed, text_embed, time_embed, freqs, sparse_params=None, transformer_options={}): self_attn_params, cross_attn_params, ff_params = torch.chunk(self.visual_modulation(time_embed), 3, dim=-1) # self attention shift, scale, gate = get_shift_scale_gate(self_attn_params) visual_out = apply_scale_shift_norm(self.self_attention_norm, visual_embed, scale, shift) - visual_out = self.self_attention(visual_out, freqs, transformer_options=transformer_options) + visual_out = self.self_attention(visual_out, freqs, sparse_params=sparse_params, transformer_options=transformer_options) visual_embed = apply_gate_sum(visual_embed, visual_out, gate) # cross attention shift, scale, gate = get_shift_scale_gate(cross_attn_params) @@ -369,21 +381,80 @@ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_o visual_embed = self.visual_embeddings(x) visual_shape = visual_embed.shape[:-1] - visual_embed = visual_embed.flatten(1, -2) blocks_replace = patches_replace.get("dit", {}) transformer_options["total_blocks"] = len(self.visual_transformer_blocks) transformer_options["block_type"] = "double" + + B, _, T, H, W = x.shape + if T > 30: # 10 sec generation + assert self.patch_size[0] == 1 + + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0] + visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:]) + visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0) + + pt, ph, pw = self.patch_size + T, H, W = T // pt, H // ph, W // pw + + wT, wW, wH = 11, 11, 3 + sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device) + + sparse_params = dict( + sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0), + attention_type="nabla", + to_fractal=True, + P=0.8, + wT=wT, wW=wW, wH=wH, + add_sta=True, + visual_shape=(T, H, W), + method="topcdf", + ) + else: + sparse_params = None + visual_embed = visual_embed.flatten(1, -2) + for i, block in enumerate(self.visual_transformer_blocks): transformer_options["block_index"] = i if ("double_block", i) in blocks_replace: def block_wrap(args): - return block(x=args["x"], context=args["context"], time_embed=args["time_embed"], freqs=args["freqs"], transformer_options=args.get("transformer_options")) - visual_embed = blocks_replace[("double_block", i)]({"x": visual_embed, "context": context, "time_embed": time_embed, "freqs": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})["x"] + return block( + x=args["x"], + context=args["context"], + time_embed=args["time_embed"], + freqs=args["freqs"], + sparse_params=args.get("sparse_params"), + transformer_options=args.get("transformer_options"), + ) + visual_embed = blocks_replace[("double_block", i)]( + { + "x": visual_embed, + "context": context, + "time_embed": time_embed, + "freqs": freqs, + "sparse_params": sparse_params, + "transformer_options": transformer_options, + }, + {"original_block": block_wrap}, + )["x"] else: - visual_embed = block(visual_embed, context, time_embed, freqs=freqs, transformer_options=transformer_options) + visual_embed = block( + visual_embed, + context, + time_embed, + freqs=freqs, + sparse_params=sparse_params, + transformer_options=transformer_options, + ) + + if T > 30: + visual_embed = fractal_unflatten( + visual_embed[0], + visual_shape[1:], + ).unsqueeze(0) + else: + visual_embed = visual_embed.reshape(*visual_shape, -1) - visual_embed = visual_embed.reshape(*visual_shape, -1) return self.out_layer(visual_embed, time_embed) def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs): @@ -411,3 +482,5 @@ def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_op self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) + + diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py new file mode 100644 index 000000000000..705b1d75e324 --- /dev/null +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -0,0 +1,147 @@ +import math + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import BlockMask, flex_attention + + +def fractal_flatten(x, rope, shape): + pixel_size = 8 + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) + x = x.flatten(0, 1) + rope = rope.flatten(0, 1) + return x, rope + + +def fractal_unflatten(x, shape): + pixel_size = 8 + x = x.reshape(-1, pixel_size**2, x.shape[-1]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + return x + + +def local_patching(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + g1, + height // g2, + g2, + width // g3, + g3, + *x.shape[dim + 3 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 2, + dim + 4, + dim + 1, + dim + 3, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 2).flatten(dim + 1, dim + 3) + return x + + +def local_merge(x, shape, group_size, dim=0): + duration, height, width = shape + g1, g2, g3 = group_size + x = x.reshape( + *x.shape[:dim], + duration // g1, + height // g2, + width // g3, + g1, + g2, + g3, + *x.shape[dim + 2 :] + ) + x = x.permute( + *range(len(x.shape[:dim])), + dim, + dim + 3, + dim + 1, + dim + 4, + dim + 2, + dim + 5, + *range(dim + 6, len(x.shape)) + ) + x = x.flatten(dim, dim + 1).flatten(dim + 1, dim + 2).flatten(dim + 2, dim + 3) + return x + +def fast_sta_nabla(T: int, H: int, W: int, wT: int = 3, wH: int = 3, wW: int = 3, device="cuda") -> Tensor: + l = torch.Tensor([T, H, W]).amax() + r = torch.arange(0, l, 1, dtype=torch.int16, device=device) + mat = (r.unsqueeze(1) - r.unsqueeze(0)).abs() + sta_t, sta_h, sta_w = ( + mat[:T, :T].flatten(), + mat[:H, :H].flatten(), + mat[:W, :W].flatten(), + ) + sta_t = sta_t <= wT // 2 + sta_h = sta_h <= wH // 2 + sta_w = sta_w <= wW // 2 + sta_hw = ( + (sta_h.unsqueeze(1) * sta_w.unsqueeze(0)) + .reshape(H, H, W, W) + .transpose(1, 2) + .flatten() + ) + sta = ( + (sta_t.unsqueeze(1) * sta_hw.unsqueeze(0)) + .reshape(T, T, H * W, H * W) + .transpose(1, 2) + ) + return sta.reshape(T * H * W, T * H * W) + +def nablaT_v2(q: Tensor, k: Tensor, sta: Tensor, thr: float = 0.9) -> BlockMask: + # Map estimation + B, h, S, D = q.shape + s1 = S // 64 + qa = q.reshape(B, h, s1, 64, D).mean(-2) + ka = k.reshape(B, h, s1, 64, D).mean(-2).transpose(-2, -1) + map = qa @ ka + + map = torch.softmax(map / math.sqrt(D), dim=-1) + # Map binarization + vals, inds = map.sort(-1) + cvals = vals.cumsum_(-1) + mask = (cvals >= 1 - thr).int() + mask = mask.gather(-1, inds.argsort(-1)) + mask = torch.logical_or(mask, sta) + + # BlockMask creation + kv_nb = mask.sum(-1).to(torch.int32) + kv_inds = mask.argsort(dim=-1, descending=True).to(torch.int32) + return BlockMask.from_kv_blocks( + torch.zeros_like(kv_nb), kv_inds, kv_nb, kv_inds, BLOCK_SIZE=64, mask_mod=None + ) + +@torch.compile(mode="max-autotune-no-cudagraphs", dynamic=True) +def nabla(query, key, value, sparse_params=None): + query = query.transpose(1, 2).contiguous() + key = key.transpose(1, 2).contiguous() + value = value.transpose(1, 2).contiguous() + block_mask = nablaT_v2( + query, + key, + sparse_params["sta_mask"], + thr=sparse_params["P"], + ) + out = ( + flex_attention( + query, + key, + value, + block_mask=block_mask + ) + .transpose(1, 2) + .contiguous() + ) + out = out.flatten(-2, -1) + return out \ No newline at end of file From 0c84b7650fca765162a33028237603c3ada81664 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Tue, 16 Dec 2025 11:15:59 +0000 Subject: [PATCH 2/4] Add batch support for nabla --- comfy/ldm/kandinsky5/model.py | 15 +++++++-------- comfy/ldm/kandinsky5/utils_nabla.py | 13 ++++++------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index df10e4496675..ac935286758d 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,13 +387,12 @@ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_o transformer_options["block_type"] = "double" B, _, T, H, W = x.shape - if T > 30: # 10 sec generation + NABLA_THR = 40 # long (10 sec) generation + if T > NABLA_THR: assert self.patch_size[0] == 1 - freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:])[0] - visual_embed_4d, freqs = fractal_flatten(visual_embed[0], freqs, visual_shape[1:]) - visual_embed, freqs = visual_embed_4d.unsqueeze(0), freqs.unsqueeze(0) - + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:]) + visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:]) pt, ph, pw = self.patch_size T, H, W = T // pt, H // ph, W // pw @@ -447,11 +446,11 @@ def block_wrap(args): transformer_options=transformer_options, ) - if T > 30: + if T > NABLA_THR: visual_embed = fractal_unflatten( - visual_embed[0], + visual_embed, visual_shape[1:], - ).unsqueeze(0) + ) else: visual_embed = visual_embed.reshape(*visual_shape, -1) diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py index 705b1d75e324..5e2bc4076f8d 100644 --- a/comfy/ldm/kandinsky5/utils_nabla.py +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -7,20 +7,19 @@ def fractal_flatten(x, rope, shape): pixel_size = 8 - x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=0) - rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=0) - x = x.flatten(0, 1) - rope = rope.flatten(0, 1) + x = local_patching(x, shape, (1, pixel_size, pixel_size), dim=1) + rope = local_patching(rope, shape, (1, pixel_size, pixel_size), dim=1) + x = x.flatten(1, 2) + rope = rope.flatten(1, 2) return x, rope def fractal_unflatten(x, shape): pixel_size = 8 - x = x.reshape(-1, pixel_size**2, x.shape[-1]) - x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=0) + x = x.reshape(x.shape[0], -1, pixel_size**2, x.shape[-1]) + x = local_merge(x, shape, (1, pixel_size, pixel_size), dim=1) return x - def local_patching(x, shape, group_size, dim=0): duration, height, width = shape g1, g2, g3 = group_size From a3f78be5c27b9b5a2c292b5a0edf3b65c3dd1121 Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 17 Dec 2025 07:37:46 +0000 Subject: [PATCH 3/4] Add 128 divisibility for nabla --- comfy/ldm/kandinsky5/model.py | 4 +--- comfy_extras/nodes_kandinsky5.py | 3 +++ 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index ac935286758d..f4e02af70aaf 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -387,7 +387,7 @@ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_o transformer_options["block_type"] = "double" B, _, T, H, W = x.shape - NABLA_THR = 40 # long (10 sec) generation + NABLA_THR = 31 # long (10 sec) generation if T > NABLA_THR: assert self.patch_size[0] == 1 @@ -481,5 +481,3 @@ def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_op self, comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options) ).execute(x, timestep, context, y, time_dim_replace=time_dim_replace, transformer_options=transformer_options, **kwargs) - - diff --git a/comfy_extras/nodes_kandinsky5.py b/comfy_extras/nodes_kandinsky5.py index 9cb234be11d0..aaaf83566711 100644 --- a/comfy_extras/nodes_kandinsky5.py +++ b/comfy_extras/nodes_kandinsky5.py @@ -34,6 +34,9 @@ def define_schema(cls): @classmethod def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None) -> io.NodeOutput: + if length > 121: # 10 sec generation, for nabla + height = 128 * round(height / 128) + width = 128 * round(width / 128) latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) cond_latent_out = {} if start_image is not None: From 296b7c7b6d7111c5ee07a09576e7e1e6255aeeff Mon Sep 17 00:00:00 2001 From: Mihail Karaev Date: Wed, 17 Dec 2025 11:40:14 +0000 Subject: [PATCH 4/4] Small fixes --- comfy/ldm/kandinsky5/model.py | 7 +++++-- comfy/ldm/kandinsky5/utils_nabla.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/ldm/kandinsky5/model.py b/comfy/ldm/kandinsky5/model.py index f4e02af70aaf..24a06da0a772 100644 --- a/comfy/ldm/kandinsky5/model.py +++ b/comfy/ldm/kandinsky5/model.py @@ -391,19 +391,22 @@ def forward_orig(self, x, timestep, context, y, freqs, freqs_text, transformer_o if T > NABLA_THR: assert self.patch_size[0] == 1 + # pro video model uses lower P at higher resolutions + P = 0.7 if self.model_dim == 4096 and H * W >= 14080 else 0.9 + freqs = freqs.view(freqs.shape[0], *visual_shape[1:], *freqs.shape[2:]) visual_embed, freqs = fractal_flatten(visual_embed, freqs, visual_shape[1:]) pt, ph, pw = self.patch_size T, H, W = T // pt, H // ph, W // pw - wT, wW, wH = 11, 11, 3 + wT, wW, wH = 11, 3, 3 sta_mask = fast_sta_nabla(T, H // 8, W // 8, wT, wH, wW, device=x.device) sparse_params = dict( sta_mask=sta_mask.unsqueeze_(0).unsqueeze_(0), attention_type="nabla", to_fractal=True, - P=0.8, + P=P, wT=wT, wW=wW, wH=wH, add_sta=True, visual_shape=(T, H, W), diff --git a/comfy/ldm/kandinsky5/utils_nabla.py b/comfy/ldm/kandinsky5/utils_nabla.py index 5e2bc4076f8d..a346736b20ef 100644 --- a/comfy/ldm/kandinsky5/utils_nabla.py +++ b/comfy/ldm/kandinsky5/utils_nabla.py @@ -143,4 +143,4 @@ def nabla(query, key, value, sparse_params=None): .contiguous() ) out = out.flatten(-2, -1) - return out \ No newline at end of file + return out