From 8992c9e83484658c1fb786ecf47125cf40e135be Mon Sep 17 00:00:00 2001 From: Surbhi Jain Date: Thu, 11 Apr 2024 18:24:56 +0000 Subject: [PATCH] Call max_utils.get_project() only when Vertex Tensorboard is enabled --- MaxText/tests/train_gpu_smoke_test.py | 1 - MaxText/tests/train_int8_smoke_test.py | 1 - MaxText/tests/train_smoke_test.py | 1 - MaxText/train.py | 3 ++- 4 files changed, 2 insertions(+), 4 deletions(-) diff --git a/MaxText/tests/train_gpu_smoke_test.py b/MaxText/tests/train_gpu_smoke_test.py index 80a4676bb4..d7f9df3a55 100644 --- a/MaxText/tests/train_gpu_smoke_test.py +++ b/MaxText/tests/train_gpu_smoke_test.py @@ -25,7 +25,6 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - os.environ["TENSORBOARD_PROJECT"] = "test-project" train_main([ None, "third_party/py/maxtext/configs/gpu_smoke_test.yml", diff --git a/MaxText/tests/train_int8_smoke_test.py b/MaxText/tests/train_int8_smoke_test.py index 89b8e5b554..5efc5ebeb3 100644 --- a/MaxText/tests/train_int8_smoke_test.py +++ b/MaxText/tests/train_int8_smoke_test.py @@ -25,7 +25,6 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - os.environ["TENSORBOARD_PROJECT"] = "test-project" train_main([None, "third_party/py/maxtext/configs/base.yml", f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", r"dataset_path=gs://maxtext-dataset", diff --git a/MaxText/tests/train_smoke_test.py b/MaxText/tests/train_smoke_test.py index 71a0c58134..8cd41fb339 100644 --- a/MaxText/tests/train_smoke_test.py +++ b/MaxText/tests/train_smoke_test.py @@ -25,7 +25,6 @@ class Train(unittest.TestCase): def test_tiny_config(self): test_tmpdir = os.environ.get("TEST_TMPDIR") - os.environ["TENSORBOARD_PROJECT"] = "test-project" train_main([None, "third_party/py/maxtext/configs/base.yml", f"base_output_directory=gs://runner-maxtext-logs", "run_name=runner_test", r"dataset_path=gs://maxtext-dataset", diff --git a/MaxText/train.py b/MaxText/train.py index 3faaca4f78..4a19669457 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -508,7 +508,8 @@ def main(argv: Sequence[str]) -> None: validate_train_config(config) os.environ["TFDS_DATA_DIR"] = config.dataset_path vertex_tensorboard_manager = VertexTensorboardManager() - vertex_tensorboard_manager.configure_vertex_tensorboard(config) + if config.use_vertex_tensorboard or os.environ.get("UPLOAD_DATA_TO_TENSORBOARD"): + vertex_tensorboard_manager.configure_vertex_tensorboard(config) debug_config = debug_configuration.DebugConfig( stack_trace_config = stack_trace_configuration.StackTraceConfig(