diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index bd57340ca6..cce0e64b5a 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -13,8 +13,8 @@ from sys import version_info import pydantic +from pydantic import Field, computed_field from packaging import version -from pydantic import Field from pydantic_core import from_json from sqlglot import exp from sqlglot.errors import ParseError @@ -110,7 +110,14 @@ class ConnectionConfig(abc.ABC, BaseConfig): catalog_type_overrides: t.Optional[t.Dict[str, str]] = None # Whether to share a single connection across threads or create a new connection per thread. - shared_connection: t.ClassVar[bool] = False + # + # MyPy throws a "Decorators on top of @property are not supported" error despite this being a + # valid decoration, and Pydantic recommend disabling the MyPy hint for this reason - see: + # https://pydantic.dev/docs/validation/2.0/usage/computed_fields/ + @computed_field # type: ignore[prop-decorator] + @property + def shared_connection(self) -> bool: + return False @property @abc.abstractmethod @@ -311,7 +318,10 @@ class BaseDuckDBConnectionConfig(ConnectionConfig): token: t.Optional[str] = None - shared_connection: t.ClassVar[bool] = True + @computed_field # type: ignore[prop-decorator] + @property + def shared_connection(self) -> bool: + return True _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} @@ -820,11 +830,15 @@ class DatabricksConnectionConfig(ConnectionConfig): DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 - shared_connection: t.ClassVar[bool] = True - _concurrent_tasks_validator = concurrent_tasks_validator _http_headers_validator = http_headers_validator + @computed_field # type: ignore[prop-decorator] + @property + def shared_connection(self) -> bool: + """The connection should only be shared if U2M OAuth is being used""" + return self.auth_type is not None and self.oauth_client_id is None + @model_validator(mode="before") def _databricks_connect_validator(cls, data: t.Any) -> t.Any: # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 7af556d6a3..4bb9580783 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,6 +1,7 @@ import os import pathlib import re +import textwrap from pathlib import Path from unittest import mock import typing as t @@ -44,16 +45,18 @@ def yaml_config_path(tmp_path_factory) -> Path: config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - another_gateway: - connection: - type: duckdb - database: test_db - -model_defaults: - dialect: '' - """ + textwrap.dedent( + """ + gateways: + another_gateway: + connection: + type: duckdb + database: test_db + + model_defaults: + dialect: '' + """ + ) ) return config_path @@ -63,9 +66,16 @@ def python_config_path(tmp_path_factory) -> Path: config_path = tmp_path_factory.mktemp("python_config") / "config.py" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig -config = Config(gateways=GatewayConfig(connection=DuckDBConnectionConfig()), model_defaults=ModelDefaultsConfig(dialect='')) - """ + textwrap.dedent( + """ + from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig + + config = Config( + gateways=GatewayConfig(connection=DuckDBConnectionConfig()), + model_defaults=ModelDefaultsConfig(dialect='') + ) + """ + ) ) return config_path @@ -197,23 +207,27 @@ def test_load_config_no_dialect(tmp_path): create_temp_file( tmp_path, pathlib.Path("config.yaml"), - """ -gateways: - local: - connection: - type: duckdb - database: db.db -""", + textwrap.dedent( + """ + gateways: + local: + connection: + type: duckdb + database: db.db + """ + ), ) create_temp_file( tmp_path, pathlib.Path("config.py"), - """ -from sqlmesh.core.config import Config, DuckDBConnectionConfig + textwrap.dedent( + """ + from sqlmesh.core.config import Config, DuckDBConnectionConfig -config = Config(default_connection=DuckDBConnectionConfig()) -""", + config = Config(default_connection=DuckDBConnectionConfig()) + """ + ), ) with pytest.raises( @@ -231,17 +245,19 @@ def test_load_config_bad_model_default_key(tmp_path): create_temp_file( tmp_path, pathlib.Path("config.yaml"), - """ -gateways: - local: - connection: - type: duckdb - database: db.db - -model_defaults: - dialect: '' - test: 1 -""", + textwrap.dedent( + """ + gateways: + local: + connection: + type: duckdb + database: db.db + + model_defaults: + dialect: '' + test: 1 + """ + ), ) with pytest.raises( @@ -262,23 +278,27 @@ def test_load_python_config_with_personal_config(tmp_path): create_temp_file( tmp_path / "personal", pathlib.Path("config.yaml"), - """ -gateways: - local: - connection: - type: duckdb - database: db.db -""", + textwrap.dedent( + """ + gateways: + local: + connection: + type: duckdb + database: db.db + """ + ), ) create_temp_file( tmp_path, pathlib.Path("config.py"), - """ -from sqlmesh.core.config import Config, DuckDBConnectionConfig, ModelDefaultsConfig + textwrap.dedent( + """ + from sqlmesh.core.config import Config, DuckDBConnectionConfig, ModelDefaultsConfig -custom_config = Config(default_connection=DuckDBConnectionConfig(), model_defaults=ModelDefaultsConfig(dialect="duckdb")) -""", + custom_config = Config(default_connection=DuckDBConnectionConfig(), model_defaults=ModelDefaultsConfig(dialect="duckdb")) + """ + ), ) config = load_config_from_paths( Config, @@ -341,15 +361,17 @@ def test_load_yaml_config_env_var_gateway_override(tmp_path_factory): config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - testing: - connection: - type: motherduck - database: blah -model_defaults: - dialect: bigquery - """ + textwrap.dedent( + """ + gateways: + testing: + connection: + type: motherduck + database: blah + model_defaults: + dialect: bigquery + """ + ) ) with mock.patch.dict( os.environ, @@ -378,9 +400,16 @@ def test_load_py_config_env_var_gateway_override(tmp_path_factory): config_path = tmp_path_factory.mktemp("python_config") / "config.py" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig -config = Config(gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, model_defaults=ModelDefaultsConfig(dialect='')) - """ + textwrap.dedent( + """ + from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig + + config = Config( + gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, + model_defaults=ModelDefaultsConfig(dialect='') + ) + """ + ) ) with mock.patch.dict( os.environ, @@ -465,18 +494,20 @@ def test_environment_catalog_mapping(tmp_path_factory, mapping, expected, dialec config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - f""" -gateways: - local: - connection: - type: duckdb - -model_defaults: - dialect: {dialect} - -environment_catalog_mapping: - {mapping} - """ + textwrap.dedent( + f""" + gateways: + local: + connection: + type: duckdb + + model_defaults: + dialect: {dialect} + + environment_catalog_mapping: + {mapping} + """ + ) ) if raise_error: with pytest.raises(ConfigError, match=raise_error): @@ -530,6 +561,7 @@ def test_connection_config_serialization(): "register_comments": True, "type": "duckdb", "extensions": [], + "shared_connection": True, "pre_ping": False, "pretty_sql": False, "connector_config": {}, @@ -542,6 +574,7 @@ def test_connection_config_serialization(): "register_comments": True, "type": "duckdb", "extensions": [], + "shared_connection": True, "pre_ping": False, "pretty_sql": False, "connector_config": {}, @@ -580,23 +613,25 @@ def test_load_duckdb_attach_config(tmp_path): config_path = tmp_path / "config_duckdb_attach.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - another_gateway: - connection: - type: duckdb - catalogs: - memory: ':memory:' - sqlite: - type: 'sqlite' - path: 'test.db' - postgres: - type: 'postgres' - path: 'dbname=postgres user=postgres host=127.0.0.1' - read_only: true -model_defaults: - dialect: '' - """ + textwrap.dedent( + """ + gateways: + another_gateway: + connection: + type: duckdb + catalogs: + memory: ':memory:' + sqlite: + type: 'sqlite' + path: 'test.db' + postgres: + type: 'postgres' + path: 'dbname=postgres user=postgres host=127.0.0.1' + read_only: true + model_defaults: + dialect: '' + """ + ) ) config = load_config_from_paths( @@ -625,13 +660,15 @@ def test_load_model_defaults_audits(tmp_path): config_path = tmp_path / "config_model_defaults_audits.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -model_defaults: - dialect: '' - audits: - - assert_positive_order_ids - - does_not_exceed_threshold(column := id, threshold := 1000) - """ + textwrap.dedent( + """ + model_defaults: + dialect: '' + audits: + - assert_positive_order_ids + - does_not_exceed_threshold(column := id, threshold := 1000) + """ + ) ) config = load_config_from_paths( @@ -652,19 +689,21 @@ def test_load_model_defaults_statements(tmp_path): config_path = tmp_path / "config_model_defaults_statements.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -model_defaults: - dialect: duckdb - pre_statements: - - SET memory_limit = '10GB' - - CREATE TEMP TABLE temp_data AS SELECT 1 as id - post_statements: - - DROP TABLE IF EXISTS temp_data - - ANALYZE @this_model - - SET memory_limit = '5GB' - on_virtual_update: - - UPDATE stats_table SET last_update = CURRENT_TIMESTAMP - """ + textwrap.dedent( + """ + model_defaults: + dialect: duckdb + pre_statements: + - SET memory_limit = '10GB' + - CREATE TEMP TABLE temp_data AS SELECT 1 as id + post_statements: + - DROP TABLE IF EXISTS temp_data + - ANALYZE @this_model + - SET memory_limit = '5GB' + on_virtual_update: + - UPDATE stats_table SET last_update = CURRENT_TIMESTAMP + """ + ) ) config = load_config_from_paths( @@ -692,12 +731,14 @@ def test_load_model_defaults_validation_statements(tmp_path): config_path = tmp_path / "config_model_defaults_statements_wrong.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -model_defaults: - dialect: duckdb - pre_statements: - - 313 - """ + textwrap.dedent( + """ + model_defaults: + dialect: duckdb + pre_statements: + - 313 + """ + ) ) with pytest.raises(TypeError, match=r"expected str instance, int found"): @@ -711,18 +752,20 @@ def test_scheduler_config(tmp_path_factory): config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - builtin_gateway: - scheduler: - type: builtin - -default_scheduler: - type: builtin - -model_defaults: - dialect: bigquery - """ + textwrap.dedent( + """ + gateways: + builtin_gateway: + scheduler: + type: builtin + + default_scheduler: + type: builtin + + model_defaults: + dialect: bigquery + """ + ) ) config = load_config_from_paths( @@ -737,38 +780,40 @@ def test_multi_gateway_config(tmp_path, mocker: MockerFixture): config_path = tmp_path / "config_athena_redshift.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - redshift: - connection: - type: redshift - user: user - password: '1234' - host: host - database: db - test_connection: - type: redshift - database: test_db - state_connection: - type: duckdb - database: state.db - athena: - connection: - type: athena - aws_access_key_id: '1234' - aws_secret_access_key: accesskey - work_group: group - s3_warehouse_location: s3://location - duckdb: - connection: - type: duckdb - database: db.db - -default_gateway: redshift - -model_defaults: - dialect: redshift - """ + textwrap.dedent( + """ + gateways: + redshift: + connection: + type: redshift + user: user + password: '1234' + host: host + database: db + test_connection: + type: redshift + database: test_db + state_connection: + type: duckdb + database: state.db + athena: + connection: + type: athena + aws_access_key_id: '1234' + aws_secret_access_key: accesskey + work_group: group + s3_warehouse_location: s3://location + duckdb: + connection: + type: duckdb + database: db.db + + default_gateway: redshift + + model_defaults: + dialect: redshift + """ + ) ) config = load_config_from_paths( @@ -793,23 +838,25 @@ def test_multi_gateway_single_threaded_config(tmp_path): config_path = tmp_path / "config_duck_athena.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - duckdb: - connection: - type: duckdb - database: db.db - athena: - connection: - type: athena - aws_access_key_id: '1234' - aws_secret_access_key: accesskey - work_group: group - s3_warehouse_location: s3://location -default_gateway: duckdb -model_defaults: - dialect: duckdb - """ + textwrap.dedent( + """ + gateways: + duckdb: + connection: + type: duckdb + database: db.db + athena: + connection: + type: athena + aws_access_key_id: '1234' + aws_secret_access_key: accesskey + work_group: group + s3_warehouse_location: s3://location + default_gateway: duckdb + model_defaults: + dialect: duckdb + """ + ) ) config = load_config_from_paths( @@ -831,23 +878,25 @@ def test_trino_schema_location_mapping_syntax(tmp_path): config_path = tmp_path / "config_trino.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ - gateways: - trino: - connection: - type: trino - user: trino - host: trino - catalog: trino - schema_location_mapping: - '^utils$': 's3://utils-bucket/@{schema_name}' - '^landing\\..*$': 's3://raw-data/@{catalog_name}/@{schema_name}' - - default_gateway: trino - - model_defaults: - dialect: trino - """ + textwrap.dedent( + """ + gateways: + trino: + connection: + type: trino + user: trino + host: trino + catalog: trino + schema_location_mapping: + '^utils$': 's3://utils-bucket/@{schema_name}' + '^landing\\..*$': 's3://raw-data/@{catalog_name}/@{schema_name}' + + default_gateway: trino + + model_defaults: + dialect: trino + """ + ) ) config = load_config_from_paths( @@ -867,21 +916,23 @@ def test_trino_source_option(tmp_path): config_path = tmp_path / "config_trino_source.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ - gateways: - trino: - connection: - type: trino - user: trino - host: trino - catalog: trino - source: my_sqlmesh_source - - default_gateway: trino - - model_defaults: - dialect: trino - """ + textwrap.dedent( + """ + gateways: + trino: + connection: + type: trino + user: trino + host: trino + catalog: trino + source: my_sqlmesh_source + + default_gateway: trino + + model_defaults: + dialect: trino + """ + ) ) config = load_config_from_paths( @@ -900,26 +951,28 @@ def test_gcp_postgres_ip_and_scopes(tmp_path): config_path = tmp_path / "config_gcp_postgres.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ - gateways: - gcp_postgres: - connection: - type: gcp_postgres - check_import: false - instance_connection_string: something - user: user - password: password - db: db - ip_type: private - scopes: - - https://www.googleapis.com/auth/cloud-platform - - https://www.googleapis.com/auth/sqlservice.admin - - default_gateway: gcp_postgres - - model_defaults: - dialect: postgres - """ + textwrap.dedent( + """ + gateways: + gcp_postgres: + connection: + type: gcp_postgres + check_import: false + instance_connection_string: something + user: user + password: password + db: db + ip_type: private + scopes: + - https://www.googleapis.com/auth/cloud-platform + - https://www.googleapis.com/auth/sqlservice.admin + + default_gateway: gcp_postgres + + model_defaults: + dialect: postgres + """ + ) ) config = load_config_from_paths( @@ -971,12 +1024,14 @@ def test_model_defaults_cron_tz(tmp_path): config_path = tmp_path / "config_model_defaults_cron_tz.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -model_defaults: - dialect: duckdb - cron: '@daily' - cron_tz: 'America/Los_Angeles' - """ + textwrap.dedent( + """ + model_defaults: + dialect: duckdb + cron: '@daily' + cron_tz: 'America/Los_Angeles' + """ + ) ) config = load_config_from_paths( @@ -1025,29 +1080,31 @@ def test_redshift_merge_flag(tmp_path, mocker: MockerFixture): config_path = tmp_path / "config_redshift_merge.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ -gateways: - redshift: - connection: - type: redshift - user: user - password: '1234' - host: host - database: db - enable_merge: true - default: - connection: - type: redshift - user: user - password: '1234' - host: host - database: db - -default_gateway: redshift - -model_defaults: - dialect: redshift - """ + textwrap.dedent( + """ + gateways: + redshift: + connection: + type: redshift + user: user + password: '1234' + host: host + database: db + enable_merge: true + default: + connection: + type: redshift + user: user + password: '1234' + host: host + database: db + + default_gateway: redshift + + model_defaults: + dialect: redshift + """ + ) ) config = load_config_from_paths( @@ -1070,28 +1127,30 @@ def test_environment_statements_config(tmp_path): config_path = tmp_path / "config_before_after_all.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """ - gateways: - postgres: - connection: - type: postgres - database: db - user: postgres - password: postgres - host: localhost - port: 5432 - - default_gateway: postgres - - before_all: - - CREATE TABLE IF NOT EXISTS custom_analytics (physical_table VARCHAR, evaluation_time VARCHAR); - after_all: - - "@grant_schema_privileges()" - - "GRANT REFERENCES ON FUTURE VIEWS IN DATABASE db TO ROLE admin_role;" - - model_defaults: - dialect: postgres - """ + textwrap.dedent( + """ + gateways: + postgres: + connection: + type: postgres + database: db + user: postgres + password: postgres + host: localhost + port: 5432 + + default_gateway: postgres + + before_all: + - CREATE TABLE IF NOT EXISTS custom_analytics (physical_table VARCHAR, evaluation_time VARCHAR); + after_all: + - "@grant_schema_privileges()" + - "GRANT REFERENCES ON FUTURE VIEWS IN DATABASE db TO ROLE admin_role;" + + model_defaults: + dialect: postgres + """ + ) ) config = load_config_from_paths( @@ -1124,18 +1183,22 @@ class ConfigSubclass(Config): ... def test_config_complex_types_supplied_as_json_strings_from_env(tmp_path: Path) -> None: config_path = tmp_path / "config.yaml" - config_path.write_text(""" - gateways: - bigquery: - connection: - type: bigquery - project: unit-test - - default_gateway: bigquery - - model_defaults: - dialect: bigquery -""") + config_path.write_text( + textwrap.dedent( + """ + gateways: + bigquery: + connection: + type: bigquery + project: unit-test + + default_gateway: bigquery + + model_defaults: + dialect: bigquery + """ + ) + ) with mock.patch.dict( os.environ, { @@ -1158,20 +1221,24 @@ def test_config_complex_types_supplied_as_json_strings_from_env(tmp_path: Path) def test_config_user_macro_function(tmp_path: Path) -> None: config_path = tmp_path / "config.yaml" - config_path.write_text(""" - gateways: - bigquery: - connection: - type: bigquery - project: unit-test + config_path.write_text( + textwrap.dedent( + """ + gateways: + bigquery: + connection: + type: bigquery + project: unit-test - default_gateway: bigquery + default_gateway: bigquery - model_defaults: - dialect: bigquery + model_defaults: + dialect: bigquery - default_target_environment: dev_{{ user() }} -""") + default_target_environment: dev_{{ user() }} + """ + ) + ) with mock.patch("getpass.getuser", return_value="test_user"): config = load_config_from_paths( @@ -1184,19 +1251,23 @@ def test_config_user_macro_function(tmp_path: Path) -> None: def test_environment_suffix_target_catalog(tmp_path: Path) -> None: config_path = tmp_path / "config.yaml" - config_path.write_text(""" - gateways: - warehouse: - connection: - type: duckdb + config_path.write_text( + textwrap.dedent( + """ + gateways: + warehouse: + connection: + type: duckdb - default_gateway: warehouse + default_gateway: warehouse - model_defaults: - dialect: duckdb + model_defaults: + dialect: duckdb - environment_suffix_target: catalog -""") + environment_suffix_target: catalog + """ + ) + ) config = load_config_from_paths( Config, @@ -1206,22 +1277,26 @@ def test_environment_suffix_target_catalog(tmp_path: Path) -> None: assert config.environment_suffix_target == EnvironmentSuffixTarget.CATALOG assert not config.environment_catalog_mapping - config_path.write_text(""" - gateways: - warehouse: - connection: - type: duckdb + config_path.write_text( + textwrap.dedent( + """ + gateways: + warehouse: + connection: + type: duckdb - default_gateway: warehouse + default_gateway: warehouse - model_defaults: - dialect: duckdb + model_defaults: + dialect: duckdb - environment_suffix_target: catalog + environment_suffix_target: catalog - environment_catalog_mapping: - '.*': "foo" -""") + environment_catalog_mapping: + '.*': "foo" + """ + ) + ) with pytest.raises(ConfigError, match=r"mutually exclusive"): config = load_config_from_paths( @@ -1235,9 +1310,16 @@ def test_load_python_config_dot_env_vars(tmp_path_factory): config_path = main_dir / "config.py" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig -config = Config(gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, model_defaults=ModelDefaultsConfig(dialect='')) - """ + textwrap.dedent( + """ + from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig + + config = Config( + gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, + model_defaults=ModelDefaultsConfig(dialect='') + ) + """ + ) ) # The environment variable value from the dot env file should be set @@ -1245,10 +1327,13 @@ def test_load_python_config_dot_env_vars(tmp_path_factory): dot_path = main_dir / ".env" with open(dot_path, "w", encoding="utf-8") as fd: fd.write( - """SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__TYPE="bigquery" -SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__CHECK_IMPORT="false" -SQLMESH__DEFAULT_GATEWAY="duckdb_gateway" - """ + textwrap.dedent( + """ + SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__TYPE="bigquery" + SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__CHECK_IMPORT="false" + SQLMESH__DEFAULT_GATEWAY="duckdb_gateway" + """ + ) ) # Use mock.patch.dict to isolate environment variables between the tests @@ -1276,22 +1361,25 @@ def test_load_yaml_config_dot_env_vars(tmp_path_factory): config_path = main_dir / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """gateways: - duckdb_gateway: - connection: - type: duckdb - catalogs: - local: local.db - cloud_sales: {{ env_var('S3_BUCKET') }} - extensions: - - name: httpfs - secrets: - - type: "s3" - key_id: {{ env_var('S3_KEY') }} - secret: {{ env_var('S3_SECRET') }} -model_defaults: - dialect: "" -""" + textwrap.dedent( + """ + gateways: + duckdb_gateway: + connection: + type: duckdb + catalogs: + local: local.db + cloud_sales: {{ env_var('S3_BUCKET') }} + extensions: + - name: httpfs + secrets: + - type: "s3" + key_id: {{ env_var('S3_KEY') }} + secret: {{ env_var('S3_SECRET') }} + model_defaults: + dialect: "" + """ + ) ) # This test checks both using SQLMESH__ prefixed environment variables with underscores @@ -1299,12 +1387,15 @@ def test_load_yaml_config_dot_env_vars(tmp_path_factory): dot_path = main_dir / ".env" with open(dot_path, "w", encoding="utf-8") as fd: fd.write( - """S3_BUCKET="s3://metrics_bucket/sales.db" -S3_KEY="S3_KEY_ID" -S3_SECRET="XXX_S3_SECRET_XXX" -SQLMESH__DEFAULT_GATEWAY="duckdb_gateway" -SQLMESH__MODEL_DEFAULTS__DIALECT="athena" -""" + textwrap.dedent( + """ + S3_BUCKET="s3://metrics_bucket/sales.db" + S3_KEY="S3_KEY_ID" + S3_SECRET="XXX_S3_SECRET_XXX" + SQLMESH__DEFAULT_GATEWAY="duckdb_gateway" + SQLMESH__MODEL_DEFAULTS__DIALECT="athena" + """ + ) ) # Use mock.patch.dict to isolate environment variables between the tests @@ -1338,14 +1429,17 @@ def test_load_config_dotenv_directory_not_loaded(tmp_path_factory): config_path = main_dir / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """gateways: - test_gateway: - connection: - type: duckdb - database: test.db -model_defaults: - dialect: duckdb -""" + textwrap.dedent( + """ + gateways: + test_gateway: + connection: + type: duckdb + database: test.db + model_defaults: + dialect: duckdb + """ + ) ) # Create a .env directory instead of a file to simulate a Python virtual environment @@ -1358,14 +1452,17 @@ def test_load_config_dotenv_directory_not_loaded(tmp_path_factory): other_config_path = other_dir / "config.yaml" with open(other_config_path, "w", encoding="utf-8") as fd: fd.write( - """gateways: - test_gateway: - connection: - type: duckdb - database: test.db -model_defaults: - dialect: duckdb -""" + textwrap.dedent( + """ + gateways: + test_gateway: + connection: + type: duckdb + database: test.db + model_defaults: + dialect: duckdb + """ + ) ) env_file = other_dir / ".env" @@ -1398,12 +1495,15 @@ def test_load_yaml_config_custom_dotenv_path(tmp_path_factory): config_path = main_dir / "config.yaml" with open(config_path, "w", encoding="utf-8") as fd: fd.write( - """gateways: - test_gateway: - connection: - type: duckdb - database: {{ env_var('DB_NAME') }} -""" + textwrap.dedent( + """ + gateways: + test_gateway: + connection: + type: duckdb + database: {{ env_var('DB_NAME') }} + """ + ) ) # Create a custom dot env file in a different location @@ -1411,10 +1511,13 @@ def test_load_yaml_config_custom_dotenv_path(tmp_path_factory): custom_env_path = custom_env_dir / ".my_env" with open(custom_env_path, "w", encoding="utf-8") as fd: fd.write( - """DB_NAME="custom_database.db" -SQLMESH__DEFAULT_GATEWAY="test_gateway" -SQLMESH__MODEL_DEFAULTS__DIALECT="postgres" -""" + textwrap.dedent( + """ + DB_NAME="custom_database.db" + SQLMESH__DEFAULT_GATEWAY="test_gateway" + SQLMESH__MODEL_DEFAULTS__DIALECT="postgres" + """ + ) ) # Test that without custom dotenv path, env vars are not loaded @@ -1463,15 +1566,19 @@ def test_physical_table_naming_convention( convention_str: t.Optional[str], expected: t.Optional[TableNamingConvention], tmp_path: Path ): config_part = f"physical_table_naming_convention: {convention_str}" if convention_str else "" - (tmp_path / "config.yaml").write_text(f""" -gateways: - test_gateway: - connection: - type: duckdb -model_defaults: - dialect: duckdb -{config_part} - """) + (tmp_path / "config.yaml").write_text( + textwrap.dedent( + f""" + gateways: + test_gateway: + connection: + type: duckdb + model_defaults: + dialect: duckdb + {config_part} + """ + ) + ) config = load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) assert config.physical_table_naming_convention == expected @@ -1480,10 +1587,14 @@ def test_physical_table_naming_convention( def test_load_configs_includes_sqlmesh_yaml(tmp_path: Path): for extension in ("yaml", "yml"): config_file = tmp_path / f"sqlmesh.{extension}" - config_file.write_text(""" -model_defaults: - start: '2023-04-05' - dialect: bigquery""") + config_file.write_text( + textwrap.dedent( + """ + model_defaults: + start: '2023-04-05' + dialect: bigquery + """ + ) configs = load_configs(config=None, config_type=Config, paths=[tmp_path]) assert len(configs) == 1 @@ -1525,29 +1636,41 @@ def test_load_configs_in_dbt_project_without_config_py(tmp_path: Path): # - uses the sqlmesh_dbt cli for the first time, which runs init if the config doesnt exist, which creates a config # when in pure yaml mode, sqlmesh should be able to auto-detect the presence of DBT and select the DbtLoader instead # of the main loader - (tmp_path / "dbt_project.yml").write_text(""" -name: jaffle_shop - """) - - (tmp_path / "profiles.yml").write_text(""" -jaffle_shop: - - target: dev - outputs: - dev: - type: duckdb - path: 'jaffle_shop.duckdb' - """) - - (tmp_path / "sqlmesh.yaml").write_text(""" -gateways: - dev: - state_connection: - type: duckdb - database: state.db -model_defaults: - start: '2020-01-01' -""") + (tmp_path / "dbt_project.yml").write_text( + textwrap.dedent( + """ + name: jaffle_shop + """ + ) + ) + + (tmp_path / "profiles.yml").write_text( + textwrap.dedent( + """ + jaffle_shop: + + target: dev + outputs: + dev: + type: duckdb + path: 'jaffle_shop.duckdb' + """ + ) + ) + + (tmp_path / "sqlmesh.yaml").write_text( + textwrap.dedent( + """ + gateways: + dev: + state_connection: + type: duckdb + database: state.db + model_defaults: + start: '2020-01-01' + """ + ) + ) configs = list(load_configs(config=None, config_type=Config, paths=[tmp_path]).values()) assert len(configs) == 1 diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index e1813df9b9..9c080eb560 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1426,18 +1426,19 @@ def test_databricks(make_config): ) -def test_databricks_shared_connection(make_config): - """Databricks should use a shared connection pool to prevent OAuth CSRF races. +def test_databricks__u2m_oauth__shared_connection_pool(make_config): + """Databricks should use a shared connection pool when using OAuth to prevent CSRF races. When concurrent_tasks > 1, ThreadLocalConnectionPool creates one connection per thread. For U2M OAuth, each thread triggers its own browser-based OAuth flow; these race on the CSRF state parameter and cause MismatchingStateError. - Setting shared_connection = True causes ThreadLocalSharedConnectionPool to be - used instead: a single connection is created (behind a lock) and each thread - gets its own cursor, so only one OAuth flow is ever initiated. + For non-U2M OAuth authentication types (e.g. access_token and M2M OAuth) then + ThreadLocalConnectionPool should still be used. - See: https://github.com/tobymao/sqlmesh/issues/5646 + See: + https://github.com/tobymao/sqlmesh/issues/5646 + https://github.com/SQLMesh/sqlmesh/issues/5858 """ from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool @@ -1445,7 +1446,7 @@ def test_databricks_shared_connection(make_config): type="databricks", server_hostname="dbc-test.cloud.databricks.com", http_path="sql/test/foo", - access_token="test-token", + auth_type="databricks-oauth", concurrent_tasks=4, ) assert isinstance(config, DatabricksConnectionConfig) @@ -1455,6 +1456,41 @@ def test_databricks_shared_connection(make_config): assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool) +def test_databricks__m2m_oauth__connection_pool(make_config): + from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + auth_type="databricks-oauth", + oauth_client_id="oauth_client_id", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is False + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool) + + +def test_databricks__access_token__connection_pool(make_config): + from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool + + config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + access_token="any-token", + concurrent_tasks=4, + ) + assert isinstance(config, DatabricksConnectionConfig) + assert config.shared_connection is False + + adapter = config.create_engine_adapter() + assert isinstance(adapter._connection_pool, ThreadLocalConnectionPool) + + def test_engine_import_validator(): with pytest.raises( ConfigError,