@@ -289,12 +289,14 @@ def forward_hook(module, args, output):
289289
290290class HpuModelAdapter :
291291
292- def __init__ (self , model , block_size , dtype , enforce_eager ):
292+ def __init__ (self , model , vllm_config ):
293293 self .model = model
294294 self .prefill_use_fusedsdpa = os .getenv ('VLLM_PROMPT_USE_FUSEDSDPA' ,
295295 '0' ).lower () in ['1' , 'true' ]
296- self .block_size = block_size
297- self .dtype = dtype
296+ self .vllm_config = vllm_config
297+ self .block_size = vllm_config .cache_config .block_size
298+ self .dtype = vllm_config .model_config .dtype
299+ enforce_eager = vllm_config .model_config .enforce_eager
298300 if not htorch .utils .internal .is_lazy () and not enforce_eager :
299301 self .model = torch .compile (self .model ,
300302 backend = 'hpu_backend' ,
@@ -353,14 +355,20 @@ def forward(self, *args, **kwargs):
353355 selected_token_indices = kwargs .pop ('selected_token_indices' )
354356 if 'warmup_mode' in kwargs :
355357 kwargs .pop ('warmup_mode' )
358+ virtual_engine = 0
359+ if 'virtual_engine' in kwargs :
360+ virtual_engine = kwargs .pop ('virtual_engine' )
356361 input_ids = kwargs ['input_ids' ]
357362 kwargs ['attn_metadata' ] = self ._update_metadata (
358363 kwargs ['attn_metadata' ], input_ids .size (0 ), input_ids .size (1 ),
359364 input_ids .device , self .dtype )
360365 LoraMask .setLoraMask (kwargs .pop ('lora_mask' ))
361- hidden_states = self .model (* args , ** kwargs )
362- hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
363- hidden_states = hidden_states .index_select (0 , selected_token_indices )
366+ with set_forward_context (kwargs ['attn_metadata' ], self .vllm_config ,
367+ virtual_engine ):
368+ hidden_states = self .model (* args , ** kwargs )
369+ hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
370+ hidden_states = hidden_states .index_select (0 ,
371+ selected_token_indices )
364372 return hidden_states
365373
366374 def compute_logits (self , * args , ** kwargs ):
@@ -660,10 +668,7 @@ def load_model(self) -> None:
660668
661669 with HabanaMemoryProfiler () as m_wrap :
662670 self .model = _maybe_wrap_in_hpu_graph (
663- self .model ,
664- self .block_size ,
665- dtype = self .model_config .dtype ,
666- enforce_eager = self .enforce_eager )
671+ self .model , vllm_config = self .vllm_config )
667672 msg = f"Wrapping in HPU Graph took { m_wrap .get_summary_string ()} "
668673 logger .info (msg )
669674
@@ -1934,6 +1939,7 @@ def execute_model(
19341939 "attn_metadata" : self .trim_attn_metadata (attn_metadata ),
19351940 "intermediate_tensors" : intermediate_tensors ,
19361941 "lora_mask" : lora_mask ,
1942+ "virtual_engine" : model_input .virtual_engine ,
19371943 ** (model_input .multi_modal_kwargs or {}),
19381944 }
19391945 if htorch .utils .internal .is_lazy ():
@@ -1948,11 +1954,7 @@ def execute_model(
19481954 f"graphs{ 'T' if use_graphs else 'F' } " )
19491955 else :
19501956 model_event_name = 'model_executable'
1951- with set_forward_context (
1952- model_input .attn_metadata , self .vllm_config ,
1953- model_input .virtual_engine ), \
1954- self .profiler .record_event (
1955- 'internal' , model_event_name ):
1957+ with self .profiler .record_event ('internal' , model_event_name ):
19561958 hidden_states = self .model .forward (
19571959 ** execute_model_kwargs ,
19581960 selected_token_indices = sampling_metadata .selected_token_indices
0 commit comments