Skip to content
10 changes: 6 additions & 4 deletions aws_lambda_powertools/utilities/idempotency/idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def handle(self) -> Any:
try:
# We call save_inprogress first as an optimization for the most common case where no idempotent record
# already exists. If it succeeds, there's no need to call get_record.
self.persistence_store.save_inprogress(event=self.event)
self.persistence_store.save_inprogress(event=self.event, context=self.context)
except IdempotencyItemAlreadyExistsError:
# Now we know the item already exists, we can retrieve it
record = self._get_idempotency_record()
Expand All @@ -151,7 +151,7 @@ def _get_idempotency_record(self) -> DataRecord:

"""
try:
event_record = self.persistence_store.get_record(self.event)
event_record = self.persistence_store.get_record(event=self.event, context=self.context)
except IdempotencyItemNotFoundError:
# This code path will only be triggered if the record is removed between save_inprogress and get_record.
logger.debug(
Expand Down Expand Up @@ -219,7 +219,9 @@ def _call_lambda_handler(self) -> Any:
# We need these nested blocks to preserve lambda handler exception in case the persistence store operation
# also raises an exception
try:
self.persistence_store.delete_record(event=self.event, exception=handler_exception)
self.persistence_store.delete_record(
event=self.event, context=self.context, exception=handler_exception
)
except Exception as delete_exception:
raise IdempotencyPersistenceLayerError(
"Failed to delete record from idempotency store"
Expand All @@ -228,7 +230,7 @@ def _call_lambda_handler(self) -> Any:

else:
try:
self.persistence_store.save_success(event=self.event, result=handler_response)
self.persistence_store.save_success(event=self.event, context=self.context, result=handler_response)
except Exception as save_exception:
raise IdempotencyPersistenceLayerError(
"Failed to update record state to success in idempotency store"
Expand Down
41 changes: 26 additions & 15 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IdempotencyKeyError,
IdempotencyValidationError,
)
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -152,34 +153,35 @@ def configure(self, config: IdempotencyConfig) -> None:
self._cache = LRUDict(max_items=config.local_cache_max_items)
self.hash_function = getattr(hashlib, config.hash_function)

def _get_hashed_idempotency_key(self, lambda_event: Dict[str, Any]) -> str:
def _get_hashed_idempotency_key(self, event: Dict[str, Any], context: LambdaContext) -> str:
"""
Extract data from lambda event using event key jmespath, and return a hashed representation

Parameters
----------
lambda_event: Dict[str, Any]
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context

Returns
-------
str
Hashed representation of the data extracted by the jmespath expression

"""
data = lambda_event
data = event

if self.event_key_jmespath:
data = self.event_key_compiled_jmespath.search(
lambda_event, options=jmespath.Options(**self.jmespath_options)
)
data = self.event_key_compiled_jmespath.search(event, options=jmespath.Options(**self.jmespath_options))

if self.is_missing_idempotency_key(data):
if self.raise_on_no_idempotency_key:
raise IdempotencyKeyError("No data found to create a hashed idempotency_key")
warnings.warn(f"No value found for idempotency_key. jmespath: {self.event_key_jmespath}")

return self._generate_hash(data)
generated_hash = self._generate_hash(data)
return f"{context.function_name}#{generated_hash}"

@staticmethod
def is_missing_idempotency_key(data) -> bool:
Expand Down Expand Up @@ -298,21 +300,23 @@ def _delete_from_cache(self, idempotency_key: str):
if idempotency_key in self._cache:
del self._cache[idempotency_key]

def save_success(self, event: Dict[str, Any], result: dict) -> None:
def save_success(self, event: Dict[str, Any], context: LambdaContext, result: dict) -> None:
"""
Save record of function's execution completing successfully

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
result: dict
The response from lambda handler
"""
response_data = json.dumps(result, cls=Encoder)

data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(event),
idempotency_key=self._get_hashed_idempotency_key(event, context),
status=STATUS_CONSTANTS["COMPLETED"],
expiry_timestamp=self._get_expiry_timestamp(),
response_data=response_data,
Expand All @@ -326,17 +330,19 @@ def save_success(self, event: Dict[str, Any], result: dict) -> None:

self._save_to_cache(data_record)

def save_inprogress(self, event: Dict[str, Any]) -> None:
def save_inprogress(self, event: Dict[str, Any], context: LambdaContext) -> None:
"""
Save record of function's execution being in progress

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
"""
data_record = DataRecord(
idempotency_key=self._get_hashed_idempotency_key(event),
idempotency_key=self._get_hashed_idempotency_key(event, context),
status=STATUS_CONSTANTS["INPROGRESS"],
expiry_timestamp=self._get_expiry_timestamp(),
payload_hash=self._get_hashed_payload(event),
Expand All @@ -349,18 +355,20 @@ def save_inprogress(self, event: Dict[str, Any]) -> None:

self._put_record(data_record)

def delete_record(self, event: Dict[str, Any], exception: Exception):
def delete_record(self, event: Dict[str, Any], context: LambdaContext, exception: Exception):
"""
Delete record from the persistence store

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context
exception
The exception raised by the lambda handler
"""
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event))
data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event, context))

logger.debug(
f"Lambda raised an exception ({type(exception).__name__}). Clearing in progress record in persistence "
Expand All @@ -370,14 +378,17 @@ def delete_record(self, event: Dict[str, Any], exception: Exception):

self._delete_from_cache(data_record.idempotency_key)

def get_record(self, event: Dict[str, Any]) -> DataRecord:
def get_record(self, event: Dict[str, Any], context: LambdaContext) -> DataRecord:
"""
Calculate idempotency key for lambda_event, then retrieve item from persistence store using idempotency key
and return it as a DataRecord instance.and return it as a DataRecord instance.

Parameters
----------
event: Dict[str, Any]
Lambda event
context: LambdaContext
Lambda context

Returns
-------
Expand All @@ -392,7 +403,7 @@ def get_record(self, event: Dict[str, Any]) -> DataRecord:
Event payload doesn't match the stored record for the given idempotency key
"""

idempotency_key = self._get_hashed_idempotency_key(event)
idempotency_key = self._get_hashed_idempotency_key(event, context)

cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
Expand Down
19 changes: 16 additions & 3 deletions tests/functional/idempotency/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import hashlib
import json
import os
from collections import namedtuple
from decimal import Decimal
from unittest import mock

Expand Down Expand Up @@ -34,6 +35,18 @@ def lambda_apigw_event():
return event


@pytest.fixture
def lambda_context():
lambda_context = {
"function_name": "test-func",
"memory_limit_in_mb": 128,
"invoked_function_arn": "arn:aws:lambda:eu-west-1:809313241234:function:test-func",
"aws_request_id": "52fdfc07-2182-154f-163f-5f0f9a621d72",
}

return namedtuple("LambdaContext", lambda_context.keys())(*lambda_context.values())


@pytest.fixture
def timestamp_future():
return str(int((datetime.datetime.now() + datetime.timedelta(seconds=3600)).timestamp()))
Expand Down Expand Up @@ -132,18 +145,18 @@ def expected_params_put_item_with_validation(hashed_idempotency_key, hashed_vali


@pytest.fixture
def hashed_idempotency_key(lambda_apigw_event, default_jmespath):
def hashed_idempotency_key(lambda_apigw_event, default_jmespath, lambda_context):
compiled_jmespath = jmespath.compile(default_jmespath)
data = compiled_jmespath.search(lambda_apigw_event)
return hashlib.md5(json.dumps(data).encode()).hexdigest()
return "test-func#" + hashlib.md5(json.dumps(data).encode()).hexdigest()


@pytest.fixture
def hashed_idempotency_key_with_envelope(lambda_apigw_event):
event = unwrap_event_from_envelope(
data=lambda_apigw_event, envelope=envelopes.API_GATEWAY_HTTP, jmespath_options={}
)
return hashlib.md5(json.dumps(event).encode()).hexdigest()
return "test-func#" + hashlib.md5(json.dumps(event).encode()).hexdigest()


@pytest.fixture
Expand Down
Loading