From 346c6739afbfa89d49e239ec305516b4db5da9d2 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 20 May 2024 21:56:18 +0000 Subject: [PATCH 1/5] Add profiling support and update docs --- README.md | 4 +- ...-prometheus-metrics-in-jetstream-server.md | 51 ++++++++++++++++ docs/online-inference-with-maxtext-engine.md | 35 ----------- ...iling-with-jax-profiler-and-tensorboard.md | 51 ++++++++++++++++ jetstream/core/server_lib.py | 11 ++++ jetstream/tests/core/test_server.py | 61 ++++++++++++------- 6 files changed, 156 insertions(+), 57 deletions(-) create mode 100644 docs/observability-prometheus-metrics-in-jetstream-server.md create mode 100644 docs/profiling-with-jax-profiler-and-tensorboard.md diff --git a/README.md b/README.md index 758c9640..ee0b1eee 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,11 @@ Currently, there are two reference engine implementations available -- one for J - README: https://github.com/google/jetstream-pytorch/blob/main/README.md ## Documentation -- [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](#jetstream-maxtext-inference-on-v5e-cloud-tpu-vm-user-guide)] +- [Online Inference with MaxText on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream) [[README](https://github.com/google/JetStream/blob/main/docs/online-inference-with-maxtext-engine.md)] - [Online Inference with Pytorch on v5e Cloud TPU VM](https://cloud.google.com/tpu/docs/tutorials/LLM/jetstream-pytorch) [[README](https://github.com/google/jetstream-pytorch/tree/main?tab=readme-ov-file#jetstream-pytorch)] - [Serve Gemma using TPUs on GKE with JetStream](https://cloud.google.com/kubernetes-engine/docs/tutorials/serve-gemma-tpu-jetstream) +- [Observability in JetStream Server](https://github.com/google/JetStream/blob/main/docs/observability-prometheus-metrics-in-jetstream-server.md) +- [Profiling in JetStream Server](https://github.com/google/JetStream/blob/main/docs/profiling-with-jax-profiler-and-tensorboard.md) - [JetStream Standalone Local Setup](#jetstream-standalone-local-setup) diff --git a/docs/observability-prometheus-metrics-in-jetstream-server.md b/docs/observability-prometheus-metrics-in-jetstream-server.md new file mode 100644 index 00000000..b61cf081 --- /dev/null +++ b/docs/observability-prometheus-metrics-in-jetstream-server.md @@ -0,0 +1,51 @@ +# Observability in JetStream Server + +In JetStream Server, we use [Prometheus](https://prometheus.io/docs/introduction/overview/) to collect key metrics within JetStream orchestrator and engines. We implemented a [Prometheus client server](https://prometheus.github.io/client_python/exporting/http/) in JetStream `server_lib.py` and use `MetricsServerConfig` (by passing `prometheus_port` in server entrypoint) to gaurd the metrics observability feature. + +## Enable Prometheus server to observe Jetstream metrics + +Metrics are not exported by default, here is an example to run JetStream MaxText server with metrics observability: + +```bash +# Refer to JetStream MaxText User Guide for the following server config. +export TOKENIZER_PATH=assets/tokenizer.gemma +export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} +export MAX_PREFILL_PREDICT_LENGTH=1024 +export MAX_TARGET_LENGTH=2048 +export MODEL_NAME=gemma-7b +export ICI_FSDP_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=-1 +export ICI_TENSOR_PARALLELISM=1 +export SCAN_LAYERS=false +export WEIGHT_DTYPE=bfloat16 +export PER_DEVICE_BATCH_SIZE=11 +# Set PROMETHEUS_PORT to enable Prometheus metrics. +export PROMETHEUS_PORT=9090 + +cd ~/maxtext +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + prometheus_port=${PROMETHEUS_PORT} +``` + +Now that we configured `prometheus_port=9090` above, we can observe various Jetstream metrics via HTTP requests to `0.0.0.0:9000`. Towards the end, the response should have content similar to the following: + +``` +# HELP jetstream_prefill_backlog_size Size of prefill queue +# TYPE jetstream_prefill_backlog_size gauge +jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 +# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch +# TYPE jetstream_slots_available_percentage gauge +jetstream_slots_available_percentage{id="",idx="0"} 0.96875 +``` \ No newline at end of file diff --git a/docs/online-inference-with-maxtext-engine.md b/docs/online-inference-with-maxtext-engine.md index 5c6aef5d..95fc84cc 100644 --- a/docs/online-inference-with-maxtext-engine.md +++ b/docs/online-inference-with-maxtext-engine.md @@ -205,41 +205,6 @@ Prompt: Today is a good day Response: to be a fan ``` -### (optional) Observe Jetstream metrics - -Metrics are not exported by default, to configure Jetstream to emit metrics start this guide again from step four and replace the `Run the following command to start the JetStream MaxText server` step with the following: - -```bash -export PROMETHEUS_PORT=9090 - -cd ~/maxtext -python MaxText/maxengine_server.py \ - MaxText/configs/base.yml \ - tokenizer_path=${TOKENIZER_PATH} \ - load_parameters_path=${LOAD_PARAMETERS_PATH} \ - max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ - max_target_length=${MAX_TARGET_LENGTH} \ - model_name=${MODEL_NAME} \ - ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ - ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ - ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ - scan_layers=${SCAN_LAYERS} \ - weight_dtype=${WEIGHT_DTYPE} \ - per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ - prometheus_port=${PROMETHEUS_PORT} -``` - -Now that we configured `prometheus_port=9090` above, we can observe various Jetstream metrics via HTTP requests to `0.0.0.0:9000`. Towards the end, the response should have content similar to the following: - -``` -# HELP jetstream_prefill_backlog_size Size of prefill queue -# TYPE jetstream_prefill_backlog_size gauge -jetstream_prefill_backlog_size{id="SOME-HOSTNAME-HERE>"} 0.0 -# HELP jetstream_slots_available_percentage The percentage of available slots in decode batch -# TYPE jetstream_slots_available_percentage gauge -jetstream_slots_available_percentage{id="",idx="0"} 0.96875 -``` - ## Step 6: Run benchmarks with JetStream MaxText server Note: The JetStream MaxText Server is not running with quantization optimization in Step 3. To get best benchmark results, we need to enable quantization (Please use AQT trained or fine tuned checkpoints to ensure accuracy) for both weights and KV cache, please add the quantization flags and restart the server as following: diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md new file mode 100644 index 00000000..21b323c4 --- /dev/null +++ b/docs/profiling-with-jax-profiler-and-tensorboard.md @@ -0,0 +1,51 @@ +# Profiling in JetStream Server + +In JetStream server, we have implemented JAX profiler server to support profiling JAX program with tensorboard. + +## Profiling with JAX profiler server and tenorboard server + +Following the [JAX official manual profiling approach](https://jax.readthedocs.io/en/latest/profiling.html#manual-capture-via-tensorboard), here is an example of JetStream MaxText server profiling with tensorboard: + +1. Start a TensorBoard server: +```bash +tensorboard --logdir /tmp/tensorboard/ +``` +You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the --port flag. + +2. Start JetStream MaxText server: +```bash +# Refer to JetStream MaxText User Guide for the following server config. +export TOKENIZER_PATH=assets/tokenizer.gemma +export LOAD_PARAMETERS_PATH=${UNSCANNED_CKPT_PATH} +export MAX_PREFILL_PREDICT_LENGTH=1024 +export MAX_TARGET_LENGTH=2048 +export MODEL_NAME=gemma-7b +export ICI_FSDP_PARALLELISM=1 +export ICI_AUTOREGRESSIVE_PARALLELISM=-1 +export ICI_TENSOR_PARALLELISM=1 +export SCAN_LAYERS=false +export WEIGHT_DTYPE=bfloat16 +export PER_DEVICE_BATCH_SIZE=11 +# Set ENABLE_JAX_PROFILER to enable JAX profiler server at port 9999. +export ENABLE_JAX_PROFILER=true + +cd ~/maxtext +python MaxText/maxengine_server.py \ + MaxText/configs/base.yml \ + tokenizer_path=${TOKENIZER_PATH} \ + load_parameters_path=${LOAD_PARAMETERS_PATH} \ + max_prefill_predict_length=${MAX_PREFILL_PREDICT_LENGTH} \ + max_target_length=${MAX_TARGET_LENGTH} \ + model_name=${MODEL_NAME} \ + ici_fsdp_parallelism=${ICI_FSDP_PARALLELISM} \ + ici_autoregressive_parallelism=${ICI_AUTOREGRESSIVE_PARALLELISM} \ + ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \ + scan_layers=${SCAN_LAYERS} \ + weight_dtype=${WEIGHT_DTYPE} \ + per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ + enable_jax_profiler=${ENABLE_JAX_PROFILER} +``` + +3. Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”. + +4. After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select trace_viewer. \ No newline at end of file diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index d50f3cc7..8c732fb5 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -95,6 +95,7 @@ def run( threads: int | None = None, jax_padding: bool = True, metrics_server_config: config_lib.MetricsServerConfig | None = None, + enable_jax_profiler: bool = False, ) -> JetStreamServer: """Runs a server with a specified config. @@ -105,6 +106,9 @@ def run( credentials: Should use grpc credentials by default. threads: Number of RPC handlers worker threads. This should be at least equal to the decoding batch size to fully saturate the decoding queue. + jax_padding: The flag to enable JAX padding during tokenization. + metrics_server_config: The config to enable Promethus metric server. + enable_jax_profiler: The flag to enable JAX profiler server. Returns: JetStreamServer that wraps the grpc server and orchestrator driver. @@ -148,6 +152,13 @@ def run( logging.info("Starting server on port %d with %d threads", port, threads) jetstream_server.start() + + # Setup Jax Profiler + if enable_jax_profiler: + logging.info("Starting JAX profiler server on port 9999") + jax.profiler.start_server(9999) + else: + logging.info(f"Not starting JAX profiler server: {enable_jax_profiler=}") return jetstream_server diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 68faef15..5f0980e4 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -62,10 +62,8 @@ async def test_server( """Sets up a server and requests token responses.""" ######################### Server side ###################################### port = portpicker.pick_unused_port() - metrics_port = portpicker.pick_unused_port() print("port: " + str(port)) - print("metrics port: " + str(metrics_port)) credentials = grpc.local_server_credentials() server = server_lib.run( @@ -105,25 +103,46 @@ async def test_server( counter += 1 server.stop() - # Now test server with prometheus config - server = server_lib.run( - port=port, - config=config, - devices=devices, - credentials=credentials, - metrics_server_config=config_lib.MetricsServerConfig( - port=metrics_port - ), - ) - # assert prometheus server is running and responding - assert server._driver._metrics_collector is not None # pylint: disable=protected-access - assert ( - requests.get( - f"http://localhost:{metrics_port}", timeout=5 - ).status_code - == requests.status_codes.codes["ok"] - ) - server.stop() + def test_prometheus_server(self): + port = portpicker.pick_unused_port() + metrics_port = portpicker.pick_unused_port() + + print("port: " + str(port)) + print("metrics port: " + str(metrics_port)) + credentials = grpc.local_server_credentials() + # Now test server with prometheus config + server = server_lib.run( + port=port, + config=config_lib.InterleavedCPUTestServer, + devices=[None], + credentials=credentials, + metrics_server_config=config_lib.MetricsServerConfig(port=metrics_port), + ) + # assert prometheus server is running and responding + assert server._driver._metrics_collector is not None # pylint: disable=protected-access + assert ( + requests.get(f"http://localhost:{metrics_port}", timeout=5).status_code + == requests.status_codes.codes["ok"] + ) + server.stop() + + def test_jax_profiler_server(self): + port = portpicker.pick_unused_port() + print("port: " + str(port)) + credentials = grpc.local_server_credentials() + # Now test server with prometheus config + server = server_lib.run( + port=port, + config=config_lib.InterleavedCPUTestServer, + devices=[None], + credentials=credentials, + enable_jax_profiler=True, + ) + assert ( + requests.get(f"http://localhost:9999", timeout=5).status_code + == requests.status_codes.codes["ok"] + ) + server.stop() def test_get_devices(self): assert len(server_lib.get_devices()) == 1 From ca3fee627c9c078c91c252f2a077e9f0bf4e1331 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 20 May 2024 22:01:07 +0000 Subject: [PATCH 2/5] docs format --- docs/profiling-with-jax-profiler-and-tensorboard.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md index 21b323c4..cfb77723 100644 --- a/docs/profiling-with-jax-profiler-and-tensorboard.md +++ b/docs/profiling-with-jax-profiler-and-tensorboard.md @@ -10,7 +10,7 @@ Following the [JAX official manual profiling approach](https://jax.readthedocs.i ```bash tensorboard --logdir /tmp/tensorboard/ ``` -You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the --port flag. +You should be able to load TensorBoard at http://localhost:6006/. You can specify a different port with the `--port` flag. 2. Start JetStream MaxText server: ```bash @@ -48,4 +48,4 @@ python MaxText/maxengine_server.py \ 3. Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”. -4. After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select trace_viewer. \ No newline at end of file +4. After the capture finishes, TensorBoard should automatically refresh. (Not all of the TensorBoard profiling features are hooked up with JAX, so it may initially look like nothing was captured.) On the left under “Tools”, select `trace_viewer`. \ No newline at end of file From 327719a0f89e937aff2ad9866cddbee844757ee3 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 20 May 2024 22:05:38 +0000 Subject: [PATCH 3/5] pylint --- jetstream/core/server_lib.py | 2 +- jetstream/tests/core/test_server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 8c732fb5..544977ac 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -158,7 +158,7 @@ def run( logging.info("Starting JAX profiler server on port 9999") jax.profiler.start_server(9999) else: - logging.info(f"Not starting JAX profiler server: {enable_jax_profiler=}") + logging.info("Not starting JAX profiler server: %s", enable_jax_profiler) return jetstream_server diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 5f0980e4..31ca3b7e 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -139,7 +139,7 @@ def test_jax_profiler_server(self): enable_jax_profiler=True, ) assert ( - requests.get(f"http://localhost:9999", timeout=5).status_code + requests.get("http://localhost:9999", timeout=5).status_code == requests.status_codes.codes["ok"] ) server.stop() From 90068ec96583db2e1d917b24347f7dd93aff4a93 Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Mon, 20 May 2024 22:35:55 +0000 Subject: [PATCH 4/5] fix test --- jetstream/tests/core/test_server.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/jetstream/tests/core/test_server.py b/jetstream/tests/core/test_server.py index 31ca3b7e..f083a823 100644 --- a/jetstream/tests/core/test_server.py +++ b/jetstream/tests/core/test_server.py @@ -138,10 +138,7 @@ def test_jax_profiler_server(self): credentials=credentials, enable_jax_profiler=True, ) - assert ( - requests.get("http://localhost:9999", timeout=5).status_code - == requests.status_codes.codes["ok"] - ) + assert server server.stop() def test_get_devices(self): From 7698edba781b73d9cf17c58abe6c0d5b773c920b Mon Sep 17 00:00:00 2001 From: Zijun Zhou Date: Tue, 21 May 2024 18:08:48 +0000 Subject: [PATCH 5/5] configurable port --- docs/profiling-with-jax-profiler-and-tensorboard.md | 5 ++++- jetstream/core/server_lib.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/profiling-with-jax-profiler-and-tensorboard.md b/docs/profiling-with-jax-profiler-and-tensorboard.md index cfb77723..3727c387 100644 --- a/docs/profiling-with-jax-profiler-and-tensorboard.md +++ b/docs/profiling-with-jax-profiler-and-tensorboard.md @@ -28,6 +28,8 @@ export WEIGHT_DTYPE=bfloat16 export PER_DEVICE_BATCH_SIZE=11 # Set ENABLE_JAX_PROFILER to enable JAX profiler server at port 9999. export ENABLE_JAX_PROFILER=true +# Set JAX_PROFILER_PORT to customize JAX profiler server port. +export JAX_PROFILER_PORT=9999 cd ~/maxtext python MaxText/maxengine_server.py \ @@ -43,7 +45,8 @@ python MaxText/maxengine_server.py \ scan_layers=${SCAN_LAYERS} \ weight_dtype=${WEIGHT_DTYPE} \ per_device_batch_size=${PER_DEVICE_BATCH_SIZE} \ - enable_jax_profiler=${ENABLE_JAX_PROFILER} + enable_jax_profiler=${ENABLE_JAX_PROFILER} \ + jax_profiler_port=${JAX_PROFILER_PORT} ``` 3. Open http://localhost:6006/#profile, and click the “CAPTURE PROFILE” button in the upper left. Enter “localhost:9999” as the profile service URL (this is the address of the profiler server you started in the previous step). Enter the number of milliseconds you’d like to profile for, and click “CAPTURE”. diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index 544977ac..5935f5b6 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -96,6 +96,7 @@ def run( jax_padding: bool = True, metrics_server_config: config_lib.MetricsServerConfig | None = None, enable_jax_profiler: bool = False, + jax_profiler_port: int = 9999, ) -> JetStreamServer: """Runs a server with a specified config. @@ -109,6 +110,7 @@ def run( jax_padding: The flag to enable JAX padding during tokenization. metrics_server_config: The config to enable Promethus metric server. enable_jax_profiler: The flag to enable JAX profiler server. + jax_profiler_port: The port JAX profiler server (default to 9999). Returns: JetStreamServer that wraps the grpc server and orchestrator driver. @@ -155,8 +157,8 @@ def run( # Setup Jax Profiler if enable_jax_profiler: - logging.info("Starting JAX profiler server on port 9999") - jax.profiler.start_server(9999) + logging.info("Starting JAX profiler server on port %s", jax_profiler_port) + jax.profiler.start_server(jax_profiler_port) else: logging.info("Not starting JAX profiler server: %s", enable_jax_profiler) return jetstream_server