From 1c68e85d843476f27b1b152501028359b19a0c94 Mon Sep 17 00:00:00 2001 From: Mihai Mitrea Date: Wed, 1 Apr 2026 07:57:25 +0000 Subject: [PATCH 1/2] Add --force-refresh support for Databricks CLI token fetching Try `--force-refresh` before the regular CLI command so the SDK can bypass the CLI's own token cache when the SDK considers its token stale. If the CLI is too old to recognise `--force-refresh` (or `--profile`), gracefully fall back to the next command in the chain. Chain order: - with profile: forceCmd (--profile --force-refresh) -> profileCmd (--profile) -> fallbackCmd (--host) - without profile: forceCmd (--host --force-refresh) -> profileCmd (--host) Azure CLI callers are unchanged; they use constructors that leave forceCmd null, preserving existing behavior. Signed-off-by: Mihai Mitrea --- NEXT_CHANGELOG.md | 1 + .../databricks/sdk/core/CliTokenSource.java | 63 +++-- .../DatabricksCliCredentialsProvider.java | 31 ++- .../sdk/core/CliTokenSourceTest.java | 216 +++++++++++++++--- .../DatabricksCliCredentialsProviderTest.java | 11 + 5 files changed, 265 insertions(+), 57 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 0bd9e224b..4983a3468 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features and Improvements * Added automatic detection of AI coding agents (Antigravity, Claude Code, Cline, Codex, Copilot CLI, Cursor, Gemini CLI, OpenCode) in the user-agent string. The SDK now appends `agent/` to HTTP request headers when running inside a known AI agent environment. +* Pass `--force-refresh` to Databricks CLI `auth token` command so the SDK always receives a fresh token instead of a potentially stale one from the CLI's internal cache. Falls back gracefully on older CLIs that do not support this flag. ### Bug Fixes * Fixed Databricks CLI authentication to detect when the cached token's scopes don't match the SDK's configured scopes. Previously, a scope mismatch was silently ignored, causing requests to use wrong permissions. The SDK now raises an error with instructions to re-authenticate. diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 58aaf7655..582811bd6 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -26,15 +26,12 @@ public class CliTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(CliTokenSource.class); private List cmd; + private List fallbackCmd; + private List secondFallbackCmd; private String tokenTypeField; private String accessTokenField; private String expiryField; private Environment env; - // fallbackCmd is tried when the primary command fails with "unknown flag: --profile", - // indicating the CLI is too old to support --profile. Can be removed once support - // for CLI versions predating --profile is dropped. - // See: https://github.com/databricks/databricks-sdk-go/pull/1497 - private List fallbackCmd; /** * Internal exception that carries the clean stderr message but exposes full output for checks. @@ -58,7 +55,7 @@ public CliTokenSource( String accessTokenField, String expiryField, Environment env) { - this(cmd, tokenTypeField, accessTokenField, expiryField, env, null); + this(cmd, tokenTypeField, accessTokenField, expiryField, env, null, null); } public CliTokenSource( @@ -67,8 +64,8 @@ public CliTokenSource( String accessTokenField, String expiryField, Environment env, - List fallbackCmd) { - super(); + List fallbackCmd, + List secondFallbackCmd) { this.cmd = OSUtils.get(env).getCliExecutableCommand(cmd); this.tokenTypeField = tokenTypeField; this.accessTokenField = accessTokenField; @@ -76,6 +73,10 @@ public CliTokenSource( this.env = env; this.fallbackCmd = fallbackCmd != null ? OSUtils.get(env).getCliExecutableCommand(fallbackCmd) : null; + this.secondFallbackCmd = + secondFallbackCmd != null + ? OSUtils.get(env).getCliExecutableCommand(secondFallbackCmd) + : null; } /** @@ -153,27 +154,47 @@ private Token execCliCommand(List cmdToRun) throws IOException { } } + private String getErrorText(IOException e) { + return e instanceof CliCommandException + ? ((CliCommandException) e).getFullOutput() + : e.getMessage(); + } + + private boolean isUnknownFlagError(String errorText) { + return errorText != null && errorText.contains("unknown flag:"); + } + @Override public Token getToken() { try { return execCliCommand(this.cmd); } catch (IOException e) { - String textToCheck = - e instanceof CliCommandException - ? ((CliCommandException) e).getFullOutput() - : e.getMessage(); - if (fallbackCmd != null - && textToCheck != null - && textToCheck.contains("unknown flag: --profile")) { + if (fallbackCmd != null && isUnknownFlagError(getErrorText(e))) { LOG.warn( - "Databricks CLI does not support --profile flag. Falling back to --host. " + "CLI does not support some flags used by this SDK. " + + "Falling back to a compatible command. " + "Please upgrade your CLI to the latest version."); - try { - return execCliCommand(this.fallbackCmd); - } catch (IOException fallbackException) { - throw new DatabricksException(fallbackException.getMessage(), fallbackException); - } + } else { + throw new DatabricksException(e.getMessage(), e); + } + } + + try { + return execCliCommand(this.fallbackCmd); + } catch (IOException e) { + if (secondFallbackCmd != null && isUnknownFlagError(getErrorText(e))) { + LOG.warn( + "CLI does not support some flags used by this SDK. " + + "Falling back to a compatible command. " + + "Please upgrade your CLI to the latest version."); + } else { + throw new DatabricksException(e.getMessage(), e); } + } + + try { + return execCliCommand(this.secondFallbackCmd); + } catch (IOException e) { throw new DatabricksException(e.getMessage(), e); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index b312e1a3e..49963601c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -69,6 +69,17 @@ List buildHostArgs(String cliPath, DatabricksConfig config) { return cmd; } + List buildProfileArgs(String cliPath, DatabricksConfig config) { + return new ArrayList<>( + Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); + } + + private static List withForceRefresh(List cmd) { + List forceCmd = new ArrayList<>(cmd); + forceCmd.add("--force-refresh"); + return forceCmd; + } + private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { String cliPath = config.getDatabricksCliPath(); if (cliPath == null) { @@ -81,23 +92,27 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { List cmd; List fallbackCmd = null; + List secondFallbackCmd = null; if (config.getProfile() != null) { - // When profile is set, use --profile as the primary command. - // The profile contains the full config (host, account_id, etc.). - cmd = - new ArrayList<>( - Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); - // Build a --host fallback for older CLIs that don't support --profile. + List profileArgs = buildProfileArgs(cliPath, config); + cmd = withForceRefresh(profileArgs); + fallbackCmd = profileArgs; if (config.getHost() != null) { - fallbackCmd = buildHostArgs(cliPath, config); + secondFallbackCmd = buildHostArgs(cliPath, config); } } else { cmd = buildHostArgs(cliPath, config); } return new CliTokenSource( - cmd, "token_type", "access_token", "expiry", config.getEnv(), fallbackCmd); + cmd, + "token_type", + "access_token", + "expiry", + config.getEnv(), + fallbackCmd, + secondFallbackCmd); } @Override diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 8476c6de5..689b530d3 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -40,6 +40,13 @@ import org.mockito.MockedStatic; public class CliTokenSourceTest { + private static final List FORCE_CMD = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile", "--force-refresh"); + private static final List PROFILE_CMD = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); + private static final List HOST_CMD = + Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); + String getExpiryStr(String dateFormat, Duration offset) { ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); return futureExpiry.format(DateTimeFormatter.ofPattern(dateFormat)); @@ -217,16 +224,21 @@ public void testParseExpiry(String input, Instant expectedInstant, String descri } } - // ---- Fallback tests for --profile flag handling ---- + // ---- Fallback tests for --profile and --force-refresh flag handling ---- private CliTokenSource makeTokenSource( - Environment env, List primaryCmd, List fallbackCmd) { + Environment env, List cmd, List fallbackCmd) { + return makeTokenSource(env, cmd, fallbackCmd, null); + } + + private CliTokenSource makeTokenSource( + Environment env, List cmd, List fallbackCmd, List secondFallbackCmd) { OSUtilities osUtils = mock(OSUtilities.class); when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0)); try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); return new CliTokenSource( - primaryCmd, "token_type", "access_token", "expiry", env, fallbackCmd); + cmd, "token_type", "access_token", "expiry", env, fallbackCmd, secondFallbackCmd); } } @@ -245,12 +257,7 @@ public void testFallbackOnUnknownProfileFlagInStderr() { Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); AtomicInteger callCount = new AtomicInteger(0); try (MockedConstruction mocked = @@ -285,16 +292,10 @@ public void testFallbackOnUnknownProfileFlagInStderr() { @Test public void testFallbackTriggeredWhenUnknownFlagInStdout() { - // Fallback triggers even when "unknown flag" appears in stdout rather than stderr. Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); AtomicInteger callCount = new AtomicInteger(0); try (MockedConstruction mocked = @@ -329,16 +330,10 @@ public void testFallbackTriggeredWhenUnknownFlagInStdout() { @Test public void testNoFallbackOnRealAuthError() { - // When the primary fails with a real error (not unknown flag), no fallback is attempted. Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); try (MockedConstruction mocked = mockConstruction( @@ -361,14 +356,10 @@ public void testNoFallbackOnRealAuthError() { @Test public void testNoFallbackWhenFallbackCmdNotSet() { - // When fallbackCmd is null and the primary fails with unknown flag, original error propagates. Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, null); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, null); try (MockedConstruction mocked = mockConstruction( @@ -387,4 +378,173 @@ public void testNoFallbackWhenFallbackCmdNotSet() { assertEquals(1, mocked.constructed().size()); } } + + // ---- Force-refresh tests ---- + + @Test + public void testForceCmdSucceedsAndFallbacksNotRun() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD, HOST_CMD); + + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(validTokenJson("forced-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + })) { + Token token = tokenSource.getToken(); + assertEquals("forced-token", token.getAccessToken()); + assertEquals(1, mocked.constructed().size()); + } + } + + @Test + public void testCmdFailsWithUnknownFlagFallsBackToFallbackCmd() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + if (callCount.getAndIncrement() == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream( + "Error: unknown flag: --force-refresh".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("profile-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("profile-token", token.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + } + } + + @Test + public void testCmdAndFallbackBothFailFallsThroughToSecondFallback() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD, HOST_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + int call = callCount.getAndIncrement(); + if (call <= 1) { + // Both forceCmd and profileCmd fail with unknown --profile + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(validTokenJson("host-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("host-token", token.getAccessToken()); + assertEquals(3, mocked.constructed().size()); + } + } + + @Test + public void testRealAuthErrorDoesNotFallBack() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD); + + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream( + "databricks OAuth is not configured for this host".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + })) { + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue(ex.getMessage().contains("databricks OAuth is not configured")); + assertEquals(1, mocked.constructed().size()); + } + } + + @Test + public void testTwoLevelFallbackWithNoSecondFallback() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + if (callCount.getAndIncrement() == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("fallback-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("fallback-token", token.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java index bac4a766b..837e3c65b 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java @@ -139,4 +139,15 @@ void testBuildHostArgs_UnifiedHostFalse_WithAccountHost() { CLI_PATH, "auth", "token", "--host", ACCOUNT_HOST, "--account-id", ACCOUNT_ID), cmd); } + + // ---- Profile args construction tests ---- + + @Test + void testBuildProfileArgs() { + DatabricksConfig config = new DatabricksConfig().setProfile("my-profile"); + + List cmd = provider.buildProfileArgs(CLI_PATH, config); + + assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile"), cmd); + } } From 0dd48d184cfc1e4a88c510492f53a25bc4df67ce Mon Sep 17 00:00:00 2001 From: Mihai Mitrea Date: Wed, 1 Apr 2026 08:22:13 +0000 Subject: [PATCH 2/2] Refactor CliTokenSource to use an ordered attempt chain --- NEXT_CHANGELOG.md | 1 + .../databricks/sdk/core/CliTokenSource.java | 161 +++++++++++++----- .../DatabricksCliCredentialsProvider.java | 59 ++++--- .../sdk/core/CliTokenSourceTest.java | 72 +++++++- .../DatabricksCliCredentialsProviderTest.java | 84 +++++++++ 5 files changed, 305 insertions(+), 72 deletions(-) diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 4983a3468..e89c3ce73 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -14,6 +14,7 @@ ### Documentation ### Internal Changes +* Generalized CLI token source into a progressive command attempt list, replacing the fixed three-field approach with an extensible chain. ### API Changes * Add `createCatalog()`, `createSyncedTable()`, `deleteCatalog()`, `deleteSyncedTable()`, `getCatalog()` and `getSyncedTable()` methods for `workspaceClient.postgres()` service. diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 582811bd6..2bf7260c4 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -16,7 +16,10 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeParseException; import java.util.Arrays; +import java.util.Collections; import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; import org.apache.commons.io.IOUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -25,13 +28,25 @@ public class CliTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(CliTokenSource.class); - private List cmd; - private List fallbackCmd; - private List secondFallbackCmd; - private String tokenTypeField; - private String accessTokenField; - private String expiryField; - private Environment env; + /** + * Describes a CLI command with an optional warning message emitted when falling through to the + * next command in the chain. + */ + static class CliCommand { + final List cmd; + + // Flags used by this command (e.g. "--force-refresh", "--profile"). Used to distinguish + // "unknown flag" errors (which trigger fallback) from real auth errors (which propagate). + final List usedFlags; + + final String fallbackMessage; + + CliCommand(List cmd, List usedFlags, String fallbackMessage) { + this.cmd = cmd; + this.usedFlags = usedFlags != null ? usedFlags : Collections.emptyList(); + this.fallbackMessage = fallbackMessage; + } + } /** * Internal exception that carries the clean stderr message but exposes full output for checks. @@ -49,34 +64,72 @@ String getFullOutput() { } } + private final List commands; + + // Index of the CLI command known to work, or -1 if not yet resolved. Once + // resolved it never changes — older CLIs don't gain new flags. We use + // AtomicInteger instead of synchronization because probing must be retryable + // on transient errors: concurrent callers may redundantly probe, but all + // converge to the same index. + private final AtomicInteger activeCommandIndex = new AtomicInteger(-1); + + private final String tokenTypeField; + private final String accessTokenField; + private final String expiryField; + private final Environment env; + + /** Constructs a single-attempt source. Used by Azure CLI and simple callers. */ public CliTokenSource( List cmd, String tokenTypeField, String accessTokenField, String expiryField, Environment env) { - this(cmd, tokenTypeField, accessTokenField, expiryField, env, null, null); + this(cmd, null, tokenTypeField, accessTokenField, expiryField, env); } - public CliTokenSource( + /** Creates a CliTokenSource from a pre-built command chain. */ + static CliTokenSource fromCommands( + List commands, + String tokenTypeField, + String accessTokenField, + String expiryField, + Environment env) { + return new CliTokenSource(null, commands, tokenTypeField, accessTokenField, expiryField, env); + } + + private CliTokenSource( List cmd, + List commands, String tokenTypeField, String accessTokenField, String expiryField, - Environment env, - List fallbackCmd, - List secondFallbackCmd) { - this.cmd = OSUtils.get(env).getCliExecutableCommand(cmd); + Environment env) { + if (commands != null && !commands.isEmpty()) { + this.commands = + commands.stream() + .map( + a -> + new CliCommand( + OSUtils.get(env).getCliExecutableCommand(a.cmd), + a.usedFlags, + a.fallbackMessage)) + .collect(Collectors.toList()); + } else if (cmd != null) { + if (commands != null && commands.isEmpty()) { + LOG.warn("No CLI commands configured. Falling back to the default command."); + } + this.commands = + Collections.singletonList( + new CliCommand( + OSUtils.get(env).getCliExecutableCommand(cmd), Collections.emptyList(), null)); + } else { + throw new DatabricksException("cannot get access token: no CLI commands configured"); + } this.tokenTypeField = tokenTypeField; this.accessTokenField = accessTokenField; this.expiryField = expiryField; this.env = env; - this.fallbackCmd = - fallbackCmd != null ? OSUtils.get(env).getCliExecutableCommand(fallbackCmd) : null; - this.secondFallbackCmd = - secondFallbackCmd != null - ? OSUtils.get(env).getCliExecutableCommand(secondFallbackCmd) - : null; } /** @@ -137,8 +190,9 @@ private Token execCliCommand(List cmdToRun) throws IOException { if (stderr.contains("not found")) { throw new DatabricksException(stderr); } - // getMessage() returns the clean stderr-based message; getFullOutput() exposes - // both streams so the caller can check for "unknown flag: --profile" in either. + // getMessage() carries the clean stderr message for user-facing errors; + // getFullOutput() includes both streams so isUnknownFlagError can detect + // "unknown flag:" regardless of which stream the CLI wrote it to. throw new CliCommandException("cannot get access token: " + stderr, stdout + "\n" + stderr); } JsonNode jsonNode = new ObjectMapper().readTree(stdout); @@ -154,48 +208,61 @@ private Token execCliCommand(List cmdToRun) throws IOException { } } - private String getErrorText(IOException e) { + private static String getErrorText(IOException e) { return e instanceof CliCommandException ? ((CliCommandException) e).getFullOutput() : e.getMessage(); } - private boolean isUnknownFlagError(String errorText) { - return errorText != null && errorText.contains("unknown flag:"); + private static boolean isUnknownFlagError(String errorText, List flags) { + if (errorText == null) { + return false; + } + for (String flag : flags) { + if (errorText.contains("unknown flag: " + flag)) { + return true; + } + } + return false; } @Override public Token getToken() { - try { - return execCliCommand(this.cmd); - } catch (IOException e) { - if (fallbackCmd != null && isUnknownFlagError(getErrorText(e))) { - LOG.warn( - "CLI does not support some flags used by this SDK. " - + "Falling back to a compatible command. " - + "Please upgrade your CLI to the latest version."); - } else { + int idx = activeCommandIndex.get(); + if (idx >= 0) { + try { + return execCliCommand(commands.get(idx).cmd); + } catch (IOException e) { throw new DatabricksException(e.getMessage(), e); } } + return probeAndExec(); + } - try { - return execCliCommand(this.fallbackCmd); - } catch (IOException e) { - if (secondFallbackCmd != null && isUnknownFlagError(getErrorText(e))) { - LOG.warn( - "CLI does not support some flags used by this SDK. " - + "Falling back to a compatible command. " - + "Please upgrade your CLI to the latest version."); - } else { + /** + * Walks the command list from most-featured to simplest, looking for a CLI command that succeeds. + * When a command fails with "unknown flag" for one of its {@link CliCommand#usedFlags}, it logs a + * warning and tries the next. On success, {@link #activeCommandIndex} is stored so future calls + * skip probing. + */ + private Token probeAndExec() { + for (int i = 0; i < commands.size(); i++) { + CliCommand command = commands.get(i); + try { + Token token = execCliCommand(command.cmd); + activeCommandIndex.set(i); + return token; + } catch (IOException e) { + if (i + 1 < commands.size() && isUnknownFlagError(getErrorText(e), command.usedFlags)) { + if (command.fallbackMessage != null) { + LOG.warn(command.fallbackMessage); + } + continue; + } throw new DatabricksException(e.getMessage(), e); } } - try { - return execCliCommand(this.secondFallbackCmd); - } catch (IOException e) { - throw new DatabricksException(e.getMessage(), e); - } + throw new DatabricksException("cannot get access token: all CLI commands failed"); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index 49963601c..8c02a781a 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -80,6 +80,40 @@ private static List withForceRefresh(List cmd) { return forceCmd; } + List buildCommands(String cliPath, DatabricksConfig config) { + List commands = new ArrayList<>(); + + boolean hasProfile = config.getProfile() != null; + boolean hasHost = config.getHost() != null; + + if (hasProfile) { + List profileCmd = buildProfileArgs(cliPath, config); + + commands.add( + new CliTokenSource.CliCommand( + withForceRefresh(profileCmd), + Arrays.asList("--force-refresh", "--profile"), + "Databricks CLI does not support --force-refresh flag. " + + "Falling back to regular token fetch. " + + "Please upgrade your CLI to the latest version.")); + + commands.add( + new CliTokenSource.CliCommand( + profileCmd, + Collections.singletonList("--profile"), + "Databricks CLI does not support --profile flag. Falling back to --host. " + + "Please upgrade your CLI to the latest version.")); + } + + if (hasHost) { + commands.add( + new CliTokenSource.CliCommand( + buildHostArgs(cliPath, config), Collections.emptyList(), null)); + } + + return commands; + } + private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { String cliPath = config.getDatabricksCliPath(); if (cliPath == null) { @@ -90,29 +124,8 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { return null; } - List cmd; - List fallbackCmd = null; - List secondFallbackCmd = null; - - if (config.getProfile() != null) { - List profileArgs = buildProfileArgs(cliPath, config); - cmd = withForceRefresh(profileArgs); - fallbackCmd = profileArgs; - if (config.getHost() != null) { - secondFallbackCmd = buildHostArgs(cliPath, config); - } - } else { - cmd = buildHostArgs(cliPath, config); - } - - return new CliTokenSource( - cmd, - "token_type", - "access_token", - "expiry", - config.getEnv(), - fallbackCmd, - secondFallbackCmd); + return CliTokenSource.fromCommands( + buildCommands(cliPath, config), "token_type", "access_token", "expiry", config.getEnv()); } @Override diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 689b530d3..3a1efed19 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -24,6 +24,7 @@ import java.time.format.DateTimeParseException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -233,12 +234,35 @@ private CliTokenSource makeTokenSource( private CliTokenSource makeTokenSource( Environment env, List cmd, List fallbackCmd, List secondFallbackCmd) { + List commands = new ArrayList<>(); + + commands.add( + new CliTokenSource.CliCommand( + cmd, + fallbackCmd != null + ? Arrays.asList("--force-refresh", "--profile") + : Collections.emptyList(), + fallbackCmd != null ? "fallback" : null)); + + if (fallbackCmd != null) { + commands.add( + new CliTokenSource.CliCommand( + fallbackCmd, + secondFallbackCmd != null + ? Collections.singletonList("--profile") + : Collections.emptyList(), + secondFallbackCmd != null ? "second fallback" : null)); + } + + if (secondFallbackCmd != null) { + commands.add(new CliTokenSource.CliCommand(secondFallbackCmd, Collections.emptyList(), null)); + } + OSUtilities osUtils = mock(OSUtilities.class); when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0)); try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); - return new CliTokenSource( - cmd, "token_type", "access_token", "expiry", env, fallbackCmd, secondFallbackCmd); + return CliTokenSource.fromCommands(commands, "token_type", "access_token", "expiry", env); } } @@ -547,4 +571,48 @@ public void testTwoLevelFallbackWithNoSecondFallback() { assertEquals(2, mocked.constructed().size()); } } + + @Test + public void testActiveCommandIndexPersists() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + int call = callCount.getAndIncrement(); + if (call == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream( + "Error: unknown flag: --force-refresh".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("profile-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token first = tokenSource.getToken(); + assertEquals("profile-token", first.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + + Token second = tokenSource.getToken(); + assertEquals("profile-token", second.getAccessToken()); + assertEquals(3, mocked.constructed().size()); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java index 837e3c65b..196fd27cd 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java @@ -3,6 +3,7 @@ import static org.junit.jupiter.api.Assertions.*; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.junit.jupiter.api.Test; @@ -15,6 +16,14 @@ class DatabricksCliCredentialsProviderTest { private static final String ACCOUNT_ID = "test-account-123"; private static final String WORKSPACE_ID = "987654321"; + private static final String FORCE_REFRESH_FALLBACK_MSG = + "Databricks CLI does not support --force-refresh flag. " + + "Falling back to regular token fetch. " + + "Please upgrade your CLI to the latest version."; + private static final String PROFILE_FALLBACK_MSG = + "Databricks CLI does not support --profile flag. Falling back to --host. " + + "Please upgrade your CLI to the latest version."; + private final DatabricksCliCredentialsProvider provider = new DatabricksCliCredentialsProvider(); @Test @@ -150,4 +159,79 @@ void testBuildProfileArgs() { assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile"), cmd); } + + // ---- Command chain construction tests ---- + + @Test + void testBuildAttempts_WithProfileAndHost() { + DatabricksConfig config = new DatabricksConfig().setHost(HOST).setProfile("my-profile"); + + List commands = provider.buildCommands(CLI_PATH, config); + + assertEquals(3, commands.size()); + + assertEquals( + Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile", "--force-refresh"), + commands.get(0).cmd); + assertEquals(Arrays.asList("--force-refresh", "--profile"), commands.get(0).usedFlags); + assertEquals(FORCE_REFRESH_FALLBACK_MSG, commands.get(0).fallbackMessage); + + assertEquals( + Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile"), commands.get(1).cmd); + assertEquals(Collections.singletonList("--profile"), commands.get(1).usedFlags); + assertEquals(PROFILE_FALLBACK_MSG, commands.get(1).fallbackMessage); + + assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST), commands.get(2).cmd); + assertEquals(Collections.emptyList(), commands.get(2).usedFlags); + assertNull(commands.get(2).fallbackMessage); + } + + @Test + void testBuildAttempts_WithProfileOnly() { + DatabricksConfig config = new DatabricksConfig().setProfile("my-profile"); + + List commands = provider.buildCommands(CLI_PATH, config); + + assertEquals(2, commands.size()); + + assertEquals( + Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile", "--force-refresh"), + commands.get(0).cmd); + assertEquals(Arrays.asList("--force-refresh", "--profile"), commands.get(0).usedFlags); + assertEquals(FORCE_REFRESH_FALLBACK_MSG, commands.get(0).fallbackMessage); + + assertEquals( + Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile"), commands.get(1).cmd); + assertEquals(Collections.singletonList("--profile"), commands.get(1).usedFlags); + assertEquals(PROFILE_FALLBACK_MSG, commands.get(1).fallbackMessage); + } + + @Test + void testBuildAttempts_WithHostOnly() { + DatabricksConfig config = new DatabricksConfig().setHost(HOST); + + List commands = provider.buildCommands(CLI_PATH, config); + + assertEquals(1, commands.size()); + + assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--host", HOST), commands.get(0).cmd); + assertEquals(Collections.emptyList(), commands.get(0).usedFlags); + assertNull(commands.get(0).fallbackMessage); + } + + @Test + void testBuildAttempts_WithAccountHost() { + DatabricksConfig config = new DatabricksConfig().setHost(ACCOUNT_HOST).setAccountId(ACCOUNT_ID); + + List commands = provider.buildCommands(CLI_PATH, config); + + assertEquals(1, commands.size()); + + assertEquals( + Arrays.asList( + CLI_PATH, "auth", "token", "--host", ACCOUNT_HOST, "--account-id", ACCOUNT_ID), + commands.get(0).cmd); + assertEquals(Collections.emptyList(), commands.get(0).usedFlags); + assertNull(commands.get(0).fallbackMessage); + } }