diff --git a/docs/en/advance/update_weights.md b/docs/en/advance/update_weights.md new file mode 100644 index 0000000000..3237b0f688 --- /dev/null +++ b/docs/en/advance/update_weights.md @@ -0,0 +1,78 @@ +# Update Weights + +LMDeploy supports update model weights online for scenes such as RL training. Here are the steps to do so. + +## Step 1: Launch server + +For pytorch backend you have to add `--distributed-executor-backend ray`. + +```shell +lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend +``` + +## Step 2: Offloads weights & kv cache + +Before update model weights, the server should offloads weights and kv cache. + +```python +from lmdeploy.utils import serialize_state_dict +import requests + +BASE_URL = 'http://0.0.0.0:23333' +api_key = 'sk-xxx' + +headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + +# offloads weights and kv cache with level=2 +response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2)) +assert response.status_code == 200, response.status_code + +# wake up weights, the server is ready for update weights +response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights'])) +assert response.status_code == 200, response.status_code +``` + +## Step 3: Update weights + +Split model weights into multi segments and update through `update_weights` endpoint. + +```python +segmented_state_dict: List[Dict[str, torch.Tensor]] = ... +num_segment = len(segmented_state_dict) +for seg_idx in range(num_segment): + serialized_data = serialize_state_dict(segmented_state_dict[seg_idx]) + data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1) + response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + +``` + +**Note**: For pytorch backend, lmdeploy also supports flattened bucket tensors: + +```python +from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata + +segmented_state_dict: List[Dict[str, torch.Tensor]] = ... +num_segment = len(segmented_state_dict) +for seg_idx in range(num_segment): + named_tensors = list(segmented_state_dict[seg_idx].items()) + bucket = FlattenedTensorBucket(named_tensors=named_tensors) + metadata = bucket.get_metadata() + flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata) + serialized_data = serialize_state_dict(flattened_tensor_data) + data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket') + response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" +``` + +## Step 4: Wakeup server + +After update model weights, the server should onloads kv cache and provide serving again with the new updated weights. + +```python +response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache'])) +assert response.status_code == 200, response.status_code +``` diff --git a/docs/en/index.rst b/docs/en/index.rst index 46d3cb34ee..2199ec7833 100644 --- a/docs/en/index.rst +++ b/docs/en/index.rst @@ -105,6 +105,7 @@ Documentation advance/metrics.md advance/context_parallel.md advance/spec_decoding.md + advance/update_weights.md .. toctree:: :maxdepth: 1 diff --git a/docs/zh_cn/advance/update_weights.md b/docs/zh_cn/advance/update_weights.md new file mode 100644 index 0000000000..f4843f7dc7 --- /dev/null +++ b/docs/zh_cn/advance/update_weights.md @@ -0,0 +1,78 @@ +# 权重更新 + +LMDeploy支持在线权重更新,方便RL训练等场景下的使用。以下是权重更新的步骤: + +## 步骤 1: 启动服务 + +For pytorch backend you have to add `--distributed-executor-backend ray`. + +```shell +lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --distributed-executor-backend ray # for pytorch backend +``` + +## 步骤 2: 卸载权重和KV缓存 + +在权重更新前,需要调用API卸载权重和KV缓存,使推理引擎处于可更新状态: + +```python +from lmdeploy.utils import serialize_state_dict +import requests + +BASE_URL = 'http://0.0.0.0:23333' +api_key = 'sk-xxx' + +headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + } + +# offloads weights and kv cache with level=2 +response = requests.post(f"{BASE_URL}/sleep", headers=headers, params=dict(tags=['weights', 'kv_cache'], level=2)) +assert response.status_code == 200, response.status_code + +# wake up weights, the server is ready for update weights +response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['weights'])) +assert response.status_code == 200, response.status_code +``` + +## 步骤 3: 更新权重 + +将模型权重切分后调用`update_weights`API进行更新。 + +```python +segmented_state_dict: List[Dict[str, torch.Tensor]] = ... +num_segment = len(segmented_state_dict) +for seg_idx in range(num_segment): + serialized_data = serialize_state_dict(segmented_state_dict[seg_idx]) + data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1) + response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" + +``` + +**注意**: 对于pytorch推理后端,lmdeploy还支持扁平化桶张量(flattened bucket tensor)传输方式: + +```python +from lmdeploy.utils import serialize_state_dict, FlattenedTensorBucket, FlattenedTensorMetadata + +segmented_state_dict: List[Dict[str, torch.Tensor]] = ... +num_segment = len(segmented_state_dict) +for seg_idx in range(num_segment): + named_tensors = list(segmented_state_dict[seg_idx].items()) + bucket = FlattenedTensorBucket(named_tensors=named_tensors) + metadata = bucket.get_metadata() + flattened_tensor_data = dict(flattened_tensor=bucket.get_flattened_tensor(), metadata=metadata) + serialized_data = serialize_state_dict(flattened_tensor_data) + data = dict(serialized_named_tensors=serialized_data, finished=seg_idx == num_segment-1, load_format='flattened_bucket') + response = requests.post(f"{BASE_URL}/update_weights", headers=headers, json=data) + assert response.status_code == 200, f"response.status_code = {response.status_code}" +``` + +## 步骤 4: 唤醒引擎 + +权重更新后,调用API构建KV缓存,唤醒引擎,重新提供推理服务。 + +```python +response = requests.post(f"{BASE_URL}/wakeup", headers=headers, params=dict(tags=['kv_cache'])) +assert response.status_code == 200, response.status_code +``` diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst index 8b3eec8360..fe3a4b53f9 100644 --- a/docs/zh_cn/index.rst +++ b/docs/zh_cn/index.rst @@ -106,6 +106,7 @@ LMDeploy 工具箱提供以下核心功能: advance/metrics.md advance/context_parallel.md advance/spec_decoding.md + advance/update_weights.md .. toctree:: :maxdepth: 1 diff --git a/lmdeploy/cli/serve.py b/lmdeploy/cli/serve.py index a28e5d29c0..e37874766f 100644 --- a/lmdeploy/cli/serve.py +++ b/lmdeploy/cli/serve.py @@ -99,6 +99,7 @@ def add_parser_api_server(): ArgumentHelper.dllm_denoising_steps(pt_group) ArgumentHelper.dllm_confidence_threshold(pt_group) ArgumentHelper.enable_return_routed_experts(pt_group) + ArgumentHelper.distributed_executor_backend(pt_group) # common engine args dtype_act = ArgumentHelper.dtype(pt_group) diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py index e7afd6ed16..27af0f3fbc 100644 --- a/lmdeploy/cli/utils.py +++ b/lmdeploy/cli/utils.py @@ -699,6 +699,7 @@ def enable_return_routed_experts(parser): default=False, help='Whether to output routed expert ids for replay') + @staticmethod def add_spec_group(parser): spec_group = parser.add_argument_group('Speculative decoding arguments') spec_group.add_argument('--speculative-algorithm', @@ -719,6 +720,15 @@ def add_spec_group(parser): return spec_group + @staticmethod + def distributed_executor_backend(parser): + """Distributed_executor_backend.""" + return parser.add_argument('--distributed-executor-backend', + type=str, + default=None, + choices=['uni', 'mp', 'ray'], + help='The distributed executor backend for pytorch engine.') + # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/utils/__init__.py class FlexibleArgumentParser(argparse.ArgumentParser):