Skip to content

Commit 2c05b46

Browse files
authored
Move payload codec, encode, and decode calls to DataConverter helper methods (#1305)
1 parent 0a3a885 commit 2c05b46

File tree

6 files changed

+206
-155
lines changed

6 files changed

+206
-155
lines changed

temporalio/bridge/worker.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,21 +299,23 @@ async def visit_payloads(self, payloads: MutableSequence[Payload]) -> None:
299299

300300
async def decode_activation(
301301
activation: temporalio.bridge.proto.workflow_activation.WorkflowActivation,
302-
codec: temporalio.converter.PayloadCodec,
302+
data_converter: temporalio.converter.DataConverter,
303303
decode_headers: bool,
304304
) -> None:
305305
"""Decode all payloads in the activation."""
306-
await CommandAwarePayloadVisitor(
307-
skip_search_attributes=True, skip_headers=not decode_headers
308-
).visit(_Visitor(codec.decode), activation)
306+
if data_converter._decode_payload_has_effect:
307+
await CommandAwarePayloadVisitor(
308+
skip_search_attributes=True, skip_headers=not decode_headers
309+
).visit(_Visitor(data_converter._decode_payload_sequence), activation)
309310

310311

311312
async def encode_completion(
312313
completion: temporalio.bridge.proto.workflow_completion.WorkflowActivationCompletion,
313-
codec: temporalio.converter.PayloadCodec,
314+
data_converter: temporalio.converter.DataConverter,
314315
encode_headers: bool,
315316
) -> None:
316317
"""Encode all payloads in the completion."""
317-
await CommandAwarePayloadVisitor(
318-
skip_search_attributes=True, skip_headers=not encode_headers
319-
).visit(_Visitor(codec.encode), completion)
318+
if data_converter._encode_payload_has_effect:
319+
await CommandAwarePayloadVisitor(
320+
skip_search_attributes=True, skip_headers=not encode_headers
321+
).visit(_Visitor(data_converter._encode_payload_sequence), completion)

temporalio/client.py

Lines changed: 24 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -2977,10 +2977,7 @@ async def memo(self) -> Mapping[str, Any]:
29772977
Returns:
29782978
Mapping of all memo keys and they values without type hints.
29792979
"""
2980-
return {
2981-
k: (await self.data_converter.decode([v]))[0]
2982-
for k, v in self.raw_info.memo.fields.items()
2983-
}
2980+
return await self.data_converter._decode_memo(self.raw_info.memo)
29842981

29852982
@overload
29862983
async def memo_value(
@@ -3019,16 +3016,9 @@ async def memo_value(
30193016
Raises:
30203017
KeyError: Key not present and default not set.
30213018
"""
3022-
payload = self.raw_info.memo.fields.get(key)
3023-
if not payload:
3024-
if default is temporalio.common._arg_unset:
3025-
raise KeyError(f"Memo does not have a value for key {key}")
3026-
return default
3027-
return (
3028-
await self.data_converter.decode(
3029-
[payload], [type_hint] if type_hint else None
3030-
)
3031-
)[0]
3019+
return await self.data_converter._decode_memo_field(
3020+
self.raw_info.memo, key, default, type_hint
3021+
)
30323022

30333023

30343024
@dataclass
@@ -4209,18 +4199,9 @@ async def _to_proto(
42094199
workflow_run_timeout=run_timeout,
42104200
workflow_task_timeout=task_timeout,
42114201
retry_policy=retry_policy,
4212-
memo=(
4213-
temporalio.api.common.v1.Memo(
4214-
fields={
4215-
k: v
4216-
if isinstance(v, temporalio.api.common.v1.Payload)
4217-
else (await data_converter.encode([v]))[0]
4218-
for k, v in self.memo.items()
4219-
},
4220-
)
4221-
if self.memo
4222-
else None
4223-
),
4202+
memo=await data_converter._encode_memo(self.memo)
4203+
if self.memo
4204+
else None,
42244205
user_metadata=await _encode_user_metadata(
42254206
data_converter, self.static_summary, self.static_details
42264207
),
@@ -4249,7 +4230,7 @@ async def _to_proto(
42494230
client.config(active_config=True)["header_codec_behavior"]
42504231
== HeaderCodecBehavior.CODEC
42514232
and not self._from_raw,
4252-
client.data_converter.payload_codec,
4233+
client.data_converter,
42534234
)
42544235
return action
42554236

@@ -4521,10 +4502,7 @@ async def memo(self) -> Mapping[str, Any]:
45214502
Returns:
45224503
Mapping of all memo keys and they values without type hints.
45234504
"""
4524-
return {
4525-
k: (await self.data_converter.decode([v]))[0]
4526-
for k, v in self.raw_description.memo.fields.items()
4527-
}
4505+
return await self.data_converter._decode_memo(self.raw_description.memo)
45284506

45294507
@overload
45304508
async def memo_value(
@@ -4563,16 +4541,9 @@ async def memo_value(
45634541
Raises:
45644542
KeyError: Key not present and default not set.
45654543
"""
4566-
payload = self.raw_description.memo.fields.get(key)
4567-
if not payload:
4568-
if default is temporalio.common._arg_unset:
4569-
raise KeyError(f"Memo does not have a value for key {key}")
4570-
return default
4571-
return (
4572-
await self.data_converter.decode(
4573-
[payload], [type_hint] if type_hint else None
4574-
)
4575-
)[0]
4544+
return await self.data_converter._decode_memo_field(
4545+
self.raw_description.memo, key, default, type_hint
4546+
)
45764547

45774548

45784549
@dataclass
@@ -4770,10 +4741,7 @@ async def memo(self) -> Mapping[str, Any]:
47704741
Returns:
47714742
Mapping of all memo keys and they values without type hints.
47724743
"""
4773-
return {
4774-
k: (await self.data_converter.decode([v]))[0]
4775-
for k, v in self.raw_entry.memo.fields.items()
4776-
}
4744+
return await self.data_converter._decode_memo(self.raw_entry.memo)
47774745

47784746
@overload
47794747
async def memo_value(
@@ -4812,16 +4780,9 @@ async def memo_value(
48124780
Raises:
48134781
KeyError: Key not present and default not set.
48144782
"""
4815-
payload = self.raw_entry.memo.fields.get(key)
4816-
if not payload:
4817-
if default is temporalio.common._arg_unset:
4818-
raise KeyError(f"Memo does not have a value for key {key}")
4819-
return default
4820-
return (
4821-
await self.data_converter.decode(
4822-
[payload], [type_hint] if type_hint else None
4823-
)
4824-
)[0]
4783+
return await self.data_converter._decode_memo_field(
4784+
self.raw_entry.memo, key, default, type_hint
4785+
)
48254786

48264787

48274788
@dataclass
@@ -6014,8 +5975,7 @@ async def _populate_start_workflow_execution_request(
60145975
input.retry_policy.apply_to_proto(req.retry_policy)
60155976
req.cron_schedule = input.cron_schedule
60165977
if input.memo is not None:
6017-
for k, v in input.memo.items():
6018-
req.memo.fields[k].CopyFrom((await data_converter.encode([v]))[0])
5978+
await data_converter._encode_memo_existing(input.memo, req.memo)
60195979
if input.search_attributes is not None:
60205980
temporalio.converter.encode_search_attributes(
60215981
input.search_attributes, req.search_attributes
@@ -6641,14 +6601,9 @@ async def create_schedule(self, input: CreateScheduleInput) -> ScheduleHandle:
66416601
initial_patch=initial_patch,
66426602
identity=self._client.identity,
66436603
request_id=str(uuid.uuid4()),
6644-
memo=None
6645-
if not input.memo
6646-
else temporalio.api.common.v1.Memo(
6647-
fields={
6648-
k: (await self._client.data_converter.encode([v]))[0]
6649-
for k, v in input.memo.items()
6650-
},
6651-
),
6604+
memo=await self._client.data_converter._encode_memo(input.memo)
6605+
if input.memo
6606+
else None,
66526607
)
66536608
if input.search_attributes:
66546609
temporalio.converter.encode_search_attributes(
@@ -6870,22 +6825,21 @@ async def _apply_headers(
68706825
dest,
68716826
self._client.config(active_config=True)["header_codec_behavior"]
68726827
== HeaderCodecBehavior.CODEC,
6873-
self._client.data_converter.payload_codec,
6828+
self._client.data_converter,
68746829
)
68756830

68766831

68776832
async def _apply_headers(
68786833
source: Mapping[str, temporalio.api.common.v1.Payload] | None,
68796834
dest: MessageMap[str, temporalio.api.common.v1.Payload],
68806835
encode_headers: bool,
6881-
codec: temporalio.converter.PayloadCodec | None,
6836+
data_converter: DataConverter,
68826837
) -> None:
68836838
if source is None:
68846839
return
6885-
if encode_headers and codec is not None:
6840+
if encode_headers and data_converter._encode_payload_has_effect:
68866841
for payload in source.values():
6887-
new_payload = (await codec.encode([payload]))[0]
6888-
payload.CopyFrom(new_payload)
6842+
payload.CopyFrom(await data_converter._encode_payload(payload))
68896843
temporalio.common._apply_headers(source, dest)
68906844

68916845

0 commit comments

Comments
 (0)