Skip to content
This repository was archived by the owner on Nov 1, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions layer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import TYPE_CHECKING

import lazy_loader # noqa

from .context import Context # noqa
from .contracts.datasets import Dataset # noqa
from .contracts.logged_data import Image, Markdown, Video # noqa
Expand All @@ -21,8 +25,6 @@
)
from .flavors.custom import CustomModel # noqa
from .global_context import current_project_full_name # noqa
from .logged_data.callbacks import KerasCallback, XGBoostCallback # noqa
from .logged_data.loggers.pytorch_lightning import PytorchLightningLogger # noqa
from .main.asset import get_dataset, get_model, save_model # noqa
from .main.auth import ( # noqa
login,
Expand All @@ -39,6 +41,20 @@
from .pandas_extensions import Arrays, Images, _register_type_extensions # noqa


# keep the existing type definitions only for autocompletions in the editors and type checks
if TYPE_CHECKING:
from .logged_data.callbacks import KerasCallback, XGBoostCallback # noqa
from .logged_data.loggers.pytorch_lightning import PytorchLightningLogger # noqa

# patch __getattr__, __dir__ and __all__ to lazy load the symbols only when they're required
__getattr__, __dir__, __all__ = lazy_loader.attach(
__name__,
submod_attrs={
"logged_data.callbacks": ["KerasCallback", "XGBoostCallback"],
"logged_data.loggers.pytorch_lightning": ["PytorchLightningLogger"],
},
)

_register_type_extensions()

__version__ = get_version()
19 changes: 13 additions & 6 deletions layer/clients/data_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import uuid
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Callable, Generator, List, Optional, Tuple

import pandas
import pyarrow
from layerapi.api.entity.dataset_build_pb2 import DatasetBuild as PBDatasetBuild
from layerapi.api.entity.dataset_pb2 import Dataset as PBDataset
from layerapi.api.entity.dataset_version_pb2 import DatasetVersion as PBDatasetVersion
Expand Down Expand Up @@ -46,7 +44,6 @@
from layer.contracts.datasets import Dataset, DatasetBuild, DatasetBuildStatus
from layer.contracts.project_full_name import ProjectFullName
from layer.exceptions.exceptions import LayerClientException
from layer.pandas_extensions import _infer_custom_types
from layer.utils.file_utils import tar_directory
from layer.utils.grpc import generate_client_error_from_grpc_error
from layer.utils.grpc.channel import get_grpc_channel
Expand All @@ -55,6 +52,11 @@
from .dataset_service import DatasetClient, DatasetClientError


if TYPE_CHECKING:
import pandas
import pyarrow


class DataCatalogClient:
_service: DataCatalogAPIStub

Expand Down Expand Up @@ -95,6 +97,8 @@ def _get_python_dataset_access_credentials(
def fetch_dataset(
self, asset_path: AssetPath, no_cache: bool = False
) -> "pandas.DataFrame":
import pandas

data_ticket = DataTicket(
dataset_path_ticket=DatasetPathTicket(path=asset_path.path()),
)
Expand Down Expand Up @@ -139,6 +143,9 @@ def store_dataset(
:param build_id: dataset build id
:param progress_callback: progress callback
"""
import pyarrow

from layer.pandas_extensions import _infer_custom_types

# Creates a Record batch from the pandas dataframe
batch = pyarrow.RecordBatch.from_pandas(
Expand Down Expand Up @@ -399,8 +406,8 @@ def _language_version() -> Tuple[int, int, int]:


def _get_batch_chunks(
batch: pyarrow.RecordBatch, max_chunk_size_bytes: int = 4_000_000
) -> Generator[pyarrow.RecordBatch, None, None]:
batch: "pyarrow.RecordBatch", max_chunk_size_bytes: int = 4_000_000
) -> Generator["pyarrow.RecordBatch", None, None]:
"""
Slice the batch into chunks, based on average row size,
but not exceeding the maximum chunk size.
Expand Down
20 changes: 14 additions & 6 deletions layer/clients/dataset_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple, Union

import pandas
import pyarrow
from layerapi.api.service.dataset.dataset_api_pb2 import Command
from layerapi.api.value.ticket_pb2 import PartitionTicket
from pyarrow import flight as fl
Expand All @@ -15,6 +13,10 @@
from layer.utils.grpc import create_grpc_ssl_config


if TYPE_CHECKING:
import pandas


class DatasetClientError(Exception):
pass

Expand Down Expand Up @@ -47,13 +49,15 @@ class PartitionMetadata:
class Partition:
def __init__(
self,
reader: Union[fl.FlightStreamReader, pandas.DataFrame],
reader: Union[fl.FlightStreamReader, "pandas.DataFrame"],
from_cache: bool = False,
):
self._reader = reader
self._from_cache = from_cache

def to_pandas(self) -> pandas.DataFrame:
def to_pandas(self) -> "pandas.DataFrame":
import pandas

if isinstance(self._reader, pandas.DataFrame):
return self._reader
return self._reader.read_pandas()
Expand Down Expand Up @@ -104,6 +108,8 @@ def __init__(self, address_and_port: str, access_token: str) -> None:
)

def health_check(self) -> str:
import pyarrow

buf = pyarrow.allocate_buffer(0)
action = fl.Action("HealthCheck", buf)
result = next(self._flight.do_action(action))
Expand Down Expand Up @@ -162,5 +168,7 @@ def get_dataset_writer(self, command: Command, schema: Any) -> Any:
return self._flight.do_put(descriptor, schema)


def _read_parquet(path: Union[str, Path]) -> pandas.DataFrame:
def _read_parquet(path: Union[str, Path]) -> "pandas.DataFrame":
import pandas

return pandas.read_parquet(path, engine="pyarrow")
23 changes: 19 additions & 4 deletions layer/contracts/datasets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import enum
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, List, Mapping, Optional, Sequence, Union

import pandas
from typing import (
TYPE_CHECKING,
Any,
Callable,
List,
Mapping,
Optional,
Sequence,
Union,
)

from layer.contracts.logged_data import LoggedDataObject
from layer.logged_data.log_data_runner import LogDataRunner

from .asset import AssetPath, AssetType, BaseAsset


if TYPE_CHECKING:
import pandas


def _create_empty_data_frame() -> "pandas.DataFrame":
import pandas

return pandas.DataFrame()


Expand Down Expand Up @@ -149,7 +162,9 @@ def to_pytorch(
class PytorchDataset(torch.utils.data.Dataset[Any]):
# TODO: Streaming data fetching for faster data access

def __init__(self, df: pandas.DataFrame, transformer: Callable[[Any], Any]):
def __init__(
self, df: "pandas.DataFrame", transformer: Callable[[Any], Any]
):
self.df = df
self.transformer = transformer

Expand Down
9 changes: 6 additions & 3 deletions layer/contracts/models.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import uuid
from dataclasses import dataclass
from typing import Optional, Sequence, Union
from typing import TYPE_CHECKING, Optional, Sequence, Union

import pandas as pd
from layerapi.api.ids_pb2 import ModelTrainId
from layerapi.api.value.aws_credentials_pb2 import AwsCredentials
from layerapi.api.value.s3_path_pb2 import S3Path
Expand All @@ -16,6 +15,10 @@
from .asset import AssetPath, AssetType, BaseAsset


if TYPE_CHECKING:
import pandas as pd


@dataclass(frozen=True)
class TrainStorageConfiguration:
train_id: ModelTrainId
Expand Down Expand Up @@ -97,7 +100,7 @@ def get_train(self) -> ModelObject:
"""
return self.model_object

def predict(self, input_df: pd.DataFrame) -> pd.DataFrame:
def predict(self, input_df: "pd.DataFrame") -> "pd.DataFrame":
"""
Performs prediction on the input dataframe data.
:return: the predictions as a pd.DataFrame
Expand Down
5 changes: 2 additions & 3 deletions layer/flavors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from pathlib import Path
from typing import TYPE_CHECKING, Callable, Optional

import pandas as pd

from layer.types import ModelObject


if TYPE_CHECKING:
import pandas as pd
from layerapi.api.value.model_flavor_pb2 import ModelFlavor as PBModelFlavor


@dataclass(frozen=True)
class ModelRuntimeObjects:
model_object: ModelObject
prediction_function: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None
prediction_function: Optional[Callable[["pd.DataFrame"], "pd.DataFrame"]] = None


class ModelFlavor(metaclass=ABCMeta):
Expand Down
10 changes: 8 additions & 2 deletions layer/flavors/catboost.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
from layerapi.api.value.model_flavor_pb2 import ModelFlavor as PbModelFlavor

from layer.types import ModelObject

from .base import ModelFlavor, ModelRuntimeObjects


if TYPE_CHECKING:
import pandas as pd


class CatBoostModelFlavor(ModelFlavor):
"""An ML Model flavor implementation which handles persistence of CatBoost Models."""

Expand All @@ -30,6 +34,8 @@ def load_model_from_directory(self, directory: Path) -> ModelRuntimeObjects:
)

@staticmethod
def __predict(model: ModelObject, input_df: pd.DataFrame) -> pd.DataFrame:
def __predict(model: ModelObject, input_df: "pd.DataFrame") -> "pd.DataFrame":
import pandas as pd

prediction_np_array = model.predict(input_df) # type: ignore
return pd.DataFrame(prediction_np_array)
13 changes: 9 additions & 4 deletions layer/flavors/custom.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import pickle # nosec
from abc import abstractmethod
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any

import pandas
from layerapi.api.value.model_flavor_pb2 import ModelFlavor as PbModelFlavor

from layer.types import ModelObject

from .base import ModelFlavor, ModelRuntimeObjects


if TYPE_CHECKING:
import pandas


class CustomModel:
"""
A generic model that evaluates inputs and produces outputs.
Expand All @@ -22,7 +25,7 @@ def __init__(self) -> None:
"""

@abstractmethod
def predict(self, model_input: pandas.DataFrame) -> pandas.DataFrame:
def predict(self, model_input: "pandas.DataFrame") -> "pandas.DataFrame":
"""
Evaluates an input for this model and produces an output.

Expand Down Expand Up @@ -124,5 +127,7 @@ def load_model_from_directory(self, directory: Path) -> ModelRuntimeObjects:
)

@staticmethod
def __predict(model: ModelObject, input_df: pandas.DataFrame) -> pandas.DataFrame:
def __predict(
model: ModelObject, input_df: "pandas.DataFrame"
) -> "pandas.DataFrame":
return model.predict(input_df) # type: ignore
8 changes: 6 additions & 2 deletions layer/flavors/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
from layerapi.api.value.model_flavor_pb2 import ModelFlavor as PbModelFlavor

from layer.types import ModelObject

from .base import ModelFlavor, ModelRuntimeObjects


if TYPE_CHECKING:
import pandas as pd


class HuggingFaceModelFlavor(ModelFlavor):
"""An ML Model flavor implementation which handles persistence of Hugging Face Transformer Models."""

Expand Down Expand Up @@ -40,5 +44,5 @@ def load_model_from_directory(self, directory: Path) -> ModelRuntimeObjects:
)

@staticmethod
def __predict(model: ModelObject, input_df: pd.DataFrame) -> pd.DataFrame:
def __predict(model: ModelObject, input_df: "pd.DataFrame") -> "pd.DataFrame":
raise Exception("Not implemented")
10 changes: 8 additions & 2 deletions layer/flavors/keras.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
from pathlib import Path
from typing import TYPE_CHECKING

import pandas as pd
from layerapi.api.value.model_flavor_pb2 import ModelFlavor as PbModelFlavor

from layer.types import ModelObject

from .base import ModelFlavor, ModelRuntimeObjects


if TYPE_CHECKING:
import pandas as pd


class KerasModelFlavor(ModelFlavor):
"""An ML Model flavor implementation which handles persistence of Keras Models."""

Expand Down Expand Up @@ -81,7 +85,9 @@ def load_model_from_directory(self, directory: Path) -> ModelRuntimeObjects:
)

@staticmethod
def __predict(model: ModelObject, input_df: pd.DataFrame) -> pd.DataFrame:
def __predict(model: ModelObject, input_df: "pd.DataFrame") -> "pd.DataFrame":
import pandas as pd

# https://www.tensorflow.org/api_docs/python/tf/keras/Model#predict
predictions = model.predict(input_df) # type: ignore
return pd.DataFrame(predictions)
Loading