diff --git a/lib/cli/traffic.go b/lib/cli/traffic.go index 7322cee70..e4ec489bc 100644 --- a/lib/cli/traffic.go +++ b/lib/cli/traffic.go @@ -63,6 +63,7 @@ func GetTrafficReplayCmd(ctx *Context) *cobra.Command { username := replayCmd.PersistentFlags().String("username", "", "the username to connect to TiDB for replay") password := replayCmd.PersistentFlags().String("password", "", "the password to connect to TiDB for replay") readonly := replayCmd.PersistentFlags().Bool("read-only", false, "only replay read-only queries, default is false") + format := replayCmd.PersistentFlags().String("format", "", "the format of traffic files") replayCmd.RunE = func(cmd *cobra.Command, args []string) error { username := *username if len(username) == 0 { @@ -83,6 +84,7 @@ func GetTrafficReplayCmd(ctx *Context) *cobra.Command { "username": username, "password": password, "readonly": strconv.FormatBool(*readonly), + "format": *format, }) resp, err := doRequest(cmd.Context(), ctx, http.MethodPost, "/api/traffic/replay", reader) if err != nil { diff --git a/pkg/server/api/traffic.go b/pkg/server/api/traffic.go index ed5425652..99a191506 100644 --- a/pkg/server/api/traffic.go +++ b/pkg/server/api/traffic.go @@ -99,6 +99,7 @@ func (h *Server) TrafficReplay(c *gin.Context) { } cfg.Username = c.PostForm("username") cfg.Password = c.PostForm("password") + cfg.Format = c.PostForm("format") cfg.ReadOnly = strings.EqualFold(c.PostForm("readonly"), "true") cfg.KeyFile = globalCfg.Security.EncryptionKeyPath diff --git a/pkg/sqlreplay/capture/capture.go b/pkg/sqlreplay/capture/capture.go index d2ef42e1a..e216aeba9 100644 --- a/pkg/sqlreplay/capture/capture.go +++ b/pkg/sqlreplay/capture/capture.go @@ -216,9 +216,10 @@ func (c *capture) collectCmds(bufCh chan<- *bytes.Buffer) { defer close(bufCh) buf := bytes.NewBuffer(make([]byte, 0, c.cfg.bufferCap)) + encoder := cmd.NewCmdEncoder(cmd.FormatNative) // Flush all commands even if the context is timeout. for command := range c.cmdCh { - if err := command.Encode(buf); err != nil { + if err := encoder.Encode(command, buf); err != nil { c.stop(errors.Wrapf(err, "failed to encode command")) continue } diff --git a/pkg/sqlreplay/cmd/audit_log_plugin.go b/pkg/sqlreplay/cmd/audit_log_plugin.go new file mode 100644 index 000000000..e391818b4 --- /dev/null +++ b/pkg/sqlreplay/cmd/audit_log_plugin.go @@ -0,0 +1,219 @@ +// Copyright 2025 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "strconv" + "strings" + "time" + + "github.com/pingcap/tiproxy/lib/util/errors" + pnet "github.com/pingcap/tiproxy/pkg/proxy/net" + "github.com/siddontang/go/hack" +) + +const ( + auditPluginKeyTimeStamp = "TIMESTAMP" + auditPluginKeyDatabase = "DATABASES" + auditPluginKeySQL = "SQL_TEXT" + auditPluginKeyConnID = "CONNECTION_ID" + auditPluginKeyClass = "EVENT_CLASS" + auditPluginKeySubClass = "EVENT_SUBCLASS" + auditPluginKeyCommand = "COMMAND" + auditPluginKeyStmtType = "SQL_STATEMENTS" + + auditPluginClassGeneral = "GENERAL" + auditPluginClassTableAccess = "TABLE_ACCESS" + auditPluginClassConnect = "CONNECTION" + + auditPluginSubClassConnected = "Connected" + auditPluginSubClassDisconnect = "Disconnect" + + timeLayout = "2006/01/02 15:04:05.999 -07:00" +) + +func NewAuditLogPluginDecoder() *AuditLogPluginDecoder { + return &AuditLogPluginDecoder{} +} + +var _ CmdDecoder = (*AuditLogPluginDecoder)(nil) + +type AuditLogPluginDecoder struct { +} + +func (*AuditLogPluginDecoder) Decode(reader LineReader) (*Command, error) { + for { + line, filename, lineIdx, err := reader.ReadLine() + if err != nil { + return nil, err + } + kvs, err := parseLog(hack.String(line)) + if err != nil { + return nil, errors.Errorf("%s, line %d: %s", filename, lineIdx, err.Error()) + } + connStr := kvs[auditPluginKeyConnID] + if len(connStr) == 0 { + return nil, errors.Errorf("%s, line %d: no connection id in line: %s", filename, lineIdx, line) + } + connID, err := strconv.ParseUint(connStr, 10, 64) + if err != nil { + return nil, errors.Errorf("%s, line %d: parsing connection id failed: %s", filename, lineIdx, connStr) + } + tsStr := kvs[auditPluginKeyTimeStamp] + if len(tsStr) == 0 { + return nil, errors.Errorf("%s, line %d: no timestamp in line: '%s", filename, lineIdx, line) + } + startTs, err := time.Parse(timeLayout, tsStr) + if err != nil { + return nil, errors.Errorf("%s, line %d: parsing timestamp failed: %s", filename, lineIdx, tsStr) + } + var c *Command + eventClass := kvs[auditPluginKeyClass] + switch eventClass { + case auditPluginClassGeneral, auditPluginClassTableAccess: + c, err = parseGeneralEvent(kvs) + case auditPluginClassConnect: + c, err = parseConnectEvent(kvs) + default: + return nil, errors.Errorf("%s, line %d: unknown event class: %s", filename, lineIdx, eventClass) + } + if err != nil { + return c, err + } + // The log is ignored, skip. + if c == nil { + continue + } + c.Succeess = true + c.ConnID = connID + c.StartTs = startTs + return c, nil + } +} + +// All SQL_TEXT are converted into one line in audit log. +func parseLog(line string) (map[string]string, error) { + kv := make(map[string]string) + for idx := 0; idx < len(line); idx++ { + switch line[idx] { + case '[': + key, value, endIdx, err := parseInBracket(line[idx+1:]) + if err != nil { + return kv, err + } + idx += endIdx + 1 + if len(key) > 0 { + kv[key] = value + } + } + } + return kv, nil +} + +func parseInBracket(line string) (key, value string, idx int, err error) { + valueStart := 0 + for ; idx < len(line); idx++ { + switch line[idx] { + case ']': + value = line[valueStart:idx] + return + case '"', '\'': + endIdx := skipQuotes(line[idx+1:], line[idx] == '\'') + if endIdx == -1 { + return "", "", len(line), errors.Errorf("unterminated quote in line: %s", line[idx+1:]) + } + idx += endIdx + 1 + case '=': + if idx == 0 { + return "", "", idx, errors.Errorf("empty key in line: %s", line) + } + // only care about the first '=' + if len(key) == 0 { + key = line[:idx] + valueStart = idx + 1 + } + } + } + return "", "", len(line), errors.Errorf("unterminated bracket in line: %s", line) +} + +func skipQuotes(line string, singleQuote bool) (endIdx int) { + for idx := 0; idx < len(line); idx++ { + switch line[idx] { + case '"': + if !singleQuote { + return idx + } + case '\'': + if singleQuote { + return idx + } + case '\\': + idx++ + } + } + return -1 +} + +// [DATABASES="[test]"] +func parseDB(value string) []string { + var err error + value, err = strconv.Unquote(value) + if err != nil { + return nil + } + if len(value) == 0 { + return nil + } + if value[0] != '[' || value[len(value)-1] != ']' { + // impossible + return nil + } + value = value[1 : len(value)-1] + if len(value) == 0 { + return nil + } + return strings.Split(value, ",") +} + +func parseGeneralEvent(kvs map[string]string) (*Command, error) { + switch kvs[auditPluginKeyCommand] { + case "Query", "Init DB": + sql, err := strconv.Unquote(kvs[auditPluginKeySQL]) + if err != nil { + return nil, errors.Wrapf(err, "unquote sql failed: %s", kvs[auditPluginKeySQL]) + } + return &Command{ + Type: pnet.ComQuery, + StmtType: kvs[auditPluginKeyStmtType], + Payload: append([]byte{pnet.ComQuery.Byte()}, hack.Slice(sql)...), + }, nil + // Ignore StmtExecute since the params are not outputted. + // Ignore Quit since disconnection is handled in parseConnectEvent. + } + // ignore the rest + return nil, nil +} + +func parseConnectEvent(kvs map[string]string) (*Command, error) { + subclass := kvs[auditPluginKeySubClass] + switch subclass { + case auditPluginSubClassConnected: + db := kvs[auditPluginKeyDatabase] + dbs := parseDB(db) + if len(dbs) == 1 { + return &Command{ + Type: pnet.ComInitDB, + Payload: append([]byte{pnet.ComInitDB.Byte()}, hack.Slice(dbs[0])...), + }, nil + } + return nil, nil + case auditPluginSubClassDisconnect: + return &Command{ + Type: pnet.ComQuit, + Payload: []byte{pnet.ComQuit.Byte()}, + }, nil + } + return nil, nil +} diff --git a/pkg/sqlreplay/cmd/audit_log_plugin_test.go b/pkg/sqlreplay/cmd/audit_log_plugin_test.go new file mode 100644 index 000000000..14e7c2afe --- /dev/null +++ b/pkg/sqlreplay/cmd/audit_log_plugin_test.go @@ -0,0 +1,406 @@ +// Copyright 2025 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "testing" + "time" + + pnet "github.com/pingcap/tiproxy/pkg/proxy/net" + "github.com/stretchr/testify/require" +) + +func TestSkipQuotes(t *testing.T) { + tests := []struct { + line string + singleQuote bool + endIdx int + }{ + { + line: "", + singleQuote: false, + endIdx: -1, + }, + { + line: "\"", + singleQuote: true, + endIdx: -1, + }, + { + line: "'", + singleQuote: false, + endIdx: -1, + }, + { + line: "\"", + singleQuote: false, + endIdx: 0, + }, + { + line: "'", + singleQuote: true, + endIdx: 0, + }, + { + line: "abc'abc", + singleQuote: true, + endIdx: 3, + }, + { + line: "\\'", + singleQuote: true, + endIdx: -1, + }, + { + line: "\\'", + singleQuote: true, + endIdx: -1, + }, + { + line: "\\\\'", + singleQuote: true, + endIdx: 2, + }, + } + + for i, test := range tests { + endIdx := skipQuotes(test.line, test.singleQuote) + require.Equal(t, test.endIdx, endIdx, "case %d", i) + } +} + +func TestParseInBracket(t *testing.T) { + tests := []struct { + line string + key string + val string + endIdx int + hasErr bool + }{ + { + line: "", + key: "", + val: "", + endIdx: 0, + hasErr: true, + }, + { + line: "]", + key: "", + val: "", + endIdx: 0, + hasErr: false, + }, + { + line: "a=b]", + key: "a", + val: "b", + endIdx: 3, + hasErr: false, + }, + { + line: "a=b", + key: "", + val: "", + endIdx: 2, + hasErr: true, + }, + { + line: "abc]", + key: "", + val: "abc", + endIdx: 3, + hasErr: false, + }, + { + line: "abc=]", + key: "abc", + val: "", + endIdx: 4, + hasErr: false, + }, + { + line: "=abc]", + key: "", + val: "abc", + endIdx: 4, + hasErr: true, + }, + { + line: "a\"]", + key: "", + val: "a\"", + endIdx: 2, + hasErr: true, + }, + { + line: "a\"]\"", + key: "", + val: "", + endIdx: 3, + hasErr: true, + }, + { + line: "a\"]\"]", + key: "", + val: "a\"]\"", + endIdx: 4, + hasErr: false, + }, + { + line: "a\"]\"a\"]", + key: "", + val: "", + endIdx: 6, + hasErr: true, + }, + { + line: "a\"]\"]abc", + key: "", + val: "a\"]\"", + endIdx: 4, + hasErr: false, + }, + } + + for i, test := range tests { + key, val, endIdx, err := parseInBracket(test.line) + if test.hasErr { + require.Error(t, err, "case %d", i) + continue + } else { + require.NoError(t, err, "case %d", i) + } + require.Equal(t, test.key, key, "case %d", i) + require.Equal(t, test.val, val, "case %d", i) + require.Equal(t, test.endIdx, endIdx, "case %d", i) + } +} + +func TestParseLog(t *testing.T) { + tests := []struct { + line string + kvs map[string]string + hasErr bool + }{ + { + line: "", + hasErr: false, + }, + { + line: "[abc]", + hasErr: false, + }, + { + line: "[a=b", + hasErr: true, + }, + { + line: "[abc", + hasErr: true, + }, + { + line: "[abc=def", + hasErr: true, + }, + { + line: "[=def", + hasErr: true, + }, + { + line: "[=def]", + hasErr: true, + }, + { + line: "[abc=def]", + kvs: map[string]string{"abc": "def"}, + hasErr: false, + }, + { + line: "[abc=def=ghi", + hasErr: true, + }, + { + line: "[abc=def=ghi]", + kvs: map[string]string{"abc": "def=ghi"}, + hasErr: false, + }, + { + line: "[a=\"b\"]", + kvs: map[string]string{"a": "\"b\""}, + hasErr: false, + }, + { + line: "[a=\"b]", + hasErr: true, + }, + { + line: "[abc][a=b]", + kvs: map[string]string{"a": "b"}, + hasErr: false, + }, + { + line: "[abc][a=b", + hasErr: true, + }, + { + line: "a[abc]a", + hasErr: false, + }, + { + line: "a[a=b]a", + kvs: map[string]string{"a": "b"}, + hasErr: false, + }, + { + line: "a[a=b]a[c=d]", + kvs: map[string]string{"a": "b", "c": "d"}, + hasErr: false, + }, + { + line: "a[a=b]a[c=d", + hasErr: true, + }, + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.720 +08:00] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + kvs: map[string]string{"ID": "17571494330", "TIMESTAMP": "2025/09/06 17:03:53.720 +08:00", "EVENT_CLASS": "GENERAL", "EVENT_SUBCLASS": "", "STATUS_CODE": "0", "COST_TIME": "1336.083", "HOST": "127.0.0.1", "CLIENT_IP": "127.0.0.1", "USER": "root", "DATABASES": "\"[]\"", "TABLES": "\"[]\"", "SQL_TEXT": "\"select \\\"[=]\\\"\"", "ROWS": "0", "CONNECTION_ID": "3695181836", "CLIENT_PORT": "63912", "PID": "61215", "COMMAND": "Query", "SQL_STATEMENTS": "Select"}, + hasErr: false, + }, + { + line: `[2025/09/06 17:03:53.717 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.717 +08:00] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=824806376.375] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"\n\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + kvs: map[string]string{"ID": "17571494330", "TIMESTAMP": "2025/09/06 17:03:53.717 +08:00", "EVENT_CLASS": "GENERAL", "EVENT_SUBCLASS": "", "STATUS_CODE": "0", "COST_TIME": "824806376.375", "HOST": "127.0.0.1", "CLIENT_IP": "127.0.0.1", "USER": "root", "DATABASES": "\"[]\"", "TABLES": "\"[]\"", "SQL_TEXT": "\"select \\\"\\n\\\"\"", "ROWS": "0", "CONNECTION_ID": "3695181836", "CLIENT_PORT": "63912", "PID": "61215", "COMMAND": "Query", "SQL_STATEMENTS": "Select"}, + hasErr: false, + }, + { + line: `[2025/09/06 16:50:08.917 +08:00] [INFO] [logger.go:77] [ID=17571486080] [TIMESTAMP=2025/09/06 16:50:08.917 +08:00] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=2442.333] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"\n\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + kvs: map[string]string{"ID": "17571486080", "TIMESTAMP": "2025/09/06 16:50:08.917 +08:00", "EVENT_CLASS": "GENERAL", "EVENT_SUBCLASS": "", "STATUS_CODE": "0", "COST_TIME": "2442.333", "HOST": "127.0.0.1", "CLIENT_IP": "127.0.0.1", "USER": "root", "DATABASES": "\"[]\"", "TABLES": "\"[]\"", "SQL_TEXT": "\"select \\\"\\n\\\"\"", "ROWS": "0", "CONNECTION_ID": "3695181836", "CLIENT_PORT": "63912", "PID": "61215", "COMMAND": "Query", "SQL_STATEMENTS": "Select"}, + hasErr: false, + }, + } + + for i, test := range tests { + kvs, err := parseLog(test.line) + if test.hasErr { + require.Error(t, err, "case %d", i) + continue + } else { + require.NoError(t, err, "case %d", i) + } + if len(test.kvs) == 0 && len(kvs) == 0 { + continue + } + require.EqualValues(t, test.kvs, kvs, "case %d", i) + } +} + +func TestParseDB(t *testing.T) { + tests := []struct { + s string + expect []string + }{ + { + s: `""`, + expect: nil, + }, + { + s: `"[]"`, + expect: nil, + }, + { + s: `"[test]"`, + expect: []string{"test"}, + }, + { + s: `"[hello,world]"`, + expect: []string{"hello", "world"}, + }, + } + for i, test := range tests { + dbs := parseDB(test.s) + if len(dbs) == 0 && len(test.expect) == 0 { + continue + } + require.EqualValues(t, test.expect, parseDB(test.s), "case %d", i) + } +} + +func TestDecodeAuditLogPlugin(t *testing.T) { + tests := []struct { + line string + cmd *Command + errMsg string + }{ + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.720 +08:10] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + cmd: &Command{ + Type: pnet.ComQuery, + ConnID: 3695181836, + StartTs: time.Date(2025, 9, 6, 17, 3, 53, 720000000, time.FixedZone("", 8*3600+600)), + Payload: append([]byte{pnet.ComQuery.Byte()}, []byte("select \"[=]\"")...), + StmtType: "Select", + Succeess: true, + }, + }, + { + // connect with an initial database + line: `[2025/09/08 21:15:12.904 +08:00] [INFO] [logger.go:77] [ID=17573373120] [TIMESTAMP=2025/09/08 21:15:12.904 +08:10] [EVENT_CLASS=CONNECTION] [EVENT_SUBCLASS=Connected] [STATUS_CODE=0] [COST_TIME=0] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[test]"] [TABLES="[]"] [SQL_TEXT=] [ROWS=0] [CLIENT_PORT=49278] [CONNECTION_ID=3552575510] [CONNECTION_TYPE=SSL/TLS] [SERVER_ID=1] [SERVER_PORT=4000] [DURATION=0] [SERVER_OS_LOGIN_USER=test] [OS_VERSION=darwin.arm64] [CLIENT_VERSION=] [SERVER_VERSION=v9.0.0] [AUDIT_VERSION=] [SSL_VERSION=TLSv1.3] [PID=89967] [Reason=]`, + cmd: &Command{ + Type: pnet.ComInitDB, + ConnID: 3552575510, + StartTs: time.Date(2025, 9, 8, 21, 15, 12, 904000000, time.FixedZone("", 8*3600+600)), + Payload: append([]byte{pnet.ComInitDB.Byte()}, []byte("test")...), + Succeess: true, + }, + }, + { + // no initial database + line: `[2025/09/08 21:15:12.904 +08:00] [INFO] [logger.go:77] [ID=17573373120] [TIMESTAMP=2025/09/08 21:15:12.904 +08:10] [EVENT_CLASS=CONNECTION] [EVENT_SUBCLASS=Connected] [STATUS_CODE=0] [COST_TIME=0] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT=] [ROWS=0] [CLIENT_PORT=49278] [CONNECTION_ID=3552575510] [CONNECTION_TYPE=SSL/TLS] [SERVER_ID=1] [SERVER_PORT=4000] [DURATION=0] [SERVER_OS_LOGIN_USER=test] [OS_VERSION=darwin.arm64] [CLIENT_VERSION=] [SERVER_VERSION=v9.0.0] [AUDIT_VERSION=] [SSL_VERSION=TLSv1.3] [PID=89967] [Reason=]`, + errMsg: "EOF", + }, + { + line: `[2025/09/08 21:15:35.621 +08:00] [INFO] [logger.go:77] [ID=17573373350] [TIMESTAMP=2025/09/08 21:15:35.621 +08:10] [EVENT_CLASS=CONNECTION] [EVENT_SUBCLASS=Disconnect] [STATUS_CODE=0] [COST_TIME=0] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[test]"] [TABLES="[]"] [SQL_TEXT=] [ROWS=0] [CLIENT_PORT=49278] [CONNECTION_ID=3552575510] [CONNECTION_TYPE=SSL/TLS] [SERVER_ID=1] [SERVER_PORT=4000] [DURATION=22716.871792] [SERVER_OS_LOGIN_USER=test] [OS_VERSION=darwin.arm64] [CLIENT_VERSION=] [SERVER_VERSION=v9.0.0] [AUDIT_VERSION=] [SSL_VERSION=TLSv1.3] [PID=89967] [Reason=]`, + cmd: &Command{ + Type: pnet.ComQuit, + ConnID: 3552575510, + StartTs: time.Date(2025, 9, 8, 21, 15, 35, 621000000, time.FixedZone("", 8*3600+600)), + Payload: []byte{pnet.ComQuit.Byte()}, + Succeess: true, + }, + }, + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + errMsg: "no timestamp", + }, + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.720] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + errMsg: "parsing timestamp failed", + }, + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.720 +08:10] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + errMsg: "no connection id", + }, + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.720 +08:10] [EVENT_CLASS=GENERAL] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CONNECTION_ID=abc] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + errMsg: "parsing connection id failed", + }, + { + line: `[2025/09/06 17:03:53.720 +08:00] [INFO] [logger.go:77] [ID=17571494330] [TIMESTAMP=2025/09/06 17:03:53.720 +08:10] [EVENT_CLASS=HELLO] [EVENT_SUBCLASS=] [STATUS_CODE=0] [COST_TIME=1336.083] [HOST=127.0.0.1] [CLIENT_IP=127.0.0.1] [USER=root] [DATABASES="[]"] [TABLES="[]"] [SQL_TEXT="select \"[=]\""] [ROWS=0] [CONNECTION_ID=3695181836] [CLIENT_PORT=63912] [PID=61215] [COMMAND=Query] [SQL_STATEMENTS=Select]`, + errMsg: "unknown event class", + }, + } + + decoder := NewAuditLogPluginDecoder() + for i, test := range tests { + mr := mockReader{data: append([]byte(test.line), '\n')} + cmd, err := decoder.Decode(&mr) + if len(test.errMsg) > 0 { + require.Error(t, err, "case %d", i) + require.ErrorContains(t, err, test.errMsg, "case %d", i) + continue + } else { + require.NoError(t, err, "case %d", i) + } + require.Equal(t, test.cmd, cmd, "case %d", i) + } +} diff --git a/pkg/sqlreplay/cmd/cmd.go b/pkg/sqlreplay/cmd/cmd.go index eadf70c05..b6d632081 100644 --- a/pkg/sqlreplay/cmd/cmd.go +++ b/pkg/sqlreplay/cmd/cmd.go @@ -6,9 +6,6 @@ package cmd import ( "bytes" "fmt" - "io" - "strconv" - "strings" "time" "github.com/pingcap/tidb/pkg/parser" @@ -18,13 +15,8 @@ import ( ) const ( - commonKeyPrefix = "# " - commonKeySuffix = ": " - keyStartTs = "# Time: " - keyConnID = "# Conn_ID: " - keyType = "# Cmd_type: " - keySuccess = "# Success: " - keyPayloadLen = "# Payload_len: " + FormatNative = "native" + FormatAuditLogPlugin = "audit_log_plugin" ) type LineReader interface { @@ -34,15 +26,40 @@ type LineReader interface { Close() } +func NewCmdEncoder(_ string) CmdEncoder { + // Only support writing native format + return NewNativeEncoder() +} + +type CmdEncoder interface { + Encode(c *Command, writer *bytes.Buffer) error +} + +func NewCmdDecoder(format string) CmdDecoder { + switch format { + case FormatAuditLogPlugin: + return NewAuditLogPluginDecoder() + default: + return NewNativeDecoder() + } +} + +type CmdDecoder interface { + Decode(reader LineReader) (c *Command, err error) +} + type Command struct { PreparedStmt string Params []any digest string // Payload starts with command type so that replay can reuse this byte array. - Payload []byte - StartTs time.Time - ConnID uint64 - Type pnet.Command + Payload []byte + StartTs time.Time + ConnID uint64 + Type pnet.Command + // Logged only in audit log. + StmtType string + // Logged only in native log. Succeess bool } @@ -84,116 +101,6 @@ func (c *Command) Validate(filename string, lineIdx int) error { return nil } -func (c *Command) Encode(writer *bytes.Buffer) error { - var err error - if err = writeString(keyStartTs, c.StartTs.Format(time.RFC3339Nano), writer); err != nil { - return err - } - if err = writeString(keyConnID, strconv.FormatUint(c.ConnID, 10), writer); err != nil { - return err - } - if c.Type != pnet.ComQuery { - if err = writeString(keyType, c.Type.String(), writer); err != nil { - return err - } - } - if !c.Succeess { - if err = writeString(keySuccess, "false", writer); err != nil { - return err - } - } - // `Payload_len` doesn't include the command type. - if err = writeString(keyPayloadLen, strconv.Itoa(len(c.Payload[1:])), writer); err != nil { - return err - } - // Unlike TiDB slow log, the payload is binary because StmtExecute can't be transformed to a SQL. - if len(c.Payload) > 1 { - if _, err = writer.Write(c.Payload[1:]); err != nil { - return errors.WithStack(err) - } - } - if err = writer.WriteByte('\n'); err != nil { - return errors.WithStack(err) - } - return nil -} - -func (c *Command) Decode(reader LineReader) error { - c.Succeess = true - c.Type = pnet.ComQuery - for { - line, filename, lineIdx, err := reader.ReadLine() - if err != nil { - return err - } - if !strings.HasPrefix(hack.String(line), commonKeyPrefix) { - return errors.Errorf("%s, line %d: line doesn't start with '%s': %s", filename, lineIdx, commonKeyPrefix, line) - } - idx := strings.Index(hack.String(line), commonKeySuffix) - if idx < 0 { - return errors.Errorf("%s, line %d: '%s' is not found in line: %s", filename, lineIdx, commonKeySuffix, line) - } - idx += len(commonKeySuffix) - key := hack.String(line[:idx]) - value := hack.String(line[idx:]) - if len(value) == 0 { - return errors.Errorf("%s, line %d: value is empty in line: %s", filename, lineIdx, line) - } - switch key { - case keyStartTs: - if !c.StartTs.IsZero() { - return errors.Errorf("%s, line %d: redundant Time: %s, Time was %v", filename, lineIdx, line, c.StartTs) - } - c.StartTs, err = time.Parse(time.RFC3339Nano, value) - if err != nil { - return errors.Errorf("%s, line %d: parsing Time failed: %s", filename, lineIdx, line) - } - case keyConnID: - if c.ConnID > 0 { - return errors.Errorf("%s, line %d: redundant Conn_ID: %s, Conn_ID was %d", filename, lineIdx, line, c.ConnID) - } - c.ConnID, err = strconv.ParseUint(value, 10, 64) - if err != nil { - return errors.Errorf("%s, line %d: parsing Conn_ID failed: %s", filename, lineIdx, line) - } - case keyType: - if c.Type != pnet.ComQuery { - return errors.Errorf("%s, line %d: redundant Cmd_type: %s, Cmd_type was %v", filename, lineIdx, line, c.Type) - } - c.Type = pnet.CommandFromString(value) - case keySuccess: - c.Succeess = value == "true" - case keyPayloadLen: - var payloadLen int - if payloadLen, err = strconv.Atoi(value); err != nil { - return errors.Errorf("parsing Payload_len failed: %s", line) - } - c.Payload = make([]byte, payloadLen+1) - c.Payload[0] = c.Type.Byte() - if payloadLen > 0 { - if filename, lineIdx, err = reader.Read(c.Payload[1:]); err != nil { - return errors.Errorf("%s, line %d: reading Payload failed: %s", filename, lineIdx, err.Error()) - } - } - // skip '\n' - var data [1]byte - if filename, lineIdx, err = reader.Read(data[:]); err != nil { - if !errors.Is(err, io.EOF) { - return errors.Errorf("%s, line %d: skipping new line failed: %s", filename, lineIdx, err.Error()) - } - return err - } - if data[0] != '\n' { - return errors.Errorf("%s, line %d: expected new line, but got: %s", filename, lineIdx, line) - } - if err = c.Validate(filename, lineIdx); err != nil { - return err - } - return nil - } - } -} - func (c *Command) Digest() string { if len(c.digest) == 0 { switch c.Type { @@ -220,17 +127,3 @@ func (c *Command) QueryText() string { } return "" } - -func writeString(key, value string, writer *bytes.Buffer) error { - var err error - if _, err = writer.WriteString(key); err != nil { - return errors.WithStack(err) - } - if _, err = writer.WriteString(value); err != nil { - return errors.WithStack(err) - } - if err = writer.WriteByte('\n'); err != nil { - return errors.WithStack(err) - } - return nil -} diff --git a/pkg/sqlreplay/cmd/cmd_test.go b/pkg/sqlreplay/cmd/cmd_test.go index ae461700f..98eaaff20 100644 --- a/pkg/sqlreplay/cmd/cmd_test.go +++ b/pkg/sqlreplay/cmd/cmd_test.go @@ -4,7 +4,6 @@ package cmd import ( - "bytes" "testing" "time" @@ -12,118 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestEncode(t *testing.T) { - tests := []struct { - payload []byte - cmd pnet.Command - }{ - { - cmd: pnet.ComQuery, - payload: []byte("select 1"), - }, - { - cmd: pnet.ComStmtSendLongData, - payload: []byte{0x01, 0x02, 0x03}, - }, - { - cmd: pnet.ComStmtSendLongData, - payload: []byte("1\n2\n3"), - }, - { - cmd: pnet.ComStmtExecute, - payload: []byte("1\n2\n"), - }, - { - cmd: pnet.ComQuit, - }, - } - - var buf bytes.Buffer - cmds := make([]*Command, 0, len(tests)) - for i, test := range tests { - packet := append([]byte{byte(test.cmd)}, test.payload...) - now := time.Now() - cmd := NewCommand(packet, now, 100) - require.NoError(t, cmd.Encode(&buf), "case %d", i) - cmds = append(cmds, cmd) - } - - mr := mockReader{data: buf.Bytes()} - for i := range tests { - cmd := cmds[i] - newCmd := &Command{} - require.NoError(t, newCmd.Decode(&mr), "case %d, buf: %s", i, buf.String()) - require.True(t, cmd.Equal(newCmd), "case %d, buf: %s", i, buf.String()) - } -} - -func TestDecodeError(t *testing.T) { - tests := []string{ - `select 1`, - `select 1 -`, - `# Time:2024-08-28T18:51:20.477067+08:00 -`, - `# Time: 100 -# Conn_ID: 100 -# Payload_len: 8 -select 1`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: abc -# Payload_len: 8 -select 1`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: 100 -# Type: abc -# Payload_len: 8 -select 1`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: 100 -# Payload_len: 8 -select 1`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: 100 -`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Payload_len: 8 -select 1 -`, - `# Conn_ID: 100 -# Payload_len: 8 -select 1 -`, - `# Conn_ID: 100 -# Payload_len: 100 -select 1 -`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: 100 -# Payload_len: 100 -select 1 -`, - `# Time: -# Conn_ID: 100 -# Payload_len: 8 -select 1 -`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: 100 -# Payload_len: 0 -`, - `# Time: 2024-08-28T18:51:20.477067+08:00 -# Conn_ID: 100 -# Payload_len: abc -`, - } - - for _, test := range tests { - mr := mockReader{data: []byte(test)} - cmd := &Command{} - require.Error(t, cmd.Decode(&mr), test) - } -} - func TestDigest(t *testing.T) { cmd1 := NewCommand(append([]byte{pnet.ComQuery.Byte()}, []byte("select 1")...), time.Time{}, 100) require.NotEmpty(t, cmd1.Digest()) diff --git a/pkg/sqlreplay/cmd/native.go b/pkg/sqlreplay/cmd/native.go new file mode 100644 index 000000000..79d0c72f1 --- /dev/null +++ b/pkg/sqlreplay/cmd/native.go @@ -0,0 +1,171 @@ +// Copyright 2025 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "bytes" + "io" + "strconv" + "strings" + "time" + + "github.com/pingcap/tiproxy/lib/util/errors" + pnet "github.com/pingcap/tiproxy/pkg/proxy/net" + "github.com/siddontang/go/hack" +) + +const ( + nativeCommonKeyPrefix = "# " + nativeCommonKeySuffix = ": " + nativeKeyStartTs = "# Time: " + nativeKeyConnID = "# Conn_ID: " + nativeKeyType = "# Cmd_type: " + nativeKeySuccess = "# Success: " + nativeKeyPayloadLen = "# Payload_len: " +) + +func NewNativeEncoder() *NativeEncoder { + // TODO: handle load infile specially + return &NativeEncoder{} +} + +var _ CmdEncoder = (*NativeEncoder)(nil) + +type NativeEncoder struct { +} + +func (rw *NativeEncoder) Encode(c *Command, writer *bytes.Buffer) error { + var err error + if err = writeString(nativeKeyStartTs, c.StartTs.Format(time.RFC3339Nano), writer); err != nil { + return err + } + if err = writeString(nativeKeyConnID, strconv.FormatUint(c.ConnID, 10), writer); err != nil { + return err + } + if c.Type != pnet.ComQuery { + if err = writeString(nativeKeyType, c.Type.String(), writer); err != nil { + return err + } + } + if !c.Succeess { + if err = writeString(nativeKeySuccess, "false", writer); err != nil { + return err + } + } + // `Payload_len` doesn't include the command type. + if err = writeString(nativeKeyPayloadLen, strconv.Itoa(len(c.Payload[1:])), writer); err != nil { + return err + } + // Unlike TiDB slow log, the payload is binary because StmtExecute can't be transformed to a SQL. + if len(c.Payload) > 1 { + if _, err = writer.Write(c.Payload[1:]); err != nil { + return errors.WithStack(err) + } + } + if err = writer.WriteByte('\n'); err != nil { + return errors.WithStack(err) + } + return nil +} + +func NewNativeDecoder() *NativeDecoder { + // TODO: handle load infile specially + return &NativeDecoder{} +} + +var _ CmdDecoder = (*NativeDecoder)(nil) + +type NativeDecoder struct { +} + +func (rw *NativeDecoder) Decode(reader LineReader) (c *Command, err error) { + c = &Command{} + c.Succeess = true + c.Type = pnet.ComQuery + for { + line, filename, lineIdx, err := reader.ReadLine() + if err != nil { + return nil, err + } + if !strings.HasPrefix(hack.String(line), nativeCommonKeyPrefix) { + return nil, errors.Errorf("%s, line %d: line doesn't start with '%s': %s", filename, lineIdx, nativeCommonKeyPrefix, line) + } + idx := strings.Index(hack.String(line), nativeCommonKeySuffix) + if idx < 0 { + return nil, errors.Errorf("%s, line %d: '%s' is not found in line: %s", filename, lineIdx, nativeCommonKeySuffix, line) + } + idx += len(nativeCommonKeySuffix) + key := hack.String(line[:idx]) + value := hack.String(line[idx:]) + if len(value) == 0 { + return nil, errors.Errorf("%s, line %d: value is empty in line: %s", filename, lineIdx, line) + } + switch key { + case nativeKeyStartTs: + if !c.StartTs.IsZero() { + return nil, errors.Errorf("%s, line %d: redundant Time: %s, Time was %v", filename, lineIdx, line, c.StartTs) + } + c.StartTs, err = time.Parse(time.RFC3339Nano, value) + if err != nil { + return nil, errors.Errorf("%s, line %d: parsing Time failed: %s", filename, lineIdx, line) + } + case nativeKeyConnID: + if c.ConnID > 0 { + return nil, errors.Errorf("%s, line %d: redundant Conn_ID: %s, Conn_ID was %d", filename, lineIdx, line, c.ConnID) + } + c.ConnID, err = strconv.ParseUint(value, 10, 64) + if err != nil { + return nil, errors.Errorf("%s, line %d: parsing Conn_ID failed: %s", filename, lineIdx, line) + } + case nativeKeyType: + if c.Type != pnet.ComQuery { + return nil, errors.Errorf("%s, line %d: redundant Cmd_type: %s, Cmd_type was %v", filename, lineIdx, line, c.Type) + } + c.Type = pnet.CommandFromString(value) + case nativeKeySuccess: + c.Succeess = value == "true" + case nativeKeyPayloadLen: + var payloadLen int + if payloadLen, err = strconv.Atoi(value); err != nil { + return nil, errors.Errorf("parsing Payload_len failed: %s", line) + } + c.Payload = make([]byte, payloadLen+1) + c.Payload[0] = c.Type.Byte() + if payloadLen > 0 { + if filename, lineIdx, err = reader.Read(c.Payload[1:]); err != nil { + return nil, errors.Errorf("%s, line %d: reading Payload failed: %s", filename, lineIdx, err.Error()) + } + } + // skip '\n' + var data [1]byte + if filename, lineIdx, err = reader.Read(data[:]); err != nil { + if !errors.Is(err, io.EOF) { + return nil, errors.Errorf("%s, line %d: skipping new line failed: %s", filename, lineIdx, err.Error()) + } + return nil, err + } + if data[0] != '\n' { + return nil, errors.Errorf("%s, line %d: expected new line, but got: %s", filename, lineIdx, line) + } + if err = c.Validate(filename, lineIdx); err != nil { + return nil, err + } + return c, nil + } + } +} + +func writeString(key, value string, writer *bytes.Buffer) error { + var err error + if _, err = writer.WriteString(key); err != nil { + return errors.WithStack(err) + } + if _, err = writer.WriteString(value); err != nil { + return errors.WithStack(err) + } + if err = writer.WriteByte('\n'); err != nil { + return errors.WithStack(err) + } + return nil +} diff --git a/pkg/sqlreplay/cmd/native_test.go b/pkg/sqlreplay/cmd/native_test.go new file mode 100644 index 000000000..2053e8037 --- /dev/null +++ b/pkg/sqlreplay/cmd/native_test.go @@ -0,0 +1,128 @@ +// Copyright 2024 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package cmd + +import ( + "bytes" + "testing" + "time" + + pnet "github.com/pingcap/tiproxy/pkg/proxy/net" + "github.com/stretchr/testify/require" +) + +func TestEncode(t *testing.T) { + tests := []struct { + payload []byte + cmd pnet.Command + }{ + { + cmd: pnet.ComQuery, + payload: []byte("select 1"), + }, + { + cmd: pnet.ComStmtSendLongData, + payload: []byte{0x01, 0x02, 0x03}, + }, + { + cmd: pnet.ComStmtSendLongData, + payload: []byte("1\n2\n3"), + }, + { + cmd: pnet.ComStmtExecute, + payload: []byte("1\n2\n"), + }, + { + cmd: pnet.ComQuit, + }, + } + + var buf bytes.Buffer + cmds := make([]*Command, 0, len(tests)) + encoder := NewCmdEncoder(FormatNative) + for i, test := range tests { + packet := append([]byte{byte(test.cmd)}, test.payload...) + now := time.Now() + cmd := NewCommand(packet, now, 100) + require.NoError(t, encoder.Encode(cmd, &buf), "case %d", i) + cmds = append(cmds, cmd) + } + + mr := mockReader{data: buf.Bytes()} + decoder := NewCmdDecoder(FormatNative) + for i := range tests { + cmd := cmds[i] + newCmd, err := decoder.Decode(&mr) + require.NoError(t, err, "case %d, buf: %s", i, buf.String()) + require.True(t, cmd.Equal(newCmd), "case %d, buf: %s", i, buf.String()) + } +} + +func TestDecodeError(t *testing.T) { + tests := []string{ + `select 1`, + `select 1 +`, + `# Time:2024-08-28T18:51:20.477067+08:00 +`, + `# Time: 100 +# Conn_ID: 100 +# Payload_len: 8 +select 1`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: abc +# Payload_len: 8 +select 1`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: 100 +# Type: abc +# Payload_len: 8 +select 1`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: 100 +# Payload_len: 8 +select 1`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: 100 +`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Payload_len: 8 +select 1 +`, + `# Conn_ID: 100 +# Payload_len: 8 +select 1 +`, + `# Conn_ID: 100 +# Payload_len: 100 +select 1 +`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: 100 +# Payload_len: 100 +select 1 +`, + `# Time: +# Conn_ID: 100 +# Payload_len: 8 +select 1 +`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: 100 +# Payload_len: 0 +`, + `# Time: 2024-08-28T18:51:20.477067+08:00 +# Conn_ID: 100 +# Payload_len: abc +`, + } + + for _, test := range tests { + mr := mockReader{data: []byte(test)} + decoder := NewCmdDecoder(FormatNative) + _, err := decoder.Decode(&mr) + require.Error(t, err, test) + } +} diff --git a/pkg/sqlreplay/conn/conn.go b/pkg/sqlreplay/conn/conn.go index 8ae9e8126..5faa92665 100644 --- a/pkg/sqlreplay/conn/conn.go +++ b/pkg/sqlreplay/conn/conn.go @@ -39,6 +39,7 @@ func (s *ReplayStats) Reset() { type Conn interface { Run(ctx context.Context) ExecuteCmd(command *cmd.Command) + LastCmd() *cmd.Command Stop() } @@ -56,6 +57,7 @@ type conn struct { backendConn BackendConn connID uint64 // capture ID, not replay ID replayStats *ReplayStats + lastCmd *cmd.Command lastPendingCmds int // last pending cmds reported to the stats readonly bool } @@ -188,6 +190,7 @@ func (c *conn) updateCmdForExecuteStmt(command *cmd.Command) bool { func (c *conn) ExecuteCmd(command *cmd.Command) { c.cmdLock.Lock() c.cmdList.PushFront(command) + c.lastCmd = command pendingCmds := c.cmdList.Len() c.updatePendingCmds(pendingCmds) c.cmdLock.Unlock() @@ -197,6 +200,13 @@ func (c *conn) ExecuteCmd(command *cmd.Command) { } } +// Used for deduplicate commands in audit logs. +func (c *conn) LastCmd() *cmd.Command { + c.cmdLock.Lock() + defer c.cmdLock.Unlock() + return c.lastCmd +} + func (c *conn) Stop() { close(c.cmdCh) } diff --git a/pkg/sqlreplay/manager/job.go b/pkg/sqlreplay/manager/job.go index afa5bb6c9..4fcb99b43 100644 --- a/pkg/sqlreplay/manager/job.go +++ b/pkg/sqlreplay/manager/job.go @@ -129,6 +129,7 @@ type replayJob4Marshal struct { job4Marshal Input string `json:"input,omitempty"` Username string `json:"username,omitempty"` + Format string `json:"format,omitempty"` Speed float64 `json:"speed,omitempty"` ReadOnly bool `json:"readonly,omitempty"` } @@ -146,6 +147,7 @@ func (job *replayJob) MarshalJSON() ([]byte, error) { Username: job.cfg.Username, Speed: job.cfg.Speed, ReadOnly: job.cfg.ReadOnly, + Format: job.cfg.Format, } return json.Marshal(r) } diff --git a/pkg/sqlreplay/replay/mock_test.go b/pkg/sqlreplay/replay/mock_test.go index 5502f45df..a2e3de573 100644 --- a/pkg/sqlreplay/replay/mock_test.go +++ b/pkg/sqlreplay/replay/mock_test.go @@ -36,6 +36,10 @@ func (c *mockConn) Run(ctx context.Context) { c.closeCh <- c.connID } +func (c *mockConn) LastCmd() *cmd.Command { + return nil +} + func (c *mockConn) Stop() { c.closed <- struct{}{} } @@ -59,6 +63,10 @@ func (c *mockPendingConn) Run(ctx context.Context) { c.closeCh <- c.connID } +func (c *mockPendingConn) LastCmd() *cmd.Command { + return nil +} + func (c *mockPendingConn) Stop() { c.closed <- struct{}{} } @@ -81,6 +89,7 @@ func (m *mockChLoader) writeCommand(cmd *cmd.Command) { } func (m *mockChLoader) Read(data []byte) (string, int, error) { + encoder := cmd.NewCmdEncoder(cmd.FormatNative) for { _, err := m.buf.Read(data) if errors.Is(err, io.EOF) { @@ -88,7 +97,7 @@ func (m *mockChLoader) Read(data []byte) (string, int, error) { if !ok { return "", 0, io.EOF } - _ = command.Encode(&m.buf) + _ = encoder.Encode(command, &m.buf) } else { return "", 0, err } @@ -96,6 +105,7 @@ func (m *mockChLoader) Read(data []byte) (string, int, error) { } func (m *mockChLoader) ReadLine() ([]byte, string, int, error) { + encoder := cmd.NewCmdEncoder(cmd.FormatNative) for { line, err := m.buf.ReadBytes('\n') if errors.Is(err, io.EOF) { @@ -103,7 +113,7 @@ func (m *mockChLoader) ReadLine() ([]byte, string, int, error) { if !ok { return nil, "", 0, io.EOF } - _ = command.Encode(&m.buf) + _ = encoder.Encode(command, &m.buf) } else { return line[:len(line)-1], "", 0, err } @@ -128,8 +138,9 @@ func newMockNormalLoader() *mockNormalLoader { return &mockNormalLoader{} } -func (m *mockNormalLoader) writeCommand(cmd *cmd.Command) { - _ = cmd.Encode(&m.buf) +func (m *mockNormalLoader) writeCommand(command *cmd.Command) { + encoder := cmd.NewCmdEncoder(cmd.FormatNative) + _ = encoder.Encode(command, &m.buf) } func (m *mockNormalLoader) Read(data []byte) (string, int, error) { diff --git a/pkg/sqlreplay/replay/replay.go b/pkg/sqlreplay/replay/replay.go index 9d808edfc..251426288 100644 --- a/pkg/sqlreplay/replay/replay.go +++ b/pkg/sqlreplay/replay/replay.go @@ -50,6 +50,7 @@ type Replay interface { } type ReplayConfig struct { + Format string Input string Username string Password string @@ -85,6 +86,11 @@ func (cfg *ReplayConfig) Validate() (storage.ExternalStorage, error) { } else if cfg.Speed < minSpeed || cfg.Speed > maxSpeed { return storage, errors.Errorf("speed should be between %f and %f", minSpeed, maxSpeed) } + switch cfg.Format { + case cmd.FormatAuditLogPlugin, cmd.FormatNative, "": + default: + return storage, errors.Errorf("invalid traffic file format %s", cfg.Format) + } // Maybe there's a time bias between TiDB and TiProxy, so add one minute. now := time.Now() if cfg.StartTime.IsZero() { @@ -214,6 +220,7 @@ func (r *replay) readCommands(ctx context.Context) { var err error maxPendingCmds := int64(0) totalWaitTime := time.Duration(0) + decoder := cmd.NewCmdDecoder(r.cfg.Format) for ctx.Err() == nil { for hasCloseEvent := true; hasCloseEvent; { select { @@ -224,8 +231,8 @@ func (r *replay) readCommands(ctx context.Context) { } } - command := &cmd.Command{} - if err = command.Decode(reader); err != nil { + var command *cmd.Command + if command, err = decoder.Decode(reader); err != nil { if errors.Is(err, io.EOF) { r.lg.Info("replay reads EOF", zap.String("reader", reader.String())) err = nil @@ -304,7 +311,10 @@ func (r *replay) executeCmd(ctx context.Context, command *cmd.Command, conns map }, nil, r.lg) } if conn != nil && !reflect.ValueOf(conn).IsNil() { - conn.ExecuteCmd(command) + // Deduplicate commands in audit logs. + if r.cfg.Format != cmd.FormatAuditLogPlugin || !command.Equal(conn.LastCmd()) { + conn.ExecuteCmd(command) + } } r.decodedCmds.Add(1) } @@ -330,8 +340,10 @@ func (r *replay) Progress() (float64, time.Time, bool, error) { func (r *replay) readMeta() *store.Meta { m := new(store.Meta) - if err := m.Read(r.storage); err != nil { - r.lg.Error("read meta failed", zap.Error(err)) + if r.cfg.Format == cmd.FormatNative || r.cfg.Format == "" { + if err := m.Read(r.storage); err != nil { + r.lg.Error("read meta failed", zap.Error(err)) + } } return m } diff --git a/pkg/sqlreplay/replay/replay_test.go b/pkg/sqlreplay/replay/replay_test.go index fb3fd4a73..7a44ea119 100644 --- a/pkg/sqlreplay/replay/replay_test.go +++ b/pkg/sqlreplay/replay/replay_test.go @@ -312,13 +312,13 @@ func TestLoadEncryptionKey(t *testing.T) { StartTime: now, reader: loader, } - for _, test := range tests { + for i, test := range tests { cfg.KeyFile = test.keyFile replay := NewReplay(zap.NewNop(), id.NewIDManager()) cfg.report = newMockReport(replay.exceptionCh) err = replay.Start(cfg, nil, nil, &backend.BCConfig{}) if len(test.err) > 0 { - require.ErrorContains(t, err, test.err) + require.ErrorContains(t, err, test.err, "test %d", i) } else { require.NoError(t, err) replay.Lock()