diff --git a/docs/api/sagemaker_mlops.rst b/docs/api/sagemaker_mlops.rst index f67879111d..d9f911068e 100644 --- a/docs/api/sagemaker_mlops.rst +++ b/docs/api/sagemaker_mlops.rst @@ -21,6 +21,14 @@ Workflow Management :undoc-members: :show-inheritance: +Feature Store +------------- + +.. automodule:: sagemaker.mlops.feature_store + :members: + :undoc-members: + :show-inheritance: + Local Development ----------------- diff --git a/sagemaker-mlops/pyproject.toml b/sagemaker-mlops/pyproject.toml index 16b60746be..ffc66b473e 100644 --- a/sagemaker-mlops/pyproject.toml +++ b/sagemaker-mlops/pyproject.toml @@ -28,6 +28,8 @@ dependencies = [ "sagemaker-serve>=1.5.0", "boto3>=1.42.2,<2.0", "botocore>=1.42.2,<2.0", + "pyiceberg[glue]>=0.8.0", + "s3fs", ] [project.optional-dependencies] diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py index f15d6d3845..ad75442185 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/__init__.py @@ -2,8 +2,12 @@ # Licensed under the Apache License, Version 2.0 """SageMaker FeatureStore V3 - powered by sagemaker-core.""" -# Resources from core +# FeatureGroup with additional operational support from sagemaker.core.resources import FeatureGroup, FeatureMetadata +from sagemaker.mlops.feature_store.feature_group_manager import FeatureGroupManager + +# Resources from core +from sagemaker.core.resources import FeatureMetadata # Shapes from core (Pydantic - no to_dict() needed) from sagemaker.core.shapes import ( @@ -73,6 +77,7 @@ __all__ = [ # Resources "FeatureGroup", + "FeatureGroupManager", "FeatureMetadata", # Shapes "DataCatalogConfig", diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group_manager.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group_manager.py new file mode 100644 index 0000000000..371edcc1b3 --- /dev/null +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_group_manager.py @@ -0,0 +1,447 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 +"""FeatureGroup with Lake Formation support.""" + +import json +import logging +from typing import Dict, List, Optional + +from pydantic import model_validator + +from sagemaker.core.resources import FeatureGroup +from sagemaker.core.resources import Base +from sagemaker.core.shapes import ( + AddOnlineStoreReplicaAction, + FeatureDefinition, + OfflineStoreConfig, + OnlineStoreConfig, + OnlineStoreConfigUpdate, + Tag, + ThroughputConfig, + ThroughputConfigUpdate, +) +from sagemaker.core.shapes import Unassigned +from sagemaker.core.helper.pipeline_variable import StrPipeVar +from sagemaker.core.s3.utils import parse_s3_url +from sagemaker.core.common_utils import aws_partition +from boto3 import Session +from pyiceberg.catalog import load_catalog +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type +from sagemaker.mlops.feature_store.feature_utils import _APPROVED_ICEBERG_PROPERTIES + + +logger = logging.getLogger(__name__) + + +class IcebergProperties(Base): + """Configuration for Iceberg table properties in a Feature Group offline store. + + Use this to customize Iceberg table behavior such as compaction settings, + snapshot retention, and other Iceberg-specific configurations. + + Attributes: + properties: A dictionary mapping Iceberg property names to their values. + Common properties include: + - 'write.target-file-size-bytes': Target size for data files + - 'commit.manifest.min-count-to-merge': Min manifests before merging + - 'history.expire.max-snapshot-age-ms': Max age for snapshot expiration + """ + + properties: Optional[Dict[str, str]] = None + + @model_validator(mode="after") + def validate_property_keys(self): + if self.properties is None: + return self + invalid_keys = set(self.properties.keys()) - _APPROVED_ICEBERG_PROPERTIES + if invalid_keys: + raise ValueError( + f"Invalid iceberg properties: {invalid_keys}. " + f"Approved properties are: {_APPROVED_ICEBERG_PROPERTIES}" + ) + # Check for no duplicate keys + if len(set(self.properties.keys())) != len(self.properties.keys()): + raise ValueError( + f"Invalid duplicate properties. Please only have 1 of each property." + ) + return self + + +class FeatureGroupManager(FeatureGroup): + + # Attribute for Iceberg table properties (populated by get() when include_iceberg_properties=True) + iceberg_properties: Optional[IcebergProperties] = None + + # Inherit parent docstring and append our additions + if FeatureGroup.__doc__ and __doc__: + __doc__ = FeatureGroup.__doc__ + + def _validate_table_ownership(self, table, database_name: str, table_name: str): + """Validate that the Iceberg table belongs to this feature group by checking S3 location.""" + table_location = table.metadata.location if table.metadata else None + s3_config = self.offline_store_config.s3_storage_config + if s3_config and s3_config.s3_uri: + expected_prefix = str(s3_config.s3_uri).rstrip("/") + if table_location and not table_location.startswith(expected_prefix): + logger.error( + f"Table ownership validation failed for feature group " + f"'{self.feature_group_name}'. The Glue table " + f"'{database_name}.{table_name}' has location '{table_location}' " + f"but the feature group's offline store is configured with " + f"S3 URI '{expected_prefix}'. This may indicate that the " + f"data_catalog_config is pointing to a table that does not belong " + f"to this feature group. To fix this, verify that the " + f"data_catalog_config.database and data_catalog_config.table_name " + f"in your feature group's offline_store_config match the correct " + f"Glue table for this feature group." + ) + raise ValueError( + f"Table '{database_name}.{table_name}' location '{table_location}' " + f"does not match the feature group's S3 URI '{expected_prefix}'. " + f"The table may not belong to feature group '{self.feature_group_name}'." + ) + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=10), + retry=retry_if_exception_type(RuntimeError), + reraise=True, + ) + def _get_iceberg_properties( + self, + session: Optional[Session] = None, + region: Optional[StrPipeVar] = None, + ) -> Dict[str, any]: + """ + Fetch the current Iceberg catalog table definition for the Feature Group's Iceberg offline store. + + Validates that the Feature Group has an Iceberg-formatted offline store + and retrieves the table via the Iceberg catalog. + + Parameters: + session: Optional boto3 session. If not provided, uses default credentials. + region: Optional AWS region. If not provided, uses default region. + + Returns: + Dict with keys: + - 'database_name': The Iceberg catalog database name + - 'table_name': The Iceberg catalog table name + - 'table': The pyiceberg Table object + - 'properties': The table properties dict + + Raises: + ValueError: If offline_store_config is not configured or table_format is not Iceberg. + RuntimeError: If the Iceberg catalog table load fails. + """ + # Validate offline store is configured + if self.offline_store_config is None or self.offline_store_config == Unassigned(): + raise ValueError( + "Cannot update Iceberg properties: offline_store_config is not configured" + ) + + # Validate table format is Iceberg + if ( + self.offline_store_config.table_format is None + or str(self.offline_store_config.table_format) != "Iceberg" + ): + raise ValueError( + "Cannot update Iceberg properties: table_format must be 'Iceberg'" + ) + + # Get database and table name from data_catalog_config + data_catalog_config = self.offline_store_config.data_catalog_config + if data_catalog_config is None: + raise ValueError( + "Cannot update Iceberg properties: data_catalog_config is not available" + ) + + database_name = str(data_catalog_config.database) + table_name = str(data_catalog_config.table_name) + + if session is None: + session = Session() + region_str = str(region) if region else session.region_name + catalog = load_catalog("glue", **{"type": "glue", "client.region": region_str}) + + try: + table = catalog.load_table(f"{database_name}.{table_name}") + self._validate_table_ownership(table, database_name, table_name) + + return { + "database_name": database_name, + "table_name": table_name, + "table": table, + "properties": dict(table.properties), + } + + except Exception as e: + raise RuntimeError( + f"Failed to get Iceberg properties for {self.feature_group_name}: {e}" + ) from e + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=1, max=10), + retry=retry_if_exception_type(RuntimeError), + reraise=True, + ) + def _update_iceberg_properties( + self, + iceberg_properties: IcebergProperties, + session: Optional[Session] = None, + region: Optional[StrPipeVar] = None, + ) -> Dict[str, any]: + """ + Update Iceberg table properties for the Feature Group's offline store. + + This method updates the Glue table properties for an Iceberg-formatted + offline store. The Feature Group must have an offline store configured + with table_format='Iceberg'. + + Parameters: + iceberg_properties: IcebergProperties object containing the properties to set. + session: Optional boto3 session. If not provided, uses default credentials. + region: Optional AWS region. If not provided, uses default region. + + Returns: + Dict containing the update results with keys: + - 'database': The Glue database name + - 'table': The Glue table name + - 'properties_updated': The properties that were updated + + Raises: + ValueError: If offline_store_config is not configured or table_format is not Iceberg. + RuntimeError: If the Glue table update fails. + """ + # Validate iceberg_properties has properties to update + if iceberg_properties is None or not iceberg_properties.properties: + raise ValueError( + "iceberg_properties must contain at least one property to update" + ) + + invalid_keys = set(iceberg_properties.properties.keys()) - _APPROVED_ICEBERG_PROPERTIES + if invalid_keys: + raise ValueError( + f"Invalid iceberg properties: {invalid_keys}. " + f"Approved properties are: {_APPROVED_ICEBERG_PROPERTIES}" + ) + + # Check for no duplicate keys + if len(set(iceberg_properties.properties.keys())) != len(iceberg_properties.properties.keys()): + raise ValueError( + f"Invalid duplicate properties. Please only have 1 of each property." + ) + + result = self._get_iceberg_properties(session=session, region=region) + database_name = result["database_name"] + table_name = result["table_name"] + table = result["table"] + current_properties = result["properties"] + + self._validate_table_ownership(table, database_name, table_name) + + # Compute before/after diff for audit logging + changed = {} + for key, new_value in iceberg_properties.properties.items(): + old_value = current_properties.get(key) + if old_value != new_value: + changed[key] = {"old": old_value, "new": new_value} + + logger.info( + f"Updating Iceberg properties for feature group '{self.feature_group_name}' " + f"(database={database_name}, table={table_name}). " + f"Property changes: {changed}" + ) + + try: + with table.transaction() as txn: + txn.set_properties(iceberg_properties.properties) + + logger.info( + f"Successfully updated Iceberg properties for feature group " + f"'{self.feature_group_name}'. Properties applied: {changed}" + ) + + return { + "database": database_name, + "table": table_name, + "properties_updated": iceberg_properties.properties, + } + + except Exception as e: + logger.error( + f"Failed to update Iceberg properties for feature group " + f"'{self.feature_group_name}'. Attempted changes: {changed}. Error: {e}" + ) + raise RuntimeError( + f"Failed to update Iceberg properties for {self.feature_group_name}: {e}" + ) from e + + @classmethod + def get( + cls, + *args, + include_iceberg_properties: bool = False, + **kwargs, + ) -> Optional["FeatureGroup"]: + """ + Get a FeatureGroup resource with optional Iceberg property retrieval. + + Accepts all parameters from FeatureGroup.get(), plus: + + Parameters: + include_iceberg_properties: If True, fetches Iceberg table properties + from Glue and stores them in the iceberg_properties attribute. + Only applies to Feature Groups with table_format='Iceberg'. + + Returns: + The FeatureGroup resource. + """ + session = kwargs.get("session") + region = kwargs.get("region") + + feature_group = super().get(*args, **kwargs) + + if include_iceberg_properties: + result = feature_group._get_iceberg_properties(session=session, region=region) + all_properties = result["properties"] + approved_properties = { + k: v for k, v in all_properties.items() + if k in _APPROVED_ICEBERG_PROPERTIES + } + feature_group.iceberg_properties = IcebergProperties( + properties=approved_properties + ) + + return feature_group + + @classmethod + def create( + cls, + *args, + iceberg_properties: Optional[IcebergProperties] = None, + **kwargs, + ) -> Optional["FeatureGroup"]: + """ + Create a FeatureGroup resource with optional Lake Formation governance and Iceberg properties. + + Accepts all parameters from FeatureGroup.create(), plus: + + Parameters: + lake_formation_config: Optional LakeFormationConfig to configure Lake Formation + governance. When enabled=True, requires offline_store_config and role_arn. + iceberg_properties: Optional IcebergProperties to configure Iceberg table + properties for the offline store. Requires offline_store_config with + table_format='Iceberg'. + + Returns: + The FeatureGroup resource. + """ + offline_store_config = kwargs.get("offline_store_config") + role_arn = kwargs.get("role_arn") + session = kwargs.get("session") + region = kwargs.get("region") + + # Validation for Iceberg properties + if iceberg_properties is not None and iceberg_properties.properties: + if offline_store_config is None: + raise ValueError( + "iceberg_properties requires offline_store_config to be configured" + ) + if ( + offline_store_config.table_format is None + or str(offline_store_config.table_format) != "Iceberg" + ): + raise ValueError( + "iceberg_properties requires offline_store_config.table_format to be 'Iceberg'" + ) + + feature_group = super().create(*args, **kwargs) + + # Update Iceberg properties if requested + if iceberg_properties is not None and iceberg_properties.properties: + # Wait for feature group to be created before updating Iceberg properties + feature_group.wait_for_status(target_status="Created") + try: + feature_group._update_iceberg_properties( + iceberg_properties=iceberg_properties, + session=session, + region=region, + ) + except Exception as e: + logger.error( + f"Feature group '{feature_group.feature_group_name}' was created " + f"successfully but failed to update Iceberg properties: {e}." + f"Please now run update on the created Feature Group with the" + f"Iceberg Properties to avoid recreating your Feature Group again." + ) + raise + + return feature_group + + def update( + self, + *args, + iceberg_properties: Optional[IcebergProperties] = None, + session: Optional[Session] = None, + region: Optional[StrPipeVar] = None, + **kwargs, + ) -> Optional["FeatureGroup"]: + """ + Update a FeatureGroup resource with optional Iceberg property updates. + + Accepts all parameters from FeatureGroup.update(), plus: + + Parameters: + iceberg_properties: Optional IcebergProperties to update Iceberg table + properties for the offline store. Requires offline_store_config with + table_format='Iceberg'. + session: Boto3 session for Iceberg property updates. + region: Region name for Iceberg property updates. + + Returns: + The FeatureGroup resource. + """ + + offline_store_config = self.offline_store_config + + # Validation for Iceberg properties + if iceberg_properties is not None and iceberg_properties.properties: + if offline_store_config is None or offline_store_config == Unassigned(): + raise ValueError( + "iceberg_properties requires offline_store_config to be configured" + ) + if ( + offline_store_config.table_format is None + or str(offline_store_config.table_format) != "Iceberg" + ): + raise ValueError( + "iceberg_properties requires offline_store_config.table_format to be 'Iceberg'" + ) + + # Only call parent update if there are non-iceberg args to pass + result = None + if args or kwargs: + try: + result = super().update(*args, **kwargs) + except Exception as e: + logger.error( + f"Feature group '{self.feature_group_name}' was not updated successfully: {e}" + ) + + # Update Iceberg properties if requested + if iceberg_properties is not None and iceberg_properties.properties: + try: + self._update_iceberg_properties( + iceberg_properties=iceberg_properties, + session=session, + region=region, + ) + except Exception as e: + logger.error( + f"Feature group '{self.feature_group_name}' was updated successfully " + f"but failed to update Iceberg properties: {e}" + ) + raise + + return result if result is not None else self diff --git a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py index 3e3e7813df..a809b0a8f3 100644 --- a/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py +++ b/sagemaker-mlops/src/sagemaker/mlops/feature_store/feature_utils.py @@ -49,6 +49,26 @@ _INTEGER_TYPES = {"int_", "int8", "int16", "int32", "int64", "uint8", "uint16", "uint32", "uint64"} _FLOAT_TYPES = {"float_", "float16", "float32", "float64"} +_APPROVED_ICEBERG_PROPERTIES = { + "write.metadata.delete-after-commit.enabled", + "write.metadata.previous-versions-max", + "history.expire.max-snapshot-age-ms", + "history.expire.min-snapshots-to-keep", + "write.parquet.row-group-size-bytes", + "read.split.target-size", + "read.split.metadata-target-size", + "write.delete.target-file-size-bytes", + "write.delete.mode", + "write.update.mode", + "write.delete.granularity", + "write.delete.isolation-level", + "write.update.isolation-level", + "write.merge.isolation-level", + "history.expire.max-ref-age-ms", + "read.split.open-file-cost", + "write.target-file-size-bytes" +} + def _get_athena_client(session: Session): """Get Athena client from session.""" @@ -731,4 +751,4 @@ def _cast_object_to_string(data_frame: pandas.DataFrame) -> pandas.DataFrame: """ for label in data_frame.select_dtypes(["object", "O"]).columns.tolist(): data_frame[label] = data_frame[label].astype("str").astype("string") - return data_frame \ No newline at end of file + return data_frame diff --git a/sagemaker-mlops/tests/integ/test_feature_store_iceberg_properties.py b/sagemaker-mlops/tests/integ/test_feature_store_iceberg_properties.py new file mode 100644 index 0000000000..452c5347f1 --- /dev/null +++ b/sagemaker-mlops/tests/integ/test_feature_store_iceberg_properties.py @@ -0,0 +1,331 @@ +"""Integration tests for FeatureGroupManager iceberg property handling.""" +import time + +import boto3 + +import pandas as pd +import pytest + +from sagemaker.core.helper.session_helper import Session, get_execution_role +from sagemaker.core.utils import unique_name_from_base +from sagemaker.mlops.feature_store import ( + OfflineStoreConfig, + S3StorageConfig, +) +from sagemaker.mlops.feature_store import FeatureGroupManager +from sagemaker.mlops.feature_store.feature_group_manager import IcebergProperties +from sagemaker.mlops.feature_store.feature_utils import ( + load_feature_definitions_from_dataframe, +) + + +@pytest.fixture(scope="module") +def sagemaker_session(): + return Session() + + +@pytest.fixture(scope="module") +def role(): + return get_execution_role() + + +@pytest.fixture(scope="module") +def region(): + return boto3.Session().region_name + + +@pytest.fixture(scope="module") +def bucket(sagemaker_session): + return sagemaker_session.default_bucket() + + +@pytest.fixture +def feature_group_name(): + return unique_name_from_base("integ-test-iceberg-fg") + + +@pytest.fixture +def sample_dataframe(): + from datetime import datetime, timezone, timedelta + base_time = datetime.now(timezone.utc) + return pd.DataFrame({ + "record_id": [f"id-{i}" for i in range(10)], + "feature_1": [i * 1.5 for i in range(10)], + "feature_2": [i * 2 for i in range(10)], + "event_time": [(base_time + timedelta(seconds=i)).strftime("%Y-%m-%dT%H:%M:%SZ") for i in range(10)], + }) + + +def cleanup_feature_group(feature_group_name): + try: + fg = FeatureGroupManager.get(feature_group_name=feature_group_name) + fg.delete() + time.sleep(2) + except Exception: + pass + + +def test_create_with_iceberg_properties( + feature_group_name, sample_dataframe, bucket, role +): + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + iceberg_props = IcebergProperties(properties={ + "write.metadata.delete-after-commit.enabled": "true", + "write.metadata.previous-versions-max": "5", + }) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + iceberg_properties=iceberg_props, + ) + + fg.wait_for_status("Created") + + retrieved = FeatureGroupManager.get( + feature_group_name=feature_group_name, + include_iceberg_properties=True, + ) + assert retrieved.iceberg_properties is not None + assert retrieved.iceberg_properties.properties["write.metadata.delete-after-commit.enabled"] == "true" + assert retrieved.iceberg_properties.properties["write.metadata.previous-versions-max"] == "5" + finally: + cleanup_feature_group(feature_group_name) + + +def test_update_iceberg_properties( + feature_group_name, sample_dataframe, bucket, role +): + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + ) + + fg.wait_for_status("Created") + + fg.update(iceberg_properties=IcebergProperties(properties={ + "write.metadata.delete-after-commit.enabled": "true", + "write.metadata.previous-versions-max": "5", + })) + + retrieved = FeatureGroupManager.get( + feature_group_name=feature_group_name, + include_iceberg_properties=True, + ) + assert retrieved.iceberg_properties.properties["write.metadata.delete-after-commit.enabled"] == "true" + assert retrieved.iceberg_properties.properties["write.metadata.previous-versions-max"] == "5" + finally: + cleanup_feature_group(feature_group_name) + + +def test_get_with_include_iceberg_properties( + feature_group_name, sample_dataframe, bucket, role +): + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + iceberg_properties=IcebergProperties(properties={ + "write.metadata.delete-after-commit.enabled": "true", + }), + ) + + fg.wait_for_status("Created") + + retrieved = FeatureGroupManager.get( + feature_group_name=feature_group_name, + include_iceberg_properties=True, + ) + assert retrieved.iceberg_properties is not None + assert isinstance(retrieved.iceberg_properties.properties, dict) + assert retrieved.iceberg_properties.properties["write.metadata.delete-after-commit.enabled"] == "true" + finally: + cleanup_feature_group(feature_group_name) + + +def test_create_with_iceberg_properties_none( + feature_group_name, sample_dataframe, bucket, role +): + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + iceberg_properties=None, + ) + + fg.wait_for_status("Created") + + assert fg.iceberg_properties is None + finally: + cleanup_feature_group(feature_group_name) + + +def test_update_only_iceberg_properties_skips_parent_update( + feature_group_name, sample_dataframe, bucket, role +): + """Test that update with only iceberg_properties skips the SageMaker UpdateFeatureGroup call.""" + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + ) + + fg.wait_for_status("Created") + + # Update with ONLY iceberg properties — no description or other parent args + fg.update(iceberg_properties=IcebergProperties(properties={ + "write.metadata.delete-after-commit.enabled": "true", + })) + + retrieved = FeatureGroupManager.get( + feature_group_name=feature_group_name, + include_iceberg_properties=True, + ) + assert retrieved.iceberg_properties.properties["write.metadata.delete-after-commit.enabled"] == "true" + finally: + cleanup_feature_group(feature_group_name) + + +def test_get_without_include_flag_has_no_iceberg_properties( + feature_group_name, sample_dataframe, bucket, role +): + """Test that get without include_iceberg_properties leaves iceberg_properties as None.""" + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + iceberg_properties=IcebergProperties(properties={ + "write.metadata.delete-after-commit.enabled": "true", + }), + ) + + fg.wait_for_status("Created") + + retrieved = FeatureGroupManager.get(feature_group_name=feature_group_name) + assert retrieved.iceberg_properties is None + finally: + cleanup_feature_group(feature_group_name) + + +def test_update_iceberg_properties_overwrites_previous_values( + feature_group_name, sample_dataframe, bucket, role +): + """Test that updating an iceberg property overwrites its previous value.""" + try: + feature_definitions = load_feature_definitions_from_dataframe(sample_dataframe) + + fg = FeatureGroupManager.create( + feature_group_name=feature_group_name, + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + role_arn=role, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri=f"s3://{bucket}/feature-store"), + table_format="Iceberg", + ), + iceberg_properties=IcebergProperties(properties={ + "write.metadata.previous-versions-max": "5", + }), + ) + + fg.wait_for_status("Created") + + # Overwrite with a new value + fg.update(iceberg_properties=IcebergProperties(properties={ + "write.metadata.previous-versions-max": "10", + })) + + retrieved = FeatureGroupManager.get( + feature_group_name=feature_group_name, + include_iceberg_properties=True, + ) + assert retrieved.iceberg_properties.properties["write.metadata.previous-versions-max"] == "10" + finally: + cleanup_feature_group(feature_group_name) + + +def test_create_iceberg_properties_without_offline_store_raises(): + with pytest.raises(ValueError, match="iceberg_properties requires offline_store_config"): + FeatureGroupManager.create( + feature_group_name="dummy-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[], + role_arn="arn:aws:iam::000000000000:role/dummy", + iceberg_properties=IcebergProperties(properties={ + "write.target-file-size-bytes": "536870912", + }), + ) + + +def test_create_iceberg_properties_with_non_iceberg_table_format_raises(): + with pytest.raises(ValueError, match="table_format to be 'Iceberg'"): + FeatureGroupManager.create( + feature_group_name="dummy-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[], + role_arn="arn:aws:iam::000000000000:role/dummy", + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/prefix"), + table_format="Glue", + ), + iceberg_properties=IcebergProperties(properties={ + "write.target-file-size-bytes": "536870912", + }), + ) diff --git a/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_iceberg_properties.py b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_iceberg_properties.py new file mode 100644 index 0000000000..1e918e7fee --- /dev/null +++ b/sagemaker-mlops/tests/unit/sagemaker/mlops/feature_store/test_iceberg_properties.py @@ -0,0 +1,995 @@ +"""Unit tests for Iceberg properties in FeatureGroupManager.""" +from unittest.mock import MagicMock, patch + +import pytest + +from boto3 import Session +from sagemaker.mlops.feature_store import FeatureGroupManager +from sagemaker.mlops.feature_store.feature_group_manager import IcebergProperties + + +class TestIcebergPropertiesConfig: + """Tests for IcebergProperties default values.""" + + def test_properties_defaults_to_none(self): + """Test that properties defaults to None.""" + config = IcebergProperties() + assert config.properties is None + + def test_properties_can_be_set(self): + """Test that properties can be set with a dict.""" + props = {"write.target-file-size-bytes": "536870912"} + config = IcebergProperties(properties=props) + assert config.properties == props + + def test_valid_approved_keys_accepted(self): + """Test that all approved keys are accepted.""" + props = { + "write.target-file-size-bytes": "536870912", + "write.metadata.delete-after-commit.enabled": "true", + "history.expire.max-snapshot-age-ms": "432000000", + } + config = IcebergProperties(properties=props) + assert config.properties == props + + def test_single_invalid_key_raises_error(self): + """Test that a single invalid key raises ValueError.""" + with pytest.raises(ValueError, match="Invalid iceberg properties"): + IcebergProperties(properties={"not.a.valid.key": "value"}) + + def test_multiple_invalid_keys_raises_error(self): + """Test that multiple invalid keys raise ValueError.""" + with pytest.raises(ValueError, match="Invalid iceberg properties"): + IcebergProperties(properties={"bad.key.one": "1", "bad.key.two": "2"}) + + def test_mix_valid_and_invalid_keys_raises_error(self): + """Test that a mix of valid and invalid keys raises ValueError.""" + with pytest.raises(ValueError, match="Invalid iceberg properties"): + IcebergProperties(properties={ + "write.target-file-size-bytes": "536870912", + "invalid.key": "value", + }) + + def test_error_message_contains_invalid_key_names(self): + """Test that the error message includes the invalid key names.""" + with pytest.raises(ValueError, match="fake.property"): + IcebergProperties(properties={"fake.property": "value"}) + + def test_duplicate_keys_raises_error(self): + """Test that duplicate property keys raise ValueError.""" + config = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + mock_props = MagicMock() + mock_props.keys.return_value = [ + "write.target-file-size-bytes", + "write.target-file-size-bytes", + ] + object.__setattr__(config, "properties", mock_props) + with pytest.raises(ValueError, match="Invalid duplicate properties"): + config.validate_property_keys() + + def test_no_duplicate_keys_passes(self): + """Test that unique approved keys pass duplicate validation.""" + config = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + result = config.validate_property_keys() + assert result is config + + +class TestValidateTableOwnership: + """Tests for _validate_table_ownership method.""" + + def setup_method(self): + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + self.fg = MagicMock(spec=FeatureGroupManager) + self.fg._validate_table_ownership = FeatureGroupManager._validate_table_ownership.__get__(self.fg) + self.fg.feature_group_name = "test-fg" + self.fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://my-bucket/feature-store"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + table_format="Iceberg", + ) + + def test_passes_when_location_matches(self): + """Test no error when table location matches S3 URI.""" + mock_table = MagicMock() + mock_table.metadata.location = "s3://my-bucket/feature-store/test_db/test_table" + self.fg._validate_table_ownership(mock_table, "test_db", "test_table") + + def test_raises_when_location_does_not_match(self): + """Test ValueError when table location doesn't match S3 URI.""" + mock_table = MagicMock() + mock_table.metadata.location = "s3://other-bucket/other-path/table" + with pytest.raises(ValueError, match="does not match the feature group's S3 URI"): + self.fg._validate_table_ownership(mock_table, "test_db", "test_table") + + def test_passes_when_no_metadata(self): + """Test no error when table has no metadata.""" + mock_table = MagicMock() + mock_table.metadata = None + self.fg._validate_table_ownership(mock_table, "test_db", "test_table") + + def test_passes_when_no_s3_config(self): + """Test no error when s3_storage_config is None.""" + object.__setattr__(self.fg.offline_store_config, "s3_storage_config", None) + mock_table = MagicMock() + mock_table.metadata.location = "s3://other-bucket/path" + self.fg._validate_table_ownership(mock_table, "test_db", "test_table") + + def test_passes_when_s3_uri_is_none(self): + """Test no error when s3_uri is None.""" + object.__setattr__(self.fg.offline_store_config.s3_storage_config, "s3_uri", None) + mock_table = MagicMock() + mock_table.metadata.location = "s3://other-bucket/path" + self.fg._validate_table_ownership(mock_table, "test_db", "test_table") + + def test_error_message_contains_details(self): + """Test error message includes table identifier, location, and feature group name.""" + mock_table = MagicMock() + mock_table.metadata.location = "s3://wrong-bucket/wrong-path" + with pytest.raises(ValueError, match="test_db.test_table") as exc_info: + self.fg._validate_table_ownership(mock_table, "test_db", "test_table") + assert "s3://wrong-bucket/wrong-path" in str(exc_info.value) + assert "s3://my-bucket/feature-store" in str(exc_info.value) + assert "test-fg" in str(exc_info.value) + + +class TestGetIcebergProperties: + """Tests for get_iceberg_properties method.""" + + def setup_method(self): + """Set up test fixtures.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + self.fg = MagicMock(spec=FeatureGroupManager) + self.fg._get_iceberg_properties = FeatureGroupManager._get_iceberg_properties.__get__(self.fg) + self.fg.feature_group_name = "test-fg" + self.fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://test-bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + table_format="Iceberg", + ) + + def test_raises_error_when_no_offline_store_config(self): + """Test ValueError when offline_store_config is None.""" + self.fg.offline_store_config = None + + with pytest.raises(ValueError, match="offline_store_config is not configured"): + self.fg._get_iceberg_properties() + + def test_raises_error_when_offline_store_config_is_unassigned(self): + """Test ValueError when offline_store_config is Unassigned().""" + from sagemaker.core.shapes import Unassigned + + self.fg.offline_store_config = Unassigned() + + with pytest.raises(ValueError, match="offline_store_config is not configured"): + self.fg._get_iceberg_properties() + + def test_raises_error_when_table_format_not_iceberg(self): + """Test ValueError when table_format is not Iceberg.""" + self.fg.offline_store_config.table_format = None + + with pytest.raises(ValueError, match="table_format must be 'Iceberg'"): + self.fg._get_iceberg_properties() + + def test_raises_error_when_no_data_catalog_config(self): + """Test ValueError when data_catalog_config is None.""" + self.fg.offline_store_config.data_catalog_config = None + + with pytest.raises(ValueError, match="data_catalog_config is not available"): + self.fg._get_iceberg_properties() + + @patch("sagemaker.mlops.feature_store.feature_group_manager.load_catalog") + def test_successful_load_table(self, mock_load_catalog): + """Test successful pyiceberg catalog load_table call.""" + mock_catalog = MagicMock() + mock_table = MagicMock() + mock_table.properties = {"table_type": "ICEBERG"} + mock_catalog.load_table.return_value = mock_table + mock_load_catalog.return_value = mock_catalog + + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + + result = self.fg._get_iceberg_properties(session=mock_session) + + assert result["database_name"] == "test_db" + assert result["table_name"] == "test_table" + assert result["table"] == mock_table + assert result["properties"] == {"table_type": "ICEBERG"} + mock_catalog.load_table.assert_called_once_with("test_db.test_table") + + @patch("sagemaker.mlops.feature_store.feature_group_manager.load_catalog") + def test_uses_provided_session_and_region(self, mock_load_catalog): + """Test that provided session and region are used instead of defaults.""" + mock_catalog = MagicMock() + mock_table = MagicMock() + mock_table.properties = {} + mock_catalog.load_table.return_value = mock_table + mock_load_catalog.return_value = mock_catalog + + mock_session = MagicMock() + + self.fg._get_iceberg_properties(session=mock_session, region="eu-west-1") + + mock_load_catalog.assert_called_once_with("glue", **{"type": "glue", "client.region": "eu-west-1"}) + + @patch("sagemaker.mlops.feature_store.feature_group_manager.load_catalog") + def test_uses_session_region_when_region_not_provided(self, mock_load_catalog): + """Test that session.region_name is used when region is None.""" + mock_catalog = MagicMock() + mock_table = MagicMock() + mock_table.properties = {} + mock_catalog.load_table.return_value = mock_table + mock_load_catalog.return_value = mock_catalog + + mock_session = MagicMock() + mock_session.region_name = "ap-southeast-1" + + self.fg._get_iceberg_properties(session=mock_session) + + mock_load_catalog.assert_called_once_with("glue", **{"type": "glue", "client.region": "ap-southeast-1"}) + + @patch("sagemaker.mlops.feature_store.feature_group_manager.load_catalog") + def test_raises_runtime_error_on_client_error(self, mock_load_catalog): + """Test RuntimeError wrapping Exception from pyiceberg.""" + mock_catalog = MagicMock() + mock_catalog.load_table.side_effect = Exception("Table not found") + mock_load_catalog.return_value = mock_catalog + + mock_session = MagicMock() + mock_session.region_name = "us-west-2" + + with pytest.raises(RuntimeError, match="Failed to get Iceberg properties"): + self.fg._get_iceberg_properties(session=mock_session) + + +class TestUpdateIcebergProperties: + """Tests for update_iceberg_properties method.""" + + def setup_method(self): + """Set up test fixtures.""" + self.fg = MagicMock(spec=FeatureGroupManager) + self.fg._update_iceberg_properties = FeatureGroupManager._update_iceberg_properties.__get__(self.fg) + self.fg.feature_group_name = "test-fg" + + def test_raises_error_when_iceberg_properties_is_none(self): + """Test ValueError when iceberg_properties is None.""" + with pytest.raises(ValueError, match="must contain at least one property"): + self.fg._update_iceberg_properties(iceberg_properties=None) + + def test_raises_error_when_properties_dict_is_empty(self): + """Test ValueError when properties dict is empty.""" + props = IcebergProperties(properties={}) + + with pytest.raises(ValueError, match="must contain at least one property"): + self.fg._update_iceberg_properties(iceberg_properties=props) + + def test_raises_error_when_properties_is_none_on_object(self): + """Test ValueError when IcebergProperties.properties is None.""" + props = IcebergProperties() + + with pytest.raises(ValueError, match="must contain at least one property"): + self.fg._update_iceberg_properties(iceberg_properties=props) + + def test_successful_update_merges_properties(self): + """Test successful update sets properties via pyiceberg transaction.""" + mock_table = MagicMock() + mock_txn = MagicMock() + mock_table.transaction.return_value.__enter__ = MagicMock(return_value=mock_txn) + mock_table.transaction.return_value.__exit__ = MagicMock(return_value=False) + + self.fg._get_iceberg_properties.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": mock_table, + "properties": {"table_type": "ICEBERG", "existing_key": "existing_value"}, + } + + props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + result = self.fg._update_iceberg_properties(iceberg_properties=props) + + mock_txn.set_properties.assert_called_once_with(props.properties) + + assert result["database"] == "test_db" + assert result["table"] == "test_table" + assert result["properties_updated"] == props.properties + + def test_update_with_no_existing_properties(self): + """Test update when table has no existing properties.""" + mock_table = MagicMock() + mock_txn = MagicMock() + mock_table.transaction.return_value.__enter__ = MagicMock(return_value=mock_txn) + mock_table.transaction.return_value.__exit__ = MagicMock(return_value=False) + + self.fg._get_iceberg_properties.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": mock_table, + "properties": {}, + } + + props = IcebergProperties(properties={"write.target-file-size-bytes": "value"}) + result = self.fg._update_iceberg_properties(iceberg_properties=props) + + mock_txn.set_properties.assert_called_once_with(props.properties) + + def test_raises_runtime_error_on_update_table_client_error(self): + """Test RuntimeError wrapping Exception from pyiceberg transaction.""" + mock_table = MagicMock() + mock_table.transaction().__enter__().set_properties.side_effect = Exception("Access denied") + + self.fg._get_iceberg_properties.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": mock_table, + "properties": {}, + } + + props = IcebergProperties(properties={"write.target-file-size-bytes": "value"}) + + with pytest.raises(RuntimeError, match="Failed to update Iceberg properties"): + self.fg._update_iceberg_properties(iceberg_properties=props) + + def test_raises_error_on_duplicate_keys(self): + """Test ValueError when iceberg_properties has duplicate keys.""" + props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + mock_props = MagicMock() + mock_props.keys.return_value = [ + "write.target-file-size-bytes", + "write.target-file-size-bytes", + ] + mock_props.__bool__ = lambda self: True + object.__setattr__(props, "properties", mock_props) + + with pytest.raises(ValueError, match="Invalid duplicate properties"): + self.fg._update_iceberg_properties(iceberg_properties=props) + + def test_logs_before_after_property_changes(self, caplog): + """Test that update logs before/after diff of property changes at INFO level.""" + import logging + + mock_table = MagicMock() + mock_txn = MagicMock() + mock_table.transaction.return_value.__enter__ = MagicMock(return_value=mock_txn) + mock_table.transaction.return_value.__exit__ = MagicMock(return_value=False) + + self.fg._get_iceberg_properties.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": mock_table, + "properties": {"write.target-file-size-bytes": "268435456"}, + } + + props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + + with caplog.at_level(logging.INFO, logger="sagemaker.mlops.feature_store.feature_group_manager"): + self.fg._update_iceberg_properties(iceberg_properties=props) + + assert "test-fg" in caplog.text + assert "'old': '268435456'" in caplog.text + assert "'new': '536870912'" in caplog.text + assert "Successfully updated" in caplog.text + + +class TestCreateWithIcebergProperties: + """Tests for create() method with iceberg_properties parameter.""" + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroupManager, "get") + def test_no_iceberg_operations_when_none(self, mock_get, mock_get_client): + """Test no iceberg operations when iceberg_properties is None.""" + from sagemaker.core.shapes import FeatureDefinition + + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + mock_fg = MagicMock(spec=FeatureGroupManager) + mock_get.return_value = mock_fg + + feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ] + + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + ) + + mock_fg.wait_for_status.assert_not_called() + mock_fg._update_iceberg_properties.assert_not_called() + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroupManager, "get") + def test_no_iceberg_operations_when_properties_empty(self, mock_get, mock_get_client): + """Test no iceberg operations when iceberg_properties.properties is empty.""" + from sagemaker.core.shapes import FeatureDefinition + + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + mock_fg = MagicMock(spec=FeatureGroupManager) + mock_get.return_value = mock_fg + + feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ] + + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + iceberg_properties=IcebergProperties(), + ) + + mock_fg._update_iceberg_properties.assert_not_called() + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_without_offline_store_config(self, mock_get_client): + """Test ValueError when iceberg_properties provided without offline_store_config.""" + from sagemaker.core.shapes import FeatureDefinition + + mock_get_client.return_value = MagicMock() + + feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ] + + with pytest.raises(ValueError, match="iceberg_properties requires offline_store_config"): + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "value"}), + ) + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_table_format_not_iceberg(self, mock_get_client): + """Test ValueError when table_format is not Iceberg.""" + from sagemaker.core.shapes import FeatureDefinition, OfflineStoreConfig, S3StorageConfig + + mock_get_client.return_value = MagicMock() + + feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ] + + with pytest.raises(ValueError, match="table_format to be 'Iceberg'"): + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + ), + iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "value"}), + ) + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_table_format_is_glue(self, mock_get_client): + """Test ValueError when table_format is explicitly Glue.""" + from sagemaker.core.shapes import FeatureDefinition, OfflineStoreConfig, S3StorageConfig + + mock_get_client.return_value = MagicMock() + + feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ] + + with pytest.raises(ValueError, match="table_format to be 'Iceberg'"): + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + table_format="Glue", + ), + iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "value"}), + ) + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroupManager, "get") + @patch.object(FeatureGroupManager, "wait_for_status") + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + def test_update_called_after_create_with_properties( + self, mock_update, mock_wait, mock_get, mock_get_client + ): + """Test update_iceberg_properties called after create when properties provided.""" + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + mock_fg = MagicMock(spec=FeatureGroupManager) + mock_fg.wait_for_status = mock_wait + mock_fg._update_iceberg_properties = mock_update + mock_get.return_value = mock_fg + + feature_definitions = [ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ] + + iceberg_props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + + result = FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=feature_definitions, + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="test_db", table_name="test_table" + ), + table_format="Iceberg", + ), + iceberg_properties=iceberg_props, + ) + + # Verify wait_for_status called before update + mock_wait.assert_called_once_with(target_status="Created") + mock_update.assert_called_once_with( + iceberg_properties=iceberg_props, + session=None, + region=None, + ) + assert result == mock_fg + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroupManager, "get") + @patch.object(FeatureGroupManager, "wait_for_status") + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + def test_create_passes_session_and_region_to_update( + self, mock_update, mock_wait, mock_get, mock_get_client + ): + """Test that session and region are forwarded to _update_iceberg_properties.""" + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + mock_fg = MagicMock(spec=FeatureGroupManager) + mock_fg.wait_for_status = mock_wait + mock_fg._update_iceberg_properties = mock_update + mock_get.return_value = mock_fg + + mock_session = MagicMock(spec=Session) + iceberg_props = IcebergProperties(properties={"write.target-file-size-bytes": "val"}) + + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ], + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="db", table_name="tbl" + ), + table_format="Iceberg", + ), + iceberg_properties=iceberg_props, + session=mock_session, + region="eu-west-1", + ) + + mock_update.assert_called_once_with( + iceberg_properties=iceberg_props, + session=mock_session, + region="eu-west-1", + ) + + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + @patch.object(FeatureGroupManager, "get") + @patch.object(FeatureGroupManager, "wait_for_status") + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + def test_create_logs_and_reraises_when_iceberg_update_fails( + self, mock_update, mock_wait, mock_get, mock_get_client + ): + """Test that create logs error and re-raises when iceberg update fails after FG creation.""" + from sagemaker.core.shapes import ( + FeatureDefinition, + OfflineStoreConfig, + S3StorageConfig, + DataCatalogConfig, + ) + + mock_client = MagicMock() + mock_client.create_feature_group.return_value = { + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test" + } + mock_get_client.return_value = mock_client + + mock_fg = MagicMock(spec=FeatureGroupManager) + mock_fg.feature_group_name = "test-fg" + mock_fg.wait_for_status = mock_wait + mock_fg._update_iceberg_properties = mock_update + mock_get.return_value = mock_fg + + mock_update.side_effect = RuntimeError("Iceberg catalog error") + + iceberg_props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + + with pytest.raises(RuntimeError, match="Iceberg catalog error"): + FeatureGroupManager.create( + feature_group_name="test-fg", + record_identifier_feature_name="record_id", + event_time_feature_name="event_time", + feature_definitions=[ + FeatureDefinition(feature_name="record_id", feature_type="String"), + FeatureDefinition(feature_name="event_time", feature_type="String"), + ], + offline_store_config=OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="db", table_name="tbl" + ), + table_format="Iceberg", + ), + iceberg_properties=iceberg_props, + ) + + # FG was created (super().create called) and waited for status + mock_client.create_feature_group.assert_called_once() + mock_wait.assert_called_once_with(target_status="Created") + + +class TestUpdateWithIcebergProperties: + """Tests for update() method with iceberg_properties parameter.""" + + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_no_iceberg_operations_when_none(self, mock_get_client, mock_refresh, mock_update_iceberg): + """Test no iceberg operations when iceberg_properties is None.""" + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.update(description="new description") + + mock_update_iceberg.assert_not_called() + + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_no_iceberg_operations_when_properties_empty(self, mock_get_client, mock_refresh, mock_update_iceberg): + """Test no iceberg operations when iceberg_properties.properties is None.""" + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.update(iceberg_properties=IcebergProperties()) + + mock_update_iceberg.assert_not_called() + + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_iceberg_update_called_with_properties(self, mock_get_client, mock_refresh, mock_update_iceberg): + """Test _update_iceberg_properties called when properties provided.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="db", table_name="tbl" + ), + table_format="Iceberg", + ) + iceberg_props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + fg.update(description="new description", iceberg_properties=iceberg_props, session=None, region=None) + + mock_update_iceberg.assert_called_once_with( + iceberg_properties=iceberg_props, + session=None, + region=None, + ) + + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_skips_parent_update_when_only_iceberg_properties(self, mock_get_client, mock_update_iceberg): + """Test that super().update() is not called when only iceberg_properties are passed.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + mock_client = MagicMock() + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="db", table_name="tbl" + ), + table_format="Iceberg", + ) + iceberg_props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + result = fg.update(iceberg_properties=iceberg_props, session=None, region=None) + + mock_client.update_feature_group.assert_not_called() + mock_update_iceberg.assert_called_once() + assert result == fg + + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_parent_update_receives_only_standard_params(self, mock_get_client, mock_refresh, mock_update_iceberg): + """Test that iceberg_properties is not passed to the parent update().""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="db", table_name="tbl" + ), + table_format="Iceberg", + ) + fg.update( + description="new desc", + iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "val"}), + ) + + # Verify the SageMaker API call does NOT contain iceberg_properties + call_args = mock_client.update_feature_group.call_args + assert "IcebergProperties" not in call_args[1] + assert "Description" in call_args[1] + + @patch.object(FeatureGroupManager, "_update_iceberg_properties") + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_update_logs_and_reraises_when_iceberg_update_fails(self, mock_get_client, mock_refresh, mock_update_iceberg): + """Test that update logs error and re-raises when iceberg update fails after FG update.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig, DataCatalogConfig + + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + mock_update_iceberg.side_effect = RuntimeError("Iceberg catalog error") + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + data_catalog_config=DataCatalogConfig( + catalog="AwsDataCatalog", database="db", table_name="tbl" + ), + table_format="Iceberg", + ) + iceberg_props = IcebergProperties(properties={"write.target-file-size-bytes": "536870912"}) + + with pytest.raises(RuntimeError, match="Iceberg catalog error"): + fg.update(description="new desc", iceberg_properties=iceberg_props, session=None, region=None) + + # Parent update was called successfully before iceberg update failed + mock_client.update_feature_group.assert_called_once() + + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_no_offline_store(self, mock_get_client, mock_refresh): + """Test ValueError when iceberg_properties provided without offline_store_config.""" + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.offline_store_config = None + + with pytest.raises(ValueError, match="iceberg_properties requires offline_store_config"): + fg.update(iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "val"})) + + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_offline_store_is_unassigned(self, mock_get_client, mock_refresh): + """Test ValueError when offline_store_config is Unassigned().""" + from sagemaker.core.shapes import Unassigned + + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + object.__setattr__(fg, "offline_store_config", Unassigned()) + + with pytest.raises(ValueError, match="iceberg_properties requires offline_store_config"): + fg.update(iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "val"})) + + @patch.object(FeatureGroupManager, "refresh") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_validation_error_when_table_format_not_iceberg(self, mock_get_client, mock_refresh): + """Test ValueError when table_format is not Iceberg.""" + from sagemaker.core.shapes import OfflineStoreConfig, S3StorageConfig + + mock_client = MagicMock() + mock_client.update_feature_group.return_value = {} + mock_get_client.return_value = mock_client + + fg = FeatureGroupManager(feature_group_name="test-fg") + fg.offline_store_config = OfflineStoreConfig( + s3_storage_config=S3StorageConfig(s3_uri="s3://bucket/path"), + table_format="Glue", + ) + + with pytest.raises(ValueError, match="table_format to be 'Iceberg'"): + fg.update(iceberg_properties=IcebergProperties(properties={"write.target-file-size-bytes": "val"})) + + +class TestGetWithIcebergProperties: + """Tests for get() method with include_iceberg_properties flag.""" + + @patch.object(FeatureGroupManager, "_get_iceberg_properties") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_no_iceberg_fetch_by_default(self, mock_get_client, mock_get_iceberg): + """Test that Iceberg properties are not fetched when flag is False (default).""" + from sagemaker.core.shapes import FeatureDefinition + + mock_client = MagicMock() + mock_client.describe_feature_group.return_value = { + "FeatureGroupName": "test-fg", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg", + "RecordIdentifierFeatureName": "record_id", + "EventTimeFeatureName": "event_time", + "FeatureDefinitions": [ + {"FeatureName": "record_id", "FeatureType": "String"}, + ], + "CreationTime": "2024-01-01T00:00:00Z", + } + mock_get_client.return_value = mock_client + + result = FeatureGroupManager.get(feature_group_name="test-fg") + + mock_get_iceberg.assert_not_called() + + @patch.object(FeatureGroupManager, "_get_iceberg_properties") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_iceberg_properties_fetched_when_flag_true(self, mock_get_client, mock_get_iceberg): + """Test that Iceberg properties are fetched and stored when flag is True.""" + mock_client = MagicMock() + mock_client.describe_feature_group.return_value = { + "FeatureGroupName": "test-fg", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg", + "RecordIdentifierFeatureName": "record_id", + "EventTimeFeatureName": "event_time", + "FeatureDefinitions": [ + {"FeatureName": "record_id", "FeatureType": "String"}, + ], + "CreationTime": "2024-01-01T00:00:00Z", + } + mock_get_client.return_value = mock_client + + mock_get_iceberg.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": MagicMock(), + "properties": { + "write.target-file-size-bytes": "536870912", + }, + } + + result = FeatureGroupManager.get( + feature_group_name="test-fg", + include_iceberg_properties=True, + ) + + mock_get_iceberg.assert_called_once_with(session=None, region=None) + assert result.iceberg_properties.properties == { + "write.target-file-size-bytes": "536870912", + } + + @patch.object(FeatureGroupManager, "_get_iceberg_properties") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_iceberg_properties_empty_parameters(self, mock_get_client, mock_get_iceberg): + """Test that empty Parameters dict results in empty properties.""" + mock_client = MagicMock() + mock_client.describe_feature_group.return_value = { + "FeatureGroupName": "test-fg", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg", + "RecordIdentifierFeatureName": "record_id", + "EventTimeFeatureName": "event_time", + "FeatureDefinitions": [ + {"FeatureName": "record_id", "FeatureType": "String"}, + ], + "CreationTime": "2024-01-01T00:00:00Z", + } + mock_get_client.return_value = mock_client + + mock_get_iceberg.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": MagicMock(), + "properties": {}, + } + + result = FeatureGroupManager.get( + feature_group_name="test-fg", + include_iceberg_properties=True, + ) + + assert result.iceberg_properties.properties == {} + + @patch.object(FeatureGroupManager, "_get_iceberg_properties") + @patch("sagemaker.core.resources.Base.get_sagemaker_client") + def test_passes_session_and_region_to_get_iceberg_properties(self, mock_get_client, mock_get_iceberg): + """Test that session and region kwargs are forwarded to _get_iceberg_properties.""" + mock_client = MagicMock() + mock_client.describe_feature_group.return_value = { + "FeatureGroupName": "test-fg", + "FeatureGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:feature-group/test-fg", + "RecordIdentifierFeatureName": "record_id", + "EventTimeFeatureName": "event_time", + "FeatureDefinitions": [ + {"FeatureName": "record_id", "FeatureType": "String"}, + ], + "CreationTime": "2024-01-01T00:00:00Z", + } + mock_get_client.return_value = mock_client + + mock_session = MagicMock(spec=Session) + mock_get_iceberg.return_value = { + "database_name": "test_db", + "table_name": "test_table", + "table": MagicMock(), + "properties": {"write.target-file-size-bytes": "val"}, + } + + FeatureGroupManager.get( + feature_group_name="test-fg", + include_iceberg_properties=True, + session=mock_session, + region="us-east-1", + ) + + mock_get_iceberg.assert_called_once_with(session=mock_session, region="us-east-1")