diff --git a/db.go b/db.go index 4e07588..a88ed5f 100644 --- a/db.go +++ b/db.go @@ -101,7 +101,11 @@ func (m *DbMap) createIndexImpl(ctx context.Context, dialect reflect.Type, s.WriteString(" unique") } s.WriteString(" index") - s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName)) + s.WriteString(fmt.Sprintf( + " %s on %s", + m.Dialect.QuoteField(index.IndexName), + m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName), + )) if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" { s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) } @@ -129,10 +133,14 @@ func (t *TableMap) DropIndex(ctx context.Context, name string) error { for _, idx := range t.indexes { if idx.IndexName == name { s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName)) + s.WriteString(fmt.Sprintf("DROP INDEX %s", t.dbmap.Dialect.QuoteField(idx.IndexName))) if dname := dialect.Name(); dname == "MySQLDialect" { - s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName)) + s.WriteString(fmt.Sprintf( + " %s %s", + t.dbmap.Dialect.DropIndexSuffix(), + t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName), + )) } s.WriteString(";") _, e := t.dbmap.ExecContext(ctx, s.String()) diff --git a/dialect_mysql.go b/dialect_mysql.go index 1dfc2be..adead48 100644 --- a/dialect_mysql.go +++ b/dialect_mysql.go @@ -146,7 +146,7 @@ func (d MySQLDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, inse } func (d MySQLDialect) QuoteField(f string) string { - return "`" + f + "`" + return "`" + strings.ReplaceAll(f, "`", "``") + "`" } func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { @@ -154,7 +154,7 @@ func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { return d.QuoteField(table) } - return schema + "." + d.QuoteField(table) + return d.QuoteField(schema) + "." + d.QuoteField(table) } func (d MySQLDialect) IfSchemaNotExists(command, schema string) string { diff --git a/dialect_mysql_test.go b/dialect_mysql_test.go index 966162e..660ae85 100644 --- a/dialect_mysql_test.go +++ b/dialect_mysql_test.go @@ -141,6 +141,7 @@ func TestMySQLDialect(t *testing.T) { o.Spec("QuoteField", func(tcx testContext) { tcx.expect(tcx.dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) + tcx.expect(tcx.dialect.QuoteField("fo`o")).To(matchers.Equal("`fo``o`")) }) o.Group("QuotedTableForQuery", func() { @@ -149,7 +150,8 @@ func TestMySQLDialect(t *testing.T) { }) o.Spec("with a supplied schema", func(tcx testContext) { - tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`")) + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("`foo`.`bar`")) + tcx.expect(tcx.dialect.QuotedTableForQuery("fo`o", "ba`r")).To(matchers.Equal("`fo``o`.`ba``r`")) }) }) diff --git a/dialect_postgres.go b/dialect_postgres.go index 937f81e..02113d1 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -124,9 +124,9 @@ func (d PostgresDialect) InsertAutoIncrToTarget(ctx context.Context, exec SqlExe func (d PostgresDialect) QuoteField(f string) string { if d.LowercaseFields { - return `"` + strings.ToLower(f) + `"` + f = strings.ToLower(f) } - return `"` + f + `"` + return `"` + strings.ReplaceAll(f, `"`, `""`) + `"` } func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { @@ -134,7 +134,7 @@ func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string return d.QuoteField(table) } - return schema + "." + d.QuoteField(table) + return d.QuoteField(schema) + "." + d.QuoteField(table) } func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 4a2f674..25e4604 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -120,6 +120,7 @@ func TestPostgresDialect(t *testing.T) { o.Spec("By default, case is preserved", func(tcx postgresTestContext) { tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) tcx.expect(tcx.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`)) + tcx.expect(tcx.dialect.QuoteField(`Fo"o`)).To(matchers.Equal(`"Fo""o"`)) }) o.Group("With LowercaseFields set to true", func() { @@ -130,6 +131,7 @@ func TestPostgresDialect(t *testing.T) { o.Spec("fields are lowercased", func(tcx postgresTestContext) { tcx.expect(tcx.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`)) + tcx.expect(tcx.dialect.QuoteField(`Fo"O`)).To(matchers.Equal(`"fo""o"`)) }) }) }) @@ -140,7 +142,8 @@ func TestPostgresDialect(t *testing.T) { }) o.Spec("with a supplied schema", func(tcx postgresTestContext) { - tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`)) + tcx.expect(tcx.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`"foo"."bar"`)) + tcx.expect(tcx.dialect.QuotedTableForQuery(`fo"o`, `ba"r`)).To(matchers.Equal(`"fo""o"."ba""r"`)) }) }) diff --git a/dialect_sqlite.go b/dialect_sqlite.go index a0ba40d..d422174 100644 --- a/dialect_sqlite.go +++ b/dialect_sqlite.go @@ -8,6 +8,7 @@ import ( "context" "fmt" "reflect" + "strings" ) type SqliteDialect struct { @@ -92,7 +93,7 @@ func (d SqliteDialect) InsertAutoIncr(ctx context.Context, exec SqlExecutor, ins } func (d SqliteDialect) QuoteField(f string) string { - return `"` + f + `"` + return `"` + strings.ReplaceAll(f, `"`, `""`) + `"` } // sqlite does not have schemas like PostgreSQL does, so just escape it like normal diff --git a/identifier_quote_test.go b/identifier_quote_test.go new file mode 100644 index 0000000..4300c70 --- /dev/null +++ b/identifier_quote_test.go @@ -0,0 +1,275 @@ +package borp + +import ( + "context" + "database/sql" + "database/sql/driver" + "errors" + "sync" + "testing" + + _ "github.com/mattn/go-sqlite3" +) + +func TestSqliteDialectEscapesIdentifierQuotes(t *testing.T) { + dialect := SqliteDialect{} + got := dialect.QuoteField(`fo"o`) + want := `"fo""o"` + if got != want { + t.Fatalf("QuoteField() = %q, want %q", got, want) + } + got = dialect.QuotedTableForQuery("", `ta"ble`) + want = `"ta""ble"` + if got != want { + t.Fatalf("QuotedTableForQuery() = %q, want %q", got, want) + } +} + +func TestSqliteQuotedTableNameCannotRewriteUpdateTarget(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Exec("CREATE TABLE victim (id integer primary key, value text, admin integer)") + if err != nil { + t.Fatal(err) + } + _, err = db.Exec("INSERT INTO victim (id, value, admin) VALUES (1, 'unchanged', 0)") + if err != nil { + t.Fatal(err) + } + + type row struct { + ID int64 `db:"ID"` + Value string `db:"Value"` + } + + dbmap := &DbMap{Db: db, Dialect: SqliteDialect{}} + injectedTable := `victim" SET admin = 1 WHERE ? <> ? -- ` + dbmap.AddTableWithName(row{}, injectedTable).SetKeys(false, "ID") + + _, err = dbmap.Update(context.Background(), &row{ID: 1, Value: "unused"}) + if err == nil { + t.Fatal("Update succeeded for escaped malicious table name") + } + + var admin int + err = db.QueryRow("SELECT admin FROM victim WHERE id = 1").Scan(&admin) + if err != nil { + t.Fatal(err) + } + if admin != 0 { + t.Fatalf("victim.admin = %d, want 0", admin) + } +} + +type identifierCapturedExec struct { + query string + args []driver.NamedValue +} + +var identifierCaptureState = struct { + sync.Mutex + registered sync.Once + execs []identifierCapturedExec +}{} + +type identifierCaptureDriver struct{} + +func (identifierCaptureDriver) Open(string) (driver.Conn, error) { + return identifierCaptureConn{}, nil +} + +type identifierCaptureConn struct{} + +func (identifierCaptureConn) Prepare(string) (driver.Stmt, error) { + return nil, errors.New("identifier capture driver does not prepare statements") +} + +func (identifierCaptureConn) Close() error { + return nil +} + +func (identifierCaptureConn) Begin() (driver.Tx, error) { + return nil, errors.New("identifier capture driver does not begin transactions") +} + +func (identifierCaptureConn) ExecContext( + _ context.Context, + query string, + args []driver.NamedValue, +) (driver.Result, error) { + argsCopy := append([]driver.NamedValue(nil), args...) + + identifierCaptureState.Lock() + defer identifierCaptureState.Unlock() + identifierCaptureState.execs = append(identifierCaptureState.execs, identifierCapturedExec{ + query: query, + args: argsCopy, + }) + return driver.RowsAffected(0), nil +} + +func newIdentifierCaptureDbMap(t *testing.T, dialect Dialect) *DbMap { + t.Helper() + + identifierCaptureState.registered.Do(func() { + sql.Register("borp_identifier_capture", identifierCaptureDriver{}) + }) + + identifierCaptureState.Lock() + identifierCaptureState.execs = nil + identifierCaptureState.Unlock() + + db, err := sql.Open("borp_identifier_capture", "") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + closeErr := db.Close() + if closeErr != nil { + t.Fatal(closeErr) + } + }) + + return &DbMap{Db: db, Dialect: dialect} +} + +func identifierCapturedExecs() []identifierCapturedExec { + identifierCaptureState.Lock() + defer identifierCaptureState.Unlock() + return append([]identifierCapturedExec(nil), identifierCaptureState.execs...) +} + +func requireIdentifierCapturedQuery(t *testing.T, want string) { + t.Helper() + execs := identifierCapturedExecs() + if len(execs) != 1 { + t.Fatalf("expected one captured exec, got %d: %+v", len(execs), execs) + } + if execs[0].query != want { + t.Fatalf("generated %q, want %q", execs[0].query, want) + } +} + +type identifierUpdatedRow struct { + ID int64 `db:"ID"` + Value string `db:"Value"` +} + +type identifierIndexedRow struct { + ID int64 `db:"ID"` +} + +type identifierSQLCase struct { + name string + dialect Dialect + schema string + table string + indexName string + indexType string + wantUpdate string + wantCreateIndex string + wantDropIndex string +} + +var identifierSQLCases = []identifierSQLCase{ + { + name: "sqlite", + dialect: SqliteDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + wantUpdate: `update "security""rows" set "Value"=? where "ID"=?;`, + wantCreateIndex: `create index "idx""name" on "security""rows" ("ID");`, + wantDropIndex: `DROP INDEX "idx""name";`, + }, + { + name: "postgres", + dialect: PostgresDialect{}, + schema: `sche"ma`, + table: `security"rows`, + indexName: `idx"name`, + indexType: "btree", + wantUpdate: `update "sche""ma"."security""rows" set ` + + `"Value"=$1 where "ID"=$2;`, + wantCreateIndex: `create index "idx""name" on ` + + `"sche""ma"."security""rows" using btree ("ID");`, + wantDropIndex: `DROP INDEX "idx""name";`, + }, + { + name: "mysql", + dialect: MySQLDialect{Engine: "InnoDB", Encoding: "UTF8"}, + schema: "sche`ma", + table: "security`rows", + indexName: "idx`name", + indexType: "Btree", + wantUpdate: "update `sche``ma`.`security``rows` set " + + "`Value`=? where `ID`=?;", + wantCreateIndex: "create index `idx``name` on " + + "`sche``ma`.`security``rows` (`ID`) using Btree;", + wantDropIndex: "DROP INDEX `idx``name` on " + + "`sche``ma`.`security``rows`;", + }, +} + +func addIdentifierIndexTable(dbmap *DbMap, tc identifierSQLCase) *TableMap { + table := dbmap.AddTableWithNameAndSchema(identifierIndexedRow{}, tc.schema, tc.table) + table.SetKeys(false, "ID") + table.AddIndex(tc.indexName, tc.indexType, []string{"ID"}) + return table +} + +func TestUpdateQuotesIdentifierMetadata(t *testing.T) { + for _, tc := range identifierSQLCases { + t.Run(tc.name, func(t *testing.T) { + dbmap := newIdentifierCaptureDbMap(t, tc.dialect) + table := dbmap.AddTableWithNameAndSchema(identifierUpdatedRow{}, tc.schema, tc.table) + table.SetKeys(true, "ID") + + _, err := dbmap.Update( + context.Background(), + &identifierUpdatedRow{ID: 1, Value: "unused"}, + ) + if err != nil { + t.Fatal(err) + } + + requireIdentifierCapturedQuery(t, tc.wantUpdate) + }) + } +} + +func TestCreateIndexQuotesIdentifierMetadata(t *testing.T) { + for _, tc := range identifierSQLCases { + t.Run(tc.name, func(t *testing.T) { + dbmap := newIdentifierCaptureDbMap(t, tc.dialect) + addIdentifierIndexTable(dbmap, tc) + + err := dbmap.CreateIndex(context.Background()) + if err != nil { + t.Fatal(err) + } + + requireIdentifierCapturedQuery(t, tc.wantCreateIndex) + }) + } +} + +func TestDropIndexQuotesIdentifierMetadata(t *testing.T) { + for _, tc := range identifierSQLCases { + t.Run(tc.name, func(t *testing.T) { + dbmap := newIdentifierCaptureDbMap(t, tc.dialect) + table := addIdentifierIndexTable(dbmap, tc) + + err := table.DropIndex(context.Background(), tc.indexName) + if err != nil { + t.Fatal(err) + } + + requireIdentifierCapturedQuery(t, tc.wantDropIndex) + }) + } +}