From ecb520cd557b2f9c52562cf688307b0a98bed01b Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 May 2024 09:05:37 -0400 Subject: [PATCH 1/3] Adds callable option to get secrets Signed-off-by: Thomas J. Fan --- .../flytekitplugins/wandb/tracking.py | 20 ++++++--- .../flytekit-wandb/tests/test_wandb_init.py | 42 ++++++++++++++++--- 2 files changed, 50 insertions(+), 12 deletions(-) diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py index 216bf176c6..aa405842bc 100644 --- a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Optional +from typing import Callable, Optional, Union import wandb from flytekit import Secret @@ -21,7 +21,7 @@ def __init__( task_function: Optional[Callable] = None, project: Optional[str] = None, entity: Optional[str] = None, - secret: Optional[Secret] = None, + secret: Optional[Union[Secret, Callable]] = None, id: Optional[str] = None, host: str = "https://wandb.ai", **init_kwargs: dict, @@ -31,7 +31,8 @@ def __init__( task_function (function, optional): The user function to be decorated. Defaults to None. project (str): The name of the project where you're sending the new run. (Required) entity (str): An entity is a username or team name where you're sending runs. (Required) - secret (Secret): Secret with your `WANDB_API_KEY`. (Required) + secret (Secret or Callable): Secret with your `WANDB_API_KEY` or callable that returns the API key. + (Required) id (str, optional): A unique id for this wandb run. host (str, optional): URL to your wandb service. The default is "https://wandb.ai". **init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see @@ -72,9 +73,16 @@ def execute(self, *args, **kwargs): # will generate it's own id. wand_id = self.id else: - # Set secret for remote execution - secrets = ctx.user_space_params.secrets - os.environ["WANDB_API_KEY"] = secrets.get(key=self.secret.key, group=self.secret.group) + if isinstance(self.secret, Secret): + # Set secret for remote execution + secrets = ctx.user_space_params.secrets + wandb_api_key = secrets.get(key=self.secret.key, group=self.secret.group) + else: + # Get API key with callable + wandb_api_key = self.secret() + + os.environ["WANDB_API_KEY"] = wandb_api_key + if self.id is None: # The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt} # If HOSTNAME is not defined, use the execution name as a fallback diff --git a/plugins/flytekit-wandb/tests/test_wandb_init.py b/plugins/flytekit-wandb/tests/test_wandb_init.py index 67d866b5c1..3520acf01b 100644 --- a/plugins/flytekit-wandb/tests/test_wandb_init.py +++ b/plugins/flytekit-wandb/tests/test_wandb_init.py @@ -4,7 +4,9 @@ from flytekitplugins.wandb import wandb_init from flytekitplugins.wandb.tracking import WANDB_CUSTOM_TYPE_VALUE, WANDB_EXECUTION_TYPE_VALUE -from flytekit import task +from flytekit import Secret, task + +secret = Secret(key="abc", group="xyz") @pytest.mark.parametrize("id", [None, "abc123"]) @@ -12,11 +14,12 @@ def test_wandb_extra_config(id): wandb_decorator = wandb_init( project="abc", entity="xyz", - secret_key="my-secret-key", + secret=secret, id=id, host="https://my_org.wandb.org", ) + assert wandb_decorator.secret is secret extra_config = wandb_decorator.get_extra_config() if id is None: @@ -29,7 +32,7 @@ def test_wandb_extra_config(id): @task -@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", secret_group="my-secret-group", tags=["my_tag"]) +@wandb_init(project="abc", entity="xyz", secret=secret, tags=["my_tag"]) def train_model(): pass @@ -42,7 +45,7 @@ def test_local_execution(wandb_mock): @task -@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", tags=["my_tag"], id="1234") +@wandb_init(project="abc", entity="xyz", secret=secret, tags=["my_tag"], id="1234") def train_model_with_id(): pass @@ -71,7 +74,7 @@ def test_non_local_execution(wandb_mock, manager_mock, os_mock): train_model() wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"]) - ctx_mock.user_space_params.secrets.get.assert_called_with(key="my-secret-key", group="my-secret-group") + ctx_mock.user_space_params.secrets.get.assert_called_with(key="abc", group="xyz") assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret" @@ -82,5 +85,32 @@ def test_errors(): with pytest.raises(ValueError, match="entity must be set"): wandb_init(project="abc") - with pytest.raises(ValueError, match="secret_key must be set"): + with pytest.raises(ValueError, match="secret must be set"): wandb_init(project="abc", entity="xyz") + + +def get_secret(): + return "my-wandb-api-key" + + +@task +@wandb_init(project="my_project", entity="my_entity", secret=get_secret, tags=["my_tag"], id="1234") +def train_model_with_id_callable_secret(): + pass + + +@patch("flytekitplugins.wandb.tracking.os") +@patch("flytekitplugins.wandb.tracking.FlyteContextManager") +@patch("flytekitplugins.wandb.tracking.wandb") +def test_secret_callable_remote(wandb_mock, manager_mock, os_mock): + # Pretend that the execution is remote + ctx_mock = Mock() + ctx_mock.execution_state.is_local_execution.return_value = False + + manager_mock.current_context.return_value = ctx_mock + os_mock.environ = {} + + train_model_with_id_callable_secret() + + wandb_mock.init.assert_called_with(project="my_project", entity="my_entity", id="1234", tags=["my_tag"]) + assert os_mock.environ["WANDB_API_KEY"] == get_secret() From f54688fb5c7836ddd2871cf9c83dfe4d4e1ae2e3 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 May 2024 09:12:34 -0400 Subject: [PATCH 2/3] DOC Improve docstring Signed-off-by: Thomas J. Fan --- plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py index aa405842bc..2723b11a17 100644 --- a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py @@ -31,8 +31,8 @@ def __init__( task_function (function, optional): The user function to be decorated. Defaults to None. project (str): The name of the project where you're sending the new run. (Required) entity (str): An entity is a username or team name where you're sending runs. (Required) - secret (Secret or Callable): Secret with your `WANDB_API_KEY` or callable that returns the API key. - (Required) + secret (Secret or Callable): Secret with your `WANDB_API_KEY` or a callable that returns the API key. + The callable takes no arguments and returns a string. (Required) id (str, optional): A unique id for this wandb run. host (str, optional): URL to your wandb service. The default is "https://wandb.ai". **init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see From 848084979d0eb2d793e6865f6f8335e5422b8d34 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 29 May 2024 12:42:08 -0400 Subject: [PATCH 3/3] Use wandb.login instead of environment variable Signed-off-by: Thomas J. Fan --- plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py | 6 +++++- plugins/flytekit-wandb/tests/test_wandb_init.py | 4 ++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py index 2723b11a17..3d0a4ac894 100644 --- a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py @@ -24,6 +24,7 @@ def __init__( secret: Optional[Union[Secret, Callable]] = None, id: Optional[str] = None, host: str = "https://wandb.ai", + api_host: str = "https://api.wandb.ai", **init_kwargs: dict, ): """Weights and Biases plugin. @@ -35,6 +36,7 @@ def __init__( The callable takes no arguments and returns a string. (Required) id (str, optional): A unique id for this wandb run. host (str, optional): URL to your wandb service. The default is "https://wandb.ai". + api_host (str, optional): URL to your API Host, The default is "https://api.wandb.ai". **init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see [the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details. """ @@ -51,6 +53,7 @@ def __init__( self.init_kwargs = init_kwargs self.secret = secret self.host = host + self.api_host = api_host # All kwargs need to be passed up so that the function wrapping works for both # `@wandb_init` and `@wandb_init(...)` @@ -61,6 +64,7 @@ def __init__( secret=secret, id=id, host=host, + api_host=api_host, **init_kwargs, ) @@ -81,7 +85,7 @@ def execute(self, *args, **kwargs): # Get API key with callable wandb_api_key = self.secret() - os.environ["WANDB_API_KEY"] = wandb_api_key + wandb.login(key=wandb_api_key, host=self.api_host) if self.id is None: # The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt} diff --git a/plugins/flytekit-wandb/tests/test_wandb_init.py b/plugins/flytekit-wandb/tests/test_wandb_init.py index 3520acf01b..664e4a77ac 100644 --- a/plugins/flytekit-wandb/tests/test_wandb_init.py +++ b/plugins/flytekit-wandb/tests/test_wandb_init.py @@ -75,7 +75,7 @@ def test_non_local_execution(wandb_mock, manager_mock, os_mock): wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"]) ctx_mock.user_space_params.secrets.get.assert_called_with(key="abc", group="xyz") - assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret" + wandb_mock.login.assert_called_with(key="this_is_the_secret", host="https://api.wandb.ai") def test_errors(): @@ -113,4 +113,4 @@ def test_secret_callable_remote(wandb_mock, manager_mock, os_mock): train_model_with_id_callable_secret() wandb_mock.init.assert_called_with(project="my_project", entity="my_entity", id="1234", tags=["my_tag"]) - assert os_mock.environ["WANDB_API_KEY"] == get_secret() + wandb_mock.login.assert_called_with(key=get_secret(), host="https://api.wandb.ai")