diff --git a/mssql_python/connection.py b/mssql_python/connection.py index 2a81009b..59eb3306 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -121,7 +121,7 @@ class Connection: ProgrammingError = ProgrammingError NotSupportedError = NotSupportedError - def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None: + def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, timeout: int = 0, **kwargs) -> None: """ Initialize the connection object with the specified connection string and parameters. @@ -180,6 +180,7 @@ def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_bef self._attrs_before.update(connection_result[1]) self._closed = False + self._timeout = timeout # Using WeakSet which automatically removes cursors when they are no longer in use # It is a set that holds weak references to its elements. @@ -236,6 +237,39 @@ def _construct_connection_string(self, connection_str: str = "", **kwargs) -> st return conn_str + @property + def timeout(self) -> int: + """ + Get the current query timeout setting in seconds. + + Returns: + int: The timeout value in seconds. Zero means no timeout (wait indefinitely). + """ + return self._timeout + + @timeout.setter + def timeout(self, value: int) -> None: + """ + Set the query timeout for all operations performed by this connection. + + Args: + value (int): The timeout value in seconds. Zero means no timeout. + + Returns: + None + + Note: + This timeout applies to all cursors created from this connection. + It cannot be changed for individual cursors or SQL statements. + If a query timeout occurs, an OperationalError exception will be raised. + """ + if not isinstance(value, int): + raise TypeError("Timeout must be an integer") + if value < 0: + raise ValueError("Timeout cannot be negative") + self._timeout = value + log('info', f"Query timeout set to {value} seconds") + @property def autocommit(self) -> bool: """ @@ -533,7 +567,7 @@ def cursor(self) -> Cursor: ddbc_error="Cannot create cursor on closed connection", ) - cursor = Cursor(self) + cursor = Cursor(self, timeout=self._timeout) self._cursors.add(cursor) # Track the cursor return cursor diff --git a/mssql_python/constants.py b/mssql_python/constants.py index 20c8f663..3d6b4732 100644 --- a/mssql_python/constants.py +++ b/mssql_python/constants.py @@ -115,6 +115,7 @@ class ConstantsDDBC(Enum): SQL_C_WCHAR = -8 SQL_NULLABLE = 1 SQL_MAX_NUMERIC_LEN = 16 + SQL_ATTR_QUERY_TIMEOUT = 2 SQL_FETCH_NEXT = 1 SQL_FETCH_FIRST = 2 diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index f72f8192..4a1e6a91 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -54,7 +54,7 @@ class Cursor: setoutputsize(size, column=None) -> None. """ - def __init__(self, connection) -> None: + def __init__(self, connection, timeout: int = 0) -> None: """ Initialize the cursor with a database connection. @@ -62,6 +62,7 @@ def __init__(self, connection) -> None: connection: Database connection object. """ self._connection = connection # Store as private attribute + self._timeout = timeout # self.connection.autocommit = False self.hstmt = None self._initialize_cursor() @@ -778,6 +779,20 @@ def execute( # Clear any previous messages self.messages = [] + # Apply timeout if set (non-zero) + if self._timeout > 0: + try: + timeout_value = int(self._timeout) + ret = ddbc_bindings.DDBCSQLSetStmtAttr( + self.hstmt, + ddbc_sql_const.SQL_ATTR_QUERY_TIMEOUT.value, + timeout_value + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + log('debug', f"Set query timeout to {timeout_value} seconds") + except Exception as e: + log('warning', f"Failed to set query timeout: {e}") + param_info = ddbc_bindings.ParamInfo parameters_type = [] @@ -932,6 +947,20 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None: if not seq_of_parameters: self.rowcount = 0 return + + # Apply timeout if set (non-zero) + if self._timeout > 0: + try: + timeout_value = int(self._timeout) + ret = ddbc_bindings.DDBCSQLSetStmtAttr( + self.hstmt, + ddbc_sql_const.SQL_ATTR_QUERY_TIMEOUT.value, + timeout_value + ) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + log('debug', f"Set query timeout to {self._timeout} seconds") + except Exception as e: + log('warning', f"Failed to set query timeout: {e}") param_info = ddbc_bindings.ParamInfo param_count = len(seq_of_parameters[0]) diff --git a/mssql_python/db_connection.py b/mssql_python/db_connection.py index 5e98056e..48f3f966 100644 --- a/mssql_python/db_connection.py +++ b/mssql_python/db_connection.py @@ -5,7 +5,7 @@ """ from mssql_python.connection import Connection -def connect(connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> Connection: +def connect(connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, timeout: int = 0, **kwargs) -> Connection: """ Constructor for creating a connection to the database. @@ -33,5 +33,5 @@ def connect(connection_str: str = "", autocommit: bool = False, attrs_before: di be used to perform database operations such as executing queries, committing transactions, and closing the connection. """ - conn = Connection(connection_str, autocommit=autocommit, attrs_before=attrs_before, **kwargs) + conn = Connection(connection_str, autocommit=autocommit, attrs_before=attrs_before, timeout=timeout, **kwargs) return conn diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 507b3a34..bac9c664 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3151,6 +3151,9 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchScroll", &SQLFetchScroll_wrap, "Scroll to a specific position in the result set and optionally fetch data"); m.def("DDBCSetDecimalSeparator", &DDBCSetDecimalSeparator, "Set the decimal separator character"); + m.def("DDBCSQLSetStmtAttr", [](SqlHandlePtr stmt, SQLINTEGER attr, SQLPOINTER value) { + return SQLSetStmtAttr_ptr(stmt->get(), attr, value, 0); + }, "Set statement attributes"); // Add a version attribute diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 902963a2..9c6fd35e 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -21,13 +21,50 @@ - test_context_manager_connection_closes: Test that context manager closes the connection. """ -from mssql_python.exceptions import InterfaceError, ProgrammingError import mssql_python import pytest import time from mssql_python import connect, Connection, pooling, SQL_CHAR, SQL_WCHAR -from contextlib import closing import threading +# Import all exception classes for testing +from mssql_python.exceptions import ( + Warning, + Error, + InterfaceError, + DatabaseError, + DataError, + OperationalError, + IntegrityError, + InternalError, + ProgrammingError, + NotSupportedError, +) +import struct +from datetime import datetime, timedelta, timezone +from mssql_python.constants import ConstantsDDBC + +@pytest.fixture(autouse=True) +def clean_connection_state(db_connection): + """Ensure connection is in a clean state before each test""" + # Create a cursor and clear any active results + try: + cleanup_cursor = db_connection.cursor() + cleanup_cursor.execute("SELECT 1") # Simple query to reset state + cleanup_cursor.fetchall() # Consume all results + cleanup_cursor.close() + except Exception: + pass # Ignore errors during cleanup + + yield # Run the test + + # Clean up after the test + try: + cleanup_cursor = db_connection.cursor() + cleanup_cursor.execute("SELECT 1") # Simple query to reset state + cleanup_cursor.fetchall() # Consume all results + cleanup_cursor.close() + except Exception: + pass # Ignore errors during cleanup # Import all exception classes for testing from mssql_python.exceptions import ( @@ -4075,4 +4112,152 @@ def faulty_converter(value): finally: # Clean up - db_connection.clear_output_converters() \ No newline at end of file + db_connection.clear_output_converters() + +def test_timeout_default(db_connection): + """Test that the default timeout value is 0 (no timeout)""" + assert hasattr(db_connection, 'timeout'), "Connection should have a timeout attribute" + assert db_connection.timeout == 0, "Default timeout should be 0" + +def test_timeout_setter(db_connection): + """Test setting and getting the timeout value""" + # Set a non-zero timeout + db_connection.timeout = 30 + assert db_connection.timeout == 30, "Timeout should be set to 30" + + # Test that timeout can be reset to zero + db_connection.timeout = 0 + assert db_connection.timeout == 0, "Timeout should be reset to 0" + + # Test setting invalid timeout values + with pytest.raises(ValueError): + db_connection.timeout = -1 + + with pytest.raises(TypeError): + db_connection.timeout = "30" + + # Reset timeout to default for other tests + db_connection.timeout = 0 + +def test_timeout_from_constructor(conn_str): + """Test setting timeout in the connection constructor""" + # Create a connection with timeout set + conn = connect(conn_str, timeout=45) + try: + assert conn.timeout == 45, "Timeout should be set to 45 from constructor" + + # Create a cursor and verify it inherits the timeout + cursor = conn.cursor() + # Execute a quick query to ensure the timeout doesn't interfere + cursor.execute("SELECT 1") + result = cursor.fetchone() + assert result[0] == 1, "Query execution should succeed with timeout set" + finally: + # Clean up + conn.close() + +def test_timeout_long_query(db_connection): + """Test that a query exceeding the timeout raises an exception if supported by driver""" + import time + import pytest + + cursor = db_connection.cursor() + + try: + # First execute a simple query to check if we can run tests + cursor.execute("SELECT 1") + cursor.fetchall() + except Exception as e: + pytest.skip(f"Skipping timeout test due to connection issue: {e}") + + # Set a short timeout + original_timeout = db_connection.timeout + db_connection.timeout = 2 # 2 seconds + + try: + # Try several different approaches to test timeout + start_time = time.perf_counter() + try: + # Method 1: CPU-intensive query with REPLICATE and large result set + cpu_intensive_query = """ + WITH numbers AS ( + SELECT TOP 1000000 ROW_NUMBER() OVER (ORDER BY (SELECT NULL)) AS n + FROM sys.objects a CROSS JOIN sys.objects b + ) + SELECT COUNT(*) FROM numbers WHERE n % 2 = 0 + """ + cursor.execute(cpu_intensive_query) + cursor.fetchall() + + elapsed_time = time.perf_counter() - start_time + + # If we get here without an exception, try a different approach + if elapsed_time < 4.5: + + # Method 2: Try with WAITFOR + start_time = time.perf_counter() + cursor.execute("WAITFOR DELAY '00:00:05'") + cursor.fetchall() + elapsed_time = time.perf_counter() - start_time + + # If we still get here, try one more approach + if elapsed_time < 4.5: + + # Method 3: Try with a join that generates many rows + start_time = time.perf_counter() + cursor.execute(""" + SELECT COUNT(*) FROM sys.objects a, sys.objects b, sys.objects c + WHERE a.object_id = b.object_id * c.object_id + """) + cursor.fetchall() + elapsed_time = time.perf_counter() - start_time + + # If we still get here without an exception + if elapsed_time < 4.5: + pytest.skip("Timeout feature not enforced by database driver") + + except Exception as e: + # Verify this is a timeout exception + elapsed_time = time.perf_counter() - start_time + assert elapsed_time < 4.5, "Exception occurred but after expected timeout" + error_text = str(e).lower() + + # Check for various error messages that might indicate timeout + timeout_indicators = [ + "timeout", "timed out", "hyt00", "hyt01", "cancel", + "operation canceled", "execution terminated", "query limit" + ] + + assert any(indicator in error_text for indicator in timeout_indicators), \ + f"Exception occurred but doesn't appear to be a timeout error: {e}" + finally: + # Reset timeout for other tests + db_connection.timeout = original_timeout + +def test_timeout_affects_all_cursors(db_connection): + """Test that changing timeout on connection affects all new cursors""" + # Create a cursor with default timeout + cursor1 = db_connection.cursor() + + # Change the connection timeout + original_timeout = db_connection.timeout + db_connection.timeout = 10 + + # Create a new cursor + cursor2 = db_connection.cursor() + + try: + # Execute quick queries to ensure both cursors work + cursor1.execute("SELECT 1") + result1 = cursor1.fetchone() + assert result1[0] == 1, "Query with first cursor failed" + + cursor2.execute("SELECT 2") + result2 = cursor2.fetchone() + assert result2[0] == 2, "Query with second cursor failed" + + # No direct way to check cursor timeout, but both should succeed + # with the current timeout setting + finally: + # Reset timeout + db_connection.timeout = original_timeout \ No newline at end of file diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 5c0c5f31..072a5ec6 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -7113,6 +7113,248 @@ def test_decimal_separator_basic_functionality(): # Restore original separator mssql_python.setDecimalSeparator(original_separator) +def test_lowercase_attribute(cursor, db_connection): + """Test that the lowercase attribute properly converts column names to lowercase""" + + # Store original value to restore after test + original_lowercase = mssql_python.lowercase + drop_cursor = None + + try: + # Create a test table with mixed-case column names + cursor.execute(""" + CREATE TABLE #pytest_lowercase_test ( + ID INT PRIMARY KEY, + UserName VARCHAR(50), + EMAIL_ADDRESS VARCHAR(100), + PhoneNumber VARCHAR(20) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_lowercase_test (ID, UserName, EMAIL_ADDRESS, PhoneNumber) + VALUES (1, 'JohnDoe', 'john@example.com', '555-1234') + """) + db_connection.commit() + + # First test with lowercase=False (default) + mssql_python.lowercase = False + cursor1 = db_connection.cursor() + cursor1.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should preserve original case + column_names1 = [desc[0] for desc in cursor1.description] + assert "ID" in column_names1, "Column 'ID' should be present with original case" + assert "UserName" in column_names1, "Column 'UserName' should be present with original case" + + # Make sure to consume all results and close the cursor + cursor1.fetchall() + cursor1.close() + + # Now test with lowercase=True + mssql_python.lowercase = True + cursor2 = db_connection.cursor() + cursor2.execute("SELECT * FROM #pytest_lowercase_test") + + # Description column names should be lowercase + column_names2 = [desc[0] for desc in cursor2.description] + assert "id" in column_names2, "Column names should be lowercase when lowercase=True" + assert "username" in column_names2, "Column names should be lowercase when lowercase=True" + + # Make sure to consume all results and close the cursor + cursor2.fetchall() + cursor2.close() + + # Create a fresh cursor for cleanup + drop_cursor = db_connection.cursor() + + finally: + # Restore original value + mssql_python.lowercase = original_lowercase + + try: + # Use a separate cursor for cleanup + if drop_cursor: + drop_cursor.execute("DROP TABLE IF EXISTS #pytest_lowercase_test") + db_connection.commit() + drop_cursor.close() + except Exception as e: + print(f"Warning: Failed to drop test table: {e}") + +def test_decimal_separator_function(cursor, db_connection): + """Test decimal separator functionality with database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_separator_test ( + id INT PRIMARY KEY, + decimal_value DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test values with default separator (.) + test_value = decimal.Decimal('123.45') + cursor.execute(""" + INSERT INTO #pytest_decimal_separator_test (id, decimal_value) + VALUES (1, ?) + """, [test_value]) + db_connection.commit() + + # First test with default decimal separator (.) + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default separator not found in string representation" + + # Now change to comma separator and test string representation + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT id, decimal_value FROM #pytest_decimal_separator_test") + row = cursor.fetchone() + + # This should format the decimal with a comma in the string representation + comma_str = str(row) + assert '123,45' in comma_str, f"Expected comma in string representation but got: {comma_str}" + + finally: + # Restore original decimal separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_separator_test") + db_connection.commit() + +def test_decimal_separator_basic_functionality(): + """Test basic decimal separator functionality without database operations""" + # Store original value to restore after test + original_separator = mssql_python.getDecimalSeparator() + + try: + # Test default value + assert mssql_python.getDecimalSeparator() == '.', "Default decimal separator should be '.'" + + # Test setting to comma + mssql_python.setDecimalSeparator(',') + assert mssql_python.getDecimalSeparator() == ',', "Decimal separator should be ',' after setting" + + # Test setting to other valid separators + mssql_python.setDecimalSeparator(':') + assert mssql_python.getDecimalSeparator() == ':', "Decimal separator should be ':' after setting" + + # Test invalid inputs + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('') # Empty string + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator('too_long') # More than one character + + with pytest.raises(ValueError): + mssql_python.setDecimalSeparator(123) # Not a string + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + +def test_decimal_separator_with_multiple_values(cursor, db_connection): + """Test decimal separator with multiple different decimal values""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_multi_test ( + id INT PRIMARY KEY, + positive_value DECIMAL(10, 2), + negative_value DECIMAL(10, 2), + zero_value DECIMAL(10, 2), + small_value DECIMAL(10, 4) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_multi_test VALUES (1, 123.45, -67.89, 0.00, 0.0001) + """) + db_connection.commit() + + # Test with default separator first + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + default_str = str(row) + assert '123.45' in default_str, "Default positive value formatting incorrect" + assert '-67.89' in default_str, "Default negative value formatting incorrect" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + cursor.execute("SELECT * FROM #pytest_decimal_multi_test") + row = cursor.fetchone() + comma_str = str(row) + + # Verify comma is used in all decimal values + assert '123,45' in comma_str, "Positive value not formatted with comma" + assert '-67,89' in comma_str, "Negative value not formatted with comma" + assert '0,00' in comma_str, "Zero value not formatted with comma" + assert '0,0001' in comma_str, "Small value not formatted with comma" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_multi_test") + db_connection.commit() + +def test_decimal_separator_calculations(cursor, db_connection): + """Test that decimal separator doesn't affect calculations""" + original_separator = mssql_python.getDecimalSeparator() + + try: + # Create test table + cursor.execute(""" + CREATE TABLE #pytest_decimal_calc_test ( + id INT PRIMARY KEY, + value1 DECIMAL(10, 2), + value2 DECIMAL(10, 2) + ) + """) + db_connection.commit() + + # Insert test data + cursor.execute(""" + INSERT INTO #pytest_decimal_calc_test VALUES (1, 10.25, 5.75) + """) + db_connection.commit() + + # Test with default separator + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation incorrect with default separator" + + # Change to comma separator + mssql_python.setDecimalSeparator(',') + + # Calculations should still work correctly + cursor.execute("SELECT value1 + value2 AS sum_result FROM #pytest_decimal_calc_test") + row = cursor.fetchone() + assert row.sum_result == decimal.Decimal('16.00'), "Sum calculation affected by separator change" + + # But string representation should use comma + assert '16,00' in str(row), "Sum result not formatted with comma in string representation" + + finally: + # Restore original separator + mssql_python.setDecimalSeparator(original_separator) + + # Cleanup + cursor.execute("DROP TABLE IF EXISTS #pytest_decimal_calc_test") + db_connection.commit() + def test_close(db_connection): """Test closing the cursor""" try: