diff --git a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py index 216bf176c6..3d0a4ac894 100644 --- a/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py +++ b/plugins/flytekit-wandb/flytekitplugins/wandb/tracking.py @@ -1,5 +1,5 @@ import os -from typing import Callable, Optional +from typing import Callable, Optional, Union import wandb from flytekit import Secret @@ -21,9 +21,10 @@ def __init__( task_function: Optional[Callable] = None, project: Optional[str] = None, entity: Optional[str] = None, - secret: Optional[Secret] = None, + secret: Optional[Union[Secret, Callable]] = None, id: Optional[str] = None, host: str = "https://wandb.ai", + api_host: str = "https://api.wandb.ai", **init_kwargs: dict, ): """Weights and Biases plugin. @@ -31,9 +32,11 @@ def __init__( task_function (function, optional): The user function to be decorated. Defaults to None. project (str): The name of the project where you're sending the new run. (Required) entity (str): An entity is a username or team name where you're sending runs. (Required) - secret (Secret): Secret with your `WANDB_API_KEY`. (Required) + secret (Secret or Callable): Secret with your `WANDB_API_KEY` or a callable that returns the API key. + The callable takes no arguments and returns a string. (Required) id (str, optional): A unique id for this wandb run. host (str, optional): URL to your wandb service. The default is "https://wandb.ai". + api_host (str, optional): URL to your API Host, The default is "https://api.wandb.ai". **init_kwargs (dict): The rest of the arguments are passed directly to `wandb.init`. Please see [the `wandb.init` docs](https://docs.wandb.ai/ref/python/init) for details. """ @@ -50,6 +53,7 @@ def __init__( self.init_kwargs = init_kwargs self.secret = secret self.host = host + self.api_host = api_host # All kwargs need to be passed up so that the function wrapping works for both # `@wandb_init` and `@wandb_init(...)` @@ -60,6 +64,7 @@ def __init__( secret=secret, id=id, host=host, + api_host=api_host, **init_kwargs, ) @@ -72,9 +77,16 @@ def execute(self, *args, **kwargs): # will generate it's own id. wand_id = self.id else: - # Set secret for remote execution - secrets = ctx.user_space_params.secrets - os.environ["WANDB_API_KEY"] = secrets.get(key=self.secret.key, group=self.secret.group) + if isinstance(self.secret, Secret): + # Set secret for remote execution + secrets = ctx.user_space_params.secrets + wandb_api_key = secrets.get(key=self.secret.key, group=self.secret.group) + else: + # Get API key with callable + wandb_api_key = self.secret() + + wandb.login(key=wandb_api_key, host=self.api_host) + if self.id is None: # The HOSTNAME is set to {.executionName}-{.nodeID}-{.taskRetryAttempt} # If HOSTNAME is not defined, use the execution name as a fallback diff --git a/plugins/flytekit-wandb/tests/test_wandb_init.py b/plugins/flytekit-wandb/tests/test_wandb_init.py index 67d866b5c1..664e4a77ac 100644 --- a/plugins/flytekit-wandb/tests/test_wandb_init.py +++ b/plugins/flytekit-wandb/tests/test_wandb_init.py @@ -4,7 +4,9 @@ from flytekitplugins.wandb import wandb_init from flytekitplugins.wandb.tracking import WANDB_CUSTOM_TYPE_VALUE, WANDB_EXECUTION_TYPE_VALUE -from flytekit import task +from flytekit import Secret, task + +secret = Secret(key="abc", group="xyz") @pytest.mark.parametrize("id", [None, "abc123"]) @@ -12,11 +14,12 @@ def test_wandb_extra_config(id): wandb_decorator = wandb_init( project="abc", entity="xyz", - secret_key="my-secret-key", + secret=secret, id=id, host="https://my_org.wandb.org", ) + assert wandb_decorator.secret is secret extra_config = wandb_decorator.get_extra_config() if id is None: @@ -29,7 +32,7 @@ def test_wandb_extra_config(id): @task -@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", secret_group="my-secret-group", tags=["my_tag"]) +@wandb_init(project="abc", entity="xyz", secret=secret, tags=["my_tag"]) def train_model(): pass @@ -42,7 +45,7 @@ def test_local_execution(wandb_mock): @task -@wandb_init(project="abc", entity="xyz", secret_key="my-secret-key", tags=["my_tag"], id="1234") +@wandb_init(project="abc", entity="xyz", secret=secret, tags=["my_tag"], id="1234") def train_model_with_id(): pass @@ -71,8 +74,8 @@ def test_non_local_execution(wandb_mock, manager_mock, os_mock): train_model() wandb_mock.init.assert_called_with(project="abc", entity="xyz", id="my_execution_id", tags=["my_tag"]) - ctx_mock.user_space_params.secrets.get.assert_called_with(key="my-secret-key", group="my-secret-group") - assert os_mock.environ["WANDB_API_KEY"] == "this_is_the_secret" + ctx_mock.user_space_params.secrets.get.assert_called_with(key="abc", group="xyz") + wandb_mock.login.assert_called_with(key="this_is_the_secret", host="https://api.wandb.ai") def test_errors(): @@ -82,5 +85,32 @@ def test_errors(): with pytest.raises(ValueError, match="entity must be set"): wandb_init(project="abc") - with pytest.raises(ValueError, match="secret_key must be set"): + with pytest.raises(ValueError, match="secret must be set"): wandb_init(project="abc", entity="xyz") + + +def get_secret(): + return "my-wandb-api-key" + + +@task +@wandb_init(project="my_project", entity="my_entity", secret=get_secret, tags=["my_tag"], id="1234") +def train_model_with_id_callable_secret(): + pass + + +@patch("flytekitplugins.wandb.tracking.os") +@patch("flytekitplugins.wandb.tracking.FlyteContextManager") +@patch("flytekitplugins.wandb.tracking.wandb") +def test_secret_callable_remote(wandb_mock, manager_mock, os_mock): + # Pretend that the execution is remote + ctx_mock = Mock() + ctx_mock.execution_state.is_local_execution.return_value = False + + manager_mock.current_context.return_value = ctx_mock + os_mock.environ = {} + + train_model_with_id_callable_secret() + + wandb_mock.init.assert_called_with(project="my_project", entity="my_entity", id="1234", tags=["my_tag"]) + wandb_mock.login.assert_called_with(key=get_secret(), host="https://api.wandb.ai")