From b0952a8a0f5a5f0a818915ea9634a33cf24d257a Mon Sep 17 00:00:00 2001 From: Jooho Yeo Date: Wed, 1 Apr 2026 19:08:05 -0700 Subject: [PATCH 1/3] Add statement-level query tag support via context Previously, query tags could only be set at the session level via WithSessionParams during connection creation. This adds per-statement query tag support, allowing different tags for each query execution. Users pass query tags through context using the new driverctx.NewContextWithQueryTags function. The tags are serialized into the TExecuteStatementReq.ConfOverlay["query_tags"] field, consistent with the Python and NodeJS connector implementations. Co-authored-by: Isaac Signed-off-by: Jooho Yeo --- connection.go | 11 ++++ connection_test.go | 116 ++++++++++++++++++++++++++++++++++++ driverctx/ctx.go | 25 ++++++++ driverctx/ctx_test.go | 46 ++++++++++++++ examples/query_tags/main.go | 41 ++++++++++--- query_tags.go | 32 ++++++++++ query_tags_test.go | 90 ++++++++++++++++++++++++++++ 7 files changed, 354 insertions(+), 7 deletions(-) create mode 100644 query_tags.go create mode 100644 query_tags_test.go diff --git a/connection.go b/connection.go index c297d5bd..b04336ad 100644 --- a/connection.go +++ b/connection.go @@ -323,6 +323,17 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver req.Parameters = parameters } + // Add per-statement query tags if provided via context + if queryTags := driverctx.QueryTagsFromContext(ctx); len(queryTags) > 0 { + serialized := SerializeQueryTags(queryTags) + if serialized != "" { + if req.ConfOverlay == nil { + req.ConfOverlay = make(map[string]string) + } + req.ConfOverlay["query_tags"] = serialized + } + } + resp, err := c.client.ExecuteStatement(ctx, &req) var log *logger.DBSQLLogger log, ctx = client.LoggerAndContext(ctx, resp) diff --git a/connection_test.go b/connection_test.go index 202c8283..2e005f69 100644 --- a/connection_test.go +++ b/connection_test.go @@ -10,6 +10,7 @@ import ( "github.com/apache/thrift/lib/go/thrift" "github.com/pkg/errors" + "github.com/databricks/databricks-sql-go/driverctx" dbsqlerr "github.com/databricks/databricks-sql-go/errors" "github.com/databricks/databricks-sql-go/internal/cli_service" "github.com/databricks/databricks-sql-go/internal/client" @@ -493,6 +494,121 @@ func TestConn_executeStatement_ProtocolFeatures(t *testing.T) { } } +func TestConn_executeStatement_QueryTags(t *testing.T) { + t.Parallel() + + makeTestConn := func(captureReq *(*cli_service.TExecuteStatementReq)) *conn { + executeStatement := func(ctx context.Context, req *cli_service.TExecuteStatementReq) (r *cli_service.TExecuteStatementResp, err error) { + *captureReq = req + return &cli_service.TExecuteStatementResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + Secret: []byte("secret"), + }, + }, + DirectResults: &cli_service.TSparkDirectResults{ + OperationStatus: &cli_service.TGetOperationStatusResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + }, + }, + }, nil + } + + return &conn{ + session: getTestSession(), + client: &client.TestClient{ + FnExecuteStatement: executeStatement, + }, + cfg: config.WithDefaults(), + } + } + + t.Run("query tags from context are set in ConfOverlay", func(t *testing.T) { + var capturedReq *cli_service.TExecuteStatementReq + testConn := makeTestConn(&capturedReq) + + ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "team": "engineering", + "app": "etl", + }) + + _, err := testConn.executeStatement(ctx, "SELECT 1", nil) + assert.NoError(t, err) + assert.NotNil(t, capturedReq.ConfOverlay) + // Map iteration is non-deterministic, so check both possible orderings + queryTags := capturedReq.ConfOverlay["query_tags"] + assert.True(t, + queryTags == "team:engineering,app:etl" || queryTags == "app:etl,team:engineering", + "unexpected query_tags value: %s", queryTags) + }) + + t.Run("no query tags in context means no ConfOverlay", func(t *testing.T) { + var capturedReq *cli_service.TExecuteStatementReq + testConn := makeTestConn(&capturedReq) + + _, err := testConn.executeStatement(context.Background(), "SELECT 1", nil) + assert.NoError(t, err) + assert.Nil(t, capturedReq.ConfOverlay) + }) + + t.Run("empty query tags map means no ConfOverlay", func(t *testing.T) { + var capturedReq *cli_service.TExecuteStatementReq + testConn := makeTestConn(&capturedReq) + + ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{}) + + _, err := testConn.executeStatement(ctx, "SELECT 1", nil) + assert.NoError(t, err) + assert.Nil(t, capturedReq.ConfOverlay) + }) + + t.Run("single query tag", func(t *testing.T) { + var capturedReq *cli_service.TExecuteStatementReq + testConn := makeTestConn(&capturedReq) + + ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "team": "data-eng", + }) + + _, err := testConn.executeStatement(ctx, "SELECT 1", nil) + assert.NoError(t, err) + assert.Equal(t, "team:data-eng", capturedReq.ConfOverlay["query_tags"]) + }) + + t.Run("query tags with special characters in values", func(t *testing.T) { + var capturedReq *cli_service.TExecuteStatementReq + testConn := makeTestConn(&capturedReq) + + ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "url": "http://host:8080", + }) + + _, err := testConn.executeStatement(ctx, "SELECT 1", nil) + assert.NoError(t, err) + assert.Equal(t, `url:http\://host\:8080`, capturedReq.ConfOverlay["query_tags"]) + }) + + t.Run("query tags with empty value", func(t *testing.T) { + var capturedReq *cli_service.TExecuteStatementReq + testConn := makeTestConn(&capturedReq) + + ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "flag": "", + }) + + _, err := testConn.executeStatement(ctx, "SELECT 1", nil) + assert.NoError(t, err) + assert.Equal(t, "flag", capturedReq.ConfOverlay["query_tags"]) + }) +} + func TestConn_pollOperation(t *testing.T) { t.Parallel() t.Run("pollOperation returns finished state response when query finishes", func(t *testing.T) { diff --git a/driverctx/ctx.go b/driverctx/ctx.go index f8f4674d..b3631295 100644 --- a/driverctx/ctx.go +++ b/driverctx/ctx.go @@ -15,6 +15,7 @@ const ( QueryIdCallbackKey ConnIdCallbackKey StagingAllowedLocalPathKey + QueryTagsContextKey ) type IdCallbackFunc func(string) @@ -107,16 +108,40 @@ func NewContextWithStagingInfo(ctx context.Context, stagingAllowedLocalPath []st return context.WithValue(ctx, StagingAllowedLocalPathKey, stagingAllowedLocalPath) } +// NewContextWithQueryTags creates a new context with per-statement query tags. +// These tags are serialized and passed via confOverlay as "query_tags" in TExecuteStatementReq. +// They apply only to the statement executed with this context and do not persist across queries. +func NewContextWithQueryTags(ctx context.Context, queryTags map[string]string) context.Context { + return context.WithValue(ctx, QueryTagsContextKey, queryTags) +} + +// QueryTagsFromContext retrieves the per-statement query tags stored in context. +func QueryTagsFromContext(ctx context.Context) map[string]string { + if ctx == nil { + return nil + } + + queryTags, ok := ctx.Value(QueryTagsContextKey).(map[string]string) + if !ok { + return nil + } + return queryTags +} + func NewContextFromBackground(ctx context.Context) context.Context { connId := ConnIdFromContext(ctx) corrId := CorrelationIdFromContext(ctx) queryId := QueryIdFromContext(ctx) stagingPaths := StagingPathsFromContext(ctx) + queryTags := QueryTagsFromContext(ctx) newCtx := NewContextWithConnId(context.Background(), connId) newCtx = NewContextWithCorrelationId(newCtx, corrId) newCtx = NewContextWithQueryId(newCtx, queryId) newCtx = NewContextWithStagingInfo(newCtx, stagingPaths) + if queryTags != nil { + newCtx = NewContextWithQueryTags(newCtx, queryTags) + } return newCtx } diff --git a/driverctx/ctx_test.go b/driverctx/ctx_test.go index 4d4800aa..8b961021 100644 --- a/driverctx/ctx_test.go +++ b/driverctx/ctx_test.go @@ -8,6 +8,52 @@ import ( "github.com/stretchr/testify/assert" ) +func TestNewContextWithQueryTags(t *testing.T) { + t.Run("stores and retrieves query tags", func(t *testing.T) { + tags := map[string]string{"team": "engineering", "app": "etl"} + ctx := NewContextWithQueryTags(context.Background(), tags) + result := QueryTagsFromContext(ctx) + assert.Equal(t, tags, result) + }) + + t.Run("returns nil for context without query tags", func(t *testing.T) { + result := QueryTagsFromContext(context.Background()) + assert.Nil(t, result) + }) + + t.Run("it maintains timeout", func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + tags := map[string]string{"team": "eng"} + ctx1 := NewContextWithQueryTags(ctx, tags) + result := QueryTagsFromContext(ctx1) + assert.Equal(t, tags, result) + dl, ok := ctx.Deadline() + dl1, ok1 := ctx1.Deadline() + assert.Equal(t, dl, dl1) + assert.True(t, ok) + assert.True(t, ok1) + }) + + t.Run("NewContextFromBackground preserves query tags", func(t *testing.T) { + tags := map[string]string{"team": "eng"} + ctx := NewContextWithConnId(context.Background(), "conn-1") + ctx = NewContextWithCorrelationId(ctx, "corr-1") + ctx = NewContextWithQueryTags(ctx, tags) + + newCtx := NewContextFromBackground(ctx) + assert.Equal(t, tags, QueryTagsFromContext(newCtx)) + assert.Equal(t, "conn-1", ConnIdFromContext(newCtx)) + assert.Equal(t, "corr-1", CorrelationIdFromContext(newCtx)) + }) + + t.Run("NewContextFromBackground without query tags", func(t *testing.T) { + ctx := NewContextWithConnId(context.Background(), "conn-1") + newCtx := NewContextFromBackground(ctx) + assert.Nil(t, QueryTagsFromContext(newCtx)) + }) +} + func TestNewContextWithCorrelationId(t *testing.T) { t.Run("base case", func(t *testing.T) { diff --git a/examples/query_tags/main.go b/examples/query_tags/main.go index 9997aeec..d26b1c67 100644 --- a/examples/query_tags/main.go +++ b/examples/query_tags/main.go @@ -9,6 +9,7 @@ import ( "strconv" dbsql "github.com/databricks/databricks-sql-go" + "github.com/databricks/databricks-sql-go/driverctx" "github.com/joho/godotenv" ) @@ -21,6 +22,7 @@ func main() { log.Fatal(err.Error()) } + // Session-level query tags: applied to all queries in this session. connector, err := dbsql.NewConnector( dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")), dbsql.WithPort(port), @@ -38,16 +40,41 @@ func main() { db := sql.OpenDB(connector) defer db.Close() + // Example 1: Session-level query tags (set during connection) + fmt.Println("=== Session-level query tags ===") ctx := context.Background() var result int err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result) if err != nil { - if err == sql.ErrNoRows { - fmt.Println("not found") - return - } else { - fmt.Printf("err: %v\n", err) - } + log.Printf("err: %v\n", err) + } else { + fmt.Println(result) + } + + // Example 2: Statement-level query tags (per-query override via context) + fmt.Println("=== Statement-level query tags ===") + ctx = driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "team": "data-eng", + "application": "etl-pipeline", + "env": "production", + }) + err = db.QueryRowContext(ctx, "SELECT 2").Scan(&result) + if err != nil { + log.Printf("err: %v\n", err) + } else { + fmt.Println(result) + } + + // Example 3: Different query tags for a different statement + fmt.Println("=== Different statement-level query tags ===") + ctx = driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "team": "analytics", + "job": "daily-report", + }) + err = db.QueryRowContext(ctx, "SELECT 3").Scan(&result) + if err != nil { + log.Printf("err: %v\n", err) + } else { + fmt.Println(result) } - fmt.Println(result) } diff --git a/query_tags.go b/query_tags.go new file mode 100644 index 00000000..6f7ec844 --- /dev/null +++ b/query_tags.go @@ -0,0 +1,32 @@ +package dbsql + +import "strings" + +// SerializeQueryTags converts a map of query tags to the wire format string. +// The format is comma-separated key:value pairs (e.g., "team:engineering,app:etl"). +// +// Escaping rules (consistent with Python and NodeJS connectors): +// - Keys: only backslashes are escaped +// - Values: backslashes, colons, and commas are escaped with a leading backslash +// - Empty string values result in just the key being emitted (no colon) +// +// Returns empty string if the map is nil or empty. +func SerializeQueryTags(tags map[string]string) string { + if len(tags) == 0 { + return "" + } + + parts := make([]string, 0, len(tags)) + for k, v := range tags { + escapedKey := strings.ReplaceAll(k, `\`, `\\`) + if v == "" { + parts = append(parts, escapedKey) + } else { + escapedValue := strings.ReplaceAll(v, `\`, `\\`) + escapedValue = strings.ReplaceAll(escapedValue, `:`, `\:`) + escapedValue = strings.ReplaceAll(escapedValue, `,`, `\,`) + parts = append(parts, escapedKey+":"+escapedValue) + } + } + return strings.Join(parts, ",") +} diff --git a/query_tags_test.go b/query_tags_test.go new file mode 100644 index 00000000..fb5ee1d4 --- /dev/null +++ b/query_tags_test.go @@ -0,0 +1,90 @@ +package dbsql + +import ( + "sort" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSerializeQueryTags(t *testing.T) { + t.Parallel() + + t.Run("nil input returns empty string", func(t *testing.T) { + result := SerializeQueryTags(nil) + assert.Equal(t, "", result) + }) + + t.Run("empty map returns empty string", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{}) + assert.Equal(t, "", result) + }) + + t.Run("single tag", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"team": "engineering"}) + assert.Equal(t, "team:engineering", result) + }) + + t.Run("multiple tags", func(t *testing.T) { + // Go map iteration order is non-deterministic, so we check the parts + result := SerializeQueryTags(map[string]string{"team": "engineering", "app": "etl"}) + parts := strings.Split(result, ",") + sort.Strings(parts) + assert.Equal(t, []string{"app:etl", "team:engineering"}, parts) + }) + + t.Run("empty value emits key only without colon", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"flag": ""}) + assert.Equal(t, "flag", result) + }) + + t.Run("mixed empty and non-empty values", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"team": "eng", "flag": "", "app": "etl"}) + parts := strings.Split(result, ",") + sort.Strings(parts) + assert.Equal(t, []string{"app:etl", "flag", "team:eng"}, parts) + }) + + t.Run("escape backslash in value", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"path": `a\b`}) + assert.Equal(t, `path:a\\b`, result) + }) + + t.Run("escape colon in value", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"url": "http://host"}) + assert.Equal(t, `url:http\://host`, result) + }) + + t.Run("escape comma in value", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"list": "a,b"}) + assert.Equal(t, `list:a\,b`, result) + }) + + t.Run("escape multiple special chars in value", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"val": `a\b:c,d`}) + assert.Equal(t, `val:a\\b\:c\,d`, result) + }) + + t.Run("escape backslash in key", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{`a\b`: "value"}) + assert.Equal(t, `a\\b:value`, result) + }) + + t.Run("escape backslash in key with empty value", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{`a\b`: ""}) + assert.Equal(t, `a\\b`, result) + }) + + t.Run("colons and commas in keys are not escaped", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"key:name": "value"}) + assert.Equal(t, "key:name:value", result) + }) + + t.Run("all empty values", func(t *testing.T) { + result := SerializeQueryTags(map[string]string{"key1": "", "key2": "", "key3": ""}) + parts := strings.Split(result, ",") + sort.Strings(parts) + assert.Equal(t, []string{"key1", "key2", "key3"}, parts) + }) +} From 95776cb1df20c273c0438716190c70c8287d7096 Mon Sep 17 00:00:00 2001 From: Jooho Yeo Date: Wed, 1 Apr 2026 20:55:35 -0700 Subject: [PATCH 2/3] Add test for session-level and statement-level query tags coexistence Verifies that session-level tags (TOpenSessionReq.Configuration) and statement-level tags (TExecuteStatementReq.ConfOverlay) are independent: session params don't leak into ConfOverlay, and statement-level tags are correctly set even when session-level tags exist. Co-authored-by: Isaac Signed-off-by: Jooho Yeo --- connection_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/connection_test.go b/connection_test.go index 2e005f69..a9a0e8b1 100644 --- a/connection_test.go +++ b/connection_test.go @@ -607,6 +607,94 @@ func TestConn_executeStatement_QueryTags(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "flag", capturedReq.ConfOverlay["query_tags"]) }) + + t.Run("session-level and statement-level query tags coexist", func(t *testing.T) { + // Session-level tags are sent via TOpenSessionReq.Configuration at connect time. + // Statement-level tags are sent via TExecuteStatementReq.ConfOverlay at query time. + // They are independent fields on different requests, so both should work together. + + var capturedOpenReq *cli_service.TOpenSessionReq + var capturedExecReq *cli_service.TExecuteStatementReq + + testClient := &client.TestClient{ + FnOpenSession: func(ctx context.Context, req *cli_service.TOpenSessionReq) (*cli_service.TOpenSessionResp, error) { + capturedOpenReq = req + return &cli_service.TOpenSessionResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + SessionHandle: &cli_service.TSessionHandle{ + SessionId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + }, + }, + }, nil + }, + FnExecuteStatement: func(ctx context.Context, req *cli_service.TExecuteStatementReq) (*cli_service.TExecuteStatementResp, error) { + capturedExecReq = req + return &cli_service.TExecuteStatementResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationHandle: &cli_service.TOperationHandle{ + OperationId: &cli_service.THandleIdentifier{ + GUID: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}, + Secret: []byte("secret"), + }, + }, + DirectResults: &cli_service.TSparkDirectResults{ + OperationStatus: &cli_service.TGetOperationStatusResp{ + Status: &cli_service.TStatus{ + StatusCode: cli_service.TStatusCode_SUCCESS_STATUS, + }, + OperationState: cli_service.TOperationStatePtr(cli_service.TOperationState_FINISHED_STATE), + }, + }, + }, nil + }, + } + + // Simulate what connector.Connect() does: pass session params to OpenSession + sessionParams := map[string]string{ + "QUERY_TAGS": "team:platform,env:prod", + "ansi_mode": "false", + } + protocolVersion := int64(cli_service.TProtocolVersion_SPARK_CLI_SERVICE_PROTOCOL_V8) + session, err := testClient.OpenSession(context.Background(), &cli_service.TOpenSessionReq{ + ClientProtocolI64: &protocolVersion, + Configuration: sessionParams, + }) + assert.NoError(t, err) + + // Verify session-level tags were sent in OpenSession + assert.Equal(t, "team:platform,env:prod", capturedOpenReq.Configuration["QUERY_TAGS"]) + assert.Equal(t, "false", capturedOpenReq.Configuration["ansi_mode"]) + + // Create conn with session that has session-level tags + cfg := config.WithDefaults() + cfg.SessionParams = sessionParams + testConn := &conn{ + session: session, + client: testClient, + cfg: cfg, + } + + // Execute with statement-level tags + ctx := driverctx.NewContextWithQueryTags(context.Background(), map[string]string{ + "job": "nightly-etl", + }) + _, err = testConn.executeStatement(ctx, "SELECT 1", nil) + assert.NoError(t, err) + + // Statement-level tags should be in ConfOverlay + assert.Equal(t, "job:nightly-etl", capturedExecReq.ConfOverlay["query_tags"]) + + // ConfOverlay should ONLY have query_tags, not session params + _, hasAnsiMode := capturedExecReq.ConfOverlay["ansi_mode"] + assert.False(t, hasAnsiMode, "session params should not leak into ConfOverlay") + _, hasSessionQueryTags := capturedExecReq.ConfOverlay["QUERY_TAGS"] + assert.False(t, hasSessionQueryTags, "session-level QUERY_TAGS should not be in ConfOverlay") + }) } func TestConn_pollOperation(t *testing.T) { From 04704e10c675e2a4f9cf1df703ec9ff9ed62f589 Mon Sep 17 00:00:00 2001 From: Jooho Yeo Date: Wed, 1 Apr 2026 21:23:03 -0700 Subject: [PATCH 3/3] Add WithQueryTags connector option for session-level map support Addresses review feedback from jiabin-hu: 1. Add WithQueryTags(map[string]string) as a connector option that accepts a structured map and serializes it internally, consistent with the statement-level API and the Python connector approach (databricks/databricks-sql-python@e916f71). 2. Context values pattern is the idiomatic Go approach for per-request metadata in database/sql drivers (same pattern used by ConnId, CorrelationId, QueryId, and StagingInfo in this driver). Co-authored-by: Isaac Signed-off-by: Jooho Yeo --- connector.go | 17 +++++++++++ connector_test.go | 59 +++++++++++++++++++++++++++++++++++++ examples/query_tags/main.go | 14 +++++---- 3 files changed, 84 insertions(+), 6 deletions(-) diff --git a/connector.go b/connector.go index 1f77ac3f..41772fd1 100644 --- a/connector.go +++ b/connector.go @@ -235,6 +235,23 @@ func WithSessionParams(params map[string]string) ConnOption { } } +// WithQueryTags sets session-level query tags from a map. +// Tags are serialized and passed as QUERY_TAGS in the session configuration. +// All queries in the session will carry these tags unless overridden at the statement level. +// This is the preferred way to set session-level query tags, as it handles serialization +// and escaping automatically (consistent with the statement-level API). +func WithQueryTags(tags map[string]string) ConnOption { + return func(c *config.Config) { + serialized := SerializeQueryTags(tags) + if serialized != "" { + if c.SessionParams == nil { + c.SessionParams = make(map[string]string) + } + c.SessionParams["QUERY_TAGS"] = serialized + } + } +} + // WithSkipTLSHostVerify disables the verification of the hostname in the TLS certificate. // WARNING: // When this option is used, TLS is susceptible to machine-in-the-middle attacks. diff --git a/connector_test.go b/connector_test.go index bba5db1f..c89b74e0 100644 --- a/connector_test.go +++ b/connector_test.go @@ -268,6 +268,65 @@ func TestNewConnector(t *testing.T) { }) } +func TestWithQueryTags(t *testing.T) { + t.Run("WithQueryTags serializes map into SessionParams QUERY_TAGS", func(t *testing.T) { + con, err := NewConnector( + WithQueryTags(map[string]string{ + "team": "data-eng", + }), + ) + require.NoError(t, err) + coni, ok := con.(*connector) + require.True(t, ok) + assert.Equal(t, "team:data-eng", coni.cfg.SessionParams["QUERY_TAGS"]) + }) + + t.Run("WithQueryTags with multiple tags", func(t *testing.T) { + con, err := NewConnector( + WithQueryTags(map[string]string{ + "team": "eng", + "app": "etl", + }), + ) + require.NoError(t, err) + coni, ok := con.(*connector) + require.True(t, ok) + // Map iteration is non-deterministic + qt := coni.cfg.SessionParams["QUERY_TAGS"] + assert.True(t, qt == "team:eng,app:etl" || qt == "app:etl,team:eng", "got: %s", qt) + }) + + t.Run("WithQueryTags with empty map does not set QUERY_TAGS", func(t *testing.T) { + con, err := NewConnector( + WithQueryTags(map[string]string{}), + ) + require.NoError(t, err) + coni, ok := con.(*connector) + require.True(t, ok) + _, exists := coni.cfg.SessionParams["QUERY_TAGS"] + assert.False(t, exists) + }) + + t.Run("WithQueryTags overrides WithSessionParams QUERY_TAGS", func(t *testing.T) { + con, err := NewConnector( + WithSessionParams(map[string]string{ + "QUERY_TAGS": "old:value", + "ansi_mode": "false", + }), + WithQueryTags(map[string]string{ + "team": "new-team", + }), + ) + require.NoError(t, err) + coni, ok := con.(*connector) + require.True(t, ok) + // WithQueryTags should override the QUERY_TAGS from WithSessionParams + assert.Equal(t, "team:new-team", coni.cfg.SessionParams["QUERY_TAGS"]) + // Other session params should be preserved + assert.Equal(t, "false", coni.cfg.SessionParams["ansi_mode"]) + }) +} + type mockRoundTripper struct{} var _ http.RoundTripper = mockRoundTripper{} diff --git a/examples/query_tags/main.go b/examples/query_tags/main.go index d26b1c67..750a5ede 100644 --- a/examples/query_tags/main.go +++ b/examples/query_tags/main.go @@ -22,15 +22,17 @@ func main() { log.Fatal(err.Error()) } - // Session-level query tags: applied to all queries in this session. + // Connection-level query tags: applied to all queries in this session. + // WithQueryTags accepts a map and handles serialization automatically. connector, err := dbsql.NewConnector( dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")), dbsql.WithPort(port), dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")), dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")), - dbsql.WithSessionParams(map[string]string{ - "QUERY_TAGS": "team:engineering,test:query-tags,driver:go", - "ansi_mode": "false", + dbsql.WithQueryTags(map[string]string{ + "team": "engineering", + "test": "query-tags", + "driver": "go", }), ) if err != nil { @@ -40,8 +42,8 @@ func main() { db := sql.OpenDB(connector) defer db.Close() - // Example 1: Session-level query tags (set during connection) - fmt.Println("=== Session-level query tags ===") + // Example 1: Connection-level query tags (set during connection) + fmt.Println("=== Connection-level query tags ===") ctx := context.Background() var result int err = db.QueryRowContext(ctx, "SELECT 1").Scan(&result)