Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 24 additions & 31 deletions experimental/aitools/cmd/batch.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/log"
"github.com/databricks/cli/libs/sqlexec"
"github.com/databricks/databricks-sdk-go/service/sql"
"golang.org/x/sync/errgroup"
)
Expand Down Expand Up @@ -58,6 +59,8 @@ type batchResultError struct {
// reused across the batch, so callers must ensure each SQL uses only markers
// that are covered.
func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, warehouseID string, sqls []string, params []sql.StatementParameterListItem, concurrency int) []batchResult {
client := sqlexec.New(api, warehouseID)

pollCtx, pollCancel := context.WithCancel(ctx)
defer pollCancel()

Expand Down Expand Up @@ -101,51 +104,45 @@ func executeBatch(ctx context.Context, api sql.StatementExecutionInterface, ware
g.SetLimit(concurrency)
for i, sqlStr := range sqls {
g.Go(func() error {
results[i] = runOneBatchQuery(pollCtx, api, warehouseID, sqlStr, params, statementIDs, i)
results[i] = runOneBatchQuery(pollCtx, client, sqlStr, params, statementIDs, i)
completed.Add(1)
return nil
})
}
_ = g.Wait()

// pollStatement is a pure helper that returns ctx.Err() on cancellation
// without touching the server. Sweep any not-yet-terminal statements here.
// Poll returns ctx.Err() on cancellation without touching the server.
// Sweep any not-yet-terminal statements here.
if pollCtx.Err() != nil {
cancelInFlight(ctx, api, statementIDs, results)
cancelInFlight(ctx, client, statementIDs, results)
}

return results
}

// runOneBatchQuery submits one SQL, polls to completion, and returns its
// batchResult. All errors are encoded into the result; never returns an error.
func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface, warehouseID, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult {
func runOneBatchQuery(ctx context.Context, client *sqlexec.Client, sqlStr string, params []sql.StatementParameterListItem, statementIDs []string, idx int) batchResult {
start := time.Now()
result := batchResult{SQL: sqlStr}

resp, err := api.ExecuteStatement(ctx, sql.ExecuteStatementRequest{
WarehouseId: warehouseID,
Statement: sqlStr,
Parameters: params,
WaitTimeout: "0s",
OnWaitTimeout: sql.ExecuteStatementRequestOnWaitTimeoutContinue,
})
stmt, err := client.Submit(ctx, sqlStr, sqlexec.WithParameters(params))
if err != nil {
if ctx.Err() != nil {
result.State = sql.StatementStateCanceled
result.Error = &batchResultError{Message: "submission cancelled"}
} else {
result.State = sql.StatementStateFailed
result.Error = &batchResultError{Message: fmt.Sprintf("execute statement: %v", err)}
result.Error = &batchResultError{Message: err.Error()}
}
result.ElapsedMs = time.Since(start).Milliseconds()
return result
}

statementIDs[idx] = resp.StatementId
result.StatementID = resp.StatementId
statementIDs[idx] = stmt.ID
result.StatementID = stmt.ID

pollResp, err := pollStatement(ctx, api, resp)
stmt, err = client.Poll(ctx, stmt)
if err != nil {
if ctx.Err() != nil {
result.State = sql.StatementStateCanceled
Expand All @@ -158,38 +155,34 @@ func runOneBatchQuery(ctx context.Context, api sql.StatementExecutionInterface,
return result
}

if pollResp.Status != nil {
result.State = pollResp.Status.State
}
result.State = stmt.State

if result.State != sql.StatementStateSucceeded {
result.Error = &batchResultError{}
if pollResp.Status != nil && pollResp.Status.Error != nil {
result.Error.Message = pollResp.Status.Error.Message
result.Error.ErrorCode = string(pollResp.Status.Error.ErrorCode)
} else {
result.Error.Message = fmt.Sprintf("query reached terminal state %s", result.State)
if err := stmt.Err(); err != nil {
se, _ := errors.AsType[*sqlexec.StatementError](err)
result.Error = &batchResultError{
Message: se.Message,
ErrorCode: string(se.Code),
}
result.ElapsedMs = time.Since(start).Milliseconds()
return result
}

result.Columns = extractColumns(pollResp.Manifest)
rows, err := fetchAllRows(ctx, api, pollResp)
res, err := client.Results(ctx, stmt)
if err != nil {
result.Error = &batchResultError{Message: fmt.Sprintf("fetch rows: %v", err)}
result.ElapsedMs = time.Since(start).Milliseconds()
return result
}
result.Rows = rows
result.Columns = res.Columns
result.Rows = res.Rows
result.ElapsedMs = time.Since(start).Milliseconds()
return result
}

// cancelInFlight sends CancelExecution for every statement that didn't reach
// a terminal state server-side before context cancellation. Best effort: errors
// are logged at warn but don't fail the batch.
func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, statementIDs []string, results []batchResult) {
func cancelInFlight(ctx context.Context, client *sqlexec.Client, statementIDs []string, results []batchResult) {
var cancelled int
for i, sid := range statementIDs {
if sid == "" {
Expand All @@ -208,7 +201,7 @@ func cancelInFlight(ctx context.Context, api sql.StatementExecutionInterface, st
// values but drops the cancellation signal so the cancel RPC actually
// reaches the warehouse instead of short-circuiting on ctx.Err().
cancelCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), cancelTimeout)
if err := api.CancelExecution(cancelCtx, sql.CancelExecutionRequest{StatementId: sid}); err != nil {
if err := client.Cancel(cancelCtx, sid); err != nil {
log.Warnf(ctx, "Failed to cancel statement %s: %v", sid, err)
}
cancel()
Expand Down
71 changes: 32 additions & 39 deletions experimental/aitools/cmd/discover_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/databricks/cli/libs/cmdctx"
"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/log"
"github.com/databricks/cli/libs/sqlexec"
"github.com/databricks/databricks-sdk-go"
dbsql "github.com/databricks/databricks-sdk-go/service/sql"
"github.com/spf13/cobra"
Expand All @@ -45,8 +46,10 @@ func newSQLGate(limit int) *sqlGate {
// run executes a SQL statement asynchronously, polls until terminal, and
// records the statement_id so it can be cancelled if the parent context is
// cancelled. Acquires a slot from the gate before submitting and releases it
// when polling completes (or the caller's context is cancelled).
func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*dbsql.StatementResponse, error) {
// when polling completes (or the caller's context is cancelled). On success it
// returns the assembled result; a terminal non-success state is surfaced as the
// CLI-facing query error.
func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, warehouseID, statement string) (*sqlexec.Result, error) {
// If the caller cancelled before we even tried, don't enter the select:
// when the gate has free slots both cases are ready and Go picks one
// pseudo-randomly. Without this early-out we'd occasionally submit a
Expand All @@ -61,28 +64,25 @@ func (g *sqlGate) run(ctx context.Context, w *databricks.WorkspaceClient, wareho
return nil, ctx.Err()
}

resp, err := w.StatementExecution.ExecuteStatement(ctx, dbsql.ExecuteStatementRequest{
WarehouseId: warehouseID,
Statement: statement,
WaitTimeout: "0s",
OnWaitTimeout: dbsql.ExecuteStatementRequestOnWaitTimeoutContinue,
})
client := sqlexec.New(w.StatementExecution, warehouseID)

stmt, err := client.Submit(ctx, statement)
if err != nil {
return nil, fmt.Errorf("execute statement: %w", err)
return nil, err
}

g.mu.Lock()
g.ids = append(g.ids, resp.StatementId)
g.ids = append(g.ids, stmt.ID)
g.mu.Unlock()

pollResp, err := pollStatement(ctx, w.StatementExecution, resp)
stmt, err = client.Poll(ctx, stmt)
if err != nil {
return nil, err
}
if err := checkFailedState(pollResp.Status); err != nil {
if err := presentQueryError(stmt.Err()); err != nil {
return nil, err
}
return pollResp, nil
return client.Results(ctx, stmt)
}

// trackedIDs returns a snapshot of statement_ids submitted through this gate.
Expand Down Expand Up @@ -235,9 +235,11 @@ func cancelDiscoverInFlight(ctx context.Context, api dbsql.StatementExecutionInt
cmdio.LogString(ctx, "discover-schema cancelled.")
return
}
// Cancel/Poll/Get don't use the warehouse ID, so an empty one is fine here.
client := sqlexec.New(api, "")
for _, id := range ids {
cancelCtx, cancel := context.WithTimeout(ctx, cancelTimeout)
if err := api.CancelExecution(cancelCtx, dbsql.CancelExecutionRequest{StatementId: id}); err != nil {
if err := client.Cancel(cancelCtx, id); err != nil {
log.Warnf(ctx, "Failed to cancel statement %s: %v", id, err)
}
cancel()
Expand All @@ -252,12 +254,12 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl
}

// 1. describe table - get columns and types
descResp, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted)
descResult, err := gate.run(ctx, w, warehouseID, "DESCRIBE TABLE "+quoted)
if err != nil {
return "", fmt.Errorf("describe table: %w", err)
}

columns, types := parseDescribeResult(descResp)
columns, types := parseDescribeResult(descResult)
if len(columns) == 0 {
return "", errors.New("no columns found")
}
Expand All @@ -281,16 +283,16 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl
nullSQL := fmt.Sprintf("SELECT COUNT(*) AS total_rows, %s FROM %s",
strings.Join(nullCountExprs, ", "), quoted)

var sampleResp, nullResp *dbsql.StatementResponse
var sampleResult, nullResult *sqlexec.Result
var sampleErr, nullErr error

g := new(errgroup.Group)
g.Go(func() error {
sampleResp, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL)
sampleResult, sampleErr = gate.run(ctx, w, warehouseID, sampleSQL)
return nil
})
g.Go(func() error {
nullResp, nullErr = gate.run(ctx, w, warehouseID, nullSQL)
nullResult, nullErr = gate.run(ctx, w, warehouseID, nullSQL)
return nil
})
_ = g.Wait()
Expand All @@ -306,25 +308,21 @@ func discoverTable(ctx context.Context, gate *sqlGate, w *databricks.WorkspaceCl
fmt.Fprintf(&sb, "\nSAMPLE DATA: Error - %v\n", sampleErr)
} else {
sb.WriteString("\nSAMPLE DATA:\n")
sb.WriteString(formatTableData(sampleResp))
sb.WriteString(formatTableData(sampleResult))
}

if nullErr != nil {
fmt.Fprintf(&sb, "\nNULL COUNTS: Error - %v\n", nullErr)
} else {
sb.WriteString("\nNULL COUNTS:\n")
sb.WriteString(formatNullCounts(nullResp, columns))
sb.WriteString(formatNullCounts(nullResult, columns))
}

return sb.String(), nil
}

func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string) {
if resp.Result == nil || resp.Result.DataArray == nil {
return nil, nil
}

for _, row := range resp.Result.DataArray {
func parseDescribeResult(result *sqlexec.Result) (columns, types []string) {
for _, row := range result.Rows {
if len(row) < 2 {
continue
}
Expand All @@ -340,20 +338,15 @@ func parseDescribeResult(resp *dbsql.StatementResponse) (columns, types []string
return columns, types
}

func formatTableData(resp *dbsql.StatementResponse) string {
if resp.Result == nil || resp.Result.DataArray == nil || len(resp.Result.DataArray) == 0 {
func formatTableData(result *sqlexec.Result) string {
if len(result.Rows) == 0 {
return " (no data)\n"
}

var sb strings.Builder
var columns []string
if resp.Manifest != nil && resp.Manifest.Schema != nil {
for _, col := range resp.Manifest.Schema.Columns {
columns = append(columns, col.Name)
}
}
columns := result.Columns

for i, row := range resp.Result.DataArray {
for i, row := range result.Rows {
fmt.Fprintf(&sb, " Row %d:\n", i+1)
for j, val := range row {
colName := fmt.Sprintf("col%d", j)
Expand All @@ -366,12 +359,12 @@ func formatTableData(resp *dbsql.StatementResponse) string {
return sb.String()
}

func formatNullCounts(resp *dbsql.StatementResponse, columns []string) string {
if resp.Result == nil || resp.Result.DataArray == nil || len(resp.Result.DataArray) == 0 {
func formatNullCounts(result *sqlexec.Result, columns []string) string {
if len(result.Rows) == 0 {
return " (no data)\n"
}

row := resp.Result.DataArray[0]
row := result.Rows[0]
var sb strings.Builder

// first value is total_rows
Expand Down
13 changes: 7 additions & 6 deletions experimental/aitools/cmd/discover_schema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"time"

"github.com/databricks/cli/libs/cmdio"
"github.com/databricks/cli/libs/sqlexec"
"github.com/databricks/databricks-sdk-go"
mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql"
dbsql "github.com/databricks/databricks-sdk-go/service/sql"
Expand Down Expand Up @@ -50,17 +51,17 @@ func TestQuoteTableName(t *testing.T) {
}

func TestParseDescribeResultSkipsMetadataRows(t *testing.T) {
resp := &dbsql.StatementResponse{
Result: &dbsql.ResultData{DataArray: [][]string{
result := &sqlexec.Result{
Rows: [][]string{
{"id", "BIGINT", ""},
{"name", "STRING", ""},
{"# Partition Information", "", ""},
{"region", "STRING", ""},
{"", "STRING", ""},
}},
},
}

cols, types := parseDescribeResult(resp)
cols, types := parseDescribeResult(result)
assert.Equal(t, []string{"id", "name", "region"}, cols)
assert.Equal(t, []string{"BIGINT", "STRING", "STRING"}, types)
}
Expand All @@ -82,9 +83,9 @@ func TestSQLGateRunPinsOnWaitTimeoutAndRecordsID(t *testing.T) {
w := &databricks.WorkspaceClient{StatementExecution: mockAPI}
gate := newSQLGate(2)

resp, err := gate.run(ctx, w, "wh-1", "SELECT 1")
result, err := gate.run(ctx, w, "wh-1", "SELECT 1")
require.NoError(t, err)
assert.Equal(t, "stmt-1", resp.StatementId)
assert.Equal(t, [][]string{{"1"}}, result.Rows)
assert.Equal(t, []string{"stmt-1"}, gate.trackedIDs())
}

Expand Down
Loading
Loading