From 474029bdb12d08f3bdd9123cc955903dc0f02d74 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Tue, 3 Jun 2025 18:12:30 +0200 Subject: [PATCH 01/17] Implementation of environment configuration for Python SDK --- .DS_Store | Bin 0 -> 8196 bytes scripts/.DS_Store | Bin 0 -> 6148 bytes temporalio/.DS_Store | Bin 0 -> 6148 bytes temporalio/bridge/.DS_Store | Bin 0 -> 6148 bytes temporalio/bridge/Cargo.lock | 195 +++++++++++++- temporalio/bridge/Cargo.toml | 2 +- temporalio/bridge/src/envconfig.rs | 238 +++++++++++++++++ temporalio/bridge/src/lib.rs | 28 ++ temporalio/envconfig.py | 332 +++++++++++++++++++++++ tests/.DS_Store | Bin 0 -> 6148 bytes tests/test_envconfig.py | 406 +++++++++++++++++++++++++++++ 11 files changed, 1186 insertions(+), 15 deletions(-) create mode 100644 .DS_Store create mode 100644 scripts/.DS_Store create mode 100644 temporalio/.DS_Store create mode 100644 temporalio/bridge/.DS_Store create mode 100644 temporalio/bridge/src/envconfig.rs create mode 100644 temporalio/envconfig.py create mode 100644 tests/.DS_Store create mode 100644 tests/test_envconfig.py diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..5f1d0dc2347d41a3f7a9e7deeec8af1d9465e842 GIT binary patch literal 8196 zcmeHM&ubGw6n>K*CKyt3l2R{=3W|r?f=7|Hrr=czUfZUrjgoZJq$w6JiynIs50ayJ zQUnqI2df?gFIwtFJqUuAqDb+i-~4d*rMsK;Bo$}E%-hcU-n?(W{V}{P5s9f@eUfOD zh>GaaR!SIRn)J&aX&Ezf0#+cNnoE;23tp=oLD~tcfK|XMU=^?mSOxZr0@$-wDe>l*883lEymp|Cnsm=Z(i zaLhXf=Njv4tHVi{;zO9s!c-_iWykuCf|GEywWC$QDiBpb*6ukPrzN^dtGs>}k1k4d z;A?r+YOU$jFjp*m&Get{Zk`+T*N6Jsjquxm_^Cb-6n^KW2KA`MvlETK!RmVuoWs_s z!qJJo;t)ZFOp;NY50T8U6W!^7^AI%WedPQ51DCfI@1YSP_dko^{E0q*kW#UO zD}M#QMDCAoCL=lTk^F__8y^F`W7chXA?I}T0ADd7`R1Yb8nj-fHGWT&Pw(Cx^m1Nv z=V!;xep9?yJm(UR;Q2?6srxfWbv&iG(mcD^n7)MUfsg0ju2AW^)`a;(72^Cnuh{Uh zG@imPo`eo^w{e$sd3F+?ccKW+;ZtYx)q|Jn{P(QGHAqHrK9q5hj1H_`+Dp4SQ3tNdw9RR(%0{?*flxBG%VwpeS7PP6llg743Jg;qtDU2={+~|%{y%I9>$3`21^#gbM6Nnhox(|&+B$MNvetIc z=g?)7IA2>8f|09MDO?dg2iO~bPN*Y5x7%GzN+ GEAR_@5va=m literal 0 HcmV?d00001 diff --git a/scripts/.DS_Store b/scripts/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b37cc5bc6e4994e03ca13b3f9a1572a75c235cdf GIT binary patch literal 6148 zcmeHKOHRWu5S=Lz1RgKh$j#8G0~8SN~qvqAHz2ye$k!`l0}>>dn^~@>2;`^UZfd* zD+BWG3c93Qs%eqEKVH3eeRZg+*)q&vD<7ZsPEMa+&vXBLtG~JF{5JRnskDo_4c$?o zdkR@?IvjoVck36Uv*XoUte@>shtD%?ipkDFUd~}+3>X8(z<*-^JzJzaP_)q)Fb0f) zH3RZ}a8N;?uvLto4h#b<#a{@DU@mtFL6^`cY!%^wu%-evmD3f2H63=B;(WqZQPYWI zW$cqGbGo5$tPZ=2;ly6iMq|Jj$Qjtt;b3~*6Rs|g+{_150Q uNw0O#E2xOXwTkN$OkyiWthC}IXcX9;On^RNs|X9k{s=@GY%m6Xlz|T=3Po!G literal 0 HcmV?d00001 diff --git a/temporalio/.DS_Store b/temporalio/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..57d0adee2e038f3b6e33c26797bc59866ad16b7c GIT binary patch literal 6148 zcmeHLze@u#6n?R#7Eht0h{6pnii7?KXQ|GDvsBwkEn2VEU+ChkZcZXD{sRs=xjKlr z2o55S`X2}m4laH#i8kqZf{Uo+LGtC2m-oK=jwDTqNaWWmV?;S3%Azo)yHI^%JkBL% z9lEC%C}fN>%~6>aXr78OZ`a`za0>i31$gaxSgl2B&=%HiYyB3Ns$`0(iQa{Mdr$~U|l3nq&PgryW%HmGLt8{XJ%jhK7}?z&bEUL+qq zuO7DY$49MvqvxD2LnUfbjVe^7c_VA;W_M_^eca5w)4l1Ps@zr#lR=V?!6S7y^29tk z@qI#CrZvf?WH7=0+xHsxFuC0Jmq*U-r4DUzOg7101|MXbQhe61)31V04Xmm*pTR@P z$MP`w%y>QT-=z}~j5sErWG|b~Ec40GCiH32BBC`dHs1AB`?$&FdGCXZ6~knbvgaC9eR$G`gjso#O`3*BEGw6k-G> z)Ks9FD)bdYsOjkUwO^nyQmCer(3cOPGYfq~5i&dC`_i36pwQJ$0jGeiKuSI4`23$p zzW>{k+>=wlDe$ip5b45nA&*zmXKU=`_^kC&W>MJKFH$I{pwh>&Z17Rc|0_ttn9mJh VpfOU29+>+fAZ2isQ{YDxcmo@u*X;lR literal 0 HcmV?d00001 diff --git a/temporalio/bridge/.DS_Store b/temporalio/bridge/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..6c18b60901085091c24d1b5c6f05be0ce6cedb1a GIT binary patch literal 6148 zcmeHKu};H44E41Y3b1r!yoHe=e-Ns`#sC!(BebcCM5|P#>={`22L6DFZ(v{mi9g^g z7+85eTWQm@5)(p|E&1N%Y~MLAQ5+MI>)d5AQJaV&C}U$A!wO+LYfT!u!bTVN=u=Kp z8jOZt z(R?_2<~&S{0b{@z*fa)Evst2DK^u($W55_F8Q|}OhcYILg<$w}U PyResult { + let dict = PyDict::new(py); + match ds { + DataSource::Path(p) => dict.set_item("path", p)?, + DataSource::Data(d) => dict.set_item("data", PyBytes::new(py, d))?, + }; + Ok(dict.to_object(py)) +} + +fn tls_to_dict(py: Python, tls: &CoreClientConfigTLS) -> PyResult { + let dict = PyDict::new(py); + dict.set_item("disabled", tls.disabled)?; + if let Some(v) = &tls.client_cert { + dict.set_item("client_cert", data_source_to_dict(py, v)?)?; + } + if let Some(v) = &tls.client_key { + dict.set_item("client_key", data_source_to_dict(py, v)?)?; + } + if let Some(v) = &tls.server_ca_cert { + dict.set_item("server_ca_cert", data_source_to_dict(py, v)?)?; + } + if let Some(v) = &tls.server_name { + dict.set_item("server_name", v)?; + } + dict.set_item("disable_host_verification", tls.disable_host_verification)?; + Ok(dict.to_object(py)) +} + +fn codec_to_dict(py: Python, codec: &ClientConfigCodec) -> PyResult { + let dict = PyDict::new(py); + if let Some(v) = &codec.endpoint { + dict.set_item("endpoint", v)?; + } + if let Some(v) = &codec.auth { + dict.set_item("auth", v)?; + } + Ok(dict.to_object(py)) +} + +fn profile_to_dict(py: Python, profile: &CoreClientConfigProfile) -> PyResult { + let dict = PyDict::new(py); + if let Some(v) = &profile.address { + dict.set_item("address", v)?; + } + if let Some(v) = &profile.namespace { + dict.set_item("namespace", v)?; + } + if let Some(v) = &profile.api_key { + dict.set_item("api_key", v)?; + } + if let Some(tls) = &profile.tls { + dict.set_item("tls", tls_to_dict(py, tls)?)?; + } + if let Some(codec) = &profile.codec { + dict.set_item("codec", codec_to_dict(py, codec)?)?; + } + if !profile.grpc_meta.is_empty() { + dict.set_item("grpc_meta", profile.grpc_meta.to_object(py))?; + } + Ok(dict.to_object(py)) +} + +fn core_config_to_dict(py: Python, core_config: &CoreClientConfig) -> PyResult { + let profiles_dict = PyDict::new(py); + for (name, profile) in &core_config.profiles { + let connect_dict = profile_to_dict(py, profile)?; + profiles_dict.set_item(name, connect_dict)?; + } + Ok(profiles_dict.to_object(py)) +} + +fn load_client_config_inner( + py: Python, + config_source: Option, + config_file_strict: bool, + disable_file: bool, + env_vars: Option>, +) -> PyResult { + let core_config = if disable_file { + CoreClientConfig::default() + } else { + let options = LoadClientConfigOptions { + config_source, + config_file_strict, + }; + core_load_client_config(options, env_vars.as_ref()) + .map_err(|e| ConfigError::new_err(format!("{}", e)))? + }; + + core_config_to_dict(py, &core_config) +} + +fn load_client_connect_config_inner( + py: Python, + config_source: Option, + profile: Option, + disable_file: bool, + disable_env: bool, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + let options = LoadClientConfigProfileOptions { + config_source, + config_file_profile: profile, + config_file_strict, + disable_file, + disable_env, + }; + + let profile = core_load_client_config_profile(options, env_vars.as_ref()) + .map_err(|e| ConfigError::new_err(format!("{}", e)))?; + + profile_to_dict(py, &profile) +} + +#[pyfunction] +#[pyo3(signature = (disable_file, config_file_strict, env_vars = None))] +pub fn load_client_config( + py: Python, + disable_file: bool, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + load_client_config_inner(py, None, config_file_strict, disable_file, env_vars) +} + +#[pyfunction] +#[pyo3(signature = (path, config_file_strict, env_vars = None))] +pub fn load_client_config_from_file( + py: Python, + path: String, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + load_client_config_inner( + py, + Some(DataSource::Path(path)), + config_file_strict, + false, + env_vars, + ) +} + +#[pyfunction] +#[pyo3(signature = (data, config_file_strict, env_vars = None))] +pub fn load_client_config_from_data( + py: Python, + data: Vec, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + load_client_config_inner( + py, + Some(DataSource::Data(data)), + config_file_strict, + false, + env_vars, + ) +} + +#[pyfunction] +#[pyo3(signature = (profile, disable_file, disable_env, config_file_strict, env_vars = None))] +pub fn load_client_connect_config( + py: Python, + profile: Option, + disable_file: bool, + disable_env: bool, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + load_client_connect_config_inner( + py, + None, + profile, + disable_file, + disable_env, + config_file_strict, + env_vars, + ) +} + +#[pyfunction] +#[pyo3(signature = (path, profile, disable_env, config_file_strict, env_vars = None))] +pub fn load_client_connect_config_from_file( + py: Python, + path: String, + profile: Option, + disable_env: bool, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + load_client_connect_config_inner( + py, + Some(DataSource::Path(path)), + profile, + false, + disable_env, + config_file_strict, + env_vars, + ) +} + +#[pyfunction] +#[pyo3(signature = (data, profile, disable_env, config_file_strict, env_vars = None))] +pub fn load_client_connect_config_from_data( + py: Python, + data: Vec, + profile: Option, + disable_env: bool, + config_file_strict: bool, + env_vars: Option>, +) -> PyResult { + load_client_connect_config_inner( + py, + Some(DataSource::Data(data)), + profile, + false, + disable_env, + config_file_strict, + env_vars, + ) +} diff --git a/temporalio/bridge/src/lib.rs b/temporalio/bridge/src/lib.rs index bac2ff6e5..73f6a435e 100644 --- a/temporalio/bridge/src/lib.rs +++ b/temporalio/bridge/src/lib.rs @@ -2,6 +2,7 @@ use pyo3::prelude::*; use pyo3::types::PyTuple; mod client; +mod envconfig; mod metric; mod runtime; mod testing; @@ -54,6 +55,33 @@ fn temporal_sdk_bridge(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_function(wrap_pyfunction!(new_worker, m)?)?; m.add_function(wrap_pyfunction!(new_replay_worker, m)?)?; + + // envconfig + let envconfig_module = PyModule::new(py, "envconfig")?; + envconfig_module.add("ConfigError", py.get_type::())?; + envconfig_module.add_function(wrap_pyfunction!(envconfig::load_client_config, m)?)?; + envconfig_module.add_function(wrap_pyfunction!( + envconfig::load_client_config_from_data, + m + )?)?; + envconfig_module.add_function(wrap_pyfunction!( + envconfig::load_client_config_from_file, + m + )?)?; + envconfig_module.add_function(wrap_pyfunction!( + envconfig::load_client_connect_config, + m + )?)?; + envconfig_module.add_function(wrap_pyfunction!( + envconfig::load_client_connect_config_from_data, + m + )?)?; + envconfig_module.add_function(wrap_pyfunction!( + envconfig::load_client_connect_config_from_file, + m + )?)?; + m.add_submodule(envconfig_module)?; + Ok(()) } diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py new file mode 100644 index 000000000..def350c98 --- /dev/null +++ b/temporalio/envconfig.py @@ -0,0 +1,332 @@ +"""Environment and file-based configuration for Temporal clients. + +This module provides utilities to load Temporal client configuration from TOML files +and environment variables, following the same patterns as the Go SDK. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Mapping, Optional, Union + +from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig +from temporalio.service import ConnectConfig, TLSConfig + + +@dataclass(frozen=True) +class _DataSource: + path: Optional[str] = None + data: Optional[bytes] = None + + @staticmethod + def from_dict(d: Optional[Mapping[str, Any]]) -> Optional[_DataSource]: + if not d: + return None + return _DataSource(path=d.get("path"), data=d.get("data")) + + def read(self) -> Optional[bytes]: + if self.data: + return self.data + if self.path: + with open(self.path, "rb") as f: + return f.read() + return None + + +@dataclass(frozen=True) +class ClientConfigTls: + """TLS configuration as specified as part of client configuration""" + + disabled: bool = False + """If true, TLS is explicitly disabled.""" + server_name: Optional[str] = None + """SNI override.""" + server_root_ca_cert: Optional[_DataSource] = None + """Server CA certificate source.""" + client_cert: Optional[_DataSource] = None + """Client certificate source.""" + client_private_key: Optional[_DataSource] = None + """Client key source.""" + + def to_connect_tls_config(self) -> Union[bool, TLSConfig]: + """Create a `temporalio.service.TLSConfig` from this profile.""" + if self.disabled: + return False + + def _read(ds: Optional[_DataSource]) -> Optional[bytes]: + return ds.read() if ds else None + + return TLSConfig( + domain=self.server_name, + server_root_ca_cert=_read(self.server_root_ca_cert), + client_cert=_read(self.client_cert), + client_private_key=_read(self.client_private_key), + ) + + @staticmethod + def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTls]: + if not d: + return None + return ClientConfigTls( + disabled=d.get("disabled", False), + server_name=d.get("server_name"), + # Note: Bridge uses snake_case, but TOML uses kebab-case which is + # converted to snake_case. Core has server_ca_cert, client_key. + server_root_ca_cert=_DataSource.from_dict(d.get("server_ca_cert")), + client_cert=_DataSource.from_dict(d.get("client_cert")), + client_private_key=_DataSource.from_dict(d.get("client_key")), + ) + + +@dataclass(frozen=True) +class ClientConfigProfile: + """Represents a client configuration profile. + + This class holds the configuration as loaded from a file or environment. + See `to_connect_config` to transform the profile to `temporalio.service.ConnectConfig`, + which can be used to create a client. + """ + + address: Optional[str] = None + """Client address.""" + namespace: Optional[str] = None + """Client namespace.""" + api_key: Optional[str] = None + """Client API key.""" + tls: Optional[ClientConfigTls] = None + """TLS configuration.""" + grpc_meta: Mapping[str, str] = field(default_factory=dict) + """gRPC metadata.""" + + @staticmethod + def from_dict(d: Mapping[str, Any]) -> ClientConfigProfile: + """Create a ClientConfigProfile from a dictionary from the bridge.""" + return ClientConfigProfile( + address=d.get("address"), + namespace=d.get("namespace"), + api_key=d.get("api_key"), + tls=ClientConfigTls._from_dict(d.get("tls")), + grpc_meta=d.get("grpc_meta") or {}, + ) + + def to_connect_config(self) -> ConnectConfig: + """Create a `temporalio.service.ConnectConfig` from this profile.""" + # Create a dictionary of kwargs for ConnectConfig + kwargs: dict[str, Any] = {"api_key": self.api_key} + + # Target host + if self.address: + kwargs["target_host"] = self.address + + # Metadata + rpc_metadata = dict(self.grpc_meta) + if self.namespace: + rpc_metadata["namespace"] = self.namespace + if rpc_metadata: + kwargs["rpc_metadata"] = rpc_metadata + + # TLS + if self.tls: + kwargs["tls"] = self.tls.to_connect_tls_config() + + return ConnectConfig(**{k: v for k, v in kwargs.items() if v is not None}) + + +@dataclass +class ClientConfig: + """Client configuration loaded from TOML and environment variables. + + This contains a mapping of profile names to client profiles. Use + `ClientConfigProfile.to_connect_config` to create a `temporalio.service.ConnectConfig` + from a profile. See `load_profile` to load an individual profile. + """ + + profiles: Mapping[str, ClientConfigProfile] + """Map of profile name to its corresponding ClientConfigProfile.""" + + @staticmethod + def _from_bridge_profiles( + bridge_profiles: Mapping[str, Mapping[str, Any]], + ) -> ClientConfig: + return ClientConfig( + profiles={ + k: ClientConfigProfile.from_dict(v) for k, v in bridge_profiles.items() + } + ) + + @staticmethod + def load_profiles( + *, + disable_file: bool = False, + config_file_strict: bool = False, + env_vars: Optional[Mapping[str, str]] = None, + ) -> ClientConfig: + """Load all client profiles from default file locations and environment variables. + + This does not apply environment variable overrides to the profiles, it + only uses an environment variable to find the default config file path + (`TEMPORAL_CONFIG_FILE`). To get a single profile with environment variables + applied, use `load_profile`. + + Args: + disable_file: If true, file loading is disabled. Will create a default + configuration. + config_file_strict: If true, will TOML file parsing will error on + unrecognized keys. + env_vars: The environment variables to use for locating the default config + file. If not provided, `TEMPORAL_CONFIG_FILE` is not checked + and only the default path is used (./temporal/temporal.toml). To use + the current process's environment, `os.environ` can be passed explicitly. + """ + loaded_profiles = _bridge_envconfig.load_client_config( + disable_file=disable_file, + config_file_strict=config_file_strict, + env_vars=env_vars, + ) + return ClientConfig._from_bridge_profiles(loaded_profiles) + + @staticmethod + def load_profiles_from_file( + config_file: str, + *, + config_file_strict: bool = False, + ) -> ClientConfig: + """Load all client profiles from a specific file.""" + loaded_profiles = _bridge_envconfig.load_client_config_from_file( + path=config_file, + config_file_strict=config_file_strict, + ) + return ClientConfig._from_bridge_profiles(loaded_profiles) + + @staticmethod + def load_profiles_from_data( + config_file_data: Union[str, bytes], + *, + config_file_strict: bool = False, + ) -> ClientConfig: + """Load all client profiles from specific data.""" + data_bytes = ( + config_file_data.encode("utf-8") + if isinstance(config_file_data, str) + else config_file_data + ) + loaded_profiles = _bridge_envconfig.load_client_config_from_data( + data=data_bytes, + config_file_strict=config_file_strict, + ) + return ClientConfig._from_bridge_profiles(loaded_profiles) + + @staticmethod + def load_profile( + profile: str = "default", + *, + disable_file: bool = False, + disable_env: bool = False, + config_file_strict: bool = False, + env_vars: Optional[Mapping[str, str]] = None, + ) -> ClientConfigProfile: + """Load a single client profile from default sources, applying env + overrides. + + To get a `temporalio.service.ConnectConfig`, use the + `ClientConfigProfile.to_connect_config` method on the returned profile. + + Args: + profile: Profile to load from the config. + disable_file: If true, file loading is disabled. + disable_env: If true, environment variable loading and overriding + is disabled. This takes precedence over the ``env_vars`` + parameter. + config_file_strict: If true, will error on unrecognized keys. + env_vars: The environment to use for loading and overrides. If not + provided, environment variables are not used for overrides. To + use the current process's environment, `os.environ` can be + passed explicitly. + + Returns: + The client configuration profile. + """ + if disable_file and disable_env: + raise ValueError("Cannot disable both file and environment loading") + + raw_profile = _bridge_envconfig.load_client_connect_config( + profile=profile, + disable_file=disable_file, + disable_env=disable_env, + config_file_strict=config_file_strict, + env_vars=env_vars, + ) + return ClientConfigProfile.from_dict(raw_profile) + + @staticmethod + def load_profile_from_file( + config_file: str, + profile: str = "default", + *, + disable_env: bool = False, + config_file_strict: bool = False, + env_vars: Optional[Mapping[str, str]] = None, + ) -> ClientConfigProfile: + """Load a single client profile from a file, applying env overrides. + + To get a `temporalio.service.ConnectConfig`, use the + `ClientConfigProfile.to_connect_config` method on the returned profile. + + Args: + config_file: Path to the TOML config file. + profile: Profile to load from the config. + disable_env: If true, environment variable overriding is disabled. + This takes precedence over the `env_vars` parameter. + config_file_strict: If true, will error on unrecognized keys. + env_vars: The environment to use for overrides. If not provided, + environment variables are not used for overrides. To use the + current process's environment, `os.environ` can be + passed explicitly. + """ + raw_profile = _bridge_envconfig.load_client_connect_config_from_file( + profile=profile, + path=config_file, + disable_env=disable_env, + config_file_strict=config_file_strict, + env_vars=env_vars, + ) + return ClientConfigProfile.from_dict(raw_profile) + + @staticmethod + def load_profile_from_data( + config_file_data: Union[str, bytes], + profile: str = "default", + *, + disable_env: bool = False, + config_file_strict: bool = False, + env_vars: Optional[Mapping[str, str]] = None, + ) -> ClientConfigProfile: + """Load a single client profile from data, applying env overrides. + + To get a `temporalio.service.ConnectConfig`, use the + `ClientConfigProfile.to_connect_config` method on the returned profile. + + Args: + config_file_data: Raw string TOML config. + profile: Profile to load from the config. + disable_env: If true, environment variable overriding is disabled. + This takes precedence over the ``env_vars`` parameter. + config_file_strict: If true, will error on unrecognized keys. + env_vars: The environment to use for overrides. If not provided, + environment variables are not used for overrides. To use the + current process's environment, `os.environ` can be + passed explicitly. + """ + data_bytes = ( + config_file_data.encode("utf-8") + if isinstance(config_file_data, str) + else config_file_data + ) + raw_profile = _bridge_envconfig.load_client_connect_config_from_data( + profile=profile, + data=data_bytes, + disable_env=disable_env, + config_file_strict=config_file_strict, + env_vars=env_vars, + ) + return ClientConfigProfile.from_dict(raw_profile) diff --git a/tests/.DS_Store b/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..459664b7720af5da0727ba65beceab9b6682ce16 GIT binary patch literal 6148 zcmeHKyGkTM6umW$?w}}Qp{i_HV7*F0VDBUSCTd z++50M_M%Fq-i_*Dlvc0js%Mvfo@F;*!0zw3D!#Fp0jTU#N9dzxI^1 z$MvFQ{cNvT56d8is7?pep(eGdC0G5pknVTaUTl^hybyD^3i3EoN60!p;QK_hM@L#s zT_`DhS;;b7Ek7qVHs`Oj52<-vHTf)|r|`?@;gh}#lUh)8DD=?pA6k!KxO!&KNB6(q zSv@IM@_q7ILXYD71%Hm9%*?Sbd4ZhGQ5rGmYwv(}z&r5h0G|&53S($7HmHvd zRQd`4ETUNlp7mJ_47dRpT8s^%1)LI9XI@S-IiYHNYpfBJDU}!Nmh!%wZBcN^Ym3QD(9k>S=kMR8f literal 0 HcmV?d00001 diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py new file mode 100644 index 000000000..2ce158f3b --- /dev/null +++ b/tests/test_envconfig.py @@ -0,0 +1,406 @@ +import os +import textwrap +from pathlib import Path + +import pytest + +from temporalio.envconfig import ClientConfig +from temporalio.service import TLSConfig + +# A base TOML config with a default and a custom profile +TOML_CONFIG_BASE = textwrap.dedent( + """ + [profile.default] + address = "default-address" + namespace = "default-namespace" + + [profile.custom] + address = "custom-address" + namespace = "custom-namespace" + api_key = "custom-api-key" + [profile.custom.tls] + server_name = "custom-server-name" + [profile.custom.grpc_meta] + custom-header = "custom-value" + """ +) + +# A TOML config with an unrecognized key for strict testing +TOML_CONFIG_STRICT_FAIL = textwrap.dedent( + """ + [profile.default] + address = "default-address" + unrecognized = "should-fail" + """ +) + +# Malformed TOML +TOML_CONFIG_MALFORMED = "this is not valid toml" + +# A TOML config for testing detailed TLS options +TOML_CONFIG_TLS_DETAILED = textwrap.dedent( + """ + [profile.tls_disabled] + address = "localhost:1234" + [profile.tls_disabled.tls] + disabled = true + server_name = "should-be-ignored" + + [profile.tls_with_certs] + address = "localhost:5678" + [profile.tls_with_certs.tls] + server_name = "custom-server" + server_ca_cert_data = "ca-pem-data" + client_cert_data = "client-crt-data" + client_key_data = "client-key-data" + """ +) + + +@pytest.fixture +def base_config_file(tmp_path: Path) -> Path: + """Fixture to create a temporary config file with base content.""" + config_file = tmp_path / "config.toml" + config_file.write_text(TOML_CONFIG_BASE) + return config_file + + +def test_load_profile_from_file_default(base_config_file: Path): + """Test loading the default profile from a file.""" + profile = ClientConfig.load_profile_from_file(str(base_config_file)) + assert profile.address == "default-address" + assert profile.namespace == "default-namespace" + assert profile.tls is None + assert "custom-header" not in profile.grpc_meta + + config = profile.to_connect_config() + assert config.target_host == "default-address" + assert config.rpc_metadata["namespace"] == "default-namespace" + assert not config.tls + assert not config.rpc_metadata or "custom-header" not in config.rpc_metadata + + +def test_load_profile_from_file_custom(base_config_file: Path): + """Test loading a specific profile from a file.""" + profile = ClientConfig.load_profile_from_file( + str(base_config_file), profile="custom" + ) + assert profile.address == "custom-address" + assert profile.namespace == "custom-namespace" + assert profile.tls is not None + assert profile.tls.server_name == "custom-server-name" + assert profile.grpc_meta["custom-header"] == "custom-value" + + config = profile.to_connect_config() + assert config.target_host == "custom-address" + assert config.rpc_metadata["namespace"] == "custom-namespace" + assert isinstance(config.tls, TLSConfig) + assert config.tls.domain == "custom-server-name" + assert config.rpc_metadata["custom-header"] == "custom-value" + + +def test_load_profile_from_data_default(): + """Test loading the default profile from raw TOML data.""" + profile = ClientConfig.load_profile_from_data(TOML_CONFIG_BASE) + assert profile.address == "default-address" + assert profile.namespace == "default-namespace" + assert profile.tls is None + + config = profile.to_connect_config() + assert config.target_host == "default-address" + assert config.rpc_metadata["namespace"] == "default-namespace" + assert not config.tls + + +def test_load_profile_from_data_custom(): + """Test loading a custom profile from raw TOML data.""" + profile = ClientConfig.load_profile_from_data(TOML_CONFIG_BASE, profile="custom") + assert profile.address == "custom-address" + assert profile.namespace == "custom-namespace" + assert profile.tls is not None + assert profile.tls.server_name == "custom-server-name" + assert profile.grpc_meta["custom-header"] == "custom-value" + + config = profile.to_connect_config() + assert config.target_host == "custom-address" + assert config.rpc_metadata["namespace"] == "custom-namespace" + assert isinstance(config.tls, TLSConfig) + assert config.tls.domain == "custom-server-name" + assert config.rpc_metadata["custom-header"] == "custom-value" + + +def test_load_profile_from_data_env_overrides(): + """Test that environment variables correctly override data settings.""" + env = { + "TEMPORAL_ADDRESS": "env-address", + "TEMPORAL_NAMESPACE": "env-namespace", + } + profile = ClientConfig.load_profile_from_data( + TOML_CONFIG_BASE, profile="custom", env_vars=env + ) + assert profile.address == "env-address" + assert profile.namespace == "env-namespace" + + config = profile.to_connect_config() + assert config.target_host == "env-address" + assert config.rpc_metadata["namespace"] == "env-namespace" + + +def test_load_profile_env_overrides(base_config_file: Path): + """Test that environment variables correctly override file settings.""" + env = { + "TEMPORAL_ADDRESS": "env-address", + "TEMPORAL_NAMESPACE": "env-namespace", + "TEMPORAL_API_KEY": "env-api-key", + "TEMPORAL_TLS_SERVER_NAME": "env-server-name", + } + profile = ClientConfig.load_profile_from_file( + str(base_config_file), profile="custom", env_vars=env + ) + assert profile.address == "env-address" + assert profile.namespace == "env-namespace" + assert profile.api_key == "env-api-key" + assert profile.tls is not None + assert profile.tls.server_name == "env-server-name" + + config = profile.to_connect_config() + assert config.target_host == "env-address" + assert config.rpc_metadata["namespace"] == "env-namespace" + assert isinstance(config.tls, TLSConfig) + assert config.api_key == "env-api-key" + assert config.tls.domain == "env-server-name" + + +def test_load_profile_grpc_meta_env_overrides(base_config_file: Path): + """Test gRPC metadata overrides from environment variables.""" + env = { + # This should override the value in the file + "TEMPORAL_GRPC_META_CUSTOM_HEADER": "env-value", + # This should add a new header + "TEMPORAL_GRPC_META_ANOTHER_HEADER": "another-value", + } + profile = ClientConfig.load_profile_from_file( + str(base_config_file), profile="custom", env_vars=env + ) + assert profile.grpc_meta["custom-header"] == "env-value" + assert profile.grpc_meta["another-header"] == "another-value" + + config = profile.to_connect_config() + assert config.rpc_metadata["custom-header"] == "env-value" + assert config.rpc_metadata["another-header"] == "another-value" + + +def test_load_profile_disable_env(base_config_file: Path): + """Test that `disable_env` prevents environment variable overrides.""" + env = {"TEMPORAL_ADDRESS": "env-address"} + profile = ClientConfig.load_profile_from_file( + str(base_config_file), env_vars=env, disable_env=True + ) + assert profile.address == "default-address" + + config = profile.to_connect_config() + assert config.target_host == "default-address" + + +def test_load_profile_disable_file(monkeypatch): + """Test that `disable_file` loads configuration only from environment.""" + monkeypatch.setattr("pathlib.Path.exists", lambda _: False) + env = {"TEMPORAL_ADDRESS": "env-address"} + profile = ClientConfig.load_profile(disable_file=True, env_vars=env) + assert profile.address == "env-address" + + config = profile.to_connect_config() + assert config.target_host == "env-address" + + +def test_load_profile_api_key_enables_tls(tmp_path: Path): + """Test that the presence of an API key enables TLS by default.""" + config_toml = "[profile.default]\naddress = 'some-host:1234'\napi_key = 'my-key'" + config_file = tmp_path / "config.toml" + config_file.write_text(config_toml) + profile = ClientConfig.load_profile_from_file(str(config_file)) + assert profile.api_key == "my-key" + assert profile.tls is not None + + config = profile.to_connect_config() + assert config.tls + assert config.api_key == "my-key" + + +def test_load_profile_not_found(base_config_file: Path): + """Test that requesting a non-existent profile raises an error.""" + with pytest.raises(RuntimeError, match="Profile 'nonexistent' not found"): + ClientConfig.load_profile_from_file( + str(base_config_file), profile="nonexistent" + ) + + +def test_load_profiles_from_file_all(base_config_file: Path): + """Test loading all profiles from a file.""" + client_config = ClientConfig.load_profiles_from_file(str(base_config_file)) + assert len(client_config.profiles) == 2 + assert "default" in client_config.profiles + assert "custom" in client_config.profiles + # Check that we can convert to a connect config + connect_config = client_config.profiles["default"].to_connect_config() + assert connect_config.target_host == "default-address" + + +def test_load_profiles_from_data_all(): + """Test loading all profiles from raw data.""" + client_config = ClientConfig.load_profiles_from_data(TOML_CONFIG_BASE) + assert len(client_config.profiles) == 2 + connect_config = client_config.profiles["custom"].to_connect_config() + assert connect_config.target_host == "custom-address" + + +def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): + """Confirm that load_profiles does not apply env overrides.""" + config_file = tmp_path / "config.toml" + config_file.write_text(TOML_CONFIG_BASE) + env = { + "TEMPORAL_CONFIG_FILE": str(config_file), + "TEMPORAL_ADDRESS": "env-address", # This should be ignored + } + client_config = ClientConfig.load_profiles(env_vars=env) + connect_config = client_config.profiles["default"].to_connect_config() + assert connect_config.target_host == "default-address" + + +def test_load_profiles_no_config_file(monkeypatch): + """Test that load_profiles works when no config file is found.""" + monkeypatch.setattr("pathlib.Path.exists", lambda _: False) + monkeypatch.setattr(os, "environ", {}) + client_config = ClientConfig.load_profiles(env_vars={}) + assert not client_config.profiles + + +def test_load_profiles_discovery(tmp_path: Path, monkeypatch): + """Test file discovery via environment variables.""" + config_file = tmp_path / "config.toml" + config_file.write_text(TOML_CONFIG_BASE) + env = {"TEMPORAL_CONFIG_FILE": str(config_file)} + client_config = ClientConfig.load_profiles(env_vars=env) + assert "default" in client_config.profiles + + +def test_load_profiles_disable_file(): + """Test load_profiles with file loading disabled.""" + # With no env vars, should be empty + client_config = ClientConfig.load_profiles(disable_file=True, env_vars={}) + assert not client_config.profiles + + +def test_load_profiles_strict_mode_fail(tmp_path: Path): + """Test that strict mode fails on unrecognized keys.""" + config_file = tmp_path / "config.toml" + config_file.write_text(TOML_CONFIG_STRICT_FAIL) + with pytest.raises(RuntimeError, match="unknown field `unrecognized`"): + ClientConfig.load_profiles_from_file(str(config_file), config_file_strict=True) + + +def test_load_profile_strict_mode_fail(tmp_path: Path): + """Test that strict mode fails on unrecognized keys for load_profile.""" + config_file = tmp_path / "config.toml" + config_file.write_text(TOML_CONFIG_STRICT_FAIL) + with pytest.raises(RuntimeError, match="unknown field `unrecognized`"): + ClientConfig.load_profile_from_file(str(config_file), config_file_strict=True) + + +def test_load_profiles_from_data_malformed(): + """Test that loading malformed TOML data raises an error.""" + with pytest.raises(RuntimeError, match="TOML parse error"): + ClientConfig.load_profiles_from_data(TOML_CONFIG_MALFORMED) + + +def test_load_profile_tls_options(): + """Test parsing of detailed TLS options from data.""" + # Test with TLS disabled + profile_disabled = ClientConfig.load_profile_from_data( + TOML_CONFIG_TLS_DETAILED, profile="tls_disabled" + ) + assert profile_disabled.tls is not None + assert profile_disabled.tls.disabled is True + + config_disabled = profile_disabled.to_connect_config() + assert not config_disabled.tls + + # Test with TLS certs + profile_certs = ClientConfig.load_profile_from_data( + TOML_CONFIG_TLS_DETAILED, profile="tls_with_certs" + ) + assert profile_certs.tls is not None + assert profile_certs.tls.server_name == "custom-server" + assert profile_certs.tls.server_root_ca_cert is not None + assert profile_certs.tls.server_root_ca_cert.data == b"ca-pem-data" + assert profile_certs.tls.client_cert is not None + assert profile_certs.tls.client_cert.data == b"client-crt-data" + assert profile_certs.tls.client_private_key is not None + assert profile_certs.tls.client_private_key.data == b"client-key-data" + + config_certs = profile_certs.to_connect_config() + assert isinstance(config_certs.tls, TLSConfig) + assert config_certs.tls.domain == "custom-server" + assert config_certs.tls.server_root_ca_cert == b"ca-pem-data" + assert config_certs.tls.client_cert == b"client-crt-data" + assert config_certs.tls.client_private_key == b"client-key-data" + + +def test_load_profile_tls_from_paths(tmp_path: Path): + """Test parsing of TLS options from file paths.""" + # Create dummy cert files + (tmp_path / "ca.pem").write_text("ca-pem-data") + (tmp_path / "client.crt").write_text("client-crt-data") + (tmp_path / "client.key").write_text("client-key-data") + + toml_config = textwrap.dedent( + f""" + [profile.default] + address = "localhost:5678" + [profile.default.tls] + server_name = "custom-server" + server_ca_cert_path = "{tmp_path / "ca.pem"}" + client_cert_path = "{tmp_path / "client.crt"}" + client_key_path = "{tmp_path / "client.key"}" + """ + ) + + profile = ClientConfig.load_profile_from_data(toml_config) + assert profile.tls is not None + assert profile.tls.server_name == "custom-server" + assert profile.tls.server_root_ca_cert is not None + assert profile.tls.server_root_ca_cert.path == str(tmp_path / "ca.pem") + assert profile.tls.client_cert is not None + assert profile.tls.client_cert.path == str(tmp_path / "client.crt") + assert profile.tls.client_private_key is not None + assert profile.tls.client_private_key.path == str(tmp_path / "client.key") + + config = profile.to_connect_config() + assert isinstance(config.tls, TLSConfig) + assert config.tls.domain == "custom-server" + assert config.tls.server_root_ca_cert == b"ca-pem-data" + assert config.tls.client_cert == b"client-crt-data" + assert config.tls.client_private_key == b"client-key-data" + + +def test_load_profile_conflicting_cert_source_fails(): + """Test that providing both path and data for a cert fails.""" + toml_config = textwrap.dedent( + """ + [profile.default] + address = "localhost:5678" + [profile.default.tls] + client_cert_path = "/path/to/cert" + client_cert_data = "cert-data" + """ + ) + with pytest.raises( + RuntimeError, match="Cannot specify both client_cert_path and client_cert_data" + ): + ClientConfig.load_profile_from_data(toml_config) + + +def test_disables_raise_error(): + """Test that providing both disable_file and disable_env raises an error.""" + with pytest.raises(ValueError, match="Cannot disable both"): + ClientConfig.load_profile(disable_file=True, disable_env=True) From 9735ddaf5ae113ae7b60dab0ca99fcde422000b5 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Thu, 12 Jun 2025 00:19:03 +0200 Subject: [PATCH 02/17] pyright lint exclusion --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index d5a9f71e1..391ea0abc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,6 +186,7 @@ exclude = [ "temporalio/bridge/metric.py", "temporalio/bridge/runtime.py", "temporalio/bridge/testing.py", + "temporalio/envconfig.py", ] [tool.ruff] From 246786912c22606ec161ae1ece399cdc9b786a55 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 13 Jun 2025 10:00:28 +0200 Subject: [PATCH 03/17] simplify DataSource to union type, read file using Path, add experimental to class docstrings, misc --- .gitignore | 1 + temporalio/envconfig.py | 76 +++++++++++++++++++++++------------------ tests/test_envconfig.py | 12 +++---- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index f94200d1b..c31f84940 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,4 @@ temporalio/bridge/temporal_sdk_bridge* /.idea /sdk-python.iml /.zed +*.DS_Store diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index def350c98..c7b6a36a1 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -1,51 +1,61 @@ """Environment and file-based configuration for Temporal clients. This module provides utilities to load Temporal client configuration from TOML files -and environment variables, following the same patterns as the Go SDK. +and environment variables. """ from __future__ import annotations from dataclasses import dataclass, field +from pathlib import Path from typing import Any, Mapping, Optional, Union +from typing_extensions import TypeAlias + from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig from temporalio.service import ConnectConfig, TLSConfig +DataSource: TypeAlias = Union[ + str, bytes +] # str represents a file path, bytes represents raw data -@dataclass(frozen=True) -class _DataSource: - path: Optional[str] = None - data: Optional[bytes] = None - @staticmethod - def from_dict(d: Optional[Mapping[str, Any]]) -> Optional[_DataSource]: - if not d: - return None - return _DataSource(path=d.get("path"), data=d.get("data")) - - def read(self) -> Optional[bytes]: - if self.data: - return self.data - if self.path: - with open(self.path, "rb") as f: - return f.read() +def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource]: + if not d: return None + if "data" in d: + return d["data"] + if "path" in d: + return d["path"] + return None + + +def _read_source(source: Optional[DataSource]) -> Optional[bytes]: + if not source: + return None + if isinstance(source, str): + with open(Path(source), "rb") as f: + return f.read() + return source @dataclass(frozen=True) class ClientConfigTls: - """TLS configuration as specified as part of client configuration""" + """TLS configuration as specified as part of client configuration + + .. warning:: + Experimental API. + """ disabled: bool = False """If true, TLS is explicitly disabled.""" server_name: Optional[str] = None """SNI override.""" - server_root_ca_cert: Optional[_DataSource] = None + server_root_ca_cert: Optional[DataSource] = None """Server CA certificate source.""" - client_cert: Optional[_DataSource] = None + client_cert: Optional[DataSource] = None """Client certificate source.""" - client_private_key: Optional[_DataSource] = None + client_private_key: Optional[DataSource] = None """Client key source.""" def to_connect_tls_config(self) -> Union[bool, TLSConfig]: @@ -53,14 +63,11 @@ def to_connect_tls_config(self) -> Union[bool, TLSConfig]: if self.disabled: return False - def _read(ds: Optional[_DataSource]) -> Optional[bytes]: - return ds.read() if ds else None - return TLSConfig( domain=self.server_name, - server_root_ca_cert=_read(self.server_root_ca_cert), - client_cert=_read(self.client_cert), - client_private_key=_read(self.client_private_key), + server_root_ca_cert=_read_source(self.server_root_ca_cert), + client_cert=_read_source(self.client_cert), + client_private_key=_read_source(self.client_private_key), ) @staticmethod @@ -72,9 +79,9 @@ def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTls]: server_name=d.get("server_name"), # Note: Bridge uses snake_case, but TOML uses kebab-case which is # converted to snake_case. Core has server_ca_cert, client_key. - server_root_ca_cert=_DataSource.from_dict(d.get("server_ca_cert")), - client_cert=_DataSource.from_dict(d.get("client_cert")), - client_private_key=_DataSource.from_dict(d.get("client_key")), + server_root_ca_cert=_from_dict_to_source(d.get("server_ca_cert")), + client_cert=_from_dict_to_source(d.get("client_cert")), + client_private_key=_from_dict_to_source(d.get("client_key")), ) @@ -85,6 +92,9 @@ class ClientConfigProfile: This class holds the configuration as loaded from a file or environment. See `to_connect_config` to transform the profile to `temporalio.service.ConnectConfig`, which can be used to create a client. + + .. warning:: + Experimental API. """ address: Optional[str] = None @@ -114,18 +124,15 @@ def to_connect_config(self) -> ConnectConfig: # Create a dictionary of kwargs for ConnectConfig kwargs: dict[str, Any] = {"api_key": self.api_key} - # Target host if self.address: kwargs["target_host"] = self.address - # Metadata rpc_metadata = dict(self.grpc_meta) if self.namespace: rpc_metadata["namespace"] = self.namespace if rpc_metadata: kwargs["rpc_metadata"] = rpc_metadata - # TLS if self.tls: kwargs["tls"] = self.tls.to_connect_tls_config() @@ -139,6 +146,9 @@ class ClientConfig: This contains a mapping of profile names to client profiles. Use `ClientConfigProfile.to_connect_config` to create a `temporalio.service.ConnectConfig` from a profile. See `load_profile` to load an individual profile. + + .. warning:: + Experimental API. """ profiles: Mapping[str, ClientConfigProfile] diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 2ce158f3b..49ecc2a5e 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -332,11 +332,11 @@ def test_load_profile_tls_options(): assert profile_certs.tls is not None assert profile_certs.tls.server_name == "custom-server" assert profile_certs.tls.server_root_ca_cert is not None - assert profile_certs.tls.server_root_ca_cert.data == b"ca-pem-data" + assert profile_certs.tls.server_root_ca_cert == b"ca-pem-data" assert profile_certs.tls.client_cert is not None - assert profile_certs.tls.client_cert.data == b"client-crt-data" + assert profile_certs.tls.client_cert == b"client-crt-data" assert profile_certs.tls.client_private_key is not None - assert profile_certs.tls.client_private_key.data == b"client-key-data" + assert profile_certs.tls.client_private_key == b"client-key-data" config_certs = profile_certs.to_connect_config() assert isinstance(config_certs.tls, TLSConfig) @@ -369,11 +369,11 @@ def test_load_profile_tls_from_paths(tmp_path: Path): assert profile.tls is not None assert profile.tls.server_name == "custom-server" assert profile.tls.server_root_ca_cert is not None - assert profile.tls.server_root_ca_cert.path == str(tmp_path / "ca.pem") + assert profile.tls.server_root_ca_cert == str(tmp_path / "ca.pem") assert profile.tls.client_cert is not None - assert profile.tls.client_cert.path == str(tmp_path / "client.crt") + assert profile.tls.client_cert == str(tmp_path / "client.crt") assert profile.tls.client_private_key is not None - assert profile.tls.client_private_key.path == str(tmp_path / "client.key") + assert profile.tls.client_private_key == str(tmp_path / "client.key") config = profile.to_connect_config() assert isinstance(config.tls, TLSConfig) From 77a6537cec78383081e56442d15392be0a445a3d Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 13 Jun 2025 10:02:54 +0200 Subject: [PATCH 04/17] Remove .DS_Store files from Git tracking --- .DS_Store | Bin 8196 -> 0 bytes scripts/.DS_Store | Bin 6148 -> 0 bytes temporalio/.DS_Store | Bin 6148 -> 0 bytes temporalio/bridge/.DS_Store | Bin 6148 -> 0 bytes tests/.DS_Store | Bin 6148 -> 0 bytes 5 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 .DS_Store delete mode 100644 scripts/.DS_Store delete mode 100644 temporalio/.DS_Store delete mode 100644 temporalio/bridge/.DS_Store delete mode 100644 tests/.DS_Store diff --git a/.DS_Store b/.DS_Store deleted file mode 100644 index 5f1d0dc2347d41a3f7a9e7deeec8af1d9465e842..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8196 zcmeHM&ubGw6n>K*CKyt3l2R{=3W|r?f=7|Hrr=czUfZUrjgoZJq$w6JiynIs50ayJ zQUnqI2df?gFIwtFJqUuAqDb+i-~4d*rMsK;Bo$}E%-hcU-n?(W{V}{P5s9f@eUfOD zh>GaaR!SIRn)J&aX&Ezf0#+cNnoE;23tp=oLD~tcfK|XMU=^?mSOxZr0@$-wDe>l*883lEymp|Cnsm=Z(i zaLhXf=Njv4tHVi{;zO9s!c-_iWykuCf|GEywWC$QDiBpb*6ukPrzN^dtGs>}k1k4d z;A?r+YOU$jFjp*m&Get{Zk`+T*N6Jsjquxm_^Cb-6n^KW2KA`MvlETK!RmVuoWs_s z!qJJo;t)ZFOp;NY50T8U6W!^7^AI%WedPQ51DCfI@1YSP_dko^{E0q*kW#UO zD}M#QMDCAoCL=lTk^F__8y^F`W7chXA?I}T0ADd7`R1Yb8nj-fHGWT&Pw(Cx^m1Nv z=V!;xep9?yJm(UR;Q2?6srxfWbv&iG(mcD^n7)MUfsg0ju2AW^)`a;(72^Cnuh{Uh zG@imPo`eo^w{e$sd3F+?ccKW+;ZtYx)q|Jn{P(QGHAqHrK9q5hj1H_`+Dp4SQ3tNdw9RR(%0{?*flxBG%VwpeS7PP6llg743Jg;qtDU2={+~|%{y%I9>$3`21^#gbM6Nnhox(|&+B$MNvetIc z=g?)7IA2>8f|09MDO?dg2iO~bPN*Y5x7%GzN+ GEAR_@5va=m diff --git a/scripts/.DS_Store b/scripts/.DS_Store deleted file mode 100644 index b37cc5bc6e4994e03ca13b3f9a1572a75c235cdf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKOHRWu5S=Lz1RgKh$j#8G0~8SN~qvqAHz2ye$k!`l0}>>dn^~@>2;`^UZfd* zD+BWG3c93Qs%eqEKVH3eeRZg+*)q&vD<7ZsPEMa+&vXBLtG~JF{5JRnskDo_4c$?o zdkR@?IvjoVck36Uv*XoUte@>shtD%?ipkDFUd~}+3>X8(z<*-^JzJzaP_)q)Fb0f) zH3RZ}a8N;?uvLto4h#b<#a{@DU@mtFL6^`cY!%^wu%-evmD3f2H63=B;(WqZQPYWI zW$cqGbGo5$tPZ=2;ly6iMq|Jj$Qjtt;b3~*6Rs|g+{_150Q uNw0O#E2xOXwTkN$OkyiWthC}IXcX9;On^RNs|X9k{s=@GY%m6Xlz|T=3Po!G diff --git a/temporalio/.DS_Store b/temporalio/.DS_Store deleted file mode 100644 index 57d0adee2e038f3b6e33c26797bc59866ad16b7c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHLze@u#6n?R#7Eht0h{6pnii7?KXQ|GDvsBwkEn2VEU+ChkZcZXD{sRs=xjKlr z2o55S`X2}m4laH#i8kqZf{Uo+LGtC2m-oK=jwDTqNaWWmV?;S3%Azo)yHI^%JkBL% z9lEC%C}fN>%~6>aXr78OZ`a`za0>i31$gaxSgl2B&=%HiYyB3Ns$`0(iQa{Mdr$~U|l3nq&PgryW%HmGLt8{XJ%jhK7}?z&bEUL+qq zuO7DY$49MvqvxD2LnUfbjVe^7c_VA;W_M_^eca5w)4l1Ps@zr#lR=V?!6S7y^29tk z@qI#CrZvf?WH7=0+xHsxFuC0Jmq*U-r4DUzOg7101|MXbQhe61)31V04Xmm*pTR@P z$MP`w%y>QT-=z}~j5sErWG|b~Ec40GCiH32BBC`dHs1AB`?$&FdGCXZ6~knbvgaC9eR$G`gjso#O`3*BEGw6k-G> z)Ks9FD)bdYsOjkUwO^nyQmCer(3cOPGYfq~5i&dC`_i36pwQJ$0jGeiKuSI4`23$p zzW>{k+>=wlDe$ip5b45nA&*zmXKU=`_^kC&W>MJKFH$I{pwh>&Z17Rc|0_ttn9mJh VpfOU29+>+fAZ2isQ{YDxcmo@u*X;lR diff --git a/temporalio/bridge/.DS_Store b/temporalio/bridge/.DS_Store deleted file mode 100644 index 6c18b60901085091c24d1b5c6f05be0ce6cedb1a..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 6148 zcmeHKu};H44E41Y3b1r!yoHe=e-Ns`#sC!(BebcCM5|P#>={`22L6DFZ(v{mi9g^g z7+85eTWQm@5)(p|E&1N%Y~MLAQ5+MI>)d5AQJaV&C}U$A!wO+LYfT!u!bTVN=u=Kp z8jOZt z(R?_2<~&S{0b{@z*fa)Evst2DK^u($W55_F8Q|}OhcYILg<$w}Up{i_HV7*F0VDBUSCTd z++50M_M%Fq-i_*Dlvc0js%Mvfo@F;*!0zw3D!#Fp0jTU#N9dzxI^1 z$MvFQ{cNvT56d8is7?pep(eGdC0G5pknVTaUTl^hybyD^3i3EoN60!p;QK_hM@L#s zT_`DhS;;b7Ek7qVHs`Oj52<-vHTf)|r|`?@;gh}#lUh)8DD=?pA6k!KxO!&KNB6(q zSv@IM@_q7ILXYD71%Hm9%*?Sbd4ZhGQ5rGmYwv(}z&r5h0G|&53S($7HmHvd zRQd`4ETUNlp7mJ_47dRpT8s^%1)LI9XI@S-IiYHNYpfBJDU}!Nmh!%wZBcN^Ym3QD(9k>S=kMR8f From 036bd5dc848bcd38fe19946e8ee2837d726d7cd8 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Fri, 13 Jun 2025 10:17:27 +0200 Subject: [PATCH 05/17] fix bridge after pyo3 upgrade --- temporalio/bridge/src/envconfig.rs | 12 ++++++------ temporalio/bridge/src/lib.rs | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/temporalio/bridge/src/envconfig.rs b/temporalio/bridge/src/envconfig.rs index 8b6e314ed..3933a6a2b 100644 --- a/temporalio/bridge/src/envconfig.rs +++ b/temporalio/bridge/src/envconfig.rs @@ -20,7 +20,7 @@ fn data_source_to_dict(py: Python, ds: &DataSource) -> PyResult { DataSource::Path(p) => dict.set_item("path", p)?, DataSource::Data(d) => dict.set_item("data", PyBytes::new(py, d))?, }; - Ok(dict.to_object(py)) + Ok(dict.into()) } fn tls_to_dict(py: Python, tls: &CoreClientConfigTLS) -> PyResult { @@ -39,7 +39,7 @@ fn tls_to_dict(py: Python, tls: &CoreClientConfigTLS) -> PyResult { dict.set_item("server_name", v)?; } dict.set_item("disable_host_verification", tls.disable_host_verification)?; - Ok(dict.to_object(py)) + Ok(dict.into()) } fn codec_to_dict(py: Python, codec: &ClientConfigCodec) -> PyResult { @@ -50,7 +50,7 @@ fn codec_to_dict(py: Python, codec: &ClientConfigCodec) -> PyResult { if let Some(v) = &codec.auth { dict.set_item("auth", v)?; } - Ok(dict.to_object(py)) + Ok(dict.into()) } fn profile_to_dict(py: Python, profile: &CoreClientConfigProfile) -> PyResult { @@ -71,9 +71,9 @@ fn profile_to_dict(py: Python, profile: &CoreClientConfigProfile) -> PyResult PyResult { @@ -82,7 +82,7 @@ fn core_config_to_dict(py: Python, core_config: &CoreClientConfig) -> PyResult

) -> PyResult<()> { envconfig::load_client_connect_config_from_file, m )?)?; - m.add_submodule(envconfig_module)?; + m.add_submodule(&envconfig_module)?; Ok(()) } From 2bb09fc4030ca72f67575c916f4b09be0a216403 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Mon, 16 Jun 2025 13:23:09 +0200 Subject: [PATCH 06/17] handle windows file paths --- tests/test_envconfig.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 49ecc2a5e..3391fefbd 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -353,15 +353,19 @@ def test_load_profile_tls_from_paths(tmp_path: Path): (tmp_path / "client.crt").write_text("client-crt-data") (tmp_path / "client.key").write_text("client-key-data") + ca_pem_path = (tmp_path / "ca.pem").as_posix() + client_crt_path = (tmp_path / "client.crt").as_posix() + client_key_path = (tmp_path / "client.key").as_posix() + toml_config = textwrap.dedent( f""" [profile.default] address = "localhost:5678" [profile.default.tls] server_name = "custom-server" - server_ca_cert_path = "{tmp_path / "ca.pem"}" - client_cert_path = "{tmp_path / "client.crt"}" - client_key_path = "{tmp_path / "client.key"}" + server_ca_cert_path = "{ca_pem_path}" + client_cert_path = "{client_crt_path}" + client_key_path = "{client_key_path}" """ ) @@ -369,11 +373,11 @@ def test_load_profile_tls_from_paths(tmp_path: Path): assert profile.tls is not None assert profile.tls.server_name == "custom-server" assert profile.tls.server_root_ca_cert is not None - assert profile.tls.server_root_ca_cert == str(tmp_path / "ca.pem") + assert profile.tls.server_root_ca_cert == ca_pem_path assert profile.tls.client_cert is not None - assert profile.tls.client_cert == str(tmp_path / "client.crt") + assert profile.tls.client_cert == client_crt_path assert profile.tls.client_private_key is not None - assert profile.tls.client_private_key == str(tmp_path / "client.key") + assert profile.tls.client_private_key == client_key_path config = profile.to_connect_config() assert isinstance(config.tls, TLSConfig) From 262ae97595b23e6bc81e06e6701e6d65309d29d3 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 21 Jun 2025 16:05:28 +0200 Subject: [PATCH 07/17] quick pr suggestions --- temporalio/envconfig.py | 27 ++++++++++++++------------- tests/test_envconfig.py | 6 ------ 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index c7b6a36a1..29c4002a9 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -12,8 +12,9 @@ from typing_extensions import TypeAlias +# from temporalio.service import ConnectConfig, TLSConfig +import temporalio.service from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig -from temporalio.service import ConnectConfig, TLSConfig DataSource: TypeAlias = Union[ str, bytes @@ -40,7 +41,7 @@ def _read_source(source: Optional[DataSource]) -> Optional[bytes]: @dataclass(frozen=True) -class ClientConfigTls: +class ClientConfigTLS: """TLS configuration as specified as part of client configuration .. warning:: @@ -58,12 +59,12 @@ class ClientConfigTls: client_private_key: Optional[DataSource] = None """Client key source.""" - def to_connect_tls_config(self) -> Union[bool, TLSConfig]: + def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]: """Create a `temporalio.service.TLSConfig` from this profile.""" if self.disabled: return False - return TLSConfig( + return temporalio.service.TLSConfig( domain=self.server_name, server_root_ca_cert=_read_source(self.server_root_ca_cert), client_cert=_read_source(self.client_cert), @@ -71,10 +72,10 @@ def to_connect_tls_config(self) -> Union[bool, TLSConfig]: ) @staticmethod - def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTls]: + def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTLS]: if not d: return None - return ClientConfigTls( + return ClientConfigTLS( disabled=d.get("disabled", False), server_name=d.get("server_name"), # Note: Bridge uses snake_case, but TOML uses kebab-case which is @@ -103,23 +104,23 @@ class ClientConfigProfile: """Client namespace.""" api_key: Optional[str] = None """Client API key.""" - tls: Optional[ClientConfigTls] = None + tls: Optional[ClientConfigTLS] = None """TLS configuration.""" grpc_meta: Mapping[str, str] = field(default_factory=dict) """gRPC metadata.""" @staticmethod def from_dict(d: Mapping[str, Any]) -> ClientConfigProfile: - """Create a ClientConfigProfile from a dictionary from the bridge.""" + """Create a ClientConfigProfile from a dictionary.""" return ClientConfigProfile( address=d.get("address"), namespace=d.get("namespace"), api_key=d.get("api_key"), - tls=ClientConfigTls._from_dict(d.get("tls")), + tls=ClientConfigTLS._from_dict(d.get("tls")), grpc_meta=d.get("grpc_meta") or {}, ) - def to_connect_config(self) -> ConnectConfig: + def to_connect_config(self) -> temporalio.service.ConnectConfig: """Create a `temporalio.service.ConnectConfig` from this profile.""" # Create a dictionary of kwargs for ConnectConfig kwargs: dict[str, Any] = {"api_key": self.api_key} @@ -128,15 +129,15 @@ def to_connect_config(self) -> ConnectConfig: kwargs["target_host"] = self.address rpc_metadata = dict(self.grpc_meta) - if self.namespace: - rpc_metadata["namespace"] = self.namespace if rpc_metadata: kwargs["rpc_metadata"] = rpc_metadata if self.tls: kwargs["tls"] = self.tls.to_connect_tls_config() - return ConnectConfig(**{k: v for k, v in kwargs.items() if v is not None}) + return temporalio.service.ConnectConfig( + **{k: v for k, v in kwargs.items() if v is not None} + ) @dataclass diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 3391fefbd..fc254acd0 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -75,7 +75,6 @@ def test_load_profile_from_file_default(base_config_file: Path): config = profile.to_connect_config() assert config.target_host == "default-address" - assert config.rpc_metadata["namespace"] == "default-namespace" assert not config.tls assert not config.rpc_metadata or "custom-header" not in config.rpc_metadata @@ -93,7 +92,6 @@ def test_load_profile_from_file_custom(base_config_file: Path): config = profile.to_connect_config() assert config.target_host == "custom-address" - assert config.rpc_metadata["namespace"] == "custom-namespace" assert isinstance(config.tls, TLSConfig) assert config.tls.domain == "custom-server-name" assert config.rpc_metadata["custom-header"] == "custom-value" @@ -108,7 +106,6 @@ def test_load_profile_from_data_default(): config = profile.to_connect_config() assert config.target_host == "default-address" - assert config.rpc_metadata["namespace"] == "default-namespace" assert not config.tls @@ -123,7 +120,6 @@ def test_load_profile_from_data_custom(): config = profile.to_connect_config() assert config.target_host == "custom-address" - assert config.rpc_metadata["namespace"] == "custom-namespace" assert isinstance(config.tls, TLSConfig) assert config.tls.domain == "custom-server-name" assert config.rpc_metadata["custom-header"] == "custom-value" @@ -143,7 +139,6 @@ def test_load_profile_from_data_env_overrides(): config = profile.to_connect_config() assert config.target_host == "env-address" - assert config.rpc_metadata["namespace"] == "env-namespace" def test_load_profile_env_overrides(base_config_file: Path): @@ -165,7 +160,6 @@ def test_load_profile_env_overrides(base_config_file: Path): config = profile.to_connect_config() assert config.target_host == "env-address" - assert config.rpc_metadata["namespace"] == "env-namespace" assert isinstance(config.tls, TLSConfig) assert config.api_key == "env-api-key" assert config.tls.domain == "env-server-name" From bee65cd579cc223ebf2bf4548266cf9ad55beeee Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 21 Jun 2025 17:47:45 +0200 Subject: [PATCH 08/17] add Path to DataSource to read file paths, read file contents via string --- temporalio/envconfig.py | 19 ++++++++++--------- tests/test_envconfig.py | 23 +++++++++++++++++++---- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index 29c4002a9..a1b8c5546 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -12,13 +12,12 @@ from typing_extensions import TypeAlias -# from temporalio.service import ConnectConfig, TLSConfig import temporalio.service from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig DataSource: TypeAlias = Union[ - str, bytes -] # str represents a file path, bytes represents raw data + Path, str, bytes +] # str represents a file contents, bytes represents raw data def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource]: @@ -27,17 +26,19 @@ def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource] if "data" in d: return d["data"] if "path" in d: - return d["path"] + return Path(d["path"]) return None def _read_source(source: Optional[DataSource]) -> Optional[bytes]: - if not source: - return None - if isinstance(source, str): - with open(Path(source), "rb") as f: + if isinstance(source, Path): + with open(source, "rb") as f: return f.read() - return source + if isinstance(source, str): + return source.encode("utf-8") + if isinstance(source, bytes): + return source + return None @dataclass(frozen=True) diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index fc254acd0..cf2baa90e 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -4,7 +4,7 @@ import pytest -from temporalio.envconfig import ClientConfig +from temporalio.envconfig import ClientConfig, ClientConfigProfile, ClientConfigTLS from temporalio.service import TLSConfig # A base TOML config with a default and a custom profile @@ -367,11 +367,11 @@ def test_load_profile_tls_from_paths(tmp_path: Path): assert profile.tls is not None assert profile.tls.server_name == "custom-server" assert profile.tls.server_root_ca_cert is not None - assert profile.tls.server_root_ca_cert == ca_pem_path + assert profile.tls.server_root_ca_cert == Path(ca_pem_path) assert profile.tls.client_cert is not None - assert profile.tls.client_cert == client_crt_path + assert profile.tls.client_cert == Path(client_crt_path) assert profile.tls.client_private_key is not None - assert profile.tls.client_private_key == client_key_path + assert profile.tls.client_private_key == Path(client_key_path) config = profile.to_connect_config() assert isinstance(config.tls, TLSConfig) @@ -380,6 +380,21 @@ def test_load_profile_tls_from_paths(tmp_path: Path): assert config.tls.client_cert == b"client-crt-data" assert config.tls.client_private_key == b"client-key-data" +def test_read_source_from_string_content(): + """Test that _read_source correctly encodes string content.""" + # Check the behavior of providing a string as a data + # source, ensuring it's treated as content and encoded to bytes. + # Note that string content can only be provided programmatically, as + # the TOML parser in core currently only supports reading file paths + # and file data as bytes in the config file. + profile = ClientConfigProfile( + address="localhost:1234", + tls=ClientConfigTLS(client_cert="string-as-cert-content"), + ) + config = profile.to_connect_config() + assert isinstance(config.tls, TLSConfig) + assert config.tls.client_cert == b"string-as-cert-content" + def test_load_profile_conflicting_cert_source_fails(): """Test that providing both path and data for a cert fails.""" From 58d7378f00dddf347e995bc92a6053e2bc2211ba Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 21 Jun 2025 17:59:35 +0200 Subject: [PATCH 09/17] add ClientConnectConfig and to_client_connect_config --- temporalio/envconfig.py | 53 ++++++++++-------- tests/test_envconfig.py | 120 +++++++++++++++++++++------------------- 2 files changed, 94 insertions(+), 79 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index a1b8c5546..568512241 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import Any, Mapping, Optional, Union -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypedDict import temporalio.service from temporalio.bridge.temporal_sdk_bridge import envconfig as _bridge_envconfig @@ -86,13 +86,26 @@ def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTLS]: client_private_key=_from_dict_to_source(d.get("client_key")), ) +class ClientConnectConfig(TypedDict, total=False): + """Arguments for `temporalio.client.Client.connect` that are configurable via + environment configuration. + + .. warning:: + Experimental API. + """ + + target_host: str + namespace: str + api_key: Optional[str] + tls: Union[bool, temporalio.service.TLSConfig] + rpc_metadata: Mapping[str, str] @dataclass(frozen=True) class ClientConfigProfile: """Represents a client configuration profile. This class holds the configuration as loaded from a file or environment. - See `to_connect_config` to transform the profile to `temporalio.service.ConnectConfig`, + See `to_connect_config` to transform the profile to `ClientConnectConfig`, which can be used to create a client. .. warning:: @@ -121,24 +134,20 @@ def from_dict(d: Mapping[str, Any]) -> ClientConfigProfile: grpc_meta=d.get("grpc_meta") or {}, ) - def to_connect_config(self) -> temporalio.service.ConnectConfig: - """Create a `temporalio.service.ConnectConfig` from this profile.""" - # Create a dictionary of kwargs for ConnectConfig - kwargs: dict[str, Any] = {"api_key": self.api_key} - + def to_client_connect_config(self) -> ClientConnectConfig: + """Create a `ClientConnectConfig` from this profile.""" + config: ClientConnectConfig = {} if self.address: - kwargs["target_host"] = self.address - - rpc_metadata = dict(self.grpc_meta) - if rpc_metadata: - kwargs["rpc_metadata"] = rpc_metadata - + config["target_host"] = self.address + if self.namespace: + config["namespace"] = self.namespace + if self.api_key: + config["api_key"] = self.api_key if self.tls: - kwargs["tls"] = self.tls.to_connect_tls_config() - - return temporalio.service.ConnectConfig( - **{k: v for k, v in kwargs.items() if v is not None} - ) + config["tls"] = self.tls.to_connect_tls_config() + if self.grpc_meta: + config["rpc_metadata"] = self.grpc_meta + return config @dataclass @@ -146,7 +155,7 @@ class ClientConfig: """Client configuration loaded from TOML and environment variables. This contains a mapping of profile names to client profiles. Use - `ClientConfigProfile.to_connect_config` to create a `temporalio.service.ConnectConfig` + `ClientConfigProfile.to_connect_config` to create a `ClientConnectConfig` from a profile. See `load_profile` to load an individual profile. .. warning:: @@ -240,7 +249,7 @@ def load_profile( """Load a single client profile from default sources, applying env overrides. - To get a `temporalio.service.ConnectConfig`, use the + To get a `ClientConnectConfig`, use the `ClientConfigProfile.to_connect_config` method on the returned profile. Args: @@ -281,7 +290,7 @@ def load_profile_from_file( ) -> ClientConfigProfile: """Load a single client profile from a file, applying env overrides. - To get a `temporalio.service.ConnectConfig`, use the + To get a `ClientConnectConfig`, use the `ClientConfigProfile.to_connect_config` method on the returned profile. Args: @@ -315,7 +324,7 @@ def load_profile_from_data( ) -> ClientConfigProfile: """Load a single client profile from data, applying env overrides. - To get a `temporalio.service.ConnectConfig`, use the + To get a `ClientConnectConfig`, use the `ClientConfigProfile.to_connect_config` method on the returned profile. Args: diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index cf2baa90e..f1c669e13 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -73,10 +73,10 @@ def test_load_profile_from_file_default(base_config_file: Path): assert profile.tls is None assert "custom-header" not in profile.grpc_meta - config = profile.to_connect_config() - assert config.target_host == "default-address" - assert not config.tls - assert not config.rpc_metadata or "custom-header" not in config.rpc_metadata + config = profile.to_client_connect_config() + assert config["target_host"] == "default-address" + assert "tls" not in config + assert "rpc_metadata" not in config or "custom-header" not in config["rpc_metadata"] def test_load_profile_from_file_custom(base_config_file: Path): @@ -90,11 +90,12 @@ def test_load_profile_from_file_custom(base_config_file: Path): assert profile.tls.server_name == "custom-server-name" assert profile.grpc_meta["custom-header"] == "custom-value" - config = profile.to_connect_config() - assert config.target_host == "custom-address" - assert isinstance(config.tls, TLSConfig) - assert config.tls.domain == "custom-server-name" - assert config.rpc_metadata["custom-header"] == "custom-value" + config = profile.to_client_connect_config() + assert config["target_host"] == "custom-address" + tls_config = config["tls"] + assert isinstance(tls_config, TLSConfig) + assert tls_config.domain == "custom-server-name" + assert config["rpc_metadata"]["custom-header"] == "custom-value" def test_load_profile_from_data_default(): @@ -104,9 +105,9 @@ def test_load_profile_from_data_default(): assert profile.namespace == "default-namespace" assert profile.tls is None - config = profile.to_connect_config() - assert config.target_host == "default-address" - assert not config.tls + config = profile.to_client_connect_config() + assert config["target_host"] == "default-address" + assert "tls" not in config def test_load_profile_from_data_custom(): @@ -118,11 +119,12 @@ def test_load_profile_from_data_custom(): assert profile.tls.server_name == "custom-server-name" assert profile.grpc_meta["custom-header"] == "custom-value" - config = profile.to_connect_config() - assert config.target_host == "custom-address" - assert isinstance(config.tls, TLSConfig) - assert config.tls.domain == "custom-server-name" - assert config.rpc_metadata["custom-header"] == "custom-value" + config = profile.to_client_connect_config() + assert config["target_host"] == "custom-address" + tls_config = config["tls"] + assert isinstance(tls_config, TLSConfig) + assert tls_config.domain == "custom-server-name" + assert config["rpc_metadata"]["custom-header"] == "custom-value" def test_load_profile_from_data_env_overrides(): @@ -137,8 +139,8 @@ def test_load_profile_from_data_env_overrides(): assert profile.address == "env-address" assert profile.namespace == "env-namespace" - config = profile.to_connect_config() - assert config.target_host == "env-address" + config = profile.to_client_connect_config() + assert config["target_host"] == "env-address" def test_load_profile_env_overrides(base_config_file: Path): @@ -158,11 +160,12 @@ def test_load_profile_env_overrides(base_config_file: Path): assert profile.tls is not None assert profile.tls.server_name == "env-server-name" - config = profile.to_connect_config() - assert config.target_host == "env-address" - assert isinstance(config.tls, TLSConfig) - assert config.api_key == "env-api-key" - assert config.tls.domain == "env-server-name" + config = profile.to_client_connect_config() + assert config["target_host"] == "env-address" + assert config["api_key"] == "env-api-key" + tls_config = config["tls"] + assert isinstance(tls_config, TLSConfig) + assert tls_config.domain == "env-server-name" def test_load_profile_grpc_meta_env_overrides(base_config_file: Path): @@ -179,9 +182,9 @@ def test_load_profile_grpc_meta_env_overrides(base_config_file: Path): assert profile.grpc_meta["custom-header"] == "env-value" assert profile.grpc_meta["another-header"] == "another-value" - config = profile.to_connect_config() - assert config.rpc_metadata["custom-header"] == "env-value" - assert config.rpc_metadata["another-header"] == "another-value" + config = profile.to_client_connect_config() + assert config["rpc_metadata"]["custom-header"] == "env-value" + assert config["rpc_metadata"]["another-header"] == "another-value" def test_load_profile_disable_env(base_config_file: Path): @@ -192,8 +195,8 @@ def test_load_profile_disable_env(base_config_file: Path): ) assert profile.address == "default-address" - config = profile.to_connect_config() - assert config.target_host == "default-address" + config = profile.to_client_connect_config() + assert config["target_host"] == "default-address" def test_load_profile_disable_file(monkeypatch): @@ -203,8 +206,8 @@ def test_load_profile_disable_file(monkeypatch): profile = ClientConfig.load_profile(disable_file=True, env_vars=env) assert profile.address == "env-address" - config = profile.to_connect_config() - assert config.target_host == "env-address" + config = profile.to_client_connect_config() + assert config["target_host"] == "env-address" def test_load_profile_api_key_enables_tls(tmp_path: Path): @@ -216,9 +219,9 @@ def test_load_profile_api_key_enables_tls(tmp_path: Path): assert profile.api_key == "my-key" assert profile.tls is not None - config = profile.to_connect_config() - assert config.tls - assert config.api_key == "my-key" + config = profile.to_client_connect_config() + assert config["tls"] + assert config["api_key"] == "my-key" def test_load_profile_not_found(base_config_file: Path): @@ -236,16 +239,16 @@ def test_load_profiles_from_file_all(base_config_file: Path): assert "default" in client_config.profiles assert "custom" in client_config.profiles # Check that we can convert to a connect config - connect_config = client_config.profiles["default"].to_connect_config() - assert connect_config.target_host == "default-address" + connect_config = client_config.profiles["default"].to_client_connect_config() + assert connect_config["target_host"] == "default-address" def test_load_profiles_from_data_all(): """Test loading all profiles from raw data.""" client_config = ClientConfig.load_profiles_from_data(TOML_CONFIG_BASE) assert len(client_config.profiles) == 2 - connect_config = client_config.profiles["custom"].to_connect_config() - assert connect_config.target_host == "custom-address" + connect_config = client_config.profiles["custom"].to_client_connect_config() + assert connect_config["target_host"] == "custom-address" def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): @@ -257,8 +260,8 @@ def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): "TEMPORAL_ADDRESS": "env-address", # This should be ignored } client_config = ClientConfig.load_profiles(env_vars=env) - connect_config = client_config.profiles["default"].to_connect_config() - assert connect_config.target_host == "default-address" + connect_config = client_config.profiles["default"].to_client_connect_config() + assert connect_config["target_host"] == "default-address" def test_load_profiles_no_config_file(monkeypatch): @@ -316,8 +319,8 @@ def test_load_profile_tls_options(): assert profile_disabled.tls is not None assert profile_disabled.tls.disabled is True - config_disabled = profile_disabled.to_connect_config() - assert not config_disabled.tls + config_disabled = profile_disabled.to_client_connect_config() + assert not config_disabled["tls"] # Test with TLS certs profile_certs = ClientConfig.load_profile_from_data( @@ -332,12 +335,13 @@ def test_load_profile_tls_options(): assert profile_certs.tls.client_private_key is not None assert profile_certs.tls.client_private_key == b"client-key-data" - config_certs = profile_certs.to_connect_config() - assert isinstance(config_certs.tls, TLSConfig) - assert config_certs.tls.domain == "custom-server" - assert config_certs.tls.server_root_ca_cert == b"ca-pem-data" - assert config_certs.tls.client_cert == b"client-crt-data" - assert config_certs.tls.client_private_key == b"client-key-data" + config_certs = profile_certs.to_client_connect_config() + tls_config_certs = config_certs["tls"] + assert isinstance(tls_config_certs, TLSConfig) + assert tls_config_certs.domain == "custom-server" + assert tls_config_certs.server_root_ca_cert == b"ca-pem-data" + assert tls_config_certs.client_cert == b"client-crt-data" + assert tls_config_certs.client_private_key == b"client-key-data" def test_load_profile_tls_from_paths(tmp_path: Path): @@ -373,12 +377,13 @@ def test_load_profile_tls_from_paths(tmp_path: Path): assert profile.tls.client_private_key is not None assert profile.tls.client_private_key == Path(client_key_path) - config = profile.to_connect_config() - assert isinstance(config.tls, TLSConfig) - assert config.tls.domain == "custom-server" - assert config.tls.server_root_ca_cert == b"ca-pem-data" - assert config.tls.client_cert == b"client-crt-data" - assert config.tls.client_private_key == b"client-key-data" + config = profile.to_client_connect_config() + tls_config = config["tls"] + assert isinstance(tls_config, TLSConfig) + assert tls_config.domain == "custom-server" + assert tls_config.server_root_ca_cert == b"ca-pem-data" + assert tls_config.client_cert == b"client-crt-data" + assert tls_config.client_private_key == b"client-key-data" def test_read_source_from_string_content(): """Test that _read_source correctly encodes string content.""" @@ -391,9 +396,10 @@ def test_read_source_from_string_content(): address="localhost:1234", tls=ClientConfigTLS(client_cert="string-as-cert-content"), ) - config = profile.to_connect_config() - assert isinstance(config.tls, TLSConfig) - assert config.tls.client_cert == b"string-as-cert-content" + config = profile.to_client_connect_config() + tls_config = config["tls"] + assert isinstance(tls_config, TLSConfig) + assert tls_config.client_cert == b"string-as-cert-content" def test_load_profile_conflicting_cert_source_fails(): From 240ff0e327df7a2dd8871a05c5d76c475e725df6 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Sat, 21 Jun 2025 18:41:26 +0200 Subject: [PATCH 10/17] added load_client_connect_config --- temporalio/envconfig.py | 63 +++++++++++++++++++++++++-- tests/test_envconfig.py | 96 ++++++++++++++++++++++++++++++----------- 2 files changed, 129 insertions(+), 30 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index 568512241..a6f7395ef 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -86,6 +86,7 @@ def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTLS]: client_private_key=_from_dict_to_source(d.get("client_key")), ) + class ClientConnectConfig(TypedDict, total=False): """Arguments for `temporalio.client.Client.connect` that are configurable via environment configuration. @@ -94,11 +95,12 @@ class ClientConnectConfig(TypedDict, total=False): Experimental API. """ - target_host: str - namespace: str + target_host: Optional[str] + namespace: Optional[str] api_key: Optional[str] - tls: Union[bool, temporalio.service.TLSConfig] - rpc_metadata: Mapping[str, str] + tls: Optional[Union[bool, temporalio.service.TLSConfig]] + rpc_metadata: Optional[Mapping[str, str]] + @dataclass(frozen=True) class ClientConfigProfile: @@ -351,3 +353,56 @@ def load_profile_from_data( env_vars=env_vars, ) return ClientConfigProfile.from_dict(raw_profile) + + @staticmethod + def load_client_connect_config( + profile: str = "default", + *, + env_vars: Optional[Mapping[str, str]] = None, + config_file: Optional[str] = None, + disable_file: bool = False, + disable_env: bool = False, + config_file_strict: bool = False, + ) -> ClientConnectConfig: + """Load a single client profile and convert to connect config. + + This is a convenience function that combines loading a profile and + converting it to a connect config dictionary. This will use the current + process's environment for overrides unless disabled. + + Args: + profile: The profile to load from the config. Defaults to "default". + env_vars: Environment variables to use. Defaults to ``os.environ``. + config_file: Path to a specific TOML config file. If not provided, + default file locations are used. This is ignored if + ``disable_file`` is true. + disable_file: If true, file loading is disabled. + disable_env: If true, environment variable loading and overriding + is disabled. + config_file_strict: If true, will error on unrecognized keys in the + TOML file. + + Returns: + TypedDict of keyword arguments for + :py:meth:`temporalio.client.Client.connect`. + """ + prof: ClientConfigProfile + if config_file and not disable_file: + # If file loading is enabled and provided, use it. + prof = ClientConfig.load_profile_from_file( + config_file, + profile=profile, + env_vars=env_vars, + disable_env=disable_env, + config_file_strict=config_file_strict, + ) + else: + # Otherwise, use default file discovery + prof = ClientConfig.load_profile( + profile=profile, + env_vars=env_vars, + disable_file=disable_file, + disable_env=disable_env, + config_file_strict=config_file_strict, + ) + return prof.to_client_connect_config() diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index f1c669e13..591dffd9f 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -74,9 +74,10 @@ def test_load_profile_from_file_default(base_config_file: Path): assert "custom-header" not in profile.grpc_meta config = profile.to_client_connect_config() - assert config["target_host"] == "default-address" + assert config.get("target_host") == "default-address" assert "tls" not in config - assert "rpc_metadata" not in config or "custom-header" not in config["rpc_metadata"] + rpc_meta = config.get("rpc_metadata") + assert not rpc_meta or "custom-header" not in rpc_meta def test_load_profile_from_file_custom(base_config_file: Path): @@ -91,11 +92,13 @@ def test_load_profile_from_file_custom(base_config_file: Path): assert profile.grpc_meta["custom-header"] == "custom-value" config = profile.to_client_connect_config() - assert config["target_host"] == "custom-address" - tls_config = config["tls"] + assert config.get("target_host") == "custom-address" + tls_config = config.get("tls") assert isinstance(tls_config, TLSConfig) assert tls_config.domain == "custom-server-name" - assert config["rpc_metadata"]["custom-header"] == "custom-value" + rpc_metadata = config.get("rpc_metadata") + assert rpc_metadata + assert rpc_metadata["custom-header"] == "custom-value" def test_load_profile_from_data_default(): @@ -106,7 +109,7 @@ def test_load_profile_from_data_default(): assert profile.tls is None config = profile.to_client_connect_config() - assert config["target_host"] == "default-address" + assert config.get("target_host") == "default-address" assert "tls" not in config @@ -120,11 +123,13 @@ def test_load_profile_from_data_custom(): assert profile.grpc_meta["custom-header"] == "custom-value" config = profile.to_client_connect_config() - assert config["target_host"] == "custom-address" - tls_config = config["tls"] + assert config.get("target_host") == "custom-address" + tls_config = config.get("tls") assert isinstance(tls_config, TLSConfig) assert tls_config.domain == "custom-server-name" - assert config["rpc_metadata"]["custom-header"] == "custom-value" + rpc_metadata = config.get("rpc_metadata") + assert rpc_metadata + assert rpc_metadata["custom-header"] == "custom-value" def test_load_profile_from_data_env_overrides(): @@ -140,7 +145,7 @@ def test_load_profile_from_data_env_overrides(): assert profile.namespace == "env-namespace" config = profile.to_client_connect_config() - assert config["target_host"] == "env-address" + assert config.get("target_host") == "env-address" def test_load_profile_env_overrides(base_config_file: Path): @@ -161,9 +166,9 @@ def test_load_profile_env_overrides(base_config_file: Path): assert profile.tls.server_name == "env-server-name" config = profile.to_client_connect_config() - assert config["target_host"] == "env-address" - assert config["api_key"] == "env-api-key" - tls_config = config["tls"] + assert config.get("target_host") == "env-address" + assert config.get("api_key") == "env-api-key" + tls_config = config.get("tls") assert isinstance(tls_config, TLSConfig) assert tls_config.domain == "env-server-name" @@ -183,8 +188,10 @@ def test_load_profile_grpc_meta_env_overrides(base_config_file: Path): assert profile.grpc_meta["another-header"] == "another-value" config = profile.to_client_connect_config() - assert config["rpc_metadata"]["custom-header"] == "env-value" - assert config["rpc_metadata"]["another-header"] == "another-value" + rpc_metadata = config.get("rpc_metadata") + assert rpc_metadata + assert rpc_metadata["custom-header"] == "env-value" + assert rpc_metadata["another-header"] == "another-value" def test_load_profile_disable_env(base_config_file: Path): @@ -196,7 +203,7 @@ def test_load_profile_disable_env(base_config_file: Path): assert profile.address == "default-address" config = profile.to_client_connect_config() - assert config["target_host"] == "default-address" + assert config.get("target_host") == "default-address" def test_load_profile_disable_file(monkeypatch): @@ -207,7 +214,7 @@ def test_load_profile_disable_file(monkeypatch): assert profile.address == "env-address" config = profile.to_client_connect_config() - assert config["target_host"] == "env-address" + assert config.get("target_host") == "env-address" def test_load_profile_api_key_enables_tls(tmp_path: Path): @@ -220,8 +227,8 @@ def test_load_profile_api_key_enables_tls(tmp_path: Path): assert profile.tls is not None config = profile.to_client_connect_config() - assert config["tls"] - assert config["api_key"] == "my-key" + assert config.get("tls") + assert config.get("api_key") == "my-key" def test_load_profile_not_found(base_config_file: Path): @@ -240,7 +247,7 @@ def test_load_profiles_from_file_all(base_config_file: Path): assert "custom" in client_config.profiles # Check that we can convert to a connect config connect_config = client_config.profiles["default"].to_client_connect_config() - assert connect_config["target_host"] == "default-address" + assert connect_config.get("target_host") == "default-address" def test_load_profiles_from_data_all(): @@ -248,7 +255,7 @@ def test_load_profiles_from_data_all(): client_config = ClientConfig.load_profiles_from_data(TOML_CONFIG_BASE) assert len(client_config.profiles) == 2 connect_config = client_config.profiles["custom"].to_client_connect_config() - assert connect_config["target_host"] == "custom-address" + assert connect_config.get("target_host") == "custom-address" def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): @@ -261,7 +268,7 @@ def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): } client_config = ClientConfig.load_profiles(env_vars=env) connect_config = client_config.profiles["default"].to_client_connect_config() - assert connect_config["target_host"] == "default-address" + assert connect_config.get("target_host") == "default-address" def test_load_profiles_no_config_file(monkeypatch): @@ -320,7 +327,7 @@ def test_load_profile_tls_options(): assert profile_disabled.tls.disabled is True config_disabled = profile_disabled.to_client_connect_config() - assert not config_disabled["tls"] + assert not config_disabled.get("tls") # Test with TLS certs profile_certs = ClientConfig.load_profile_from_data( @@ -336,7 +343,7 @@ def test_load_profile_tls_options(): assert profile_certs.tls.client_private_key == b"client-key-data" config_certs = profile_certs.to_client_connect_config() - tls_config_certs = config_certs["tls"] + tls_config_certs = config_certs.get("tls") assert isinstance(tls_config_certs, TLSConfig) assert tls_config_certs.domain == "custom-server" assert tls_config_certs.server_root_ca_cert == b"ca-pem-data" @@ -378,13 +385,14 @@ def test_load_profile_tls_from_paths(tmp_path: Path): assert profile.tls.client_private_key == Path(client_key_path) config = profile.to_client_connect_config() - tls_config = config["tls"] + tls_config = config.get("tls") assert isinstance(tls_config, TLSConfig) assert tls_config.domain == "custom-server" assert tls_config.server_root_ca_cert == b"ca-pem-data" assert tls_config.client_cert == b"client-crt-data" assert tls_config.client_private_key == b"client-key-data" + def test_read_source_from_string_content(): """Test that _read_source correctly encodes string content.""" # Check the behavior of providing a string as a data @@ -397,7 +405,7 @@ def test_read_source_from_string_content(): tls=ClientConfigTLS(client_cert="string-as-cert-content"), ) config = profile.to_client_connect_config() - tls_config = config["tls"] + tls_config = config.get("tls") assert isinstance(tls_config, TLSConfig) assert tls_config.client_cert == b"string-as-cert-content" @@ -419,6 +427,42 @@ def test_load_profile_conflicting_cert_source_fails(): ClientConfig.load_profile_from_data(toml_config) +def test_load_client_connect_config(base_config_file: Path): + """Test the load_client_connect_config.""" + # Test with explicit file path, default profile + config = ClientConfig.load_client_connect_config(config_file=str(base_config_file)) + assert config.get("target_host") == "default-address" + assert config.get("namespace") == "default-namespace" + + # Test with explicit file path, custom profile + config = ClientConfig.load_client_connect_config( + config_file=str(base_config_file), profile="custom" + ) + assert config.get("target_host") == "custom-address" + assert config.get("namespace") == "custom-namespace" + rpc_metadata = config.get("rpc_metadata") + assert rpc_metadata + assert "custom-header" in rpc_metadata + + # Test with env overrides + env = {"TEMPORAL_ADDRESS": "env-address"} + config = ClientConfig.load_client_connect_config( + config_file=str(base_config_file), env_vars=env + ) + assert config.get("target_host") == "env-address" + + # Test with env overrides disabled + config = ClientConfig.load_client_connect_config( + config_file=str(base_config_file), env_vars=env, disable_env=True + ) + assert config.get("target_host") == "default-address" + + # Test with file loading disabled + config = ClientConfig.load_client_connect_config(disable_file=True, env_vars=env) + assert config.get("target_host") == "env-address" + assert "namespace" not in config + + def test_disables_raise_error(): """Test that providing both disable_file and disable_env raises an error.""" with pytest.raises(ValueError, match="Cannot disable both"): From 652f610a14eeac2b2d84cc18d324d7bebed5a5ef Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Wed, 25 Jun 2025 22:07:50 +0200 Subject: [PATCH 11/17] raise TypeError when reading invalid data source --- temporalio/envconfig.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index a6f7395ef..e338311d7 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -31,6 +31,8 @@ def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource] def _read_source(source: Optional[DataSource]) -> Optional[bytes]: + if source is None: + return None if isinstance(source, Path): with open(source, "rb") as f: return f.read() @@ -38,7 +40,9 @@ def _read_source(source: Optional[DataSource]) -> Optional[bytes]: return source.encode("utf-8") if isinstance(source, bytes): return source - return None + raise TypeError( + f"Source must be one of pathlib.Path, str, or bytes, but got {type(source).__name__}" + ) @dataclass(frozen=True) From f517dabce4ba2bf16312c5dd957c1f735527a579 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Wed, 25 Jun 2025 22:15:16 +0200 Subject: [PATCH 12/17] load_client_connect_config rename env_vars to override_env_vars --- temporalio/envconfig.py | 9 +++++---- tests/test_envconfig.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index e338311d7..3e298af8a 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -362,11 +362,11 @@ def load_profile_from_data( def load_client_connect_config( profile: str = "default", *, - env_vars: Optional[Mapping[str, str]] = None, config_file: Optional[str] = None, disable_file: bool = False, disable_env: bool = False, config_file_strict: bool = False, + override_env_vars: Optional[Mapping[str, str]] = None, ) -> ClientConnectConfig: """Load a single client profile and convert to connect config. @@ -376,7 +376,6 @@ def load_client_connect_config( Args: profile: The profile to load from the config. Defaults to "default". - env_vars: Environment variables to use. Defaults to ``os.environ``. config_file: Path to a specific TOML config file. If not provided, default file locations are used. This is ignored if ``disable_file`` is true. @@ -385,6 +384,8 @@ def load_client_connect_config( is disabled. config_file_strict: If true, will error on unrecognized keys in the TOML file. + override_env_vars: A dictionary of environment variables to use for + loading and overrides. Returns: TypedDict of keyword arguments for @@ -396,7 +397,7 @@ def load_client_connect_config( prof = ClientConfig.load_profile_from_file( config_file, profile=profile, - env_vars=env_vars, + env_vars=override_env_vars, disable_env=disable_env, config_file_strict=config_file_strict, ) @@ -404,7 +405,7 @@ def load_client_connect_config( # Otherwise, use default file discovery prof = ClientConfig.load_profile( profile=profile, - env_vars=env_vars, + env_vars=override_env_vars, disable_file=disable_file, disable_env=disable_env, config_file_strict=config_file_strict, diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 591dffd9f..7b86c336f 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -447,18 +447,18 @@ def test_load_client_connect_config(base_config_file: Path): # Test with env overrides env = {"TEMPORAL_ADDRESS": "env-address"} config = ClientConfig.load_client_connect_config( - config_file=str(base_config_file), env_vars=env + config_file=str(base_config_file), override_env_vars=env ) assert config.get("target_host") == "env-address" # Test with env overrides disabled config = ClientConfig.load_client_connect_config( - config_file=str(base_config_file), env_vars=env, disable_env=True + config_file=str(base_config_file), override_env_vars=env, disable_env=True ) assert config.get("target_host") == "default-address" # Test with file loading disabled - config = ClientConfig.load_client_connect_config(disable_file=True, env_vars=env) + config = ClientConfig.load_client_connect_config(disable_file=True, override_env_vars=env) assert config.get("target_host") == "env-address" assert "namespace" not in config From 82809d9a6f88ba5b7ba5d299dc361e1ccb98a5d8 Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Wed, 25 Jun 2025 22:57:08 +0200 Subject: [PATCH 13/17] add .connect calls to test_load_client_connect_config to ensure load_client_connect_config actually matches connect arguments --- tests/test_envconfig.py | 80 ++++++++++++++++++++++++++++++++--------- 1 file changed, 63 insertions(+), 17 deletions(-) diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 7b86c336f..9d218a03b 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -4,6 +4,7 @@ import pytest +from temporalio.client import Client from temporalio.envconfig import ClientConfig, ClientConfigProfile, ClientConfigTLS from temporalio.service import TLSConfig @@ -427,41 +428,86 @@ def test_load_profile_conflicting_cert_source_fails(): ClientConfig.load_profile_from_data(toml_config) -def test_load_client_connect_config(base_config_file: Path): - """Test the load_client_connect_config.""" +async def test_load_client_connect_config(client: Client, tmp_path: Path): + """Test the load_client_connect_config for various scenarios.""" + # Get connection details from the fixture client + target_host = client.service_client.config.target_host + namespace = client.namespace + + # Create a TOML file with profiles pointing to the test server + config_content = f""" +[profile.default] +address = "{target_host}" +namespace = "{namespace}" + +[profile.custom] +address = "{target_host}" +namespace = "custom-namespace" +[profile.custom.grpc_meta] +custom-header = "custom-value" + """ + config_file = tmp_path / "temporal.toml" + config_file.write_text(config_content) + # Test with explicit file path, default profile - config = ClientConfig.load_client_connect_config(config_file=str(base_config_file)) - assert config.get("target_host") == "default-address" - assert config.get("namespace") == "default-namespace" + config = ClientConfig.load_client_connect_config(config_file=str(config_file)) + assert config.get("target_host") == target_host + assert config.get("namespace") == namespace + new_client = await Client.connect(**config) + assert new_client.service_client.config.target_host == target_host + assert new_client.namespace == namespace # Test with explicit file path, custom profile config = ClientConfig.load_client_connect_config( - config_file=str(base_config_file), profile="custom" + config_file=str(config_file), profile="custom" ) - assert config.get("target_host") == "custom-address" + assert config.get("target_host") == target_host assert config.get("namespace") == "custom-namespace" rpc_metadata = config.get("rpc_metadata") assert rpc_metadata assert "custom-header" in rpc_metadata + new_client = await Client.connect(**config) + assert new_client.service_client.config.target_host == target_host + assert new_client.namespace == "custom-namespace" + assert ( + new_client.service_client.config.rpc_metadata["custom-header"] + == "custom-value" + ) # Test with env overrides - env = {"TEMPORAL_ADDRESS": "env-address"} + env = {"TEMPORAL_NAMESPACE": "env-namespace-override"} config = ClientConfig.load_client_connect_config( - config_file=str(base_config_file), override_env_vars=env + config_file=str(config_file), override_env_vars=env ) - assert config.get("target_host") == "env-address" + assert config.get("target_host") == target_host + assert config.get("namespace") == "env-namespace-override" + new_client = await Client.connect(**config) + assert new_client.namespace == "env-namespace-override" # Test with env overrides disabled config = ClientConfig.load_client_connect_config( - config_file=str(base_config_file), override_env_vars=env, disable_env=True + config_file=str(config_file), + override_env_vars={"TEMPORAL_NAMESPACE": "ignored"}, + disable_env=True, ) - assert config.get("target_host") == "default-address" - - # Test with file loading disabled - config = ClientConfig.load_client_connect_config(disable_file=True, override_env_vars=env) - assert config.get("target_host") == "env-address" - assert "namespace" not in config + assert config.get("target_host") == target_host + assert config.get("namespace") == namespace + new_client = await Client.connect(**config) + assert new_client.namespace == namespace + # Test with file loading disabled (so only env is used) + env = { + "TEMPORAL_ADDRESS": target_host, + "TEMPORAL_NAMESPACE": "env-only-namespace", + } + config = ClientConfig.load_client_connect_config( + disable_file=True, override_env_vars=env + ) + assert config.get("target_host") == target_host + assert config.get("namespace") == "env-only-namespace" + new_client = await Client.connect(**config) + assert new_client.service_client.config.target_host == target_host + assert new_client.namespace == "env-only-namespace" def test_disables_raise_error(): """Test that providing both disable_file and disable_env raises an error.""" From 0e65f2262d1d131eac8b3c654f121edb9df60e4d Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Thu, 26 Jun 2025 00:02:15 +0200 Subject: [PATCH 14/17] refactor to general 'load' methods --- temporalio/bridge/src/envconfig.rs | 106 ++++-------- temporalio/bridge/src/lib.rs | 16 -- temporalio/envconfig.py | 265 ++++++++++------------------- tests/test_envconfig.py | 92 +++++----- 4 files changed, 168 insertions(+), 311 deletions(-) diff --git a/temporalio/bridge/src/envconfig.rs b/temporalio/bridge/src/envconfig.rs index 3933a6a2b..30999314e 100644 --- a/temporalio/bridge/src/envconfig.rs +++ b/temporalio/bridge/src/envconfig.rs @@ -130,109 +130,63 @@ fn load_client_connect_config_inner( } #[pyfunction] -#[pyo3(signature = (disable_file, config_file_strict, env_vars = None))] +#[pyo3(signature = (path, data, disable_file, config_file_strict, env_vars = None))] pub fn load_client_config( py: Python, + path: Option, + data: Option>, disable_file: bool, config_file_strict: bool, env_vars: Option>, ) -> PyResult { - load_client_config_inner(py, None, config_file_strict, disable_file, env_vars) -} - -#[pyfunction] -#[pyo3(signature = (path, config_file_strict, env_vars = None))] -pub fn load_client_config_from_file( - py: Python, - path: String, - config_file_strict: bool, - env_vars: Option>, -) -> PyResult { - load_client_config_inner( - py, - Some(DataSource::Path(path)), - config_file_strict, - false, - env_vars, - ) -} - -#[pyfunction] -#[pyo3(signature = (data, config_file_strict, env_vars = None))] -pub fn load_client_config_from_data( - py: Python, - data: Vec, - config_file_strict: bool, - env_vars: Option>, -) -> PyResult { + let config_source = match (path, data) { + (Some(p), None) => Some(DataSource::Path(p)), + (None, Some(d)) => Some(DataSource::Data(d)), + (None, None) => None, + (Some(_), Some(_)) => { + return Err(ConfigError::new_err( + "Cannot specify both path and data for config source", + )) + } + }; load_client_config_inner( py, - Some(DataSource::Data(data)), + config_source, config_file_strict, - false, + disable_file, env_vars, ) } #[pyfunction] -#[pyo3(signature = (profile, disable_file, disable_env, config_file_strict, env_vars = None))] +#[pyo3(signature = (profile, path, data, disable_file, disable_env, config_file_strict, env_vars = None))] pub fn load_client_connect_config( py: Python, profile: Option, + path: Option, + data: Option>, disable_file: bool, disable_env: bool, config_file_strict: bool, env_vars: Option>, ) -> PyResult { + let config_source = match (path, data) { + (Some(p), None) => Some(DataSource::Path(p)), + (None, Some(d)) => Some(DataSource::Data(d)), + (None, None) => None, + (Some(_), Some(_)) => { + return Err(ConfigError::new_err( + "Cannot specify both path and data for config source", + )) + } + }; load_client_connect_config_inner( py, - None, + config_source, profile, disable_file, disable_env, config_file_strict, env_vars, ) -} - -#[pyfunction] -#[pyo3(signature = (path, profile, disable_env, config_file_strict, env_vars = None))] -pub fn load_client_connect_config_from_file( - py: Python, - path: String, - profile: Option, - disable_env: bool, - config_file_strict: bool, - env_vars: Option>, -) -> PyResult { - load_client_connect_config_inner( - py, - Some(DataSource::Path(path)), - profile, - false, - disable_env, - config_file_strict, - env_vars, - ) -} - -#[pyfunction] -#[pyo3(signature = (data, profile, disable_env, config_file_strict, env_vars = None))] -pub fn load_client_connect_config_from_data( - py: Python, - data: Vec, - profile: Option, - disable_env: bool, - config_file_strict: bool, - env_vars: Option>, -) -> PyResult { - load_client_connect_config_inner( - py, - Some(DataSource::Data(data)), - profile, - false, - disable_env, - config_file_strict, - env_vars, - ) -} +} \ No newline at end of file diff --git a/temporalio/bridge/src/lib.rs b/temporalio/bridge/src/lib.rs index 3413524ee..0281e9210 100644 --- a/temporalio/bridge/src/lib.rs +++ b/temporalio/bridge/src/lib.rs @@ -60,26 +60,10 @@ fn temporal_sdk_bridge(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> { let envconfig_module = PyModule::new(py, "envconfig")?; envconfig_module.add("ConfigError", py.get_type::())?; envconfig_module.add_function(wrap_pyfunction!(envconfig::load_client_config, m)?)?; - envconfig_module.add_function(wrap_pyfunction!( - envconfig::load_client_config_from_data, - m - )?)?; - envconfig_module.add_function(wrap_pyfunction!( - envconfig::load_client_config_from_file, - m - )?)?; envconfig_module.add_function(wrap_pyfunction!( envconfig::load_client_connect_config, m )?)?; - envconfig_module.add_function(wrap_pyfunction!( - envconfig::load_client_connect_config_from_data, - m - )?)?; - envconfig_module.add_function(wrap_pyfunction!( - envconfig::load_client_connect_config_from_file, - m - )?)?; m.add_submodule(&envconfig_module)?; Ok(()) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index 3e298af8a..f2e020559 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -30,6 +30,25 @@ def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource] return None +def _source_to_path_and_data( + source: Optional[DataSource], +) -> tuple[Optional[str], Optional[bytes]]: + path: Optional[str] = None + data: Optional[bytes] = None + if isinstance(source, Path): + path = str(source) + elif isinstance(source, str): + data = source.encode("utf-8") + elif isinstance(source, bytes): + data = source + elif source is not None: + raise TypeError( + "config_source must be one of pathlib.Path, str, bytes, or None, " + f"but got {type(source).__name__}" + ) + return path, data + + def _read_source(source: Optional[DataSource]) -> Optional[bytes]: if source is None: return None @@ -155,129 +174,47 @@ def to_client_connect_config(self) -> ClientConnectConfig: config["rpc_metadata"] = self.grpc_meta return config - -@dataclass -class ClientConfig: - """Client configuration loaded from TOML and environment variables. - - This contains a mapping of profile names to client profiles. Use - `ClientConfigProfile.to_connect_config` to create a `ClientConnectConfig` - from a profile. See `load_profile` to load an individual profile. - - .. warning:: - Experimental API. - """ - - profiles: Mapping[str, ClientConfigProfile] - """Map of profile name to its corresponding ClientConfigProfile.""" - - @staticmethod - def _from_bridge_profiles( - bridge_profiles: Mapping[str, Mapping[str, Any]], - ) -> ClientConfig: - return ClientConfig( - profiles={ - k: ClientConfigProfile.from_dict(v) for k, v in bridge_profiles.items() - } - ) - - @staticmethod - def load_profiles( - *, - disable_file: bool = False, - config_file_strict: bool = False, - env_vars: Optional[Mapping[str, str]] = None, - ) -> ClientConfig: - """Load all client profiles from default file locations and environment variables. - - This does not apply environment variable overrides to the profiles, it - only uses an environment variable to find the default config file path - (`TEMPORAL_CONFIG_FILE`). To get a single profile with environment variables - applied, use `load_profile`. - - Args: - disable_file: If true, file loading is disabled. Will create a default - configuration. - config_file_strict: If true, will TOML file parsing will error on - unrecognized keys. - env_vars: The environment variables to use for locating the default config - file. If not provided, `TEMPORAL_CONFIG_FILE` is not checked - and only the default path is used (./temporal/temporal.toml). To use - the current process's environment, `os.environ` can be passed explicitly. - """ - loaded_profiles = _bridge_envconfig.load_client_config( - disable_file=disable_file, - config_file_strict=config_file_strict, - env_vars=env_vars, - ) - return ClientConfig._from_bridge_profiles(loaded_profiles) - - @staticmethod - def load_profiles_from_file( - config_file: str, - *, - config_file_strict: bool = False, - ) -> ClientConfig: - """Load all client profiles from a specific file.""" - loaded_profiles = _bridge_envconfig.load_client_config_from_file( - path=config_file, - config_file_strict=config_file_strict, - ) - return ClientConfig._from_bridge_profiles(loaded_profiles) - - @staticmethod - def load_profiles_from_data( - config_file_data: Union[str, bytes], - *, - config_file_strict: bool = False, - ) -> ClientConfig: - """Load all client profiles from specific data.""" - data_bytes = ( - config_file_data.encode("utf-8") - if isinstance(config_file_data, str) - else config_file_data - ) - loaded_profiles = _bridge_envconfig.load_client_config_from_data( - data=data_bytes, - config_file_strict=config_file_strict, - ) - return ClientConfig._from_bridge_profiles(loaded_profiles) - @staticmethod - def load_profile( + def load( profile: str = "default", *, + config_source: Optional[DataSource] = None, disable_file: bool = False, disable_env: bool = False, config_file_strict: bool = False, env_vars: Optional[Mapping[str, str]] = None, ) -> ClientConfigProfile: - """Load a single client profile from default sources, applying env + """Load a single client profile from given sources, applying env overrides. - To get a `ClientConnectConfig`, use the - `ClientConfigProfile.to_connect_config` method on the returned profile. + To get a :py:class:`ClientConnectConfig`, use the + :py:meth:`to_client_connect_config` method on the returned profile. Args: profile: Profile to load from the config. - disable_file: If true, file loading is disabled. + config_source: If present, this is used as the configuration source + instead of default file locations. This can be a path to the file + or the string/byte contents of the file. + disable_file: If true, file loading is disabled. This is only used + when ``config_source`` is not present. disable_env: If true, environment variable loading and overriding is disabled. This takes precedence over the ``env_vars`` parameter. config_file_strict: If true, will error on unrecognized keys. env_vars: The environment to use for loading and overrides. If not provided, environment variables are not used for overrides. To - use the current process's environment, `os.environ` can be + use the current process's environment, :py:attr:`os.environ` can be passed explicitly. Returns: The client configuration profile. """ - if disable_file and disable_env: - raise ValueError("Cannot disable both file and environment loading") + path, data = _source_to_path_and_data(config_source) raw_profile = _bridge_envconfig.load_client_connect_config( profile=profile, + path=path, + data=data, disable_file=disable_file, disable_env=disable_env, config_file_strict=config_file_strict, @@ -285,78 +222,71 @@ def load_profile( ) return ClientConfigProfile.from_dict(raw_profile) - @staticmethod - def load_profile_from_file( - config_file: str, - profile: str = "default", - *, - disable_env: bool = False, - config_file_strict: bool = False, - env_vars: Optional[Mapping[str, str]] = None, - ) -> ClientConfigProfile: - """Load a single client profile from a file, applying env overrides. - To get a `ClientConnectConfig`, use the - `ClientConfigProfile.to_connect_config` method on the returned profile. +@dataclass +class ClientConfig: + """Client configuration loaded from TOML and environment variables. - Args: - config_file: Path to the TOML config file. - profile: Profile to load from the config. - disable_env: If true, environment variable overriding is disabled. - This takes precedence over the `env_vars` parameter. - config_file_strict: If true, will error on unrecognized keys. - env_vars: The environment to use for overrides. If not provided, - environment variables are not used for overrides. To use the - current process's environment, `os.environ` can be - passed explicitly. - """ - raw_profile = _bridge_envconfig.load_client_connect_config_from_file( - profile=profile, - path=config_file, - disable_env=disable_env, - config_file_strict=config_file_strict, - env_vars=env_vars, + This contains a mapping of profile names to client profiles. Use + `ClientConfigProfile.to_connect_config` to create a `ClientConnectConfig` + from a profile. See `load_profile` to load an individual profile. + + .. warning:: + Experimental API. + """ + + profiles: Mapping[str, ClientConfigProfile] + """Map of profile name to its corresponding ClientConfigProfile.""" + + @staticmethod + def _from_bridge_profiles( + bridge_profiles: Mapping[str, Mapping[str, Any]], + ) -> ClientConfig: + return ClientConfig( + profiles={ + k: ClientConfigProfile.from_dict(v) for k, v in bridge_profiles.items() + } ) - return ClientConfigProfile.from_dict(raw_profile) @staticmethod - def load_profile_from_data( - config_file_data: Union[str, bytes], - profile: str = "default", + def load( *, - disable_env: bool = False, + config_source: Optional[DataSource] = None, + disable_file: bool = False, config_file_strict: bool = False, env_vars: Optional[Mapping[str, str]] = None, - ) -> ClientConfigProfile: - """Load a single client profile from data, applying env overrides. + ) -> ClientConfig: + """Load all client profiles from given sources. - To get a `ClientConnectConfig`, use the - `ClientConfigProfile.to_connect_config` method on the returned profile. + This does not apply environment variable overrides to the profiles, it + only uses an environment variable to find the default config file path + (``TEMPORAL_CONFIG_FILE``). To get a single profile with environment variables + applied, use :py:meth:`ClientConfigProfile.load`. Args: - config_file_data: Raw string TOML config. - profile: Profile to load from the config. - disable_env: If true, environment variable overriding is disabled. - This takes precedence over the ``env_vars`` parameter. - config_file_strict: If true, will error on unrecognized keys. - env_vars: The environment to use for overrides. If not provided, - environment variables are not used for overrides. To use the - current process's environment, `os.environ` can be - passed explicitly. + config_source: If present, this is used as the configuration source + instead of default file locations. This can be a path to the file + or the string/byte contents of the file. + disable_file: If true, file loading is disabled. This is only used + when ``config_source`` is not present. + config_file_strict: If true, will TOML file parsing will error on + unrecognized keys. + env_vars: The environment variables to use for locating the default config + file. If not provided, ``TEMPORAL_CONFIG_FILE`` is not checked + and only the default path is used (e.g. ``~/.config/temporalio/temporal.toml``). + To use the current process's environment, :py:attr:`os.environ` can be passed + explicitly. """ - data_bytes = ( - config_file_data.encode("utf-8") - if isinstance(config_file_data, str) - else config_file_data - ) - raw_profile = _bridge_envconfig.load_client_connect_config_from_data( - profile=profile, - data=data_bytes, - disable_env=disable_env, + path, data = _source_to_path_and_data(config_source) + + loaded_profiles = _bridge_envconfig.load_client_config( + path=path, + data=data, + disable_file=disable_file, config_file_strict=config_file_strict, env_vars=env_vars, ) - return ClientConfigProfile.from_dict(raw_profile) + return ClientConfig._from_bridge_profiles(loaded_profiles) @staticmethod def load_client_connect_config( @@ -391,23 +321,16 @@ def load_client_connect_config( TypedDict of keyword arguments for :py:meth:`temporalio.client.Client.connect`. """ - prof: ClientConfigProfile + config_source: Optional[DataSource] = None if config_file and not disable_file: - # If file loading is enabled and provided, use it. - prof = ClientConfig.load_profile_from_file( - config_file, - profile=profile, - env_vars=override_env_vars, - disable_env=disable_env, - config_file_strict=config_file_strict, - ) - else: - # Otherwise, use default file discovery - prof = ClientConfig.load_profile( - profile=profile, - env_vars=override_env_vars, - disable_file=disable_file, - disable_env=disable_env, - config_file_strict=config_file_strict, - ) + config_source = Path(config_file) + + prof = ClientConfigProfile.load( + profile=profile, + config_source=config_source, + disable_file=disable_file, + disable_env=disable_env, + config_file_strict=config_file_strict, + env_vars=override_env_vars, + ) return prof.to_client_connect_config() diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 9d218a03b..3a8727cbb 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -68,7 +68,7 @@ def base_config_file(tmp_path: Path) -> Path: def test_load_profile_from_file_default(base_config_file: Path): """Test loading the default profile from a file.""" - profile = ClientConfig.load_profile_from_file(str(base_config_file)) + profile = ClientConfigProfile.load(config_source=base_config_file) assert profile.address == "default-address" assert profile.namespace == "default-namespace" assert profile.tls is None @@ -83,9 +83,7 @@ def test_load_profile_from_file_default(base_config_file: Path): def test_load_profile_from_file_custom(base_config_file: Path): """Test loading a specific profile from a file.""" - profile = ClientConfig.load_profile_from_file( - str(base_config_file), profile="custom" - ) + profile = ClientConfigProfile.load(config_source=base_config_file, profile="custom") assert profile.address == "custom-address" assert profile.namespace == "custom-namespace" assert profile.tls is not None @@ -104,7 +102,7 @@ def test_load_profile_from_file_custom(base_config_file: Path): def test_load_profile_from_data_default(): """Test loading the default profile from raw TOML data.""" - profile = ClientConfig.load_profile_from_data(TOML_CONFIG_BASE) + profile = ClientConfigProfile.load(config_source=TOML_CONFIG_BASE) assert profile.address == "default-address" assert profile.namespace == "default-namespace" assert profile.tls is None @@ -116,7 +114,7 @@ def test_load_profile_from_data_default(): def test_load_profile_from_data_custom(): """Test loading a custom profile from raw TOML data.""" - profile = ClientConfig.load_profile_from_data(TOML_CONFIG_BASE, profile="custom") + profile = ClientConfigProfile.load(config_source=TOML_CONFIG_BASE, profile="custom") assert profile.address == "custom-address" assert profile.namespace == "custom-namespace" assert profile.tls is not None @@ -139,8 +137,8 @@ def test_load_profile_from_data_env_overrides(): "TEMPORAL_ADDRESS": "env-address", "TEMPORAL_NAMESPACE": "env-namespace", } - profile = ClientConfig.load_profile_from_data( - TOML_CONFIG_BASE, profile="custom", env_vars=env + profile = ClientConfigProfile.load( + config_source=TOML_CONFIG_BASE, profile="custom", env_vars=env ) assert profile.address == "env-address" assert profile.namespace == "env-namespace" @@ -157,8 +155,8 @@ def test_load_profile_env_overrides(base_config_file: Path): "TEMPORAL_API_KEY": "env-api-key", "TEMPORAL_TLS_SERVER_NAME": "env-server-name", } - profile = ClientConfig.load_profile_from_file( - str(base_config_file), profile="custom", env_vars=env + profile = ClientConfigProfile.load( + config_source=base_config_file, profile="custom", env_vars=env ) assert profile.address == "env-address" assert profile.namespace == "env-namespace" @@ -182,8 +180,8 @@ def test_load_profile_grpc_meta_env_overrides(base_config_file: Path): # This should add a new header "TEMPORAL_GRPC_META_ANOTHER_HEADER": "another-value", } - profile = ClientConfig.load_profile_from_file( - str(base_config_file), profile="custom", env_vars=env + profile = ClientConfigProfile.load( + config_source=base_config_file, profile="custom", env_vars=env ) assert profile.grpc_meta["custom-header"] == "env-value" assert profile.grpc_meta["another-header"] == "another-value" @@ -198,8 +196,8 @@ def test_load_profile_grpc_meta_env_overrides(base_config_file: Path): def test_load_profile_disable_env(base_config_file: Path): """Test that `disable_env` prevents environment variable overrides.""" env = {"TEMPORAL_ADDRESS": "env-address"} - profile = ClientConfig.load_profile_from_file( - str(base_config_file), env_vars=env, disable_env=True + profile = ClientConfigProfile.load( + config_source=base_config_file, env_vars=env, disable_env=True ) assert profile.address == "default-address" @@ -211,7 +209,7 @@ def test_load_profile_disable_file(monkeypatch): """Test that `disable_file` loads configuration only from environment.""" monkeypatch.setattr("pathlib.Path.exists", lambda _: False) env = {"TEMPORAL_ADDRESS": "env-address"} - profile = ClientConfig.load_profile(disable_file=True, env_vars=env) + profile = ClientConfigProfile.load(disable_file=True, env_vars=env) assert profile.address == "env-address" config = profile.to_client_connect_config() @@ -223,7 +221,7 @@ def test_load_profile_api_key_enables_tls(tmp_path: Path): config_toml = "[profile.default]\naddress = 'some-host:1234'\napi_key = 'my-key'" config_file = tmp_path / "config.toml" config_file.write_text(config_toml) - profile = ClientConfig.load_profile_from_file(str(config_file)) + profile = ClientConfigProfile.load(config_source=config_file) assert profile.api_key == "my-key" assert profile.tls is not None @@ -235,14 +233,12 @@ def test_load_profile_api_key_enables_tls(tmp_path: Path): def test_load_profile_not_found(base_config_file: Path): """Test that requesting a non-existent profile raises an error.""" with pytest.raises(RuntimeError, match="Profile 'nonexistent' not found"): - ClientConfig.load_profile_from_file( - str(base_config_file), profile="nonexistent" - ) + ClientConfigProfile.load(config_source=base_config_file, profile="nonexistent") def test_load_profiles_from_file_all(base_config_file: Path): """Test loading all profiles from a file.""" - client_config = ClientConfig.load_profiles_from_file(str(base_config_file)) + client_config = ClientConfig.load(config_source=base_config_file) assert len(client_config.profiles) == 2 assert "default" in client_config.profiles assert "custom" in client_config.profiles @@ -253,7 +249,7 @@ def test_load_profiles_from_file_all(base_config_file: Path): def test_load_profiles_from_data_all(): """Test loading all profiles from raw data.""" - client_config = ClientConfig.load_profiles_from_data(TOML_CONFIG_BASE) + client_config = ClientConfig.load(config_source=TOML_CONFIG_BASE) assert len(client_config.profiles) == 2 connect_config = client_config.profiles["custom"].to_client_connect_config() assert connect_config.get("target_host") == "custom-address" @@ -267,7 +263,7 @@ def test_load_profiles_no_env_override(tmp_path: Path, monkeypatch): "TEMPORAL_CONFIG_FILE": str(config_file), "TEMPORAL_ADDRESS": "env-address", # This should be ignored } - client_config = ClientConfig.load_profiles(env_vars=env) + client_config = ClientConfig.load(env_vars=env) connect_config = client_config.profiles["default"].to_client_connect_config() assert connect_config.get("target_host") == "default-address" @@ -276,7 +272,7 @@ def test_load_profiles_no_config_file(monkeypatch): """Test that load_profiles works when no config file is found.""" monkeypatch.setattr("pathlib.Path.exists", lambda _: False) monkeypatch.setattr(os, "environ", {}) - client_config = ClientConfig.load_profiles(env_vars={}) + client_config = ClientConfig.load(env_vars={}) assert not client_config.profiles @@ -285,14 +281,14 @@ def test_load_profiles_discovery(tmp_path: Path, monkeypatch): config_file = tmp_path / "config.toml" config_file.write_text(TOML_CONFIG_BASE) env = {"TEMPORAL_CONFIG_FILE": str(config_file)} - client_config = ClientConfig.load_profiles(env_vars=env) + client_config = ClientConfig.load(env_vars=env) assert "default" in client_config.profiles def test_load_profiles_disable_file(): """Test load_profiles with file loading disabled.""" # With no env vars, should be empty - client_config = ClientConfig.load_profiles(disable_file=True, env_vars={}) + client_config = ClientConfig.load(disable_file=True, env_vars={}) assert not client_config.profiles @@ -301,7 +297,7 @@ def test_load_profiles_strict_mode_fail(tmp_path: Path): config_file = tmp_path / "config.toml" config_file.write_text(TOML_CONFIG_STRICT_FAIL) with pytest.raises(RuntimeError, match="unknown field `unrecognized`"): - ClientConfig.load_profiles_from_file(str(config_file), config_file_strict=True) + ClientConfig.load(config_source=config_file, config_file_strict=True) def test_load_profile_strict_mode_fail(tmp_path: Path): @@ -309,20 +305,20 @@ def test_load_profile_strict_mode_fail(tmp_path: Path): config_file = tmp_path / "config.toml" config_file.write_text(TOML_CONFIG_STRICT_FAIL) with pytest.raises(RuntimeError, match="unknown field `unrecognized`"): - ClientConfig.load_profile_from_file(str(config_file), config_file_strict=True) + ClientConfigProfile.load(config_source=config_file, config_file_strict=True) def test_load_profiles_from_data_malformed(): """Test that loading malformed TOML data raises an error.""" with pytest.raises(RuntimeError, match="TOML parse error"): - ClientConfig.load_profiles_from_data(TOML_CONFIG_MALFORMED) + ClientConfig.load(config_source=TOML_CONFIG_MALFORMED) def test_load_profile_tls_options(): """Test parsing of detailed TLS options from data.""" # Test with TLS disabled - profile_disabled = ClientConfig.load_profile_from_data( - TOML_CONFIG_TLS_DETAILED, profile="tls_disabled" + profile_disabled = ClientConfigProfile.load( + config_source=TOML_CONFIG_TLS_DETAILED, profile="tls_disabled" ) assert profile_disabled.tls is not None assert profile_disabled.tls.disabled is True @@ -331,8 +327,8 @@ def test_load_profile_tls_options(): assert not config_disabled.get("tls") # Test with TLS certs - profile_certs = ClientConfig.load_profile_from_data( - TOML_CONFIG_TLS_DETAILED, profile="tls_with_certs" + profile_certs = ClientConfigProfile.load( + config_source=TOML_CONFIG_TLS_DETAILED, profile="tls_with_certs" ) assert profile_certs.tls is not None assert profile_certs.tls.server_name == "custom-server" @@ -375,7 +371,7 @@ def test_load_profile_tls_from_paths(tmp_path: Path): """ ) - profile = ClientConfig.load_profile_from_data(toml_config) + profile = ClientConfigProfile.load(config_source=toml_config) assert profile.tls is not None assert profile.tls.server_name == "custom-server" assert profile.tls.server_root_ca_cert is not None @@ -425,7 +421,7 @@ def test_load_profile_conflicting_cert_source_fails(): with pytest.raises( RuntimeError, match="Cannot specify both client_cert_path and client_cert_data" ): - ClientConfig.load_profile_from_data(toml_config) + ClientConfigProfile.load(config_source=toml_config) async def test_load_client_connect_config(client: Client, tmp_path: Path): @@ -451,8 +447,8 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): # Test with explicit file path, default profile config = ClientConfig.load_client_connect_config(config_file=str(config_file)) - assert config.get("target_host") == target_host - assert config.get("namespace") == namespace + assert config["target_host"] == target_host + assert config["namespace"] == namespace new_client = await Client.connect(**config) assert new_client.service_client.config.target_host == target_host assert new_client.namespace == namespace @@ -461,8 +457,8 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config( config_file=str(config_file), profile="custom" ) - assert config.get("target_host") == target_host - assert config.get("namespace") == "custom-namespace" + assert config["target_host"] == target_host + assert config["namespace"] == "custom-namespace" rpc_metadata = config.get("rpc_metadata") assert rpc_metadata assert "custom-header" in rpc_metadata @@ -470,8 +466,7 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): assert new_client.service_client.config.target_host == target_host assert new_client.namespace == "custom-namespace" assert ( - new_client.service_client.config.rpc_metadata["custom-header"] - == "custom-value" + new_client.service_client.config.rpc_metadata["custom-header"] == "custom-value" ) # Test with env overrides @@ -479,8 +474,8 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config( config_file=str(config_file), override_env_vars=env ) - assert config.get("target_host") == target_host - assert config.get("namespace") == "env-namespace-override" + assert config["target_host"] == target_host + assert config["namespace"] == "env-namespace-override" new_client = await Client.connect(**config) assert new_client.namespace == "env-namespace-override" @@ -490,8 +485,8 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): override_env_vars={"TEMPORAL_NAMESPACE": "ignored"}, disable_env=True, ) - assert config.get("target_host") == target_host - assert config.get("namespace") == namespace + assert config["target_host"] == target_host + assert config["namespace"] == namespace new_client = await Client.connect(**config) assert new_client.namespace == namespace @@ -503,13 +498,14 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config( disable_file=True, override_env_vars=env ) - assert config.get("target_host") == target_host - assert config.get("namespace") == "env-only-namespace" + assert config["target_host"] == target_host + assert config["namespace"] == "env-only-namespace" new_client = await Client.connect(**config) assert new_client.service_client.config.target_host == target_host assert new_client.namespace == "env-only-namespace" + def test_disables_raise_error(): """Test that providing both disable_file and disable_env raises an error.""" - with pytest.raises(ValueError, match="Cannot disable both"): - ClientConfig.load_profile(disable_file=True, disable_env=True) + with pytest.raises(RuntimeError, match="Cannot disable both"): + ClientConfigProfile.load(disable_file=True, disable_env=True) From 68efb13e545267f4dab5d85d752a48b88d5b829a Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Thu, 26 Jun 2025 00:40:10 +0200 Subject: [PATCH 15/17] add to_dict methods --- temporalio/envconfig.py | 92 +++++++++++++++++++++++++++++++++-- tests/test_envconfig.py | 105 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 193 insertions(+), 4 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index f2e020559..fa6519de9 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -20,6 +20,27 @@ ] # str represents a file contents, bytes represents raw data +# We define typed dictionaries for what these configs look like as TOML. +class ClientTLSConfigDict(TypedDict, total=False): + """Dictionary representation of TLS config for TOML.""" + + disabled: bool + server_name: str + server_ca_cert: Mapping[str, str] + client_cert: Mapping[str, str] + client_key: Mapping[str, str] + + +class ClientConfigProfileDict(TypedDict, total=False): + """Dictionary representation of a client config profile for TOML.""" + + address: str + namespace: str + api_key: str + tls: ClientTLSConfigDict + grpc_meta: Mapping[str, str] + + def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource]: if not d: return None @@ -30,6 +51,18 @@ def _from_dict_to_source(d: Optional[Mapping[str, Any]]) -> Optional[DataSource] return None +def _source_to_dict( + source: Optional[DataSource], +) -> Optional[Mapping[str, str]]: + if isinstance(source, Path): + return {"path": str(source)} + if isinstance(source, str): + return {"data": source} + if isinstance(source, bytes): + return {"data": source.decode("utf-8")} + return None + + def _source_to_path_and_data( source: Optional[DataSource], ) -> tuple[Optional[str], Optional[bytes]]: @@ -83,6 +116,24 @@ class ClientConfigTLS: client_private_key: Optional[DataSource] = None """Client key source.""" + def to_dict(self) -> ClientTLSConfigDict: + """Convert to a dictionary that can be used for TOML serialization.""" + d: dict[str, Any] = {} + if self.disabled: + d["disabled"] = self.disabled + if self.server_name is not None: + d["server_name"] = self.server_name + + if self.server_root_ca_cert is not None: + d["server_ca_cert"] = _source_to_dict(self.server_root_ca_cert) + if self.client_cert is not None: + d["client_cert"] = _source_to_dict(self.client_cert) + if self.client_private_key is not None: + d["client_key"] = _source_to_dict(self.client_private_key) + # To please the type checker, we have to cast. This is because + # ClientTLSConfigDict is a TypedDict and d is a regular dict. + return d # type: ignore + def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]: """Create a `temporalio.service.TLSConfig` from this profile.""" if self.disabled: @@ -96,7 +147,8 @@ def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]: ) @staticmethod - def _from_dict(d: Optional[Mapping[str, Any]]) -> Optional[ClientConfigTLS]: + def from_dict(d: Optional[ClientTLSConfigDict]) -> Optional[ClientConfigTLS]: + """Create a ClientConfigTLS from a dictionary.""" if not d: return None return ClientConfigTLS( @@ -149,16 +201,35 @@ class ClientConfigProfile: """gRPC metadata.""" @staticmethod - def from_dict(d: Mapping[str, Any]) -> ClientConfigProfile: + def from_dict(d: ClientConfigProfileDict) -> ClientConfigProfile: """Create a ClientConfigProfile from a dictionary.""" return ClientConfigProfile( address=d.get("address"), namespace=d.get("namespace"), api_key=d.get("api_key"), - tls=ClientConfigTLS._from_dict(d.get("tls")), + tls=ClientConfigTLS.from_dict(d.get("tls")), grpc_meta=d.get("grpc_meta") or {}, ) + def to_dict(self) -> ClientConfigProfileDict: + """Convert to a dictionary that can be used for TOML serialization.""" + d: dict[str, Any] = {} + if self.address is not None: + d["address"] = self.address + if self.namespace is not None: + d["namespace"] = self.namespace + if self.api_key is not None: + d["api_key"] = self.api_key + if self.tls is not None: + tls_dict = self.tls.to_dict() + if tls_dict: + d["tls"] = tls_dict + if self.grpc_meta: + d["grpc_meta"] = self.grpc_meta + # To please the type checker, we have to cast. This is because + # ClientConfigProfileDict is a TypedDict and d is a regular dict. + return d # type: ignore + def to_client_connect_config(self) -> ClientConnectConfig: """Create a `ClientConnectConfig` from this profile.""" config: ClientConnectConfig = {} @@ -238,6 +309,10 @@ class ClientConfig: profiles: Mapping[str, ClientConfigProfile] """Map of profile name to its corresponding ClientConfigProfile.""" + def to_dict(self) -> Mapping[str, ClientConfigProfileDict]: + """Convert to a dictionary that can be used for TOML serialization.""" + return {k: v.to_dict() for k, v in self.profiles.items()} + @staticmethod def _from_bridge_profiles( bridge_profiles: Mapping[str, Mapping[str, Any]], @@ -286,7 +361,7 @@ def load( config_file_strict=config_file_strict, env_vars=env_vars, ) - return ClientConfig._from_bridge_profiles(loaded_profiles) + return ClientConfig.from_dict(loaded_profiles) @staticmethod def load_client_connect_config( @@ -334,3 +409,12 @@ def load_client_connect_config( env_vars=override_env_vars, ) return prof.to_client_connect_config() + + @staticmethod + def from_dict( + d: Mapping[str, ClientConfigProfileDict], + ) -> ClientConfig: + """Create a ClientConfig from a dictionary.""" + return ClientConfig( + profiles={k: ClientConfigProfile.from_dict(v) for k, v in d.items()} + ) diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index 3a8727cbb..f56219f68 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -509,3 +509,108 @@ def test_disables_raise_error(): """Test that providing both disable_file and disable_env raises an error.""" with pytest.raises(RuntimeError, match="Cannot disable both"): ClientConfigProfile.load(disable_file=True, disable_env=True) + + +def test_client_config_profile_to_from_dict(): + """Test round-trip ClientConfigProfile to and from a dictionary.""" + # Profile with all fields + profile = ClientConfigProfile( + address="some-address", + namespace="some-namespace", + api_key="some-api-key", + tls=ClientConfigTLS( + disabled=False, + server_name="some-server-name", + server_root_ca_cert=b"ca-cert-data", + client_cert=Path("/path/to/client.crt"), + client_private_key="client-key-data", + ), + grpc_meta={"some-header": "some-value"}, + ) + + profile_dict = profile.to_dict() + + # Check dict representation. Note that disabled=False is not in the dict. + expected_dict = { + "address": "some-address", + "namespace": "some-namespace", + "api_key": "some-api-key", + "tls": { + "server_name": "some-server-name", + "server_ca_cert": {"data": "ca-cert-data"}, + "client_cert": {"path": str(Path("/path/to/client.crt"))}, + "client_key": {"data": "client-key-data"}, + }, + "grpc_meta": {"some-header": "some-value"}, + } + assert profile_dict == expected_dict + + # Convert back to profile + new_profile = ClientConfigProfile.from_dict(profile_dict) + + # We expect the new profile to be the same, but with bytes-based data + # sources converted to strings. This is because to_dict converts + # bytes-based data to a string, suitable for TOML. So we only have + # a string representation to work with. + expected_new_profile = ClientConfigProfile( + address="some-address", + namespace="some-namespace", + api_key="some-api-key", + tls=ClientConfigTLS( + disabled=False, + server_name="some-server-name", + server_root_ca_cert="ca-cert-data", # Was bytes, now str + client_cert=Path("/path/to/client.crt"), + client_private_key="client-key-data", + ), + grpc_meta={"some-header": "some-value"}, + ) + assert new_profile == expected_new_profile + + # Test with minimal profile + profile_minimal = ClientConfigProfile() + profile_minimal_dict = profile_minimal.to_dict() + assert profile_minimal_dict == {} + new_profile_minimal = ClientConfigProfile.from_dict(profile_minimal_dict) + assert profile_minimal == new_profile_minimal + + +def test_client_config_to_from_dict(): + """Test round-trip ClientConfig to and from a dictionary.""" + # Config with multiple profiles + profile1 = ClientConfigProfile( + address="some-address", + namespace="some-namespace", + ) + profile2 = ClientConfigProfile( + address="another-address", + tls=ClientConfigTLS(server_name="some-server-name"), + grpc_meta={"some-header": "some-value"}, + ) + config = ClientConfig(profiles={"default": profile1, "custom": profile2}) + + config_dict = config.to_dict() + + expected_dict = { + "default": { + "address": "some-address", + "namespace": "some-namespace", + }, + "custom": { + "address": "another-address", + "tls": {"server_name": "some-server-name"}, + "grpc_meta": {"some-header": "some-value"}, + }, + } + assert config_dict == expected_dict + + # Convert back to config + new_config = ClientConfig.from_dict(config_dict) + assert config == new_config + + # Test empty config + empty_config = ClientConfig(profiles={}) + empty_config_dict = empty_config.to_dict() + assert empty_config_dict == {} + new_empty_config = ClientConfig.from_dict(empty_config_dict) + assert empty_config == new_empty_config From f62727aafc192c1443d78bef81e4461f46e995ff Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Thu, 26 Jun 2025 00:50:23 +0200 Subject: [PATCH 16/17] formatting + linting --- temporalio/bridge/src/envconfig.rs | 1 + temporalio/envconfig.py | 21 ++++++++------------- tests/test_envconfig.py | 30 +++++++++++++++--------------- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/temporalio/bridge/src/envconfig.rs b/temporalio/bridge/src/envconfig.rs index 30999314e..d79a25510 100644 --- a/temporalio/bridge/src/envconfig.rs +++ b/temporalio/bridge/src/envconfig.rs @@ -160,6 +160,7 @@ pub fn load_client_config( #[pyfunction] #[pyo3(signature = (profile, path, data, disable_file, disable_env, config_file_strict, env_vars = None))] +#[allow(clippy::too_many_arguments)] pub fn load_client_connect_config( py: Python, profile: Option, diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index fa6519de9..f6556a0e8 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Mapping, Optional, Union +from typing import Any, Mapping, Optional, Union, cast from typing_extensions import TypeAlias, TypedDict @@ -314,12 +314,16 @@ def to_dict(self) -> Mapping[str, ClientConfigProfileDict]: return {k: v.to_dict() for k, v in self.profiles.items()} @staticmethod - def _from_bridge_profiles( - bridge_profiles: Mapping[str, Mapping[str, Any]], + def from_dict( + d: Mapping[str, Mapping[str, Any]], ) -> ClientConfig: + """Create a ClientConfig from a dictionary.""" + # We must cast the inner dictionary because the source is often a plain + # Mapping[str, Any] from the bridge or other sources. return ClientConfig( profiles={ - k: ClientConfigProfile.from_dict(v) for k, v in bridge_profiles.items() + k: ClientConfigProfile.from_dict(cast(ClientConfigProfileDict, v)) + for k, v in d.items() } ) @@ -409,12 +413,3 @@ def load_client_connect_config( env_vars=override_env_vars, ) return prof.to_client_connect_config() - - @staticmethod - def from_dict( - d: Mapping[str, ClientConfigProfileDict], - ) -> ClientConfig: - """Create a ClientConfig from a dictionary.""" - return ClientConfig( - profiles={k: ClientConfigProfile.from_dict(v) for k, v in d.items()} - ) diff --git a/tests/test_envconfig.py b/tests/test_envconfig.py index f56219f68..aa62782d4 100644 --- a/tests/test_envconfig.py +++ b/tests/test_envconfig.py @@ -447,9 +447,9 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): # Test with explicit file path, default profile config = ClientConfig.load_client_connect_config(config_file=str(config_file)) - assert config["target_host"] == target_host - assert config["namespace"] == namespace - new_client = await Client.connect(**config) + assert config.get("target_host") == target_host + assert config.get("namespace") == namespace + new_client = await Client.connect(**config) # type: ignore assert new_client.service_client.config.target_host == target_host assert new_client.namespace == namespace @@ -457,12 +457,12 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config( config_file=str(config_file), profile="custom" ) - assert config["target_host"] == target_host - assert config["namespace"] == "custom-namespace" + assert config.get("target_host") == target_host + assert config.get("namespace") == "custom-namespace" rpc_metadata = config.get("rpc_metadata") assert rpc_metadata assert "custom-header" in rpc_metadata - new_client = await Client.connect(**config) + new_client = await Client.connect(**config) # type: ignore assert new_client.service_client.config.target_host == target_host assert new_client.namespace == "custom-namespace" assert ( @@ -474,9 +474,9 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config( config_file=str(config_file), override_env_vars=env ) - assert config["target_host"] == target_host - assert config["namespace"] == "env-namespace-override" - new_client = await Client.connect(**config) + assert config.get("target_host") == target_host + assert config.get("namespace") == "env-namespace-override" + new_client = await Client.connect(**config) # type: ignore assert new_client.namespace == "env-namespace-override" # Test with env overrides disabled @@ -485,9 +485,9 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): override_env_vars={"TEMPORAL_NAMESPACE": "ignored"}, disable_env=True, ) - assert config["target_host"] == target_host - assert config["namespace"] == namespace - new_client = await Client.connect(**config) + assert config.get("target_host") == target_host + assert config.get("namespace") == namespace + new_client = await Client.connect(**config) # type: ignore assert new_client.namespace == namespace # Test with file loading disabled (so only env is used) @@ -498,9 +498,9 @@ async def test_load_client_connect_config(client: Client, tmp_path: Path): config = ClientConfig.load_client_connect_config( disable_file=True, override_env_vars=env ) - assert config["target_host"] == target_host - assert config["namespace"] == "env-only-namespace" - new_client = await Client.connect(**config) + assert config.get("target_host") == target_host + assert config.get("namespace") == "env-only-namespace" + new_client = await Client.connect(**config) # type: ignore assert new_client.service_client.config.target_host == target_host assert new_client.namespace == "env-only-namespace" From c4f134584733c69cb466353cb661ffb23c51469a Mon Sep 17 00:00:00 2001 From: Thomas Hardy Date: Thu, 26 Jun 2025 16:42:32 +0200 Subject: [PATCH 17/17] class rename, some cleanup with type casting --- temporalio/envconfig.py | 44 ++++++++++++++++++++--------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/temporalio/envconfig.py b/temporalio/envconfig.py index f6556a0e8..8eac61a37 100644 --- a/temporalio/envconfig.py +++ b/temporalio/envconfig.py @@ -8,7 +8,7 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Mapping, Optional, Union, cast +from typing import Any, Literal, Mapping, Optional, Union, cast from typing_extensions import TypeAlias, TypedDict @@ -21,7 +21,7 @@ # We define typed dictionaries for what these configs look like as TOML. -class ClientTLSConfigDict(TypedDict, total=False): +class ClientConfigTLSDict(TypedDict, total=False): """Dictionary representation of TLS config for TOML.""" disabled: bool @@ -37,7 +37,7 @@ class ClientConfigProfileDict(TypedDict, total=False): address: str namespace: str api_key: str - tls: ClientTLSConfigDict + tls: ClientConfigTLSDict grpc_meta: Mapping[str, str] @@ -116,23 +116,25 @@ class ClientConfigTLS: client_private_key: Optional[DataSource] = None """Client key source.""" - def to_dict(self) -> ClientTLSConfigDict: + def to_dict(self) -> ClientConfigTLSDict: """Convert to a dictionary that can be used for TOML serialization.""" - d: dict[str, Any] = {} + d: ClientConfigTLSDict = {} if self.disabled: d["disabled"] = self.disabled if self.server_name is not None: d["server_name"] = self.server_name - if self.server_root_ca_cert is not None: - d["server_ca_cert"] = _source_to_dict(self.server_root_ca_cert) - if self.client_cert is not None: - d["client_cert"] = _source_to_dict(self.client_cert) - if self.client_private_key is not None: - d["client_key"] = _source_to_dict(self.client_private_key) - # To please the type checker, we have to cast. This is because - # ClientTLSConfigDict is a TypedDict and d is a regular dict. - return d # type: ignore + def set_source( + key: Literal["server_ca_cert", "client_cert", "client_key"], + source: Optional[DataSource], + ): + if source is not None and (val := _source_to_dict(source)): + d[key] = val + + set_source("server_ca_cert", self.server_root_ca_cert) + set_source("client_cert", self.client_cert) + set_source("client_key", self.client_private_key) + return d def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]: """Create a `temporalio.service.TLSConfig` from this profile.""" @@ -147,7 +149,7 @@ def to_connect_tls_config(self) -> Union[bool, temporalio.service.TLSConfig]: ) @staticmethod - def from_dict(d: Optional[ClientTLSConfigDict]) -> Optional[ClientConfigTLS]: + def from_dict(d: Optional[ClientConfigTLSDict]) -> Optional[ClientConfigTLS]: """Create a ClientConfigTLS from a dictionary.""" if not d: return None @@ -213,22 +215,18 @@ def from_dict(d: ClientConfigProfileDict) -> ClientConfigProfile: def to_dict(self) -> ClientConfigProfileDict: """Convert to a dictionary that can be used for TOML serialization.""" - d: dict[str, Any] = {} + d: ClientConfigProfileDict = {} if self.address is not None: d["address"] = self.address if self.namespace is not None: d["namespace"] = self.namespace if self.api_key is not None: d["api_key"] = self.api_key - if self.tls is not None: - tls_dict = self.tls.to_dict() - if tls_dict: - d["tls"] = tls_dict + if self.tls and (tls_dict := self.tls.to_dict()): + d["tls"] = tls_dict if self.grpc_meta: d["grpc_meta"] = self.grpc_meta - # To please the type checker, we have to cast. This is because - # ClientConfigProfileDict is a TypedDict and d is a regular dict. - return d # type: ignore + return d def to_client_connect_config(self) -> ClientConnectConfig: """Create a `ClientConnectConfig` from this profile."""