Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions src/dataherald/_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@
from ._constants import (
DEFAULT_LIMITS,
DEFAULT_TIMEOUT,
MAX_RETRY_DELAY,
DEFAULT_MAX_RETRIES,
INITIAL_RETRY_DELAY,
RAW_RESPONSE_HEADER,
OVERRIDE_CAST_TO_HEADER,
)
Expand Down Expand Up @@ -589,47 +591,57 @@ def base_url(self, url: URL | str) -> None:
def platform_headers(self) -> Dict[str, str]:
return platform_headers(self._version)

def _parse_retry_after_header(self, response_headers: Optional[httpx.Headers] = None) -> float | None:
"""Returns a float of the number of seconds (not milliseconds) to wait after retrying, or None if unspecified.

About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
See also https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax
"""
if response_headers is None:
return None

# First, try the non-standard `retry-after-ms` header for milliseconds,
# which is more precise than integer-seconds `retry-after`
try:
retry_ms_header = response_headers.get("retry-after-ms", None)
return float(retry_ms_header) / 1000
except (TypeError, ValueError):
pass

# Next, try parsing `retry-after` header as seconds (allowing nonstandard floats).
retry_header = response_headers.get("retry-after")
try:
# note: the spec indicates that this should only ever be an integer
# but if someone sends a float there's no reason for us to not respect it
return float(retry_header)
except (TypeError, ValueError):
pass

# Last, try parsing `retry-after` as a date.
retry_date_tuple = email.utils.parsedate_tz(retry_header)
if retry_date_tuple is None:
return None

retry_date = email.utils.mktime_tz(retry_date_tuple)
return float(retry_date - time.time())

def _calculate_retry_timeout(
self,
remaining_retries: int,
options: FinalRequestOptions,
response_headers: Optional[httpx.Headers] = None,
) -> float:
max_retries = options.get_max_retries(self.max_retries)
try:
# About the Retry-After header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After
#
# <http-date>". See https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After#syntax for
# details.
if response_headers is not None:
retry_header = response_headers.get("retry-after")
try:
# note: the spec indicates that this should only ever be an integer
# but if someone sends a float there's no reason for us to not respect it
retry_after = float(retry_header)
except Exception:
retry_date_tuple = email.utils.parsedate_tz(retry_header)
if retry_date_tuple is None:
retry_after = -1
else:
retry_date = email.utils.mktime_tz(retry_date_tuple)
retry_after = int(retry_date - time.time())
else:
retry_after = -1

except Exception:
retry_after = -1

# If the API asks us to wait a certain amount of time (and it's a reasonable amount), just do what it says.
if 0 < retry_after <= 60:
retry_after = self._parse_retry_after_header(response_headers)
if retry_after is not None and 0 < retry_after <= 60:
return retry_after

initial_retry_delay = 0.5
max_retry_delay = 8.0
nb_retries = max_retries - remaining_retries

# Apply exponential backoff, but not more than the max.
sleep_seconds = min(initial_retry_delay * pow(2.0, nb_retries), max_retry_delay)
sleep_seconds = min(INITIAL_RETRY_DELAY * pow(2.0, nb_retries), MAX_RETRY_DELAY)

# Apply some jitter, plus-or-minus half a second.
jitter = 1 - 0.25 * random()
Expand Down
39 changes: 38 additions & 1 deletion src/dataherald/_compat.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Union, TypeVar, cast
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
from datetime import date, datetime
from typing_extensions import Self

import pydantic
from pydantic.fields import FieldInfo

from ._types import StrBytesIntFloat

_T = TypeVar("_T")
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)

# --------------- Pydantic v2 compatibility ---------------
Expand Down Expand Up @@ -178,8 +180,43 @@ class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel):
# cached properties
if TYPE_CHECKING:
cached_property = property

# we define a separate type (copied from typeshed)
# that represents that `cached_property` is `set`able
# at runtime, which differs from `@property`.
#
# this is a separate type as editors likely special case
# `@property` and we don't want to cause issues just to have
# more helpful internal types.

class typed_cached_property(Generic[_T]):
func: Callable[[Any], _T]
attrname: str | None

def __init__(self, func: Callable[[Any], _T]) -> None:
...

@overload
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self:
...

@overload
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T:
...

def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
raise NotImplementedError()

def __set_name__(self, owner: type[Any], name: str) -> None:
...

# __set__ is not defined at runtime, but @cached_property is designed to be settable
def __set__(self, instance: object, value: _T) -> None:
...
else:
try:
from functools import cached_property as cached_property
except ImportError:
from cached_property import cached_property as cached_property

typed_cached_property = cached_property
3 changes: 3 additions & 0 deletions src/dataherald/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@
DEFAULT_TIMEOUT = httpx.Timeout(timeout=60.0, connect=5.0)
DEFAULT_MAX_RETRIES = 2
DEFAULT_LIMITS = httpx.Limits(max_connections=100, max_keepalive_connections=20)

INITIAL_RETRY_DELAY = 0.5
MAX_RETRY_DELAY = 8.0
1 change: 1 addition & 0 deletions src/dataherald/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._sync import asyncify as asyncify
from ._proxy import LazyProxy as LazyProxy
from ._utils import (
flatten as flatten,
Expand Down
20 changes: 2 additions & 18 deletions src/dataherald/_utils/_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from abc import ABC, abstractmethod
from typing import Generic, TypeVar, Iterable, cast
from typing_extensions import ClassVar, override
from typing_extensions import override

T = TypeVar("T")

Expand All @@ -13,11 +13,6 @@ class LazyProxy(Generic[T], ABC):
This includes forwarding attribute access and othe methods.
"""

should_cache: ClassVar[bool] = False

def __init__(self) -> None:
self.__proxied: T | None = None

# Note: we have to special case proxies that themselves return proxies
# to support using a proxy as a catch-all for any random access, e.g. `proxy.foo.bar.baz`

Expand Down Expand Up @@ -57,18 +52,7 @@ def __class__(self) -> type:
return proxied.__class__

def __get_proxied__(self) -> T:
if not self.should_cache:
return self.__load__()

proxied = self.__proxied
if proxied is not None:
return proxied

self.__proxied = proxied = self.__load__()
return proxied

def __set_proxied__(self, value: T) -> None:
self.__proxied = value
return self.__load__()

def __as_proxied__(self) -> T:
"""Helper method that returns the current proxy, typed as the loaded object"""
Expand Down
64 changes: 64 additions & 0 deletions src/dataherald/_utils/_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from __future__ import annotations

import functools
from typing import TypeVar, Callable, Awaitable
from typing_extensions import ParamSpec

import anyio
import anyio.to_thread

T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")


# copied from `asyncer`, https://github.com/tiangolo/asyncer
def asyncify(
function: Callable[T_ParamSpec, T_Retval],
*,
cancellable: bool = False,
limiter: anyio.CapacityLimiter | None = None,
) -> Callable[T_ParamSpec, Awaitable[T_Retval]]:
"""
Take a blocking function and create an async one that receives the same
positional and keyword arguments, and that when called, calls the original function
in a worker thread using `anyio.to_thread.run_sync()`. Internally,
`asyncer.asyncify()` uses the same `anyio.to_thread.run_sync()`, but it supports
keyword arguments additional to positional arguments and it adds better support for
autocompletion and inline errors for the arguments of the function called and the
return value.

If the `cancellable` option is enabled and the task waiting for its completion is
cancelled, the thread will still run its course but its return value (or any raised
exception) will be ignored.

Use it like this:

```Python
def do_work(arg1, arg2, kwarg1="", kwarg2="") -> str:
# Do work
return "Some result"


result = await to_thread.asyncify(do_work)("spam", "ham", kwarg1="a", kwarg2="b")
print(result)
```

## Arguments

`function`: a blocking regular callable (e.g. a function)
`cancellable`: `True` to allow cancellation of the operation
`limiter`: capacity limiter to use to limit the total amount of threads running
(if omitted, the default limiter is used)

## Return

An async function that takes the same positional and keyword arguments as the
original one, that when called runs the same original function in a thread worker
and returns the result.
"""

async def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval:
partial_f = functools.partial(function, *args, **kwargs)
return await anyio.to_thread.run_sync(partial_f, cancellable=cancellable, limiter=limiter)

return wrapper
31 changes: 29 additions & 2 deletions src/dataherald/_utils/_typing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, cast
from typing import Any, TypeVar, cast
from typing_extensions import Required, Annotated, get_args, get_origin

from .._types import InheritsGeneric
Expand All @@ -23,6 +23,12 @@ def is_required_type(typ: type) -> bool:
return get_origin(typ) == Required


def is_typevar(typ: type) -> bool:
# type ignore is required because type checkers
# think this expression will always return False
return type(typ) == TypeVar # type: ignore


# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
def strip_annotated_type(typ: type) -> type:
if is_required_type(typ) or is_annotated_type(typ):
Expand All @@ -49,6 +55,15 @@ class MyResponse(Foo[bytes]):

extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
```

And where a generic subclass is given:
```py
_T = TypeVar('_T')
class MyResponse(Foo[_T]):
...

extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
```
"""
cls = cast(object, get_origin(typ) or typ)
if cls in generic_bases:
Expand All @@ -75,6 +90,18 @@ class MyResponse(Foo[bytes]):
f"Does {cls} inherit from one of {generic_bases} ?"
)

return extract_type_arg(target_base_class, index)
extracted = extract_type_arg(target_base_class, index)
if is_typevar(extracted):
# If the extracted type argument is itself a type variable
# then that means the subclass itself is generic, so we have
# to resolve the type argument from the class itself, not
# the base class.
#
# Note: if there is more than 1 type argument, the subclass could
# change the ordering of the type arguments, this is not currently
# supported.
return extract_type_arg(typ, index)

return extracted

raise RuntimeError(f"Could not resolve inner type variable at index {index} for {typ}")
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ async def list(

class DatabaseConnectionsWithRawResponse:
def __init__(self, database_connections: DatabaseConnections) -> None:
self.drivers = DriversWithRawResponse(database_connections.drivers)
self._database_connections = database_connections

self.create = to_raw_response_wrapper(
database_connections.create,
Expand All @@ -388,10 +388,14 @@ def __init__(self, database_connections: DatabaseConnections) -> None:
database_connections.list,
)

@cached_property
def drivers(self) -> DriversWithRawResponse:
return DriversWithRawResponse(self._database_connections.drivers)


class AsyncDatabaseConnectionsWithRawResponse:
def __init__(self, database_connections: AsyncDatabaseConnections) -> None:
self.drivers = AsyncDriversWithRawResponse(database_connections.drivers)
self._database_connections = database_connections

self.create = async_to_raw_response_wrapper(
database_connections.create,
Expand All @@ -406,10 +410,14 @@ def __init__(self, database_connections: AsyncDatabaseConnections) -> None:
database_connections.list,
)

@cached_property
def drivers(self) -> AsyncDriversWithRawResponse:
return AsyncDriversWithRawResponse(self._database_connections.drivers)


class DatabaseConnectionsWithStreamingResponse:
def __init__(self, database_connections: DatabaseConnections) -> None:
self.drivers = DriversWithStreamingResponse(database_connections.drivers)
self._database_connections = database_connections

self.create = to_streamed_response_wrapper(
database_connections.create,
Expand All @@ -424,10 +432,14 @@ def __init__(self, database_connections: DatabaseConnections) -> None:
database_connections.list,
)

@cached_property
def drivers(self) -> DriversWithStreamingResponse:
return DriversWithStreamingResponse(self._database_connections.drivers)


class AsyncDatabaseConnectionsWithStreamingResponse:
def __init__(self, database_connections: AsyncDatabaseConnections) -> None:
self.drivers = AsyncDriversWithStreamingResponse(database_connections.drivers)
self._database_connections = database_connections

self.create = async_to_streamed_response_wrapper(
database_connections.create,
Expand All @@ -441,3 +453,7 @@ def __init__(self, database_connections: AsyncDatabaseConnections) -> None:
self.list = async_to_streamed_response_wrapper(
database_connections.list,
)

@cached_property
def drivers(self) -> AsyncDriversWithStreamingResponse:
return AsyncDriversWithStreamingResponse(self._database_connections.drivers)
Loading