diff --git a/README.md b/README.md index 28521a2..386b04f 100644 --- a/README.md +++ b/README.md @@ -191,7 +191,7 @@ if response.my_field is None: ### Accessing raw response data (e.g. headers) -The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call. +The "raw" Response object can be accessed by prefixing `.with_raw_response.` to any HTTP method call, e.g., ```py from dataherald import Dataherald @@ -206,6 +206,24 @@ print(database_connection.id) These methods return an [`APIResponse`](https://github.com/Dataherald/dataherald-python/tree/main/src/dataherald/_response.py) object. +The async client returns an [`AsyncAPIResponse`](https://github.com/Dataherald/dataherald-python/tree/main/src/dataherald/_response.py) with the same structure, the only difference being `await`able methods for reading the response content. + +#### `.with_streaming_response` + +The above interface eagerly reads the full response body when you make the request, which may not always be what you want. + +To stream the response body, use `.with_streaming_response` instead, which requires a context manager and only reads the response body once you call `.read()`, `.text()`, `.json()`, `.iter_bytes()`, `.iter_text()`, `.iter_lines()` or `.parse()`. In the async client, these are async methods. + +```python +with client.database_connections.with_streaming_response.create() as response: + print(response.headers.get("X-My-Header")) + + for line in response.iter_lines(): + print(line) +``` + +The context manager is required so that the response will reliably be closed. + ### Configuring the HTTP client You can directly override the [httpx client](https://www.python-httpx.org/api/#client) to customize it for your use case, including: diff --git a/src/dataherald/__init__.py b/src/dataherald/__init__.py index f21fc74..270671c 100644 --- a/src/dataherald/__init__.py +++ b/src/dataherald/__init__.py @@ -16,6 +16,7 @@ AsyncDataherald, ) from ._version import __title__, __version__ +from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse from ._exceptions import ( APIError, ConflictError, diff --git a/src/dataherald/_base_client.py b/src/dataherald/_base_client.py index c2c2db5..2a630de 100644 --- a/src/dataherald/_base_client.py +++ b/src/dataherald/_base_client.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import json import time import uuid @@ -31,7 +30,7 @@ overload, ) from functools import lru_cache -from typing_extensions import Literal, override +from typing_extensions import Literal, override, get_origin import anyio import httpx @@ -61,18 +60,22 @@ AsyncTransport, RequestOptions, ModelBuilderProtocol, - BinaryResponseContent, ) from ._utils import is_dict, is_given, is_mapping from ._compat import model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type -from ._response import APIResponse +from ._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + extract_response_type, +) from ._constants import ( DEFAULT_LIMITS, DEFAULT_TIMEOUT, DEFAULT_MAX_RETRIES, RAW_RESPONSE_HEADER, - STREAMED_RAW_RESPONSE_HEADER, + OVERRIDE_CAST_TO_HEADER, ) from ._streaming import Stream, AsyncStream from ._exceptions import ( @@ -493,28 +496,25 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o serialized[key] = value return serialized - def _process_response( - self, - *, - cast_to: Type[ResponseT], - options: FinalRequestOptions, - response: httpx.Response, - stream: bool, - stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, - ) -> ResponseT: - api_response = APIResponse( - raw=response, - client=self, - cast_to=cast_to, - stream=stream, - stream_cls=stream_cls, - options=options, - ) + def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]: + if not is_given(options.headers): + return cast_to - if response.request.headers.get(RAW_RESPONSE_HEADER) == "true": - return cast(ResponseT, api_response) + # make a copy of the headers so we don't mutate user-input + headers = dict(options.headers) - return api_response.parse() + # we internally support defining a temporary header to override the + # default `cast_to` type for use with `.with_raw_response` and `.with_streaming_response` + # see _response.py for implementation details + override_cast_to = headers.pop(OVERRIDE_CAST_TO_HEADER, NOT_GIVEN) + if is_given(override_cast_to): + options.headers = headers + return cast(Type[ResponseT], override_cast_to) + + return cast_to + + def _should_stream_response_body(self, request: httpx.Request) -> bool: + return request.headers.get(RAW_RESPONSE_HEADER) == "stream" # type: ignore[no-any-return] def _process_response_data( self, @@ -540,12 +540,6 @@ def _process_response_data( except pydantic.ValidationError as err: raise APIResponseValidationError(response=response, body=data) from err - def _should_stream_response_body(self, *, request: httpx.Request) -> bool: - if request.headers.get(STREAMED_RAW_RESPONSE_HEADER) == "true": - return True - - return False - @property def qs(self) -> Querystring: return Querystring() @@ -610,6 +604,8 @@ def _calculate_retry_timeout( if response_headers is not None: retry_header = response_headers.get("retry-after") try: + # note: the spec indicates that this should only ever be an integer + # but if someone sends a float there's no reason for us to not respect it retry_after = float(retry_header) except Exception: retry_date_tuple = email.utils.parsedate_tz(retry_header) @@ -873,6 +869,7 @@ def _request( stream: bool, stream_cls: type[_StreamT] | None, ) -> ResponseT | _StreamT: + cast_to = self._maybe_override_cast_to(cast_to, options) self._prepare_options(options) retries = self._remaining_retries(remaining_retries, options) @@ -987,6 +984,50 @@ def _retry_request( stream_cls=stream_cls, ) + def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + ) -> ResponseT: + origin = get_origin(cast_to) or cast_to + + if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if not issubclass(origin, APIResponse): + raise TypeError(f"API Response types must subclass {APIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_to) + return cast( + ResponseT, + response_cls( + raw=response, + client=self, + cast_to=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + if cast_to == httpx.Response: + return cast(ResponseT, response) + + api_response = APIResponse( + raw=response, + client=self, + cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast] + stream=stream, + stream_cls=stream_cls, + options=options, + ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return api_response.parse() + def _request_api_list( self, model: Type[object], @@ -1353,6 +1394,7 @@ async def _request( stream_cls: type[_AsyncStreamT] | None, remaining_retries: int | None, ) -> ResponseT | _AsyncStreamT: + cast_to = self._maybe_override_cast_to(cast_to, options) await self._prepare_options(options) retries = self._remaining_retries(remaining_retries, options) @@ -1428,7 +1470,7 @@ async def _request( log.debug("Re-raising status error") raise self._make_status_error_from_response(err.response) from None - return self._process_response( + return await self._process_response( cast_to=cast_to, options=options, response=response, @@ -1465,6 +1507,50 @@ async def _retry_request( stream_cls=stream_cls, ) + async def _process_response( + self, + *, + cast_to: Type[ResponseT], + options: FinalRequestOptions, + response: httpx.Response, + stream: bool, + stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None, + ) -> ResponseT: + origin = get_origin(cast_to) or cast_to + + if inspect.isclass(origin) and issubclass(origin, BaseAPIResponse): + if not issubclass(origin, AsyncAPIResponse): + raise TypeError(f"API Response types must subclass {AsyncAPIResponse}; Received {origin}") + + response_cls = cast("type[BaseAPIResponse[Any]]", cast_to) + return cast( + "ResponseT", + response_cls( + raw=response, + client=self, + cast_to=extract_response_type(response_cls), + stream=stream, + stream_cls=stream_cls, + options=options, + ), + ) + + if cast_to == httpx.Response: + return cast(ResponseT, response) + + api_response = AsyncAPIResponse( + raw=response, + client=self, + cast_to=cast("type[ResponseT]", cast_to), # pyright: ignore[reportUnnecessaryCast] + stream=stream, + stream_cls=stream_cls, + options=options, + ) + if bool(response.request.headers.get(RAW_RESPONSE_HEADER)): + return cast(ResponseT, api_response) + + return await api_response.parse() + def _request_api_list( self, model: Type[_T], @@ -1783,105 +1869,3 @@ def _merge_mappings( """ merged = {**obj1, **obj2} return {key: value for key, value in merged.items() if not isinstance(value, Omit)} - - -class HttpxBinaryResponseContent(BinaryResponseContent): - response: httpx.Response - - def __init__(self, response: httpx.Response) -> None: - self.response = response - - @property - @override - def content(self) -> bytes: - return self.response.content - - @property - @override - def text(self) -> str: - return self.response.text - - @property - @override - def encoding(self) -> Optional[str]: - return self.response.encoding - - @property - @override - def charset_encoding(self) -> Optional[str]: - return self.response.charset_encoding - - @override - def json(self, **kwargs: Any) -> Any: - return self.response.json(**kwargs) - - @override - def read(self) -> bytes: - return self.response.read() - - @override - def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - return self.response.iter_bytes(chunk_size) - - @override - def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]: - return self.response.iter_text(chunk_size) - - @override - def iter_lines(self) -> Iterator[str]: - return self.response.iter_lines() - - @override - def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - return self.response.iter_raw(chunk_size) - - @override - def stream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - with open(file, mode="wb") as f: - for data in self.response.iter_bytes(chunk_size): - f.write(data) - - @override - def close(self) -> None: - return self.response.close() - - @override - async def aread(self) -> bytes: - return await self.response.aread() - - @override - async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - return self.response.aiter_bytes(chunk_size) - - @override - async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]: - return self.response.aiter_text(chunk_size) - - @override - async def aiter_lines(self) -> AsyncIterator[str]: - return self.response.aiter_lines() - - @override - async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - return self.response.aiter_raw(chunk_size) - - @override - async def astream_to_file( - self, - file: str | os.PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - path = anyio.Path(file) - async with await path.open(mode="wb") as f: - async for data in self.response.aiter_bytes(chunk_size): - await f.write(data) - - @override - async def aclose(self) -> None: - return await self.response.aclose() diff --git a/src/dataherald/_client.py b/src/dataherald/_client.py index b6f6536..2c44dbd 100644 --- a/src/dataherald/_client.py +++ b/src/dataherald/_client.py @@ -64,6 +64,7 @@ class Dataherald(SyncAPIClient): heartbeat: resources.Heartbeat engine: resources.Engine with_raw_response: DataheraldWithRawResponse + with_streaming_response: DataheraldWithStreamedResponse # client options api_key: str @@ -153,6 +154,7 @@ def __init__( self.heartbeat = resources.Heartbeat(self) self.engine = resources.Engine(self) self.with_raw_response = DataheraldWithRawResponse(self) + self.with_streaming_response = DataheraldWithStreamedResponse(self) @property @override @@ -274,6 +276,7 @@ class AsyncDataherald(AsyncAPIClient): heartbeat: resources.AsyncHeartbeat engine: resources.AsyncEngine with_raw_response: AsyncDataheraldWithRawResponse + with_streaming_response: AsyncDataheraldWithStreamedResponse # client options api_key: str @@ -363,6 +366,7 @@ def __init__( self.heartbeat = resources.AsyncHeartbeat(self) self.engine = resources.AsyncEngine(self) self.with_raw_response = AsyncDataheraldWithRawResponse(self) + self.with_streaming_response = AsyncDataheraldWithStreamedResponse(self) @property @override @@ -501,6 +505,36 @@ def __init__(self, client: AsyncDataherald) -> None: self.engine = resources.AsyncEngineWithRawResponse(client.engine) +class DataheraldWithStreamedResponse: + def __init__(self, client: Dataherald) -> None: + self.database_connections = resources.DatabaseConnectionsWithStreamingResponse(client.database_connections) + self.finetunings = resources.FinetuningsWithStreamingResponse(client.finetunings) + self.golden_sqls = resources.GoldenSqlsWithStreamingResponse(client.golden_sqls) + self.instructions = resources.InstructionsWithStreamingResponse(client.instructions) + self.generations = resources.GenerationsWithStreamingResponse(client.generations) + self.prompts = resources.PromptsWithStreamingResponse(client.prompts) + self.sql_generations = resources.SqlGenerationsWithStreamingResponse(client.sql_generations) + self.nl_generations = resources.NlGenerationsWithStreamingResponse(client.nl_generations) + self.table_descriptions = resources.TableDescriptionsWithStreamingResponse(client.table_descriptions) + self.heartbeat = resources.HeartbeatWithStreamingResponse(client.heartbeat) + self.engine = resources.EngineWithStreamingResponse(client.engine) + + +class AsyncDataheraldWithStreamedResponse: + def __init__(self, client: AsyncDataherald) -> None: + self.database_connections = resources.AsyncDatabaseConnectionsWithStreamingResponse(client.database_connections) + self.finetunings = resources.AsyncFinetuningsWithStreamingResponse(client.finetunings) + self.golden_sqls = resources.AsyncGoldenSqlsWithStreamingResponse(client.golden_sqls) + self.instructions = resources.AsyncInstructionsWithStreamingResponse(client.instructions) + self.generations = resources.AsyncGenerationsWithStreamingResponse(client.generations) + self.prompts = resources.AsyncPromptsWithStreamingResponse(client.prompts) + self.sql_generations = resources.AsyncSqlGenerationsWithStreamingResponse(client.sql_generations) + self.nl_generations = resources.AsyncNlGenerationsWithStreamingResponse(client.nl_generations) + self.table_descriptions = resources.AsyncTableDescriptionsWithStreamingResponse(client.table_descriptions) + self.heartbeat = resources.AsyncHeartbeatWithStreamingResponse(client.heartbeat) + self.engine = resources.AsyncEngineWithStreamingResponse(client.engine) + + Client = Dataherald AsyncClient = AsyncDataherald diff --git a/src/dataherald/_constants.py b/src/dataherald/_constants.py index 39b46eb..76b21f0 100644 --- a/src/dataherald/_constants.py +++ b/src/dataherald/_constants.py @@ -3,7 +3,7 @@ import httpx RAW_RESPONSE_HEADER = "X-Stainless-Raw-Response" -STREAMED_RAW_RESPONSE_HEADER = "X-Stainless-Streamed-Raw-Response" +OVERRIDE_CAST_TO_HEADER = "____stainless_override_cast_to" # default timeout is 1 minute DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0) diff --git a/src/dataherald/_response.py b/src/dataherald/_response.py index a947273..9ef757c 100644 --- a/src/dataherald/_response.py +++ b/src/dataherald/_response.py @@ -1,19 +1,32 @@ from __future__ import annotations +import os import inspect import logging import datetime import functools -from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + Union, + Generic, + TypeVar, + Callable, + Iterator, + AsyncIterator, + cast, +) from typing_extensions import Awaitable, ParamSpec, override, get_origin +import anyio import httpx -from ._types import NoneType, BinaryResponseContent +from ._types import NoneType from ._utils import is_given, extract_type_var_from_base from ._models import BaseModel, is_basemodel -from ._constants import RAW_RESPONSE_HEADER -from ._exceptions import APIResponseValidationError +from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER +from ._exceptions import DataheraldError, APIResponseValidationError if TYPE_CHECKING: from ._models import FinalRequestOptions @@ -22,15 +35,17 @@ P = ParamSpec("P") R = TypeVar("R") +_APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") +_AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]") log: logging.Logger = logging.getLogger(__name__) -class APIResponse(Generic[R]): +class BaseAPIResponse(Generic[R]): _cast_to: type[R] _client: BaseClient[Any, Any] _parsed: R | None - _stream: bool + _is_sse_stream: bool _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None _options: FinalRequestOptions @@ -49,28 +64,18 @@ def __init__( self._cast_to = cast_to self._client = client self._parsed = None - self._stream = stream + self._is_sse_stream = stream self._stream_cls = stream_cls self._options = options self.http_response = raw - def parse(self) -> R: - if self._parsed is not None: - return self._parsed - - parsed = self._parse() - if is_given(self._options.post_parser): - parsed = self._options.post_parser(parsed) - - self._parsed = parsed - return parsed - @property def headers(self) -> httpx.Headers: return self.http_response.headers @property def http_request(self) -> httpx.Request: + """Returns the httpx Request instance associated with the current response.""" return self.http_response.request @property @@ -79,20 +84,13 @@ def status_code(self) -> int: @property def url(self) -> httpx.URL: + """Returns the URL for which the request was made.""" return self.http_response.url @property def method(self) -> str: return self.http_request.method - @property - def content(self) -> bytes: - return self.http_response.content - - @property - def text(self) -> str: - return self.http_response.text - @property def http_version(self) -> str: return self.http_response.http_version @@ -102,13 +100,29 @@ def elapsed(self) -> datetime.timedelta: """The time taken for the complete request/response cycle to complete.""" return self.http_response.elapsed + @property + def is_closed(self) -> bool: + """Whether or not the response body has been closed. + + If this is False then there is response data that has not been read yet. + You must either fully consume the response body or call `.close()` + before discarding the response to prevent resource leaks. + """ + return self.http_response.is_closed + + @override + def __repr__(self) -> str: + return ( + f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>" + ) + def _parse(self) -> R: - if self._stream: + if self._is_sse_stream: if self._stream_cls: return cast( R, self._stream_cls( - cast_to=_extract_stream_chunk_type(self._stream_cls), + cast_to=extract_stream_chunk_type(self._stream_cls), response=self.http_response, client=cast(Any, self._client), ), @@ -135,10 +149,10 @@ def _parse(self) -> R: if cast_to == str: return cast(R, response.text) - origin = get_origin(cast_to) or cast_to + if cast_to == bytes: + return cast(R, response.content) - if inspect.isclass(origin) and issubclass(origin, BinaryResponseContent): - return cast(R, cast_to(response)) # type: ignore + origin = get_origin(cast_to) or cast_to if origin == APIResponse: raise RuntimeError("Unexpected state - cast_to is `APIResponse`") @@ -208,9 +222,227 @@ def _parse(self) -> R: response=response, ) - @override - def __repr__(self) -> str: - return f"" + +class APIResponse(BaseAPIResponse[R]): + def parse(self) -> R: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + """ + if self._parsed is not None: + return self._parsed + + if not self._is_sse_stream: + self.read() + + parsed = self._parse() + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed = parsed + return parsed + + def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return self.http_response.read() + except httpx.StreamConsumed as exc: + # The default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message. + raise StreamAlreadyConsumed() from exc + + def text(self) -> str: + """Read and decode the response content into a string.""" + self.read() + return self.http_response.text + + def json(self) -> object: + """Read and decode the JSON response content.""" + self.read() + return self.http_response.json() + + def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + self.http_response.close() + + def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + for chunk in self.http_response.iter_bytes(chunk_size): + yield chunk + + def iter_text(self, chunk_size: int | None = None) -> Iterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + for chunk in self.http_response.iter_text(chunk_size): + yield chunk + + def iter_lines(self) -> Iterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + for chunk in self.http_response.iter_lines(): + yield chunk + + +class AsyncAPIResponse(BaseAPIResponse[R]): + async def parse(self) -> R: + """Returns the rich python representation of this response's data. + + For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + """ + if self._parsed is not None: + return self._parsed + + if not self._is_sse_stream: + await self.read() + + parsed = self._parse() + if is_given(self._options.post_parser): + parsed = self._options.post_parser(parsed) + + self._parsed = parsed + return parsed + + async def read(self) -> bytes: + """Read and return the binary response content.""" + try: + return await self.http_response.aread() + except httpx.StreamConsumed as exc: + # the default error raised by httpx isn't very + # helpful in our case so we re-raise it with + # a different error message + raise StreamAlreadyConsumed() from exc + + async def text(self) -> str: + """Read and decode the response content into a string.""" + await self.read() + return self.http_response.text + + async def json(self) -> object: + """Read and decode the JSON response content.""" + await self.read() + return self.http_response.json() + + async def close(self) -> None: + """Close the response and release the connection. + + Automatically called if the response body is read to completion. + """ + await self.http_response.aclose() + + async def iter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]: + """ + A byte-iterator over the decoded response content. + + This automatically handles gzip, deflate and brotli encoded responses. + """ + async for chunk in self.http_response.aiter_bytes(chunk_size): + yield chunk + + async def iter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]: + """A str-iterator over the decoded response content + that handles both gzip, deflate, etc but also detects the content's + string encoding. + """ + async for chunk in self.http_response.aiter_text(chunk_size): + yield chunk + + async def iter_lines(self) -> AsyncIterator[str]: + """Like `iter_text()` but will only yield chunks for each line""" + async for chunk in self.http_response.aiter_lines(): + yield chunk + + +class BinaryAPIResponse(APIResponse[bytes]): + """Subclass of APIResponse providing helpers for dealing with binary data. + + Note: If you want to stream the response data instead of eagerly reading it + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + + def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + with open(file, mode="wb") as f: + for data in self.iter_bytes(): + f.write(data) + + +class AsyncBinaryAPIResponse(AsyncAPIResponse[bytes]): + """Subclass of APIResponse providing helpers for dealing with binary data. + + Note: If you want to stream the response data instead of eagerly reading it + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + + async def write_to_file( + self, + file: str | os.PathLike[str], + ) -> None: + """Write the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + + Note: if you want to stream the data to the file instead of writing + all at once then you should use `.with_streaming_response` when making + the API request, e.g. `.with_streaming_response.get_binary_response()` + """ + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.iter_bytes(): + await f.write(data) + + +class StreamedBinaryAPIResponse(APIResponse[bytes]): + def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + """Streams the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + """ + with open(file, mode="wb") as f: + for data in self.iter_bytes(chunk_size): + f.write(data) + + +class AsyncStreamedBinaryAPIResponse(AsyncAPIResponse[bytes]): + async def stream_to_file( + self, + file: str | os.PathLike[str], + *, + chunk_size: int | None = None, + ) -> None: + """Streams the output to the given file. + + Accepts a filename or any path-like object, e.g. pathlib.Path + """ + path = anyio.Path(file) + async with await path.open(mode="wb") as f: + async for data in self.iter_bytes(chunk_size): + await f.write(data) class MissingStreamClassError(TypeError): @@ -220,14 +452,176 @@ def __init__(self) -> None: ) -def _extract_stream_chunk_type(stream_cls: type) -> type: - from ._base_client import Stream, AsyncStream +class StreamAlreadyConsumed(DataheraldError): + """ + Attempted to read or stream content, but the content has already + been streamed. - return extract_type_var_from_base( - stream_cls, - index=0, - generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), - ) + This can happen if you use a method like `.iter_lines()` and then attempt + to read th entire response body afterwards, e.g. + + ```py + response = await client.post(...) + async for line in response.iter_lines(): + ... # do something with `line` + + content = await response.read() + # ^ error + ``` + + If you want this behaviour you'll need to either manually accumulate the response + content or call `await response.read()` before iterating over the stream. + """ + + def __init__(self) -> None: + message = ( + "Attempted to read or stream some content, but the content has " + "already been streamed. " + "This could be due to attempting to stream the response " + "content more than once." + "\n\n" + "You can fix this by manually accumulating the response content while streaming " + "or by calling `.read()` before starting to stream." + ) + super().__init__(message) + + +class ResponseContextManager(Generic[_APIResponseT]): + """Context manager for ensuring that a request is not made + until it is entered and that the response will always be closed + when the context manager exits + """ + + def __init__(self, request_func: Callable[[], _APIResponseT]) -> None: + self._request_func = request_func + self.__response: _APIResponseT | None = None + + def __enter__(self) -> _APIResponseT: + self.__response = self._request_func() + return self.__response + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__response is not None: + self.__response.close() + + +class AsyncResponseContextManager(Generic[_AsyncAPIResponseT]): + """Context manager for ensuring that a request is not made + until it is entered and that the response will always be closed + when the context manager exits + """ + + def __init__(self, api_request: Awaitable[_AsyncAPIResponseT]) -> None: + self._api_request = api_request + self.__response: _AsyncAPIResponseT | None = None + + async def __aenter__(self) -> _AsyncAPIResponseT: + self.__response = await self._api_request + return self.__response + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + if self.__response is not None: + await self.__response.close() + + +def to_streamed_response_wrapper(func: Callable[P, R]) -> Callable[P, ResponseContextManager[APIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support streaming and returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[APIResponse[R]]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + + kwargs["extra_headers"] = extra_headers + + make_request = functools.partial(func, *args, **kwargs) + + return ResponseContextManager(cast(Callable[[], APIResponse[R]], make_request)) + + return wrapped + + +def async_to_streamed_response_wrapper( + func: Callable[P, Awaitable[R]], +) -> Callable[P, AsyncResponseContextManager[AsyncAPIResponse[R]]]: + """Higher order function that takes one of our bound API methods and wraps it + to support streaming and returning the raw `APIResponse` object directly. + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[AsyncAPIResponse[R]]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + + kwargs["extra_headers"] = extra_headers + + make_request = func(*args, **kwargs) + + return AsyncResponseContextManager(cast(Awaitable[AsyncAPIResponse[R]], make_request)) + + return wrapped + + +def to_custom_streamed_response_wrapper( + func: Callable[P, object], + response_cls: type[_APIResponseT], +) -> Callable[P, ResponseContextManager[_APIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support streaming and returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> ResponseContextManager[_APIResponseT]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + make_request = functools.partial(func, *args, **kwargs) + + return ResponseContextManager(cast(Callable[[], _APIResponseT], make_request)) + + return wrapped + + +def async_to_custom_streamed_response_wrapper( + func: Callable[P, Awaitable[object]], + response_cls: type[_AsyncAPIResponseT], +) -> Callable[P, AsyncResponseContextManager[_AsyncAPIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support streaming and returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncResponseContextManager[_AsyncAPIResponseT]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "stream" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + make_request = func(*args, **kwargs) + + return AsyncResponseContextManager(cast(Awaitable[_AsyncAPIResponseT], make_request)) + + return wrapped def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]]: @@ -238,7 +632,7 @@ def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, APIResponse[R]] @functools.wraps(func) def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} - extra_headers[RAW_RESPONSE_HEADER] = "true" + extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers @@ -247,18 +641,102 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: return wrapped -def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[APIResponse[R]]]: +def async_to_raw_response_wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[AsyncAPIResponse[R]]]: """Higher order function that takes one of our bound API methods and wraps it to support returning the raw `APIResponse` object directly. """ @functools.wraps(func) - async def wrapped(*args: P.args, **kwargs: P.kwargs) -> APIResponse[R]: + async def wrapped(*args: P.args, **kwargs: P.kwargs) -> AsyncAPIResponse[R]: extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} - extra_headers[RAW_RESPONSE_HEADER] = "true" + extra_headers[RAW_RESPONSE_HEADER] = "raw" kwargs["extra_headers"] = extra_headers - return cast(APIResponse[R], await func(*args, **kwargs)) + return cast(AsyncAPIResponse[R], await func(*args, **kwargs)) return wrapped + + +def to_custom_raw_response_wrapper( + func: Callable[P, object], + response_cls: type[_APIResponseT], +) -> Callable[P, _APIResponseT]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> _APIResponseT: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + return cast(_APIResponseT, func(*args, **kwargs)) + + return wrapped + + +def async_to_custom_raw_response_wrapper( + func: Callable[P, Awaitable[object]], + response_cls: type[_AsyncAPIResponseT], +) -> Callable[P, Awaitable[_AsyncAPIResponseT]]: + """Higher order function that takes one of our bound API methods and an `APIResponse` class + and wraps the method to support returning the given response class directly. + + Note: the given `response_cls` *must* be concrete, e.g. `class BinaryAPIResponse(APIResponse[bytes])` + """ + + @functools.wraps(func) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: + extra_headers = {**(cast(Any, kwargs.get("extra_headers")) or {})} + extra_headers[RAW_RESPONSE_HEADER] = "raw" + extra_headers[OVERRIDE_CAST_TO_HEADER] = response_cls + + kwargs["extra_headers"] = extra_headers + + return cast(Awaitable[_AsyncAPIResponseT], func(*args, **kwargs)) + + return wrapped + + +def extract_stream_chunk_type(stream_cls: type) -> type: + """Given a type like `Stream[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyStream(Stream[bytes]): + ... + + extract_stream_chunk_type(MyStream) -> bytes + ``` + """ + from ._base_client import Stream, AsyncStream + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), + ) + + +def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: + """Given a type like `APIResponse[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyResponse(APIResponse[bytes]): + ... + + extract_response_type(MyResponse) -> bytes + ``` + """ + return extract_type_var_from_base( + typ, + generic_bases=cast("tuple[type, ...]", (BaseAPIResponse, APIResponse, AsyncAPIResponse)), + index=0, + ) diff --git a/src/dataherald/_types.py b/src/dataherald/_types.py index eb2e4b5..13aa351 100644 --- a/src/dataherald/_types.py +++ b/src/dataherald/_types.py @@ -1,7 +1,6 @@ from __future__ import annotations from os import PathLike -from abc import ABC, abstractmethod from typing import ( IO, TYPE_CHECKING, @@ -14,10 +13,8 @@ Mapping, TypeVar, Callable, - Iterator, Optional, Sequence, - AsyncIterator, ) from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable @@ -27,6 +24,7 @@ if TYPE_CHECKING: from ._models import BaseModel + from ._response import APIResponse, AsyncAPIResponse Transport = BaseTransport AsyncTransport = AsyncBaseTransport @@ -37,162 +35,6 @@ _T = TypeVar("_T") -class BinaryResponseContent(ABC): - @abstractmethod - def __init__( - self, - response: Any, - ) -> None: - ... - - @property - @abstractmethod - def content(self) -> bytes: - pass - - @property - @abstractmethod - def text(self) -> str: - pass - - @property - @abstractmethod - def encoding(self) -> Optional[str]: - """ - Return an encoding to use for decoding the byte content into text. - The priority for determining this is given by... - - * `.encoding = <>` has been set explicitly. - * The encoding as specified by the charset parameter in the Content-Type header. - * The encoding as determined by `default_encoding`, which may either be - a string like "utf-8" indicating the encoding to use, or may be a callable - which enables charset autodetection. - """ - pass - - @property - @abstractmethod - def charset_encoding(self) -> Optional[str]: - """ - Return the encoding, as specified by the Content-Type header. - """ - pass - - @abstractmethod - def json(self, **kwargs: Any) -> Any: - pass - - @abstractmethod - def read(self) -> bytes: - """ - Read and return the response content. - """ - pass - - @abstractmethod - def iter_bytes(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - """ - A byte-iterator over the decoded response content. - This allows us to handle gzip, deflate, and brotli encoded responses. - """ - pass - - @abstractmethod - def iter_text(self, chunk_size: Optional[int] = None) -> Iterator[str]: - """ - A str-iterator over the decoded response content - that handles both gzip, deflate, etc but also detects the content's - string encoding. - """ - pass - - @abstractmethod - def iter_lines(self) -> Iterator[str]: - pass - - @abstractmethod - def iter_raw(self, chunk_size: Optional[int] = None) -> Iterator[bytes]: - """ - A byte-iterator over the raw response content. - """ - pass - - @abstractmethod - def stream_to_file( - self, - file: str | PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - """ - Stream the output to the given file. - """ - pass - - @abstractmethod - def close(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - pass - - @abstractmethod - async def aread(self) -> bytes: - """ - Read and return the response content. - """ - pass - - @abstractmethod - async def aiter_bytes(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - """ - A byte-iterator over the decoded response content. - This allows us to handle gzip, deflate, and brotli encoded responses. - """ - pass - - @abstractmethod - async def aiter_text(self, chunk_size: Optional[int] = None) -> AsyncIterator[str]: - """ - A str-iterator over the decoded response content - that handles both gzip, deflate, etc but also detects the content's - string encoding. - """ - pass - - @abstractmethod - async def aiter_lines(self) -> AsyncIterator[str]: - pass - - @abstractmethod - async def aiter_raw(self, chunk_size: Optional[int] = None) -> AsyncIterator[bytes]: - """ - A byte-iterator over the raw response content. - """ - pass - - @abstractmethod - async def astream_to_file( - self, - file: str | PathLike[str], - *, - chunk_size: int | None = None, - ) -> None: - """ - Stream the output to the given file. - """ - pass - - @abstractmethod - async def aclose(self) -> None: - """ - Close the response and release the connection. - Automatically called if the response body is read to completion. - """ - pass - - # Approximates httpx internal ProxiesTypes and RequestFiles types # while adding support for `PathLike` instances ProxiesDict = Dict["str | URL", Union[None, str, URL, Proxy]] @@ -343,7 +185,8 @@ def get(self, __key: str) -> str | None: Dict[str, Any], Response, ModelBuilderProtocol, - BinaryResponseContent, + "APIResponse[Any]", + "AsyncAPIResponse[Any]", ], ) @@ -359,6 +202,7 @@ def get(self, __key: str) -> str | None: @runtime_checkable class InheritsGeneric(Protocol): """Represents a type that has inherited from `Generic` + The `__orig_bases__` property can be used to determine the resolved type variable for a given base class. """ diff --git a/src/dataherald/resources/__init__.py b/src/dataherald/resources/__init__.py index 3bf9bea..60ede1b 100644 --- a/src/dataherald/resources/__init__.py +++ b/src/dataherald/resources/__init__.py @@ -1,35 +1,92 @@ # File generated from our OpenAPI spec by Stainless. -from .engine import Engine, AsyncEngine, EngineWithRawResponse, AsyncEngineWithRawResponse -from .prompts import Prompts, AsyncPrompts, PromptsWithRawResponse, AsyncPromptsWithRawResponse -from .heartbeat import Heartbeat, AsyncHeartbeat, HeartbeatWithRawResponse, AsyncHeartbeatWithRawResponse -from .finetunings import Finetunings, AsyncFinetunings, FinetuningsWithRawResponse, AsyncFinetuningsWithRawResponse -from .generations import Generations, AsyncGenerations, GenerationsWithRawResponse, AsyncGenerationsWithRawResponse -from .golden_sqls import GoldenSqls, AsyncGoldenSqls, GoldenSqlsWithRawResponse, AsyncGoldenSqlsWithRawResponse -from .instructions import Instructions, AsyncInstructions, InstructionsWithRawResponse, AsyncInstructionsWithRawResponse +from .engine import ( + Engine, + AsyncEngine, + EngineWithRawResponse, + AsyncEngineWithRawResponse, + EngineWithStreamingResponse, + AsyncEngineWithStreamingResponse, +) +from .prompts import ( + Prompts, + AsyncPrompts, + PromptsWithRawResponse, + AsyncPromptsWithRawResponse, + PromptsWithStreamingResponse, + AsyncPromptsWithStreamingResponse, +) +from .heartbeat import ( + Heartbeat, + AsyncHeartbeat, + HeartbeatWithRawResponse, + AsyncHeartbeatWithRawResponse, + HeartbeatWithStreamingResponse, + AsyncHeartbeatWithStreamingResponse, +) +from .finetunings import ( + Finetunings, + AsyncFinetunings, + FinetuningsWithRawResponse, + AsyncFinetuningsWithRawResponse, + FinetuningsWithStreamingResponse, + AsyncFinetuningsWithStreamingResponse, +) +from .generations import ( + Generations, + AsyncGenerations, + GenerationsWithRawResponse, + AsyncGenerationsWithRawResponse, + GenerationsWithStreamingResponse, + AsyncGenerationsWithStreamingResponse, +) +from .golden_sqls import ( + GoldenSqls, + AsyncGoldenSqls, + GoldenSqlsWithRawResponse, + AsyncGoldenSqlsWithRawResponse, + GoldenSqlsWithStreamingResponse, + AsyncGoldenSqlsWithStreamingResponse, +) +from .instructions import ( + Instructions, + AsyncInstructions, + InstructionsWithRawResponse, + AsyncInstructionsWithRawResponse, + InstructionsWithStreamingResponse, + AsyncInstructionsWithStreamingResponse, +) from .nl_generations import ( NlGenerations, AsyncNlGenerations, NlGenerationsWithRawResponse, AsyncNlGenerationsWithRawResponse, + NlGenerationsWithStreamingResponse, + AsyncNlGenerationsWithStreamingResponse, ) from .sql_generations import ( SqlGenerations, AsyncSqlGenerations, SqlGenerationsWithRawResponse, AsyncSqlGenerationsWithRawResponse, + SqlGenerationsWithStreamingResponse, + AsyncSqlGenerationsWithStreamingResponse, ) from .table_descriptions import ( TableDescriptions, AsyncTableDescriptions, TableDescriptionsWithRawResponse, AsyncTableDescriptionsWithRawResponse, + TableDescriptionsWithStreamingResponse, + AsyncTableDescriptionsWithStreamingResponse, ) from .database_connections import ( DatabaseConnections, AsyncDatabaseConnections, DatabaseConnectionsWithRawResponse, AsyncDatabaseConnectionsWithRawResponse, + DatabaseConnectionsWithStreamingResponse, + AsyncDatabaseConnectionsWithStreamingResponse, ) __all__ = [ @@ -37,44 +94,66 @@ "AsyncDatabaseConnections", "DatabaseConnectionsWithRawResponse", "AsyncDatabaseConnectionsWithRawResponse", + "DatabaseConnectionsWithStreamingResponse", + "AsyncDatabaseConnectionsWithStreamingResponse", "Finetunings", "AsyncFinetunings", "FinetuningsWithRawResponse", "AsyncFinetuningsWithRawResponse", + "FinetuningsWithStreamingResponse", + "AsyncFinetuningsWithStreamingResponse", "GoldenSqls", "AsyncGoldenSqls", "GoldenSqlsWithRawResponse", "AsyncGoldenSqlsWithRawResponse", + "GoldenSqlsWithStreamingResponse", + "AsyncGoldenSqlsWithStreamingResponse", "Instructions", "AsyncInstructions", "InstructionsWithRawResponse", "AsyncInstructionsWithRawResponse", + "InstructionsWithStreamingResponse", + "AsyncInstructionsWithStreamingResponse", "Generations", "AsyncGenerations", "GenerationsWithRawResponse", "AsyncGenerationsWithRawResponse", + "GenerationsWithStreamingResponse", + "AsyncGenerationsWithStreamingResponse", "Prompts", "AsyncPrompts", "PromptsWithRawResponse", "AsyncPromptsWithRawResponse", + "PromptsWithStreamingResponse", + "AsyncPromptsWithStreamingResponse", "SqlGenerations", "AsyncSqlGenerations", "SqlGenerationsWithRawResponse", "AsyncSqlGenerationsWithRawResponse", + "SqlGenerationsWithStreamingResponse", + "AsyncSqlGenerationsWithStreamingResponse", "NlGenerations", "AsyncNlGenerations", "NlGenerationsWithRawResponse", "AsyncNlGenerationsWithRawResponse", + "NlGenerationsWithStreamingResponse", + "AsyncNlGenerationsWithStreamingResponse", "TableDescriptions", "AsyncTableDescriptions", "TableDescriptionsWithRawResponse", "AsyncTableDescriptionsWithRawResponse", + "TableDescriptionsWithStreamingResponse", + "AsyncTableDescriptionsWithStreamingResponse", "Heartbeat", "AsyncHeartbeat", "HeartbeatWithRawResponse", "AsyncHeartbeatWithRawResponse", + "HeartbeatWithStreamingResponse", + "AsyncHeartbeatWithStreamingResponse", "Engine", "AsyncEngine", "EngineWithRawResponse", "AsyncEngineWithRawResponse", + "EngineWithStreamingResponse", + "AsyncEngineWithStreamingResponse", ] diff --git a/src/dataherald/resources/database_connections/__init__.py b/src/dataherald/resources/database_connections/__init__.py index 17de552..cb4760b 100644 --- a/src/dataherald/resources/database_connections/__init__.py +++ b/src/dataherald/resources/database_connections/__init__.py @@ -1,11 +1,20 @@ # File generated from our OpenAPI spec by Stainless. -from .drivers import Drivers, AsyncDrivers, DriversWithRawResponse, AsyncDriversWithRawResponse +from .drivers import ( + Drivers, + AsyncDrivers, + DriversWithRawResponse, + AsyncDriversWithRawResponse, + DriversWithStreamingResponse, + AsyncDriversWithStreamingResponse, +) from .database_connections import ( DatabaseConnections, AsyncDatabaseConnections, DatabaseConnectionsWithRawResponse, AsyncDatabaseConnectionsWithRawResponse, + DatabaseConnectionsWithStreamingResponse, + AsyncDatabaseConnectionsWithStreamingResponse, ) __all__ = [ @@ -13,8 +22,12 @@ "AsyncDrivers", "DriversWithRawResponse", "AsyncDriversWithRawResponse", + "DriversWithStreamingResponse", + "AsyncDriversWithStreamingResponse", "DatabaseConnections", "AsyncDatabaseConnections", "DatabaseConnectionsWithRawResponse", "AsyncDatabaseConnectionsWithRawResponse", + "DatabaseConnectionsWithStreamingResponse", + "AsyncDatabaseConnectionsWithStreamingResponse", ] diff --git a/src/dataherald/resources/database_connections/database_connections.py b/src/dataherald/resources/database_connections/database_connections.py index 1931e73..61b2c37 100644 --- a/src/dataherald/resources/database_connections/database_connections.py +++ b/src/dataherald/resources/database_connections/database_connections.py @@ -12,12 +12,24 @@ database_connection_create_params, database_connection_update_params, ) -from .drivers import Drivers, AsyncDrivers, DriversWithRawResponse, AsyncDriversWithRawResponse +from .drivers import ( + Drivers, + AsyncDrivers, + DriversWithRawResponse, + AsyncDriversWithRawResponse, + DriversWithStreamingResponse, + AsyncDriversWithStreamingResponse, +) from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -34,6 +46,10 @@ def drivers(self) -> Drivers: def with_raw_response(self) -> DatabaseConnectionsWithRawResponse: return DatabaseConnectionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> DatabaseConnectionsWithStreamingResponse: + return DatabaseConnectionsWithStreamingResponse(self) + def create( self, *, @@ -42,7 +58,7 @@ def create( credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, - ssh_settings: database_connection_create_params.SshSettings | NotGiven = NOT_GIVEN, + ssh_settings: database_connection_create_params.SSHSettings | NotGiven = NOT_GIVEN, use_ssh: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -106,6 +122,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/database-connections/{id}", options=make_request_options( @@ -123,7 +141,7 @@ def update( credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, - ssh_settings: database_connection_update_params.SshSettings | NotGiven = NOT_GIVEN, + ssh_settings: database_connection_update_params.SSHSettings | NotGiven = NOT_GIVEN, use_ssh: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -144,6 +162,8 @@ def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._put( f"/api/database-connections/{id}", body=maybe_transform( @@ -193,6 +213,10 @@ def drivers(self) -> AsyncDrivers: def with_raw_response(self) -> AsyncDatabaseConnectionsWithRawResponse: return AsyncDatabaseConnectionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncDatabaseConnectionsWithStreamingResponse: + return AsyncDatabaseConnectionsWithStreamingResponse(self) + async def create( self, *, @@ -201,7 +225,7 @@ async def create( credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, - ssh_settings: database_connection_create_params.SshSettings | NotGiven = NOT_GIVEN, + ssh_settings: database_connection_create_params.SSHSettings | NotGiven = NOT_GIVEN, use_ssh: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -265,6 +289,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/database-connections/{id}", options=make_request_options( @@ -282,7 +308,7 @@ async def update( credential_file_content: Union[object, str] | NotGiven = NOT_GIVEN, llm_api_key: str | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, - ssh_settings: database_connection_update_params.SshSettings | NotGiven = NOT_GIVEN, + ssh_settings: database_connection_update_params.SSHSettings | NotGiven = NOT_GIVEN, use_ssh: bool | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -303,6 +329,8 @@ async def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._put( f"/api/database-connections/{id}", body=maybe_transform( @@ -377,3 +405,39 @@ def __init__(self, database_connections: AsyncDatabaseConnections) -> None: self.list = async_to_raw_response_wrapper( database_connections.list, ) + + +class DatabaseConnectionsWithStreamingResponse: + def __init__(self, database_connections: DatabaseConnections) -> None: + self.drivers = DriversWithStreamingResponse(database_connections.drivers) + + self.create = to_streamed_response_wrapper( + database_connections.create, + ) + self.retrieve = to_streamed_response_wrapper( + database_connections.retrieve, + ) + self.update = to_streamed_response_wrapper( + database_connections.update, + ) + self.list = to_streamed_response_wrapper( + database_connections.list, + ) + + +class AsyncDatabaseConnectionsWithStreamingResponse: + def __init__(self, database_connections: AsyncDatabaseConnections) -> None: + self.drivers = AsyncDriversWithStreamingResponse(database_connections.drivers) + + self.create = async_to_streamed_response_wrapper( + database_connections.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + database_connections.retrieve, + ) + self.update = async_to_streamed_response_wrapper( + database_connections.update, + ) + self.list = async_to_streamed_response_wrapper( + database_connections.list, + ) diff --git a/src/dataherald/resources/database_connections/drivers.py b/src/dataherald/resources/database_connections/drivers.py index 279c2c9..c9e5e1d 100644 --- a/src/dataherald/resources/database_connections/drivers.py +++ b/src/dataherald/resources/database_connections/drivers.py @@ -7,7 +7,12 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -21,6 +26,10 @@ class Drivers(SyncAPIResource): def with_raw_response(self) -> DriversWithRawResponse: return DriversWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> DriversWithStreamingResponse: + return DriversWithStreamingResponse(self) + def list( self, *, @@ -46,6 +55,10 @@ class AsyncDrivers(AsyncAPIResource): def with_raw_response(self) -> AsyncDriversWithRawResponse: return AsyncDriversWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncDriversWithStreamingResponse: + return AsyncDriversWithStreamingResponse(self) + async def list( self, *, @@ -78,3 +91,17 @@ def __init__(self, drivers: AsyncDrivers) -> None: self.list = async_to_raw_response_wrapper( drivers.list, ) + + +class DriversWithStreamingResponse: + def __init__(self, drivers: Drivers) -> None: + self.list = to_streamed_response_wrapper( + drivers.list, + ) + + +class AsyncDriversWithStreamingResponse: + def __init__(self, drivers: AsyncDrivers) -> None: + self.list = async_to_streamed_response_wrapper( + drivers.list, + ) diff --git a/src/dataherald/resources/engine.py b/src/dataherald/resources/engine.py index 9e99778..443134d 100644 --- a/src/dataherald/resources/engine.py +++ b/src/dataherald/resources/engine.py @@ -7,7 +7,12 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -20,6 +25,10 @@ class Engine(SyncAPIResource): def with_raw_response(self) -> EngineWithRawResponse: return EngineWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> EngineWithStreamingResponse: + return EngineWithStreamingResponse(self) + def heartbeat( self, *, @@ -45,6 +54,10 @@ class AsyncEngine(AsyncAPIResource): def with_raw_response(self) -> AsyncEngineWithRawResponse: return AsyncEngineWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncEngineWithStreamingResponse: + return AsyncEngineWithStreamingResponse(self) + async def heartbeat( self, *, @@ -77,3 +90,17 @@ def __init__(self, engine: AsyncEngine) -> None: self.heartbeat = async_to_raw_response_wrapper( engine.heartbeat, ) + + +class EngineWithStreamingResponse: + def __init__(self, engine: Engine) -> None: + self.heartbeat = to_streamed_response_wrapper( + engine.heartbeat, + ) + + +class AsyncEngineWithStreamingResponse: + def __init__(self, engine: AsyncEngine) -> None: + self.heartbeat = async_to_streamed_response_wrapper( + engine.heartbeat, + ) diff --git a/src/dataherald/resources/finetunings.py b/src/dataherald/resources/finetunings.py index e969c36..5a51b96 100644 --- a/src/dataherald/resources/finetunings.py +++ b/src/dataherald/resources/finetunings.py @@ -16,7 +16,12 @@ from .._utils import maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -29,6 +34,10 @@ class Finetunings(SyncAPIResource): def with_raw_response(self) -> FinetuningsWithRawResponse: return FinetuningsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> FinetuningsWithStreamingResponse: + return FinetuningsWithStreamingResponse(self) + def create( self, *, @@ -97,6 +106,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/finetunings/{id}", options=make_request_options( @@ -165,6 +176,8 @@ def cancel( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/finetunings/{id}/cancel", options=make_request_options( @@ -179,6 +192,10 @@ class AsyncFinetunings(AsyncAPIResource): def with_raw_response(self) -> AsyncFinetuningsWithRawResponse: return AsyncFinetuningsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncFinetuningsWithStreamingResponse: + return AsyncFinetuningsWithStreamingResponse(self) + async def create( self, *, @@ -247,6 +264,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/finetunings/{id}", options=make_request_options( @@ -315,6 +334,8 @@ async def cancel( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/finetunings/{id}/cancel", options=make_request_options( @@ -354,3 +375,35 @@ def __init__(self, finetunings: AsyncFinetunings) -> None: self.cancel = async_to_raw_response_wrapper( finetunings.cancel, ) + + +class FinetuningsWithStreamingResponse: + def __init__(self, finetunings: Finetunings) -> None: + self.create = to_streamed_response_wrapper( + finetunings.create, + ) + self.retrieve = to_streamed_response_wrapper( + finetunings.retrieve, + ) + self.list = to_streamed_response_wrapper( + finetunings.list, + ) + self.cancel = to_streamed_response_wrapper( + finetunings.cancel, + ) + + +class AsyncFinetuningsWithStreamingResponse: + def __init__(self, finetunings: AsyncFinetunings) -> None: + self.create = async_to_streamed_response_wrapper( + finetunings.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + finetunings.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + finetunings.list, + ) + self.cancel = async_to_streamed_response_wrapper( + finetunings.cancel, + ) diff --git a/src/dataherald/resources/generations.py b/src/dataherald/resources/generations.py index 14625c8..bd6e592 100644 --- a/src/dataherald/resources/generations.py +++ b/src/dataherald/resources/generations.py @@ -17,7 +17,12 @@ from .._utils import maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -31,6 +36,10 @@ class Generations(SyncAPIResource): def with_raw_response(self) -> GenerationsWithRawResponse: return GenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> GenerationsWithStreamingResponse: + return GenerationsWithStreamingResponse(self) + def create( self, id: str, @@ -54,6 +63,8 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/generations/{id}", options=make_request_options( @@ -85,6 +96,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/generations/{id}", options=make_request_options( @@ -121,6 +134,8 @@ def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._put( f"/api/generations/{id}", body=maybe_transform( @@ -205,6 +220,8 @@ def nl_generation( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/generations/{id}/nl-generation", options=make_request_options( @@ -237,6 +254,8 @@ def sql_generation( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/generations/{id}/sql-generation", body=maybe_transform({"sql": sql}, generation_sql_generation_params.GenerationSqlGenerationParams), @@ -252,6 +271,10 @@ class AsyncGenerations(AsyncAPIResource): def with_raw_response(self) -> AsyncGenerationsWithRawResponse: return AsyncGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncGenerationsWithStreamingResponse: + return AsyncGenerationsWithStreamingResponse(self) + async def create( self, id: str, @@ -275,6 +298,8 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/generations/{id}", options=make_request_options( @@ -306,6 +331,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/generations/{id}", options=make_request_options( @@ -342,6 +369,8 @@ async def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._put( f"/api/generations/{id}", body=maybe_transform( @@ -426,6 +455,8 @@ async def nl_generation( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/generations/{id}/nl-generation", options=make_request_options( @@ -458,6 +489,8 @@ async def sql_generation( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/generations/{id}/sql-generation", body=maybe_transform({"sql": sql}, generation_sql_generation_params.GenerationSqlGenerationParams), @@ -510,3 +543,47 @@ def __init__(self, generations: AsyncGenerations) -> None: self.sql_generation = async_to_raw_response_wrapper( generations.sql_generation, ) + + +class GenerationsWithStreamingResponse: + def __init__(self, generations: Generations) -> None: + self.create = to_streamed_response_wrapper( + generations.create, + ) + self.retrieve = to_streamed_response_wrapper( + generations.retrieve, + ) + self.update = to_streamed_response_wrapper( + generations.update, + ) + self.list = to_streamed_response_wrapper( + generations.list, + ) + self.nl_generation = to_streamed_response_wrapper( + generations.nl_generation, + ) + self.sql_generation = to_streamed_response_wrapper( + generations.sql_generation, + ) + + +class AsyncGenerationsWithStreamingResponse: + def __init__(self, generations: AsyncGenerations) -> None: + self.create = async_to_streamed_response_wrapper( + generations.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + generations.retrieve, + ) + self.update = async_to_streamed_response_wrapper( + generations.update, + ) + self.list = async_to_streamed_response_wrapper( + generations.list, + ) + self.nl_generation = async_to_streamed_response_wrapper( + generations.nl_generation, + ) + self.sql_generation = async_to_streamed_response_wrapper( + generations.sql_generation, + ) diff --git a/src/dataherald/resources/golden_sqls.py b/src/dataherald/resources/golden_sqls.py index 2082d4b..52fe144 100644 --- a/src/dataherald/resources/golden_sqls.py +++ b/src/dataherald/resources/golden_sqls.py @@ -16,7 +16,12 @@ from .._utils import maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -30,6 +35,10 @@ class GoldenSqls(SyncAPIResource): def with_raw_response(self) -> GoldenSqlsWithRawResponse: return GoldenSqlsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> GoldenSqlsWithStreamingResponse: + return GoldenSqlsWithStreamingResponse(self) + def retrieve( self, id: str, @@ -53,6 +62,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/golden-sqls/{id}", options=make_request_options( @@ -130,6 +141,8 @@ def delete( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( f"/api/golden-sqls/{id}", options=make_request_options( @@ -176,6 +189,10 @@ class AsyncGoldenSqls(AsyncAPIResource): def with_raw_response(self) -> AsyncGoldenSqlsWithRawResponse: return AsyncGoldenSqlsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncGoldenSqlsWithStreamingResponse: + return AsyncGoldenSqlsWithStreamingResponse(self) + async def retrieve( self, id: str, @@ -199,6 +216,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/golden-sqls/{id}", options=make_request_options( @@ -276,6 +295,8 @@ async def delete( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( f"/api/golden-sqls/{id}", options=make_request_options( @@ -347,3 +368,35 @@ def __init__(self, golden_sqls: AsyncGoldenSqls) -> None: self.upload = async_to_raw_response_wrapper( golden_sqls.upload, ) + + +class GoldenSqlsWithStreamingResponse: + def __init__(self, golden_sqls: GoldenSqls) -> None: + self.retrieve = to_streamed_response_wrapper( + golden_sqls.retrieve, + ) + self.list = to_streamed_response_wrapper( + golden_sqls.list, + ) + self.delete = to_streamed_response_wrapper( + golden_sqls.delete, + ) + self.upload = to_streamed_response_wrapper( + golden_sqls.upload, + ) + + +class AsyncGoldenSqlsWithStreamingResponse: + def __init__(self, golden_sqls: AsyncGoldenSqls) -> None: + self.retrieve = async_to_streamed_response_wrapper( + golden_sqls.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + golden_sqls.list, + ) + self.delete = async_to_streamed_response_wrapper( + golden_sqls.delete, + ) + self.upload = async_to_streamed_response_wrapper( + golden_sqls.upload, + ) diff --git a/src/dataherald/resources/heartbeat.py b/src/dataherald/resources/heartbeat.py index 1a80edc..e6abaed 100644 --- a/src/dataherald/resources/heartbeat.py +++ b/src/dataherald/resources/heartbeat.py @@ -7,7 +7,12 @@ from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -20,6 +25,10 @@ class Heartbeat(SyncAPIResource): def with_raw_response(self) -> HeartbeatWithRawResponse: return HeartbeatWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> HeartbeatWithStreamingResponse: + return HeartbeatWithStreamingResponse(self) + def retrieve( self, *, @@ -45,6 +54,10 @@ class AsyncHeartbeat(AsyncAPIResource): def with_raw_response(self) -> AsyncHeartbeatWithRawResponse: return AsyncHeartbeatWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncHeartbeatWithStreamingResponse: + return AsyncHeartbeatWithStreamingResponse(self) + async def retrieve( self, *, @@ -77,3 +90,17 @@ def __init__(self, heartbeat: AsyncHeartbeat) -> None: self.retrieve = async_to_raw_response_wrapper( heartbeat.retrieve, ) + + +class HeartbeatWithStreamingResponse: + def __init__(self, heartbeat: Heartbeat) -> None: + self.retrieve = to_streamed_response_wrapper( + heartbeat.retrieve, + ) + + +class AsyncHeartbeatWithStreamingResponse: + def __init__(self, heartbeat: AsyncHeartbeat) -> None: + self.retrieve = async_to_streamed_response_wrapper( + heartbeat.retrieve, + ) diff --git a/src/dataherald/resources/instructions/__init__.py b/src/dataherald/resources/instructions/__init__.py index c8fd692..57b1d94 100644 --- a/src/dataherald/resources/instructions/__init__.py +++ b/src/dataherald/resources/instructions/__init__.py @@ -1,15 +1,33 @@ # File generated from our OpenAPI spec by Stainless. -from .first import First, AsyncFirst, FirstWithRawResponse, AsyncFirstWithRawResponse -from .instructions import Instructions, AsyncInstructions, InstructionsWithRawResponse, AsyncInstructionsWithRawResponse +from .first import ( + First, + AsyncFirst, + FirstWithRawResponse, + AsyncFirstWithRawResponse, + FirstWithStreamingResponse, + AsyncFirstWithStreamingResponse, +) +from .instructions import ( + Instructions, + AsyncInstructions, + InstructionsWithRawResponse, + AsyncInstructionsWithRawResponse, + InstructionsWithStreamingResponse, + AsyncInstructionsWithStreamingResponse, +) __all__ = [ "First", "AsyncFirst", "FirstWithRawResponse", "AsyncFirstWithRawResponse", + "FirstWithStreamingResponse", + "AsyncFirstWithStreamingResponse", "Instructions", "AsyncInstructions", "InstructionsWithRawResponse", "AsyncInstructionsWithRawResponse", + "InstructionsWithStreamingResponse", + "AsyncInstructionsWithStreamingResponse", ] diff --git a/src/dataherald/resources/instructions/first.py b/src/dataherald/resources/instructions/first.py index a1c8686..26921f7 100644 --- a/src/dataherald/resources/instructions/first.py +++ b/src/dataherald/resources/instructions/first.py @@ -7,7 +7,12 @@ from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -21,6 +26,10 @@ class First(SyncAPIResource): def with_raw_response(self) -> FirstWithRawResponse: return FirstWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> FirstWithStreamingResponse: + return FirstWithStreamingResponse(self) + def retrieve( self, *, @@ -46,6 +55,10 @@ class AsyncFirst(AsyncAPIResource): def with_raw_response(self) -> AsyncFirstWithRawResponse: return AsyncFirstWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncFirstWithStreamingResponse: + return AsyncFirstWithStreamingResponse(self) + async def retrieve( self, *, @@ -78,3 +91,17 @@ def __init__(self, first: AsyncFirst) -> None: self.retrieve = async_to_raw_response_wrapper( first.retrieve, ) + + +class FirstWithStreamingResponse: + def __init__(self, first: First) -> None: + self.retrieve = to_streamed_response_wrapper( + first.retrieve, + ) + + +class AsyncFirstWithStreamingResponse: + def __init__(self, first: AsyncFirst) -> None: + self.retrieve = async_to_streamed_response_wrapper( + first.retrieve, + ) diff --git a/src/dataherald/resources/instructions/instructions.py b/src/dataherald/resources/instructions/instructions.py index b397d3e..6aad586 100644 --- a/src/dataherald/resources/instructions/instructions.py +++ b/src/dataherald/resources/instructions/instructions.py @@ -4,7 +4,14 @@ import httpx -from .first import First, AsyncFirst, FirstWithRawResponse, AsyncFirstWithRawResponse +from .first import ( + First, + AsyncFirst, + FirstWithRawResponse, + AsyncFirstWithRawResponse, + FirstWithStreamingResponse, + AsyncFirstWithStreamingResponse, +) from ...types import ( InstructionListResponse, instruction_list_params, @@ -15,7 +22,12 @@ from ..._utils import maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -33,6 +45,10 @@ def first(self) -> First: def with_raw_response(self) -> InstructionsWithRawResponse: return InstructionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> InstructionsWithStreamingResponse: + return InstructionsWithStreamingResponse(self) + def create( self, *, @@ -100,6 +116,8 @@ def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._put( f"/api/instructions/{id}", body=maybe_transform( @@ -176,6 +194,8 @@ def delete( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._delete( f"/api/instructions/{id}", options=make_request_options( @@ -194,6 +214,10 @@ def first(self) -> AsyncFirst: def with_raw_response(self) -> AsyncInstructionsWithRawResponse: return AsyncInstructionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncInstructionsWithStreamingResponse: + return AsyncInstructionsWithStreamingResponse(self) + async def create( self, *, @@ -261,6 +285,8 @@ async def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._put( f"/api/instructions/{id}", body=maybe_transform( @@ -337,6 +363,8 @@ async def delete( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._delete( f"/api/instructions/{id}", options=make_request_options( @@ -380,3 +408,39 @@ def __init__(self, instructions: AsyncInstructions) -> None: self.delete = async_to_raw_response_wrapper( instructions.delete, ) + + +class InstructionsWithStreamingResponse: + def __init__(self, instructions: Instructions) -> None: + self.first = FirstWithStreamingResponse(instructions.first) + + self.create = to_streamed_response_wrapper( + instructions.create, + ) + self.update = to_streamed_response_wrapper( + instructions.update, + ) + self.list = to_streamed_response_wrapper( + instructions.list, + ) + self.delete = to_streamed_response_wrapper( + instructions.delete, + ) + + +class AsyncInstructionsWithStreamingResponse: + def __init__(self, instructions: AsyncInstructions) -> None: + self.first = AsyncFirstWithStreamingResponse(instructions.first) + + self.create = async_to_streamed_response_wrapper( + instructions.create, + ) + self.update = async_to_streamed_response_wrapper( + instructions.update, + ) + self.list = async_to_streamed_response_wrapper( + instructions.list, + ) + self.delete = async_to_streamed_response_wrapper( + instructions.delete, + ) diff --git a/src/dataherald/resources/nl_generations.py b/src/dataherald/resources/nl_generations.py index 9f86199..9563a9a 100644 --- a/src/dataherald/resources/nl_generations.py +++ b/src/dataherald/resources/nl_generations.py @@ -9,7 +9,12 @@ from .._utils import maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -23,6 +28,10 @@ class NlGenerations(SyncAPIResource): def with_raw_response(self) -> NlGenerationsWithRawResponse: return NlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> NlGenerationsWithStreamingResponse: + return NlGenerationsWithStreamingResponse(self) + def create( self, *, @@ -87,6 +96,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/nl-generations/{id}", options=make_request_options( @@ -147,6 +158,10 @@ class AsyncNlGenerations(AsyncAPIResource): def with_raw_response(self) -> AsyncNlGenerationsWithRawResponse: return AsyncNlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncNlGenerationsWithStreamingResponse: + return AsyncNlGenerationsWithStreamingResponse(self) + async def create( self, *, @@ -211,6 +226,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/nl-generations/{id}", options=make_request_options( @@ -290,3 +307,29 @@ def __init__(self, nl_generations: AsyncNlGenerations) -> None: self.list = async_to_raw_response_wrapper( nl_generations.list, ) + + +class NlGenerationsWithStreamingResponse: + def __init__(self, nl_generations: NlGenerations) -> None: + self.create = to_streamed_response_wrapper( + nl_generations.create, + ) + self.retrieve = to_streamed_response_wrapper( + nl_generations.retrieve, + ) + self.list = to_streamed_response_wrapper( + nl_generations.list, + ) + + +class AsyncNlGenerationsWithStreamingResponse: + def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self.create = async_to_streamed_response_wrapper( + nl_generations.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + nl_generations.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + nl_generations.list, + ) diff --git a/src/dataherald/resources/prompts/__init__.py b/src/dataherald/resources/prompts/__init__.py index a1a79a1..a1577d7 100644 --- a/src/dataherald/resources/prompts/__init__.py +++ b/src/dataherald/resources/prompts/__init__.py @@ -1,11 +1,20 @@ # File generated from our OpenAPI spec by Stainless. -from .prompts import Prompts, AsyncPrompts, PromptsWithRawResponse, AsyncPromptsWithRawResponse +from .prompts import ( + Prompts, + AsyncPrompts, + PromptsWithRawResponse, + AsyncPromptsWithRawResponse, + PromptsWithStreamingResponse, + AsyncPromptsWithStreamingResponse, +) from .sql_generations import ( SqlGenerations, AsyncSqlGenerations, SqlGenerationsWithRawResponse, AsyncSqlGenerationsWithRawResponse, + SqlGenerationsWithStreamingResponse, + AsyncSqlGenerationsWithStreamingResponse, ) __all__ = [ @@ -13,8 +22,12 @@ "AsyncSqlGenerations", "SqlGenerationsWithRawResponse", "AsyncSqlGenerationsWithRawResponse", + "SqlGenerationsWithStreamingResponse", + "AsyncSqlGenerationsWithStreamingResponse", "Prompts", "AsyncPrompts", "PromptsWithRawResponse", "AsyncPromptsWithRawResponse", + "PromptsWithStreamingResponse", + "AsyncPromptsWithStreamingResponse", ] diff --git a/src/dataherald/resources/prompts/prompts.py b/src/dataherald/resources/prompts/prompts.py index b3b5087..eea29db 100644 --- a/src/dataherald/resources/prompts/prompts.py +++ b/src/dataherald/resources/prompts/prompts.py @@ -9,7 +9,12 @@ from ..._utils import maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -18,6 +23,8 @@ AsyncSqlGenerations, SqlGenerationsWithRawResponse, AsyncSqlGenerationsWithRawResponse, + SqlGenerationsWithStreamingResponse, + AsyncSqlGenerationsWithStreamingResponse, ) __all__ = ["Prompts", "AsyncPrompts"] @@ -32,6 +39,10 @@ def sql_generations(self) -> SqlGenerations: def with_raw_response(self) -> PromptsWithRawResponse: return PromptsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> PromptsWithStreamingResponse: + return PromptsWithStreamingResponse(self) + def create( self, *, @@ -96,6 +107,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/prompts/{id}", options=make_request_options( @@ -160,6 +173,10 @@ def sql_generations(self) -> AsyncSqlGenerations: def with_raw_response(self) -> AsyncPromptsWithRawResponse: return AsyncPromptsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncPromptsWithStreamingResponse: + return AsyncPromptsWithStreamingResponse(self) + async def create( self, *, @@ -224,6 +241,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/prompts/{id}", options=make_request_options( @@ -307,3 +326,33 @@ def __init__(self, prompts: AsyncPrompts) -> None: self.list = async_to_raw_response_wrapper( prompts.list, ) + + +class PromptsWithStreamingResponse: + def __init__(self, prompts: Prompts) -> None: + self.sql_generations = SqlGenerationsWithStreamingResponse(prompts.sql_generations) + + self.create = to_streamed_response_wrapper( + prompts.create, + ) + self.retrieve = to_streamed_response_wrapper( + prompts.retrieve, + ) + self.list = to_streamed_response_wrapper( + prompts.list, + ) + + +class AsyncPromptsWithStreamingResponse: + def __init__(self, prompts: AsyncPrompts) -> None: + self.sql_generations = AsyncSqlGenerationsWithStreamingResponse(prompts.sql_generations) + + self.create = async_to_streamed_response_wrapper( + prompts.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + prompts.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + prompts.list, + ) diff --git a/src/dataherald/resources/prompts/sql_generations.py b/src/dataherald/resources/prompts/sql_generations.py index 71955ba..ab6b467 100644 --- a/src/dataherald/resources/prompts/sql_generations.py +++ b/src/dataherald/resources/prompts/sql_generations.py @@ -8,7 +8,12 @@ from ..._utils import maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -27,6 +32,10 @@ class SqlGenerations(SyncAPIResource): def with_raw_response(self) -> SqlGenerationsWithRawResponse: return SqlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> SqlGenerationsWithStreamingResponse: + return SqlGenerationsWithStreamingResponse(self) + def create( self, id: str, @@ -54,6 +63,8 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/prompts/{id}/sql-generations", body=maybe_transform( @@ -98,6 +109,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/prompts/{id}/sql-generations", options=make_request_options( @@ -144,6 +157,8 @@ def nl_generations( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/prompts/{id}/sql-generations/nl-generations", body=maybe_transform( @@ -166,6 +181,10 @@ class AsyncSqlGenerations(AsyncAPIResource): def with_raw_response(self) -> AsyncSqlGenerationsWithRawResponse: return AsyncSqlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncSqlGenerationsWithStreamingResponse: + return AsyncSqlGenerationsWithStreamingResponse(self) + async def create( self, id: str, @@ -193,6 +212,8 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/prompts/{id}/sql-generations", body=maybe_transform( @@ -237,6 +258,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/prompts/{id}/sql-generations", options=make_request_options( @@ -283,6 +306,8 @@ async def nl_generations( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/prompts/{id}/sql-generations/nl-generations", body=maybe_transform( @@ -324,3 +349,29 @@ def __init__(self, sql_generations: AsyncSqlGenerations) -> None: self.nl_generations = async_to_raw_response_wrapper( sql_generations.nl_generations, ) + + +class SqlGenerationsWithStreamingResponse: + def __init__(self, sql_generations: SqlGenerations) -> None: + self.create = to_streamed_response_wrapper( + sql_generations.create, + ) + self.retrieve = to_streamed_response_wrapper( + sql_generations.retrieve, + ) + self.nl_generations = to_streamed_response_wrapper( + sql_generations.nl_generations, + ) + + +class AsyncSqlGenerationsWithStreamingResponse: + def __init__(self, sql_generations: AsyncSqlGenerations) -> None: + self.create = async_to_streamed_response_wrapper( + sql_generations.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + sql_generations.retrieve, + ) + self.nl_generations = async_to_streamed_response_wrapper( + sql_generations.nl_generations, + ) diff --git a/src/dataherald/resources/sql_generations/__init__.py b/src/dataherald/resources/sql_generations/__init__.py index ffc66f6..834f84f 100644 --- a/src/dataherald/resources/sql_generations/__init__.py +++ b/src/dataherald/resources/sql_generations/__init__.py @@ -5,12 +5,16 @@ AsyncNlGenerations, NlGenerationsWithRawResponse, AsyncNlGenerationsWithRawResponse, + NlGenerationsWithStreamingResponse, + AsyncNlGenerationsWithStreamingResponse, ) from .sql_generations import ( SqlGenerations, AsyncSqlGenerations, SqlGenerationsWithRawResponse, AsyncSqlGenerationsWithRawResponse, + SqlGenerationsWithStreamingResponse, + AsyncSqlGenerationsWithStreamingResponse, ) __all__ = [ @@ -18,8 +22,12 @@ "AsyncNlGenerations", "NlGenerationsWithRawResponse", "AsyncNlGenerationsWithRawResponse", + "NlGenerationsWithStreamingResponse", + "AsyncNlGenerationsWithStreamingResponse", "SqlGenerations", "AsyncSqlGenerations", "SqlGenerationsWithRawResponse", "AsyncSqlGenerationsWithRawResponse", + "SqlGenerationsWithStreamingResponse", + "AsyncSqlGenerationsWithStreamingResponse", ] diff --git a/src/dataherald/resources/sql_generations/nl_generations.py b/src/dataherald/resources/sql_generations/nl_generations.py index 4faa932..8ed3183 100644 --- a/src/dataherald/resources/sql_generations/nl_generations.py +++ b/src/dataherald/resources/sql_generations/nl_generations.py @@ -8,7 +8,12 @@ from ..._utils import maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -23,6 +28,10 @@ class NlGenerations(SyncAPIResource): def with_raw_response(self) -> NlGenerationsWithRawResponse: return NlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> NlGenerationsWithStreamingResponse: + return NlGenerationsWithStreamingResponse(self) + def create( self, id: str, @@ -48,6 +57,8 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._post( f"/api/sql-generations/{id}/nl-generations", body=maybe_transform( @@ -90,6 +101,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/sql-generations/{id}/nl-generations", options=make_request_options( @@ -116,6 +129,10 @@ class AsyncNlGenerations(AsyncAPIResource): def with_raw_response(self) -> AsyncNlGenerationsWithRawResponse: return AsyncNlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncNlGenerationsWithStreamingResponse: + return AsyncNlGenerationsWithStreamingResponse(self) + async def create( self, id: str, @@ -141,6 +158,8 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._post( f"/api/sql-generations/{id}/nl-generations", body=maybe_transform( @@ -183,6 +202,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/sql-generations/{id}/nl-generations", options=make_request_options( @@ -222,3 +243,23 @@ def __init__(self, nl_generations: AsyncNlGenerations) -> None: self.retrieve = async_to_raw_response_wrapper( nl_generations.retrieve, ) + + +class NlGenerationsWithStreamingResponse: + def __init__(self, nl_generations: NlGenerations) -> None: + self.create = to_streamed_response_wrapper( + nl_generations.create, + ) + self.retrieve = to_streamed_response_wrapper( + nl_generations.retrieve, + ) + + +class AsyncNlGenerationsWithStreamingResponse: + def __init__(self, nl_generations: AsyncNlGenerations) -> None: + self.create = async_to_streamed_response_wrapper( + nl_generations.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + nl_generations.retrieve, + ) diff --git a/src/dataherald/resources/sql_generations/sql_generations.py b/src/dataherald/resources/sql_generations/sql_generations.py index ba2d8b9..f2bb54e 100644 --- a/src/dataherald/resources/sql_generations/sql_generations.py +++ b/src/dataherald/resources/sql_generations/sql_generations.py @@ -15,7 +15,12 @@ from ..._utils import maybe_transform from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource -from ..._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from ..._base_client import ( make_request_options, ) @@ -25,6 +30,8 @@ AsyncNlGenerations, NlGenerationsWithRawResponse, AsyncNlGenerationsWithRawResponse, + NlGenerationsWithStreamingResponse, + AsyncNlGenerationsWithStreamingResponse, ) __all__ = ["SqlGenerations", "AsyncSqlGenerations"] @@ -39,6 +46,10 @@ def nl_generations(self) -> NlGenerations: def with_raw_response(self) -> SqlGenerationsWithRawResponse: return SqlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> SqlGenerationsWithStreamingResponse: + return SqlGenerationsWithStreamingResponse(self) + def create( self, *, @@ -107,6 +118,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/sql-generations/{id}", options=make_request_options( @@ -185,6 +198,8 @@ def execute( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/sql-generations/{id}/execute", options=make_request_options( @@ -207,6 +222,10 @@ def nl_generations(self) -> AsyncNlGenerations: def with_raw_response(self) -> AsyncSqlGenerationsWithRawResponse: return AsyncSqlGenerationsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncSqlGenerationsWithStreamingResponse: + return AsyncSqlGenerationsWithStreamingResponse(self) + async def create( self, *, @@ -275,6 +294,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/sql-generations/{id}", options=make_request_options( @@ -353,6 +374,8 @@ async def execute( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/sql-generations/{id}/execute", options=make_request_options( @@ -400,3 +423,39 @@ def __init__(self, sql_generations: AsyncSqlGenerations) -> None: self.execute = async_to_raw_response_wrapper( sql_generations.execute, ) + + +class SqlGenerationsWithStreamingResponse: + def __init__(self, sql_generations: SqlGenerations) -> None: + self.nl_generations = NlGenerationsWithStreamingResponse(sql_generations.nl_generations) + + self.create = to_streamed_response_wrapper( + sql_generations.create, + ) + self.retrieve = to_streamed_response_wrapper( + sql_generations.retrieve, + ) + self.list = to_streamed_response_wrapper( + sql_generations.list, + ) + self.execute = to_streamed_response_wrapper( + sql_generations.execute, + ) + + +class AsyncSqlGenerationsWithStreamingResponse: + def __init__(self, sql_generations: AsyncSqlGenerations) -> None: + self.nl_generations = AsyncNlGenerationsWithStreamingResponse(sql_generations.nl_generations) + + self.create = async_to_streamed_response_wrapper( + sql_generations.create, + ) + self.retrieve = async_to_streamed_response_wrapper( + sql_generations.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + sql_generations.list, + ) + self.execute = async_to_streamed_response_wrapper( + sql_generations.execute, + ) diff --git a/src/dataherald/resources/table_descriptions.py b/src/dataherald/resources/table_descriptions.py index fc7ce7f..d19efc7 100644 --- a/src/dataherald/resources/table_descriptions.py +++ b/src/dataherald/resources/table_descriptions.py @@ -18,7 +18,12 @@ from .._utils import maybe_transform from .._compat import cached_property from .._resource import SyncAPIResource, AsyncAPIResource -from .._response import to_raw_response_wrapper, async_to_raw_response_wrapper +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) from .._base_client import ( make_request_options, ) @@ -31,6 +36,10 @@ class TableDescriptions(SyncAPIResource): def with_raw_response(self) -> TableDescriptionsWithRawResponse: return TableDescriptionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> TableDescriptionsWithStreamingResponse: + return TableDescriptionsWithStreamingResponse(self) + def retrieve( self, id: str, @@ -54,6 +63,8 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._get( f"/api/table-descriptions/{id}", options=make_request_options( @@ -89,6 +100,8 @@ def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return self._put( f"/api/table-descriptions/{id}", body=maybe_transform( @@ -193,6 +206,10 @@ class AsyncTableDescriptions(AsyncAPIResource): def with_raw_response(self) -> AsyncTableDescriptionsWithRawResponse: return AsyncTableDescriptionsWithRawResponse(self) + @cached_property + def with_streaming_response(self) -> AsyncTableDescriptionsWithStreamingResponse: + return AsyncTableDescriptionsWithStreamingResponse(self) + async def retrieve( self, id: str, @@ -216,6 +233,8 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._get( f"/api/table-descriptions/{id}", options=make_request_options( @@ -251,6 +270,8 @@ async def update( timeout: Override the client-level default timeout for this request, in seconds """ + if not id: + raise ValueError(f"Expected a non-empty value for `id` but received {id!r}") return await self._put( f"/api/table-descriptions/{id}", body=maybe_transform( @@ -380,3 +401,35 @@ def __init__(self, table_descriptions: AsyncTableDescriptions) -> None: self.sync_schemas = async_to_raw_response_wrapper( table_descriptions.sync_schemas, ) + + +class TableDescriptionsWithStreamingResponse: + def __init__(self, table_descriptions: TableDescriptions) -> None: + self.retrieve = to_streamed_response_wrapper( + table_descriptions.retrieve, + ) + self.update = to_streamed_response_wrapper( + table_descriptions.update, + ) + self.list = to_streamed_response_wrapper( + table_descriptions.list, + ) + self.sync_schemas = to_streamed_response_wrapper( + table_descriptions.sync_schemas, + ) + + +class AsyncTableDescriptionsWithStreamingResponse: + def __init__(self, table_descriptions: AsyncTableDescriptions) -> None: + self.retrieve = async_to_streamed_response_wrapper( + table_descriptions.retrieve, + ) + self.update = async_to_streamed_response_wrapper( + table_descriptions.update, + ) + self.list = async_to_streamed_response_wrapper( + table_descriptions.list, + ) + self.sync_schemas = async_to_streamed_response_wrapper( + table_descriptions.sync_schemas, + ) diff --git a/src/dataherald/types/database_connection_create_params.py b/src/dataherald/types/database_connection_create_params.py index c7698ca..dba52af 100644 --- a/src/dataherald/types/database_connection_create_params.py +++ b/src/dataherald/types/database_connection_create_params.py @@ -5,7 +5,7 @@ from typing import Union from typing_extensions import TypedDict -__all__ = ["DatabaseConnectionCreateParams", "SshSettings"] +__all__ = ["DatabaseConnectionCreateParams", "SSHSettings"] class DatabaseConnectionCreateParams(TypedDict, total=False): @@ -19,12 +19,12 @@ class DatabaseConnectionCreateParams(TypedDict, total=False): metadata: object - ssh_settings: SshSettings + ssh_settings: SSHSettings use_ssh: bool -class SshSettings(TypedDict, total=False): +class SSHSettings(TypedDict, total=False): db_driver: str db_name: str diff --git a/src/dataherald/types/database_connection_update_params.py b/src/dataherald/types/database_connection_update_params.py index db8a4c4..53eb8ec 100644 --- a/src/dataherald/types/database_connection_update_params.py +++ b/src/dataherald/types/database_connection_update_params.py @@ -5,7 +5,7 @@ from typing import Union from typing_extensions import TypedDict -__all__ = ["DatabaseConnectionUpdateParams", "SshSettings"] +__all__ = ["DatabaseConnectionUpdateParams", "SSHSettings"] class DatabaseConnectionUpdateParams(TypedDict, total=False): @@ -19,12 +19,12 @@ class DatabaseConnectionUpdateParams(TypedDict, total=False): metadata: object - ssh_settings: SshSettings + ssh_settings: SSHSettings use_ssh: bool -class SshSettings(TypedDict, total=False): +class SSHSettings(TypedDict, total=False): db_driver: str db_name: str diff --git a/src/dataherald/types/db_connection_response.py b/src/dataherald/types/db_connection_response.py index 1d98f87..872930e 100644 --- a/src/dataherald/types/db_connection_response.py +++ b/src/dataherald/types/db_connection_response.py @@ -5,7 +5,7 @@ from .._models import BaseModel -__all__ = ["DBConnectionResponse", "Metadata", "MetadataDataheraldInternal", "SshSettings"] +__all__ = ["DBConnectionResponse", "Metadata", "MetadataDataheraldInternal", "SSHSettings"] class MetadataDataheraldInternal(BaseModel): @@ -16,7 +16,7 @@ class Metadata(BaseModel): dataherald_internal: Optional[MetadataDataheraldInternal] = None -class SshSettings(BaseModel): +class SSHSettings(BaseModel): db_driver: Optional[str] = None db_name: Optional[str] = None @@ -49,7 +49,7 @@ class DBConnectionResponse(BaseModel): path_to_credentials_file: Optional[str] = None - ssh_settings: Optional[SshSettings] = None + ssh_settings: Optional[SSHSettings] = None uri: Optional[str] = None diff --git a/src/dataherald/types/finetuning_create_params.py b/src/dataherald/types/finetuning_create_params.py index d01fc31..8bba060 100644 --- a/src/dataherald/types/finetuning_create_params.py +++ b/src/dataherald/types/finetuning_create_params.py @@ -21,8 +21,8 @@ class FinetuningCreateParams(TypedDict, total=False): class BaseLlm(TypedDict, total=False): - _model_name: str + model_name: str - _model_parameters: Dict[str, str] + model_parameters: Dict[str, str] - _model_provider: str + model_provider: str diff --git a/src/dataherald/types/finetuning_response.py b/src/dataherald/types/finetuning_response.py index e5cb51a..a48cbf3 100644 --- a/src/dataherald/types/finetuning_response.py +++ b/src/dataherald/types/finetuning_response.py @@ -3,17 +3,19 @@ from typing import Dict, List, Optional from datetime import datetime +from pydantic import Field as FieldInfo + from .._models import BaseModel __all__ = ["FinetuningResponse", "BaseLlm", "Metadata", "MetadataDataheraldInternal"] class BaseLlm(BaseModel): - _model_name: Optional[str] = None + api_model_name: Optional[str] = FieldInfo(alias="model_name", default=None) - _model_parameters: Optional[Dict[str, str]] = None + api_model_parameters: Optional[Dict[str, str]] = FieldInfo(alias="model_parameters", default=None) - _model_provider: Optional[str] = None + api_model_provider: Optional[str] = FieldInfo(alias="model_provider", default=None) class MetadataDataheraldInternal(BaseModel): diff --git a/tests/api_resources/database_connections/test_drivers.py b/tests/api_resources/database_connections/test_drivers.py index 8118581..5790065 100644 --- a/tests/api_resources/database_connections/test_drivers.py +++ b/tests/api_resources/database_connections/test_drivers.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -28,10 +29,23 @@ def test_method_list(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.database_connections.drivers.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" driver = response.parse() assert_matches_type(DriverListResponse, driver, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.database_connections.drivers.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + driver = response.parse() + assert_matches_type(DriverListResponse, driver, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncDrivers: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -46,6 +60,19 @@ async def test_method_list(self, client: AsyncDataherald) -> None: @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.database_connections.drivers.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - driver = response.parse() + driver = await response.parse() assert_matches_type(DriverListResponse, driver, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.database_connections.drivers.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + driver = await response.parse() + assert_matches_type(DriverListResponse, driver, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/instructions/test_first.py b/tests/api_resources/instructions/test_first.py index 6459158..52ab76c 100644 --- a/tests/api_resources/instructions/test_first.py +++ b/tests/api_resources/instructions/test_first.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -28,10 +29,23 @@ def test_method_retrieve(self, client: Dataherald) -> None: @parametrize def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.instructions.first.with_raw_response.retrieve() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" first = response.parse() assert_matches_type(InstructionResponse, first, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.instructions.first.with_streaming_response.retrieve() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + first = response.parse() + assert_matches_type(InstructionResponse, first, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncFirst: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -46,6 +60,19 @@ async def test_method_retrieve(self, client: AsyncDataherald) -> None: @parametrize async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.instructions.first.with_raw_response.retrieve() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - first = response.parse() + first = await response.parse() assert_matches_type(InstructionResponse, first, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.instructions.first.with_streaming_response.retrieve() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + first = await response.parse() + assert_matches_type(InstructionResponse, first, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/prompts/test_sql_generations.py b/tests/api_resources/prompts/test_sql_generations.py index 90c6192..1f1cce6 100644 --- a/tests/api_resources/prompts/test_sql_generations.py +++ b/tests/api_resources/prompts/test_sql_generations.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -43,10 +44,32 @@ def test_raw_response_create(self, client: Dataherald) -> None: response = client.prompts.sql_generations.with_raw_response.create( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.prompts.sql_generations.with_streaming_response.create( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_create(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.prompts.sql_generations.with_raw_response.create( + "", + ) + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: sql_generation = client.prompts.sql_generations.retrieve( @@ -70,10 +93,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.prompts.sql_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(object, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.prompts.sql_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(object, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.prompts.sql_generations.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_nl_generations(self, client: Dataherald) -> None: sql_generation = client.prompts.sql_generations.nl_generations( @@ -103,10 +148,34 @@ def test_raw_response_nl_generations(self, client: Dataherald) -> None: "string", sql_generation={}, ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_nl_generations(self, client: Dataherald) -> None: + with client.prompts.sql_generations.with_streaming_response.nl_generations( + "string", + sql_generation={}, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_nl_generations(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.prompts.sql_generations.with_raw_response.nl_generations( + "", + sql_generation={}, + ) + class TestAsyncSqlGenerations: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -136,10 +205,32 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: response = await client.prompts.sql_generations.with_raw_response.create( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.prompts.sql_generations.with_streaming_response.create( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_create(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.prompts.sql_generations.with_raw_response.create( + "", + ) + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: sql_generation = await client.prompts.sql_generations.retrieve( @@ -163,10 +254,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.prompts.sql_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(object, sql_generation, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.prompts.sql_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(object, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.prompts.sql_generations.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_nl_generations(self, client: AsyncDataherald) -> None: sql_generation = await client.prompts.sql_generations.nl_generations( @@ -196,6 +309,30 @@ async def test_raw_response_nl_generations(self, client: AsyncDataherald) -> Non "string", sql_generation={}, ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) + + @parametrize + async def test_streaming_response_nl_generations(self, client: AsyncDataherald) -> None: + async with client.prompts.sql_generations.with_streaming_response.nl_generations( + "string", + sql_generation={}, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(NlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_nl_generations(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.prompts.sql_generations.with_raw_response.nl_generations( + "", + sql_generation={}, + ) diff --git a/tests/api_resources/sql_generations/test_nl_generations.py b/tests/api_resources/sql_generations/test_nl_generations.py index b5363b5..7f3004e 100644 --- a/tests/api_resources/sql_generations/test_nl_generations.py +++ b/tests/api_resources/sql_generations/test_nl_generations.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -41,10 +42,32 @@ def test_raw_response_create(self, client: Dataherald) -> None: response = client.sql_generations.nl_generations.with_raw_response.create( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" nl_generation = response.parse() assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.sql_generations.nl_generations.with_streaming_response.create( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_create(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.sql_generations.nl_generations.with_raw_response.create( + "", + ) + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: nl_generation = client.sql_generations.nl_generations.retrieve( @@ -68,10 +91,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.sql_generations.nl_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" nl_generation = response.parse() assert_matches_type(object, nl_generation, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.sql_generations.nl_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = response.parse() + assert_matches_type(object, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.sql_generations.nl_generations.with_raw_response.retrieve( + "", + ) + class TestAsyncNlGenerations: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -99,10 +144,32 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: response = await client.sql_generations.nl_generations.with_raw_response.create( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - nl_generation = response.parse() + nl_generation = await response.parse() assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.sql_generations.nl_generations.with_streaming_response.create( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = await response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_create(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.sql_generations.nl_generations.with_raw_response.create( + "", + ) + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: nl_generation = await client.sql_generations.nl_generations.retrieve( @@ -126,6 +193,28 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.sql_generations.nl_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - nl_generation = response.parse() + nl_generation = await response.parse() assert_matches_type(object, nl_generation, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.sql_generations.nl_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = await response.parse() + assert_matches_type(object, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.sql_generations.nl_generations.with_raw_response.retrieve( + "", + ) diff --git a/tests/api_resources/test_database_connections.py b/tests/api_resources/test_database_connections.py index f737863..0efd89b 100644 --- a/tests/api_resources/test_database_connections.py +++ b/tests/api_resources/test_database_connections.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -54,10 +55,23 @@ def test_method_create_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_create(self, client: Dataherald) -> None: response = client.database_connections.with_raw_response.create() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" database_connection = response.parse() assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.database_connections.with_streaming_response.create() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = response.parse() + assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: database_connection = client.database_connections.retrieve( @@ -70,10 +84,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.database_connections.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" database_connection = response.parse() assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.database_connections.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = response.parse() + assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.database_connections.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_update(self, client: Dataherald) -> None: database_connection = client.database_connections.update( @@ -110,10 +146,32 @@ def test_raw_response_update(self, client: Dataherald) -> None: response = client.database_connections.with_raw_response.update( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" database_connection = response.parse() assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + @parametrize + def test_streaming_response_update(self, client: Dataherald) -> None: + with client.database_connections.with_streaming_response.update( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = response.parse() + assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_update(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.database_connections.with_raw_response.update( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: database_connection = client.database_connections.list() @@ -122,10 +180,23 @@ def test_method_list(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.database_connections.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" database_connection = response.parse() assert_matches_type(DatabaseConnectionListResponse, database_connection, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.database_connections.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = response.parse() + assert_matches_type(DatabaseConnectionListResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncDatabaseConnections: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -163,10 +234,23 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N @parametrize async def test_raw_response_create(self, client: AsyncDataherald) -> None: response = await client.database_connections.with_raw_response.create() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - database_connection = response.parse() + database_connection = await response.parse() assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.database_connections.with_streaming_response.create() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = await response.parse() + assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: database_connection = await client.database_connections.retrieve( @@ -179,10 +263,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.database_connections.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - database_connection = response.parse() + database_connection = await response.parse() assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.database_connections.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = await response.parse() + assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.database_connections.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_update(self, client: AsyncDataherald) -> None: database_connection = await client.database_connections.update( @@ -219,10 +325,32 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: response = await client.database_connections.with_raw_response.update( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - database_connection = response.parse() + database_connection = await response.parse() assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + @parametrize + async def test_streaming_response_update(self, client: AsyncDataherald) -> None: + async with client.database_connections.with_streaming_response.update( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = await response.parse() + assert_matches_type(DBConnectionResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_update(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.database_connections.with_raw_response.update( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: database_connection = await client.database_connections.list() @@ -231,6 +359,19 @@ async def test_method_list(self, client: AsyncDataherald) -> None: @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.database_connections.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - database_connection = response.parse() + database_connection = await response.parse() assert_matches_type(DatabaseConnectionListResponse, database_connection, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.database_connections.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + database_connection = await response.parse() + assert_matches_type(DatabaseConnectionListResponse, database_connection, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_engine.py b/tests/api_resources/test_engine.py index b73b4be..0c920f7 100644 --- a/tests/api_resources/test_engine.py +++ b/tests/api_resources/test_engine.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -27,10 +28,23 @@ def test_method_heartbeat(self, client: Dataherald) -> None: @parametrize def test_raw_response_heartbeat(self, client: Dataherald) -> None: response = client.engine.with_raw_response.heartbeat() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" engine = response.parse() assert_matches_type(object, engine, path=["response"]) + @parametrize + def test_streaming_response_heartbeat(self, client: Dataherald) -> None: + with client.engine.with_streaming_response.heartbeat() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + engine = response.parse() + assert_matches_type(object, engine, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncEngine: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -45,6 +59,19 @@ async def test_method_heartbeat(self, client: AsyncDataherald) -> None: @parametrize async def test_raw_response_heartbeat(self, client: AsyncDataherald) -> None: response = await client.engine.with_raw_response.heartbeat() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - engine = response.parse() + engine = await response.parse() assert_matches_type(object, engine, path=["response"]) + + @parametrize + async def test_streaming_response_heartbeat(self, client: AsyncDataherald) -> None: + async with client.engine.with_streaming_response.heartbeat() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + engine = await response.parse() + assert_matches_type(object, engine, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_finetunings.py b/tests/api_resources/test_finetunings.py index f235c94..02a910e 100644 --- a/tests/api_resources/test_finetunings.py +++ b/tests/api_resources/test_finetunings.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -36,9 +37,9 @@ def test_method_create_with_all_params(self, client: Dataherald) -> None: db_connection_id="string", alias="string", base_llm={ - "_model_provider": "string", - "_model_name": "string", - "_model_parameters": {"foo": "string"}, + "model_provider": "string", + "model_name": "string", + "model_parameters": {"foo": "string"}, }, golden_records=["string", "string", "string"], metadata={}, @@ -50,10 +51,25 @@ def test_raw_response_create(self, client: Dataherald) -> None: response = client.finetunings.with_raw_response.create( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" finetuning = response.parse() assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.finetunings.with_streaming_response.create( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = response.parse() + assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: finetuning = client.finetunings.retrieve( @@ -66,10 +82,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.finetunings.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" finetuning = response.parse() assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.finetunings.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = response.parse() + assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.finetunings.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: finetuning = client.finetunings.list( @@ -82,10 +120,25 @@ def test_raw_response_list(self, client: Dataherald) -> None: response = client.finetunings.with_raw_response.list( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" finetuning = response.parse() assert_matches_type(FinetuningListResponse, finetuning, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.finetunings.with_streaming_response.list( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = response.parse() + assert_matches_type(FinetuningListResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_cancel(self, client: Dataherald) -> None: finetuning = client.finetunings.cancel( @@ -98,10 +151,32 @@ def test_raw_response_cancel(self, client: Dataherald) -> None: response = client.finetunings.with_raw_response.cancel( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" finetuning = response.parse() assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + @parametrize + def test_streaming_response_cancel(self, client: Dataherald) -> None: + with client.finetunings.with_streaming_response.cancel( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = response.parse() + assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_cancel(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.finetunings.with_raw_response.cancel( + "", + ) + class TestAsyncFinetunings: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -121,9 +196,9 @@ async def test_method_create_with_all_params(self, client: AsyncDataherald) -> N db_connection_id="string", alias="string", base_llm={ - "_model_provider": "string", - "_model_name": "string", - "_model_parameters": {"foo": "string"}, + "model_provider": "string", + "model_name": "string", + "model_parameters": {"foo": "string"}, }, golden_records=["string", "string", "string"], metadata={}, @@ -135,10 +210,25 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: response = await client.finetunings.with_raw_response.create( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - finetuning = response.parse() + finetuning = await response.parse() assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.finetunings.with_streaming_response.create( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = await response.parse() + assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: finetuning = await client.finetunings.retrieve( @@ -151,10 +241,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.finetunings.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - finetuning = response.parse() + finetuning = await response.parse() assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.finetunings.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = await response.parse() + assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.finetunings.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: finetuning = await client.finetunings.list( @@ -167,10 +279,25 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.finetunings.with_raw_response.list( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - finetuning = response.parse() + finetuning = await response.parse() assert_matches_type(FinetuningListResponse, finetuning, path=["response"]) + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.finetunings.with_streaming_response.list( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = await response.parse() + assert_matches_type(FinetuningListResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_cancel(self, client: AsyncDataherald) -> None: finetuning = await client.finetunings.cancel( @@ -183,6 +310,28 @@ async def test_raw_response_cancel(self, client: AsyncDataherald) -> None: response = await client.finetunings.with_raw_response.cancel( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - finetuning = response.parse() + finetuning = await response.parse() assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + @parametrize + async def test_streaming_response_cancel(self, client: AsyncDataherald) -> None: + async with client.finetunings.with_streaming_response.cancel( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + finetuning = await response.parse() + assert_matches_type(FinetuningResponse, finetuning, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_cancel(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.finetunings.with_raw_response.cancel( + "", + ) diff --git a/tests/api_resources/test_generations.py b/tests/api_resources/test_generations.py index 1667f05..9687df2 100644 --- a/tests/api_resources/test_generations.py +++ b/tests/api_resources/test_generations.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -36,10 +37,32 @@ def test_raw_response_create(self, client: Dataherald) -> None: response = client.generations.with_raw_response.create( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" generation = response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.generations.with_streaming_response.create( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_create(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.generations.with_raw_response.create( + "", + ) + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: generation = client.generations.retrieve( @@ -52,10 +75,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" generation = response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.generations.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_update(self, client: Dataherald) -> None: generation = client.generations.update( @@ -77,10 +122,32 @@ def test_raw_response_update(self, client: Dataherald) -> None: response = client.generations.with_raw_response.update( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" generation = response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + def test_streaming_response_update(self, client: Dataherald) -> None: + with client.generations.with_streaming_response.update( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_update(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.generations.with_raw_response.update( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: generation = client.generations.list() @@ -99,10 +166,23 @@ def test_method_list_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.generations.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" generation = response.parse() assert_matches_type(GenerationListResponse, generation, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.generations.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = response.parse() + assert_matches_type(GenerationListResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_nl_generation(self, client: Dataherald) -> None: generation = client.generations.nl_generation( @@ -115,10 +195,32 @@ def test_raw_response_nl_generation(self, client: Dataherald) -> None: response = client.generations.with_raw_response.nl_generation( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" generation = response.parse() assert_matches_type(NlGenerationResponse, generation, path=["response"]) + @parametrize + def test_streaming_response_nl_generation(self, client: Dataherald) -> None: + with client.generations.with_streaming_response.nl_generation( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = response.parse() + assert_matches_type(NlGenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_nl_generation(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.generations.with_raw_response.nl_generation( + "", + ) + @parametrize def test_method_sql_generation(self, client: Dataherald) -> None: generation = client.generations.sql_generation( @@ -133,10 +235,34 @@ def test_raw_response_sql_generation(self, client: Dataherald) -> None: "string", sql="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" generation = response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + def test_streaming_response_sql_generation(self, client: Dataherald) -> None: + with client.generations.with_streaming_response.sql_generation( + "string", + sql="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_sql_generation(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.generations.with_raw_response.sql_generation( + "", + sql="string", + ) + class TestAsyncGenerations: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -155,10 +281,32 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: response = await client.generations.with_raw_response.create( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - generation = response.parse() + generation = await response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.generations.with_streaming_response.create( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = await response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_create(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.generations.with_raw_response.create( + "", + ) + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: generation = await client.generations.retrieve( @@ -171,10 +319,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - generation = response.parse() + generation = await response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = await response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.generations.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_update(self, client: AsyncDataherald) -> None: generation = await client.generations.update( @@ -196,10 +366,32 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: response = await client.generations.with_raw_response.update( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - generation = response.parse() + generation = await response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + @parametrize + async def test_streaming_response_update(self, client: AsyncDataherald) -> None: + async with client.generations.with_streaming_response.update( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = await response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_update(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.generations.with_raw_response.update( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: generation = await client.generations.list() @@ -218,10 +410,23 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.generations.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - generation = response.parse() + generation = await response.parse() assert_matches_type(GenerationListResponse, generation, path=["response"]) + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.generations.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = await response.parse() + assert_matches_type(GenerationListResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_nl_generation(self, client: AsyncDataherald) -> None: generation = await client.generations.nl_generation( @@ -234,10 +439,32 @@ async def test_raw_response_nl_generation(self, client: AsyncDataherald) -> None response = await client.generations.with_raw_response.nl_generation( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - generation = response.parse() + generation = await response.parse() assert_matches_type(NlGenerationResponse, generation, path=["response"]) + @parametrize + async def test_streaming_response_nl_generation(self, client: AsyncDataherald) -> None: + async with client.generations.with_streaming_response.nl_generation( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = await response.parse() + assert_matches_type(NlGenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_nl_generation(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.generations.with_raw_response.nl_generation( + "", + ) + @parametrize async def test_method_sql_generation(self, client: AsyncDataherald) -> None: generation = await client.generations.sql_generation( @@ -252,6 +479,30 @@ async def test_raw_response_sql_generation(self, client: AsyncDataherald) -> Non "string", sql="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - generation = response.parse() + generation = await response.parse() assert_matches_type(GenerationResponse, generation, path=["response"]) + + @parametrize + async def test_streaming_response_sql_generation(self, client: AsyncDataherald) -> None: + async with client.generations.with_streaming_response.sql_generation( + "string", + sql="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + generation = await response.parse() + assert_matches_type(GenerationResponse, generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_sql_generation(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.generations.with_raw_response.sql_generation( + "", + sql="string", + ) diff --git a/tests/api_resources/test_golden_sqls.py b/tests/api_resources/test_golden_sqls.py index ccf0660..22d8386 100644 --- a/tests/api_resources/test_golden_sqls.py +++ b/tests/api_resources/test_golden_sqls.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -36,10 +37,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.golden_sqls.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" golden_sql = response.parse() assert_matches_type(GoldenSqlResponse, golden_sql, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.golden_sqls.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = response.parse() + assert_matches_type(GoldenSqlResponse, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.golden_sqls.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: golden_sql = client.golden_sqls.list() @@ -58,10 +81,23 @@ def test_method_list_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.golden_sqls.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" golden_sql = response.parse() assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.golden_sqls.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = response.parse() + assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_delete(self, client: Dataherald) -> None: golden_sql = client.golden_sqls.delete( @@ -74,10 +110,32 @@ def test_raw_response_delete(self, client: Dataherald) -> None: response = client.golden_sqls.with_raw_response.delete( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" golden_sql = response.parse() assert_matches_type(object, golden_sql, path=["response"]) + @parametrize + def test_streaming_response_delete(self, client: Dataherald) -> None: + with client.golden_sqls.with_streaming_response.delete( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = response.parse() + assert_matches_type(object, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_delete(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.golden_sqls.with_raw_response.delete( + "", + ) + @parametrize def test_method_upload(self, client: Dataherald) -> None: golden_sql = client.golden_sqls.upload( @@ -122,10 +180,41 @@ def test_raw_response_upload(self, client: Dataherald) -> None: }, ], ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" golden_sql = response.parse() assert_matches_type(GoldenSqlUploadResponse, golden_sql, path=["response"]) + @parametrize + def test_streaming_response_upload(self, client: Dataherald) -> None: + with client.golden_sqls.with_streaming_response.upload( + body=[ + { + "db_connection_id": "string", + "prompt_text": "string", + "sql": "string", + }, + { + "db_connection_id": "string", + "prompt_text": "string", + "sql": "string", + }, + { + "db_connection_id": "string", + "prompt_text": "string", + "sql": "string", + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = response.parse() + assert_matches_type(GoldenSqlUploadResponse, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncGoldenSqls: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -144,10 +233,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.golden_sqls.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - golden_sql = response.parse() + golden_sql = await response.parse() assert_matches_type(GoldenSqlResponse, golden_sql, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.golden_sqls.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = await response.parse() + assert_matches_type(GoldenSqlResponse, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.golden_sqls.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: golden_sql = await client.golden_sqls.list() @@ -166,10 +277,23 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.golden_sqls.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - golden_sql = response.parse() + golden_sql = await response.parse() assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.golden_sqls.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = await response.parse() + assert_matches_type(GoldenSqlListResponse, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_delete(self, client: AsyncDataherald) -> None: golden_sql = await client.golden_sqls.delete( @@ -182,10 +306,32 @@ async def test_raw_response_delete(self, client: AsyncDataherald) -> None: response = await client.golden_sqls.with_raw_response.delete( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - golden_sql = response.parse() + golden_sql = await response.parse() assert_matches_type(object, golden_sql, path=["response"]) + @parametrize + async def test_streaming_response_delete(self, client: AsyncDataherald) -> None: + async with client.golden_sqls.with_streaming_response.delete( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = await response.parse() + assert_matches_type(object, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_delete(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.golden_sqls.with_raw_response.delete( + "", + ) + @parametrize async def test_method_upload(self, client: AsyncDataherald) -> None: golden_sql = await client.golden_sqls.upload( @@ -230,6 +376,37 @@ async def test_raw_response_upload(self, client: AsyncDataherald) -> None: }, ], ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - golden_sql = response.parse() + golden_sql = await response.parse() assert_matches_type(GoldenSqlUploadResponse, golden_sql, path=["response"]) + + @parametrize + async def test_streaming_response_upload(self, client: AsyncDataherald) -> None: + async with client.golden_sqls.with_streaming_response.upload( + body=[ + { + "db_connection_id": "string", + "prompt_text": "string", + "sql": "string", + }, + { + "db_connection_id": "string", + "prompt_text": "string", + "sql": "string", + }, + { + "db_connection_id": "string", + "prompt_text": "string", + "sql": "string", + }, + ], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + golden_sql = await response.parse() + assert_matches_type(GoldenSqlUploadResponse, golden_sql, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_heartbeat.py b/tests/api_resources/test_heartbeat.py index 56f83d0..366718f 100644 --- a/tests/api_resources/test_heartbeat.py +++ b/tests/api_resources/test_heartbeat.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -27,10 +28,23 @@ def test_method_retrieve(self, client: Dataherald) -> None: @parametrize def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.heartbeat.with_raw_response.retrieve() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" heartbeat = response.parse() assert_matches_type(object, heartbeat, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.heartbeat.with_streaming_response.retrieve() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + heartbeat = response.parse() + assert_matches_type(object, heartbeat, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncHeartbeat: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -45,6 +59,19 @@ async def test_method_retrieve(self, client: AsyncDataherald) -> None: @parametrize async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.heartbeat.with_raw_response.retrieve() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - heartbeat = response.parse() + heartbeat = await response.parse() assert_matches_type(object, heartbeat, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.heartbeat.with_streaming_response.retrieve() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + heartbeat = await response.parse() + assert_matches_type(object, heartbeat, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_instructions.py b/tests/api_resources/test_instructions.py index 3e49ff6..bb55143 100644 --- a/tests/api_resources/test_instructions.py +++ b/tests/api_resources/test_instructions.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -44,10 +45,25 @@ def test_raw_response_create(self, client: Dataherald) -> None: response = client.instructions.with_raw_response.create( instruction="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" instruction = response.parse() assert_matches_type(InstructionResponse, instruction, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.instructions.with_streaming_response.create( + instruction="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_update(self, client: Dataherald) -> None: instruction = client.instructions.update( @@ -72,10 +88,34 @@ def test_raw_response_update(self, client: Dataherald) -> None: "string", instruction="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" instruction = response.parse() assert_matches_type(InstructionResponse, instruction, path=["response"]) + @parametrize + def test_streaming_response_update(self, client: Dataherald) -> None: + with client.instructions.with_streaming_response.update( + "string", + instruction="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_update(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.instructions.with_raw_response.update( + "", + instruction="string", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: instruction = client.instructions.list( @@ -88,10 +128,25 @@ def test_raw_response_list(self, client: Dataherald) -> None: response = client.instructions.with_raw_response.list( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" instruction = response.parse() assert_matches_type(InstructionListResponse, instruction, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.instructions.with_streaming_response.list( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = response.parse() + assert_matches_type(InstructionListResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_delete(self, client: Dataherald) -> None: instruction = client.instructions.delete( @@ -104,10 +159,32 @@ def test_raw_response_delete(self, client: Dataherald) -> None: response = client.instructions.with_raw_response.delete( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" instruction = response.parse() assert_matches_type(object, instruction, path=["response"]) + @parametrize + def test_streaming_response_delete(self, client: Dataherald) -> None: + with client.instructions.with_streaming_response.delete( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = response.parse() + assert_matches_type(object, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_delete(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.instructions.with_raw_response.delete( + "", + ) + class TestAsyncInstructions: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -135,10 +212,25 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: response = await client.instructions.with_raw_response.create( instruction="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - instruction = response.parse() + instruction = await response.parse() assert_matches_type(InstructionResponse, instruction, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.instructions.with_streaming_response.create( + instruction="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = await response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_update(self, client: AsyncDataherald) -> None: instruction = await client.instructions.update( @@ -163,10 +255,34 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: "string", instruction="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - instruction = response.parse() + instruction = await response.parse() assert_matches_type(InstructionResponse, instruction, path=["response"]) + @parametrize + async def test_streaming_response_update(self, client: AsyncDataherald) -> None: + async with client.instructions.with_streaming_response.update( + "string", + instruction="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = await response.parse() + assert_matches_type(InstructionResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_update(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.instructions.with_raw_response.update( + "", + instruction="string", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: instruction = await client.instructions.list( @@ -179,10 +295,25 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.instructions.with_raw_response.list( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - instruction = response.parse() + instruction = await response.parse() assert_matches_type(InstructionListResponse, instruction, path=["response"]) + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.instructions.with_streaming_response.list( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = await response.parse() + assert_matches_type(InstructionListResponse, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_delete(self, client: AsyncDataherald) -> None: instruction = await client.instructions.delete( @@ -195,6 +326,28 @@ async def test_raw_response_delete(self, client: AsyncDataherald) -> None: response = await client.instructions.with_raw_response.delete( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - instruction = response.parse() + instruction = await response.parse() assert_matches_type(object, instruction, path=["response"]) + + @parametrize + async def test_streaming_response_delete(self, client: AsyncDataherald) -> None: + async with client.instructions.with_streaming_response.delete( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + instruction = await response.parse() + assert_matches_type(object, instruction, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_delete(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.instructions.with_raw_response.delete( + "", + ) diff --git a/tests/api_resources/test_nl_generations.py b/tests/api_resources/test_nl_generations.py index 90b17f8..aeceef7 100644 --- a/tests/api_resources/test_nl_generations.py +++ b/tests/api_resources/test_nl_generations.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -62,10 +63,30 @@ def test_raw_response_create(self, client: Dataherald) -> None: } }, ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" nl_generation = response.parse() assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.nl_generations.with_streaming_response.create( + sql_generation={ + "prompt": { + "text": "string", + "db_connection_id": "string", + } + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: nl_generation = client.nl_generations.retrieve( @@ -78,10 +99,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.nl_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" nl_generation = response.parse() assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.nl_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.nl_generations.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: nl_generation = client.nl_generations.list() @@ -100,10 +143,23 @@ def test_method_list_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.nl_generations.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" nl_generation = response.parse() assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.nl_generations.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = response.parse() + assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncNlGenerations: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -151,10 +207,30 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: } }, ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - nl_generation = response.parse() + nl_generation = await response.parse() assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.nl_generations.with_streaming_response.create( + sql_generation={ + "prompt": { + "text": "string", + "db_connection_id": "string", + } + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = await response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: nl_generation = await client.nl_generations.retrieve( @@ -167,10 +243,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.nl_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - nl_generation = response.parse() + nl_generation = await response.parse() assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.nl_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = await response.parse() + assert_matches_type(NlGenerationResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.nl_generations.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: nl_generation = await client.nl_generations.list() @@ -189,6 +287,19 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.nl_generations.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - nl_generation = response.parse() + nl_generation = await response.parse() assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.nl_generations.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + nl_generation = await response.parse() + assert_matches_type(NlGenerationListResponse, nl_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_prompts.py b/tests/api_resources/test_prompts.py index 16fc57c..697d0ef 100644 --- a/tests/api_resources/test_prompts.py +++ b/tests/api_resources/test_prompts.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -43,10 +44,26 @@ def test_raw_response_create(self, client: Dataherald) -> None: db_connection_id="string", text="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" prompt = response.parse() assert_matches_type(PromptResponse, prompt, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.prompts.with_streaming_response.create( + db_connection_id="string", + text="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prompt = response.parse() + assert_matches_type(PromptResponse, prompt, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: prompt = client.prompts.retrieve( @@ -59,10 +76,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.prompts.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" prompt = response.parse() assert_matches_type(PromptResponse, prompt, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.prompts.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prompt = response.parse() + assert_matches_type(PromptResponse, prompt, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.prompts.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: prompt = client.prompts.list() @@ -81,10 +120,23 @@ def test_method_list_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.prompts.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" prompt = response.parse() assert_matches_type(PromptListResponse, prompt, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.prompts.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prompt = response.parse() + assert_matches_type(PromptListResponse, prompt, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncPrompts: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -114,10 +166,26 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: db_connection_id="string", text="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prompt = response.parse() + prompt = await response.parse() assert_matches_type(PromptResponse, prompt, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.prompts.with_streaming_response.create( + db_connection_id="string", + text="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prompt = await response.parse() + assert_matches_type(PromptResponse, prompt, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: prompt = await client.prompts.retrieve( @@ -130,10 +198,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.prompts.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prompt = response.parse() + prompt = await response.parse() assert_matches_type(PromptResponse, prompt, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.prompts.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prompt = await response.parse() + assert_matches_type(PromptResponse, prompt, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.prompts.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: prompt = await client.prompts.list() @@ -152,6 +242,19 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.prompts.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - prompt = response.parse() + prompt = await response.parse() assert_matches_type(PromptListResponse, prompt, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.prompts.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + prompt = await response.parse() + assert_matches_type(PromptListResponse, prompt, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_sql_generations.py b/tests/api_resources/test_sql_generations.py index 036f74b..aba6e06 100644 --- a/tests/api_resources/test_sql_generations.py +++ b/tests/api_resources/test_sql_generations.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -57,10 +58,28 @@ def test_raw_response_create(self, client: Dataherald) -> None: "db_connection_id": "string", }, ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_create(self, client: Dataherald) -> None: + with client.sql_generations.with_streaming_response.create( + prompt={ + "text": "string", + "db_connection_id": "string", + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_retrieve(self, client: Dataherald) -> None: sql_generation = client.sql_generations.retrieve( @@ -73,10 +92,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.sql_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.sql_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.sql_generations.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: sql_generation = client.sql_generations.list() @@ -95,10 +136,23 @@ def test_method_list_with_all_params(self, client: Dataherald) -> None: @parametrize def test_raw_response_list(self, client: Dataherald) -> None: response = client.sql_generations.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.sql_generations.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_execute(self, client: Dataherald) -> None: sql_generation = client.sql_generations.execute( @@ -119,10 +173,32 @@ def test_raw_response_execute(self, client: Dataherald) -> None: response = client.sql_generations.with_raw_response.execute( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" sql_generation = response.parse() assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) + @parametrize + def test_streaming_response_execute(self, client: Dataherald) -> None: + with client.sql_generations.with_streaming_response.execute( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = response.parse() + assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_execute(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.sql_generations.with_raw_response.execute( + "", + ) + class TestAsyncSqlGenerations: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -162,10 +238,28 @@ async def test_raw_response_create(self, client: AsyncDataherald) -> None: "db_connection_id": "string", }, ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + @parametrize + async def test_streaming_response_create(self, client: AsyncDataherald) -> None: + async with client.sql_generations.with_streaming_response.create( + prompt={ + "text": "string", + "db_connection_id": "string", + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_retrieve(self, client: AsyncDataherald) -> None: sql_generation = await client.sql_generations.retrieve( @@ -178,10 +272,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.sql_generations.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.sql_generations.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(SqlGenerationResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.sql_generations.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: sql_generation = await client.sql_generations.list() @@ -200,10 +316,23 @@ async def test_method_list_with_all_params(self, client: AsyncDataherald) -> Non @parametrize async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.sql_generations.with_raw_response.list() + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.sql_generations.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(SqlGenerationListResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_execute(self, client: AsyncDataherald) -> None: sql_generation = await client.sql_generations.execute( @@ -224,6 +353,28 @@ async def test_raw_response_execute(self, client: AsyncDataherald) -> None: response = await client.sql_generations.with_raw_response.execute( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - sql_generation = response.parse() + sql_generation = await response.parse() assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) + + @parametrize + async def test_streaming_response_execute(self, client: AsyncDataherald) -> None: + async with client.sql_generations.with_streaming_response.execute( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + sql_generation = await response.parse() + assert_matches_type(SqlGenerationExecuteResponse, sql_generation, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_execute(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.sql_generations.with_raw_response.execute( + "", + ) diff --git a/tests/api_resources/test_table_descriptions.py b/tests/api_resources/test_table_descriptions.py index d260e2f..a281362 100644 --- a/tests/api_resources/test_table_descriptions.py +++ b/tests/api_resources/test_table_descriptions.py @@ -3,6 +3,7 @@ from __future__ import annotations import os +from typing import Any, cast import pytest @@ -36,10 +37,32 @@ def test_raw_response_retrieve(self, client: Dataherald) -> None: response = client.table_descriptions.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" table_description = response.parse() assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + @parametrize + def test_streaming_response_retrieve(self, client: Dataherald) -> None: + with client.table_descriptions.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = response.parse() + assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_retrieve(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.table_descriptions.with_raw_response.retrieve( + "", + ) + @parametrize def test_method_update(self, client: Dataherald) -> None: table_description = client.table_descriptions.update( @@ -91,10 +114,32 @@ def test_raw_response_update(self, client: Dataherald) -> None: response = client.table_descriptions.with_raw_response.update( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" table_description = response.parse() assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + @parametrize + def test_streaming_response_update(self, client: Dataherald) -> None: + with client.table_descriptions.with_streaming_response.update( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = response.parse() + assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_path_params_update(self, client: Dataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + client.table_descriptions.with_raw_response.update( + "", + ) + @parametrize def test_method_list(self, client: Dataherald) -> None: table_description = client.table_descriptions.list( @@ -115,10 +160,25 @@ def test_raw_response_list(self, client: Dataherald) -> None: response = client.table_descriptions.with_raw_response.list( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" table_description = response.parse() assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) + @parametrize + def test_streaming_response_list(self, client: Dataherald) -> None: + with client.table_descriptions.with_streaming_response.list( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = response.parse() + assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize def test_method_sync_schemas(self, client: Dataherald) -> None: table_description = client.table_descriptions.sync_schemas( @@ -139,10 +199,25 @@ def test_raw_response_sync_schemas(self, client: Dataherald) -> None: response = client.table_descriptions.with_raw_response.sync_schemas( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" table_description = response.parse() assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) + @parametrize + def test_streaming_response_sync_schemas(self, client: Dataherald) -> None: + with client.table_descriptions.with_streaming_response.sync_schemas( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = response.parse() + assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + class TestAsyncTableDescriptions: strict_client = AsyncDataherald(base_url=base_url, api_key=api_key, _strict_response_validation=True) @@ -161,10 +236,32 @@ async def test_raw_response_retrieve(self, client: AsyncDataherald) -> None: response = await client.table_descriptions.with_raw_response.retrieve( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - table_description = response.parse() + table_description = await response.parse() assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + @parametrize + async def test_streaming_response_retrieve(self, client: AsyncDataherald) -> None: + async with client.table_descriptions.with_streaming_response.retrieve( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = await response.parse() + assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_retrieve(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.table_descriptions.with_raw_response.retrieve( + "", + ) + @parametrize async def test_method_update(self, client: AsyncDataherald) -> None: table_description = await client.table_descriptions.update( @@ -216,10 +313,32 @@ async def test_raw_response_update(self, client: AsyncDataherald) -> None: response = await client.table_descriptions.with_raw_response.update( "string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - table_description = response.parse() + table_description = await response.parse() assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + @parametrize + async def test_streaming_response_update(self, client: AsyncDataherald) -> None: + async with client.table_descriptions.with_streaming_response.update( + "string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = await response.parse() + assert_matches_type(TableDescriptionResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_path_params_update(self, client: AsyncDataherald) -> None: + with pytest.raises(ValueError, match=r"Expected a non-empty value for `id` but received ''"): + await client.table_descriptions.with_raw_response.update( + "", + ) + @parametrize async def test_method_list(self, client: AsyncDataherald) -> None: table_description = await client.table_descriptions.list( @@ -240,10 +359,25 @@ async def test_raw_response_list(self, client: AsyncDataherald) -> None: response = await client.table_descriptions.with_raw_response.list( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - table_description = response.parse() + table_description = await response.parse() assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) + @parametrize + async def test_streaming_response_list(self, client: AsyncDataherald) -> None: + async with client.table_descriptions.with_streaming_response.list( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = await response.parse() + assert_matches_type(TableDescriptionListResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True + @parametrize async def test_method_sync_schemas(self, client: AsyncDataherald) -> None: table_description = await client.table_descriptions.sync_schemas( @@ -264,6 +398,21 @@ async def test_raw_response_sync_schemas(self, client: AsyncDataherald) -> None: response = await client.table_descriptions.with_raw_response.sync_schemas( db_connection_id="string", ) + + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - table_description = response.parse() + table_description = await response.parse() assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) + + @parametrize + async def test_streaming_response_sync_schemas(self, client: AsyncDataherald) -> None: + async with client.table_descriptions.with_streaming_response.sync_schemas( + db_connection_id="string", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + table_description = await response.parse() + assert_matches_type(TableDescriptionSyncSchemasResponse, table_description, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/test_client.py b/tests/test_client.py index 6513bef..dac2eca 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -19,6 +19,8 @@ from dataherald import Dataherald, AsyncDataherald, APIResponseValidationError from dataherald._client import Dataherald, AsyncDataherald from dataherald._models import BaseModel, FinalRequestOptions +from dataherald._response import APIResponse, AsyncAPIResponse +from dataherald._constants import RAW_RESPONSE_HEADER from dataherald._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError from dataherald._base_client import ( DEFAULT_TIMEOUT, @@ -224,6 +226,7 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic # to_raw_response_wrapper leaks through the @functools.wraps() decorator. # # removing the decorator fixes the leak for reasons we don't understand. + "dataherald/_legacy_response.py", "dataherald/_response.py", # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. "dataherald/_compat.py", @@ -664,6 +667,25 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_streaming_response(self) -> None: + response = self.client.post( + "/api/database-connections", + body=dict(), + cast_to=APIResponse[bytes], + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, + ) + + assert not cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 1 + + for _ in response.iter_bytes(): + ... + + assert cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 0 + @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: @@ -674,7 +696,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No "/api/database-connections", body=dict(), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 @@ -689,7 +711,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non "/api/database-connections", body=dict(), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 @@ -870,6 +892,7 @@ def add_leak(leaks: list[tracemalloc.StatisticDiff], diff: tracemalloc.Statistic # to_raw_response_wrapper leaks through the @functools.wraps() decorator. # # removing the decorator fixes the leak for reasons we don't understand. + "dataherald/_legacy_response.py", "dataherald/_response.py", # pydantic.BaseModel.model_dump || pydantic.BaseModel.dict leak memory for some reason. "dataherald/_compat.py", @@ -1316,6 +1339,25 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte calculated = client._calculate_retry_timeout(remaining_retries, options, headers) assert calculated == pytest.approx(timeout, 0.5 * 0.875) # pyright: ignore[reportUnknownMemberType] + @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + async def test_streaming_response(self) -> None: + response = await self.client.post( + "/api/database-connections", + body=dict(), + cast_to=AsyncAPIResponse[bytes], + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, + ) + + assert not cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 1 + + async for _ in response.iter_bytes(): + ... + + assert cast(Any, response.is_closed) + assert _get_open_connections(self.client) == 0 + @mock.patch("dataherald._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: @@ -1326,7 +1368,7 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) "/api/database-connections", body=dict(), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 @@ -1341,7 +1383,7 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) "/api/database-connections", body=dict(), cast_to=httpx.Response, - options={"headers": {"X-Stainless-Streamed-Raw-Response": "true"}}, + options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) assert _get_open_connections(self.client) == 0 diff --git a/tests/test_response.py b/tests/test_response.py new file mode 100644 index 0000000..1359a2f --- /dev/null +++ b/tests/test_response.py @@ -0,0 +1,50 @@ +from typing import List + +import httpx +import pytest + +from dataherald._response import ( + APIResponse, + BaseAPIResponse, + AsyncAPIResponse, + BinaryAPIResponse, + AsyncBinaryAPIResponse, + extract_response_type, +) + + +class ConcreteBaseAPIResponse(APIResponse[bytes]): + ... + + +class ConcreteAPIResponse(APIResponse[List[str]]): + ... + + +class ConcreteAsyncAPIResponse(APIResponse[httpx.Response]): + ... + + +def test_extract_response_type_direct_classes() -> None: + assert extract_response_type(BaseAPIResponse[str]) == str + assert extract_response_type(APIResponse[str]) == str + assert extract_response_type(AsyncAPIResponse[str]) == str + + +def test_extract_response_type_direct_class_missing_type_arg() -> None: + with pytest.raises( + RuntimeError, + match="Expected type to have a type argument at index 0 but it did not", + ): + extract_response_type(AsyncAPIResponse) + + +def test_extract_response_type_concrete_subclasses() -> None: + assert extract_response_type(ConcreteBaseAPIResponse) == bytes + assert extract_response_type(ConcreteAPIResponse) == List[str] + assert extract_response_type(ConcreteAsyncAPIResponse) == httpx.Response + + +def test_extract_response_type_binary_response() -> None: + assert extract_response_type(BinaryAPIResponse) == bytes + assert extract_response_type(AsyncBinaryAPIResponse) == bytes diff --git a/tests/utils.py b/tests/utils.py index 39ffb87..e699732 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import inspect import traceback import contextlib from typing import Any, TypeVar, Iterator, cast @@ -68,6 +69,8 @@ def assert_matches_type( assert isinstance(value, bool) elif origin == float: assert isinstance(value, float) + elif origin == bytes: + assert isinstance(value, bytes) elif origin == datetime: assert isinstance(value, datetime) elif origin == date: @@ -100,6 +103,8 @@ def assert_matches_type( elif issubclass(origin, BaseModel): assert isinstance(value, type_) assert assert_matches_model(type_, cast(Any, value), path=path) + elif inspect.isclass(origin) and origin.__name__ == "HttpxBinaryResponseContent": + assert value.__class__.__name__ == "HttpxBinaryResponseContent" else: assert None, f"Unhandled field type: {type_}"