diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 0bd9e224b..4983a3468 100755 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -4,6 +4,7 @@ ### New Features and Improvements * Added automatic detection of AI coding agents (Antigravity, Claude Code, Cline, Codex, Copilot CLI, Cursor, Gemini CLI, OpenCode) in the user-agent string. The SDK now appends `agent/` to HTTP request headers when running inside a known AI agent environment. +* Pass `--force-refresh` to Databricks CLI `auth token` command so the SDK always receives a fresh token instead of a potentially stale one from the CLI's internal cache. Falls back gracefully on older CLIs that do not support this flag. ### Bug Fixes * Fixed Databricks CLI authentication to detect when the cached token's scopes don't match the SDK's configured scopes. Previously, a scope mismatch was silently ignored, causing requests to use wrong permissions. The SDK now raises an error with instructions to re-authenticate. diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java index 58aaf7655..582811bd6 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/CliTokenSource.java @@ -26,15 +26,12 @@ public class CliTokenSource implements TokenSource { private static final Logger LOG = LoggerFactory.getLogger(CliTokenSource.class); private List cmd; + private List fallbackCmd; + private List secondFallbackCmd; private String tokenTypeField; private String accessTokenField; private String expiryField; private Environment env; - // fallbackCmd is tried when the primary command fails with "unknown flag: --profile", - // indicating the CLI is too old to support --profile. Can be removed once support - // for CLI versions predating --profile is dropped. - // See: https://github.com/databricks/databricks-sdk-go/pull/1497 - private List fallbackCmd; /** * Internal exception that carries the clean stderr message but exposes full output for checks. @@ -58,7 +55,7 @@ public CliTokenSource( String accessTokenField, String expiryField, Environment env) { - this(cmd, tokenTypeField, accessTokenField, expiryField, env, null); + this(cmd, tokenTypeField, accessTokenField, expiryField, env, null, null); } public CliTokenSource( @@ -67,8 +64,8 @@ public CliTokenSource( String accessTokenField, String expiryField, Environment env, - List fallbackCmd) { - super(); + List fallbackCmd, + List secondFallbackCmd) { this.cmd = OSUtils.get(env).getCliExecutableCommand(cmd); this.tokenTypeField = tokenTypeField; this.accessTokenField = accessTokenField; @@ -76,6 +73,10 @@ public CliTokenSource( this.env = env; this.fallbackCmd = fallbackCmd != null ? OSUtils.get(env).getCliExecutableCommand(fallbackCmd) : null; + this.secondFallbackCmd = + secondFallbackCmd != null + ? OSUtils.get(env).getCliExecutableCommand(secondFallbackCmd) + : null; } /** @@ -153,27 +154,47 @@ private Token execCliCommand(List cmdToRun) throws IOException { } } + private String getErrorText(IOException e) { + return e instanceof CliCommandException + ? ((CliCommandException) e).getFullOutput() + : e.getMessage(); + } + + private boolean isUnknownFlagError(String errorText) { + return errorText != null && errorText.contains("unknown flag:"); + } + @Override public Token getToken() { try { return execCliCommand(this.cmd); } catch (IOException e) { - String textToCheck = - e instanceof CliCommandException - ? ((CliCommandException) e).getFullOutput() - : e.getMessage(); - if (fallbackCmd != null - && textToCheck != null - && textToCheck.contains("unknown flag: --profile")) { + if (fallbackCmd != null && isUnknownFlagError(getErrorText(e))) { LOG.warn( - "Databricks CLI does not support --profile flag. Falling back to --host. " + "CLI does not support some flags used by this SDK. " + + "Falling back to a compatible command. " + "Please upgrade your CLI to the latest version."); - try { - return execCliCommand(this.fallbackCmd); - } catch (IOException fallbackException) { - throw new DatabricksException(fallbackException.getMessage(), fallbackException); - } + } else { + throw new DatabricksException(e.getMessage(), e); + } + } + + try { + return execCliCommand(this.fallbackCmd); + } catch (IOException e) { + if (secondFallbackCmd != null && isUnknownFlagError(getErrorText(e))) { + LOG.warn( + "CLI does not support some flags used by this SDK. " + + "Falling back to a compatible command. " + + "Please upgrade your CLI to the latest version."); + } else { + throw new DatabricksException(e.getMessage(), e); } + } + + try { + return execCliCommand(this.secondFallbackCmd); + } catch (IOException e) { throw new DatabricksException(e.getMessage(), e); } } diff --git a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java index b312e1a3e..49963601c 100644 --- a/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java +++ b/databricks-sdk-java/src/main/java/com/databricks/sdk/core/DatabricksCliCredentialsProvider.java @@ -69,6 +69,17 @@ List buildHostArgs(String cliPath, DatabricksConfig config) { return cmd; } + List buildProfileArgs(String cliPath, DatabricksConfig config) { + return new ArrayList<>( + Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); + } + + private static List withForceRefresh(List cmd) { + List forceCmd = new ArrayList<>(cmd); + forceCmd.add("--force-refresh"); + return forceCmd; + } + private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { String cliPath = config.getDatabricksCliPath(); if (cliPath == null) { @@ -81,23 +92,27 @@ private CliTokenSource getDatabricksCliTokenSource(DatabricksConfig config) { List cmd; List fallbackCmd = null; + List secondFallbackCmd = null; if (config.getProfile() != null) { - // When profile is set, use --profile as the primary command. - // The profile contains the full config (host, account_id, etc.). - cmd = - new ArrayList<>( - Arrays.asList(cliPath, "auth", "token", "--profile", config.getProfile())); - // Build a --host fallback for older CLIs that don't support --profile. + List profileArgs = buildProfileArgs(cliPath, config); + cmd = withForceRefresh(profileArgs); + fallbackCmd = profileArgs; if (config.getHost() != null) { - fallbackCmd = buildHostArgs(cliPath, config); + secondFallbackCmd = buildHostArgs(cliPath, config); } } else { cmd = buildHostArgs(cliPath, config); } return new CliTokenSource( - cmd, "token_type", "access_token", "expiry", config.getEnv(), fallbackCmd); + cmd, + "token_type", + "access_token", + "expiry", + config.getEnv(), + fallbackCmd, + secondFallbackCmd); } @Override diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java index 8476c6de5..689b530d3 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/CliTokenSourceTest.java @@ -40,6 +40,13 @@ import org.mockito.MockedStatic; public class CliTokenSourceTest { + private static final List FORCE_CMD = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile", "--force-refresh"); + private static final List PROFILE_CMD = + Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); + private static final List HOST_CMD = + Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); + String getExpiryStr(String dateFormat, Duration offset) { ZonedDateTime futureExpiry = ZonedDateTime.now().plus(offset); return futureExpiry.format(DateTimeFormatter.ofPattern(dateFormat)); @@ -217,16 +224,21 @@ public void testParseExpiry(String input, Instant expectedInstant, String descri } } - // ---- Fallback tests for --profile flag handling ---- + // ---- Fallback tests for --profile and --force-refresh flag handling ---- private CliTokenSource makeTokenSource( - Environment env, List primaryCmd, List fallbackCmd) { + Environment env, List cmd, List fallbackCmd) { + return makeTokenSource(env, cmd, fallbackCmd, null); + } + + private CliTokenSource makeTokenSource( + Environment env, List cmd, List fallbackCmd, List secondFallbackCmd) { OSUtilities osUtils = mock(OSUtilities.class); when(osUtils.getCliExecutableCommand(any())).thenAnswer(inv -> inv.getArgument(0)); try (MockedStatic mockedOSUtils = mockStatic(OSUtils.class)) { mockedOSUtils.when(() -> OSUtils.get(any())).thenReturn(osUtils); return new CliTokenSource( - primaryCmd, "token_type", "access_token", "expiry", env, fallbackCmd); + cmd, "token_type", "access_token", "expiry", env, fallbackCmd, secondFallbackCmd); } } @@ -245,12 +257,7 @@ public void testFallbackOnUnknownProfileFlagInStderr() { Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); AtomicInteger callCount = new AtomicInteger(0); try (MockedConstruction mocked = @@ -285,16 +292,10 @@ public void testFallbackOnUnknownProfileFlagInStderr() { @Test public void testFallbackTriggeredWhenUnknownFlagInStdout() { - // Fallback triggers even when "unknown flag" appears in stdout rather than stderr. Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); AtomicInteger callCount = new AtomicInteger(0); try (MockedConstruction mocked = @@ -329,16 +330,10 @@ public void testFallbackTriggeredWhenUnknownFlagInStdout() { @Test public void testNoFallbackOnRealAuthError() { - // When the primary fails with a real error (not unknown flag), no fallback is attempted. Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - List fallbackCmdList = - Arrays.asList("databricks", "auth", "token", "--host", "https://workspace.databricks.com"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, fallbackCmdList); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); try (MockedConstruction mocked = mockConstruction( @@ -361,14 +356,10 @@ public void testNoFallbackOnRealAuthError() { @Test public void testNoFallbackWhenFallbackCmdNotSet() { - // When fallbackCmd is null and the primary fails with unknown flag, original error propagates. Environment env = mock(Environment.class); when(env.getEnv()).thenReturn(new HashMap<>()); - List primaryCmd = - Arrays.asList("databricks", "auth", "token", "--profile", "my-profile"); - - CliTokenSource tokenSource = makeTokenSource(env, primaryCmd, null); + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, null); try (MockedConstruction mocked = mockConstruction( @@ -387,4 +378,173 @@ public void testNoFallbackWhenFallbackCmdNotSet() { assertEquals(1, mocked.constructed().size()); } } + + // ---- Force-refresh tests ---- + + @Test + public void testForceCmdSucceedsAndFallbacksNotRun() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD, HOST_CMD); + + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(validTokenJson("forced-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + })) { + Token token = tokenSource.getToken(); + assertEquals("forced-token", token.getAccessToken()); + assertEquals(1, mocked.constructed().size()); + } + } + + @Test + public void testCmdFailsWithUnknownFlagFallsBackToFallbackCmd() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + if (callCount.getAndIncrement() == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream( + "Error: unknown flag: --force-refresh".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("profile-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("profile-token", token.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + } + } + + @Test + public void testCmdAndFallbackBothFailFallsThroughToSecondFallback() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD, HOST_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + int call = callCount.getAndIncrement(); + if (call <= 1) { + // Both forceCmd and profileCmd fail with unknown --profile + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(validTokenJson("host-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("host-token", token.getAccessToken()); + assertEquals(3, mocked.constructed().size()); + } + } + + @Test + public void testRealAuthErrorDoesNotFallBack() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, FORCE_CMD, PROFILE_CMD); + + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream( + "databricks OAuth is not configured for this host".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + })) { + DatabricksException ex = assertThrows(DatabricksException.class, tokenSource::getToken); + assertTrue(ex.getMessage().contains("databricks OAuth is not configured")); + assertEquals(1, mocked.constructed().size()); + } + } + + @Test + public void testTwoLevelFallbackWithNoSecondFallback() { + Environment env = mock(Environment.class); + when(env.getEnv()).thenReturn(new HashMap<>()); + + CliTokenSource tokenSource = makeTokenSource(env, PROFILE_CMD, HOST_CMD); + + AtomicInteger callCount = new AtomicInteger(0); + try (MockedConstruction mocked = + mockConstruction( + ProcessBuilder.class, + (pb, context) -> { + if (callCount.getAndIncrement() == 0) { + Process failProcess = mock(Process.class); + when(failProcess.getInputStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(failProcess.getErrorStream()) + .thenReturn( + new ByteArrayInputStream("Error: unknown flag: --profile".getBytes())); + when(failProcess.waitFor()).thenReturn(1); + when(pb.start()).thenReturn(failProcess); + } else { + Process successProcess = mock(Process.class); + when(successProcess.getInputStream()) + .thenReturn( + new ByteArrayInputStream(validTokenJson("fallback-token").getBytes())); + when(successProcess.getErrorStream()) + .thenReturn(new ByteArrayInputStream(new byte[0])); + when(successProcess.waitFor()).thenReturn(0); + when(pb.start()).thenReturn(successProcess); + } + })) { + Token token = tokenSource.getToken(); + assertEquals("fallback-token", token.getAccessToken()); + assertEquals(2, mocked.constructed().size()); + } + } } diff --git a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java index bac4a766b..837e3c65b 100644 --- a/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java +++ b/databricks-sdk-java/src/test/java/com/databricks/sdk/core/DatabricksCliCredentialsProviderTest.java @@ -139,4 +139,15 @@ void testBuildHostArgs_UnifiedHostFalse_WithAccountHost() { CLI_PATH, "auth", "token", "--host", ACCOUNT_HOST, "--account-id", ACCOUNT_ID), cmd); } + + // ---- Profile args construction tests ---- + + @Test + void testBuildProfileArgs() { + DatabricksConfig config = new DatabricksConfig().setProfile("my-profile"); + + List cmd = provider.buildProfileArgs(CLI_PATH, config); + + assertEquals(Arrays.asList(CLI_PATH, "auth", "token", "--profile", "my-profile"), cmd); + } }