diff --git a/.release-please-manifest.json b/.release-please-manifest.json index a26ebfc..8f3e0a4 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.14.0" + ".": "0.15.0" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index c7c071b..364b612 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1 +1 @@ -configured_endpoints: 45 +configured_endpoints: 46 diff --git a/CHANGELOG.md b/CHANGELOG.md index fceec81..21f5612 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,17 @@ # Changelog +## 0.15.0 (2024-01-24) + +Full Changelog: [v0.14.0...v0.15.0](https://github.com/Dataherald/dataherald-python/compare/v0.14.0...v0.15.0) + +### Features + +* **api:** OpenAPI spec update ([#30](https://github.com/Dataherald/dataherald-python/issues/30)) ([d5328bf](https://github.com/Dataherald/dataherald-python/commit/d5328bfab08dc97760ce01661c3f83665bafd389)) +* **api:** OpenAPI spec update ([#32](https://github.com/Dataherald/dataherald-python/issues/32)) ([e443463](https://github.com/Dataherald/dataherald-python/commit/e443463f508753125e7b48582a4432156959b898)) +* **api:** OpenAPI spec update ([#33](https://github.com/Dataherald/dataherald-python/issues/33)) ([1c8a887](https://github.com/Dataherald/dataherald-python/commit/1c8a88761d37f060a0835735fabe8e75980ae2d7)) +* **api:** OpenAPI spec update ([#34](https://github.com/Dataherald/dataherald-python/issues/34)) ([6e1ac9c](https://github.com/Dataherald/dataherald-python/commit/6e1ac9cf083f2382260f44f5f2627cbbc6b3d8f2)) +* **api:** OpenAPI spec update ([#35](https://github.com/Dataherald/dataherald-python/issues/35)) ([317b743](https://github.com/Dataherald/dataherald-python/commit/317b74340bebc295024fcadb0a4ff3aeebefe06e)) + ## 0.14.0 (2024-01-16) Full Changelog: [v0.13.0...v0.14.0](https://github.com/Dataherald/dataherald-python/compare/v0.13.0...v0.14.0) diff --git a/README.md b/README.md index 386b04f..c640664 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,10 @@ client = Dataherald( environment="staging", ) -db_connection_response = client.database_connections.create() +db_connection_response = client.database_connections.create( + alias="string", + connection_uri="string", +) print(db_connection_response.id) ``` @@ -58,7 +61,10 @@ client = AsyncDataherald( async def main() -> None: - db_connection_response = await client.database_connections.create() + db_connection_response = await client.database_connections.create( + alias="string", + connection_uri="string", + ) print(db_connection_response.id) @@ -92,7 +98,10 @@ from dataherald import Dataherald client = Dataherald() try: - client.database_connections.create() + client.database_connections.create( + alias="string", + connection_uri="string", + ) except dataherald.APIConnectionError as e: print("The server could not be reached") print(e.__cause__) # an underlying Exception, likely raised within httpx. @@ -135,7 +144,10 @@ client = Dataherald( ) # Or, configure per-request: -client.with_options(max_retries=5).database_connections.create() +client.with_options(max_retries=5).database_connections.create( + alias="string", + connection_uri="string", +) ``` ### Timeouts @@ -158,7 +170,10 @@ client = Dataherald( ) # Override per-request: -client.with_options(timeout=5 * 1000).database_connections.create() +client.with_options(timeout=5 * 1000).database_connections.create( + alias="string", + connection_uri="string", +) ``` On timeout, an `APITimeoutError` is thrown. @@ -197,7 +212,10 @@ The "raw" Response object can be accessed by prefixing `.with_raw_response.` to from dataherald import Dataherald client = Dataherald() -response = client.database_connections.with_raw_response.create() +response = client.database_connections.with_raw_response.create( + alias="string", + connection_uri="string", +) print(response.headers.get('X-My-Header')) database_connection = response.parse() # get the object that `database_connections.create()` would have returned @@ -215,7 +233,10 @@ The above interface eagerly reads the full response body when you make the reque To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods. ```python -with client.database_connections.with_streaming_response.create() as response: +with client.database_connections.with_streaming_response.create( + alias="string", + connection_uri="string", +) as response: print(response.headers.get("X-My-Header")) for line in response.iter_lines(): diff --git a/api.md b/api.md index 0f35e86..cd89b93 100644 --- a/api.md +++ b/api.md @@ -82,6 +82,7 @@ from dataherald.types import InstructionListResponse, InstructionDeleteResponse Methods: - client.instructions.create(\*\*params) -> InstructionResponse +- client.instructions.retrieve(id) -> InstructionResponse - client.instructions.update(id, \*\*params) -> InstructionResponse - client.instructions.list(\*\*params) -> InstructionListResponse - client.instructions.delete(id) -> object diff --git a/pyproject.toml b/pyproject.toml index ee5c2bd..e80ac0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dataherald" -version = "0.14.0" +version = "0.15.0" description = "The official Python library for the Dataherald API" readme = "README.md" license = "Apache-2.0" diff --git a/src/dataherald/_base_client.py b/src/dataherald/_base_client.py index 2a630de..5c1695e 100644 --- a/src/dataherald/_base_client.py +++ b/src/dataherald/_base_client.py @@ -73,7 +73,9 @@ from ._constants import ( DEFAULT_LIMITS, DEFAULT_TIMEOUT, + MAX_RETRY_DELAY, DEFAULT_MAX_RETRIES, + INITIAL_RETRY_DELAY, RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER, ) @@ -589,6 +591,40 @@ def base_url(self, url: URL | str) -> None: def platform_headers(self) -> Dict[str, str]: return platform_headers(self._version) + def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None: + """Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified. + + About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After + See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax + """ + if response_headers is None: + return None + + # First, try the non-standard `retry-after-ms` header for milliseconds, + # which is more precise than integer-seconds `retry-after` + try: + retry_ms_header = response_headers.get("retry-after-ms", None) + return float(retry_ms_header) / 1000 + except (TypeError, ValueError): + pass + + # Next, try parsing `retry-after` header as seconds (allowing nonstandard floats). + retry_header = response_headers.get("retry-after") + try: + # note: the spec indicates that this should only ever be an integer + # but if someone sends a float there's no reason for us to not respect it + return float(retry_header) + except (TypeError, ValueError): + pass + + # Last, try parsing `retry-after` as a date. + retry_date_tuple = email.utils.parsedate_tz(retry_header) + if retry_date_tuple is None: + return None + + retry_date = email.utils.mktime_tz(retry_date_tuple) + return float(retry_date - time.time()) + def _calculate_retry_timeout( self, remaining_retries: int, @@ -596,40 +632,16 @@ def _calculate_retry_timeout( response_headers: Optional[httpx.Headers] = None, ) -> float: max_retries = options.get_max_retries(self.max_retries) - try: - # About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After - # - # ". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for - # details. - if response_headers is not None: - retry_header = response_headers.get("retry-after") - try: - # note: the spec indicates that this should only ever be an integer - # but if someone sends a float there's no reason for us to not respect it - retry_after = float(retry_header) - except Exception: - retry_date_tuple = email.utils.parsedate_tz(retry_header) - if retry_date_tuple is None: - retry_after = -1 - else: - retry_date = email.utils.mktime_tz(retry_date_tuple) - retry_after = int(retry_date - time.time()) - else: - retry_after = -1 - - except Exception: - retry_after = -1 # If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says. - if 0 < retry_after <= 60: + retry_after = self._parse_retry_after_header(response_headers) + if retry_after is not None and 0 < retry_after <= 60: return retry_after - initial_retry_delay = 0.5 - max_retry_delay = 8.0 nb_retries = max_retries - remaining_retries # Apply exponential backoff, but not more than the max. - sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay) + sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY) # Apply some jitter, plus-or-minus half a second. jitter = 1 - 0.25 * random() @@ -764,6 +776,7 @@ def __init__( proxies=proxies, transport=transport, limits=limits, + follow_redirects=True, ) def is_closed(self) -> bool: @@ -1292,6 +1305,7 @@ def __init__( proxies=proxies, transport=transport, limits=limits, + follow_redirects=True, ) def is_closed(self) -> bool: diff --git a/src/dataherald/_compat.py b/src/dataherald/_compat.py index 3cda399..74c7639 100644 --- a/src/dataherald/_compat.py +++ b/src/dataherald/_compat.py @@ -1,13 +1,15 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Union, TypeVar, cast +from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload from datetime import date, datetime +from typing_extensions import Self import pydantic from pydantic.fields import FieldInfo from ._types import StrBytesIntFloat +_T = TypeVar("_T") _ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel) # --------------- Pydantic v2 compatibility --------------- @@ -178,8 +180,43 @@ class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): # cached properties if TYPE_CHECKING: cached_property = property + + # we define a separate type (copied from typeshed) + # that represents that `cached_property` is `set`able + # at runtime, which differs from `@property`. + # + # this is a separate type as editors likely special case + # `@property` and we don't want to cause issues just to have + # more helpful internal types. + + class typed_cached_property(Generic[_T]): + func: Callable[[Any], _T] + attrname: str | None + + def __init__(self, func: Callable[[Any], _T]) -> None: + ... + + @overload + def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: + ... + + @overload + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: + ... + + def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self: + raise NotImplementedError() + + def __set_name__(self, owner: type[Any], name: str) -> None: + ... + + # __set__ is not defined at runtime, but @cached_property is designed to be settable + def __set__(self, instance: object, value: _T) -> None: + ... else: try: from functools import cached_property as cached_property except ImportError: from cached_property import cached_property as cached_property + + typed_cached_property = cached_property diff --git a/src/dataherald/_constants.py b/src/dataherald/_constants.py index 76b21f0..bf15141 100644 --- a/src/dataherald/_constants.py +++ b/src/dataherald/_constants.py @@ -9,3 +9,6 @@ DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0) DEFAULT_MAX_RETRIES = 2 DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20) + +INITIAL_RETRY_DELAY = 0.5 +MAX_RETRY_DELAY = 8.0 diff --git a/src/dataherald/_utils/__init__.py b/src/dataherald/_utils/__init__.py index 2dcfc12..0fb811a 100644 --- a/src/dataherald/_utils/__init__.py +++ b/src/dataherald/_utils/__init__.py @@ -1,3 +1,4 @@ +from ._sync import asyncify as asyncify from ._proxy import LazyProxy as LazyProxy from ._utils import ( flatten as flatten, diff --git a/src/dataherald/_utils/_proxy.py b/src/dataherald/_utils/_proxy.py index 3c9e790..6f05efc 100644 --- a/src/dataherald/_utils/_proxy.py +++ b/src/dataherald/_utils/_proxy.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar, Iterable, cast -from typing_extensions import ClassVar, override +from typing_extensions import override T = TypeVar("T") @@ -13,11 +13,6 @@ class LazyProxy(Generic[T], ABC): This includes forwarding attribute access and othe methods. """ - should_cache: ClassVar[bool] = False - - def __init__(self) -> None: - self.__proxied: T | None = None - # Note: we have to special case proxies that themselves return proxies # to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz` @@ -57,18 +52,7 @@ def __class__(self) -> type: return proxied.__class__ def __get_proxied__(self) -> T: - if not self.should_cache: - return self.__load__() - - proxied = self.__proxied - if proxied is not None: - return proxied - - self.__proxied = proxied = self.__load__() - return proxied - - def __set_proxied__(self, value: T) -> None: - self.__proxied = value + return self.__load__() def __as_proxied__(self) -> T: """Helper method that returns the current proxy, typed as the loaded object""" diff --git a/src/dataherald/_utils/_sync.py b/src/dataherald/_utils/_sync.py new file mode 100644 index 0000000..595924e --- /dev/null +++ b/src/dataherald/_utils/_sync.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +import functools +from typing import TypeVar, Callable, Awaitable +from typing_extensions import ParamSpec + +import anyio +import anyio.to_thread + +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") + + +# copied from `asyncer`, https://github.com/tiangolo/asyncer +def asyncify( + function: Callable[T_ParamSpec, T_Retval], + *, + cancellable: bool = False, + limiter: anyio.CapacityLimiter | None = None, +) -> Callable[T_ParamSpec, Awaitable[T_Retval]]: + """ + Take a blocking function and create an async one that receives the same + positional and keyword arguments, and that when called, calls the original function + in a worker thread using `anyio.to_thread.run_sync()`. Internally, + `asyncer.asyncify()` uses the same `anyio.to_thread.run_sync()`, but it supports + keyword arguments additional to positional arguments and it adds better support for + autocompletion and inline errors for the arguments of the function called and the + return value. + + If the `cancellable` option is enabled and the task waiting for its completion is + cancelled, the thread will still run its course but its return value (or any raised + exception) will be ignored. + + Use it like this: + + ```Python + def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str: + # Do work + return "Some result" + + + result = await to_thread.asyncify(do_work)("spam", "ham", kwarg1="a", kwarg2="b") + print(result) + ``` + + ## Arguments + + `function`: a blocking regular callable (e.g. a function) + `cancellable`: `True` to allow cancellation of the operation + `limiter`: capacity limiter to use to limit the total amount of threads running + (if omitted, the default limiter is used) + + ## Return + + An async function that takes the same positional and keyword arguments as the + original one, that when called runs the same original function in a thread worker + and returns the result. + """ + + async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: + partial_f = functools.partial(function, *args, **kwargs) + return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter) + + return wrapper diff --git a/src/dataherald/_utils/_typing.py b/src/dataherald/_utils/_typing.py index b5e2c2e..a020822 100644 --- a/src/dataherald/_utils/_typing.py +++ b/src/dataherald/_utils/_typing.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, cast +from typing import Any, TypeVar, cast from typing_extensions import Required, Annotated, get_args, get_origin from .._types import InheritsGeneric @@ -23,6 +23,12 @@ def is_required_type(typ: type) -> bool: return get_origin(typ) == Required +def is_typevar(typ: type) -> bool: + # type ignore is required because type checkers + # think this expression will always return False + return type(typ) == TypeVar # type: ignore + + # Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]] def strip_annotated_type(typ: type) -> type: if is_required_type(typ) or is_annotated_type(typ): @@ -49,6 +55,15 @@ class MyResponse(Foo[bytes]): extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes ``` + + And where a generic subclass is given: + ```py + _T = TypeVar('_T') + class MyResponse(Foo[_T]): + ... + + extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes + ``` """ cls = cast(object, get_origin(typ) or typ) if cls in generic_bases: @@ -75,6 +90,18 @@ class MyResponse(Foo[bytes]): f"Does {cls} inherit from one of {generic_bases} ?" ) - return extract_type_arg(target_base_class, index) + extracted = extract_type_arg(target_base_class, index) + if is_typevar(extracted): + # If the extracted type argument is itself a type variable + # then that means the subclass itself is generic, so we have + # to resolve the type argument from the class itself, not + # the base class. + # + # Note: if there is more than 1 type argument, the subclass could + # change the ordering of the type arguments, this is not currently + # supported. + return extract_type_arg(typ, index) + + return extracted raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/src/dataherald/_version.py b/src/dataherald/_version.py index 4a05a1e..893cb09 100644 --- a/src/dataherald/_version.py +++ b/src/dataherald/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. __title__ = "dataherald" -__version__ = "0.14.0" # x-release-please-version +__version__ = "0.15.0" # x-release-please-version diff --git a/src/dataherald/resources/database_connections/database_connections.py b/src/dataherald/resources/database_connections/database_connections.py index 61b2c37..3a9fb39 100644 --- a/src/dataherald/resources/database_connections/database_connections.py +++ b/src/dataherald/resources/database_connections/database_connections.py @@ -53,9 +53,9 @@ def with_streaming_response(self) -> DatabaseConnectionsWithStreamingResponse: def create( self, *, - alias: str | NotGiven = NOT_GIVEN, - connection_uri: str | NotGiven = NOT_GIVEN, - credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, + alias: str, + connection_uri: str, + bigquery_credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, ssh_settings: database_connection_create_params.SSHSettings | NotGiven = NOT_GIVEN, @@ -85,7 +85,7 @@ def create( { "alias": alias, "connection_uri": connection_uri, - "credential_file_content": credential_file_content, + "bigquery_credential_file_content": bigquery_credential_file_content, "llm_api_key": llm_api_key, "metadata": metadata, "ssh_settings": ssh_settings, @@ -136,9 +136,9 @@ def update( self, id: str, *, - alias: str | NotGiven = NOT_GIVEN, - connection_uri: str | NotGiven = NOT_GIVEN, - credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, + alias: str, + connection_uri: str, + bigquery_credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, ssh_settings: database_connection_update_params.SSHSettings | NotGiven = NOT_GIVEN, @@ -170,7 +170,7 @@ def update( { "alias": alias, "connection_uri": connection_uri, - "credential_file_content": credential_file_content, + "bigquery_credential_file_content": bigquery_credential_file_content, "llm_api_key": llm_api_key, "metadata": metadata, "ssh_settings": ssh_settings, @@ -220,9 +220,9 @@ def with_streaming_response(self) -> AsyncDatabaseConnectionsWithStreamingRespon async def create( self, *, - alias: str | NotGiven = NOT_GIVEN, - connection_uri: str | NotGiven = NOT_GIVEN, - credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, + alias: str, + connection_uri: str, + bigquery_credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, ssh_settings: database_connection_create_params.SSHSettings | NotGiven = NOT_GIVEN, @@ -252,7 +252,7 @@ async def create( { "alias": alias, "connection_uri": connection_uri, - "credential_file_content": credential_file_content, + "bigquery_credential_file_content": bigquery_credential_file_content, "llm_api_key": llm_api_key, "metadata": metadata, "ssh_settings": ssh_settings, @@ -303,9 +303,9 @@ async def update( self, id: str, *, - alias: str | NotGiven = NOT_GIVEN, - connection_uri: str | NotGiven = NOT_GIVEN, - credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, + alias: str, + connection_uri: str, + bigquery_credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, ssh_settings: database_connection_update_params.SSHSettings | NotGiven = NOT_GIVEN, @@ -337,7 +337,7 @@ async def update( { "alias": alias, "connection_uri": connection_uri, - "credential_file_content": credential_file_content, + "bigquery_credential_file_content": bigquery_credential_file_content, "llm_api_key": llm_api_key, "metadata": metadata, "ssh_settings": ssh_settings, @@ -373,7 +373,7 @@ async def list( class DatabaseConnectionsWithRawResponse: def __init__(self, database_connections: DatabaseConnections) -> None: - self.drivers = DriversWithRawResponse(database_connections.drivers) + self._database_connections = database_connections self.create = to_raw_response_wrapper( database_connections.create, @@ -388,10 +388,14 @@ def __init__(self, database_connections: DatabaseConnections) -> None: database_connections.list, ) + @cached_property + def drivers(self) -> DriversWithRawResponse: + return DriversWithRawResponse(self._database_connections.drivers) + class AsyncDatabaseConnectionsWithRawResponse: def __init__(self, database_connections: AsyncDatabaseConnections) -> None: - self.drivers = AsyncDriversWithRawResponse(database_connections.drivers) + self._database_connections = database_connections self.create = async_to_raw_response_wrapper( database_connections.create, @@ -406,10 +410,14 @@ def __init__(self, database_connections: AsyncDatabaseConnections) -> None: database_connections.list, ) + @cached_property + def drivers(self) -> AsyncDriversWithRawResponse: + return AsyncDriversWithRawResponse(self._database_connections.drivers) + class DatabaseConnectionsWithStreamingResponse: def __init__(self, database_connections: DatabaseConnections) -> None: - self.drivers = DriversWithStreamingResponse(database_connections.drivers) + self._database_connections = database_connections self.create = to_streamed_response_wrapper( database_connections.create, @@ -424,10 +432,14 @@ def __init__(self, database_connections: DatabaseConnections) -> None: database_connections.list, ) + @cached_property + def drivers(self) -> DriversWithStreamingResponse: + return DriversWithStreamingResponse(self._database_connections.drivers) + class AsyncDatabaseConnectionsWithStreamingResponse: def __init__(self, database_connections: AsyncDatabaseConnections) -> None: - self.drivers = AsyncDriversWithStreamingResponse(database_connections.drivers) + self._database_connections = database_connections self.create = async_to_streamed_response_wrapper( database_connections.create, @@ -441,3 +453,7 @@ def __init__(self, database_connections: AsyncDatabaseConnections) -> None: self.list = async_to_streamed_response_wrapper( database_connections.list, ) + + @cached_property + def drivers(self) -> AsyncDriversWithStreamingResponse: + return AsyncDriversWithStreamingResponse(self._database_connections.drivers) diff --git a/src/dataherald/resources/database_connections/drivers.py b/src/dataherald/resources/database_connections/drivers.py index c9e5e1d..809c174 100644 --- a/src/dataherald/resources/database_connections/drivers.py +++ b/src/dataherald/resources/database_connections/drivers.py @@ -81,6 +81,8 @@ async def list( class DriversWithRawResponse: def __init__(self, drivers: Drivers) -> None: + self._drivers = drivers + self.list = to_raw_response_wrapper( drivers.list, ) @@ -88,6 +90,8 @@ def __init__(self, drivers: Drivers) -> None: class AsyncDriversWithRawResponse: def __init__(self, drivers: AsyncDrivers) -> None: + self._drivers = drivers + self.list = async_to_raw_response_wrapper( drivers.list, ) @@ -95,6 +99,8 @@ def __init__(self, drivers: AsyncDrivers) -> None: class DriversWithStreamingResponse: def __init__(self, drivers: Drivers) -> None: + self._drivers = drivers + self.list = to_streamed_response_wrapper( drivers.list, ) @@ -102,6 +108,8 @@ def __init__(self, drivers: Drivers) -> None: class AsyncDriversWithStreamingResponse: def __init__(self, drivers: AsyncDrivers) -> None: + self._drivers = drivers + self.list = async_to_streamed_response_wrapper( drivers.list, ) diff --git a/src/dataherald/resources/engine.py b/src/dataherald/resources/engine.py index 443134d..7233f66 100644 --- a/src/dataherald/resources/engine.py +++ b/src/dataherald/resources/engine.py @@ -80,6 +80,8 @@ async def heartbeat( class EngineWithRawResponse: def __init__(self, engine: Engine) -> None: + self._engine = engine + self.heartbeat = to_raw_response_wrapper( engine.heartbeat, ) @@ -87,6 +89,8 @@ def __init__(self, engine: Engine) -> None: class AsyncEngineWithRawResponse: def __init__(self, engine: AsyncEngine) -> None: + self._engine = engine + self.heartbeat = async_to_raw_response_wrapper( engine.heartbeat, ) @@ -94,6 +98,8 @@ def __init__(self, engine: AsyncEngine) -> None: class EngineWithStreamingResponse: def __init__(self, engine: Engine) -> None: + self._engine = engine + self.heartbeat = to_streamed_response_wrapper( engine.heartbeat, ) @@ -101,6 +107,8 @@ def __init__(self, engine: Engine) -> None: class AsyncEngineWithStreamingResponse: def __init__(self, engine: AsyncEngine) -> None: + self._engine = engine + self.heartbeat = async_to_streamed_response_wrapper( engine.heartbeat, ) diff --git a/src/dataherald/resources/finetunings.py b/src/dataherald/resources/finetunings.py index 5a51b96..7ada1ea 100644 --- a/src/dataherald/resources/finetunings.py +++ b/src/dataherald/resources/finetunings.py @@ -347,6 +347,8 @@ async def cancel( class FinetuningsWithRawResponse: def __init__(self, finetunings: Finetunings) -> None: + self._finetunings = finetunings + self.create = to_raw_response_wrapper( finetunings.create, ) @@ -363,6 +365,8 @@ def __init__(self, finetunings: Finetunings) -> None: class AsyncFinetuningsWithRawResponse: def __init__(self, finetunings: AsyncFinetunings) -> None: + self._finetunings = finetunings + self.create = async_to_raw_response_wrapper( finetunings.create, ) @@ -379,6 +383,8 @@ def __init__(self, finetunings: AsyncFinetunings) -> None: class FinetuningsWithStreamingResponse: def __init__(self, finetunings: Finetunings) -> None: + self._finetunings = finetunings + self.create = to_streamed_response_wrapper( finetunings.create, ) @@ -395,6 +401,8 @@ def __init__(self, finetunings: Finetunings) -> None: class AsyncFinetuningsWithStreamingResponse: def __init__(self, finetunings: AsyncFinetunings) -> None: + self._finetunings = finetunings + self.create = async_to_streamed_response_wrapper( finetunings.create, ) diff --git a/src/dataherald/resources/generations.py b/src/dataherald/resources/generations.py index bd6e592..dfe2330 100644 --- a/src/dataherald/resources/generations.py +++ b/src/dataherald/resources/generations.py @@ -503,6 +503,8 @@ async def sql_generation( class GenerationsWithRawResponse: def __init__(self, generations: Generations) -> None: + self._generations = generations + self.create = to_raw_response_wrapper( generations.create, ) @@ -525,6 +527,8 @@ def __init__(self, generations: Generations) -> None: class AsyncGenerationsWithRawResponse: def __init__(self, generations: AsyncGenerations) -> None: + self._generations = generations + self.create = async_to_raw_response_wrapper( generations.create, ) @@ -547,6 +551,8 @@ def __init__(self, generations: AsyncGenerations) -> None: class GenerationsWithStreamingResponse: def __init__(self, generations: Generations) -> None: + self._generations = generations + self.create = to_streamed_response_wrapper( generations.create, ) @@ -569,6 +575,8 @@ def __init__(self, generations: Generations) -> None: class AsyncGenerationsWithStreamingResponse: def __init__(self, generations: AsyncGenerations) -> None: + self._generations = generations + self.create = async_to_streamed_response_wrapper( generations.create, ) diff --git a/src/dataherald/resources/golden_sqls.py b/src/dataherald/resources/golden_sqls.py index 52fe144..6a626ed 100644 --- a/src/dataherald/resources/golden_sqls.py +++ b/src/dataherald/resources/golden_sqls.py @@ -340,6 +340,8 @@ async def upload( class GoldenSqlsWithRawResponse: def __init__(self, golden_sqls: GoldenSqls) -> None: + self._golden_sqls = golden_sqls + self.retrieve = to_raw_response_wrapper( golden_sqls.retrieve, ) @@ -356,6 +358,8 @@ def __init__(self, golden_sqls: GoldenSqls) -> None: class AsyncGoldenSqlsWithRawResponse: def __init__(self, golden_sqls: AsyncGoldenSqls) -> None: + self._golden_sqls = golden_sqls + self.retrieve = async_to_raw_response_wrapper( golden_sqls.retrieve, ) @@ -372,6 +376,8 @@ def __init__(self, golden_sqls: AsyncGoldenSqls) -> None: class GoldenSqlsWithStreamingResponse: def __init__(self, golden_sqls: GoldenSqls) -> None: + self._golden_sqls = golden_sqls + self.retrieve = to_streamed_response_wrapper( golden_sqls.retrieve, ) @@ -388,6 +394,8 @@ def __init__(self, golden_sqls: GoldenSqls) -> None: class AsyncGoldenSqlsWithStreamingResponse: def __init__(self, golden_sqls: AsyncGoldenSqls) -> None: + self._golden_sqls = golden_sqls + self.retrieve = async_to_streamed_response_wrapper( golden_sqls.retrieve, ) diff --git a/src/dataherald/resources/heartbeat.py b/src/dataherald/resources/heartbeat.py index e6abaed..93c2d7a 100644 --- a/src/dataherald/resources/heartbeat.py +++ b/src/dataherald/resources/heartbeat.py @@ -80,6 +80,8 @@ async def retrieve( class HeartbeatWithRawResponse: def __init__(self, heartbeat: Heartbeat) -> None: + self._heartbeat = heartbeat + self.retrieve = to_raw_response_wrapper( heartbeat.retrieve, ) @@ -87,6 +89,8 @@ def __init__(self, heartbeat: Heartbeat) -> None: class AsyncHeartbeatWithRawResponse: def __init__(self, heartbeat: AsyncHeartbeat) -> None: + self._heartbeat = heartbeat + self.retrieve = async_to_raw_response_wrapper( heartbeat.retrieve, ) @@ -94,6 +98,8 @@ def __init__(self, heartbeat: AsyncHeartbeat) -> None: class HeartbeatWithStreamingResponse: def __init__(self, heartbeat: Heartbeat) -> None: + self._heartbeat = heartbeat + self.retrieve = to_streamed_response_wrapper( heartbeat.retrieve, ) @@ -101,6 +107,8 @@ def __init__(self, heartbeat: Heartbeat) -> None: class AsyncHeartbeatWithStreamingResponse: def __init__(self, heartbeat: AsyncHeartbeat) -> None: + self._heartbeat = heartbeat + self.retrieve = async_to_streamed_response_wrapper( heartbeat.retrieve, ) diff --git a/src/dataherald/resources/instructions/first.py b/src/dataherald/resources/instructions/first.py index 26921f7..fd5d04b 100644 --- a/src/dataherald/resources/instructions/first.py +++ b/src/dataherald/resources/instructions/first.py @@ -81,6 +81,8 @@ async def retrieve( class FirstWithRawResponse: def __init__(self, first: First) -> None: + self._first = first + self.retrieve = to_raw_response_wrapper( first.retrieve, ) @@ -88,6 +90,8 @@ def __init__(self, first: First) -> None: class AsyncFirstWithRawResponse: def __init__(self, first: AsyncFirst) -> None: + self._first = first + self.retrieve = async_to_raw_response_wrapper( first.retrieve, ) @@ -95,6 +99,8 @@ def __init__(self, first: AsyncFirst) -> None: class FirstWithStreamingResponse: def __init__(self, first: First) -> None: + self._first = first + self.retrieve = to_streamed_response_wrapper( first.retrieve, ) @@ -102,6 +108,8 @@ def __init__(self, first: First) -> None: class AsyncFirstWithStreamingResponse: def __init__(self, first: AsyncFirst) -> None: + self._first = first + self.retrieve = async_to_streamed_response_wrapper( first.retrieve, ) diff --git a/src/dataherald/resources/instructions/instructions.py b/src/dataherald/resources/instructions/instructions.py index 6aad586..b9b03cc 100644 --- a/src/dataherald/resources/instructions/instructions.py +++ b/src/dataherald/resources/instructions/instructions.py @@ -90,6 +90,39 @@ def create( cast_to=InstructionResponse, ) + def retrieve( + self, + id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InstructionResponse: + """ + Get Instruction + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") + return self._get( + f"/api/instructions/{id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=InstructionResponse, + ) + def update( self, id: str, @@ -259,6 +292,39 @@ async def create( cast_to=InstructionResponse, ) + async def retrieve( + self, + id: str, + *, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> InstructionResponse: + """ + Get Instruction + + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") + return await self._get( + f"/api/instructions/{id}", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=InstructionResponse, + ) + async def update( self, id: str, @@ -376,11 +442,14 @@ async def delete( class InstructionsWithRawResponse: def __init__(self, instructions: Instructions) -> None: - self.first = FirstWithRawResponse(instructions.first) + self._instructions = instructions self.create = to_raw_response_wrapper( instructions.create, ) + self.retrieve = to_raw_response_wrapper( + instructions.retrieve, + ) self.update = to_raw_response_wrapper( instructions.update, ) @@ -391,14 +460,21 @@ def __init__(self, instructions: Instructions) -> None: instructions.delete, ) + @cached_property + def first(self) -> FirstWithRawResponse: + return FirstWithRawResponse(self._instructions.first) + class AsyncInstructionsWithRawResponse: def __init__(self, instructions: AsyncInstructions) -> None: - self.first = AsyncFirstWithRawResponse(instructions.first) + self._instructions = instructions self.create = async_to_raw_response_wrapper( instructions.create, ) + self.retrieve = async_to_raw_response_wrapper( + instructions.retrieve, + ) self.update = async_to_raw_response_wrapper( instructions.update, ) @@ -409,14 +485,21 @@ def __init__(self, instructions: AsyncInstructions) -> None: instructions.delete, ) + @cached_property + def first(self) -> AsyncFirstWithRawResponse: + return AsyncFirstWithRawResponse(self._instructions.first) + class InstructionsWithStreamingResponse: def __init__(self, instructions: Instructions) -> None: - self.first = FirstWithStreamingResponse(instructions.first) + self._instructions = instructions self.create = to_streamed_response_wrapper( instructions.create, ) + self.retrieve = to_streamed_response_wrapper( + instructions.retrieve, + ) self.update = to_streamed_response_wrapper( instructions.update, ) @@ -427,14 +510,21 @@ def __init__(self, instructions: Instructions) -> None: instructions.delete, ) + @cached_property + def first(self) -> FirstWithStreamingResponse: + return FirstWithStreamingResponse(self._instructions.first) + class AsyncInstructionsWithStreamingResponse: def __init__(self, instructions: AsyncInstructions) -> None: - self.first = AsyncFirstWithStreamingResponse(instructions.first) + self._instructions = instructions self.create = async_to_streamed_response_wrapper( instructions.create, ) + self.retrieve = async_to_streamed_response_wrapper( + instructions.retrieve, + ) self.update = async_to_streamed_response_wrapper( instructions.update, ) @@ -444,3 +534,7 @@ def __init__(self, instructions: AsyncInstructions) -> None: self.delete = async_to_streamed_response_wrapper( instructions.delete, ) + + @cached_property + def first(self) -> AsyncFirstWithStreamingResponse: + return AsyncFirstWithStreamingResponse(self._instructions.first) diff --git a/src/dataherald/resources/nl_generations.py b/src/dataherald/resources/nl_generations.py index 9563a9a..6e1a72c 100644 --- a/src/dataherald/resources/nl_generations.py +++ b/src/dataherald/resources/nl_generations.py @@ -285,6 +285,8 @@ async def list( class NlGenerationsWithRawResponse: def __init__(self, nl_generations: NlGenerations) -> None: + self._nl_generations = nl_generations + self.create = to_raw_response_wrapper( nl_generations.create, ) @@ -298,6 +300,8 @@ def __init__(self, nl_generations: NlGenerations) -> None: class AsyncNlGenerationsWithRawResponse: def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self._nl_generations = nl_generations + self.create = async_to_raw_response_wrapper( nl_generations.create, ) @@ -311,6 +315,8 @@ def __init__(self, nl_generations: AsyncNlGenerations) -> None: class NlGenerationsWithStreamingResponse: def __init__(self, nl_generations: NlGenerations) -> None: + self._nl_generations = nl_generations + self.create = to_streamed_response_wrapper( nl_generations.create, ) @@ -324,6 +330,8 @@ def __init__(self, nl_generations: NlGenerations) -> None: class AsyncNlGenerationsWithStreamingResponse: def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self._nl_generations = nl_generations + self.create = async_to_streamed_response_wrapper( nl_generations.create, ) diff --git a/src/dataherald/resources/prompts/prompts.py b/src/dataherald/resources/prompts/prompts.py index eea29db..3df86c4 100644 --- a/src/dataherald/resources/prompts/prompts.py +++ b/src/dataherald/resources/prompts/prompts.py @@ -300,7 +300,7 @@ async def list( class PromptsWithRawResponse: def __init__(self, prompts: Prompts) -> None: - self.sql_generations = SqlGenerationsWithRawResponse(prompts.sql_generations) + self._prompts = prompts self.create = to_raw_response_wrapper( prompts.create, @@ -312,10 +312,14 @@ def __init__(self, prompts: Prompts) -> None: prompts.list, ) + @cached_property + def sql_generations(self) -> SqlGenerationsWithRawResponse: + return SqlGenerationsWithRawResponse(self._prompts.sql_generations) + class AsyncPromptsWithRawResponse: def __init__(self, prompts: AsyncPrompts) -> None: - self.sql_generations = AsyncSqlGenerationsWithRawResponse(prompts.sql_generations) + self._prompts = prompts self.create = async_to_raw_response_wrapper( prompts.create, @@ -327,10 +331,14 @@ def __init__(self, prompts: AsyncPrompts) -> None: prompts.list, ) + @cached_property + def sql_generations(self) -> AsyncSqlGenerationsWithRawResponse: + return AsyncSqlGenerationsWithRawResponse(self._prompts.sql_generations) + class PromptsWithStreamingResponse: def __init__(self, prompts: Prompts) -> None: - self.sql_generations = SqlGenerationsWithStreamingResponse(prompts.sql_generations) + self._prompts = prompts self.create = to_streamed_response_wrapper( prompts.create, @@ -342,10 +350,14 @@ def __init__(self, prompts: Prompts) -> None: prompts.list, ) + @cached_property + def sql_generations(self) -> SqlGenerationsWithStreamingResponse: + return SqlGenerationsWithStreamingResponse(self._prompts.sql_generations) + class AsyncPromptsWithStreamingResponse: def __init__(self, prompts: AsyncPrompts) -> None: - self.sql_generations = AsyncSqlGenerationsWithStreamingResponse(prompts.sql_generations) + self._prompts = prompts self.create = async_to_streamed_response_wrapper( prompts.create, @@ -356,3 +368,7 @@ def __init__(self, prompts: AsyncPrompts) -> None: self.list = async_to_streamed_response_wrapper( prompts.list, ) + + @cached_property + def sql_generations(self) -> AsyncSqlGenerationsWithStreamingResponse: + return AsyncSqlGenerationsWithStreamingResponse(self._prompts.sql_generations) diff --git a/src/dataherald/resources/prompts/sql_generations.py b/src/dataherald/resources/prompts/sql_generations.py index ab6b467..1092d98 100644 --- a/src/dataherald/resources/prompts/sql_generations.py +++ b/src/dataherald/resources/prompts/sql_generations.py @@ -327,6 +327,8 @@ async def nl_generations( class SqlGenerationsWithRawResponse: def __init__(self, sql_generations: SqlGenerations) -> None: + self._sql_generations = sql_generations + self.create = to_raw_response_wrapper( sql_generations.create, ) @@ -340,6 +342,8 @@ def __init__(self, sql_generations: SqlGenerations) -> None: class AsyncSqlGenerationsWithRawResponse: def __init__(self, sql_generations: AsyncSqlGenerations) -> None: + self._sql_generations = sql_generations + self.create = async_to_raw_response_wrapper( sql_generations.create, ) @@ -353,6 +357,8 @@ def __init__(self, sql_generations: AsyncSqlGenerations) -> None: class SqlGenerationsWithStreamingResponse: def __init__(self, sql_generations: SqlGenerations) -> None: + self._sql_generations = sql_generations + self.create = to_streamed_response_wrapper( sql_generations.create, ) @@ -366,6 +372,8 @@ def __init__(self, sql_generations: SqlGenerations) -> None: class AsyncSqlGenerationsWithStreamingResponse: def __init__(self, sql_generations: AsyncSqlGenerations) -> None: + self._sql_generations = sql_generations + self.create = async_to_streamed_response_wrapper( sql_generations.create, ) diff --git a/src/dataherald/resources/sql_generations/nl_generations.py b/src/dataherald/resources/sql_generations/nl_generations.py index 8ed3183..3101d26 100644 --- a/src/dataherald/resources/sql_generations/nl_generations.py +++ b/src/dataherald/resources/sql_generations/nl_generations.py @@ -227,6 +227,8 @@ async def retrieve( class NlGenerationsWithRawResponse: def __init__(self, nl_generations: NlGenerations) -> None: + self._nl_generations = nl_generations + self.create = to_raw_response_wrapper( nl_generations.create, ) @@ -237,6 +239,8 @@ def __init__(self, nl_generations: NlGenerations) -> None: class AsyncNlGenerationsWithRawResponse: def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self._nl_generations = nl_generations + self.create = async_to_raw_response_wrapper( nl_generations.create, ) @@ -247,6 +251,8 @@ def __init__(self, nl_generations: AsyncNlGenerations) -> None: class NlGenerationsWithStreamingResponse: def __init__(self, nl_generations: NlGenerations) -> None: + self._nl_generations = nl_generations + self.create = to_streamed_response_wrapper( nl_generations.create, ) @@ -257,6 +263,8 @@ def __init__(self, nl_generations: NlGenerations) -> None: class AsyncNlGenerationsWithStreamingResponse: def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self._nl_generations = nl_generations + self.create = async_to_streamed_response_wrapper( nl_generations.create, ) diff --git a/src/dataherald/resources/sql_generations/sql_generations.py b/src/dataherald/resources/sql_generations/sql_generations.py index f2bb54e..4870601 100644 --- a/src/dataherald/resources/sql_generations/sql_generations.py +++ b/src/dataherald/resources/sql_generations/sql_generations.py @@ -391,7 +391,7 @@ async def execute( class SqlGenerationsWithRawResponse: def __init__(self, sql_generations: SqlGenerations) -> None: - self.nl_generations = NlGenerationsWithRawResponse(sql_generations.nl_generations) + self._sql_generations = sql_generations self.create = to_raw_response_wrapper( sql_generations.create, @@ -406,10 +406,14 @@ def __init__(self, sql_generations: SqlGenerations) -> None: sql_generations.execute, ) + @cached_property + def nl_generations(self) -> NlGenerationsWithRawResponse: + return NlGenerationsWithRawResponse(self._sql_generations.nl_generations) + class AsyncSqlGenerationsWithRawResponse: def __init__(self, sql_generations: AsyncSqlGenerations) -> None: - self.nl_generations = AsyncNlGenerationsWithRawResponse(sql_generations.nl_generations) + self._sql_generations = sql_generations self.create = async_to_raw_response_wrapper( sql_generations.create, @@ -424,10 +428,14 @@ def __init__(self, sql_generations: AsyncSqlGenerations) -> None: sql_generations.execute, ) + @cached_property + def nl_generations(self) -> AsyncNlGenerationsWithRawResponse: + return AsyncNlGenerationsWithRawResponse(self._sql_generations.nl_generations) + class SqlGenerationsWithStreamingResponse: def __init__(self, sql_generations: SqlGenerations) -> None: - self.nl_generations = NlGenerationsWithStreamingResponse(sql_generations.nl_generations) + self._sql_generations = sql_generations self.create = to_streamed_response_wrapper( sql_generations.create, @@ -442,10 +450,14 @@ def __init__(self, sql_generations: SqlGenerations) -> None: sql_generations.execute, ) + @cached_property + def nl_generations(self) -> NlGenerationsWithStreamingResponse: + return NlGenerationsWithStreamingResponse(self._sql_generations.nl_generations) + class AsyncSqlGenerationsWithStreamingResponse: def __init__(self, sql_generations: AsyncSqlGenerations) -> None: - self.nl_generations = AsyncNlGenerationsWithStreamingResponse(sql_generations.nl_generations) + self._sql_generations = sql_generations self.create = async_to_streamed_response_wrapper( sql_generations.create, @@ -459,3 +471,7 @@ def __init__(self, sql_generations: AsyncSqlGenerations) -> None: self.execute = async_to_streamed_response_wrapper( sql_generations.execute, ) + + @cached_property + def nl_generations(self) -> AsyncNlGenerationsWithStreamingResponse: + return AsyncNlGenerationsWithStreamingResponse(self._sql_generations.nl_generations) diff --git a/src/dataherald/resources/table_descriptions.py b/src/dataherald/resources/table_descriptions.py index d19efc7..d7d015e 100644 --- a/src/dataherald/resources/table_descriptions.py +++ b/src/dataherald/resources/table_descriptions.py @@ -164,8 +164,7 @@ def list( def sync_schemas( self, *, - db_connection_id: str, - table_names: List[str] | NotGiven = NOT_GIVEN, + body: List[table_description_sync_schemas_params.Body], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -187,13 +186,7 @@ def sync_schemas( """ return self._post( "/api/table-descriptions/sync-schemas", - body=maybe_transform( - { - "db_connection_id": db_connection_id, - "table_names": table_names, - }, - table_description_sync_schemas_params.TableDescriptionSyncSchemasParams, - ), + body=maybe_transform(body, table_description_sync_schemas_params.TableDescriptionSyncSchemasParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -334,8 +327,7 @@ async def list( async def sync_schemas( self, *, - db_connection_id: str, - table_names: List[str] | NotGiven = NOT_GIVEN, + body: List[table_description_sync_schemas_params.Body], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -357,13 +349,7 @@ async def sync_schemas( """ return await self._post( "/api/table-descriptions/sync-schemas", - body=maybe_transform( - { - "db_connection_id": db_connection_id, - "table_names": table_names, - }, - table_description_sync_schemas_params.TableDescriptionSyncSchemasParams, - ), + body=maybe_transform(body, table_description_sync_schemas_params.TableDescriptionSyncSchemasParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -373,6 +359,8 @@ async def sync_schemas( class TableDescriptionsWithRawResponse: def __init__(self, table_descriptions: TableDescriptions) -> None: + self._table_descriptions = table_descriptions + self.retrieve = to_raw_response_wrapper( table_descriptions.retrieve, ) @@ -389,6 +377,8 @@ def __init__(self, table_descriptions: TableDescriptions) -> None: class AsyncTableDescriptionsWithRawResponse: def __init__(self, table_descriptions: AsyncTableDescriptions) -> None: + self._table_descriptions = table_descriptions + self.retrieve = async_to_raw_response_wrapper( table_descriptions.retrieve, ) @@ -405,6 +395,8 @@ def __init__(self, table_descriptions: AsyncTableDescriptions) -> None: class TableDescriptionsWithStreamingResponse: def __init__(self, table_descriptions: TableDescriptions) -> None: + self._table_descriptions = table_descriptions + self.retrieve = to_streamed_response_wrapper( table_descriptions.retrieve, ) @@ -421,6 +413,8 @@ def __init__(self, table_descriptions: TableDescriptions) -> None: class AsyncTableDescriptionsWithStreamingResponse: def __init__(self, table_descriptions: AsyncTableDescriptions) -> None: + self._table_descriptions = table_descriptions + self.retrieve = async_to_streamed_response_wrapper( table_descriptions.retrieve, ) diff --git a/src/dataherald/types/database_connection_create_params.py b/src/dataherald/types/database_connection_create_params.py index dba52af..1fa71aa 100644 --- a/src/dataherald/types/database_connection_create_params.py +++ b/src/dataherald/types/database_connection_create_params.py @@ -3,17 +3,17 @@ from __future__ import annotations from typing import Union -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict __all__ = ["DatabaseConnectionCreateParams", "SSHSettings"] class DatabaseConnectionCreateParams(TypedDict, total=False): - alias: str + alias: Required[str] - connection_uri: str + connection_uri: Required[str] - credential_file_content: Union[object, str] + bigquery_credential_file_content: Union[object, str] llm_api_key: str @@ -25,20 +25,8 @@ class DatabaseConnectionCreateParams(TypedDict, total=False): class SSHSettings(TypedDict, total=False): - db_driver: str - - db_name: str - host: str password: str - private_key_password: str - - remote_db_name: str - - remote_db_password: str - - remote_host: str - username: str diff --git a/src/dataherald/types/database_connection_update_params.py b/src/dataherald/types/database_connection_update_params.py index 53eb8ec..2f51d0f 100644 --- a/src/dataherald/types/database_connection_update_params.py +++ b/src/dataherald/types/database_connection_update_params.py @@ -3,17 +3,17 @@ from __future__ import annotations from typing import Union -from typing_extensions import TypedDict +from typing_extensions import Required, TypedDict __all__ = ["DatabaseConnectionUpdateParams", "SSHSettings"] class DatabaseConnectionUpdateParams(TypedDict, total=False): - alias: str + alias: Required[str] - connection_uri: str + connection_uri: Required[str] - credential_file_content: Union[object, str] + bigquery_credential_file_content: Union[object, str] llm_api_key: str @@ -25,20 +25,8 @@ class DatabaseConnectionUpdateParams(TypedDict, total=False): class SSHSettings(TypedDict, total=False): - db_driver: str - - db_name: str - host: str password: str - private_key_password: str - - remote_db_name: str - - remote_db_password: str - - remote_host: str - username: str diff --git a/src/dataherald/types/db_connection_response.py b/src/dataherald/types/db_connection_response.py index 872930e..c18f52b 100644 --- a/src/dataherald/types/db_connection_response.py +++ b/src/dataherald/types/db_connection_response.py @@ -17,29 +17,21 @@ class Metadata(BaseModel): class SSHSettings(BaseModel): - db_driver: Optional[str] = None - - db_name: Optional[str] = None - host: Optional[str] = None password: Optional[str] = None private_key_password: Optional[str] = None - remote_db_name: Optional[str] = None - - remote_db_password: Optional[str] = None - - remote_host: Optional[str] = None - username: Optional[str] = None class DBConnectionResponse(BaseModel): - id: str + alias: str + + connection_uri: str - alias: Optional[str] = None + id: Optional[str] = None created_at: Optional[datetime] = None @@ -51,6 +43,4 @@ class DBConnectionResponse(BaseModel): ssh_settings: Optional[SSHSettings] = None - uri: Optional[str] = None - use_ssh: Optional[bool] = None diff --git a/src/dataherald/types/finetuning_response.py b/src/dataherald/types/finetuning_response.py index a48cbf3..9f432f0 100644 --- a/src/dataherald/types/finetuning_response.py +++ b/src/dataherald/types/finetuning_response.py @@ -29,7 +29,7 @@ class Metadata(BaseModel): class FinetuningResponse(BaseModel): id: str - _model_id: Optional[str] = None + api_model_id: Optional[str] = FieldInfo(alias="_model_id", default=None) base_llm: Optional[BaseLlm] = None diff --git a/src/dataherald/types/table_description_sync_schemas_params.py b/src/dataherald/types/table_description_sync_schemas_params.py index d968da4..880570c 100644 --- a/src/dataherald/types/table_description_sync_schemas_params.py +++ b/src/dataherald/types/table_description_sync_schemas_params.py @@ -5,10 +5,14 @@ from typing import List from typing_extensions import Required, TypedDict -__all__ = ["TableDescriptionSyncSchemasParams"] +__all__ = ["TableDescriptionSyncSchemasParams", "Body"] class TableDescriptionSyncSchemasParams(TypedDict, total=False): + body: Required[List[Body]] + + +class Body(TypedDict, total=False): db_connection_id: Required[str] table_names: List[str] diff --git a/tests/api_resources/database_connections/test_drivers.py b/tests/api_resources/database_connections/test_drivers.py index 5790065..e403bea 100644 --- a/tests/api_resources/database_connections/test_drivers.py +++ b/tests/api_resources/database_connections/test_drivers.py @@ -9,17 +9,13 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.database_connections import DriverListResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestDrivers: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_list(self, client: Dataherald) -> None: @@ -48,18 +44,16 @@ def test_streaming_response_list(self, client: Dataherald) -> None: class TestAsyncDrivers: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - driver = await client.database_connections.drivers.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + driver = await async_client.database_connections.drivers.list() assert_matches_type(DriverListResponse, driver, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.database_connections.drivers.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.database_connections.drivers.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -67,8 +61,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(DriverListResponse, driver, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.database_connections.drivers.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.database_connections.drivers.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/instructions/test_first.py b/tests/api_resources/instructions/test_first.py index 52ab76c..0bf1856 100644 --- a/tests/api_resources/instructions/test_first.py +++ b/tests/api_resources/instructions/test_first.py @@ -9,17 +9,13 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import InstructionResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestFirst: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_retrieve(self, client: Dataherald) -> None: @@ -48,18 +44,16 @@ def test_streaming_response_retrieve(self, client: Dataherald) -> None: class TestAsyncFirst: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - first = await client.instructions.first.retrieve() + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + first = await async_client.instructions.first.retrieve() assert_matches_type(InstructionResponse, first, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.instructions.first.with_raw_response.retrieve() + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.instructions.first.with_raw_response.retrieve() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -67,8 +61,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(InstructionResponse, first, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.instructions.first.with_streaming_response.retrieve() as response: + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.instructions.first.with_streaming_response.retrieve() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/prompts/test_sql_generations.py b/tests/api_resources/prompts/test_sql_generations.py index 1f1cce6..c59c0b9 100644 --- a/tests/api_resources/prompts/test_sql_generations.py +++ b/tests/api_resources/prompts/test_sql_generations.py @@ -9,17 +9,13 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import NlGenerationResponse, SqlGenerationResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestSqlGenerations: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -178,20 +174,18 @@ def test_path_params_nl_generations(self, client: Dataherald) -> None: class TestAsyncSqlGenerations: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - sql_generation = await client.prompts.sql_generations.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.prompts.sql_generations.create( "string", ) assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - sql_generation = await client.prompts.sql_generations.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.prompts.sql_generations.create( "string", evaluate=True, finetuning_id="string", @@ -201,8 +195,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.prompts.sql_generations.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.prompts.sql_generations.with_raw_response.create( "string", ) @@ -212,8 +206,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.prompts.sql_generations.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.prompts.sql_generations.with_streaming_response.create( "string", ) as response: assert not response.is_closed @@ -225,22 +219,22 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_create(self, client: AsyncDataherald) -> None: + async def test_path_params_create(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.prompts.sql_generations.with_raw_response.create( + await async_client.prompts.sql_generations.with_raw_response.create( "", ) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - sql_generation = await client.prompts.sql_generations.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.prompts.sql_generations.retrieve( "string", ) assert_matches_type(object, sql_generation, path=["response"]) @parametrize - async def test_method_retrieve_with_all_params(self, client: AsyncDataherald) -> None: - sql_generation = await client.prompts.sql_generations.retrieve( + async def test_method_retrieve_with_all_params(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.prompts.sql_generations.retrieve( "string", ascend=True, order="string", @@ -250,8 +244,8 @@ async def test_method_retrieve_with_all_params(self, client: AsyncDataherald) -> assert_matches_type(object, sql_generation, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.prompts.sql_generations.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.prompts.sql_generations.with_raw_response.retrieve( "string", ) @@ -261,8 +255,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(object, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.prompts.sql_generations.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.prompts.sql_generations.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -274,23 +268,23 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.prompts.sql_generations.with_raw_response.retrieve( + await async_client.prompts.sql_generations.with_raw_response.retrieve( "", ) @parametrize - async def test_method_nl_generations(self, client: AsyncDataherald) -> None: - sql_generation = await client.prompts.sql_generations.nl_generations( + async def test_method_nl_generations(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.prompts.sql_generations.nl_generations( "string", sql_generation={}, ) assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_method_nl_generations_with_all_params(self, client: AsyncDataherald) -> None: - sql_generation = await client.prompts.sql_generations.nl_generations( + async def test_method_nl_generations_with_all_params(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.prompts.sql_generations.nl_generations( "string", sql_generation={ "finetuning_id": "string", @@ -304,8 +298,8 @@ async def test_method_nl_generations_with_all_params(self, client: AsyncDatahera assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_raw_response_nl_generations(self, client: AsyncDataherald) -> None: - response = await client.prompts.sql_generations.with_raw_response.nl_generations( + async def test_raw_response_nl_generations(self, async_client: AsyncDataherald) -> None: + response = await async_client.prompts.sql_generations.with_raw_response.nl_generations( "string", sql_generation={}, ) @@ -316,8 +310,8 @@ async def test_raw_response_nl_generations(self, client: AsyncDataherald) -> Non assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_nl_generations(self, client: AsyncDataherald) -> None: - async with client.prompts.sql_generations.with_streaming_response.nl_generations( + async def test_streaming_response_nl_generations(self, async_client: AsyncDataherald) -> None: + async with async_client.prompts.sql_generations.with_streaming_response.nl_generations( "string", sql_generation={}, ) as response: @@ -330,9 +324,9 @@ async def test_streaming_response_nl_generations(self, client: AsyncDataherald) assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_nl_generations(self, client: AsyncDataherald) -> None: + async def test_path_params_nl_generations(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.prompts.sql_generations.with_raw_response.nl_generations( + await async_client.prompts.sql_generations.with_raw_response.nl_generations( "", sql_generation={}, ) diff --git a/tests/api_resources/sql_generations/test_nl_generations.py b/tests/api_resources/sql_generations/test_nl_generations.py index 7f3004e..f75bfa6 100644 --- a/tests/api_resources/sql_generations/test_nl_generations.py +++ b/tests/api_resources/sql_generations/test_nl_generations.py @@ -9,17 +9,13 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import NlGenerationResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestNlGenerations: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -119,20 +115,18 @@ def test_path_params_retrieve(self, client: Dataherald) -> None: class TestAsyncNlGenerations: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - nl_generation = await client.sql_generations.nl_generations.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.sql_generations.nl_generations.create( "string", ) assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - nl_generation = await client.sql_generations.nl_generations.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.sql_generations.nl_generations.create( "string", max_rows=0, metadata={}, @@ -140,8 +134,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.sql_generations.nl_generations.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.sql_generations.nl_generations.with_raw_response.create( "string", ) @@ -151,8 +145,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.sql_generations.nl_generations.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.sql_generations.nl_generations.with_streaming_response.create( "string", ) as response: assert not response.is_closed @@ -164,22 +158,22 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_create(self, client: AsyncDataherald) -> None: + async def test_path_params_create(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.sql_generations.nl_generations.with_raw_response.create( + await async_client.sql_generations.nl_generations.with_raw_response.create( "", ) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - nl_generation = await client.sql_generations.nl_generations.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.sql_generations.nl_generations.retrieve( "string", ) assert_matches_type(object, nl_generation, path=["response"]) @parametrize - async def test_method_retrieve_with_all_params(self, client: AsyncDataherald) -> None: - nl_generation = await client.sql_generations.nl_generations.retrieve( + async def test_method_retrieve_with_all_params(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.sql_generations.nl_generations.retrieve( "string", ascend=True, order="string", @@ -189,8 +183,8 @@ async def test_method_retrieve_with_all_params(self, client: AsyncDataherald) -> assert_matches_type(object, nl_generation, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.sql_generations.nl_generations.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.sql_generations.nl_generations.with_raw_response.retrieve( "string", ) @@ -200,8 +194,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(object, nl_generation, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.sql_generations.nl_generations.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.sql_generations.nl_generations.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -213,8 +207,8 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.sql_generations.nl_generations.with_raw_response.retrieve( + await async_client.sql_generations.nl_generations.with_raw_response.retrieve( "", ) diff --git a/tests/api_resources/test_database_connections.py b/tests/api_resources/test_database_connections.py index 0efd89b..a11a6fb 100644 --- a/tests/api_resources/test_database_connections.py +++ b/tests/api_resources/test_database_connections.py @@ -13,20 +13,19 @@ DBConnectionResponse, DatabaseConnectionListResponse, ) -from dataherald._client import Dataherald, AsyncDataherald base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestDatabaseConnections: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: - database_connection = client.database_connections.create() + database_connection = client.database_connections.create( + alias="string", + connection_uri="string", + ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize @@ -34,19 +33,13 @@ def test_method_create_with_all_params(self, client: Dataherald) -> None: database_connection = client.database_connections.create( alias="string", connection_uri="string", - credential_file_content={}, + bigquery_credential_file_content={}, llm_api_key="string", metadata={}, ssh_settings={ - "db_name": "string", "host": "string", "username": "string", "password": "string", - "remote_host": "string", - "remote_db_name": "string", - "remote_db_password": "string", - "private_key_password": "string", - "db_driver": "string", }, use_ssh=True, ) @@ -54,7 +47,10 @@ def test_method_create_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_create(self, client: Dataherald) -> None: - response = client.database_connections.with_raw_response.create() + response = client.database_connections.with_raw_response.create( + alias="string", + connection_uri="string", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -63,7 +59,10 @@ def test_raw_response_create(self, client: Dataherald) -> None: @parametrize def test_streaming_response_create(self, client: Dataherald) -> None: - with client.database_connections.with_streaming_response.create() as response: + with client.database_connections.with_streaming_response.create( + alias="string", + connection_uri="string", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -114,6 +113,8 @@ def test_path_params_retrieve(self, client: Dataherald) -> None: def test_method_update(self, client: Dataherald) -> None: database_connection = client.database_connections.update( "string", + alias="string", + connection_uri="string", ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @@ -123,19 +124,13 @@ def test_method_update_with_all_params(self, client: Dataherald) -> None: "string", alias="string", connection_uri="string", - credential_file_content={}, + bigquery_credential_file_content={}, llm_api_key="string", metadata={}, ssh_settings={ - "db_name": "string", "host": "string", "username": "string", "password": "string", - "remote_host": "string", - "remote_db_name": "string", - "remote_db_password": "string", - "private_key_password": "string", - "db_driver": "string", }, use_ssh=True, ) @@ -145,6 +140,8 @@ def test_method_update_with_all_params(self, client: Dataherald) -> None: def test_raw_response_update(self, client: Dataherald) -> None: response = client.database_connections.with_raw_response.update( "string", + alias="string", + connection_uri="string", ) assert response.is_closed is True @@ -156,6 +153,8 @@ def test_raw_response_update(self, client: Dataherald) -> None: def test_streaming_response_update(self, client: Dataherald) -> None: with client.database_connections.with_streaming_response.update( "string", + alias="string", + connection_uri="string", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -170,6 +169,8 @@ def test_path_params_update(self, client: Dataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): client.database_connections.with_raw_response.update( "", + alias="string", + connection_uri="string", ) @parametrize @@ -199,41 +200,39 @@ def test_streaming_response_list(self, client: Dataherald) -> None: class TestAsyncDatabaseConnections: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - database_connection = await client.database_connections.create() + async def test_method_create(self, async_client: AsyncDataherald) -> None: + database_connection = await async_client.database_connections.create( + alias="string", + connection_uri="string", + ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - database_connection = await client.database_connections.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + database_connection = await async_client.database_connections.create( alias="string", connection_uri="string", - credential_file_content={}, + bigquery_credential_file_content={}, llm_api_key="string", metadata={}, ssh_settings={ - "db_name": "string", "host": "string", "username": "string", "password": "string", - "remote_host": "string", - "remote_db_name": "string", - "remote_db_password": "string", - "private_key_password": "string", - "db_driver": "string", }, use_ssh=True, ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.database_connections.with_raw_response.create() + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.database_connections.with_raw_response.create( + alias="string", + connection_uri="string", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -241,8 +240,11 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.database_connections.with_streaming_response.create() as response: + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.database_connections.with_streaming_response.create( + alias="string", + connection_uri="string", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -252,15 +254,15 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - database_connection = await client.database_connections.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + database_connection = await async_client.database_connections.retrieve( "string", ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.database_connections.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.database_connections.with_raw_response.retrieve( "string", ) @@ -270,8 +272,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.database_connections.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.database_connections.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -283,47 +285,45 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.database_connections.with_raw_response.retrieve( + await async_client.database_connections.with_raw_response.retrieve( "", ) @parametrize - async def test_method_update(self, client: AsyncDataherald) -> None: - database_connection = await client.database_connections.update( + async def test_method_update(self, async_client: AsyncDataherald) -> None: + database_connection = await async_client.database_connections.update( "string", + alias="string", + connection_uri="string", ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_method_update_with_all_params(self, client: AsyncDataherald) -> None: - database_connection = await client.database_connections.update( + async def test_method_update_with_all_params(self, async_client: AsyncDataherald) -> None: + database_connection = await async_client.database_connections.update( "string", alias="string", connection_uri="string", - credential_file_content={}, + bigquery_credential_file_content={}, llm_api_key="string", metadata={}, ssh_settings={ - "db_name": "string", "host": "string", "username": "string", "password": "string", - "remote_host": "string", - "remote_db_name": "string", - "remote_db_password": "string", - "private_key_password": "string", - "db_driver": "string", }, use_ssh=True, ) assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_raw_response_update(self, client: AsyncDataherald) -> None: - response = await client.database_connections.with_raw_response.update( + async def test_raw_response_update(self, async_client: AsyncDataherald) -> None: + response = await async_client.database_connections.with_raw_response.update( "string", + alias="string", + connection_uri="string", ) assert response.is_closed is True @@ -332,9 +332,11 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) @parametrize - async def test_streaming_response_update(self, client: AsyncDataherald) -> None: - async with client.database_connections.with_streaming_response.update( + async def test_streaming_response_update(self, async_client: AsyncDataherald) -> None: + async with async_client.database_connections.with_streaming_response.update( "string", + alias="string", + connection_uri="string", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -345,20 +347,22 @@ async def test_streaming_response_update(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_update(self, client: AsyncDataherald) -> None: + async def test_path_params_update(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.database_connections.with_raw_response.update( + await async_client.database_connections.with_raw_response.update( "", + alias="string", + connection_uri="string", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - database_connection = await client.database_connections.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + database_connection = await async_client.database_connections.list() assert_matches_type(DatabaseConnectionListResponse, database_connection, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.database_connections.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.database_connections.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -366,8 +370,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(DatabaseConnectionListResponse, database_connection, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.database_connections.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.database_connections.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_engine.py b/tests/api_resources/test_engine.py index 0c920f7..77bd41e 100644 --- a/tests/api_resources/test_engine.py +++ b/tests/api_resources/test_engine.py @@ -9,16 +9,12 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type -from dataherald._client import Dataherald, AsyncDataherald base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestEngine: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_heartbeat(self, client: Dataherald) -> None: @@ -47,18 +43,16 @@ def test_streaming_response_heartbeat(self, client: Dataherald) -> None: class TestAsyncEngine: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_heartbeat(self, client: AsyncDataherald) -> None: - engine = await client.engine.heartbeat() + async def test_method_heartbeat(self, async_client: AsyncDataherald) -> None: + engine = await async_client.engine.heartbeat() assert_matches_type(object, engine, path=["response"]) @parametrize - async def test_raw_response_heartbeat(self, client: AsyncDataherald) -> None: - response = await client.engine.with_raw_response.heartbeat() + async def test_raw_response_heartbeat(self, async_client: AsyncDataherald) -> None: + response = await async_client.engine.with_raw_response.heartbeat() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -66,8 +60,8 @@ async def test_raw_response_heartbeat(self, client: AsyncDataherald) -> None: assert_matches_type(object, engine, path=["response"]) @parametrize - async def test_streaming_response_heartbeat(self, client: AsyncDataherald) -> None: - async with client.engine.with_streaming_response.heartbeat() as response: + async def test_streaming_response_heartbeat(self, async_client: AsyncDataherald) -> None: + async with async_client.engine.with_streaming_response.heartbeat() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_finetunings.py b/tests/api_resources/test_finetunings.py index 02a910e..013371f 100644 --- a/tests/api_resources/test_finetunings.py +++ b/tests/api_resources/test_finetunings.py @@ -13,16 +13,12 @@ FinetuningResponse, FinetuningListResponse, ) -from dataherald._client import Dataherald, AsyncDataherald base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestFinetunings: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -179,20 +175,18 @@ def test_path_params_cancel(self, client: Dataherald) -> None: class TestAsyncFinetunings: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - finetuning = await client.finetunings.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + finetuning = await async_client.finetunings.create( db_connection_id="string", ) assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - finetuning = await client.finetunings.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + finetuning = await async_client.finetunings.create( db_connection_id="string", alias="string", base_llm={ @@ -206,8 +200,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.finetunings.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.finetunings.with_raw_response.create( db_connection_id="string", ) @@ -217,8 +211,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.finetunings.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.finetunings.with_streaming_response.create( db_connection_id="string", ) as response: assert not response.is_closed @@ -230,15 +224,15 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - finetuning = await client.finetunings.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + finetuning = await async_client.finetunings.retrieve( "string", ) assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.finetunings.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.finetunings.with_raw_response.retrieve( "string", ) @@ -248,8 +242,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.finetunings.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.finetunings.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -261,22 +255,22 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.finetunings.with_raw_response.retrieve( + await async_client.finetunings.with_raw_response.retrieve( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - finetuning = await client.finetunings.list( + async def test_method_list(self, async_client: AsyncDataherald) -> None: + finetuning = await async_client.finetunings.list( db_connection_id="string", ) assert_matches_type(FinetuningListResponse, finetuning, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.finetunings.with_raw_response.list( + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.finetunings.with_raw_response.list( db_connection_id="string", ) @@ -286,8 +280,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(FinetuningListResponse, finetuning, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.finetunings.with_streaming_response.list( + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.finetunings.with_streaming_response.list( db_connection_id="string", ) as response: assert not response.is_closed @@ -299,15 +293,15 @@ async def test_streaming_response_list(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_cancel(self, client: AsyncDataherald) -> None: - finetuning = await client.finetunings.cancel( + async def test_method_cancel(self, async_client: AsyncDataherald) -> None: + finetuning = await async_client.finetunings.cancel( "string", ) assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_raw_response_cancel(self, client: AsyncDataherald) -> None: - response = await client.finetunings.with_raw_response.cancel( + async def test_raw_response_cancel(self, async_client: AsyncDataherald) -> None: + response = await async_client.finetunings.with_raw_response.cancel( "string", ) @@ -317,8 +311,8 @@ async def test_raw_response_cancel(self, client: AsyncDataherald) -> None: assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @parametrize - async def test_streaming_response_cancel(self, client: AsyncDataherald) -> None: - async with client.finetunings.with_streaming_response.cancel( + async def test_streaming_response_cancel(self, async_client: AsyncDataherald) -> None: + async with async_client.finetunings.with_streaming_response.cancel( "string", ) as response: assert not response.is_closed @@ -330,8 +324,8 @@ async def test_streaming_response_cancel(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_cancel(self, client: AsyncDataherald) -> None: + async def test_path_params_cancel(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.finetunings.with_raw_response.cancel( + await async_client.finetunings.with_raw_response.cancel( "", ) diff --git a/tests/api_resources/test_generations.py b/tests/api_resources/test_generations.py index 9687df2..7b90c0e 100644 --- a/tests/api_resources/test_generations.py +++ b/tests/api_resources/test_generations.py @@ -13,17 +13,13 @@ GenerationResponse, GenerationListResponse, ) -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import NlGenerationResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestGenerations: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -265,20 +261,18 @@ def test_path_params_sql_generation(self, client: Dataherald) -> None: class TestAsyncGenerations: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - generation = await client.generations.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.create( "string", ) assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.generations.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.generations.with_raw_response.create( "string", ) @@ -288,8 +282,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.generations.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.generations.with_streaming_response.create( "string", ) as response: assert not response.is_closed @@ -301,22 +295,22 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_create(self, client: AsyncDataherald) -> None: + async def test_path_params_create(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.generations.with_raw_response.create( + await async_client.generations.with_raw_response.create( "", ) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - generation = await client.generations.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.retrieve( "string", ) assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.generations.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.generations.with_raw_response.retrieve( "string", ) @@ -326,8 +320,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.generations.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.generations.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -339,22 +333,22 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.generations.with_raw_response.retrieve( + await async_client.generations.with_raw_response.retrieve( "", ) @parametrize - async def test_method_update(self, client: AsyncDataherald) -> None: - generation = await client.generations.update( + async def test_method_update(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.update( "string", ) assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_method_update_with_all_params(self, client: AsyncDataherald) -> None: - generation = await client.generations.update( + async def test_method_update_with_all_params(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.update( "string", generation_status="INITIALIZED", message="string", @@ -362,8 +356,8 @@ async def test_method_update_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_raw_response_update(self, client: AsyncDataherald) -> None: - response = await client.generations.with_raw_response.update( + async def test_raw_response_update(self, async_client: AsyncDataherald) -> None: + response = await async_client.generations.with_raw_response.update( "string", ) @@ -373,8 +367,8 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_streaming_response_update(self, client: AsyncDataherald) -> None: - async with client.generations.with_streaming_response.update( + async def test_streaming_response_update(self, async_client: AsyncDataherald) -> None: + async with async_client.generations.with_streaming_response.update( "string", ) as response: assert not response.is_closed @@ -386,20 +380,20 @@ async def test_streaming_response_update(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_update(self, client: AsyncDataherald) -> None: + async def test_path_params_update(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.generations.with_raw_response.update( + await async_client.generations.with_raw_response.update( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - generation = await client.generations.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.list() assert_matches_type(GenerationListResponse, generation, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, client: AsyncDataherald) -> None: - generation = await client.generations.list( + async def test_method_list_with_all_params(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.list( ascend=True, order="string", page=0, @@ -408,8 +402,8 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non assert_matches_type(GenerationListResponse, generation, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.generations.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.generations.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -417,8 +411,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(GenerationListResponse, generation, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.generations.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.generations.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -428,15 +422,15 @@ async def test_streaming_response_list(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_nl_generation(self, client: AsyncDataherald) -> None: - generation = await client.generations.nl_generation( + async def test_method_nl_generation(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.nl_generation( "string", ) assert_matches_type(NlGenerationResponse, generation, path=["response"]) @parametrize - async def test_raw_response_nl_generation(self, client: AsyncDataherald) -> None: - response = await client.generations.with_raw_response.nl_generation( + async def test_raw_response_nl_generation(self, async_client: AsyncDataherald) -> None: + response = await async_client.generations.with_raw_response.nl_generation( "string", ) @@ -446,8 +440,8 @@ async def test_raw_response_nl_generation(self, client: AsyncDataherald) -> None assert_matches_type(NlGenerationResponse, generation, path=["response"]) @parametrize - async def test_streaming_response_nl_generation(self, client: AsyncDataherald) -> None: - async with client.generations.with_streaming_response.nl_generation( + async def test_streaming_response_nl_generation(self, async_client: AsyncDataherald) -> None: + async with async_client.generations.with_streaming_response.nl_generation( "string", ) as response: assert not response.is_closed @@ -459,23 +453,23 @@ async def test_streaming_response_nl_generation(self, client: AsyncDataherald) - assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_nl_generation(self, client: AsyncDataherald) -> None: + async def test_path_params_nl_generation(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.generations.with_raw_response.nl_generation( + await async_client.generations.with_raw_response.nl_generation( "", ) @parametrize - async def test_method_sql_generation(self, client: AsyncDataherald) -> None: - generation = await client.generations.sql_generation( + async def test_method_sql_generation(self, async_client: AsyncDataherald) -> None: + generation = await async_client.generations.sql_generation( "string", sql="string", ) assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_raw_response_sql_generation(self, client: AsyncDataherald) -> None: - response = await client.generations.with_raw_response.sql_generation( + async def test_raw_response_sql_generation(self, async_client: AsyncDataherald) -> None: + response = await async_client.generations.with_raw_response.sql_generation( "string", sql="string", ) @@ -486,8 +480,8 @@ async def test_raw_response_sql_generation(self, client: AsyncDataherald) -> Non assert_matches_type(GenerationResponse, generation, path=["response"]) @parametrize - async def test_streaming_response_sql_generation(self, client: AsyncDataherald) -> None: - async with client.generations.with_streaming_response.sql_generation( + async def test_streaming_response_sql_generation(self, async_client: AsyncDataherald) -> None: + async with async_client.generations.with_streaming_response.sql_generation( "string", sql="string", ) as response: @@ -500,9 +494,9 @@ async def test_streaming_response_sql_generation(self, client: AsyncDataherald) assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_sql_generation(self, client: AsyncDataherald) -> None: + async def test_path_params_sql_generation(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.generations.with_raw_response.sql_generation( + await async_client.generations.with_raw_response.sql_generation( "", sql="string", ) diff --git a/tests/api_resources/test_golden_sqls.py b/tests/api_resources/test_golden_sqls.py index 22d8386..17f4bb1 100644 --- a/tests/api_resources/test_golden_sqls.py +++ b/tests/api_resources/test_golden_sqls.py @@ -13,17 +13,13 @@ GoldenSqlListResponse, GoldenSqlUploadResponse, ) -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import GoldenSqlResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestGoldenSqls: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_retrieve(self, client: Dataherald) -> None: @@ -217,20 +213,18 @@ def test_streaming_response_upload(self, client: Dataherald) -> None: class TestAsyncGoldenSqls: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - golden_sql = await client.golden_sqls.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + golden_sql = await async_client.golden_sqls.retrieve( "string", ) assert_matches_type(GoldenSqlResponse, golden_sql, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.golden_sqls.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.golden_sqls.with_raw_response.retrieve( "string", ) @@ -240,8 +234,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(GoldenSqlResponse, golden_sql, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.golden_sqls.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.golden_sqls.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -253,20 +247,20 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.golden_sqls.with_raw_response.retrieve( + await async_client.golden_sqls.with_raw_response.retrieve( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - golden_sql = await client.golden_sqls.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + golden_sql = await async_client.golden_sqls.list() assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, client: AsyncDataherald) -> None: - golden_sql = await client.golden_sqls.list( + async def test_method_list_with_all_params(self, async_client: AsyncDataherald) -> None: + golden_sql = await async_client.golden_sqls.list( ascend=True, order="string", page=0, @@ -275,8 +269,8 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.golden_sqls.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.golden_sqls.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -284,8 +278,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.golden_sqls.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.golden_sqls.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -295,15 +289,15 @@ async def test_streaming_response_list(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_delete(self, client: AsyncDataherald) -> None: - golden_sql = await client.golden_sqls.delete( + async def test_method_delete(self, async_client: AsyncDataherald) -> None: + golden_sql = await async_client.golden_sqls.delete( "string", ) assert_matches_type(object, golden_sql, path=["response"]) @parametrize - async def test_raw_response_delete(self, client: AsyncDataherald) -> None: - response = await client.golden_sqls.with_raw_response.delete( + async def test_raw_response_delete(self, async_client: AsyncDataherald) -> None: + response = await async_client.golden_sqls.with_raw_response.delete( "string", ) @@ -313,8 +307,8 @@ async def test_raw_response_delete(self, client: AsyncDataherald) -> None: assert_matches_type(object, golden_sql, path=["response"]) @parametrize - async def test_streaming_response_delete(self, client: AsyncDataherald) -> None: - async with client.golden_sqls.with_streaming_response.delete( + async def test_streaming_response_delete(self, async_client: AsyncDataherald) -> None: + async with async_client.golden_sqls.with_streaming_response.delete( "string", ) as response: assert not response.is_closed @@ -326,15 +320,15 @@ async def test_streaming_response_delete(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_delete(self, client: AsyncDataherald) -> None: + async def test_path_params_delete(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.golden_sqls.with_raw_response.delete( + await async_client.golden_sqls.with_raw_response.delete( "", ) @parametrize - async def test_method_upload(self, client: AsyncDataherald) -> None: - golden_sql = await client.golden_sqls.upload( + async def test_method_upload(self, async_client: AsyncDataherald) -> None: + golden_sql = await async_client.golden_sqls.upload( body=[ { "db_connection_id": "string", @@ -356,8 +350,8 @@ async def test_method_upload(self, client: AsyncDataherald) -> None: assert_matches_type(GoldenSqlUploadResponse, golden_sql, path=["response"]) @parametrize - async def test_raw_response_upload(self, client: AsyncDataherald) -> None: - response = await client.golden_sqls.with_raw_response.upload( + async def test_raw_response_upload(self, async_client: AsyncDataherald) -> None: + response = await async_client.golden_sqls.with_raw_response.upload( body=[ { "db_connection_id": "string", @@ -383,8 +377,8 @@ async def test_raw_response_upload(self, client: AsyncDataherald) -> None: assert_matches_type(GoldenSqlUploadResponse, golden_sql, path=["response"]) @parametrize - async def test_streaming_response_upload(self, client: AsyncDataherald) -> None: - async with client.golden_sqls.with_streaming_response.upload( + async def test_streaming_response_upload(self, async_client: AsyncDataherald) -> None: + async with async_client.golden_sqls.with_streaming_response.upload( body=[ { "db_connection_id": "string", diff --git a/tests/api_resources/test_heartbeat.py b/tests/api_resources/test_heartbeat.py index 366718f..b2d99a5 100644 --- a/tests/api_resources/test_heartbeat.py +++ b/tests/api_resources/test_heartbeat.py @@ -9,16 +9,12 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type -from dataherald._client import Dataherald, AsyncDataherald base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestHeartbeat: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_retrieve(self, client: Dataherald) -> None: @@ -47,18 +43,16 @@ def test_streaming_response_retrieve(self, client: Dataherald) -> None: class TestAsyncHeartbeat: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - heartbeat = await client.heartbeat.retrieve() + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + heartbeat = await async_client.heartbeat.retrieve() assert_matches_type(object, heartbeat, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.heartbeat.with_raw_response.retrieve() + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.heartbeat.with_raw_response.retrieve() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -66,8 +60,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(object, heartbeat, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.heartbeat.with_streaming_response.retrieve() as response: + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.heartbeat.with_streaming_response.retrieve() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_instructions.py b/tests/api_resources/test_instructions.py index bb55143..18bb20e 100644 --- a/tests/api_resources/test_instructions.py +++ b/tests/api_resources/test_instructions.py @@ -12,17 +12,13 @@ from dataherald.types import ( InstructionListResponse, ) -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import InstructionResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestInstructions: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -64,6 +60,44 @@ def test_streaming_response_create(self, client: Dataherald) -> None: assert cast(Any, response.is_closed) is True + @parametrize + def test_method_retrieve(self, client: Dataherald) -> None: + instruction = client.instructions.retrieve( + "string", + ) + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: Dataherald) -> None: + response = client.instructions.with_raw_response.retrieve( + "string", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + instruction = response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.instructions.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.instructions.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_update(self, client: Dataherald) -> None: instruction = client.instructions.update( @@ -187,20 +221,18 @@ def test_path_params_delete(self, client: Dataherald) -> None: class TestAsyncInstructions: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - instruction = await client.instructions.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.create( instruction="string", ) assert_matches_type(InstructionResponse, instruction, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - instruction = await client.instructions.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.create( instruction="string", db_connection_id="string", metadata={}, @@ -208,8 +240,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(InstructionResponse, instruction, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.instructions.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.instructions.with_raw_response.create( instruction="string", ) @@ -219,8 +251,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(InstructionResponse, instruction, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.instructions.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.instructions.with_streaming_response.create( instruction="string", ) as response: assert not response.is_closed @@ -232,16 +264,54 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_update(self, client: AsyncDataherald) -> None: - instruction = await client.instructions.update( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.retrieve( + "string", + ) + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.instructions.with_raw_response.retrieve( + "string", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + instruction = await response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.instructions.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = await response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await async_client.instructions.with_raw_response.retrieve( + "", + ) + + @parametrize + async def test_method_update(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.update( "string", instruction="string", ) assert_matches_type(InstructionResponse, instruction, path=["response"]) @parametrize - async def test_method_update_with_all_params(self, client: AsyncDataherald) -> None: - instruction = await client.instructions.update( + async def test_method_update_with_all_params(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.update( "string", instruction="string", db_connection_id="string", @@ -250,8 +320,8 @@ async def test_method_update_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(InstructionResponse, instruction, path=["response"]) @parametrize - async def test_raw_response_update(self, client: AsyncDataherald) -> None: - response = await client.instructions.with_raw_response.update( + async def test_raw_response_update(self, async_client: AsyncDataherald) -> None: + response = await async_client.instructions.with_raw_response.update( "string", instruction="string", ) @@ -262,8 +332,8 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: assert_matches_type(InstructionResponse, instruction, path=["response"]) @parametrize - async def test_streaming_response_update(self, client: AsyncDataherald) -> None: - async with client.instructions.with_streaming_response.update( + async def test_streaming_response_update(self, async_client: AsyncDataherald) -> None: + async with async_client.instructions.with_streaming_response.update( "string", instruction="string", ) as response: @@ -276,23 +346,23 @@ async def test_streaming_response_update(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_update(self, client: AsyncDataherald) -> None: + async def test_path_params_update(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.instructions.with_raw_response.update( + await async_client.instructions.with_raw_response.update( "", instruction="string", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - instruction = await client.instructions.list( + async def test_method_list(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.list( db_connection_id="string", ) assert_matches_type(InstructionListResponse, instruction, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.instructions.with_raw_response.list( + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.instructions.with_raw_response.list( db_connection_id="string", ) @@ -302,8 +372,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(InstructionListResponse, instruction, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.instructions.with_streaming_response.list( + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.instructions.with_streaming_response.list( db_connection_id="string", ) as response: assert not response.is_closed @@ -315,15 +385,15 @@ async def test_streaming_response_list(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_delete(self, client: AsyncDataherald) -> None: - instruction = await client.instructions.delete( + async def test_method_delete(self, async_client: AsyncDataherald) -> None: + instruction = await async_client.instructions.delete( "string", ) assert_matches_type(object, instruction, path=["response"]) @parametrize - async def test_raw_response_delete(self, client: AsyncDataherald) -> None: - response = await client.instructions.with_raw_response.delete( + async def test_raw_response_delete(self, async_client: AsyncDataherald) -> None: + response = await async_client.instructions.with_raw_response.delete( "string", ) @@ -333,8 +403,8 @@ async def test_raw_response_delete(self, client: AsyncDataherald) -> None: assert_matches_type(object, instruction, path=["response"]) @parametrize - async def test_streaming_response_delete(self, client: AsyncDataherald) -> None: - async with client.instructions.with_streaming_response.delete( + async def test_streaming_response_delete(self, async_client: AsyncDataherald) -> None: + async with async_client.instructions.with_streaming_response.delete( "string", ) as response: assert not response.is_closed @@ -346,8 +416,8 @@ async def test_streaming_response_delete(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_delete(self, client: AsyncDataherald) -> None: + async def test_path_params_delete(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.instructions.with_raw_response.delete( + await async_client.instructions.with_raw_response.delete( "", ) diff --git a/tests/api_resources/test_nl_generations.py b/tests/api_resources/test_nl_generations.py index aeceef7..d37f1e0 100644 --- a/tests/api_resources/test_nl_generations.py +++ b/tests/api_resources/test_nl_generations.py @@ -10,17 +10,13 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type from dataherald.types import NlGenerationListResponse -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import NlGenerationResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestNlGenerations: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -162,13 +158,11 @@ def test_streaming_response_list(self, client: Dataherald) -> None: class TestAsyncNlGenerations: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - nl_generation = await client.nl_generations.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.nl_generations.create( sql_generation={ "prompt": { "text": "string", @@ -179,8 +173,8 @@ async def test_method_create(self, client: AsyncDataherald) -> None: assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - nl_generation = await client.nl_generations.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.nl_generations.create( sql_generation={ "finetuning_id": "string", "evaluate": True, @@ -198,8 +192,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.nl_generations.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.nl_generations.with_raw_response.create( sql_generation={ "prompt": { "text": "string", @@ -214,8 +208,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.nl_generations.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.nl_generations.with_streaming_response.create( sql_generation={ "prompt": { "text": "string", @@ -232,15 +226,15 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - nl_generation = await client.nl_generations.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.nl_generations.retrieve( "string", ) assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.nl_generations.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.nl_generations.with_raw_response.retrieve( "string", ) @@ -250,8 +244,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.nl_generations.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.nl_generations.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -263,20 +257,20 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.nl_generations.with_raw_response.retrieve( + await async_client.nl_generations.with_raw_response.retrieve( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - nl_generation = await client.nl_generations.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.nl_generations.list() assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, client: AsyncDataherald) -> None: - nl_generation = await client.nl_generations.list( + async def test_method_list_with_all_params(self, async_client: AsyncDataherald) -> None: + nl_generation = await async_client.nl_generations.list( ascend=True, order="string", page=0, @@ -285,8 +279,8 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.nl_generations.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.nl_generations.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -294,8 +288,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.nl_generations.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.nl_generations.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_prompts.py b/tests/api_resources/test_prompts.py index 697d0ef..924ff14 100644 --- a/tests/api_resources/test_prompts.py +++ b/tests/api_resources/test_prompts.py @@ -10,16 +10,12 @@ from dataherald import Dataherald, AsyncDataherald from tests.utils import assert_matches_type from dataherald.types import PromptResponse, PromptListResponse -from dataherald._client import Dataherald, AsyncDataherald base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestPrompts: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -139,21 +135,19 @@ def test_streaming_response_list(self, client: Dataherald) -> None: class TestAsyncPrompts: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - prompt = await client.prompts.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + prompt = await async_client.prompts.create( db_connection_id="string", text="string", ) assert_matches_type(PromptResponse, prompt, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - prompt = await client.prompts.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + prompt = await async_client.prompts.create( db_connection_id="string", text="string", metadata={}, @@ -161,8 +155,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(PromptResponse, prompt, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.prompts.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.prompts.with_raw_response.create( db_connection_id="string", text="string", ) @@ -173,8 +167,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(PromptResponse, prompt, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.prompts.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.prompts.with_streaming_response.create( db_connection_id="string", text="string", ) as response: @@ -187,15 +181,15 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - prompt = await client.prompts.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + prompt = await async_client.prompts.retrieve( "string", ) assert_matches_type(PromptResponse, prompt, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.prompts.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.prompts.with_raw_response.retrieve( "string", ) @@ -205,8 +199,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(PromptResponse, prompt, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.prompts.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.prompts.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -218,20 +212,20 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.prompts.with_raw_response.retrieve( + await async_client.prompts.with_raw_response.retrieve( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - prompt = await client.prompts.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + prompt = await async_client.prompts.list() assert_matches_type(PromptListResponse, prompt, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, client: AsyncDataherald) -> None: - prompt = await client.prompts.list( + async def test_method_list_with_all_params(self, async_client: AsyncDataherald) -> None: + prompt = await async_client.prompts.list( ascend=True, order="string", page=0, @@ -240,8 +234,8 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non assert_matches_type(PromptListResponse, prompt, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.prompts.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.prompts.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -249,8 +243,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(PromptListResponse, prompt, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.prompts.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.prompts.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_sql_generations.py b/tests/api_resources/test_sql_generations.py index aba6e06..23116fc 100644 --- a/tests/api_resources/test_sql_generations.py +++ b/tests/api_resources/test_sql_generations.py @@ -13,17 +13,13 @@ SqlGenerationListResponse, SqlGenerationExecuteResponse, ) -from dataherald._client import Dataherald, AsyncDataherald from dataherald.types.shared import SqlGenerationResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestSqlGenerations: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_create(self, client: Dataherald) -> None: @@ -201,13 +197,11 @@ def test_path_params_execute(self, client: Dataherald) -> None: class TestAsyncSqlGenerations: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.create( + async def test_method_create(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.create( prompt={ "text": "string", "db_connection_id": "string", @@ -216,8 +210,8 @@ async def test_method_create(self, client: AsyncDataherald) -> None: assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.create( + async def test_method_create_with_all_params(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.create( prompt={ "text": "string", "db_connection_id": "string", @@ -231,8 +225,8 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_raw_response_create(self, client: AsyncDataherald) -> None: - response = await client.sql_generations.with_raw_response.create( + async def test_raw_response_create(self, async_client: AsyncDataherald) -> None: + response = await async_client.sql_generations.with_raw_response.create( prompt={ "text": "string", "db_connection_id": "string", @@ -245,8 +239,8 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_create(self, client: AsyncDataherald) -> None: - async with client.sql_generations.with_streaming_response.create( + async def test_streaming_response_create(self, async_client: AsyncDataherald) -> None: + async with async_client.sql_generations.with_streaming_response.create( prompt={ "text": "string", "db_connection_id": "string", @@ -261,15 +255,15 @@ async def test_streaming_response_create(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.retrieve( "string", ) assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.sql_generations.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.sql_generations.with_raw_response.retrieve( "string", ) @@ -279,8 +273,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.sql_generations.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.sql_generations.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -292,20 +286,20 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.sql_generations.with_raw_response.retrieve( + await async_client.sql_generations.with_raw_response.retrieve( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.list() + async def test_method_list(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.list() assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.list( + async def test_method_list_with_all_params(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.list( ascend=True, order="string", page=0, @@ -314,8 +308,8 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.sql_generations.with_raw_response.list() + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.sql_generations.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -323,8 +317,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.sql_generations.with_streaming_response.list() as response: + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.sql_generations.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -334,23 +328,23 @@ async def test_streaming_response_list(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_execute(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.execute( + async def test_method_execute(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.execute( "string", ) assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) @parametrize - async def test_method_execute_with_all_params(self, client: AsyncDataherald) -> None: - sql_generation = await client.sql_generations.execute( + async def test_method_execute_with_all_params(self, async_client: AsyncDataherald) -> None: + sql_generation = await async_client.sql_generations.execute( "string", max_rows=0, ) assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) @parametrize - async def test_raw_response_execute(self, client: AsyncDataherald) -> None: - response = await client.sql_generations.with_raw_response.execute( + async def test_raw_response_execute(self, async_client: AsyncDataherald) -> None: + response = await async_client.sql_generations.with_raw_response.execute( "string", ) @@ -360,8 +354,8 @@ async def test_raw_response_execute(self, client: AsyncDataherald) -> None: assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) @parametrize - async def test_streaming_response_execute(self, client: AsyncDataherald) -> None: - async with client.sql_generations.with_streaming_response.execute( + async def test_streaming_response_execute(self, async_client: AsyncDataherald) -> None: + async with async_client.sql_generations.with_streaming_response.execute( "string", ) as response: assert not response.is_closed @@ -373,8 +367,8 @@ async def test_streaming_response_execute(self, client: AsyncDataherald) -> None assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_execute(self, client: AsyncDataherald) -> None: + async def test_path_params_execute(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.sql_generations.with_raw_response.execute( + await async_client.sql_generations.with_raw_response.execute( "", ) diff --git a/tests/api_resources/test_table_descriptions.py b/tests/api_resources/test_table_descriptions.py index a281362..f3b9daa 100644 --- a/tests/api_resources/test_table_descriptions.py +++ b/tests/api_resources/test_table_descriptions.py @@ -14,16 +14,12 @@ TableDescriptionListResponse, TableDescriptionSyncSchemasResponse, ) -from dataherald._client import Dataherald, AsyncDataherald base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -api_key = "My API Key" class TestTableDescriptions: - strict_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize def test_method_retrieve(self, client: Dataherald) -> None: @@ -182,22 +178,14 @@ def test_streaming_response_list(self, client: Dataherald) -> None: @parametrize def test_method_sync_schemas(self, client: Dataherald) -> None: table_description = client.table_descriptions.sync_schemas( - db_connection_id="string", - ) - assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) - - @parametrize - def test_method_sync_schemas_with_all_params(self, client: Dataherald) -> None: - table_description = client.table_descriptions.sync_schemas( - db_connection_id="string", - table_names=["string", "string", "string"], + body=[{"db_connection_id": "string"}, {"db_connection_id": "string"}, {"db_connection_id": "string"}], ) assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) @parametrize def test_raw_response_sync_schemas(self, client: Dataherald) -> None: response = client.table_descriptions.with_raw_response.sync_schemas( - db_connection_id="string", + body=[{"db_connection_id": "string"}, {"db_connection_id": "string"}, {"db_connection_id": "string"}], ) assert response.is_closed is True @@ -208,7 +196,7 @@ def test_raw_response_sync_schemas(self, client: Dataherald) -> None: @parametrize def test_streaming_response_sync_schemas(self, client: Dataherald) -> None: with client.table_descriptions.with_streaming_response.sync_schemas( - db_connection_id="string", + body=[{"db_connection_id": "string"}, {"db_connection_id": "string"}, {"db_connection_id": "string"}], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -220,20 +208,18 @@ def test_streaming_response_sync_schemas(self, client: Dataherald) -> None: class TestAsyncTableDescriptions: - strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) - loose_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=False) - parametrize = pytest.mark.parametrize("client", [strict_client, loose_client], ids=["strict", "loose"]) + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_retrieve(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.retrieve( + async def test_method_retrieve(self, async_client: AsyncDataherald) -> None: + table_description = await async_client.table_descriptions.retrieve( "string", ) assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) @parametrize - async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: - response = await client.table_descriptions.with_raw_response.retrieve( + async def test_raw_response_retrieve(self, async_client: AsyncDataherald) -> None: + response = await async_client.table_descriptions.with_raw_response.retrieve( "string", ) @@ -243,8 +229,8 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) @parametrize - async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: - async with client.table_descriptions.with_streaming_response.retrieve( + async def test_streaming_response_retrieve(self, async_client: AsyncDataherald) -> None: + async with async_client.table_descriptions.with_streaming_response.retrieve( "string", ) as response: assert not response.is_closed @@ -256,22 +242,22 @@ async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> Non assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + async def test_path_params_retrieve(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.table_descriptions.with_raw_response.retrieve( + await async_client.table_descriptions.with_raw_response.retrieve( "", ) @parametrize - async def test_method_update(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.update( + async def test_method_update(self, async_client: AsyncDataherald) -> None: + table_description = await async_client.table_descriptions.update( "string", ) assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) @parametrize - async def test_method_update_with_all_params(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.update( + async def test_method_update_with_all_params(self, async_client: AsyncDataherald) -> None: + table_description = await async_client.table_descriptions.update( "string", columns=[ { @@ -309,8 +295,8 @@ async def test_method_update_with_all_params(self, client: AsyncDataherald) -> N assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) @parametrize - async def test_raw_response_update(self, client: AsyncDataherald) -> None: - response = await client.table_descriptions.with_raw_response.update( + async def test_raw_response_update(self, async_client: AsyncDataherald) -> None: + response = await async_client.table_descriptions.with_raw_response.update( "string", ) @@ -320,8 +306,8 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) @parametrize - async def test_streaming_response_update(self, client: AsyncDataherald) -> None: - async with client.table_descriptions.with_streaming_response.update( + async def test_streaming_response_update(self, async_client: AsyncDataherald) -> None: + async with async_client.table_descriptions.with_streaming_response.update( "string", ) as response: assert not response.is_closed @@ -333,30 +319,30 @@ async def test_streaming_response_update(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_path_params_update(self, client: AsyncDataherald) -> None: + async def test_path_params_update(self, async_client: AsyncDataherald) -> None: with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): - await client.table_descriptions.with_raw_response.update( + await async_client.table_descriptions.with_raw_response.update( "", ) @parametrize - async def test_method_list(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.list( + async def test_method_list(self, async_client: AsyncDataherald) -> None: + table_description = await async_client.table_descriptions.list( db_connection_id="string", ) assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.list( + async def test_method_list_with_all_params(self, async_client: AsyncDataherald) -> None: + table_description = await async_client.table_descriptions.list( db_connection_id="string", table_name="string", ) assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) @parametrize - async def test_raw_response_list(self, client: AsyncDataherald) -> None: - response = await client.table_descriptions.with_raw_response.list( + async def test_raw_response_list(self, async_client: AsyncDataherald) -> None: + response = await async_client.table_descriptions.with_raw_response.list( db_connection_id="string", ) @@ -366,8 +352,8 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) @parametrize - async def test_streaming_response_list(self, client: AsyncDataherald) -> None: - async with client.table_descriptions.with_streaming_response.list( + async def test_streaming_response_list(self, async_client: AsyncDataherald) -> None: + async with async_client.table_descriptions.with_streaming_response.list( db_connection_id="string", ) as response: assert not response.is_closed @@ -379,24 +365,16 @@ async def test_streaming_response_list(self, client: AsyncDataherald) -> None: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_sync_schemas(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.sync_schemas( - db_connection_id="string", - ) - assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) - - @parametrize - async def test_method_sync_schemas_with_all_params(self, client: AsyncDataherald) -> None: - table_description = await client.table_descriptions.sync_schemas( - db_connection_id="string", - table_names=["string", "string", "string"], + async def test_method_sync_schemas(self, async_client: AsyncDataherald) -> None: + table_description = await async_client.table_descriptions.sync_schemas( + body=[{"db_connection_id": "string"}, {"db_connection_id": "string"}, {"db_connection_id": "string"}], ) assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) @parametrize - async def test_raw_response_sync_schemas(self, client: AsyncDataherald) -> None: - response = await client.table_descriptions.with_raw_response.sync_schemas( - db_connection_id="string", + async def test_raw_response_sync_schemas(self, async_client: AsyncDataherald) -> None: + response = await async_client.table_descriptions.with_raw_response.sync_schemas( + body=[{"db_connection_id": "string"}, {"db_connection_id": "string"}, {"db_connection_id": "string"}], ) assert response.is_closed is True @@ -405,9 +383,9 @@ async def test_raw_response_sync_schemas(self, client: AsyncDataherald) -> None: assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) @parametrize - async def test_streaming_response_sync_schemas(self, client: AsyncDataherald) -> None: - async with client.table_descriptions.with_streaming_response.sync_schemas( - db_connection_id="string", + async def test_streaming_response_sync_schemas(self, async_client: AsyncDataherald) -> None: + async with async_client.table_descriptions.with_streaming_response.sync_schemas( + body=[{"db_connection_id": "string"}, {"db_connection_id": "string"}, {"db_connection_id": "string"}], ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/conftest.py b/tests/conftest.py index b8ca9e9..1455033 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,17 @@ +from __future__ import annotations + +import os import asyncio import logging -from typing import Iterator +from typing import TYPE_CHECKING, Iterator, AsyncIterator import pytest +from dataherald import Dataherald, AsyncDataherald + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest + pytest.register_assert_rewrite("tests.utils") logging.getLogger("dataherald").setLevel(logging.DEBUG) @@ -14,3 +22,28 @@ def event_loop() -> Iterator[asyncio.AbstractEventLoop]: loop = asyncio.new_event_loop() yield loop loop.close() + + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + +api_key = "My API Key" + + +@pytest.fixture(scope="session") +def client(request: FixtureRequest) -> Iterator[Dataherald]: + strict = getattr(request, "param", True) + if not isinstance(strict, bool): + raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") + + with Dataherald(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + yield client + + +@pytest.fixture(scope="session") +async def async_client(request: FixtureRequest) -> AsyncIterator[AsyncDataherald]: + strict = getattr(request, "param", True) + if not isinstance(strict, bool): + raise TypeError(f"Unexpected fixture parameter type {type(strict)}, expected {bool}") + + async with AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=strict) as client: + yield client diff --git a/tests/test_client.py b/tests/test_client.py index dac2eca..9038b17 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -19,7 +19,6 @@ from dataherald import Dataherald, AsyncDataherald, APIResponseValidationError from dataherald._client import Dataherald, AsyncDataherald from dataherald._models import BaseModel, FinalRequestOptions -from dataherald._response import APIResponse, AsyncAPIResponse from dataherald._constants import RAW_RESPONSE_HEADER from dataherald._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError from dataherald._base_client import ( @@ -667,25 +666,6 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] - @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) - @pytest.mark.respx(base_url=base_url) - def test_streaming_response(self) -> None: - response = self.client.post( - "/api/database-connections", - body=dict(), - cast_to=APIResponse[bytes], - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - - assert not cast(Any, response.is_closed) - assert _get_open_connections(self.client) == 1 - - for _ in response.iter_bytes(): - ... - - assert cast(Any, response.is_closed) - assert _get_open_connections(self.client) == 0 - @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: @@ -694,7 +674,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No with pytest.raises(APITimeoutError): self.client.post( "/api/database-connections", - body=dict(), + body=dict(alias="string", connection_uri="string"), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -709,7 +689,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non with pytest.raises(APIStatusError): self.client.post( "/api/database-connections", - body=dict(), + body=dict(alias="string", connection_uri="string"), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1339,25 +1319,6 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] - @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) - @pytest.mark.respx(base_url=base_url) - async def test_streaming_response(self) -> None: - response = await self.client.post( - "/api/database-connections", - body=dict(), - cast_to=AsyncAPIResponse[bytes], - options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, - ) - - assert not cast(Any, response.is_closed) - assert _get_open_connections(self.client) == 1 - - async for _ in response.iter_bytes(): - ... - - assert cast(Any, response.is_closed) - assert _get_open_connections(self.client) == 0 - @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: @@ -1366,7 +1327,7 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) with pytest.raises(APITimeoutError): await self.client.post( "/api/database-connections", - body=dict(), + body=dict(alias="string", connection_uri="string"), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1381,7 +1342,7 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) with pytest.raises(APIStatusError): await self.client.post( "/api/database-connections", - body=dict(), + body=dict(alias="string", connection_uri="string"), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) diff --git a/tests/test_utils/test_typing.py b/tests/test_utils/test_typing.py new file mode 100644 index 0000000..e330962 --- /dev/null +++ b/tests/test_utils/test_typing.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Generic, TypeVar, cast + +from dataherald._utils import extract_type_var_from_base + +_T = TypeVar("_T") +_T2 = TypeVar("_T2") +_T3 = TypeVar("_T3") + + +class BaseGeneric(Generic[_T]): + ... + + +class SubclassGeneric(BaseGeneric[_T]): + ... + + +class BaseGenericMultipleTypeArgs(Generic[_T, _T2, _T3]): + ... + + +class SubclassGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T, _T2, _T3]): + ... + + +class SubclassDifferentOrderGenericMultipleTypeArgs(BaseGenericMultipleTypeArgs[_T2, _T, _T3]): + ... + + +def test_extract_type_var() -> None: + assert ( + extract_type_var_from_base( + BaseGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_generic_subclass() -> None: + assert ( + extract_type_var_from_base( + SubclassGeneric[int], + index=0, + generic_bases=cast("tuple[type, ...]", (BaseGeneric,)), + ) + == int + ) + + +def test_extract_type_var_multiple() -> None: + typ = BaseGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_multiple() -> None: + typ = SubclassGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None) + + +def test_extract_type_var_generic_subclass_different_ordering_multiple() -> None: + typ = SubclassDifferentOrderGenericMultipleTypeArgs[int, str, None] + + generic_bases = cast("tuple[type, ...]", (BaseGenericMultipleTypeArgs,)) + assert extract_type_var_from_base(typ, index=0, generic_bases=generic_bases) == int + assert extract_type_var_from_base(typ, index=1, generic_bases=generic_bases) == str + assert extract_type_var_from_base(typ, index=2, generic_bases=generic_bases) == type(None)