From 6123fdd3ff91c3bba4f75e44be156f0f7f1b8a22 Mon Sep 17 00:00:00 2001 From: Jacky Fang Date: Mon, 29 Jun 2026 08:37:55 +0000 Subject: [PATCH] docs: resolve leftover lora_llama3_demo.ipynb comments and relative code links in lora_model_bringup.md --- docs/guides/lora_model_bringup.md | 8 ++++---- src/maxtext/examples/lora_llama3_demo.ipynb | 16 +++++----------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/docs/guides/lora_model_bringup.md b/docs/guides/lora_model_bringup.md index ecb85ced89..dbc4c2ff0e 100644 --- a/docs/guides/lora_model_bringup.md +++ b/docs/guides/lora_model_bringup.md @@ -30,12 +30,12 @@ To enable LoRA support for a new model, follow these two simple steps: The target model architecture must already be implemented and supported as a base model in MaxText. -- The JAX/NNX model definition should be located under `src/maxtext/models/` (e.g., \[gemma3.py\](../../src/maxtext/models/gemma3.py)). +- The JAX/NNX model definition should be located under `src/maxtext/models/` (e.g., [gemma3.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/models/gemma3.py)). - The model configurations must be registered and runnable for baseline pre-training or full fine-tuning. ### Step 1.2: Define Trainable LoRA Target Modules -Add a recommended target pattern for your model architecture prefix in \[src/maxtext/configs/post_train/lora_module_path.yml\](../../src/maxtext/configs/post_train/lora_module_path.yml): +Add a recommended target pattern for your model architecture prefix in [src/maxtext/configs/post_train/lora_module_path.yml](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/configs/post_train/lora_module_path.yml): ```yaml your_model_prefix: "decoder/layers/.*(self_attention/(query|key|value|out)|mlp/(wi_0|wi_1|wo))" @@ -69,7 +69,7 @@ If you want to perform decoding or run high-performance serving on your adapted To add weight mapping for vLLM decode: 1. **Create a Weight Mapping Config**: - Create a new file in \[src/maxtext/integration/tunix/weight_mapping/\](../../src/maxtext/integration/tunix/weight_mapping/) (e.g., `your_model.py`) defining a mapping dataclass. You can refer to \[gemma3.py\](../../src/maxtext/integration/tunix/weight_mapping/gemma3.py) or \[llama3.py\](../../src/maxtext/integration/tunix/weight_mapping/llama3.py) as templates. + Create a new file in [src/maxtext/integration/tunix/weight_mapping/](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/integration/tunix/weight_mapping/) (e.g., `your_model.py`) defining a mapping dataclass. You can refer to [llama3.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/integration/tunix/weight_mapping/llama3.py) as a template. Your class should specify: @@ -78,7 +78,7 @@ To add weight mapping for vLLM decode: - `lora_to_hf_mappings()`: Custom mapping for LoRA weights if they require different handling. 2. **Register the Mapping**: - Register your new class in \[src/maxtext/integration/tunix/weight_mapping/__init__.py\](../../src/maxtext/integration/tunix/weight_mapping/__init__.py) inside the `StandaloneVllmWeightMapping` class: + Register your new class in [src/maxtext/integration/tunix/weight_mapping/__init__.py](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/integration/tunix/weight_mapping/__init__.py) inside the `StandaloneVllmWeightMapping` class: ```python # Inside StandaloneVllmWeightMapping diff --git a/src/maxtext/examples/lora_llama3_demo.ipynb b/src/maxtext/examples/lora_llama3_demo.ipynb index 3777d6f96a..d324419093 100644 --- a/src/maxtext/examples/lora_llama3_demo.ipynb +++ b/src/maxtext/examples/lora_llama3_demo.ipynb @@ -79,6 +79,10 @@ " # Install uv, a fast Python package installer\n", " !pip install uv\n", " \n", + " # Set the torch backend to CPU for uv\n", + " import os\n", + " os.environ[\"UV_TORCH_BACKEND\"] = \"cpu\"\n", + " \n", " # Install MaxText and post-training dependencies\n", " !uv pip install -e .[tpu-post-train] --resolution=lowest\n", " !install_tpu_post_train_extra_deps" @@ -215,16 +219,6 @@ "outputs": [], "source": [ "if not epath.Path(MODEL_CHECKPOINT_PATH).exists():\n", - " # Install torch for the conversion script\n", - " print(\"Installing torch...\")\n", - " subprocess.run(\n", - " [\n", - " sys.executable, \"-m\", \"pip\", \"install\",\n", - " \"torch\", \"--index-url\", \"https://download.pytorch.org/whl/cpu\"\n", - " ],\n", - " check=True\n", - " )\n", - "\n", " print(\"Converting checkpoint from HuggingFace...\")\n", " env = os.environ.copy()\n", " env[\"JAX_PLATFORMS\"] = \"cpu\"\n", @@ -607,4 +601,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +}