diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b4e9013..6db19b9 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.16.0" + ".": "0.17.0" } \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index 244e8fa..226ef59 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 0.17.0 (2024-02-09) + +Full Changelog: [v0.16.0...v0.17.0](https://github.com/Dataherald/dataherald-python/compare/v0.16.0...v0.17.0) + +### Features + +* OpenAPI spec update ([#38](https://github.com/Dataherald/dataherald-python/issues/38)) ([5852d74](https://github.com/Dataherald/dataherald-python/commit/5852d745b443ac13bf771c1bd4c507f66a91620c)) + ## 0.16.0 (2024-01-24) Full Changelog: [v0.15.0...v0.16.0](https://github.com/Dataherald/dataherald-python/compare/v0.15.0...v0.16.0) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..adcfe96 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,125 @@ +## Setting up the environment + +### With Rye + +We use [Rye](https://rye-up.com/) to manage dependencies so we highly recommend [installing it](https://rye-up.com/guide/installation/) as it will automatically provision a Python environment with the expected Python version. + +After installing Rye, you'll just have to run this command: + +```sh +$ rye sync --all-features +``` + +You can then run scripts using `rye run python script.py` or by activating the virtual environment: + +```sh +$ rye shell +# or manually activate - https://docs.python.org/3/library/venv.html#how-venvs-work +$ source .venv/bin/activate + +# now you can omit the `rye run` prefix +$ python script.py +``` + +### Without Rye + +Alternatively if you don't want to install `Rye`, you can stick with the standard `pip` setup by ensuring you have the Python version specified in `.python-version`, create a virtual environment however you desire and then install dependencies using this command: + +```sh +$ pip install -r requirements-dev.lock +``` + +## Modifying/Adding code + +Most of the SDK is generated code, and any modified code will be overridden on the next generation. The +`src/dataherald/lib/` and `examples/` directories are exceptions and will never be overridden. + +## Adding and running examples + +All files in the `examples/` directory are not modified by the Stainless generator and can be freely edited or +added to. + +```bash +# add an example to examples/.py + +#!/usr/bin/env -S rye run python +… +``` + +``` +chmod +x examples/.py +# run the example against your api +./examples/.py +``` + +## Using the repository from source + +If you’d like to use the repository from source, you can either install from git or link to a cloned repository: + +To install via git: + +```bash +pip install git+ssh://git@github.com:Dataherald/dataherald-python.git +``` + +Alternatively, you can build from source and install the wheel file: + +Building this package will create two files in the `dist/` directory, a `.tar.gz` containing the source files and a `.whl` that can be used to install the package efficiently. + +To create a distributable version of the library, all you have to do is run this command: + +```bash +rye build +# or +python -m build +``` + +Then to install: + +```sh +pip install ./path-to-wheel-file.whl +``` + +## Running tests + +Most tests will require you to [setup a mock server](https://github.com/stoplightio/prism) against the OpenAPI spec to run the tests. + +```bash +# you will need npm installed +npx prism path/to/your/openapi.yml +``` + +```bash +rye run pytest +``` + +## Linting and formatting + +This repository uses [ruff](https://github.com/astral-sh/ruff) and +[black](https://github.com/psf/black) to format the code in the repository. + +To lint: + +```bash +rye run lint +``` + +To format and fix all ruff issues automatically: + +```bash +rye run format +``` + +## Publishing and releases + +Changes made to this repository via the automated release PR pipeline should publish to PyPI automatically. If +the changes aren't made through the automated pipeline, you may want to make releases manually. + +### Publish with a GitHub workflow + +You can release to package managers by using [the `Publish PyPI` GitHub action](https://www.github.com/Dataherald/dataherald-python/actions/workflows/publish-pypi.yml). This will require a setup organization or repository secret to be set up. + +### Publish manually + +If you need to manually release a package, you can run the `bin/publish-pypi` script with an `PYPI_TOKEN` set on +the environment. diff --git a/README.md b/README.md index c640664..78c09e5 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ and offers both synchronous and asynchronous clients powered by [httpx](https:// ## Documentation -The REST API documentation can be found [on dataherald.readthedocs.io](https://dataherald.readthedocs.io/en/latest/). The full API of this library can be found in [api.md](https://www.github.com/Dataherald/dataherald-python/blob/main/api.md). +The REST API documentation can be found [on dataherald.readthedocs.io](https://dataherald.readthedocs.io/en/latest/). The full API of this library can be found in [api.md](api.md). ## Installation @@ -18,7 +18,7 @@ pip install dataherald ## Usage -The full API of this library can be found in [api.md](https://www.github.com/Dataherald/dataherald-python/blob/main/api.md). +The full API of this library can be found in [api.md](api.md). ```python import os diff --git a/pyproject.toml b/pyproject.toml index 896e58e..b029e57 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dataherald" -version = "0.16.0" +version = "0.17.0" description = "The official Python library for the Dataherald API" readme = "README.md" license = "Apache-2.0" @@ -72,6 +72,10 @@ format = { chain = [ "format:ruff" = "ruff format" "format:isort" = "isort ." +"lint" = { chain = [ + "check:ruff", + "typecheck", +]} "check:ruff" = "ruff ." "fix:ruff" = "ruff --fix ." @@ -144,6 +148,8 @@ select = [ # print statements "T201", "T203", + # misuse of typing.TYPE_CHECKING + "TCH004" ] ignore = [ # mutable defaults diff --git a/release-please-config.json b/release-please-config.json index 556e172..a85c35c 100644 --- a/release-please-config.json +++ b/release-please-config.json @@ -5,6 +5,8 @@ "$schema": "https://raw.githubusercontent.com/stainless-api/release-please/main/schemas/config.json", "include-v-in-tag": true, "include-component-in-tag": false, + "versioning": "prerelease", + "prerelease": true, "bump-minor-pre-major": true, "bump-patch-for-minor-pre-major": false, "pull-request-header": "Automated Release PR", diff --git a/src/dataherald/__init__.py b/src/dataherald/__init__.py index 270671c..35c80b9 100644 --- a/src/dataherald/__init__.py +++ b/src/dataherald/__init__.py @@ -15,6 +15,7 @@ RequestOptions, AsyncDataherald, ) +from ._models import BaseModel from ._version import __title__, __version__ from ._response import APIResponse as APIResponse, AsyncAPIResponse as AsyncAPIResponse from ._exceptions import ( @@ -66,6 +67,7 @@ "AsyncDataherald", "ENVIRONMENTS", "file_from_path", + "BaseModel", ] _setup_logging() diff --git a/src/dataherald/_base_client.py b/src/dataherald/_base_client.py index 5c1695e..0b5ece2 100644 --- a/src/dataherald/_base_client.py +++ b/src/dataherald/_base_client.py @@ -61,7 +61,7 @@ RequestOptions, ModelBuilderProtocol, ) -from ._utils import is_dict, is_given, is_mapping +from ._utils import is_dict, is_list, is_given, is_mapping from ._compat import model_copy, model_dump from ._models import GenericModel, FinalRequestOptions, validate_type, construct_type from ._response import ( @@ -450,14 +450,18 @@ def _build_request( headers = self._build_headers(options) params = _merge_mappings(self._custom_query, options.params) + content_type = headers.get("Content-Type") # If the given Content-Type header is multipart/form-data then it # has to be removed so that httpx can generate the header with # additional information for us as it has to be in this form # for the server to be able to correctly parse the request: # multipart/form-data; boundary=---abc-- - if headers.get("Content-Type") == "multipart/form-data": - headers.pop("Content-Type") + if content_type is not None and content_type.startswith("multipart/form-data"): + if "boundary" not in content_type: + # only remove the header if the boundary hasn't been explicitly set + # as the caller doesn't want httpx to come up with their own boundary + headers.pop("Content-Type") # As we are now sending multipart/form-data instead of application/json # we need to tell httpx to use it, https://www.python-httpx.org/advanced/#multipart-file-encoding @@ -493,9 +497,25 @@ def _serialize_multipartform(self, data: Mapping[object, object]) -> dict[str, o ) serialized: dict[str, object] = {} for key, value in items: - if key in serialized: - raise ValueError(f"Duplicate key encountered: {key}; This behaviour is not supported") - serialized[key] = value + existing = serialized.get(key) + + if not existing: + serialized[key] = value + continue + + # If a value has already been set for this key then that + # means we're sending data like `array[]=[1, 2, 3]` and we + # need to tell httpx that we want to send multiple values with + # the same key which is done by using a list or a tuple. + # + # Note: 2d arrays should never result in the same key at both + # levels so it's safe to assume that if the value is a list, + # it was because we changed it to be a list. + if is_list(existing): + existing.append(value) + else: + serialized[key] = [existing, value] + return serialized def _maybe_override_cast_to(self, cast_to: type[ResponseT], options: FinalRequestOptions) -> type[ResponseT]: @@ -1789,8 +1809,12 @@ def __str__(self) -> str: def get_platform() -> Platform: - system = platform.system().lower() - platform_name = platform.platform().lower() + try: + system = platform.system().lower() + platform_name = platform.platform().lower() + except Exception: + return "Unknown" + if "iphone" in platform_name or "ipad" in platform_name: # Tested using Python3IDE on an iPhone 11 and Pythonista on an iPad 7 # system is Darwin and platform_name is a string like: @@ -1833,8 +1857,8 @@ def platform_headers(version: str) -> Dict[str, str]: "X-Stainless-Package-Version": version, "X-Stainless-OS": str(get_platform()), "X-Stainless-Arch": str(get_architecture()), - "X-Stainless-Runtime": platform.python_implementation(), - "X-Stainless-Runtime-Version": platform.python_version(), + "X-Stainless-Runtime": get_python_runtime(), + "X-Stainless-Runtime-Version": get_python_version(), } @@ -1850,9 +1874,27 @@ def __str__(self) -> str: Arch = Union[OtherArch, Literal["x32", "x64", "arm", "arm64", "unknown"]] +def get_python_runtime() -> str: + try: + return platform.python_implementation() + except Exception: + return "unknown" + + +def get_python_version() -> str: + try: + return platform.python_version() + except Exception: + return "unknown" + + def get_architecture() -> Arch: - python_bitness, _ = platform.architecture() - machine = platform.machine().lower() + try: + python_bitness, _ = platform.architecture() + machine = platform.machine().lower() + except Exception: + return "unknown" + if machine in ("arm64", "aarch64"): return "arm64" diff --git a/src/dataherald/_response.py b/src/dataherald/_response.py index 9ef757c..d09fd45 100644 --- a/src/dataherald/_response.py +++ b/src/dataherald/_response.py @@ -16,25 +16,29 @@ Iterator, AsyncIterator, cast, + overload, ) from typing_extensions import Awaitable, ParamSpec, override, get_origin import anyio import httpx +import pydantic from ._types import NoneType from ._utils import is_given, extract_type_var_from_base from ._models import BaseModel, is_basemodel from ._constants import RAW_RESPONSE_HEADER, OVERRIDE_CAST_TO_HEADER +from ._streaming import Stream, AsyncStream, is_stream_class_type, extract_stream_chunk_type from ._exceptions import DataheraldError, APIResponseValidationError if TYPE_CHECKING: from ._models import FinalRequestOptions - from ._base_client import Stream, BaseClient, AsyncStream + from ._base_client import BaseClient P = ParamSpec("P") R = TypeVar("R") +_T = TypeVar("_T") _APIResponseT = TypeVar("_APIResponseT", bound="APIResponse[Any]") _AsyncAPIResponseT = TypeVar("_AsyncAPIResponseT", bound="AsyncAPIResponse[Any]") @@ -44,7 +48,7 @@ class BaseAPIResponse(Generic[R]): _cast_to: type[R] _client: BaseClient[Any, Any] - _parsed: R | None + _parsed_by_type: dict[type[Any], Any] _is_sse_stream: bool _stream_cls: type[Stream[Any]] | type[AsyncStream[Any]] | None _options: FinalRequestOptions @@ -63,7 +67,7 @@ def __init__( ) -> None: self._cast_to = cast_to self._client = client - self._parsed = None + self._parsed_by_type = {} self._is_sse_stream = stream self._stream_cls = stream_cls self._options = options @@ -116,8 +120,24 @@ def __repr__(self) -> str: f"<{self.__class__.__name__} [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_to}>" ) - def _parse(self) -> R: + def _parse(self, *, to: type[_T] | None = None) -> R | _T: if self._is_sse_stream: + if to: + if not is_stream_class_type(to): + raise TypeError(f"Expected custom parse type to be a subclass of {Stream} or {AsyncStream}") + + return cast( + _T, + to( + cast_to=extract_stream_chunk_type( + to, + failure_message="Expected custom stream type to be passed with a type argument, e.g. Stream[ChunkType]", + ), + response=self.http_response, + client=cast(Any, self._client), + ), + ) + if self._stream_cls: return cast( R, @@ -141,7 +161,7 @@ def _parse(self) -> R: ), ) - cast_to = self._cast_to + cast_to = to if to is not None else self._cast_to if cast_to is NoneType: return cast(R, None) @@ -167,14 +187,11 @@ def _parse(self) -> R: raise ValueError(f"Subclasses of httpx.Response cannot be passed to `cast_to`") return cast(R, response) - # The check here is necessary as we are subverting the the type system - # with casts as the relationship between TypeVars and Types are very strict - # which means we must return *exactly* what was input or transform it in a - # way that retains the TypeVar state. As we cannot do that in this function - # then we have to resort to using `cast`. At the time of writing, we know this - # to be safe as we have handled all the types that could be bound to the - # `ResponseT` TypeVar, however if that TypeVar is ever updated in the future, then - # this function would become unsafe but a type checker would not report an error. + if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel): + raise TypeError( + "Pydantic models must subclass our base model type, e.g. `from dataherald import BaseModel`" + ) + if ( cast_to is not object and not origin is list @@ -183,12 +200,12 @@ def _parse(self) -> R: and not issubclass(origin, BaseModel) ): raise RuntimeError( - f"Invalid state, expected {cast_to} to be a subclass type of {BaseModel}, {dict}, {list} or {Union}." + f"Unsupported type, expected {cast_to} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." ) # split is required to handle cases where additional information is included # in the response, e.g. application/json; charset=utf-8 - content_type, *_ = response.headers.get("content-type").split(";") + content_type, *_ = response.headers.get("content-type", "*").split(";") if content_type != "application/json": if is_basemodel(cast_to): try: @@ -224,22 +241,55 @@ def _parse(self) -> R: class APIResponse(BaseAPIResponse[R]): + @overload + def parse(self, *, to: type[_T]) -> _T: + ... + + @overload def parse(self) -> R: + ... + + def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from dataherald import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `httpx.Response` """ - if self._parsed is not None: - return self._parsed + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] if not self._is_sse_stream: self.read() - parsed = self._parse() + parsed = self._parse(to=to) if is_given(self._options.post_parser): parsed = self._options.post_parser(parsed) - self._parsed = parsed + self._parsed_by_type[cache_key] = parsed return parsed def read(self) -> bytes: @@ -293,22 +343,55 @@ def iter_lines(self) -> Iterator[str]: class AsyncAPIResponse(BaseAPIResponse[R]): + @overload + async def parse(self, *, to: type[_T]) -> _T: + ... + + @overload async def parse(self) -> R: + ... + + async def parse(self, *, to: type[_T] | None = None) -> R | _T: """Returns the rich python representation of this response's data. For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`. + + You can customise the type that the response is parsed into through + the `to` argument, e.g. + + ```py + from dataherald import BaseModel + + + class MyModel(BaseModel): + foo: str + + + obj = response.parse(to=MyModel) + print(obj.foo) + ``` + + We support parsing: + - `BaseModel` + - `dict` + - `list` + - `Union` + - `str` + - `httpx.Response` """ - if self._parsed is not None: - return self._parsed + cache_key = to if to is not None else self._cast_to + cached = self._parsed_by_type.get(cache_key) + if cached is not None: + return cached # type: ignore[no-any-return] if not self._is_sse_stream: await self.read() - parsed = self._parse() + parsed = self._parse(to=to) if is_given(self._options.post_parser): parsed = self._options.post_parser(parsed) - self._parsed = parsed + self._parsed_by_type[cache_key] = parsed return parsed async def read(self) -> bytes: @@ -704,26 +787,6 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> Awaitable[_AsyncAPIResponseT]: return wrapped -def extract_stream_chunk_type(stream_cls: type) -> type: - """Given a type like `Stream[T]`, returns the generic type variable `T`. - - This also handles the case where a concrete subclass is given, e.g. - ```py - class MyStream(Stream[bytes]): - ... - - extract_stream_chunk_type(MyStream) -> bytes - ``` - """ - from ._base_client import Stream, AsyncStream - - return extract_type_var_from_base( - stream_cls, - index=0, - generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), - ) - - def extract_response_type(typ: type[BaseAPIResponse[Any]]) -> type: """Given a type like `APIResponse[T]`, returns the generic type variable `T`. diff --git a/src/dataherald/_streaming.py b/src/dataherald/_streaming.py index e876a6e..d51d4cf 100644 --- a/src/dataherald/_streaming.py +++ b/src/dataherald/_streaming.py @@ -2,12 +2,15 @@ from __future__ import annotations import json +import inspect from types import TracebackType from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast -from typing_extensions import Self, override +from typing_extensions import Self, TypeGuard, override, get_origin import httpx +from ._utils import extract_type_var_from_base + if TYPE_CHECKING: from ._client import Dataherald, AsyncDataherald @@ -254,3 +257,34 @@ def decode(self, line: str) -> ServerSentEvent | None: pass # Field is ignored. return None + + +def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: + """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" + origin = get_origin(typ) or typ + return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) + + +def extract_stream_chunk_type( + stream_cls: type, + *, + failure_message: str | None = None, +) -> type: + """Given a type like `Stream[T]`, returns the generic type variable `T`. + + This also handles the case where a concrete subclass is given, e.g. + ```py + class MyStream(Stream[bytes]): + ... + + extract_stream_chunk_type(MyStream) -> bytes + ``` + """ + from ._base_client import Stream, AsyncStream + + return extract_type_var_from_base( + stream_cls, + index=0, + generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), + failure_message=failure_message, + ) diff --git a/src/dataherald/_utils/__init__.py b/src/dataherald/_utils/__init__.py index 0fb811a..b5790a8 100644 --- a/src/dataherald/_utils/__init__.py +++ b/src/dataherald/_utils/__init__.py @@ -9,6 +9,7 @@ is_mapping as is_mapping, is_tuple_t as is_tuple_t, parse_date as parse_date, + is_iterable as is_iterable, is_sequence as is_sequence, coerce_float as coerce_float, is_mapping_t as is_mapping_t, @@ -33,6 +34,7 @@ is_list_type as is_list_type, is_union_type as is_union_type, extract_type_arg as extract_type_arg, + is_iterable_type as is_iterable_type, is_required_type as is_required_type, is_annotated_type as is_annotated_type, strip_annotated_type as strip_annotated_type, diff --git a/src/dataherald/_utils/_transform.py b/src/dataherald/_utils/_transform.py index 3a1c149..2cb7726 100644 --- a/src/dataherald/_utils/_transform.py +++ b/src/dataherald/_utils/_transform.py @@ -9,11 +9,13 @@ from ._utils import ( is_list, is_mapping, + is_iterable, ) from ._typing import ( is_list_type, is_union_type, extract_type_arg, + is_iterable_type, is_required_type, is_annotated_type, strip_annotated_type, @@ -157,7 +159,12 @@ def _transform_recursive( if is_typeddict(stripped_type) and is_mapping(data): return _transform_typeddict(data, stripped_type) - if is_list_type(stripped_type) and is_list(data): + if ( + # List[T] + (is_list_type(stripped_type) and is_list(data)) + # Iterable[T] + or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str)) + ): inner_type = extract_type_arg(stripped_type, 0) return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data] diff --git a/src/dataherald/_utils/_typing.py b/src/dataherald/_utils/_typing.py index a020822..c036991 100644 --- a/src/dataherald/_utils/_typing.py +++ b/src/dataherald/_utils/_typing.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, TypeVar, cast +from typing import Any, TypeVar, Iterable, cast +from collections import abc as _c_abc from typing_extensions import Required, Annotated, get_args, get_origin from .._types import InheritsGeneric @@ -15,6 +16,12 @@ def is_list_type(typ: type) -> bool: return (get_origin(typ) or typ) == list +def is_iterable_type(typ: type) -> bool: + """If the given type is `typing.Iterable[T]`""" + origin = get_origin(typ) or typ + return origin == Iterable or origin == _c_abc.Iterable + + def is_union_type(typ: type) -> bool: return _is_union(get_origin(typ)) @@ -45,7 +52,13 @@ def extract_type_arg(typ: type, index: int) -> type: raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err -def extract_type_var_from_base(typ: type, *, generic_bases: tuple[type, ...], index: int) -> type: +def extract_type_var_from_base( + typ: type, + *, + generic_bases: tuple[type, ...], + index: int, + failure_message: str | None = None, +) -> type: """Given a type like `Foo[T]`, returns the generic type variable `T`. This also handles the case where a concrete subclass is given, e.g. @@ -104,4 +117,4 @@ class MyResponse(Foo[_T]): return extracted - raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}") + raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}") diff --git a/src/dataherald/_utils/_utils.py b/src/dataherald/_utils/_utils.py index 1c5c21a..93c9551 100644 --- a/src/dataherald/_utils/_utils.py +++ b/src/dataherald/_utils/_utils.py @@ -164,6 +164,10 @@ def is_list(obj: object) -> TypeGuard[list[object]]: return isinstance(obj, list) +def is_iterable(obj: object) -> TypeGuard[Iterable[object]]: + return isinstance(obj, Iterable) + + def deepcopy_minimal(item: _T) -> _T: """Minimal reimplementation of copy.deepcopy() that will only copy certain object types: diff --git a/src/dataherald/_version.py b/src/dataherald/_version.py index c7dd448..4225dea 100644 --- a/src/dataherald/_version.py +++ b/src/dataherald/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. __title__ = "dataherald" -__version__ = "0.16.0" # x-release-please-version +__version__ = "0.17.0" # x-release-please-version diff --git a/src/dataherald/resources/finetunings.py b/src/dataherald/resources/finetunings.py index 7ada1ea..c561f3c 100644 --- a/src/dataherald/resources/finetunings.py +++ b/src/dataherald/resources/finetunings.py @@ -44,7 +44,7 @@ def create( db_connection_id: str, alias: str | NotGiven = NOT_GIVEN, base_llm: finetuning_create_params.BaseLlm | NotGiven = NOT_GIVEN, - golden_records: List[str] | NotGiven = NOT_GIVEN, + golden_sqls: List[str] | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -72,7 +72,7 @@ def create( "db_connection_id": db_connection_id, "alias": alias, "base_llm": base_llm, - "golden_records": golden_records, + "golden_sqls": golden_sqls, "metadata": metadata, }, finetuning_create_params.FinetuningCreateParams, @@ -202,7 +202,7 @@ async def create( db_connection_id: str, alias: str | NotGiven = NOT_GIVEN, base_llm: finetuning_create_params.BaseLlm | NotGiven = NOT_GIVEN, - golden_records: List[str] | NotGiven = NOT_GIVEN, + golden_sqls: List[str] | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -230,7 +230,7 @@ async def create( "db_connection_id": db_connection_id, "alias": alias, "base_llm": base_llm, - "golden_records": golden_records, + "golden_sqls": golden_sqls, "metadata": metadata, }, finetuning_create_params.FinetuningCreateParams, diff --git a/src/dataherald/resources/golden_sqls.py b/src/dataherald/resources/golden_sqls.py index 6a626ed..fb3b76c 100644 --- a/src/dataherald/resources/golden_sqls.py +++ b/src/dataherald/resources/golden_sqls.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List +from typing import Iterable import httpx @@ -154,7 +154,7 @@ def delete( def upload( self, *, - body: List[golden_sql_upload_params.Body], + body: Iterable[golden_sql_upload_params.Body], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -308,7 +308,7 @@ async def delete( async def upload( self, *, - body: List[golden_sql_upload_params.Body], + body: Iterable[golden_sql_upload_params.Body], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, diff --git a/src/dataherald/resources/prompts/sql_generations.py b/src/dataherald/resources/prompts/sql_generations.py index 1092d98..c144752 100644 --- a/src/dataherald/resources/prompts/sql_generations.py +++ b/src/dataherald/resources/prompts/sql_generations.py @@ -42,6 +42,7 @@ def create( *, evaluate: bool | NotGiven = NOT_GIVEN, finetuning_id: str | NotGiven = NOT_GIVEN, + low_latency_mode: bool | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, sql: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -71,6 +72,7 @@ def create( { "evaluate": evaluate, "finetuning_id": finetuning_id, + "low_latency_mode": low_latency_mode, "metadata": metadata, "sql": sql, }, @@ -191,6 +193,7 @@ async def create( *, evaluate: bool | NotGiven = NOT_GIVEN, finetuning_id: str | NotGiven = NOT_GIVEN, + low_latency_mode: bool | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, sql: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -220,6 +223,7 @@ async def create( { "evaluate": evaluate, "finetuning_id": finetuning_id, + "low_latency_mode": low_latency_mode, "metadata": metadata, "sql": sql, }, diff --git a/src/dataherald/resources/table_descriptions.py b/src/dataherald/resources/table_descriptions.py index d7d015e..1d55aea 100644 --- a/src/dataherald/resources/table_descriptions.py +++ b/src/dataherald/resources/table_descriptions.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import List +from typing import Iterable import httpx @@ -77,9 +77,9 @@ def update( self, id: str, *, - columns: List[table_description_update_params.Column] | NotGiven = NOT_GIVEN, + columns: Iterable[table_description_update_params.Column] | NotGiven = NOT_GIVEN, description: str | NotGiven = NOT_GIVEN, - examples: List[object] | NotGiven = NOT_GIVEN, + examples: Iterable[object] | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -164,7 +164,7 @@ def list( def sync_schemas( self, *, - body: List[table_description_sync_schemas_params.Body], + body: Iterable[table_description_sync_schemas_params.Body], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, @@ -240,9 +240,9 @@ async def update( self, id: str, *, - columns: List[table_description_update_params.Column] | NotGiven = NOT_GIVEN, + columns: Iterable[table_description_update_params.Column] | NotGiven = NOT_GIVEN, description: str | NotGiven = NOT_GIVEN, - examples: List[object] | NotGiven = NOT_GIVEN, + examples: Iterable[object] | NotGiven = NOT_GIVEN, metadata: object | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -327,7 +327,7 @@ async def list( async def sync_schemas( self, *, - body: List[table_description_sync_schemas_params.Body], + body: Iterable[table_description_sync_schemas_params.Body], # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Headers | None = None, diff --git a/src/dataherald/types/finetuning_create_params.py b/src/dataherald/types/finetuning_create_params.py index 8bba060..1b07204 100644 --- a/src/dataherald/types/finetuning_create_params.py +++ b/src/dataherald/types/finetuning_create_params.py @@ -15,7 +15,7 @@ class FinetuningCreateParams(TypedDict, total=False): base_llm: BaseLlm - golden_records: List[str] + golden_sqls: List[str] metadata: object diff --git a/src/dataherald/types/golden_sql_upload_params.py b/src/dataherald/types/golden_sql_upload_params.py index 0280003..14a904c 100644 --- a/src/dataherald/types/golden_sql_upload_params.py +++ b/src/dataherald/types/golden_sql_upload_params.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import List +from typing import Iterable from typing_extensions import Required, TypedDict __all__ = ["GoldenSqlUploadParams", "Body"] class GoldenSqlUploadParams(TypedDict, total=False): - body: Required[List[Body]] + body: Required[Iterable[Body]] class Body(TypedDict, total=False): diff --git a/src/dataherald/types/prompts/sql_generation_create_params.py b/src/dataherald/types/prompts/sql_generation_create_params.py index e566cd5..2c0e7f8 100644 --- a/src/dataherald/types/prompts/sql_generation_create_params.py +++ b/src/dataherald/types/prompts/sql_generation_create_params.py @@ -12,6 +12,8 @@ class SqlGenerationCreateParams(TypedDict, total=False): finetuning_id: str + low_latency_mode: bool + metadata: object sql: str diff --git a/src/dataherald/types/prompts/sql_generation_nl_generations_params.py b/src/dataherald/types/prompts/sql_generation_nl_generations_params.py index ea1acad..754bafd 100644 --- a/src/dataherald/types/prompts/sql_generation_nl_generations_params.py +++ b/src/dataherald/types/prompts/sql_generation_nl_generations_params.py @@ -20,6 +20,8 @@ class SqlGeneration(TypedDict, total=False): finetuning_id: str + low_latency_mode: bool + metadata: object sql: str diff --git a/src/dataherald/types/table_description_sync_schemas_params.py b/src/dataherald/types/table_description_sync_schemas_params.py index 880570c..70f7846 100644 --- a/src/dataherald/types/table_description_sync_schemas_params.py +++ b/src/dataherald/types/table_description_sync_schemas_params.py @@ -2,14 +2,14 @@ from __future__ import annotations -from typing import List +from typing import List, Iterable from typing_extensions import Required, TypedDict __all__ = ["TableDescriptionSyncSchemasParams", "Body"] class TableDescriptionSyncSchemasParams(TypedDict, total=False): - body: Required[List[Body]] + body: Required[Iterable[Body]] class Body(TypedDict, total=False): diff --git a/src/dataherald/types/table_description_update_params.py b/src/dataherald/types/table_description_update_params.py index a6ec4ce..917b2f4 100644 --- a/src/dataherald/types/table_description_update_params.py +++ b/src/dataherald/types/table_description_update_params.py @@ -2,18 +2,18 @@ from __future__ import annotations -from typing import List +from typing import List, Iterable from typing_extensions import TypedDict __all__ = ["TableDescriptionUpdateParams", "Column"] class TableDescriptionUpdateParams(TypedDict, total=False): - columns: List[Column] + columns: Iterable[Column] description: str - examples: List[object] + examples: Iterable[object] metadata: object diff --git a/tests/api_resources/prompts/test_sql_generations.py b/tests/api_resources/prompts/test_sql_generations.py index c59c0b9..b82e35d 100644 --- a/tests/api_resources/prompts/test_sql_generations.py +++ b/tests/api_resources/prompts/test_sql_generations.py @@ -30,6 +30,7 @@ def test_method_create_with_all_params(self, client: Dataherald) -> None: "string", evaluate=True, finetuning_id="string", + low_latency_mode=True, metadata={}, sql="string", ) @@ -129,6 +130,7 @@ def test_method_nl_generations_with_all_params(self, client: Dataherald) -> None "string", sql_generation={ "finetuning_id": "string", + "low_latency_mode": True, "evaluate": True, "sql": "string", "metadata": {}, @@ -189,6 +191,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncDataherald "string", evaluate=True, finetuning_id="string", + low_latency_mode=True, metadata={}, sql="string", ) @@ -288,6 +291,7 @@ async def test_method_nl_generations_with_all_params(self, async_client: AsyncDa "string", sql_generation={ "finetuning_id": "string", + "low_latency_mode": True, "evaluate": True, "sql": "string", "metadata": {}, diff --git a/tests/api_resources/test_finetunings.py b/tests/api_resources/test_finetunings.py index 013371f..84e6b16 100644 --- a/tests/api_resources/test_finetunings.py +++ b/tests/api_resources/test_finetunings.py @@ -37,7 +37,7 @@ def test_method_create_with_all_params(self, client: Dataherald) -> None: "model_name": "string", "model_parameters": {"foo": "string"}, }, - golden_records=["string", "string", "string"], + golden_sqls=["string", "string", "string"], metadata={}, ) assert_matches_type(FinetuningResponse, finetuning, path=["response"]) @@ -194,7 +194,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncDataherald "model_name": "string", "model_parameters": {"foo": "string"}, }, - golden_records=["string", "string", "string"], + golden_sqls=["string", "string", "string"], metadata={}, ) assert_matches_type(FinetuningResponse, finetuning, path=["response"]) diff --git a/tests/test_client.py b/tests/test_client.py index 9038b17..a839e4d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -434,6 +434,35 @@ def test_request_extra_query(self) -> None: params = dict(request.url.params) assert params == {"foo": "2"} + def test_multipart_repeating_array(self, client: Dataherald) -> None: + request = client._build_request( + FinalRequestOptions.construct( + method="get", + url="/foo", + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, + json_data={"array": ["foo", "bar"]}, + files=[("foo.txt", b"hello world")], + ) + ) + + assert request.read().split(b"\r\n") == [ + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"foo", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"bar", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="foo.txt"; filename="upload"', + b"Content-Type: application/octet-stream", + b"", + b"hello world", + b"--6b7ba517decee4a450543ea6ae821c82--", + b"", + ] + @pytest.mark.respx(base_url=base_url) def test_basic_union_response(self, respx_mock: MockRouter) -> None: class Model1(BaseModel): @@ -674,7 +703,7 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No with pytest.raises(APITimeoutError): self.client.post( "/api/database-connections", - body=dict(alias="string", connection_uri="string"), + body=cast(object, dict(alias="string", connection_uri="string")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -689,7 +718,7 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non with pytest.raises(APIStatusError): self.client.post( "/api/database-connections", - body=dict(alias="string", connection_uri="string"), + body=cast(object, dict(alias="string", connection_uri="string")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1081,6 +1110,35 @@ def test_request_extra_query(self) -> None: params = dict(request.url.params) assert params == {"foo": "2"} + def test_multipart_repeating_array(self, async_client: AsyncDataherald) -> None: + request = async_client._build_request( + FinalRequestOptions.construct( + method="get", + url="/foo", + headers={"Content-Type": "multipart/form-data; boundary=6b7ba517decee4a450543ea6ae821c82"}, + json_data={"array": ["foo", "bar"]}, + files=[("foo.txt", b"hello world")], + ) + ) + + assert request.read().split(b"\r\n") == [ + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"foo", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="array[]"', + b"", + b"bar", + b"--6b7ba517decee4a450543ea6ae821c82", + b'Content-Disposition: form-data; name="foo.txt"; filename="upload"', + b"Content-Type: application/octet-stream", + b"", + b"hello world", + b"--6b7ba517decee4a450543ea6ae821c82--", + b"", + ] + @pytest.mark.respx(base_url=base_url) async def test_basic_union_response(self, respx_mock: MockRouter) -> None: class Model1(BaseModel): @@ -1327,7 +1385,7 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) with pytest.raises(APITimeoutError): await self.client.post( "/api/database-connections", - body=dict(alias="string", connection_uri="string"), + body=cast(object, dict(alias="string", connection_uri="string")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1342,7 +1400,7 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) with pytest.raises(APIStatusError): await self.client.post( "/api/database-connections", - body=dict(alias="string", connection_uri="string"), + body=cast(object, dict(alias="string", connection_uri="string")), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) diff --git a/tests/test_response.py b/tests/test_response.py index 1359a2f..29b4a69 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -1,8 +1,11 @@ +import json from typing import List import httpx import pytest +import pydantic +from dataherald import BaseModel, Dataherald, AsyncDataherald from dataherald._response import ( APIResponse, BaseAPIResponse, @@ -11,6 +14,8 @@ AsyncBinaryAPIResponse, extract_response_type, ) +from dataherald._streaming import Stream +from dataherald._base_client import FinalRequestOptions class ConcreteBaseAPIResponse(APIResponse[bytes]): @@ -48,3 +53,107 @@ def test_extract_response_type_concrete_subclasses() -> None: def test_extract_response_type_binary_response() -> None: assert extract_response_type(BinaryAPIResponse) == bytes assert extract_response_type(AsyncBinaryAPIResponse) == bytes + + +class PydanticModel(pydantic.BaseModel): + ... + + +def test_response_parse_mismatched_basemodel(client: Dataherald) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from dataherald import BaseModel`", + ): + response.parse(to=PydanticModel) + + +@pytest.mark.asyncio +async def test_async_response_parse_mismatched_basemodel(async_client: AsyncDataherald) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + with pytest.raises( + TypeError, + match="Pydantic models must subclass our base model type, e.g. `from dataherald import BaseModel`", + ): + await response.parse(to=PydanticModel) + + +def test_response_parse_custom_stream(client: Dataherald) -> None: + response = APIResponse( + raw=httpx.Response(200, content=b"foo"), + client=client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = response.parse(to=Stream[int]) + assert stream._cast_to == int + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_stream(async_client: AsyncDataherald) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=b"foo"), + client=async_client, + stream=True, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + stream = await response.parse(to=Stream[int]) + assert stream._cast_to == int + + +class CustomModel(BaseModel): + foo: str + bar: int + + +def test_response_parse_custom_model(client: Dataherald) -> None: + response = APIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 + + +@pytest.mark.asyncio +async def test_async_response_parse_custom_model(async_client: AsyncDataherald) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=json.dumps({"foo": "hello!", "bar": 2})), + client=async_client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + obj = await response.parse(to=CustomModel) + assert obj.foo == "hello!" + assert obj.bar == 2 diff --git a/tests/test_transform.py b/tests/test_transform.py index 5ebb55f..c67f48f 100644 --- a/tests/test_transform.py +++ b/tests/test_transform.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, List, Union, Optional +from typing import Any, List, Union, Iterable, Optional, cast from datetime import date, datetime from typing_extensions import Required, Annotated, TypedDict @@ -265,3 +265,35 @@ def test_pydantic_default_field() -> None: assert model.with_none_default == "bar" assert model.with_str_default == "baz" assert transform(model, Any) == {"with_none_default": "bar", "with_str_default": "baz"} + + +class TypedDictIterableUnion(TypedDict): + foo: Annotated[Union[Bar8, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +class Bar8(TypedDict): + foo_bar: Annotated[str, PropertyInfo(alias="fooBar")] + + +class Baz8(TypedDict): + foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")] + + +def test_iterable_of_dictionaries() -> None: + assert transform({"foo": [{"foo_baz": "bar"}]}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "bar"}]} + assert cast(Any, transform({"foo": ({"foo_baz": "bar"},)}, TypedDictIterableUnion)) == {"FOO": [{"fooBaz": "bar"}]} + + def my_iter() -> Iterable[Baz8]: + yield {"foo_baz": "hello"} + yield {"foo_baz": "world"} + + assert transform({"foo": my_iter()}, TypedDictIterableUnion) == {"FOO": [{"fooBaz": "hello"}, {"fooBaz": "world"}]} + + +class TypedDictIterableUnionStr(TypedDict): + foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")] + + +def test_iterable_union_str() -> None: + assert transform({"foo": "bar"}, TypedDictIterableUnionStr) == {"FOO": "bar"} + assert cast(Any, transform(iter([{"foo_baz": "bar"}]), Union[str, Iterable[Baz8]])) == [{"fooBaz": "bar"}]