diff --git a/experimental/aitools/cmd/batch.go b/experimental/aitools/cmd/batch.go index f63599c6e82..cd1890ceb10 100644 --- a/experimental/aitools/cmd/batch.go +++ b/experimental/aitools/cmd/batch.go @@ -12,6 +12,7 @@ import ( "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "golang.org/x/sync/errgroup" ) @@ -58,6 +59,8 @@ type batchResultError struct { // reused across the batch, so callers must ensure each SQL uses only markers // that are covered. func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) []batchResult { + client := sqlexec.New(api, warehouseID) + pollCtx, pollCancel := context.WithCancel(ctx) defer pollCancel() @@ -101,17 +104,17 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware g.SetLimit(concurrency) for i, sqlStr := range sqls { g.Go(func() error { - results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, params, statementIDs, i) + results[i] = runOneBatchQuery(pollCtx, client, sqlStr, params, statementIDs, i) completed.Add(1) return nil }) } _ = g.Wait() - // pollStatement is a pure helper that returns ctx.Err() on cancellation - // without touching the server. Sweep any not-yet-terminal statements here. + // Poll returns ctx.Err() on cancellation without touching the server. + // Sweep any not-yet-terminal statements here. if pollCtx.Err() != nil { - cancelInFlight(ctx, api, statementIDs, results) + cancelInFlight(ctx, client, statementIDs, results) } return results @@ -119,33 +122,27 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware // runOneBatchQuery submits one SQL, polls to completion, and returns its // batchResult. All errors are encoded into the result; never returns an error. -func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult { +func runOneBatchQuery(ctx context.Context, client *sqlexec.Client, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult { start := time.Now() result := batchResult{SQL: sqlStr} - resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: sqlStr, - Parameters: params, - WaitTimeout: "0s", - OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, - }) + stmt, err := client.Submit(ctx, sqlStr, sqlexec.WithParameters(params)) if err != nil { if ctx.Err() != nil { result.State = sql.StatementStateCanceled result.Error = &batchResultError{Message: "submission cancelled"} } else { result.State = sql.StatementStateFailed - result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)} + result.Error = &batchResultError{Message: err.Error()} } result.ElapsedMs = time.Since(start).Milliseconds() return result } - statementIDs[idx] = resp.StatementId - result.StatementID = resp.StatementId + statementIDs[idx] = stmt.ID + result.StatementID = stmt.ID - pollResp, err := pollStatement(ctx, api, resp) + stmt, err = client.Poll(ctx, stmt) if err != nil { if ctx.Err() != nil { result.State = sql.StatementStateCanceled @@ -158,30 +155,26 @@ func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, return result } - if pollResp.Status != nil { - result.State = pollResp.Status.State - } + result.State = stmt.State - if result.State != sql.StatementStateSucceeded { - result.Error = &batchResultError{} - if pollResp.Status != nil && pollResp.Status.Error != nil { - result.Error.Message = pollResp.Status.Error.Message - result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode) - } else { - result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State) + if err := stmt.Err(); err != nil { + se, _ := errors.AsType[*sqlexec.StatementError](err) + result.Error = &batchResultError{ + Message: se.Message, + ErrorCode: string(se.Code), } result.ElapsedMs = time.Since(start).Milliseconds() return result } - result.Columns = extractColumns(pollResp.Manifest) - rows, err := fetchAllRows(ctx, api, pollResp) + res, err := client.Results(ctx, stmt) if err != nil { result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)} result.ElapsedMs = time.Since(start).Milliseconds() return result } - result.Rows = rows + result.Columns = res.Columns + result.Rows = res.Rows result.ElapsedMs = time.Since(start).Milliseconds() return result } @@ -189,7 +182,7 @@ func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, // cancelInFlight sends CancelExecution for every statement that didn't reach // a terminal state server-side before context cancellation. Best effort: errors // are logged at warn but don't fail the batch. -func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) { +func cancelInFlight(ctx context.Context, client *sqlexec.Client, statementIDs []string, results []batchResult) { var cancelled int for i, sid := range statementIDs { if sid == "" { @@ -208,7 +201,7 @@ func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, st // values but drops the cancellation signal so the cancel RPC actually // reaches the warehouse instead of short-circuiting on ctx.Err(). cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) - if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil { + if err := client.Cancel(cancelCtx, sid); err != nil { log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err) } cancel() diff --git a/experimental/aitools/cmd/discover_schema.go b/experimental/aitools/cmd/discover_schema.go index 418ab78e257..28f8bcb730b 100644 --- a/experimental/aitools/cmd/discover_schema.go +++ b/experimental/aitools/cmd/discover_schema.go @@ -19,6 +19,7 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go" dbsql "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -45,8 +46,10 @@ func newSQLGate(limit int) *sqlGate { // run executes a SQL statement asynchronously, polls until terminal, and // records the statement_id so it can be cancelled if the parent context is // cancelled. Acquires a slot from the gate before submitting and releases it -// when polling completes (or the caller's context is cancelled). -func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) { +// when polling completes (or the caller's context is cancelled). On success it +// returns the assembled result; a terminal non-success state is surfaced as the +// CLI-facing query error. +func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*sqlexec.Result, error) { // If the caller cancelled before we even tried, don't enter the select: // when the gate has free slots both cases are ready and Go picks one // pseudo-randomly. Without this early-out we'd occasionally submit a @@ -61,28 +64,25 @@ func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, wareho return nil, ctx.Err() } - resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - WaitTimeout: "0s", - OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue, - }) + client := sqlexec.New(w.StatementExecution, warehouseID) + + stmt, err := client.Submit(ctx, statement) if err != nil { - return nil, fmt.Errorf("execute statement: %w", err) + return nil, err } g.mu.Lock() - g.ids = append(g.ids, resp.StatementId) + g.ids = append(g.ids, stmt.ID) g.mu.Unlock() - pollResp, err := pollStatement(ctx, w.StatementExecution, resp) + stmt, err = client.Poll(ctx, stmt) if err != nil { return nil, err } - if err := checkFailedState(pollResp.Status); err != nil { + if err := presentQueryError(stmt.Err()); err != nil { return nil, err } - return pollResp, nil + return client.Results(ctx, stmt) } // trackedIDs returns a snapshot of statement_ids submitted through this gate. @@ -235,9 +235,11 @@ func cancelDiscoverInFlight(ctx context.Context, api dbsql.StatementExecutionInt cmdio.LogString(ctx, "discover-schema cancelled.") return } + // Cancel/Poll/Get don't use the warehouse ID, so an empty one is fine here. + client := sqlexec.New(api, "") for _, id := range ids { cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout) - if err := api.CancelExecution(cancelCtx, dbsql.CancelExecutionRequest{StatementId: id}); err != nil { + if err := client.Cancel(cancelCtx, id); err != nil { log.Warnf(ctx, "Failed to cancel statement %s: %v", id, err) } cancel() @@ -252,12 +254,12 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl } // 1. describe table - get columns and types - descResp, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted) + descResult, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted) if err != nil { return "", fmt.Errorf("describe table: %w", err) } - columns, types := parseDescribeResult(descResp) + columns, types := parseDescribeResult(descResult) if len(columns) == 0 { return "", errors.New("no columns found") } @@ -281,16 +283,16 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s", strings.Join(nullCountExprs, ", "), quoted) - var sampleResp, nullResp *dbsql.StatementResponse + var sampleResult, nullResult *sqlexec.Result var sampleErr, nullErr error g := new(errgroup.Group) g.Go(func() error { - sampleResp, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL) + sampleResult, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL) return nil }) g.Go(func() error { - nullResp, nullErr = gate.run(ctx, w, warehouseID, nullSQL) + nullResult, nullErr = gate.run(ctx, w, warehouseID, nullSQL) return nil }) _ = g.Wait() @@ -306,25 +308,21 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", sampleErr) } else { sb.WriteString("\nSAMPLE DATA:\n") - sb.WriteString(formatTableData(sampleResp)) + sb.WriteString(formatTableData(sampleResult)) } if nullErr != nil { fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", nullErr) } else { sb.WriteString("\nNULL COUNTS:\n") - sb.WriteString(formatNullCounts(nullResp, columns)) + sb.WriteString(formatNullCounts(nullResult, columns)) } return sb.String(), nil } -func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) { - if resp.Result == nil || resp.Result.DataArray == nil { - return nil, nil - } - - for _, row := range resp.Result.DataArray { +func parseDescribeResult(result *sqlexec.Result) (columns, types []string) { + for _, row := range result.Rows { if len(row) < 2 { continue } @@ -340,20 +338,15 @@ func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string return columns, types } -func formatTableData(resp *dbsql.StatementResponse) string { - if resp.Result == nil || resp.Result.DataArray == nil || len(resp.Result.DataArray) == 0 { +func formatTableData(result *sqlexec.Result) string { + if len(result.Rows) == 0 { return " (no data)\n" } var sb strings.Builder - var columns []string - if resp.Manifest != nil && resp.Manifest.Schema != nil { - for _, col := range resp.Manifest.Schema.Columns { - columns = append(columns, col.Name) - } - } + columns := result.Columns - for i, row := range resp.Result.DataArray { + for i, row := range result.Rows { fmt.Fprintf(&sb, " Row %d:\n", i+1) for j, val := range row { colName := fmt.Sprintf("col%d", j) @@ -366,12 +359,12 @@ func formatTableData(resp *dbsql.StatementResponse) string { return sb.String() } -func formatNullCounts(resp *dbsql.StatementResponse, columns []string) string { - if resp.Result == nil || resp.Result.DataArray == nil || len(resp.Result.DataArray) == 0 { +func formatNullCounts(result *sqlexec.Result, columns []string) string { + if len(result.Rows) == 0 { return " (no data)\n" } - row := resp.Result.DataArray[0] + row := result.Rows[0] var sb strings.Builder // first value is total_rows diff --git a/experimental/aitools/cmd/discover_schema_test.go b/experimental/aitools/cmd/discover_schema_test.go index b76004367c9..d0fae4b24c2 100644 --- a/experimental/aitools/cmd/discover_schema_test.go +++ b/experimental/aitools/cmd/discover_schema_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go" mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" dbsql "github.com/databricks/databricks-sdk-go/service/sql" @@ -50,17 +51,17 @@ func TestQuoteTableName(t *testing.T) { } func TestParseDescribeResultSkipsMetadataRows(t *testing.T) { - resp := &dbsql.StatementResponse{ - Result: &dbsql.ResultData{DataArray: [][]string{ + result := &sqlexec.Result{ + Rows: [][]string{ {"id", "BIGINT", ""}, {"name", "STRING", ""}, {"# Partition Information", "", ""}, {"region", "STRING", ""}, {"", "STRING", ""}, - }}, + }, } - cols, types := parseDescribeResult(resp) + cols, types := parseDescribeResult(result) assert.Equal(t, []string{"id", "name", "region"}, cols) assert.Equal(t, []string{"BIGINT", "STRING", "STRING"}, types) } @@ -82,9 +83,9 @@ func TestSQLGateRunPinsOnWaitTimeoutAndRecordsID(t *testing.T) { w := &databricks.WorkspaceClient{StatementExecution: mockAPI} gate := newSQLGate(2) - resp, err := gate.run(ctx, w, "wh-1", "SELECT 1") + result, err := gate.run(ctx, w, "wh-1", "SELECT 1") require.NoError(t, err) - assert.Equal(t, "stmt-1", resp.StatementId) + assert.Equal(t, [][]string{{"1"}}, result.Rows) assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs()) } diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 873932b5158..1cb2c02da99 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -17,17 +17,12 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/log" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) const ( - // pollIntervalInitial is the starting interval between status polls. - pollIntervalInitial = 1 * time.Second - - // pollIntervalMax is the maximum interval between status polls. - pollIntervalMax = 5 * time.Second - // cancelTimeout is how long to wait for server-side cancellation. cancelTimeout = 10 * time.Second ) @@ -158,16 +153,19 @@ multi-query mode, the same parameter set is applied to every statement.`, return runBatch(ctx, cmd, w.StatementExecution, wID, sqls, params, concurrency) } - resp, err := executeAndPoll(ctx, w.StatementExecution, wID, sqls[0], params) + client := sqlexec.New(w.StatementExecution, wID) + + stmt, err := executeAndPoll(ctx, client, sqls[0], params) if err != nil { return err } - columns := extractColumns(resp.Manifest) - rows, err := fetchAllRows(ctx, w.StatementExecution, resp) + result, err := client.Results(ctx, stmt) if err != nil { return err } + columns := result.Columns + rows := result.Rows // CSV bypasses the normal output mode selection. if format == sqlcli.OutputCSV { @@ -271,20 +269,14 @@ func resolveWarehouseID(ctx context.Context, w any, flagValue string) (string, e // executeAndPoll submits a SQL statement asynchronously and polls until completion. // It shows a spinner in interactive mode and supports Ctrl+C cancellation. -func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, statement string, params []sql.StatementParameterListItem) (*sql.StatementResponse, error) { +func executeAndPoll(ctx context.Context, client *sqlexec.Client, statement string, params []sql.StatementParameterListItem) (*sqlexec.Statement, error) { // Submit asynchronously to get the statement ID immediately for cancellation. - resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - Parameters: params, - WaitTimeout: "0s", - OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, - }) + stmt, err := client.Submit(ctx, statement, sqlexec.WithParameters(params)) if err != nil { - return nil, fmt.Errorf("execute statement: %w", err) + return nil, err } - statementID := resp.StatementId + statementID := stmt.ID // Set up Ctrl+C: signal cancels the poll context, cleanup is unified below. pollCtx, pollCancel := context.WithCancel(ctx) @@ -312,9 +304,7 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa // reaches the warehouse. cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout) defer cancel() - if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{ - StatementId: statementID, - }); err != nil { + if err := client.Cancel(cancelCtx, statementID); err != nil { log.Warnf(ctx, "Failed to cancel statement %s: %v", statementID, err) } } @@ -339,7 +329,7 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } }() - pollResp, err := pollStatement(pollCtx, api, resp) + stmt, err = client.Poll(pollCtx, stmt) if err != nil { if pollCtx.Err() != nil { cancelStatement() @@ -350,111 +340,34 @@ func executeAndPoll(ctx context.Context, api sql.StatementExecutionInterface, wa } sp.Close() - if err := checkFailedState(pollResp.Status); err != nil { + if err := presentQueryError(stmt.Err()); err != nil { return nil, err } - return pollResp, nil + return stmt, nil } -// pollStatement polls until the statement reaches a terminal state. -// -// On context cancellation it returns the context error WITHOUT cancelling the -// server-side statement. Callers that want server-side cancellation should -// invoke CancelExecution explicitly. -// -// If the input response is already in a terminal state, it is returned without -// further polling. -func pollStatement(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) (*sql.StatementResponse, error) { - if isTerminalState(resp.Status) { - return resp, nil - } - - statementID := resp.StatementId - start := time.Now() - - // Poll with additive backoff: 1s, 2s, 3s, 4s, 5s (capped). - interval := pollIntervalInitial - for { - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-time.After(interval): - } - - log.Debugf(ctx, "Polling statement %s: %s elapsed", statementID, time.Since(start).Truncate(time.Second)) - - pollResp, err := api.GetStatementByStatementId(ctx, statementID) - if err != nil { - if ctx.Err() != nil { - return nil, ctx.Err() - } - return nil, fmt.Errorf("poll statement status: %w", err) - } - - if isTerminalState(pollResp.Status) { - return &sql.StatementResponse{ - StatementId: pollResp.StatementId, - Status: pollResp.Status, - Manifest: pollResp.Manifest, - Result: pollResp.Result, - }, nil - } - - interval = min(interval+time.Second, pollIntervalMax) - } -} - -// fetchAllRows collects all result rows, fetching additional chunks if needed. -func fetchAllRows(ctx context.Context, api sql.StatementExecutionInterface, resp *sql.StatementResponse) ([][]string, error) { - if resp.Result == nil { - return nil, nil - } - - rows := append([][]string{}, resp.Result.DataArray...) - - totalChunks := 0 - if resp.Manifest != nil { - totalChunks = resp.Manifest.TotalChunkCount - } - - for chunk := 1; chunk < totalChunks; chunk++ { - log.Debugf(ctx, "Fetching result chunk %d/%d for statement %s", chunk+1, totalChunks, resp.StatementId) - chunkResp, err := api.GetStatementResultChunkNByStatementIdAndChunkIndex(ctx, resp.StatementId, chunk) - if err != nil { - return nil, fmt.Errorf("fetch result chunk %d: %w", chunk, err) - } - rows = append(rows, chunkResp.DataArray...) - } - - return rows, nil -} - -// isTerminalState returns true if the statement has reached a final state. -func isTerminalState(status *sql.StatementStatus) bool { - if status == nil { - return false +// presentQueryError converts the engine's structured statement error into the +// CLI-facing message for the query and discover-schema commands. It returns nil +// for a nil error or any error that is not a *sqlexec.StatementError (the engine +// only produces the latter on terminal non-success states). +func presentQueryError(err error) error { + if err == nil { + return nil } - switch status.State { - case sql.StatementStateSucceeded, sql.StatementStateFailed, - sql.StatementStateCanceled, sql.StatementStateClosed: - return true - case sql.StatementStatePending, sql.StatementStateRunning: - return false + se, ok := errors.AsType[*sqlexec.StatementError](err) + if !ok { + return err } - return false -} -// checkFailedState returns an error if the statement is in a non-success terminal state. -func checkFailedState(status *sql.StatementStatus) error { - if status == nil { - return nil - } - switch status.State { + switch se.State { case sql.StatementStateFailed: msg := "query failed" - if status.Error != nil { - msg = fmt.Sprintf("query failed: %s %s", status.Error.ErrorCode, status.Error.Message) - if strings.Contains(status.Error.Message, "UNRESOLVED_MAP_KEY") { + // The engine populates Code only when the backend returned a + // ServiceError; otherwise Message is a synthesized state string we + // don't surface here, matching the original "query failed" fallback. + if se.Code != "" { + msg = fmt.Sprintf("query failed: %s %s", se.Code, se.Message) + if strings.Contains(se.Message, "UNRESOLVED_MAP_KEY") { msg += "\n\nHint: your shell may have stripped quotes from the SQL string. " + "Use single quotes for map keys (e.g. info['key']) or pass the query via --file." } @@ -464,10 +377,9 @@ func checkFailedState(status *sql.StatementStatus) error { return errors.New("query was cancelled") case sql.StatementStateClosed: return errors.New("query was closed before results could be fetched") - case sql.StatementStatePending, sql.StatementStateRunning, sql.StatementStateSucceeded: - return nil + default: + return err } - return nil } // cleanSQL removes surrounding quotes, empty lines, and SQL comments. diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 50197e12d92..66e70b238ee 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -7,12 +7,12 @@ import ( "path/filepath" "strings" "testing" - "time" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/sqlexec" mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -59,10 +59,10 @@ func TestExecuteAndPollImmediateSuccess(t *testing.T) { Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, }, nil) - resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1", nil) + stmt, err := executeAndPoll(ctx, sqlexec.New(mockAPI, "wh-123"), "SELECT 1", nil) require.NoError(t, err) - assert.Equal(t, sql.StatementStateSucceeded, resp.Status.State) - assert.Equal(t, "stmt-1", resp.StatementId) + assert.Equal(t, sql.StatementStateSucceeded, stmt.State) + assert.Equal(t, "stmt-1", stmt.ID) } func TestExecuteAndPollPassesParameters(t *testing.T) { @@ -81,7 +81,7 @@ func TestExecuteAndPollPassesParameters(t *testing.T) { Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, }, nil) - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT * FROM t WHERE name = :name AND ts > :since", params) + _, err := executeAndPoll(ctx, sqlexec.New(mockAPI, "wh-123"), "SELECT * FROM t WHERE name = :name AND ts > :since", params) require.NoError(t, err) } @@ -100,7 +100,7 @@ func TestExecuteAndPollImmediateFailure(t *testing.T) { }, }, nil) - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELCT 1", nil) + _, err := executeAndPoll(ctx, sqlexec.New(mockAPI, "wh-123"), "SELCT 1", nil) require.Error(t, err) assert.Contains(t, err.Error(), "SYNTAX_ERROR") assert.Contains(t, err.Error(), "syntax error") @@ -129,10 +129,14 @@ func TestExecuteAndPollWithPolling(t *testing.T) { Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, }, nil).Once() - resp, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 42", nil) + client := sqlexec.New(mockAPI, "wh-123") + stmt, err := executeAndPoll(ctx, client, "SELECT 42", nil) require.NoError(t, err) - assert.Equal(t, sql.StatementStateSucceeded, resp.Status.State) - assert.Equal(t, [][]string{{"42"}}, resp.Result.DataArray) + assert.Equal(t, sql.StatementStateSucceeded, stmt.State) + + result, err := client.Results(ctx, stmt) + require.NoError(t, err) + assert.Equal(t, [][]string{{"42"}}, result.Rows) } func TestExecuteAndPollFailsDuringPolling(t *testing.T) { @@ -152,7 +156,7 @@ func TestExecuteAndPollFailsDuringPolling(t *testing.T) { }, }, nil).Once() - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1", nil) + _, err := executeAndPoll(ctx, sqlexec.New(mockAPI, "wh-123"), "SELECT 1", nil) require.Error(t, err) assert.Contains(t, err.Error(), "RESOURCE_EXHAUSTED") } @@ -177,115 +181,51 @@ func TestExecuteAndPollCancelledContextCallsCancelExecution(t *testing.T) { cancel() - _, err := executeAndPoll(ctx, mockAPI, "wh-123", "SELECT 1", nil) + _, err := executeAndPoll(ctx, sqlexec.New(mockAPI, "wh-123"), "SELECT 1", nil) require.ErrorIs(t, err, root.ErrAlreadyPrinted) } -func TestPollStatementImmediateTerminal(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - resp := &sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, - Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "1"}}}}, - Result: &sql.ResultData{DataArray: [][]string{{"1"}}}, - } - - pollResp, err := pollStatement(ctx, mockAPI, resp) - require.NoError(t, err) - assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) - assert.Equal(t, "stmt-1", pollResp.StatementId) -} - -func TestPollStatementTerminalFailureNotErrored(t *testing.T) { - // pollStatement returns the response without erroring on failed terminal - // states; callers (e.g. executeAndPoll) decide what to do via checkFailedState. - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - resp := &sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{ - State: sql.StatementStateFailed, - Error: &sql.ServiceError{ErrorCode: "ERR", Message: "boom"}, - }, - } - - pollResp, err := pollStatement(ctx, mockAPI, resp) - require.NoError(t, err) - assert.Equal(t, sql.StatementStateFailed, pollResp.Status.State) -} - -func TestPollStatementEventualSuccess(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - initial := &sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStatePending}, - } - - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateRunning}, - }, nil).Once() - - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1").Return(&sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, - Result: &sql.ResultData{DataArray: [][]string{{"42"}}}, - }, nil).Once() - - pollResp, err := pollStatement(ctx, mockAPI, initial) +func TestResolveWarehouseIDWithFlag(t *testing.T) { + ctx := t.Context() + id, err := resolveWarehouseID(ctx, nil, "explicit-id") require.NoError(t, err) - assert.Equal(t, sql.StatementStateSucceeded, pollResp.Status.State) - assert.Equal(t, [][]string{{"42"}}, pollResp.Result.DataArray) + assert.Equal(t, "explicit-id", id) } -func TestPollStatementContextCancellationDoesNotCancelServerSide(t *testing.T) { - // The mock asserts (via t.Cleanup) that no unexpected calls are made. - // Specifically, pollStatement must NOT call CancelExecution on context - // cancellation; that is the caller's responsibility. - ctx, cancel := context.WithCancel(cmdio.MockDiscard(t.Context())) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - initial := &sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStatePending}, - } - - cancel() +func TestPresentQueryError(t *testing.T) { + assert.NoError(t, presentQueryError(nil)) - pollResp, err := pollStatement(ctx, mockAPI, initial) - require.ErrorIs(t, err, context.Canceled) - assert.Nil(t, pollResp) -} - -func TestPollStatementGetErrorPropagated(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) + // Non-StatementError passes through unchanged. + plain := errors.New("boom") + assert.Equal(t, plain, presentQueryError(plain)) - initial := &sql.StatementResponse{ - StatementId: "stmt-1", - Status: &sql.StatementStatus{State: sql.StatementStatePending}, - } + err := presentQueryError(&sqlexec.StatementError{ + State: sql.StatementStateFailed, + Code: "ERR", + Message: "bad", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "query failed: ERR bad") - mockAPI.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). - Return(nil, errors.New("network unreachable")).Once() + err = presentQueryError(&sqlexec.StatementError{State: sql.StatementStateCanceled}) + require.Error(t, err) + assert.Contains(t, err.Error(), "cancelled") - pollResp, err := pollStatement(ctx, mockAPI, initial) + err = presentQueryError(&sqlexec.StatementError{State: sql.StatementStateClosed}) require.Error(t, err) - assert.Contains(t, err.Error(), "poll statement status") - assert.Contains(t, err.Error(), "network unreachable") - assert.Nil(t, pollResp) + assert.Contains(t, err.Error(), "closed") } -func TestResolveWarehouseIDWithFlag(t *testing.T) { - ctx := t.Context() - id, err := resolveWarehouseID(ctx, nil, "explicit-id") - require.NoError(t, err) - assert.Equal(t, "explicit-id", id) +func TestPresentQueryErrorMapKeyHint(t *testing.T) { + err := presentQueryError(&sqlexec.StatementError{ + State: sql.StatementStateFailed, + Code: "BAD_REQUEST", + Message: "[UNRESOLVED_MAP_KEY.WITH_SUGGESTION] Cannot resolve column", + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "Hint:") + assert.Contains(t, err.Error(), "single quotes") + assert.Contains(t, err.Error(), "--file") } func TestSelectQueryOutputMode(t *testing.T) { @@ -347,116 +287,6 @@ func TestSelectQueryOutputMode(t *testing.T) { } } -func TestFetchAllRowsSingleChunk(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - resp := &sql.StatementResponse{ - StatementId: "stmt-1", - Manifest: &sql.ResultManifest{TotalChunkCount: 1}, - Result: &sql.ResultData{DataArray: [][]string{{"1", "alice"}, {"2", "bob"}}}, - } - - rows, err := fetchAllRows(ctx, mockAPI, resp) - require.NoError(t, err) - assert.Equal(t, [][]string{{"1", "alice"}, {"2", "bob"}}, rows) -} - -func TestFetchAllRowsMultiChunk(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - resp := &sql.StatementResponse{ - StatementId: "stmt-1", - Manifest: &sql.ResultManifest{TotalChunkCount: 3}, - Result: &sql.ResultData{DataArray: [][]string{{"1", "a"}}}, - } - - mockAPI.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 1). - Return(&sql.ResultData{DataArray: [][]string{{"2", "b"}}}, nil).Once() - mockAPI.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 2). - Return(&sql.ResultData{DataArray: [][]string{{"3", "c"}}}, nil).Once() - - rows, err := fetchAllRows(ctx, mockAPI, resp) - require.NoError(t, err) - assert.Equal(t, [][]string{{"1", "a"}, {"2", "b"}, {"3", "c"}}, rows) -} - -func TestFetchAllRowsNilResult(t *testing.T) { - ctx := cmdio.MockDiscard(t.Context()) - mockAPI := mocksql.NewMockStatementExecutionInterface(t) - - resp := &sql.StatementResponse{StatementId: "stmt-1"} - - rows, err := fetchAllRows(ctx, mockAPI, resp) - require.NoError(t, err) - assert.Nil(t, rows) -} - -func TestIsTerminalState(t *testing.T) { - tests := []struct { - state sql.StatementState - terminal bool - }{ - {sql.StatementStateSucceeded, true}, - {sql.StatementStateFailed, true}, - {sql.StatementStateCanceled, true}, - {sql.StatementStateClosed, true}, - {sql.StatementStatePending, false}, - {sql.StatementStateRunning, false}, - } - - for _, tc := range tests { - t.Run(string(tc.state), func(t *testing.T) { - status := &sql.StatementStatus{State: tc.state} - assert.Equal(t, tc.terminal, isTerminalState(status)) - }) - } - - assert.False(t, isTerminalState(nil)) -} - -func TestCheckFailedState(t *testing.T) { - assert.NoError(t, checkFailedState(nil)) - assert.NoError(t, checkFailedState(&sql.StatementStatus{State: sql.StatementStateSucceeded})) - - err := checkFailedState(&sql.StatementStatus{ - State: sql.StatementStateFailed, - Error: &sql.ServiceError{ErrorCode: "ERR", Message: "bad"}, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "ERR") - assert.Contains(t, err.Error(), "bad") - - err = checkFailedState(&sql.StatementStatus{State: sql.StatementStateCanceled}) - require.Error(t, err) - assert.Contains(t, err.Error(), "cancelled") - - err = checkFailedState(&sql.StatementStatus{State: sql.StatementStateClosed}) - require.Error(t, err) - assert.Contains(t, err.Error(), "closed") -} - -func TestCheckFailedStateMapKeyHint(t *testing.T) { - err := checkFailedState(&sql.StatementStatus{ - State: sql.StatementStateFailed, - Error: &sql.ServiceError{ - ErrorCode: "BAD_REQUEST", - Message: "[UNRESOLVED_MAP_KEY.WITH_SUGGESTION] Cannot resolve column", - }, - }) - require.Error(t, err) - assert.Contains(t, err.Error(), "Hint:") - assert.Contains(t, err.Error(), "single quotes") - assert.Contains(t, err.Error(), "--file") -} - -func TestPollingConstants(t *testing.T) { - assert.Equal(t, 1*time.Second, pollIntervalInitial) - assert.Equal(t, 5*time.Second, pollIntervalMax) - assert.Equal(t, 10*time.Second, cancelTimeout) -} - // newTestCmd creates a minimal cobra.Command for testing resolveSQLs. func newTestCmd() *cobra.Command { return &cobra.Command{Use: "test"} diff --git a/experimental/aitools/cmd/render.go b/experimental/aitools/cmd/render.go index d0b62926c20..22abd516b0b 100644 --- a/experimental/aitools/cmd/render.go +++ b/experimental/aitools/cmd/render.go @@ -9,7 +9,6 @@ import ( "text/tabwriter" "github.com/databricks/cli/libs/tableview" - "github.com/databricks/databricks-sdk-go/service/sql" ) const ( @@ -17,18 +16,6 @@ const ( maxColumnWidth = 40 ) -// extractColumns returns column names from the query result manifest. -func extractColumns(manifest *sql.ResultManifest) []string { - if manifest == nil || manifest.Schema == nil { - return nil - } - columns := make([]string, len(manifest.Schema.Columns)) - for i, col := range manifest.Schema.Columns { - columns[i] = col.Name - } - return columns -} - // renderBatchJSON writes batch results as a JSON array. The array preserves // input order and includes one object per submitted statement. func renderBatchJSON(w io.Writer, results []batchResult) error { diff --git a/experimental/aitools/cmd/render_test.go b/experimental/aitools/cmd/render_test.go index 6d9cf760eef..4e26eff85f3 100644 --- a/experimental/aitools/cmd/render_test.go +++ b/experimental/aitools/cmd/render_test.go @@ -4,40 +4,10 @@ import ( "bytes" "testing" - "github.com/databricks/databricks-sdk-go/service/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestExtractColumns(t *testing.T) { - tests := []struct { - name string - manifest *sql.ResultManifest - want []string - }{ - { - "with columns", - &sql.ResultManifest{Schema: &sql.ResultSchema{ - Columns: []sql.ColumnInfo{{Name: "id"}, {Name: "name"}}, - }}, - []string{"id", "name"}, - }, - {"nil manifest", nil, nil}, - {"nil schema", &sql.ResultManifest{}, nil}, - { - "empty columns", - &sql.ResultManifest{Schema: &sql.ResultSchema{}}, - []string{}, - }, - } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - got := extractColumns(tc.manifest) - assert.Equal(t, tc.want, got) - }) - } -} - func TestRenderJSON(t *testing.T) { var buf bytes.Buffer columns := []string{"id", "name"} diff --git a/experimental/aitools/cmd/statement.go b/experimental/aitools/cmd/statement.go index e1c48a7ddbe..13d2f70e316 100644 --- a/experimental/aitools/cmd/statement.go +++ b/experimental/aitools/cmd/statement.go @@ -2,9 +2,11 @@ package aitools import ( "encoding/json" + "errors" "fmt" "io" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) @@ -31,25 +33,26 @@ func renderStatementInfo(w io.Writer, info statementInfo) error { return nil } -// statementErrorFromStatus builds a batchResultError for any terminal non-success -// state (FAILED, CANCELED, CLOSED), populating it from the server's ServiceError -// when available and synthesizing a message when it isn't. Returns nil for -// SUCCEEDED, non-terminal states, and nil status. The synthesized fallback -// matters because the Statements API can hand back a non-success terminal state -// with `Error == nil`, and skill consumers should be able to branch on +// statementError converts the engine's structured statement error into the +// batchResultError shape emitted in JSON output. It returns nil for a nil error +// or any error that is not a *sqlexec.StatementError (the engine produces the +// latter only on terminal non-success states). The error's Message and Code are +// surfaced directly rather than the formatted Error() string, and the engine +// synthesizes a "statement reached terminal state " message when the +// backend reports no ServiceError, so skill consumers can branch on // `error == null` alone instead of inspecting `state`. -func statementErrorFromStatus(status *sql.StatementStatus) *batchResultError { - if status == nil || !isTerminalState(status) || status.State == sql.StatementStateSucceeded { +func statementError(err error) *batchResultError { + if err == nil { return nil } - out := &batchResultError{} - if status.Error != nil { - out.Message = status.Error.Message - out.ErrorCode = string(status.Error.ErrorCode) - } else { - out.Message = fmt.Sprintf("statement reached terminal state %s", status.State) + se, ok := errors.AsType[*sqlexec.StatementError](err) + if !ok { + return &batchResultError{Message: err.Error()} + } + return &batchResultError{ + Message: se.Message, + ErrorCode: string(se.Code), } - return out } func newStatementCmd() *cobra.Command { diff --git a/experimental/aitools/cmd/statement_cancel.go b/experimental/aitools/cmd/statement_cancel.go index 1774b7abe6a..912f20d83d6 100644 --- a/experimental/aitools/cmd/statement_cancel.go +++ b/experimental/aitools/cmd/statement_cancel.go @@ -2,10 +2,10 @@ package aitools import ( "context" - "fmt" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) @@ -41,10 +41,10 @@ the server-side state if you need certainty.`, // CancelExecution returns no body; the actual server-side state is verified // asynchronously. Use 'statement status' to confirm if certainty is required. func cancelStatementExecution(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { - if err := api.CancelExecution(ctx, sql.CancelExecutionRequest{ - StatementId: statementID, - }); err != nil { - return statementInfo{}, fmt.Errorf("cancel statement: %w", err) + // Cancel doesn't use the warehouse ID. + client := sqlexec.New(api, "") + if err := client.Cancel(ctx, statementID); err != nil { + return statementInfo{}, err } return statementInfo{ StatementID: statementID, diff --git a/experimental/aitools/cmd/statement_get.go b/experimental/aitools/cmd/statement_get.go index 617b5c274dd..fe011464a0c 100644 --- a/experimental/aitools/cmd/statement_get.go +++ b/experimental/aitools/cmd/statement_get.go @@ -6,6 +6,7 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) @@ -54,43 +55,48 @@ invoked the synchronous path.)`, // getStatementResult polls a statement until terminal, then assembles a // statementInfo with rows on success or an error object on failure. // -// Context cancellation propagates from pollStatement WITHOUT cancelling the +// Context cancellation propagates from the poll WITHOUT cancelling the // server-side statement (intentional: 'get' is a poll-only operation; use // 'cancel' to terminate explicitly). func getStatementResult(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { - // Fetch the current state first so pollStatement can short-circuit if - // the statement is already terminal. - resp, err := api.GetStatementByStatementId(ctx, statementID) + // Get/Poll don't use the warehouse ID. + client := sqlexec.New(api, "") + + // Fetch the current state first so Poll can short-circuit if the statement + // is already terminal. + stmt, err := client.Get(ctx, statementID) if err != nil { - return statementInfo{}, fmt.Errorf("get statement: %w", err) + return statementInfo{}, err } - pollResp, err := pollStatement(ctx, api, resp) + stmt, err = client.Poll(ctx, stmt) if err != nil { return statementInfo{}, err } - info := statementInfo{StatementID: pollResp.StatementId} - if pollResp.Status != nil { - info.State = pollResp.Status.State + info := statementInfo{ + StatementID: stmt.ID, + State: stmt.State, + Error: statementError(stmt.Err()), } - info.Error = statementErrorFromStatus(pollResp.Status) if info.State == sql.StatementStateSucceeded { - info.Columns = extractColumns(pollResp.Manifest) - rows, err := fetchAllRows(ctx, api, pollResp) + result, err := client.Results(ctx, stmt) if err != nil { // The query succeeded server-side but a later chunk fetch failed // (network blip, throttling, transient 5xx). Surface this as a // structured error on the same statementInfo so the caller still - // gets a parseable JSON response with the statement_id; RunE then - // signals exit-non-zero based on info.Error. + // gets a parseable JSON response with the statement_id and the + // column metadata (known from the manifest before any chunk fetch); + // RunE then signals exit-non-zero based on info.Error. + info.Columns = stmt.Columns() info.Error = &batchResultError{ Message: fmt.Sprintf("fetch result rows: %v", err), } return info, nil } - info.Rows = rows + info.Columns = result.Columns + info.Rows = result.Rows } return info, nil } diff --git a/experimental/aitools/cmd/statement_status.go b/experimental/aitools/cmd/statement_status.go index 9981f49aa63..91ea9c36408 100644 --- a/experimental/aitools/cmd/statement_status.go +++ b/experimental/aitools/cmd/statement_status.go @@ -2,10 +2,10 @@ package aitools import ( "context" - "fmt" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) @@ -38,15 +38,16 @@ without blocking. For a blocking poll-until-terminal call, use // getStatementStatus performs a single GET against the Statements API, no polling. func getStatementStatus(ctx context.Context, api sql.StatementExecutionInterface, statementID string) (statementInfo, error) { - resp, err := api.GetStatementByStatementId(ctx, statementID) + // Get doesn't use the warehouse ID. + client := sqlexec.New(api, "") + stmt, err := client.Get(ctx, statementID) if err != nil { - return statementInfo{}, fmt.Errorf("get statement: %w", err) + return statementInfo{}, err } - info := statementInfo{StatementID: resp.StatementId} - if resp.Status != nil { - info.State = resp.Status.State - } - info.Error = statementErrorFromStatus(resp.Status) - return info, nil + return statementInfo{ + StatementID: stmt.ID, + State: stmt.State, + Error: statementError(stmt.Err()), + }, nil } diff --git a/experimental/aitools/cmd/statement_submit.go b/experimental/aitools/cmd/statement_submit.go index c578590b50f..716f6550ee0 100644 --- a/experimental/aitools/cmd/statement_submit.go +++ b/experimental/aitools/cmd/statement_submit.go @@ -3,10 +3,10 @@ package aitools import ( "context" "errors" - "fmt" "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/sqlexec" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" ) @@ -86,23 +86,15 @@ bind values.`, // submitStatement issues an asynchronous ExecuteStatement and returns the handle. func submitStatement(ctx context.Context, api sql.StatementExecutionInterface, statement, warehouseID string, params []sql.StatementParameterListItem) (statementInfo, error) { - resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{ - WarehouseId: warehouseID, - Statement: statement, - Parameters: params, - WaitTimeout: "0s", - OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, - }) + client := sqlexec.New(api, warehouseID) + stmt, err := client.Submit(ctx, statement, sqlexec.WithParameters(params)) if err != nil { - return statementInfo{}, fmt.Errorf("execute statement: %w", err) + return statementInfo{}, err } - info := statementInfo{ - StatementID: resp.StatementId, + return statementInfo{ + StatementID: stmt.ID, + State: stmt.State, WarehouseID: warehouseID, - } - if resp.Status != nil { - info.State = resp.Status.State - } - return info, nil + }, nil } diff --git a/experimental/aitools/cmd/statement_test.go b/experimental/aitools/cmd/statement_test.go index ff1e9fd4b25..e6905b3f20d 100644 --- a/experimental/aitools/cmd/statement_test.go +++ b/experimental/aitools/cmd/statement_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/sqlexec" mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/stretchr/testify/assert" @@ -161,10 +162,12 @@ func TestGetStatementResultChunkFetchFailureRendersPartialInfo(t *testing.T) { info, err := getStatementResult(ctx, mockAPI, "stmt-1") require.NoError(t, err) assert.Equal(t, sql.StatementStateSucceeded, info.State) - assert.Equal(t, []string{"n"}, info.Columns, "columns from the initial response are still surfaced") require.NotNil(t, info.Error) assert.Contains(t, info.Error.Message, "fetch result rows") assert.Contains(t, info.Error.Message, "network blip") + // Column metadata is known from the manifest before the failed chunk fetch, + // so it is still surfaced alongside the error. + assert.Equal(t, []string{"n"}, info.Columns) } func TestGetStatementStatusSinglePoll(t *testing.T) { @@ -303,63 +306,45 @@ func TestStatementSubmitRejectsMultipleSQLsBeforeWorkspaceClient(t *testing.T) { assert.Contains(t, err.Error(), "exactly one") } -func TestStatementErrorFromStatus(t *testing.T) { +func TestStatementError(t *testing.T) { tests := []struct { name string - status *sql.StatementStatus + err error wantNil bool wantMsg string wantCode string }{ { - name: "nil status", - status: nil, - wantNil: true, - }, - { - name: "succeeded never produces an error", - status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + name: "nil error", + err: nil, wantNil: true, }, { - name: "running is not terminal", - status: &sql.StatementStatus{State: sql.StatementStateRunning}, - wantNil: true, - }, - { - name: "pending is not terminal", - status: &sql.StatementStatus{State: sql.StatementStatePending}, - wantNil: true, - }, - { - name: "failed with backend error preserves both fields", - status: &sql.StatementStatus{ - State: sql.StatementStateFailed, - Error: &sql.ServiceError{ErrorCode: "SYNTAX_ERROR", Message: "near 'bad'"}, - }, + name: "failed with backend error preserves both fields", + err: &sqlexec.StatementError{State: sql.StatementStateFailed, Code: "SYNTAX_ERROR", Message: "near 'bad'"}, wantMsg: "near 'bad'", wantCode: "SYNTAX_ERROR", }, { - name: "failed without backend error synthesizes message", - status: &sql.StatementStatus{State: sql.StatementStateFailed}, + name: "failed without backend error surfaces synthesized message", + err: &sqlexec.StatementError{State: sql.StatementStateFailed, Message: "statement reached terminal state FAILED"}, wantMsg: "statement reached terminal state FAILED", }, { - name: "canceled without backend error synthesizes message", - status: &sql.StatementStatus{State: sql.StatementStateCanceled}, + name: "canceled without backend error surfaces synthesized message", + err: &sqlexec.StatementError{State: sql.StatementStateCanceled, Message: "statement reached terminal state CANCELED"}, wantMsg: "statement reached terminal state CANCELED", }, { - name: "closed without backend error synthesizes message", - status: &sql.StatementStatus{State: sql.StatementStateClosed}, + name: "closed without backend error surfaces synthesized message", + err: &sqlexec.StatementError{State: sql.StatementStateClosed, Message: "statement reached terminal state CLOSED"}, wantMsg: "statement reached terminal state CLOSED", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got := statementErrorFromStatus(tc.status) + got := statementError(tc.err) if tc.wantNil { assert.Nil(t, got) return diff --git a/integration/libs/sqlexec/sqlexec_test.go b/integration/libs/sqlexec/sqlexec_test.go new file mode 100644 index 00000000000..538bb673806 --- /dev/null +++ b/integration/libs/sqlexec/sqlexec_test.go @@ -0,0 +1,97 @@ +package sqlexec_test + +import ( + "testing" + + "github.com/databricks/cli/integration/internal/acc" + "github.com/databricks/cli/internal/testutil" + "github.com/databricks/cli/libs/sqlexec" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// These tests exercise libs/sqlexec against a live SQL warehouse to confirm the +// engine keeps working against the real Statement Execution API. They run in the +// nightly integration suite and skip without CLOUD_ENV / TEST_DEFAULT_WAREHOUSE_ID. +// +// They are limited to UC workspaces: gating on TEST_METASTORE_ID (set only on the +// *-ucws environments) keeps them off the non-UC workspaces, whose shared classic +// warehouses are unreliable to start. CLOUD_ENV can't make this distinction +// because azure-prod-ucws and the non-UC azure-prod both report CLOUD_ENV=azure. +func newClient(t *testing.T) (*acc.WorkspaceT, *sqlexec.Client) { + t.Helper() + testutil.GetEnvOrSkipTest(t, "TEST_METASTORE_ID") + _, wt := acc.WorkspaceTest(t) + warehouseID := testutil.GetEnvOrSkipTest(t, "TEST_DEFAULT_WAREHOUSE_ID") + return wt, sqlexec.New(wt.W.StatementExecution, warehouseID) +} + +func TestSQLExecScalar(t *testing.T) { + wt, c := newClient(t) + got, err := c.ExecuteScalar(wt.Context(), "SELECT 1") + require.NoError(t, err) + assert.Equal(t, "1", got) +} + +func TestSQLExecColumnsAndRows(t *testing.T) { + wt, c := newClient(t) + r, err := c.Execute(wt.Context(), "SELECT id, id * 2 AS doubled FROM range(3) ORDER BY id") + require.NoError(t, err) + assert.Equal(t, []string{"id", "doubled"}, r.Columns) + assert.Equal(t, [][]string{{"0", "0"}, {"1", "2"}, {"2", "4"}}, r.Rows) +} + +func TestSQLExecParameters(t *testing.T) { + wt, c := newClient(t) + r, err := c.Execute(wt.Context(), "SELECT :n AS n, :s AS s", sqlexec.WithParameters([]sql.StatementParameterListItem{ + {Name: "n", Type: "INT", Value: "42"}, + {Name: "s", Value: "hello"}, + })) + require.NoError(t, err) + assert.Equal(t, [][]string{{"42", "hello"}}, r.Rows) +} + +func TestSQLExecNullParameter(t *testing.T) { + wt, c := newClient(t) + // An empty value is sent as SQL NULL (StatementParameterListItem.Value is omitempty). + got, err := c.ExecuteScalar(wt.Context(), "SELECT :maybe IS NULL", sqlexec.WithParameters([]sql.StatementParameterListItem{ + {Name: "maybe", Value: ""}, + })) + require.NoError(t, err) + assert.Equal(t, "true", got) +} + +func TestSQLExecFailedStatement(t *testing.T) { + wt, c := newClient(t) + _, err := c.Execute(wt.Context(), "SELECT * FROM a_table_that_does_not_exist_zzz") + var se *sqlexec.StatementError + require.ErrorAs(t, err, &se) + assert.Equal(t, sql.StatementStateFailed, se.State) + assert.NotEmpty(t, se.Code) + assert.NotEmpty(t, se.Message) +} + +func TestSQLExecSubmitAndCancel(t *testing.T) { + wt, c := newClient(t) + ctx := wt.Context() + // Submit returns immediately with the statement ID so we can cancel it. + stmt, err := c.Submit(ctx, "SELECT count(*) FROM range(100000000000)") + require.NoError(t, err) + require.NotEmpty(t, stmt.ID) + + require.NoError(t, c.Cancel(ctx, stmt.ID)) + + // The cancel is best-effort; poll to a terminal state and accept either a + // CANCELED statement or one that finished before the cancel landed. + stmt, err = c.Poll(ctx, stmt) + require.NoError(t, err) + if stmt.State == sql.StatementStateCanceled { + require.Error(t, stmt.Err()) + var se *sqlexec.StatementError + require.ErrorAs(t, stmt.Err(), &se) + assert.Equal(t, sql.StatementStateCanceled, se.State) + } else { + assert.Equal(t, sql.StatementStateSucceeded, stmt.State) + } +} diff --git a/libs/sqlexec/sqlexec.go b/libs/sqlexec/sqlexec.go new file mode 100644 index 00000000000..6c3f534e0f7 --- /dev/null +++ b/libs/sqlexec/sqlexec.go @@ -0,0 +1,349 @@ +// Package sqlexec runs SQL statements through the Databricks SQL Statement +// Execution API. It is a general-purpose, non-interactive executor: it submits +// statements, polls them to a terminal state, assembles paginated results, and +// turns failures into typed errors. Programmatic callers such as bundle deploy +// resources (metric views, which have no REST API and are managed via SQL DDL) +// and the experimental aitools query commands share this engine instead of each +// re-implementing the submit/poll/fetch loop. +// +// The engine speaks only the INLINE disposition with the JSON_ARRAY format, +// which the API caps at 25 MiB per result set. That covers every caller today. +// EXTERNAL_LINKS (presigned downloads for larger results, optionally Arrow or +// CSV) is a separate concern and intentionally not implemented here. +// +// A Client holds no mutable state and is safe for concurrent use; aitools fans +// many statements out through a single Client. +package sqlexec + +import ( + "context" + "fmt" + "time" + + "github.com/databricks/databricks-sdk-go/service/sql" +) + +const ( + // asyncWaitTimeout is the wait applied to Submit. "0s" makes ExecuteStatement + // return immediately with a statement ID (state PENDING) so callers can wire + // up cancellation before the statement has a chance to finish. + asyncWaitTimeout = "0s" + + // defaultWaitTimeout is the synchronous wait Execute applies. Within this + // window ExecuteStatement blocks server-side, so fast statements (most DDL) + // return in a single round-trip and never enter the poll loop. The API + // accepts "0s" or 5s–50s; 10s keeps interactive deploys responsive while + // absorbing typical warehouse latency. + defaultWaitTimeout = "10s" + + // defaultPollInterval and defaultPollMax bound the additive backoff Poll + // applies between GetStatement calls while a statement is PENDING or RUNNING. + defaultPollInterval = 1 * time.Second + defaultPollMax = 5 * time.Second + + // pollIntervalStep is how much the poll interval grows after each poll. + pollIntervalStep = 1 * time.Second +) + +// Client executes SQL statements against a single SQL warehouse. +type Client struct { + api sql.StatementExecutionInterface + warehouseID string + + waitTimeout string + pollInterval time.Duration + pollMax time.Duration +} + +// Option configures a Client. +type Option func(*Client) + +// WithWaitTimeout sets the synchronous wait Execute applies before falling back +// to polling. Must be "0s" or "5s".."50s" per the API; values outside that range +// are rejected by the backend at submit time. +func WithWaitTimeout(d string) Option { + return func(c *Client) { c.waitTimeout = d } +} + +// WithPollInterval sets the initial and maximum delay between status polls. The +// delay grows additively from initial to max. Tests use a small interval to +// avoid real sleeps. +func WithPollInterval(initial, max time.Duration) Option { + return func(c *Client) { + c.pollInterval = initial + c.pollMax = max + } +} + +// New returns a Client that runs statements on warehouseID via api. +func New(api sql.StatementExecutionInterface, warehouseID string, opts ...Option) *Client { + c := &Client{ + api: api, + warehouseID: warehouseID, + waitTimeout: defaultWaitTimeout, + pollInterval: defaultPollInterval, + pollMax: defaultPollMax, + } + for _, opt := range opts { + opt(c) + } + return c +} + +// RequestOption mutates the ExecuteStatementRequest for a single submission. +type RequestOption func(*sql.ExecuteStatementRequest) + +// WithParameters binds named parameters (`:name` markers) on the statement. +// Parameter binding is server-side, so values need no manual quoting or +// escaping; prefer it over string interpolation. +func WithParameters(params []sql.StatementParameterListItem) RequestOption { + return func(req *sql.ExecuteStatementRequest) { req.Parameters = params } +} + +// Statement is a handle to a submitted statement and its latest known state. +type Statement struct { + ID string + State sql.StatementState + + // resp is the most recent response observed for the statement. It carries the + // manifest and first result chunk needed by Results. + resp *sql.StatementResponse +} + +// newStatement wraps a response into a Statement handle. +func newStatement(resp *sql.StatementResponse) *Statement { + s := &Statement{ID: resp.StatementId, resp: resp} + if resp.Status != nil { + s.State = resp.Status.State + } + return s +} + +// Err returns a *StatementError if the statement is in a terminal non-success +// state (FAILED, CANCELED, CLOSED) or carries no status, and nil otherwise. +// Calling Err on a still-running statement returns nil. +func (s *Statement) Err() error { + status := s.resp.Status + if status == nil { + // The API always populates status; a nil here means a malformed or + // partial response, which we surface rather than silently treat as empty. + return &StatementError{Message: "statement response had no status"} + } + switch status.State { + case sql.StatementStateFailed, sql.StatementStateCanceled, sql.StatementStateClosed: + return newStatementError(status) + default: + // SUCCEEDED, PENDING, RUNNING: no error. + return nil + } +} + +// Columns returns the result column names from the statement's manifest. They +// are known once the statement has succeeded, before any row chunk is fetched, +// so callers can still report column metadata when a later chunk fetch fails. +func (s *Statement) Columns() []string { + return columns(s.resp.Manifest) +} + +// StatementError describes a statement that reached a terminal non-success +// state. FAILED statuses carry a backend error code and message (and, in the +// FAILED case, an SQLSTATE); CANCELED and CLOSED carry no error object, so the +// message is synthesized from the state. +type StatementError struct { + State sql.StatementState + Code sql.ServiceErrorCode + Message string + SQLState string +} + +// newStatementError builds a StatementError from a terminal non-success status. +func newStatementError(status *sql.StatementStatus) *StatementError { + e := &StatementError{State: status.State, SQLState: status.SqlState} + if status.Error != nil { + e.Code = status.Error.ErrorCode + e.Message = status.Error.Message + } else { + e.Message = fmt.Sprintf("statement reached terminal state %s", status.State) + } + return e +} + +func (e *StatementError) Error() string { + if e.Code != "" { + return fmt.Sprintf("statement failed: %s: %s", e.Code, e.Message) + } + return e.Message +} + +// Result is the assembled result set of a statement: column names and every row +// across all chunks. Statements that return no result set (DDL) yield empty +// Columns and Rows. +type Result struct { + Columns []string + Rows [][]string +} + +// Scalar returns the top-left cell of the result, or "" when there are no rows. +func (r *Result) Scalar() string { + if len(r.Rows) == 0 || len(r.Rows[0]) == 0 { + return "" + } + return r.Rows[0][0] +} + +// Submit starts a statement asynchronously and returns immediately with its +// handle (state PENDING). Use Poll to wait for completion and Cancel to stop it. +func (c *Client) Submit(ctx context.Context, statement string, opts ...RequestOption) (*Statement, error) { + return c.submit(ctx, statement, asyncWaitTimeout, opts) +} + +// submit issues ExecuteStatement with the given synchronous wait timeout. +func (c *Client) submit(ctx context.Context, statement, waitTimeout string, opts []RequestOption) (*Statement, error) { + req := sql.ExecuteStatementRequest{ + WarehouseId: c.warehouseID, + Statement: statement, + WaitTimeout: waitTimeout, + OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue, + Disposition: sql.DispositionInline, + Format: sql.FormatJsonArray, + } + for _, opt := range opts { + opt(&req) + } + + resp, err := c.api.ExecuteStatement(ctx, req) + if err != nil { + return nil, fmt.Errorf("execute statement: %w", err) + } + return newStatement(resp), nil +} + +// Poll waits for a statement to reach a terminal state, returning the updated +// handle. A statement that is already terminal is returned without an API call. +// +// On context cancellation Poll returns the context error WITHOUT cancelling the +// statement server-side; callers that want server-side cancellation must call +// Cancel explicitly. This keeps Poll usable both for "stop watching" (statement +// get) and "stop the query" (interactive query) callers. +func (c *Client) Poll(ctx context.Context, s *Statement) (*Statement, error) { + interval := c.pollInterval + for isPending(s.resp.Status) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(interval): + } + + resp, err := c.api.GetStatementByStatementId(ctx, s.ID) + if err != nil { + if ctx.Err() != nil { + return nil, ctx.Err() + } + return nil, fmt.Errorf("poll statement %s: %w", s.ID, err) + } + s = newStatement(resp) + interval = min(interval+pollIntervalStep, c.pollMax) + } + return s, nil +} + +// Get returns the current state of a statement with a single GET, no polling. +func (c *Client) Get(ctx context.Context, statementID string) (*Statement, error) { + resp, err := c.api.GetStatementByStatementId(ctx, statementID) + if err != nil { + return nil, fmt.Errorf("get statement %s: %w", statementID, err) + } + return newStatement(resp), nil +} + +// Cancel requests server-side cancellation of a statement. A successful return +// only means the request was accepted; the statement may have already finished. +// Poll or Get to observe the resulting state. +func (c *Client) Cancel(ctx context.Context, statementID string) error { + if err := c.api.CancelExecution(ctx, sql.CancelExecutionRequest{StatementId: statementID}); err != nil { + return fmt.Errorf("cancel statement %s: %w", statementID, err) + } + return nil +} + +// Results assembles the full result set of a statement, fetching every chunk +// beyond the first that the manifest reports. It does not check the statement +// state; call Err first (or use Execute) to reject non-success statements. +func (c *Client) Results(ctx context.Context, s *Statement) (*Result, error) { + r := &Result{Columns: columns(s.resp.Manifest)} + if s.resp.Result == nil { + return r, nil + } + r.Rows = append(r.Rows, s.resp.Result.DataArray...) + + total := 0 + if s.resp.Manifest != nil { + total = s.resp.Manifest.TotalChunkCount + } + // Chunk 0 is already in resp.Result; fetch the rest in order. + for chunk := 1; chunk < total; chunk++ { + data, err := c.api.GetStatementResultChunkNByStatementIdAndChunkIndex(ctx, s.ID, chunk) + if err != nil { + return nil, fmt.Errorf("fetch result chunk %d of statement %s: %w", chunk, s.ID, err) + } + r.Rows = append(r.Rows, data.DataArray...) + } + return r, nil +} + +// Execute submits a statement synchronously, polls it to a terminal state, and +// returns its assembled result. A terminal non-success state is returned as a +// *StatementError. +func (c *Client) Execute(ctx context.Context, statement string, opts ...RequestOption) (*Result, error) { + s, err := c.submit(ctx, statement, c.waitTimeout, opts) + if err != nil { + return nil, err + } + s, err = c.Poll(ctx, s) + if err != nil { + return nil, err + } + if err := s.Err(); err != nil { + return nil, err + } + return c.Results(ctx, s) +} + +// ExecuteScalar runs a statement returning at most one row and one column and +// returns that cell, or "" when there are no rows. +func (c *Client) ExecuteScalar(ctx context.Context, statement string, opts ...RequestOption) (string, error) { + r, err := c.Execute(ctx, statement, opts...) + if err != nil { + return "", err + } + return r.Scalar(), nil +} + +// isPending reports whether a statement is still PENDING or RUNNING, i.e. Poll +// should keep waiting. A nil status (only possible from a malformed response) is +// not pending, so Poll stops and Err surfaces the missing status to the caller +// rather than looping forever. +func isPending(status *sql.StatementStatus) bool { + if status == nil { + return false + } + switch status.State { + case sql.StatementStatePending, sql.StatementStateRunning: + return true + default: + return false + } +} + +// columns returns the column names from a result manifest, or nil when the +// manifest carries no schema (e.g. a statement with no result set). +func columns(manifest *sql.ResultManifest) []string { + if manifest == nil || manifest.Schema == nil { + return nil + } + out := make([]string, len(manifest.Schema.Columns)) + for i, col := range manifest.Schema.Columns { + out[i] = col.Name + } + return out +} diff --git a/libs/sqlexec/sqlexec_http_test.go b/libs/sqlexec/sqlexec_http_test.go new file mode 100644 index 00000000000..83328a74687 --- /dev/null +++ b/libs/sqlexec/sqlexec_http_test.go @@ -0,0 +1,146 @@ +package sqlexec_test + +import ( + "testing" + "time" + + "github.com/databricks/cli/libs/sqlexec" + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// These tests drive the engine through a real SDK client over HTTP against the +// in-process testserver, with the statement-execution endpoints programmed per +// test. Unlike the mock-interface unit tests they exercise the full +// request/response JSON serialization, and unlike the integration tests they are +// hermetic and run on every PR without a warehouse. +func httpClient(t *testing.T, server *testserver.Server) *sqlexec.Client { + t.Helper() + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + // Tiny poll interval so the polling tests don't sleep. + return sqlexec.New(w.StatementExecution, "wh-1", sqlexec.WithPollInterval(time.Millisecond, time.Millisecond)) +} + +func TestHTTPExecuteSuccess(t *testing.T) { + server := testserver.New(t) + server.Handle("POST", "/api/2.0/sql/statements", func(testserver.Request) any { + return sql.StatementResponse{ + StatementId: "s1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: &sql.ResultSchema{Columns: []sql.ColumnInfo{{Name: "a"}, {Name: "b"}}}, TotalChunkCount: 1}, + Result: &sql.ResultData{DataArray: [][]string{{"1", "2"}}}, + } + }) + + r, err := httpClient(t, server).Execute(t.Context(), "SELECT 1 AS a, 2 AS b") + require.NoError(t, err) + assert.Equal(t, []string{"a", "b"}, r.Columns) + assert.Equal(t, [][]string{{"1", "2"}}, r.Rows) +} + +func TestHTTPExecutePolls(t *testing.T) { + server := testserver.New(t) + server.Handle("POST", "/api/2.0/sql/statements", func(testserver.Request) any { + return sql.StatementResponse{StatementId: "s1", Status: &sql.StatementStatus{State: sql.StatementStatePending}} + }) + polls := 0 + server.Handle("GET", "/api/2.0/sql/statements/{statement_id}", func(req testserver.Request) any { + assert.Equal(t, "s1", req.Vars["statement_id"]) + polls++ + if polls < 2 { + return sql.StatementResponse{StatementId: "s1", Status: &sql.StatementStatus{State: sql.StatementStateRunning}} + } + return sql.StatementResponse{ + StatementId: "s1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Result: &sql.ResultData{DataArray: [][]string{{"done"}}}, + } + }) + + got, err := httpClient(t, server).ExecuteScalar(t.Context(), "SELECT 1") + require.NoError(t, err) + assert.Equal(t, "done", got) + assert.GreaterOrEqual(t, polls, 2) +} + +func TestHTTPExecutePaginatesChunks(t *testing.T) { + server := testserver.New(t) + server.Handle("POST", "/api/2.0/sql/statements", func(testserver.Request) any { + return sql.StatementResponse{ + StatementId: "s1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{TotalChunkCount: 3}, + Result: &sql.ResultData{DataArray: [][]string{{"0"}}}, + } + }) + server.Handle("GET", "/api/2.0/sql/statements/{statement_id}/result/chunks/{chunk_index}", func(req testserver.Request) any { + return sql.ResultData{DataArray: [][]string{{req.Vars["chunk_index"]}}} + }) + + r, err := httpClient(t, server).Execute(t.Context(), "SELECT * FROM big") + require.NoError(t, err) + assert.Equal(t, [][]string{{"0"}, {"1"}, {"2"}}, r.Rows) +} + +func TestHTTPExecuteFailedReturns200(t *testing.T) { + server := testserver.New(t) + // A failed statement comes back as HTTP 200 with state=FAILED, not an HTTP + // error; the engine must inspect the body and surface a *StatementError. + server.Handle("POST", "/api/2.0/sql/statements", func(testserver.Request) any { + return sql.StatementResponse{ + StatementId: "s1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + SqlState: "42P01", + Error: &sql.ServiceError{ErrorCode: sql.ServiceErrorCodeBadRequest, Message: "no such table"}, + }, + } + }) + + _, err := httpClient(t, server).Execute(t.Context(), "SELECT * FROM nope") + var se *sqlexec.StatementError + require.ErrorAs(t, err, &se) + assert.Equal(t, sql.StatementStateFailed, se.State) + assert.Equal(t, sql.ServiceErrorCodeBadRequest, se.Code) + assert.Equal(t, "no such table", se.Message) + assert.Equal(t, "42P01", se.SQLState) +} + +func TestHTTPSubmitAndCancel(t *testing.T) { + server := testserver.New(t) + server.Handle("POST", "/api/2.0/sql/statements", func(testserver.Request) any { + return sql.StatementResponse{StatementId: "s1", Status: &sql.StatementStatus{State: sql.StatementStatePending}} + }) + canceled := false + server.Handle("POST", "/api/2.0/sql/statements/{statement_id}/cancel", func(req testserver.Request) any { + assert.Equal(t, "s1", req.Vars["statement_id"]) + canceled = true + return map[string]string{} + }) + server.Handle("GET", "/api/2.0/sql/statements/{statement_id}", func(testserver.Request) any { + state := sql.StatementStatePending + if canceled { + state = sql.StatementStateCanceled + } + return sql.StatementResponse{StatementId: "s1", Status: &sql.StatementStatus{State: state}} + }) + + c := httpClient(t, server) + ctx := t.Context() + + stmt, err := c.Submit(ctx, "SELECT 1") + require.NoError(t, err) + assert.Equal(t, "s1", stmt.ID) + + require.NoError(t, c.Cancel(ctx, stmt.ID)) + assert.True(t, canceled) + + stmt, err = c.Poll(ctx, stmt) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateCanceled, stmt.State) + require.Error(t, stmt.Err()) +} diff --git a/libs/sqlexec/sqlexec_test.go b/libs/sqlexec/sqlexec_test.go new file mode 100644 index 00000000000..ce170d1043e --- /dev/null +++ b/libs/sqlexec/sqlexec_test.go @@ -0,0 +1,378 @@ +package sqlexec + +import ( + "context" + "errors" + "testing" + "time" + + mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" + "github.com/databricks/databricks-sdk-go/service/sql" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +// testClient returns a Client wired to a fresh mock with a near-zero poll +// interval so polling tests don't sleep. +func testClient(t *testing.T) (*Client, *mocksql.MockStatementExecutionInterface) { + api := mocksql.NewMockStatementExecutionInterface(t) + c := New(api, "wh-1", WithPollInterval(time.Millisecond, time.Millisecond)) + return c, api +} + +func statusResp(state sql.StatementState) *sql.StatementResponse { + return &sql.StatementResponse{StatementId: "stmt-1", Status: &sql.StatementStatus{State: state}} +} + +func succeededResp(columns []string, dataArray [][]string) *sql.StatementResponse { + var schema *sql.ResultSchema + if columns != nil { + cols := make([]sql.ColumnInfo, len(columns)) + for i, name := range columns { + cols[i] = sql.ColumnInfo{Name: name} + } + schema = &sql.ResultSchema{Columns: cols} + } + return &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{State: sql.StatementStateSucceeded}, + Manifest: &sql.ResultManifest{Schema: schema}, + Result: &sql.ResultData{DataArray: dataArray}, + } +} + +func TestExecuteScalar(t *testing.T) { + tests := []struct { + name string + resp *sql.StatementResponse + want string + errSubstr string + }{ + { + name: "succeeded with cell", + resp: succeededResp([]string{"table_type"}, [][]string{{"METRIC_VIEW"}}), + want: "METRIC_VIEW", + }, + { + name: "succeeded with no rows", + resp: succeededResp([]string{"table_type"}, nil), + want: "", + }, + { + name: "failed surfaces backend code and message", + resp: &sql.StatementResponse{ + StatementId: "stmt-1", + Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + SqlState: "42000", + Error: &sql.ServiceError{ErrorCode: sql.ServiceErrorCodeBadRequest, Message: "boom"}, + }, + }, + errSubstr: "boom", + }, + { + name: "canceled is not treated as success", + resp: statusResp(sql.StatementStateCanceled), + errSubstr: "terminal state CANCELED", + }, + { + name: "closed is not treated as success", + resp: statusResp(sql.StatementStateClosed), + errSubstr: "terminal state CLOSED", + }, + { + name: "missing status surfaces as error", + resp: &sql.StatementResponse{StatementId: "stmt-1"}, + errSubstr: "no status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c, api := testClient(t) + api.EXPECT().ExecuteStatement(mock.Anything, mock.Anything).Return(tt.resp, nil).Once() + + got, err := c.ExecuteScalar(t.Context(), "SELECT 1") + if tt.errSubstr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errSubstr) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestExecuteDDLReturnsEmptyResult(t *testing.T) { + c, api := testClient(t) + // DDL responses carry a status but no manifest or result block. + api.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(statusResp(sql.StatementStateSucceeded), nil).Once() + + r, err := c.Execute(t.Context(), "CREATE VIEW v") + require.NoError(t, err) + assert.Empty(t, r.Columns) + assert.Empty(t, r.Rows) + assert.Empty(t, r.Scalar()) +} + +func TestExecuteSubmitTransportError(t *testing.T) { + c, api := testClient(t) + api.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(nil, errors.New("network down")).Once() + + _, err := c.Execute(t.Context(), "SELECT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "execute statement") + assert.Contains(t, err.Error(), "network down") +} + +func TestExecutePollsUntilTerminal(t *testing.T) { + c, api := testClient(t) + api.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(statusResp(sql.StatementStatePending), nil).Once() + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(statusResp(sql.StatementStateRunning), nil).Once() + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(succeededResp([]string{"n"}, [][]string{{"polled"}}), nil).Once() + + got, err := c.ExecuteScalar(t.Context(), "SELECT 1") + require.NoError(t, err) + assert.Equal(t, "polled", got) +} + +func TestExecuteUsesSyncWaitTimeout(t *testing.T) { + api := mocksql.NewMockStatementExecutionInterface(t) + c := New(api, "wh-1", WithWaitTimeout("25s")) + api.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WaitTimeout == "25s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue && + req.Disposition == sql.DispositionInline && + req.Format == sql.FormatJsonArray && + req.WarehouseId == "wh-1" + })).Return(succeededResp(nil, nil), nil).Once() + + _, err := c.Execute(t.Context(), "SELECT 1") + require.NoError(t, err) +} + +func TestSubmitIsAsyncAndForwardsParameters(t *testing.T) { + c, api := testClient(t) + params := []sql.StatementParameterListItem{{Name: "since", Type: "DATE", Value: "2026-01-01"}} + + api.EXPECT().ExecuteStatement(mock.Anything, mock.MatchedBy(func(req sql.ExecuteStatementRequest) bool { + return req.WaitTimeout == "0s" && + req.OnWaitTimeout == sql.ExecuteStatementRequestOnWaitTimeoutContinue && + assert.ObjectsAreEqual(params, req.Parameters) + })).Return(statusResp(sql.StatementStatePending), nil).Once() + + s, err := c.Submit(t.Context(), "SELECT :since", WithParameters(params)) + require.NoError(t, err) + assert.Equal(t, "stmt-1", s.ID) + assert.Equal(t, sql.StatementStatePending, s.State) +} + +func TestPollImmediateTerminalDoesNotCallAPI(t *testing.T) { + c, _ := testClient(t) + // No GetStatement expectation: a terminal statement must not be polled. + s, err := c.Poll(t.Context(), newStatement(succeededResp(nil, nil))) + require.NoError(t, err) + assert.Equal(t, sql.StatementStateSucceeded, s.State) +} + +func TestPollContextCancellationDoesNotCancelServerSide(t *testing.T) { + c, api := testClient(t) + _ = api // no CancelExecution expectation: Poll must not cancel server-side. + + ctx, cancel := context.WithCancel(t.Context()) + cancel() + + _, err := c.Poll(ctx, newStatement(statusResp(sql.StatementStatePending))) + require.ErrorIs(t, err, context.Canceled) +} + +func TestPollWrapsGetError(t *testing.T) { + c, api := testClient(t) + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("boom")).Once() + + _, err := c.Poll(t.Context(), newStatement(statusResp(sql.StatementStatePending))) + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement stmt-1") + assert.Contains(t, err.Error(), "boom") +} + +func TestPollGetErrorAfterCancellationReturnsContextErr(t *testing.T) { + api := mocksql.NewMockStatementExecutionInterface(t) + c := New(api, "wh-1", WithPollInterval(time.Millisecond, time.Millisecond)) + + ctx, cancel := context.WithCancel(t.Context()) + // Simulate the context being cancelled while the GET is in flight: the call + // fails, but Poll must surface the context error, not the transport error. + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + RunAndReturn(func(context.Context, string) (*sql.StatementResponse, error) { + cancel() + return nil, errors.New("cancelled mid-flight") + }).Once() + + _, err := c.Poll(ctx, newStatement(statusResp(sql.StatementStatePending))) + require.ErrorIs(t, err, context.Canceled) +} + +func TestExecuteReturnsPollError(t *testing.T) { + c, api := testClient(t) + api.EXPECT().ExecuteStatement(mock.Anything, mock.Anything). + Return(statusResp(sql.StatementStatePending), nil).Once() + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-1"). + Return(nil, errors.New("boom")).Once() + + _, err := c.Execute(t.Context(), "SELECT 1") + require.Error(t, err) + assert.Contains(t, err.Error(), "poll statement stmt-1") +} + +func TestResultsPaginatesChunks(t *testing.T) { + c, api := testClient(t) + resp := succeededResp([]string{"n"}, [][]string{{"0"}}) + resp.Manifest.TotalChunkCount = 3 + api.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 1). + Return(&sql.ResultData{DataArray: [][]string{{"1"}}}, nil).Once() + api.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 2). + Return(&sql.ResultData{DataArray: [][]string{{"2"}}}, nil).Once() + + r, err := c.Results(t.Context(), newStatement(resp)) + require.NoError(t, err) + assert.Equal(t, []string{"n"}, r.Columns) + assert.Equal(t, [][]string{{"0"}, {"1"}, {"2"}}, r.Rows) +} + +func TestResultsChunkFetchError(t *testing.T) { + c, api := testClient(t) + resp := succeededResp([]string{"n"}, [][]string{{"0"}}) + resp.Manifest.TotalChunkCount = 2 + api.EXPECT().GetStatementResultChunkNByStatementIdAndChunkIndex(mock.Anything, "stmt-1", 1). + Return(nil, errors.New("throttled")).Once() + + _, err := c.Results(t.Context(), newStatement(resp)) + require.Error(t, err) + assert.Contains(t, err.Error(), "fetch result chunk 1 of statement stmt-1") + assert.Contains(t, err.Error(), "throttled") +} + +func TestResultsNoResultBlock(t *testing.T) { + c, _ := testClient(t) + r, err := c.Results(t.Context(), newStatement(statusResp(sql.StatementStateSucceeded))) + require.NoError(t, err) + assert.Nil(t, r.Columns) + assert.Nil(t, r.Rows) +} + +func TestGet(t *testing.T) { + c, api := testClient(t) + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-9"). + Return(statusResp(sql.StatementStateRunning), nil).Once() + + s, err := c.Get(t.Context(), "stmt-9") + require.NoError(t, err) + assert.Equal(t, sql.StatementStateRunning, s.State) +} + +func TestGetWrapsError(t *testing.T) { + c, api := testClient(t) + api.EXPECT().GetStatementByStatementId(mock.Anything, "stmt-9"). + Return(nil, errors.New("nope")).Once() + + _, err := c.Get(t.Context(), "stmt-9") + require.Error(t, err) + assert.Contains(t, err.Error(), "get statement stmt-9") +} + +func TestCancel(t *testing.T) { + c, api := testClient(t) + api.EXPECT().CancelExecution(mock.Anything, sql.CancelExecutionRequest{StatementId: "stmt-1"}). + Return(nil).Once() + require.NoError(t, c.Cancel(t.Context(), "stmt-1")) +} + +func TestCancelWrapsError(t *testing.T) { + c, api := testClient(t) + api.EXPECT().CancelExecution(mock.Anything, mock.Anything).Return(errors.New("gone")).Once() + err := c.Cancel(t.Context(), "stmt-1") + require.Error(t, err) + assert.Contains(t, err.Error(), "cancel statement stmt-1") + assert.Contains(t, err.Error(), "gone") +} + +func TestStatementErr(t *testing.T) { + tests := []struct { + name string + resp *sql.StatementResponse + wantErr bool + wantCode sql.ServiceErrorCode + wantMsg string + }{ + {name: "succeeded", resp: statusResp(sql.StatementStateSucceeded)}, + {name: "pending", resp: statusResp(sql.StatementStatePending)}, + {name: "running", resp: statusResp(sql.StatementStateRunning)}, + { + name: "failed with service error", + resp: &sql.StatementResponse{Status: &sql.StatementStatus{ + State: sql.StatementStateFailed, + SqlState: "42000", + Error: &sql.ServiceError{ErrorCode: sql.ServiceErrorCodeBadRequest, Message: "bad"}, + }}, + wantErr: true, + wantCode: sql.ServiceErrorCodeBadRequest, + wantMsg: "bad", + }, + { + name: "canceled without error object", + resp: statusResp(sql.StatementStateCanceled), + wantErr: true, + wantMsg: "statement reached terminal state CANCELED", + }, + { + name: "nil status", + resp: &sql.StatementResponse{}, + wantErr: true, + wantMsg: "statement response had no status", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := newStatement(tt.resp).Err() + if !tt.wantErr { + require.NoError(t, err) + return + } + require.Error(t, err) + var se *StatementError + require.ErrorAs(t, err, &se) + assert.Equal(t, tt.wantCode, se.Code) + assert.Equal(t, tt.wantMsg, se.Message) + }) + } +} + +func TestStatementErrorError(t *testing.T) { + withCode := &StatementError{Code: sql.ServiceErrorCodeBadRequest, Message: "bad"} + assert.Equal(t, "statement failed: BAD_REQUEST: bad", withCode.Error()) + + noCode := &StatementError{State: sql.StatementStateCanceled, Message: "statement reached terminal state CANCELED"} + assert.Equal(t, "statement reached terminal state CANCELED", noCode.Error()) +} + +func TestResultScalar(t *testing.T) { + assert.Empty(t, (&Result{}).Scalar()) + assert.Empty(t, (&Result{Rows: [][]string{{}}}).Scalar()) + assert.Equal(t, "x", (&Result{Rows: [][]string{{"x", "y"}}}).Scalar()) +} + +func TestStatementColumns(t *testing.T) { + assert.Equal(t, []string{"a", "b"}, newStatement(succeededResp([]string{"a", "b"}, nil)).Columns()) + // No manifest (e.g. a DDL response): no columns. + assert.Nil(t, newStatement(statusResp(sql.StatementStateSucceeded)).Columns()) +}