From d97af6dd9f78553b77893498141bd1212b029b64 Mon Sep 17 00:00:00 2001 From: Akshith Gunasekaran Date: Fri, 5 Jun 2026 15:24:54 -0700 Subject: [PATCH 1/3] Escape generated SQL identifiers --- db.go | 14 ++- dialect_mysql.go | 4 +- dialect_mysql_test.go | 4 +- dialect_postgres.go | 6 +- dialect_postgres_test.go | 5 +- dialect_sqlite.go | 3 +- identifier_quote_test.go | 196 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 221 insertions(+), 11 deletions(-) create mode 100644 identifier_quote_test.go diff --git a/db.go b/db.go index 4e07588f..a88ed5f1 100644 --- a/db.go +++ b/db.go @@ -101,7 +101,11 @@ func (m *DbMap) createIndexImpl(ctx context.Context, dialect reflect.Type, s.WriteString(" unique") } s.WriteString(" index") - s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName)) + s.WriteString(fmt.Sprintf( + " %s on %s", + m.Dialect.QuoteField(index.IndexName), + m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName), + )) if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" { s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) } @@ -129,10 +133,14 @@ func (t *TableMap) DropIndex(ctx context.Context, name string) error { for _, idx := range t.indexes { if idx.IndexName == name { s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName)) + s.WriteString(fmt.Sprintf("DROP INDEX %s", t.dbmap.Dialect.QuoteField(idx.IndexName))) if dname := dialect.Name(); dname == "MySQLDialect" { - s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName)) + s.WriteString(fmt.Sprintf( + " %s %s", + t.dbmap.Dialect.DropIndexSuffix(), + t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName), + )) } s.WriteString(";") _, e := t.dbmap.ExecContext(ctx, s.String()) diff --git a/dialect_mysql.go b/dialect_mysql.go index 1dfc2be6..adead487 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -146,7 +146,7 @@ func (d MySQLDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, inse } func (d MySQLDialect) QuoteField(f string) string { - return "`" + f + "`" + return "`" + strings.ReplaceAll(f, "`", "``") + "`" } func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { @@ -154,7 +154,7 @@ func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { return d.QuoteField(table) } - return schema + "." + d.QuoteField(table) + return d.QuoteField(schema) + "." + d.QuoteField(table) } func (d MySQLDialect) IfSchemaNotExists(command, schema string) string { diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index 966162ef..660ae85a 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -141,6 +141,7 @@ func TestMySQLDialect(t *testing.T) { o.Spec("QuoteField", func(tcx testContext) { tcx.expect(tcx.dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) + tcx.expect(tcx.dialect.QuoteField("fo`o")).To(matchers.Equal("`fo``o`")) }) o.Group("QuotedTableForQuery", func() { @@ -149,7 +150,8 @@ func TestMySQLDialect(t *testing.T) { }) o.Spec("with a supplied schema", func(tcx testContext) { - tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`")) + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("`foo`.`bar`")) + tcx.expect(tcx.dialect.QuotedTableForQuery("fo`o", "ba`r")).To(matchers.Equal("`fo``o`.`ba``r`")) }) }) diff --git a/dialect_postgres.go b/dialect_postgres.go index 937f81e6..02113d1f 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -124,9 +124,9 @@ func (d PostgresDialect) InsertAutoIncrToTarget(ctx context.Context, exec SqlExe func (d PostgresDialect) QuoteField(f string) string { if d.LowercaseFields { - return `"` + strings.ToLower(f) + `"` + f = strings.ToLower(f) } - return `"` + f + `"` + return `"` + strings.ReplaceAll(f, `"`, `""`) + `"` } func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { @@ -134,7 +134,7 @@ func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string return d.QuoteField(table) } - return schema + "." + d.QuoteField(table) + return d.QuoteField(schema) + "." + d.QuoteField(table) } func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 4a2f6749..25e46048 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -120,6 +120,7 @@ func TestPostgresDialect(t *testing.T) { o.Spec("By default, case is preserved", func(tcx postgresTestContext) { tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) tcx.expect(tcx.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`)) + tcx.expect(tcx.dialect.QuoteField(`Fo"o`)).To(matchers.Equal(`"Fo""o"`)) }) o.Group("With LowercaseFields set to true", func() { @@ -130,6 +131,7 @@ func TestPostgresDialect(t *testing.T) { o.Spec("fields are lowercased", func(tcx postgresTestContext) { tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`)) + tcx.expect(tcx.dialect.QuoteField(`Fo"O`)).To(matchers.Equal(`"fo""o"`)) }) }) }) @@ -140,7 +142,8 @@ func TestPostgresDialect(t *testing.T) { }) o.Spec("with a supplied schema", func(tcx postgresTestContext) { - tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`)) + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`"foo"."bar"`)) + tcx.expect(tcx.dialect.QuotedTableForQuery(`fo"o`, `ba"r`)).To(matchers.Equal(`"fo""o"."ba""r"`)) }) }) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index a0ba40d6..d4221740 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "reflect" + "strings" ) type SqliteDialect struct { @@ -92,7 +93,7 @@ func (d SqliteDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, ins } func (d SqliteDialect) QuoteField(f string) string { - return `"` + f + `"` + return `"` + strings.ReplaceAll(f, `"`, `""`) + `"` } // sqlite does not have schemas like PostgreSQL does, so just escape it like normal diff --git a/identifier_quote_test.go b/identifier_quote_test.go new file mode 100644 index 00000000..f250f85b --- /dev/null +++ b/identifier_quote_test.go @@ -0,0 +1,196 @@ +package borp + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "sync" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +type identifierCapturedExec struct { + query string + args []driver.NamedValue +} + +var identifierCaptureState = struct { + sync.Mutex + registered sync.Once + execs []identifierCapturedExec +}{} + +type identifierCaptureDriver struct{} + +func (identifierCaptureDriver) Open(string) (driver.Conn, error) { + return identifierCaptureConn{}, nil +} + +type identifierCaptureConn struct{} + +func (identifierCaptureConn) Prepare(string) (driver.Stmt, error) { + return nil, errors.New("identifier capture driver does not prepare statements") +} + +func (identifierCaptureConn) Close() error { + return nil +} + +func (identifierCaptureConn) Begin() (driver.Tx, error) { + return nil, errors.New("identifier capture driver does not begin transactions") +} + +func (identifierCaptureConn) ExecContext( + _ context.Context, + query string, + args []driver.NamedValue, +) (driver.Result, error) { + argsCopy := append([]driver.NamedValue(nil), args...) + + identifierCaptureState.Lock() + defer identifierCaptureState.Unlock() + identifierCaptureState.execs = append(identifierCaptureState.execs, identifierCapturedExec{ + query: query, + args: argsCopy, + }) + return driver.RowsAffected(0), nil +} + +func newIdentifierCaptureDbMap(t *testing.T, dialect Dialect) *DbMap { + t.Helper() + + identifierCaptureState.registered.Do(func() { + sql.Register("borp_identifier_capture", identifierCaptureDriver{}) + }) + + identifierCaptureState.Lock() + identifierCaptureState.execs = nil + identifierCaptureState.Unlock() + + db, err := sql.Open("borp_identifier_capture", "") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + closeErr := db.Close() + if closeErr != nil { + t.Fatal(closeErr) + } + }) + + return &DbMap{Db: db, Dialect: dialect} +} + +func identifierCapturedExecs() []identifierCapturedExec { + identifierCaptureState.Lock() + defer identifierCaptureState.Unlock() + return append([]identifierCapturedExec(nil), identifierCaptureState.execs...) +} + +func TestSqliteDialectEscapesIdentifierQuotes(t *testing.T) { + dialect := SqliteDialect{} + got := dialect.QuoteField(`fo"o`) + want := `"fo""o"` + if got != want { + t.Fatalf("QuoteField() = %q, want %q", got, want) + } + got = dialect.QuotedTableForQuery("", `ta"ble`) + want = `"ta""ble"` + if got != want { + t.Fatalf("QuotedTableForQuery() = %q, want %q", got, want) + } +} + +func TestQuotedTableNameCannotRewriteUpdateTarget(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE victim (id integer primary key, value text, admin integer)") + if err != nil { + t.Fatal(err) + } + _, err = db.Exec("INSERT INTO victim (id, value, admin) VALUES (1, 'unchanged', 0)") + if err != nil { + t.Fatal(err) + } + + type row struct { + ID int64 `db:"ID"` + Value string `db:"Value"` + } + + dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}} + injectedTable := `victim" SET admin = 1 WHERE ? <> ? -- ` + dbmap.AddTableWithName(row{}, injectedTable).SetKeys(false, "ID") + + _, err = dbmap.Update(context.Background(), &row{ID: 1, Value: "unused"}) + if err == nil { + t.Fatal("Update succeeded for escaped malicious table name") + } + + var admin int + err = db.QueryRow("SELECT admin FROM victim WHERE id = 1").Scan(&admin) + if err != nil { + t.Fatal(err) + } + if admin != 0 { + t.Fatalf("victim.admin = %d, want 0", admin) + } +} + +func TestCreateIndexQuotesIdentifierMetadata(t *testing.T) { + type indexedRow struct { + ID int64 `db:"ID"` + } + + dbmap := newIdentifierCaptureDbMap(t, PostgresDialect{}) + table := dbmap.AddTableWithNameAndSchema(indexedRow{}, `sche"ma`, `security"rows`) + table.SetKeys(false, "ID") + table.AddIndex(`idx"name`, "btree", []string{"ID"}) + + err := dbmap.CreateIndex(context.Background()) + if err != nil { + t.Fatal(err) + } + + execs := identifierCapturedExecs() + if len(execs) != 1 { + t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) + } + + want := `create index "idx""name" on "sche""ma"."security""rows" using btree ("ID");` + if execs[0].query != want { + t.Fatalf("generated %q, want %q", execs[0].query, want) + } +} + +func TestDropIndexQuotesIdentifierMetadata(t *testing.T) { + type indexedRow struct { + ID int64 `db:"ID"` + } + + dbmap := newIdentifierCaptureDbMap(t, MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}) + table := dbmap.AddTableWithNameAndSchema(indexedRow{}, "sche`ma", "security`rows") + table.SetKeys(false, "ID") + table.AddIndex("idx`name", "Btree", []string{"ID"}) + + err := table.DropIndex(context.Background(), "idx`name") + if err != nil { + t.Fatal(err) + } + + execs := identifierCapturedExecs() + if len(execs) != 1 { + t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) + } + + want := "DROP INDEX `idx``name` on `sche``ma`.`security``rows`;" + if execs[0].query != want { + t.Fatalf("generated %q, want %q", execs[0].query, want) + } +} From a7171261842694bbb553b8e107a0ac4c142ba871 Mon Sep 17 00:00:00 2001 From: Akshith Gunasekaran Date: Thu, 11 Jun 2026 00:44:19 -0700 Subject: [PATCH 2/3] Address identifier quote review nits --- identifier_quote_test.go | 271 +++++++++++++++++++++++++++++---------- 1 file changed, 200 insertions(+), 71 deletions(-) diff --git a/identifier_quote_test.go b/identifier_quote_test.go index f250f85b..9dc34600 100644 --- a/identifier_quote_test.go +++ b/identifier_quote_test.go @@ -11,6 +11,60 @@ import ( _ "github.com/mattn/go-sqlite3" ) +func TestSqliteDialectEscapesIdentifierQuotes(t *testing.T) { + dialect := SqliteDialect{} + got := dialect.QuoteField(`fo"o`) + want := `"fo""o"` + if got != want { + t.Fatalf("QuoteField() = %q, want %q", got, want) + } + got = dialect.QuotedTableForQuery("", `ta"ble`) + want = `"ta""ble"` + if got != want { + t.Fatalf("QuotedTableForQuery() = %q, want %q", got, want) + } +} + +func TestSqliteQuotedTableNameCannotRewriteUpdateTarget(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE victim (id integer primary key, value text, admin integer)") + if err != nil { + t.Fatal(err) + } + _, err = db.Exec("INSERT INTO victim (id, value, admin) VALUES (1, 'unchanged', 0)") + if err != nil { + t.Fatal(err) + } + + type row struct { + ID int64 `db:"ID"` + Value string `db:"Value"` + } + + dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}} + injectedTable := `victim" SET admin = 1 WHERE ? <> ? -- ` + dbmap.AddTableWithName(row{}, injectedTable).SetKeys(false, "ID") + + _, err = dbmap.Update(context.Background(), &row{ID: 1, Value: "unused"}) + if err == nil { + t.Fatal("Update succeeded for escaped malicious table name") + } + + var admin int + err = db.QueryRow("SELECT admin FROM victim WHERE id = 1").Scan(&admin) + if err != nil { + t.Fatal(err) + } + if admin != 0 { + t.Fatalf("victim.admin = %d, want 0", admin) + } +} + type identifierCapturedExec struct { query string args []driver.NamedValue @@ -89,57 +143,66 @@ func identifierCapturedExecs() []identifierCapturedExec { return append([]identifierCapturedExec(nil), identifierCaptureState.execs...) } -func TestSqliteDialectEscapesIdentifierQuotes(t *testing.T) { - dialect := SqliteDialect{} - got := dialect.QuoteField(`fo"o`) - want := `"fo""o"` - if got != want { - t.Fatalf("QuoteField() = %q, want %q", got, want) +func requireIdentifierCapturedQuery(t *testing.T, want string) { + t.Helper() + execs := identifierCapturedExecs() + if len(execs) != 1 { + t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) } - got = dialect.QuotedTableForQuery("", `ta"ble`) - want = `"ta""ble"` - if got != want { - t.Fatalf("QuotedTableForQuery() = %q, want %q", got, want) + if execs[0].query != want { + t.Fatalf("generated %q, want %q", execs[0].query, want) } } -func TestQuotedTableNameCannotRewriteUpdateTarget(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - t.Fatal(err) - } - defer db.Close() - - _, err = db.Exec("CREATE TABLE victim (id integer primary key, value text, admin integer)") - if err != nil { - t.Fatal(err) - } - _, err = db.Exec("INSERT INTO victim (id, value, admin) VALUES (1, 'unchanged', 0)") - if err != nil { - t.Fatal(err) - } - - type row struct { +func TestUpdateQuotesIdentifierMetadata(t *testing.T) { + type updatedRow struct { ID int64 `db:"ID"` Value string `db:"Value"` } - dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}} - injectedTable := `victim" SET admin = 1 WHERE ? <> ? -- ` - dbmap.AddTableWithName(row{}, injectedTable).SetKeys(false, "ID") - - _, err = dbmap.Update(context.Background(), &row{ID: 1, Value: "unused"}) - if err == nil { - t.Fatal("Update succeeded for escaped malicious table name") + tests := []struct { + name string + dialect Dialect + schema string + table string + want string + }{ + { + name: "sqlite", + dialect: SqliteDialect{}, + schema: `sche"ma`, + table: `security"rows`, + want: `update "security""rows" set "Value"=? where "ID"=?;`, + }, + { + name: "postgres", + dialect: PostgresDialect{}, + schema: `sche"ma`, + table: `security"rows`, + want: `update "sche""ma"."security""rows" set "Value"=$1 where "ID"=$2;`, + }, + { + name: "mysql", + dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, + schema: "sche`ma", + table: "security`rows", + want: "update `sche``ma`.`security``rows` set `Value`=? where `ID`=?;", + }, } - var admin int - err = db.QueryRow("SELECT admin FROM victim WHERE id = 1").Scan(&admin) - if err != nil { - t.Fatal(err) - } - if admin != 0 { - t.Fatalf("victim.admin = %d, want 0", admin) + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dbmap := newIdentifierCaptureDbMap(t, tc.dialect) + table := dbmap.AddTableWithNameAndSchema(updatedRow{}, tc.schema, tc.table) + table.SetKeys(true, "ID") + + _, err := dbmap.Update(context.Background(), &updatedRow{ID: 1, Value: "unused"}) + if err != nil { + t.Fatal(err) + } + + requireIdentifierCapturedQuery(t, tc.want) + }) } } @@ -148,24 +211,57 @@ func TestCreateIndexQuotesIdentifierMetadata(t *testing.T) { ID int64 `db:"ID"` } - dbmap := newIdentifierCaptureDbMap(t, PostgresDialect{}) - table := dbmap.AddTableWithNameAndSchema(indexedRow{}, `sche"ma`, `security"rows`) - table.SetKeys(false, "ID") - table.AddIndex(`idx"name`, "btree", []string{"ID"}) - - err := dbmap.CreateIndex(context.Background()) - if err != nil { - t.Fatal(err) + tests := []struct { + name string + dialect Dialect + schema string + table string + indexName string + indexType string + want string + }{ + { + name: "sqlite", + dialect: SqliteDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + want: `create index "idx""name" on "security""rows" ("ID");`, + }, + { + name: "postgres", + dialect: PostgresDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + indexType: "btree", + want: `create index "idx""name" on "sche""ma"."security""rows" using btree ("ID");`, + }, + { + name: "mysql", + dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, + schema: "sche`ma", + table: "security`rows", + indexName: "idx`name", + indexType: "Btree", + want: "create index `idx``name` on `sche``ma`.`security``rows` (`ID`) using Btree;", + }, } - execs := identifierCapturedExecs() - if len(execs) != 1 { - t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dbmap := newIdentifierCaptureDbMap(t, tc.dialect) + table := dbmap.AddTableWithNameAndSchema(indexedRow{}, tc.schema, tc.table) + table.SetKeys(false, "ID") + table.AddIndex(tc.indexName, tc.indexType, []string{"ID"}) - want := `create index "idx""name" on "sche""ma"."security""rows" using btree ("ID");` - if execs[0].query != want { - t.Fatalf("generated %q, want %q", execs[0].query, want) + err := dbmap.CreateIndex(context.Background()) + if err != nil { + t.Fatal(err) + } + + requireIdentifierCapturedQuery(t, tc.want) + }) } } @@ -174,23 +270,56 @@ func TestDropIndexQuotesIdentifierMetadata(t *testing.T) { ID int64 `db:"ID"` } - dbmap := newIdentifierCaptureDbMap(t, MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}) - table := dbmap.AddTableWithNameAndSchema(indexedRow{}, "sche`ma", "security`rows") - table.SetKeys(false, "ID") - table.AddIndex("idx`name", "Btree", []string{"ID"}) - - err := table.DropIndex(context.Background(), "idx`name") - if err != nil { - t.Fatal(err) + tests := []struct { + name string + dialect Dialect + schema string + table string + indexName string + indexType string + want string + }{ + { + name: "sqlite", + dialect: SqliteDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + want: `DROP INDEX "idx""name";`, + }, + { + name: "postgres", + dialect: PostgresDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + indexType: "btree", + want: `DROP INDEX "idx""name";`, + }, + { + name: "mysql", + dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, + schema: "sche`ma", + table: "security`rows", + indexName: "idx`name", + indexType: "Btree", + want: "DROP INDEX `idx``name` on `sche``ma`.`security``rows`;", + }, } - execs := identifierCapturedExecs() - if len(execs) != 1 { - t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) - } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + dbmap := newIdentifierCaptureDbMap(t, tc.dialect) + table := dbmap.AddTableWithNameAndSchema(indexedRow{}, tc.schema, tc.table) + table.SetKeys(false, "ID") + table.AddIndex(tc.indexName, tc.indexType, []string{"ID"}) - want := "DROP INDEX `idx``name` on `sche``ma`.`security``rows`;" - if execs[0].query != want { - t.Fatalf("generated %q, want %q", execs[0].query, want) + err := table.DropIndex(context.Background(), tc.indexName) + if err != nil { + t.Fatal(err) + } + + requireIdentifierCapturedQuery(t, tc.want) + }) } } From 4b7b15ea36892e4cf7bda80ef1287f924ffa62c7 Mon Sep 17 00:00:00 2001 From: Akshith Gunasekaran Date: Thu, 11 Jun 2026 14:16:04 -0700 Subject: [PATCH 3/3] Simplify identifier quote test matrix --- identifier_quote_test.go | 210 +++++++++++++++------------------------ 1 file changed, 80 insertions(+), 130 deletions(-) diff --git a/identifier_quote_test.go b/identifier_quote_test.go index 9dc34600..4300c704 100644 --- a/identifier_quote_test.go +++ b/identifier_quote_test.go @@ -154,172 +154,122 @@ func requireIdentifierCapturedQuery(t *testing.T, want string) { } } -func TestUpdateQuotesIdentifierMetadata(t *testing.T) { - type updatedRow struct { - ID int64 `db:"ID"` - Value string `db:"Value"` - } +type identifierUpdatedRow struct { + ID int64 `db:"ID"` + Value string `db:"Value"` +} - tests := []struct { - name string - dialect Dialect - schema string - table string - want string - }{ - { - name: "sqlite", - dialect: SqliteDialect{}, - schema: `sche"ma`, - table: `security"rows`, - want: `update "security""rows" set "Value"=? where "ID"=?;`, - }, - { - name: "postgres", - dialect: PostgresDialect{}, - schema: `sche"ma`, - table: `security"rows`, - want: `update "sche""ma"."security""rows" set "Value"=$1 where "ID"=$2;`, - }, - { - name: "mysql", - dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, - schema: "sche`ma", - table: "security`rows", - want: "update `sche``ma`.`security``rows` set `Value`=? where `ID`=?;", - }, - } +type identifierIndexedRow struct { + ID int64 `db:"ID"` +} + +type identifierSQLCase struct { + name string + dialect Dialect + schema string + table string + indexName string + indexType string + wantUpdate string + wantCreateIndex string + wantDropIndex string +} + +var identifierSQLCases = []identifierSQLCase{ + { + name: "sqlite", + dialect: SqliteDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + wantUpdate: `update "security""rows" set "Value"=? where "ID"=?;`, + wantCreateIndex: `create index "idx""name" on "security""rows" ("ID");`, + wantDropIndex: `DROP INDEX "idx""name";`, + }, + { + name: "postgres", + dialect: PostgresDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + indexType: "btree", + wantUpdate: `update "sche""ma"."security""rows" set ` + + `"Value"=$1 where "ID"=$2;`, + wantCreateIndex: `create index "idx""name" on ` + + `"sche""ma"."security""rows" using btree ("ID");`, + wantDropIndex: `DROP INDEX "idx""name";`, + }, + { + name: "mysql", + dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, + schema: "sche`ma", + table: "security`rows", + indexName: "idx`name", + indexType: "Btree", + wantUpdate: "update `sche``ma`.`security``rows` set " + + "`Value`=? where `ID`=?;", + wantCreateIndex: "create index `idx``name` on " + + "`sche``ma`.`security``rows` (`ID`) using Btree;", + wantDropIndex: "DROP INDEX `idx``name` on " + + "`sche``ma`.`security``rows`;", + }, +} - for _, tc := range tests { +func addIdentifierIndexTable(dbmap *DbMap, tc identifierSQLCase) *TableMap { + table := dbmap.AddTableWithNameAndSchema(identifierIndexedRow{}, tc.schema, tc.table) + table.SetKeys(false, "ID") + table.AddIndex(tc.indexName, tc.indexType, []string{"ID"}) + return table +} + +func TestUpdateQuotesIdentifierMetadata(t *testing.T) { + for _, tc := range identifierSQLCases { t.Run(tc.name, func(t *testing.T) { dbmap := newIdentifierCaptureDbMap(t, tc.dialect) - table := dbmap.AddTableWithNameAndSchema(updatedRow{}, tc.schema, tc.table) + table := dbmap.AddTableWithNameAndSchema(identifierUpdatedRow{}, tc.schema, tc.table) table.SetKeys(true, "ID") - _, err := dbmap.Update(context.Background(), &updatedRow{ID: 1, Value: "unused"}) + _, err := dbmap.Update( + context.Background(), + &identifierUpdatedRow{ID: 1, Value: "unused"}, + ) if err != nil { t.Fatal(err) } - requireIdentifierCapturedQuery(t, tc.want) + requireIdentifierCapturedQuery(t, tc.wantUpdate) }) } } func TestCreateIndexQuotesIdentifierMetadata(t *testing.T) { - type indexedRow struct { - ID int64 `db:"ID"` - } - - tests := []struct { - name string - dialect Dialect - schema string - table string - indexName string - indexType string - want string - }{ - { - name: "sqlite", - dialect: SqliteDialect{}, - schema: `sche"ma`, - table: `security"rows`, - indexName: `idx"name`, - want: `create index "idx""name" on "security""rows" ("ID");`, - }, - { - name: "postgres", - dialect: PostgresDialect{}, - schema: `sche"ma`, - table: `security"rows`, - indexName: `idx"name`, - indexType: "btree", - want: `create index "idx""name" on "sche""ma"."security""rows" using btree ("ID");`, - }, - { - name: "mysql", - dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, - schema: "sche`ma", - table: "security`rows", - indexName: "idx`name", - indexType: "Btree", - want: "create index `idx``name` on `sche``ma`.`security``rows` (`ID`) using Btree;", - }, - } - - for _, tc := range tests { + for _, tc := range identifierSQLCases { t.Run(tc.name, func(t *testing.T) { dbmap := newIdentifierCaptureDbMap(t, tc.dialect) - table := dbmap.AddTableWithNameAndSchema(indexedRow{}, tc.schema, tc.table) - table.SetKeys(false, "ID") - table.AddIndex(tc.indexName, tc.indexType, []string{"ID"}) + addIdentifierIndexTable(dbmap, tc) err := dbmap.CreateIndex(context.Background()) if err != nil { t.Fatal(err) } - requireIdentifierCapturedQuery(t, tc.want) + requireIdentifierCapturedQuery(t, tc.wantCreateIndex) }) } } func TestDropIndexQuotesIdentifierMetadata(t *testing.T) { - type indexedRow struct { - ID int64 `db:"ID"` - } - - tests := []struct { - name string - dialect Dialect - schema string - table string - indexName string - indexType string - want string - }{ - { - name: "sqlite", - dialect: SqliteDialect{}, - schema: `sche"ma`, - table: `security"rows`, - indexName: `idx"name`, - want: `DROP INDEX "idx""name";`, - }, - { - name: "postgres", - dialect: PostgresDialect{}, - schema: `sche"ma`, - table: `security"rows`, - indexName: `idx"name`, - indexType: "btree", - want: `DROP INDEX "idx""name";`, - }, - { - name: "mysql", - dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, - schema: "sche`ma", - table: "security`rows", - indexName: "idx`name", - indexType: "Btree", - want: "DROP INDEX `idx``name` on `sche``ma`.`security``rows`;", - }, - } - - for _, tc := range tests { + for _, tc := range identifierSQLCases { t.Run(tc.name, func(t *testing.T) { dbmap := newIdentifierCaptureDbMap(t, tc.dialect) - table := dbmap.AddTableWithNameAndSchema(indexedRow{}, tc.schema, tc.table) - table.SetKeys(false, "ID") - table.AddIndex(tc.indexName, tc.indexType, []string{"ID"}) + table := addIdentifierIndexTable(dbmap, tc) err := table.DropIndex(context.Background(), tc.indexName) if err != nil { t.Fatal(err) } - requireIdentifierCapturedQuery(t, tc.want) + requireIdentifierCapturedQuery(t, tc.wantDropIndex) }) } }