|
24 | 24 |
|
25 | 25 | import functools |
26 | 26 | import os |
| 27 | +import threading |
27 | 28 |
|
28 | 29 | os.environ["TLLM_DISABLE_MPI"] = "1" |
29 | 30 |
|
|
38 | 39 | try: |
39 | 40 | from tensorrt_llm._torch.visual_gen.attention_backend import UlyssesAttention |
40 | 41 | from tensorrt_llm._torch.visual_gen.attention_backend.trtllm import TrtllmAttention |
| 42 | + from tensorrt_llm._torch.visual_gen.config import create_attention_metadata_state |
41 | 43 | from tensorrt_llm._utils import get_free_port |
42 | 44 |
|
43 | 45 | MODULES_AVAILABLE = True |
| 46 | + ATTENTION_META_DICT = threading.local() |
| 47 | + ATTENTION_META_DICT.metadata = create_attention_metadata_state() |
44 | 48 | except ImportError: |
45 | 49 | MODULES_AVAILABLE = False |
46 | 50 |
|
@@ -133,6 +137,7 @@ def _logic_sage_ulysses_forward(rank, world_size, *, sage_attn_qk_int8: bool): |
133 | 137 | sage_attn_num_elts_per_blk_k=blk_k, |
134 | 138 | sage_attn_num_elts_per_blk_v=1, |
135 | 139 | sage_attn_qk_int8=sage_attn_qk_int8, |
| 140 | + attention_metadata_state=ATTENTION_META_DICT.metadata, |
136 | 141 | ) |
137 | 142 | attention = UlyssesAttention(inner_backend=inner, process_group=None) |
138 | 143 |
|
@@ -189,6 +194,7 @@ def _logic_sage_ulysses_vs_reference( |
189 | 194 | sage_attn_num_elts_per_blk_k=sage_attn_num_elts_per_blk_k, |
190 | 195 | sage_attn_num_elts_per_blk_v=1, |
191 | 196 | sage_attn_qk_int8=sage_attn_qk_int8, |
| 197 | + attention_metadata_state=ATTENTION_META_DICT.metadata, |
192 | 198 | ) |
193 | 199 | attention = UlyssesAttention(inner_backend=inner, process_group=None) |
194 | 200 |
|
|
0 commit comments