Skip to content

[Model] Add support for the multi-modal Llama 3.2 model#8811

Merged
simon-mo merged 82 commits intovllm-project:mainfrom
heheda12345:mllama_model
Sep 25, 2024
Merged

[Model] Add support for the multi-modal Llama 3.2 model#8811
simon-mo merged 82 commits intovllm-project:mainfrom
heheda12345:mllama_model

Conversation

@heheda12345
Copy link
Copy Markdown
Collaborator

@heheda12345 heheda12345 commented Sep 25, 2024

This PR adds support for the Llama3.2's vision models.

Currently, we are blocked by a HuggingFace Transformers release after huggingface/transformers#33703

We will make a release as soon as this PR is merged.

Closes #8812

@mgoin
Copy link
Copy Markdown
Member

mgoin commented Sep 25, 2024

Just leaving this log here as I experiment with quantization on-top of this work:

vllm serve mgoin/Llama-3.2-11B-Vision-Instruct-FP8-Dynamic

WARNING 09-25 19:05:44 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead, and make sure to uninstall `pynvml`. When both of them are installed, `pynvml` will take precedence and cause errors. See https://pypi.org/project/pynvml for more information.
/home/mgoin/code/vllm/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
INFO 09-25 19:05:46 api_server.py:526] vLLM API server version dev
INFO 09-25 19:05:46 api_server.py:527] args: Namespace(model_tag='Llama-3.2-11B-Vision-Instruct-FP8-Dynamic', config='', host=None, port=8000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, prompt_adapters=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], return_tokens_as_token_ids=False, disable_frontend_multiprocessing=False, enable_auto_tool_choice=False, tool_call_parser=None, model='Llama-3.2-11B-Vision-Instruct-FP8-Dynamic', tokenizer=None, skip_tokenizer_init=False, revision=None, code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', config_format='auto', dtype='auto', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=None, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=1, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, cpu_offload_gb=0, gpu_memory_utilization=0.9, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=256, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, limit_mm_per_prompt=None, mm_processor_kwargs=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, enable_prompt_adapter=False, max_prompt_adapters=1, max_prompt_adapter_token=0, device='auto', num_scheduler_steps=1, multi_step_stream_outputs=False, scheduler_delay_factor=0.0, enable_chunked_prefill=None, speculative_model=None, speculative_model_quantization=None, num_speculative_tokens=None, speculative_draft_tensor_parallel_size=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, spec_decoding_acceptance_method='rejection_sampler', typical_acceptance_sampler_posterior_threshold=None, typical_acceptance_sampler_posterior_alpha=None, disable_logprobs_during_spec_decoding=None, model_loader_extra_config=None, ignore_patterns=[], preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, otlp_traces_endpoint=None, collect_detailed_traces=None, disable_async_output_proc=False, override_neuron_config=None, disable_log_requests=False, max_log_len=None, disable_fastapi_docs=False, dispatch_function=<function serve at 0x7c76acb9eb00>)
INFO 09-25 19:05:46 api_server.py:164] Multiprocessing frontend to use ipc:///tmp/0ee642b0-ab8e-4620-b5e2-0dd7f4ba90d5 for IPC Path.
INFO 09-25 19:05:46 api_server.py:177] Started engine process with PID 766178
WARNING 09-25 19:05:46 arg_utils.py:940] The model has a long context length (131072). This may cause OOM errors during the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-model-len to a smaller value.
WARNING 09-25 19:05:47 cuda.py:22] You are using a deprecated `pynvml` package. Please install `nvidia-ml-py` instead, and make sure to uninstall `pynvml`. When both of them are installed, `pynvml` will take precedence and cause errors. See https://pypi.org/project/pynvml for more information.
/home/mgoin/code/vllm/vllm/connections.py:8: RuntimeWarning: Failed to read commit hash:
No module named 'vllm._version'
  from vllm.version import __version__ as VLLM_VERSION
WARNING 09-25 19:05:49 arg_utils.py:940] The model has a long context length (131072). This may cause OOM errors during the initial memory profiling phase, or result in low performance due to small KV cache space. Consider setting --max-model-len to a smaller value.
INFO 09-25 19:05:49 llm_engine.py:226] Initializing an LLM engine (vdev) with config: model='Llama-3.2-11B-Vision-Instruct-FP8-Dynamic', speculative_config=None, tokenizer='Llama-3.2-11B-Vision-Instruct-FP8-Dynamic', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, rope_scaling=None, rope_theta=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=131072, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=compressed-tensors, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), observability_config=ObservabilityConfig(otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=Llama-3.2-11B-Vision-Instruct-FP8-Dynamic, use_v2_block_manager=False, num_scheduler_steps=1, multi_step_stream_outputs=False, enable_prefix_caching=False, use_async_output_proc=True, use_cached_outputs=True, mm_processor_kwargs=None)
INFO 09-25 19:05:49 enc_dec_model_runner.py:140] EncoderDecoderModelRunner requires XFormers backend; overriding backend auto-selection and forcing XFormers.
INFO 09-25 19:05:49 selector.py:116] Using XFormers backend.
/home/mgoin/venvs/vllm/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:211: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
/home/mgoin/venvs/vllm/lib/python3.10/site-packages/xformers/ops/fmha/flash.py:344: FutureWarning: `torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
INFO 09-25 19:05:51 model_runner.py:1014] Starting to load model Llama-3.2-11B-Vision-Instruct-FP8-Dynamic...
INFO 09-25 19:05:51 selector.py:116] Using XFormers backend.
Loading safetensors checkpoint shards:   0% Completed | 0/3 [00:00<?, ?it/s]
Process SpawnProcess-1:
Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/mgoin/code/vllm/vllm/engine/multiprocessing/engine.py", line 388, in run_mp_engine
    engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
  File "/home/mgoin/code/vllm/vllm/engine/multiprocessing/engine.py", line 138, in from_engine_args
    return cls(
  File "/home/mgoin/code/vllm/vllm/engine/multiprocessing/engine.py", line 78, in __init__
    self.engine = LLMEngine(*args,
  File "/home/mgoin/code/vllm/vllm/engine/llm_engine.py", line 325, in __init__
    self.model_executor = executor_class(
  File "/home/mgoin/code/vllm/vllm/executor/executor_base.py", line 47, in __init__
    self._init_executor()
  File "/home/mgoin/code/vllm/vllm/executor/gpu_executor.py", line 40, in _init_executor
    self.driver_worker.load_model()
  File "/home/mgoin/code/vllm/vllm/worker/worker.py", line 183, in load_model
    self.model_runner.load_model()
  File "/home/mgoin/code/vllm/vllm/worker/model_runner.py", line 1016, in load_model
    self.model = get_model(model_config=self.model_config,
  File "/home/mgoin/code/vllm/vllm/model_executor/model_loader/__init__.py", line 19, in get_model
    return loader.load_model(model_config=model_config,
  File "/home/mgoin/code/vllm/vllm/model_executor/model_loader/loader.py", line 403, in load_model
    model.load_weights(self._get_all_weights(model_config, model))
  File "/home/mgoin/code/vllm/vllm/model_executor/models/mllama.py", line 1123, in load_weights
    param = params_dict[name]
KeyError: 'language_model.model.layers.13.cross_attn.qkv_proj.weight_scale'

Copy link
Copy Markdown
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! Thank you for the contribution!

@yukiarimo
Copy link
Copy Markdown

Is there a fine-tuning colab for this?

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Sep 25, 2024

Latest transformers release broke a few tests on our CI. Investigating...

@ArthurZucker
Copy link
Copy Markdown
Contributor

Anything we can do to help with the failing tests

@ywang96
Copy link
Copy Markdown
Member

ywang96 commented Sep 26, 2024

@ArthurZucker Hey thanks for reaching out and offering the help!

We already made a release and have a good sense how to fix the failing tests on CI, so hopefully they'll be all resolved during day time here (most of us are in PST), thanks!

@kinchahoy
Copy link
Copy Markdown

Hey - is this ready to try? Happy to try a fork

@heheda12345
Copy link
Copy Markdown
Collaborator Author

Yes. It is ready now.

Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
…#8811)

Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Alvant <alvasian@yandex.ru>
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
…#8811)

Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
…#8811)

Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Signed-off-by: Sumit Dubey <sumit.dubey2@ibm.com>
LeiWang1999 pushed a commit to LeiWang1999/vllm-bitblas that referenced this pull request Mar 26, 2025
…#8811)

Co-authored-by: simon-mo <xmo@berkeley.edu>
Co-authored-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[New Model]: Llama 3.2

8 participants