You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
`#queue-req` indicates the number of requests in the queue. If you frequently see `#queue-req == 0`, it suggests you are bottlenecked by the request submission speed.
14
+
12
15
A healthy range for `#queue-req` is `50 - 500`.
16
+
13
17
On the other hand, do not make `#queue-req` too large because it will also increase the scheduling overhead on the server, especially when using the default longest-prefix-match schedule policy (`--schedule-policy lpm`).
14
18
15
19
### Tune `--schedule-conservativeness`
20
+
16
21
`token usage` indicates the KV cache memory utilization of the server. `token usage > 0.9` means good utilization.
17
22
If you frequently see `token usage < 0.9` and `#queue-req > 0`, it means the server is too conservative about taking in new requests. You can decrease `--schedule-conservativeness` to a value like 0.3.
18
23
The case of server being too conservative can happen when users send many requests with a large `max_new_tokens` but the requests stop very early due to EOS or stop strings.
@@ -22,18 +27,37 @@ On the other hand, if you see `token usage` very high and you frequently see war
22
27
If you see `decode out of memory happened` occasionally but not frequently, it is okay.
23
28
24
29
### Tune `--dp-size` and `--tp-size`
25
-
Data parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput.
26
30
27
-
### Avoid out-of-memory by Tuning `--chunked-prefill-size`, `--mem-fraction-static`, `--max-running-requests`
31
+
Data parallelism is better for throughput. When there is enough GPU memory, always favor data parallelism for throughput. Refer to [sglang router](../backend/sglang_router.md) for a better data parallelism rather than using `dp_size` parameter.
32
+
33
+
## Avoid out-of-memory by Tuning `--chunked-prefill-size`, `--mem-fraction-static`, `--max-running-requests`
34
+
28
35
If you see out of memory (OOM) errors, you can try to tune the following parameters.
29
36
- If OOM happens during prefill, try to decrease `--chunked-prefill-size` to `4096` or `2048`.
30
37
- If OOM happens during decoding, try to decrease `--max-running-requests`.
31
38
- You can also try to decrease `--mem-fraction-static`, which reduces the memory usage of the KV cache memory pool and helps both prefill and decoding.
32
39
33
-
### Try Advanced Options
34
-
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
40
+
## Enabling cache for `torch.compile`
41
+
42
+
To enable `torch.compile` acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. By default, `torch.compile` will automatically cache the FX graph and Triton in `/tmp/torchinductor_root`, which might be cleared according to the [system policy](https://serverfault.com/questions/377348/when-does-tmp-get-cleared). You can export the environment variable `TORCHINDUCTOR_CACHE_DIR` to save compilation cache in your desired directory to avoid unwanted deletion. You can also share the cache with other machines to reduce the compilation time.
43
+
44
+
SGLang uses `max-autotune-no-cudagraphs` mode of `torch.compile`. The auto-tuning can be slow.
45
+
If you want to deploy a model on many different machines, you can ship the `torch.compile` cache to these machines and skip the compilation steps. This is based on [PyTorch official documentation](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html).
46
+
47
+
*Examples*:
48
+
49
+
1. Generate the cache by setting `TORCHINDUCTOR_CACHE_DIR` and running the model once.
Copy file name to clipboardExpand all lines: docs/backend/server_arguments.md
+16-10Lines changed: 16 additions & 10 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -3,31 +3,39 @@
3
3
## Common launch commands
4
4
5
5
- To enable multi-GPU tensor parallelism, add `--tp 2`. If it reports the error "peer access is not supported between these two devices", add `--enable-p2p-check` to the server launch command.
- To enable multi-GPU data parallelism, add `--dp 2`. Data parallelism is better for throughput if there is enough memory. It can also be used together with tensor parallelism. The following command uses 4 GPUs in total. We recommend [SGLang Router](../router/router.md) for data parallelism.
- If you see out-of-memory errors during serving, try to reduce the memory usage of the KV cache pool by setting a smaller value of `--mem-fraction-static`. The default value is `0.9`.
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently.
29
+
30
+
- To enable torch.compile acceleration, add `--enable-torch-compile`. It accelerates small models on small batch sizes. This does not work for FP8 currently. You can refer to [`Enabling cache for torch.compile`](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile) for more details.
24
31
- To enable torchao quantization, add `--torchao-config int4wo-128`. It supports other [quantization strategies (INT8/FP8)](https://github.com/sgl-project/sglang/blob/v0.3.6/python/sglang/srt/server_args.py#L671) as well.
25
32
- To enable fp8 weight quantization, add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
26
33
- To enable fp8 kv cache quantization, add `--kv-cache-dtype fp8_e5m2`.
27
34
- If the model does not have a chat template in the Hugging Face tokenizer, you can specify a [custom chat template](custom_chat_template.md).
28
35
29
36
- To run tensor parallelism on multiple nodes, add `--nnodes 2`. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-0` be the hostname of the first node and `50000` be an available port, you can use the following commands. If you meet deadlock, please try to add `--disable-cuda-graph`
@@ -49,15 +57,13 @@ Please consult the documentation below to learn more about the parameters you ma
49
57
*`kv_cache_dtype`: Dtype of the kv cache, defaults to the `dtype`.
50
58
*`context_length`: The number of tokens our model can process *including the input*. Note that extending the default might lead to strange behavior.
51
59
*`device`: The device we put the model, defaults to `cuda`.
52
-
*`chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).
60
+
*`chat_template`: The chat template to use. Deviating from the default might lead to unexpected responses. For multi-modal chat templates, refer to [here](https://docs.sglang.ai/backend/openai_api_vision.ipynb#Chat-Template).**Make sure the correct**`chat_template`**is passed, or performance degradation may occur!!!!**
53
61
*`is_embedding`: Set to true to perform [embedding](./openai_api_embeddings.ipynb) / [encode](https://docs.sglang.ai/backend/native_api#Encode-(embedding-model)) and [reward](https://docs.sglang.ai/backend/native_api#Classify-(reward-model)) tasks.
54
62
*`revision`: Adjust if a specific version of the model should be used.
55
63
*`skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
56
64
*`json_model_override_args`: Override model config with the provided JSON.
57
65
*`delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
58
66
59
-
> [!IMPORTANT]
60
-
> **Make sure the correct `chat_template` is passed, or performance degradation may occur.**
61
67
62
68
## Serving: HTTP & API
63
69
@@ -178,7 +184,7 @@ Please consult the documentation below to learn more about the parameters you ma
178
184
179
185
*`enable_mixed_chunk`: Enables mixing prefill and decode, see [this discussion](https://github.com/sgl-project/sglang/discussions/1163).
180
186
*`enable_dp_attention`: Enable [Data Parallelism Attention](https://lmsys.org/blog/2024-12-04-sglang-v0-4/#data-parallelism-attention-for-deepseek-models) for Deepseek models. Note that you need to choose `dp_size = tp_size` for this.
181
-
*`enable_torch_compile`: Torch compile the model. This is an experimental feature.
187
+
*`enable_torch_compile`: Torch compile the model. Note that compiling a model takes a long time but have a great performance boost. The compiled model can also be [cached for future use](https://docs.sglang.ai/backend/hyperparameter_tuning.html#enabling-cache-for-torch-compile).
182
188
*`torch_compile_max_bs`: The maximum batch size when using `torch_compile`.
183
189
*`cuda_graph_max_bs`: Adjust the maximum batchsize when using cuda graph. By default this is chosen for you based on GPU specifics.
184
190
*`cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
Copy file name to clipboardExpand all lines: docs/references/deepseek.md
+1-2Lines changed: 1 addition & 2 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -61,8 +61,7 @@ If you encounter errors when starting the server, ensure the weights have finish
61
61
62
62
### Caching `torch.compile`
63
63
64
-
The DeepSeek series have huge model weights, it takes some time to compile the model with `torch.compile` for the first time if you have added the flag `--enable-torch-compile`. By default, `torch.compile` will automatically cache the FX graph and Triton in `/tmp/torchinductor_root`, which might be cleared according to the [system policy](https://serverfault.com/questions/377348/when-does-tmp-get-cleared). You can export the environment variable `TORCHINDUCTOR_CACHE_DIR` to save compilation cache in your desired directory to avoid unwanted deletion. You can also share the cache with other machines to reduce the compilation time. You may refer to the [PyTorch official documentation](https://pytorch.org/tutorials/recipes/torch_compile_caching_tutorial.html) and [SGLang Documentation](./torch_compile_cache.md) for more details.
65
-
64
+
The DeepSeek series have huge model weights, it takes some time to compile the model with `torch.compile` for the first time if you have added the flag `--enable-torch-compile`. You can refer [here](https://docs.sglang.ai/backend/hyperparameter_tuning.html#try-advanced-options) to optimize the caching of compilation results, so that the cache can be used to speed up the next startup.
66
65
### Launch with One node of 8 H200
67
66
68
67
Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/benchmark/deepseek_v3#using-docker-recommended). **Note that Deepseek V3 is already in FP8. So we should not run it with any quantization arguments like `--quantization fp8 --kv-cache-dtype fp8_e5m2`.** Also, `--enable-dp-attention` can be useful to improve for Deepseek V3/R1's throughput. Please refer to [Data Parallelism Attention](https://docs.sglang.ai/references/deepseek.html#multi-head-latent-attention-mla-throughput-optimizations) for detail.
0 commit comments