diff --git a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs index 990c59fa9c2..54b5138eec2 100644 --- a/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs +++ b/src/Microsoft.Data.Sqlite.Core/SqliteDataReader.cs @@ -154,6 +154,9 @@ public override bool NextResult() { stmt = _stmtEnumerator.Current; + var connectionHandle = _command.Connection!.Handle; + var totalChangesBefore = sqlite3_total_changes(connectionHandle); + var timer = SharedStopwatch.StartNew(); while (IsBusy(rc = sqlite3_step(stmt))) @@ -172,7 +175,7 @@ public override bool NextResult() _totalElapsedTime += timer.Elapsed; - SqliteException.ThrowExceptionForRC(rc, _command.Connection!.Handle); + SqliteException.ThrowExceptionForRC(rc, connectionHandle); // It's a SELECT statement if (sqlite3_column_count(stmt) != 0) @@ -185,13 +188,26 @@ public override bool NextResult() while (rc != SQLITE_DONE) { rc = sqlite3_step(stmt); - SqliteException.ThrowExceptionForRC(rc, _command.Connection.Handle); + SqliteException.ThrowExceptionForRC(rc, connectionHandle); } sqlite3_reset(stmt); - var changes = sqlite3_changes(_command.Connection.Handle); - AddChanges(changes); + // sqlite3_changes() returns the row count from the most recent INSERT, UPDATE, or DELETE + // and incorrectly persists across DDL statements. Use sqlite3_total_changes() before and after + // to calculate the actual delta for this statement, ensuring DDL statements don't add stale counts. + var totalChangesAfter = sqlite3_total_changes(connectionHandle); + var changes = totalChangesAfter - totalChangesBefore; + // sqlite3_total_changes, unfortunately, counts also changes from triggers, etc. which is not what we want. + // So we use it only to detect changes and if so, use sqlite3_changes. + if (changes > 0) + { + AddChanges(sqlite3_changes(connectionHandle)); + } + else + { + AddChanges(0); + } } catch { diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs index cbc49267a45..e635608043b 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteConnectionTest.cs @@ -812,7 +812,7 @@ public void CreateFunction_deterministic_param_works() connection.ExecuteNonQuery("CREATE TABLE Data (Value); INSERT INTO Data VALUES (0);"); connection.CreateFunction("test", (double x) => x, true); - Assert.Equal(1, connection.ExecuteNonQuery("CREATE INDEX InvalidIndex ON Data (Value) WHERE test(Value) = 0;")); + Assert.Equal(0, connection.ExecuteNonQuery("CREATE INDEX InvalidIndex ON Data (Value) WHERE test(Value) = 0;")); } [Fact] diff --git a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs index 8cfe3ba73df..f86a34c7dfc 100644 --- a/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs +++ b/test/Microsoft.Data.Sqlite.Tests/SqliteDataReaderTest.cs @@ -1970,6 +1970,45 @@ public void RecordsAffected_works_with_returning_multiple() } } + [Fact] + public void RecordsAffected_not_affected_by_DDL_statements() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + + using (var reader = connection.ExecuteReader( + @"CREATE TABLE foo(bar TEXT NOT NULL); + CREATE TABLE xyz(aaa TEXT NOT NULL); + INSERT INTO foo(bar) VALUES('baz'); + INSERT INTO foo(bar) VALUES('baz2'); + DROP TABLE xyz;")) + { + Assert.Equal(2, reader.RecordsAffected); + } + } + } + + [Fact] + public void RecordsAffected_not_affected_by_DDL_statements_with_drop_and_create() + { + using (var connection = new SqliteConnection("Data Source=:memory:")) + { + connection.Open(); + + using (var reader = connection.ExecuteReader( + @"CREATE TABLE foo(bar TEXT NOT NULL); + CREATE TABLE xyz(aaa TEXT NOT NULL); + INSERT INTO foo(bar) VALUES('baz'); + INSERT INTO foo(bar) VALUES('baz2'); + DROP TABLE xyz; + CREATE TABLE xyz(aaa TEXT NOT NULL);")) + { + Assert.Equal(2, reader.RecordsAffected); + } + } + } + [Fact] public void GetSchemaTable_works() {