Skip to content

Commit ef94b82

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Add transformMapEntry macros and a map-merging overload for cel.@mapInsert.
PiperOrigin-RevId: 794350506
1 parent d76b840 commit ef94b82

11 files changed

Lines changed: 614 additions & 116 deletions

eval/compiler/flat_expr_builder.cc

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,11 @@ const cel::Expr* GetOptimizableListAppendOperand(
357357
// Returns whether this comprehension appears to be a macro implementation for
358358
// map transformations. It is not exhaustive, so it is unsafe to use with custom
359359
// comprehensions outside of the standard macros or hand crafted ASTs.
360-
bool IsOptimizableMapInsert(const cel::ComprehensionExpr* comprehension) {
360+
bool IsOptimizableMapInsert(const cel::ComprehensionExpr* comprehension,
361+
bool enable_comprehension_mutable_map) {
362+
if (!enable_comprehension_mutable_map) {
363+
return false;
364+
}
361365
if (comprehension->iter_var().empty() || comprehension->iter_var2().empty()) {
362366
return false;
363367
}
@@ -383,7 +387,7 @@ bool IsOptimizableMapInsert(const cel::ComprehensionExpr* comprehension) {
383387
call_expr = &(call_expr->args()[1].call_expr());
384388
}
385389
return call_expr->function() == "cel.@mapInsert" &&
386-
call_expr->args().size() == 3 &&
390+
(call_expr->args().size() == 2 || call_expr->args().size() == 3) &&
387391
call_expr->args()[0].has_ident_expr() &&
388392
call_expr->args()[0].ident_expr().name() == accu_var;
389393
}
@@ -1407,7 +1411,9 @@ class FlatExprVisitor : public cel::AstVisitor {
14071411
/*.is_optimizable_list_append=*/
14081412
IsOptimizableListAppend(&comprehension,
14091413
options_.enable_comprehension_list_append),
1410-
/*.is_optimizable_map_insert=*/IsOptimizableMapInsert(&comprehension),
1414+
/*.is_optimizable_map_insert=*/
1415+
IsOptimizableMapInsert(&comprehension,
1416+
options_.enable_comprehension_mutable_map),
14111417
/*.is_optimizable_bind=*/is_bind,
14121418
/*.iter_var_in_scope=*/false,
14131419
/*.iter_var2_in_scope=*/false,
@@ -1587,21 +1593,6 @@ class FlatExprVisitor : public cel::AstVisitor {
15871593
return;
15881594
}
15891595

1590-
if (!comprehension_stack_.empty()) {
1591-
const ComprehensionStackRecord& comprehension =
1592-
comprehension_stack_.back();
1593-
if (comprehension.is_optimizable_map_insert) {
1594-
if (&(comprehension.comprehension->accu_init()) == &expr) {
1595-
if (options_.max_recursion_depth != 0) {
1596-
SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1);
1597-
return;
1598-
}
1599-
AddStep(CreateMutableMapStep(expr.id()));
1600-
return;
1601-
}
1602-
}
1603-
}
1604-
16051596
auto status_or_resolved_fields =
16061597
ResolveCreateStructFields(struct_expr, expr.id());
16071598
if (!status_or_resolved_fields.ok()) {
@@ -1639,6 +1630,22 @@ class FlatExprVisitor : public cel::AstVisitor {
16391630
ValidateOrError(entry.has_key(), "Map entry missing key");
16401631
ValidateOrError(entry.has_value(), "Map entry missing value");
16411632
}
1633+
1634+
if (!comprehension_stack_.empty()) {
1635+
const ComprehensionStackRecord& comprehension =
1636+
comprehension_stack_.back();
1637+
if (comprehension.is_optimizable_map_insert) {
1638+
if (&(comprehension.comprehension->accu_init()) == &expr) {
1639+
if (options_.max_recursion_depth != 0) {
1640+
SetRecursiveStep(CreateDirectMutableMapStep(expr.id()), 1);
1641+
return;
1642+
}
1643+
AddStep(CreateMutableMapStep(expr.id()));
1644+
return;
1645+
}
1646+
}
1647+
}
1648+
16421649
auto depth = RecursionEligible();
16431650
if (depth.has_value()) {
16441651
auto deps = ExtractRecursiveDependencies();

eval/public/cel_options.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ cel::RuntimeOptions ConvertToRuntimeOptions(const InterpreterOptions& options) {
2727
options.enable_comprehension,
2828
options.comprehension_max_iterations,
2929
options.enable_comprehension_list_append,
30+
options.enable_comprehension_mutable_map,
3031
options.enable_regex,
3132
options.regex_max_program_size,
3233
options.enable_string_conversion,

eval/public/cel_options.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@
1717
#ifndef THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_
1818
#define THIRD_PARTY_CEL_CPP_EVAL_PUBLIC_CEL_OPTIONS_H_
1919

20-
#include <string>
21-
2220
#include "absl/base/attributes.h"
2321
#include "runtime/runtime_options.h"
2422
#include "google/protobuf/arena.h"
@@ -74,6 +72,10 @@ struct InterpreterOptions {
7472
// with hand-rolled ASTs.
7573
bool enable_comprehension_list_append = false;
7674

75+
// Enable mutable map construction within comprehensions. Note, this option is
76+
// not safe with hand-rolled ASTs.
77+
bool enable_comprehension_mutable_map = false;
78+
7779
// Enable RE2 match() overload.
7880
bool enable_regex = true;
7981

eval/tests/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ cc_test(
6161
"//common:memory",
6262
"//common:native_type",
6363
"//common:value",
64+
"//extensions:comprehensions_v2_functions",
65+
"//extensions:comprehensions_v2_macros",
6466
"//extensions/protobuf:runtime_adapter",
6567
"//extensions/protobuf:value",
6668
"//internal:benchmark",
6769
"//internal:testing",
6870
"//internal:testing_descriptor_pool",
6971
"//internal:testing_message_factory",
7072
"//parser",
73+
"//parser:macro",
74+
"//parser:macro_registry",
7175
"//runtime",
7276
"//runtime:activation",
7377
"//runtime:constant_folding",

eval/tests/modern_benchmark_test.cc

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,15 @@
4141
#include "common/native_type.h"
4242
#include "common/value.h"
4343
#include "eval/tests/request_context.pb.h"
44+
#include "extensions/comprehensions_v2_functions.h"
45+
#include "extensions/comprehensions_v2_macros.h"
4446
#include "extensions/protobuf/runtime_adapter.h"
4547
#include "extensions/protobuf/value.h"
4648
#include "internal/benchmark.h"
4749
#include "internal/testing.h"
4850
#include "internal/testing_descriptor_pool.h"
4951
#include "internal/testing_message_factory.h"
52+
#include "parser/macro_registry.h"
5053
#include "parser/parser.h"
5154
#include "runtime/activation.h"
5255
#include "runtime/constant_folding.h"
@@ -70,6 +73,7 @@ using ::cel::extensions::ProtobufRuntimeAdapter;
7073
using ::cel::expr::Expr;
7174
using ::cel::expr::ParsedExpr;
7275
using ::cel::expr::SourceInfo;
76+
using ::google::api::expr::parser::EnrichedParse;
7377
using ::google::api::expr::parser::Parse;
7478
using ::google::api::expr::runtime::RequestContext;
7579
using ::google::rpc::context::AttributeContext;
@@ -1270,6 +1274,62 @@ void BM_ComprehensionCpp(benchmark::State& state) {
12701274

12711275
BENCHMARK(BM_ComprehensionCpp)->Range(1, 1 << 20);
12721276

1277+
void BM_MapTransformComprehension(benchmark::State& state) {
1278+
ASSERT_OK_AND_ASSIGN(auto source,
1279+
NewSource("map_var.transformMapEntry(k, v, {v:k})"));
1280+
1281+
MacroRegistry registry;
1282+
ASSERT_THAT(
1283+
extensions::RegisterComprehensionsV2Macros(registry, ParserOptions()),
1284+
IsOk());
1285+
1286+
ASSERT_OK_AND_ASSIGN(auto parsed_expr,
1287+
EnrichedParse(*source, registry, ParserOptions()));
1288+
1289+
RuntimeOptions options = GetOptions();
1290+
options.comprehension_max_iterations = 10000000;
1291+
1292+
// This is a critical optimization: it allows the comprehension to accumulate
1293+
// results in a mutable map instead of cloning and augmenting an unmodifiable
1294+
// map on every iteration.
1295+
options.enable_comprehension_mutable_map = true;
1296+
1297+
ASSERT_OK_AND_ASSIGN(auto builder,
1298+
CreateStandardRuntimeBuilder(
1299+
internal::GetTestingDescriptorPool(), options));
1300+
1301+
ASSERT_THAT(extensions::RegisterComprehensionsV2Functions(
1302+
builder.function_registry(), options),
1303+
IsOk());
1304+
1305+
ASSERT_OK_AND_ASSIGN(auto runtime, std::move(builder).Build());
1306+
1307+
google::protobuf::Arena arena;
1308+
Activation activation;
1309+
1310+
auto map_builder = cel::NewMapValueBuilder(&arena);
1311+
1312+
int len = state.range(0);
1313+
map_builder->Reserve(len);
1314+
for (int i = 0; i < len; i++) {
1315+
ASSERT_THAT(map_builder->Put(IntValue(i), IntValue(i)), IsOk());
1316+
}
1317+
1318+
activation.InsertOrAssignValue("map_var", std::move(*map_builder).Build());
1319+
1320+
ASSERT_OK_AND_ASSIGN(auto cel_expr, ProtobufRuntimeAdapter::CreateProgram(
1321+
*runtime, parsed_expr.parsed_expr()));
1322+
1323+
for (auto _ : state) {
1324+
ASSERT_OK_AND_ASSIGN(cel::Value result,
1325+
cel_expr->Evaluate(&arena, activation));
1326+
ASSERT_TRUE(InstanceOf<MapValue>(result));
1327+
ASSERT_THAT(Cast<MapValue>(result).Size(), IsOkAndHolds(len));
1328+
}
1329+
}
1330+
1331+
BENCHMARK(BM_MapTransformComprehension)->Range(1, 1 << 16);
1332+
12731333
} // namespace
12741334

12751335
} // namespace cel

extensions/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ cc_test(
581581
":comprehensions_v2_macros",
582582
":strings",
583583
"//common:source",
584+
"//common:value",
584585
"//common:value_testing",
585586
"//extensions/protobuf:runtime_adapter",
586587
"//internal:status_macros",
@@ -596,6 +597,7 @@ cc_test(
596597
"//runtime:reference_resolver",
597598
"//runtime:runtime_options",
598599
"//runtime:standard_runtime_builder_factory",
600+
"@com_google_absl//absl/status",
599601
"@com_google_absl//absl/status:status_matchers",
600602
"@com_google_absl//absl/status:statusor",
601603
"@com_google_absl//absl/strings:string_view",

extensions/comprehensions_v2_functions.cc

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ namespace cel::extensions {
3535

3636
namespace {
3737

38-
absl::StatusOr<Value> MapInsert(
38+
absl::StatusOr<Value> MapInsertKeyValue(
3939
const MapValue& map, const Value& key, const Value& value,
4040
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
4141
google::protobuf::MessageFactory* absl_nonnull message_factory,
@@ -68,6 +68,54 @@ absl::StatusOr<Value> MapInsert(
6868
return std::move(*builder).Build();
6969
}
7070

71+
absl::StatusOr<Value> MapInsertMap(
72+
const MapValue& map, const MapValue& value,
73+
const google::protobuf::DescriptorPool* absl_nonnull descriptor_pool,
74+
google::protobuf::MessageFactory* absl_nonnull message_factory,
75+
google::protobuf::Arena* absl_nonnull arena) {
76+
if (auto mutable_map_value = common_internal::AsMutableMapValue(map);
77+
mutable_map_value) {
78+
// Fast path, runtime has given us a mutable map. We can mutate it directly
79+
// and return it.
80+
CEL_RETURN_IF_ERROR(
81+
value.ForEach(
82+
[&mutable_map_value](const Value& key,
83+
const Value& value) -> absl::StatusOr<bool> {
84+
CEL_RETURN_IF_ERROR(mutable_map_value->Put(key, value));
85+
return true;
86+
},
87+
descriptor_pool, message_factory, arena))
88+
.With(ErrorValueReturn());
89+
return map;
90+
}
91+
// Slow path, we have to make a copy.
92+
auto builder = NewMapValueBuilder(arena);
93+
if (auto size = map.Size(); size.ok()) {
94+
builder->Reserve(*size + 1);
95+
} else {
96+
size.IgnoreError();
97+
}
98+
CEL_RETURN_IF_ERROR(
99+
map.ForEach(
100+
[&builder](const Value& key,
101+
const Value& value) -> absl::StatusOr<bool> {
102+
CEL_RETURN_IF_ERROR(builder->Put(key, value));
103+
return true;
104+
},
105+
descriptor_pool, message_factory, arena))
106+
.With(ErrorValueReturn());
107+
CEL_RETURN_IF_ERROR(
108+
value.ForEach(
109+
[&builder](const Value& key,
110+
const Value& value) -> absl::StatusOr<bool> {
111+
CEL_RETURN_IF_ERROR(builder->Put(key, value));
112+
return true;
113+
},
114+
descriptor_pool, message_factory, arena))
115+
.With(ErrorValueReturn());
116+
return std::move(*builder).Build();
117+
}
118+
71119
} // namespace
72120

73121
absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry,
@@ -77,7 +125,15 @@ absl::Status RegisterComprehensionsV2Functions(FunctionRegistry& registry,
77125
Value>::CreateDescriptor("cel.@mapInsert",
78126
/*receiver_style=*/false),
79127
TernaryFunctionAdapter<absl::StatusOr<Value>, MapValue, Value,
80-
Value>::WrapFunction(&MapInsert)));
128+
Value>::WrapFunction(&MapInsertKeyValue)));
129+
130+
CEL_RETURN_IF_ERROR(registry.Register(
131+
BinaryFunctionAdapter<absl::StatusOr<Value>, MapValue, MapValue>::
132+
CreateDescriptor("cel.@mapInsert",
133+
/*receiver_style=*/false),
134+
BinaryFunctionAdapter<absl::StatusOr<Value>, MapValue,
135+
MapValue>::WrapFunction(&MapInsertMap)));
136+
81137
return absl::OkStatus();
82138
}
83139

0 commit comments

Comments
 (0)