Skip to content
Next Next commit
feat: add snapshot expiration methods with retention strategies
  • Loading branch information
ForeverAngry committed Aug 22, 2025
commit 3da1528a41ddd34e99cdf395ade390059db1de32
196 changes: 195 additions & 1 deletion pyiceberg/table/update/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -953,7 +953,7 @@ def _get_protected_snapshot_ids(self) -> Set[int]:
if ref.snapshot_ref_type in [SnapshotRefType.TAG, SnapshotRefType.BRANCH]
}

def by_id(self, snapshot_id: int) -> ExpireSnapshots:
def by_id(self, snapshot_id: int) -> "ExpireSnapshots":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw since we have from __future__ import annotations at the top of the file I think its cleaner to make things consistent to not have quotes. Probably outside of the scope of this PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thats a great point. I can log an issue to address these all at one.

"""
Expire a snapshot by its ID.

Expand Down Expand Up @@ -1005,3 +1005,197 @@ def older_than(self, dt: datetime) -> "ExpireSnapshots":
if snapshot.timestamp_ms < expire_from and snapshot.snapshot_id not in protected_ids:
self._snapshot_ids_to_expire.add(snapshot.snapshot_id)
return self

def older_than_with_retention(
self, timestamp_ms: int, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
) -> "ExpireSnapshots":
"""Expire all unprotected snapshots with a timestamp older than a given value, with retention strategies.

Args:
timestamp_ms: Only snapshots with timestamp_ms < this value will be expired.
retain_last_n: Always keep the last N snapshots regardless of age.
min_snapshots_to_keep: Minimum number of snapshots to keep in total.

Returns:
This for method chaining.
"""
snapshots_to_expire = self._get_snapshots_to_expire_with_retention(
timestamp_ms=timestamp_ms, retain_last_n=retain_last_n, min_snapshots_to_keep=min_snapshots_to_keep
)
self._snapshot_ids_to_expire.update(snapshots_to_expire)
return self

def with_retention_policy(
self, timestamp_ms: Optional[int] = None, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
) -> "ExpireSnapshots":
"""Comprehensive snapshot expiration with multiple retention strategies.

This method provides a unified interface for snapshot expiration with various
retention policies to ensure operational resilience while allowing space reclamation.

The method will use table properties as defaults if they are set:
- history.expire.max-snapshot-age-ms: Default for timestamp_ms if not provided
- history.expire.min-snapshots-to-keep: Default for min_snapshots_to_keep if not provided
- history.expire.max-ref-age-ms: Used for ref expiration (branches/tags)

Args:
timestamp_ms: Only snapshots with timestamp_ms < this value will be considered for expiration.
If None, will use history.expire.max-snapshot-age-ms table property if set.
retain_last_n: Always keep the last N snapshots regardless of age.
Useful when regular snapshot creation occurs and users want to keep
the last few for rollback purposes.
min_snapshots_to_keep: Minimum number of snapshots to keep in total.
Acts as a guardrail to prevent aggressive expiration logic.
If None, will use history.expire.min-snapshots-to-keep table property if set.

Returns:
This for method chaining.

Raises:
ValueError: If retain_last_n or min_snapshots_to_keep is less than 1.

Examples:
# Use table property defaults
table.expire_snapshots().with_retention_policy().commit()

# Override defaults with explicit values
table.expire_snapshots().with_retention_policy(
timestamp_ms=1234567890000,
retain_last_n=10,
min_snapshots_to_keep=5
).commit()
"""
# Get default values from table properties
default_max_age, default_min_snapshots, _ = self._get_expiration_properties()

# Use defaults from table properties if not explicitly provided
if timestamp_ms is None:
timestamp_ms = default_max_age

if min_snapshots_to_keep is None:
min_snapshots_to_keep = default_min_snapshots

# If no expiration criteria are provided, don't expire anything
if timestamp_ms is None and retain_last_n is None and min_snapshots_to_keep is None:
return self

if retain_last_n is not None and retain_last_n < 1:
raise ValueError("retain_last_n must be at least 1")

if min_snapshots_to_keep is not None and min_snapshots_to_keep < 1:
raise ValueError("min_snapshots_to_keep must be at least 1")

snapshots_to_expire = self._get_snapshots_to_expire_with_retention(
timestamp_ms=timestamp_ms, retain_last_n=retain_last_n, min_snapshots_to_keep=min_snapshots_to_keep
)
self._snapshot_ids_to_expire.update(snapshots_to_expire)
return self

def retain_last_n(self, n: int) -> "ExpireSnapshots":
"""Keep only the last N snapshots, expiring all others.

Args:
n: Number of most recent snapshots to keep.

Returns:
This for method chaining.

Raises:
ValueError: If n is less than 1.
"""
if n < 1:
raise ValueError("Number of snapshots to retain must be at least 1")

protected_ids = self._get_protected_snapshot_ids()

# Sort snapshots by timestamp (most recent first)
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)

# Keep the last N snapshots and all protected ones
snapshots_to_keep = set()
snapshots_to_keep.update(protected_ids)

# Add the N most recent snapshots
for i, snapshot in enumerate(sorted_snapshots):
if i < n:
snapshots_to_keep.add(snapshot.snapshot_id)

# Find snapshots to expire
snapshots_to_expire = []
for snapshot in self._transaction.table_metadata.snapshots:
if snapshot.snapshot_id not in snapshots_to_keep:
snapshots_to_expire.append(snapshot.snapshot_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
protected_ids = self._get_protected_snapshot_ids()
# Sort snapshots by timestamp (most recent first)
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
# Keep the last N snapshots and all protected ones
snapshots_to_keep = set()
snapshots_to_keep.update(protected_ids)
# Add the N most recent snapshots
for i, snapshot in enumerate(sorted_snapshots):
if i < n:
snapshots_to_keep.add(snapshot.snapshot_id)
# Find snapshots to expire
snapshots_to_expire = []
for snapshot in self._transaction.table_metadata.snapshots:
if snapshot.snapshot_id not in snapshots_to_keep:
snapshots_to_expire.append(snapshot.snapshot_id)
snapshots_to_keep = self._get_protected_snapshot_ids()
# Sort snapshots by timestamp (most recent first), and get most recent N
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)
snapshots_to_keep.update(snapshot.snapshot_id for snapshot in sorted_snapshots[:n])
snapshots_to_expire = [id for snapshot in self._transaction.table_metadata.snapshots if (id := snapshot.snapshot_id) not in snapshots_to_keep]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small syntax change to make more pythonic :)


self._snapshot_ids_to_expire.update(snapshots_to_expire)
return self

def _get_snapshots_to_expire_with_retention(
self, timestamp_ms: Optional[int] = None, retain_last_n: Optional[int] = None, min_snapshots_to_keep: Optional[int] = None
) -> List[int]:
"""Get snapshots to expire considering retention strategies.

Args:
timestamp_ms: Only snapshots with timestamp_ms < this value will be considered for expiration.
retain_last_n: Always keep the last N snapshots regardless of age.
min_snapshots_to_keep: Minimum number of snapshots to keep in total.

Returns:
List of snapshot IDs to expire.
"""
protected_ids = self._get_protected_snapshot_ids()

# Sort snapshots by timestamp (most recent first)
sorted_snapshots = sorted(self._transaction.table_metadata.snapshots, key=lambda s: s.timestamp_ms, reverse=True)

# Start with all snapshots that could be expired
candidates_for_expiration = []
snapshots_to_keep = set(protected_ids)

# Apply retain_last_n constraint
if retain_last_n is not None:
for i, snapshot in enumerate(sorted_snapshots):
if i < retain_last_n:
snapshots_to_keep.add(snapshot.snapshot_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is the same as in retain_last_n, can we refactor to its own function? I think we also need to handle branches and take the last n of each branch


# Apply timestamp constraint
for snapshot in self._transaction.table_metadata.snapshots:
if snapshot.snapshot_id not in snapshots_to_keep and (timestamp_ms is None or snapshot.timestamp_ms < timestamp_ms):
candidates_for_expiration.append(snapshot)
Comment on lines +1156 to +1158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make more pythonic with comprehension?


# Sort candidates by timestamp (oldest first) for potential expiration
candidates_for_expiration.sort(key=lambda s: s.timestamp_ms)

# Apply min_snapshots_to_keep constraint
total_snapshots = len(self._transaction.table_metadata.snapshots)
snapshots_to_expire: List[int] = []

for candidate in candidates_for_expiration:
# Check if expiring this snapshot would violate min_snapshots_to_keep
remaining_after_expiration = total_snapshots - len(snapshots_to_expire) - 1

if min_snapshots_to_keep is None or remaining_after_expiration >= min_snapshots_to_keep:
snapshots_to_expire.append(candidate.snapshot_id)
else:
# Stop expiring to maintain minimum count
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Sort candidates by timestamp (oldest first) for potential expiration
candidates_for_expiration.sort(key=lambda s: s.timestamp_ms)
# Apply min_snapshots_to_keep constraint
total_snapshots = len(self._transaction.table_metadata.snapshots)
snapshots_to_expire: List[int] = []
for candidate in candidates_for_expiration:
# Check if expiring this snapshot would violate min_snapshots_to_keep
remaining_after_expiration = total_snapshots - len(snapshots_to_expire) - 1
if min_snapshots_to_keep is None or remaining_after_expiration >= min_snapshots_to_keep:
snapshots_to_expire.append(candidate.snapshot_id)
else:
# Stop expiring to maintain minimum count
break
# Sort candidates by timestamp (newest first) for potential expiration
candidates_for_expiration.sort(key=lambda s: s.timestamp_ms, reverse=True)
snapshots_to_expire = candidates_for_expiration[min_snapshots_to_keep:]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double check that I didn't make an off-by-one error here but I believe this is a more concise way to express things :)


return snapshots_to_expire

def _get_expiration_properties(self) -> tuple[Optional[int], Optional[int], Optional[int]]:
"""Get the default expiration properties from table properties.

Returns:
Tuple of (max_snapshot_age_ms, min_snapshots_to_keep, max_ref_age_ms)
"""
properties = self._transaction.table_metadata.properties

max_snapshot_age_ms = properties.get("history.expire.max-snapshot-age-ms")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this string and the default value be a named constant somewhere? What do you think about using property_as_int from properties.py to be consistent with how properties are handled elsewhere?

max_snapshot_age = int(max_snapshot_age_ms) if max_snapshot_age_ms is not None else None

min_snapshots = properties.get("history.expire.min-snapshots-to-keep")
min_snapshots_to_keep = int(min_snapshots) if min_snapshots is not None else None

max_ref_age = properties.get("history.expire.max-ref-age-ms")
max_ref_age_ms = int(max_ref_age) if max_ref_age is not None else None

return max_snapshot_age, min_snapshots_to_keep, max_ref_age_ms
171 changes: 171 additions & 0 deletions tests/table/test_expire_snapshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest

from pyiceberg.table import CommitTableResponse, Table
from pyiceberg.table.update.snapshot import ExpireSnapshots


def test_cannot_expire_protected_head_snapshot(table_v2: Table) -> None:
Expand Down Expand Up @@ -223,3 +224,173 @@ def test_expire_snapshots_by_ids(table_v2: Table) -> None:
assert EXPIRE_SNAPSHOT_1 not in remaining_snapshots
assert EXPIRE_SNAPSHOT_2 not in remaining_snapshots
assert len(table_v2.metadata.snapshots) == 1


def test_retain_last_n_with_protection(table_v2: Table) -> None:
"""Test retain_last_n keeps most recent snapshots plus protected ones."""
from types import SimpleNamespace

# Clear shared state set on the class between tests
ExpireSnapshots._snapshot_ids_to_expire.clear()

S1 = 101 # oldest (also protected)
S2 = 102
S3 = 103
S4 = 104 # newest

# Protected S1 as branch head
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {
"main": MagicMock(snapshot_id=S1, snapshot_ref_type="branch"),
},
"snapshots": [
SimpleNamespace(snapshot_id=S1, timestamp_ms=1, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S2, timestamp_ms=2, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S3, timestamp_ms=3, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S4, timestamp_ms=4, parent_snapshot_id=None),
],
}
)

table_v2.catalog = MagicMock()
kept_ids = {S1, S3, S4} # retain_last_n=2 keeps S4,S3 plus protected S1
mock_response = CommitTableResponse(
metadata=table_v2.metadata.model_copy(update={"snapshots": list(kept_ids)}),
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
table_v2.catalog.commit_table.return_value = mock_response

table_v2.maintenance.expire_snapshots().retain_last_n(2).commit()
table_v2.metadata = mock_response.metadata

args, kwargs = table_v2.catalog.commit_table.call_args
updates = args[2] if len(args) > 2 else ()
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
assert remove_update is not None
# Only S2 should be expired
assert set(remove_update.snapshot_ids) == {S2}
assert S2 not in table_v2.metadata.snapshots


def test_older_than_with_retention_combination(table_v2: Table) -> None:
"""Test older_than_with_retention combining timestamp, retain_last_n and min_snapshots_to_keep."""
from types import SimpleNamespace

ExpireSnapshots._snapshot_ids_to_expire.clear()

# Create 5 snapshots with increasing timestamps
S1, S2, S3, S4, S5 = 201, 202, 203, 204, 205
snapshots = [
SimpleNamespace(snapshot_id=S1, timestamp_ms=100, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S2, timestamp_ms=200, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S3, timestamp_ms=300, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S4, timestamp_ms=400, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S5, timestamp_ms=500, parent_snapshot_id=None),
]
table_v2.metadata = table_v2.metadata.model_copy(update={"refs": {}, "snapshots": snapshots})
table_v2.catalog = MagicMock()

# Expect to expire S1,S2,S3 ; keep S4 (due to min snapshots) and S5 (retain_last_n=1)
mock_response = CommitTableResponse(
metadata=table_v2.metadata.model_copy(update={"snapshots": [S4, S5]}),
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
table_v2.catalog.commit_table.return_value = mock_response

table_v2.maintenance.expire_snapshots().older_than_with_retention(
timestamp_ms=450, retain_last_n=1, min_snapshots_to_keep=2
).commit()
table_v2.metadata = mock_response.metadata

args, kwargs = table_v2.catalog.commit_table.call_args
updates = args[2] if len(args) > 2 else ()
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
assert remove_update is not None
assert set(remove_update.snapshot_ids) == {S1, S2, S3}
assert set(table_v2.metadata.snapshots) == {S4, S5}


def test_with_retention_policy_defaults(table_v2: Table) -> None:
"""Test with_retention_policy uses table property defaults when arguments omitted."""
from types import SimpleNamespace

ExpireSnapshots._snapshot_ids_to_expire.clear()

# Properties: expire snapshots older than 350ms, keep at least 3 snapshots
properties = {
"history.expire.max-snapshot-age-ms": "350",
"history.expire.min-snapshots-to-keep": "3",
}
S1, S2, S3, S4, S5 = 301, 302, 303, 304, 305
snapshots = [
SimpleNamespace(snapshot_id=S1, timestamp_ms=100, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S2, timestamp_ms=200, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S3, timestamp_ms=300, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S4, timestamp_ms=400, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=S5, timestamp_ms=500, parent_snapshot_id=None),
]
table_v2.metadata = table_v2.metadata.model_copy(update={"refs": {}, "snapshots": snapshots, "properties": properties})
table_v2.catalog = MagicMock()

# Expect S1,S2 expired; S3 kept due to min_snapshots_to_keep
mock_response = CommitTableResponse(
metadata=table_v2.metadata.model_copy(update={"snapshots": [S3, S4, S5]}),
metadata_location="mock://metadata/location",
uuid=uuid4(),
)
table_v2.catalog.commit_table.return_value = mock_response

table_v2.maintenance.expire_snapshots().with_retention_policy().commit()
table_v2.metadata = mock_response.metadata

args, kwargs = table_v2.catalog.commit_table.call_args
updates = args[2] if len(args) > 2 else ()
remove_update = next((u for u in updates if getattr(u, "action", None) == "remove-snapshots"), None)
assert remove_update is not None
assert set(remove_update.snapshot_ids) == {S1, S2}
assert set(table_v2.metadata.snapshots) == {S3, S4, S5}


def test_get_expiration_properties(table_v2: Table) -> None:
"""Test retrieval of expiration properties from table metadata."""
ExpireSnapshots._snapshot_ids_to_expire.clear()
properties = {
"history.expire.max-snapshot-age-ms": "60000",
"history.expire.min-snapshots-to-keep": "5",
"history.expire.max-ref-age-ms": "120000",
}
table_v2.metadata = table_v2.metadata.model_copy(update={"properties": properties})
expire = table_v2.maintenance.expire_snapshots()
max_age, min_snaps, max_ref_age = expire._get_expiration_properties()
assert max_age == 60000
assert min_snaps == 5
assert max_ref_age == 120000


def test_get_snapshots_to_expire_with_retention_respects_protection(table_v2: Table) -> None:
"""Internal helper should not select protected snapshots for expiration."""
from types import SimpleNamespace

ExpireSnapshots._snapshot_ids_to_expire.clear()

P = 401 # protected
A = 402
B = 403
table_v2.metadata = table_v2.metadata.model_copy(
update={
"refs": {"main": MagicMock(snapshot_id=P, snapshot_ref_type="branch")},
"snapshots": [
SimpleNamespace(snapshot_id=P, timestamp_ms=10, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=A, timestamp_ms=20, parent_snapshot_id=None),
SimpleNamespace(snapshot_id=B, timestamp_ms=30, parent_snapshot_id=None),
],
}
)
expire = table_v2.maintenance.expire_snapshots()
to_expire = expire._get_snapshots_to_expire_with_retention(timestamp_ms=100, retain_last_n=None, min_snapshots_to_keep=1)
# Protected snapshot P should not be in list; both A and B can expire respecting min keep
assert P not in to_expire
assert set(to_expire) == {A, B}