Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,17 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
req.Parameters = parameters
}

// Add per-statement query tags if provided via context
if queryTags := driverctx.QueryTagsFromContext(ctx); len(queryTags) > 0 {
serialized := SerializeQueryTags(queryTags)
if serialized != "" {
if req.ConfOverlay == nil {
req.ConfOverlay = make(map[string]string)
}
req.ConfOverlay["query_tags"] = serialized
}
}

resp, err := c.client.ExecuteStatement(ctx, &req)
var log *logger.DBSQLLogger
log, ctx = client.LoggerAndContext(ctx, resp)
Expand Down
204 changes: 204 additions & 0 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/apache/thrift/lib/go/thrift"
"github.com/pkg/errors"

"github.com/databricks/databricks-sql-go/driverctx"
dbsqlerr "github.com/databricks/databricks-sql-go/errors"
"github.com/databricks/databricks-sql-go/internal/cli_service"
"github.com/databricks/databricks-sql-go/internal/client"
Expand Down Expand Up @@ -493,6 +494,209 @@ func TestConn_executeStatement_ProtocolFeatures(t *testing.T) {
}
}

func TestConn_executeStatement_QueryTags(t *testing.T) {
t.Parallel()

makeTestConn := func(captureReq *(*cli_service.TExecuteStatementReq)) *conn {
executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) {
*captureReq = req
return &cli_service.TExecuteStatementResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationHandle: &cli_service.TOperationHandle{
OperationId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
Secret: []byte("secret"),
},
},
DirectResults: &cli_service.TSparkDirectResults{
OperationStatus: &cli_service.TGetOperationStatusResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
},
},
}, nil
}

return &conn{
session: getTestSession(),
client: &client.TestClient{
FnExecuteStatement: executeStatement,
},
cfg: config.WithDefaults(),
}
}

t.Run("query tags from context are set in ConfOverlay", func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
testConn := makeTestConn(&capturedReq)

ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
"team": "engineering",
"app": "etl",
})

_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
assert.NoError(t, err)
assert.NotNil(t, capturedReq.ConfOverlay)
// Map iteration is non-deterministic, so check both possible orderings
queryTags := capturedReq.ConfOverlay["query_tags"]
assert.True(t,
queryTags == "team:engineering,app:etl" || queryTags == "app:etl,team:engineering",
"unexpected query_tags value: %s", queryTags)
})

t.Run("no query tags in context means no ConfOverlay", func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
testConn := makeTestConn(&capturedReq)

_, err := testConn.executeStatement(context.Background(), "SELECT 1", nil)
assert.NoError(t, err)
assert.Nil(t, capturedReq.ConfOverlay)
})

t.Run("empty query tags map means no ConfOverlay", func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
testConn := makeTestConn(&capturedReq)

ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{})

_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
assert.NoError(t, err)
assert.Nil(t, capturedReq.ConfOverlay)
})

t.Run("single query tag", func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
testConn := makeTestConn(&capturedReq)

ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
"team": "data-eng",
})

_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
assert.NoError(t, err)
assert.Equal(t, "team:data-eng", capturedReq.ConfOverlay["query_tags"])
})

t.Run("query tags with special characters in values", func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
testConn := makeTestConn(&capturedReq)

ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
"url": "http://host:8080",
})

_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
assert.NoError(t, err)
assert.Equal(t, `url:http\://host\:8080`, capturedReq.ConfOverlay["query_tags"])
})

t.Run("query tags with empty value", func(t *testing.T) {
var capturedReq *cli_service.TExecuteStatementReq
testConn := makeTestConn(&capturedReq)

ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
"flag": "",
})

_, err := testConn.executeStatement(ctx, "SELECT 1", nil)
assert.NoError(t, err)
assert.Equal(t, "flag", capturedReq.ConfOverlay["query_tags"])
})

t.Run("session-level and statement-level query tags coexist", func(t *testing.T) {
// Session-level tags are sent via TOpenSessionReq.Configuration at connect time.
// Statement-level tags are sent via TExecuteStatementReq.ConfOverlay at query time.
// They are independent fields on different requests, so both should work together.

var capturedOpenReq *cli_service.TOpenSessionReq
var capturedExecReq *cli_service.TExecuteStatementReq

testClient := &client.TestClient{
FnOpenSession: func(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) {
capturedOpenReq = req
return &cli_service.TOpenSessionResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
SessionHandle: &cli_service.TSessionHandle{
SessionId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
},
},
}, nil
},
FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) {
capturedExecReq = req
return &cli_service.TExecuteStatementResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationHandle: &cli_service.TOperationHandle{
OperationId: &cli_service.THandleIdentifier{
GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
Secret: []byte("secret"),
},
},
DirectResults: &cli_service.TSparkDirectResults{
OperationStatus: &cli_service.TGetOperationStatusResp{
Status: &cli_service.TStatus{
StatusCode: cli_service.TStatusCode_SUCCESS_STATUS,
},
OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE),
},
},
}, nil
},
}

// Simulate what connector.Connect() does: pass session params to OpenSession
sessionParams := map[string]string{
"QUERY_TAGS": "team:platform,env:prod",
"ansi_mode": "false",
}
protocolVersion := int64(cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8)
session, err := testClient.OpenSession(context.Background(), &cli_service.TOpenSessionReq{
ClientProtocolI64: &protocolVersion,
Configuration: sessionParams,
})
assert.NoError(t, err)

// Verify session-level tags were sent in OpenSession
assert.Equal(t, "team:platform,env:prod", capturedOpenReq.Configuration["QUERY_TAGS"])
assert.Equal(t, "false", capturedOpenReq.Configuration["ansi_mode"])

// Create conn with session that has session-level tags
cfg := config.WithDefaults()
cfg.SessionParams = sessionParams
testConn := &conn{
session: session,
client: testClient,
cfg: cfg,
}

// Execute with statement-level tags
ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{
"job": "nightly-etl",
})
_, err = testConn.executeStatement(ctx, "SELECT 1", nil)
assert.NoError(t, err)

// Statement-level tags should be in ConfOverlay
assert.Equal(t, "job:nightly-etl", capturedExecReq.ConfOverlay["query_tags"])

// ConfOverlay should ONLY have query_tags, not session params
_, hasAnsiMode := capturedExecReq.ConfOverlay["ansi_mode"]
assert.False(t, hasAnsiMode, "session params should not leak into ConfOverlay")
_, hasSessionQueryTags := capturedExecReq.ConfOverlay["QUERY_TAGS"]
assert.False(t, hasSessionQueryTags, "session-level QUERY_TAGS should not be in ConfOverlay")
})
}

func TestConn_pollOperation(t *testing.T) {
t.Parallel()
t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) {
Expand Down
17 changes: 17 additions & 0 deletions connector.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,23 @@ func WithSessionParams(params map[string]string) ConnOption {
}
}

// WithQueryTags sets session-level query tags from a map.
// Tags are serialized and passed as QUERY_TAGS in the session configuration.
// All queries in the session will carry these tags unless overridden at the statement level.
// This is the preferred way to set session-level query tags, as it handles serialization
// and escaping automatically (consistent with the statement-level API).
func WithQueryTags(tags map[string]string) ConnOption {
return func(c *config.Config) {
serialized := SerializeQueryTags(tags)
if serialized != "" {
if c.SessionParams == nil {
c.SessionParams = make(map[string]string)
}
c.SessionParams["QUERY_TAGS"] = serialized
}
}
}

// WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate.
// WARNING:
// When this option is used, TLS is susceptible to machine-in-the-middle attacks.
Expand Down
59 changes: 59 additions & 0 deletions connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,65 @@ func TestNewConnector(t *testing.T) {
})
}

func TestWithQueryTags(t *testing.T) {
t.Run("WithQueryTags serializes map into SessionParams QUERY_TAGS", func(t *testing.T) {
con, err := NewConnector(
WithQueryTags(map[string]string{
"team": "data-eng",
}),
)
require.NoError(t, err)
coni, ok := con.(*connector)
require.True(t, ok)
assert.Equal(t, "team:data-eng", coni.cfg.SessionParams["QUERY_TAGS"])
})

t.Run("WithQueryTags with multiple tags", func(t *testing.T) {
con, err := NewConnector(
WithQueryTags(map[string]string{
"team": "eng",
"app": "etl",
}),
)
require.NoError(t, err)
coni, ok := con.(*connector)
require.True(t, ok)
// Map iteration is non-deterministic
qt := coni.cfg.SessionParams["QUERY_TAGS"]
assert.True(t, qt == "team:eng,app:etl" || qt == "app:etl,team:eng", "got: %s", qt)
})

t.Run("WithQueryTags with empty map does not set QUERY_TAGS", func(t *testing.T) {
con, err := NewConnector(
WithQueryTags(map[string]string{}),
)
require.NoError(t, err)
coni, ok := con.(*connector)
require.True(t, ok)
_, exists := coni.cfg.SessionParams["QUERY_TAGS"]
assert.False(t, exists)
})

t.Run("WithQueryTags overrides WithSessionParams QUERY_TAGS", func(t *testing.T) {
con, err := NewConnector(
WithSessionParams(map[string]string{
"QUERY_TAGS": "old:value",
"ansi_mode": "false",
}),
WithQueryTags(map[string]string{
"team": "new-team",
}),
)
require.NoError(t, err)
coni, ok := con.(*connector)
require.True(t, ok)
// WithQueryTags should override the QUERY_TAGS from WithSessionParams
assert.Equal(t, "team:new-team", coni.cfg.SessionParams["QUERY_TAGS"])
// Other session params should be preserved
assert.Equal(t, "false", coni.cfg.SessionParams["ansi_mode"])
})
}

type mockRoundTripper struct{}

var _ http.RoundTripper = mockRoundTripper{}
Expand Down
25 changes: 25 additions & 0 deletions driverctx/ctx.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ const (
QueryIdCallbackKey
ConnIdCallbackKey
StagingAllowedLocalPathKey
QueryTagsContextKey
)

type IdCallbackFunc func(string)
Expand Down Expand Up @@ -107,16 +108,40 @@ func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []st
return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath)
}

// NewContextWithQueryTags creates a new context with per-statement query tags.
// These tags are serialized and passed via confOverlay as "query_tags" in TExecuteStatementReq.
// They apply only to the statement executed with this context and do not persist across queries.
func NewContextWithQueryTags(ctx context.Context, queryTags map[string]string) context.Context {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if this is the best way to create a context object with an optional param in Go. I'm not familiar w/ Go. Perhaps ask Claude to double check.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question! This is actually the idiomatic Go pattern for per-request metadata in database/sql drivers. Since the QueryContext(ctx, query, args) interface is fixed by the standard library, context values are the standard way to pass per-request options.

This driver already uses the same pattern for other per-request data:

  • driverctx.NewContextWithConnId
  • driverctx.NewContextWithCorrelationId
  • driverctx.NewContextWithQueryId
  • driverctx.NewContextWithStagingInfo

So NewContextWithQueryTags follows the established convention.

return context.WithValue(ctx, QueryTagsContextKey, queryTags)
}

// QueryTagsFromContext retrieves the per-statement query tags stored in context.
func QueryTagsFromContext(ctx context.Context) map[string]string {
if ctx == nil {
return nil
}

queryTags, ok := ctx.Value(QueryTagsContextKey).(map[string]string)
if !ok {
return nil
}
return queryTags
}

func NewContextFromBackground(ctx context.Context) context.Context {
connId := ConnIdFromContext(ctx)
corrId := CorrelationIdFromContext(ctx)
queryId := QueryIdFromContext(ctx)
stagingPaths := StagingPathsFromContext(ctx)
queryTags := QueryTagsFromContext(ctx)

newCtx := NewContextWithConnId(context.Background(), connId)
newCtx = NewContextWithCorrelationId(newCtx, corrId)
newCtx = NewContextWithQueryId(newCtx, queryId)
newCtx = NewContextWithStagingInfo(newCtx, stagingPaths)
if queryTags != nil {
newCtx = NewContextWithQueryTags(newCtx, queryTags)
}

return newCtx
}
Loading
Loading