Skip to content

Commit d0131ea

Browse files
jnthntatumcopybara-github
authored andcommitted
Move arithmetic operators to runtime/standard.
PiperOrigin-RevId: 549769613
1 parent dfc8d87 commit d0131ea

6 files changed

Lines changed: 400 additions & 212 deletions

File tree

eval/public/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,7 @@ cc_library(
306306
"//internal:utf8",
307307
"//runtime:function_registry",
308308
"//runtime:runtime_options",
309+
"//runtime/standard:arithmetic_functions",
309310
"//runtime/standard:comparison_functions",
310311
"//runtime/standard:container_functions",
311312
"//runtime/standard:logical_functions",

eval/public/builtin_func_registrar.cc

Lines changed: 3 additions & 212 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#include <array>
1818
#include <cstdint>
1919
#include <functional>
20-
#include <limits>
2120
#include <string>
2221

2322
#include "absl/status/status.h"
@@ -49,6 +48,7 @@
4948
#include "internal/utf8.h"
5049
#include "runtime/function_registry.h"
5150
#include "runtime/runtime_options.h"
51+
#include "runtime/standard/arithmetic_functions.h"
5252
#include "runtime/standard/comparison_functions.h"
5353
#include "runtime/standard/container_functions.h"
5454
#include "runtime/standard/logical_functions.h"
@@ -68,216 +68,6 @@ using ::cel::Value;
6868
using ::cel::ValueFactory;
6969
using ::google::protobuf::Arena;
7070

71-
// Template functions providing arithmetic operations
72-
template <class Type>
73-
Handle<Value> Add(ValueFactory&, Type v0, Type v1);
74-
75-
template <>
76-
Handle<Value> Add<int64_t>(ValueFactory& value_factory, int64_t v0,
77-
int64_t v1) {
78-
auto sum = cel::internal::CheckedAdd(v0, v1);
79-
if (!sum.ok()) {
80-
return value_factory.CreateErrorValue(sum.status());
81-
}
82-
return value_factory.CreateIntValue(*sum);
83-
}
84-
85-
template <>
86-
Handle<Value> Add<uint64_t>(ValueFactory& value_factory, uint64_t v0,
87-
uint64_t v1) {
88-
auto sum = cel::internal::CheckedAdd(v0, v1);
89-
if (!sum.ok()) {
90-
return value_factory.CreateErrorValue(sum.status());
91-
}
92-
return value_factory.CreateUintValue(*sum);
93-
}
94-
95-
template <>
96-
Handle<Value> Add<double>(ValueFactory& value_factory, double v0, double v1) {
97-
return value_factory.CreateDoubleValue(v0 + v1);
98-
}
99-
100-
template <class Type>
101-
Handle<Value> Sub(ValueFactory&, Type v0, Type v1);
102-
103-
template <>
104-
Handle<Value> Sub<int64_t>(ValueFactory& value_factory, int64_t v0,
105-
int64_t v1) {
106-
auto diff = cel::internal::CheckedSub(v0, v1);
107-
if (!diff.ok()) {
108-
return value_factory.CreateErrorValue(diff.status());
109-
}
110-
return value_factory.CreateIntValue(*diff);
111-
}
112-
113-
template <>
114-
Handle<Value> Sub<uint64_t>(ValueFactory& value_factory, uint64_t v0,
115-
uint64_t v1) {
116-
auto diff = cel::internal::CheckedSub(v0, v1);
117-
if (!diff.ok()) {
118-
return value_factory.CreateErrorValue(diff.status());
119-
}
120-
return value_factory.CreateUintValue(*diff);
121-
}
122-
123-
template <>
124-
Handle<Value> Sub<double>(ValueFactory& value_factory, double v0, double v1) {
125-
return value_factory.CreateDoubleValue(v0 - v1);
126-
}
127-
128-
template <class Type>
129-
Handle<Value> Mul(ValueFactory&, Type v0, Type v1);
130-
131-
template <>
132-
Handle<Value> Mul<int64_t>(ValueFactory& value_factory, int64_t v0,
133-
int64_t v1) {
134-
auto prod = cel::internal::CheckedMul(v0, v1);
135-
if (!prod.ok()) {
136-
return value_factory.CreateErrorValue(prod.status());
137-
}
138-
return value_factory.CreateIntValue(*prod);
139-
}
140-
141-
template <>
142-
Handle<Value> Mul<uint64_t>(ValueFactory& value_factory, uint64_t v0,
143-
uint64_t v1) {
144-
auto prod = cel::internal::CheckedMul(v0, v1);
145-
if (!prod.ok()) {
146-
return value_factory.CreateErrorValue(prod.status());
147-
}
148-
return value_factory.CreateUintValue(*prod);
149-
}
150-
151-
template <>
152-
Handle<Value> Mul<double>(ValueFactory& value_factory, double v0, double v1) {
153-
return value_factory.CreateDoubleValue(v0 * v1);
154-
}
155-
156-
template <class Type>
157-
Handle<Value> Div(ValueFactory&, Type v0, Type v1);
158-
159-
// Division operations for integer types should check for
160-
// division by 0
161-
template <>
162-
Handle<Value> Div<int64_t>(ValueFactory& value_factory, int64_t v0,
163-
int64_t v1) {
164-
auto quot = cel::internal::CheckedDiv(v0, v1);
165-
if (!quot.ok()) {
166-
return value_factory.CreateErrorValue(quot.status());
167-
}
168-
return value_factory.CreateIntValue(*quot);
169-
}
170-
171-
// Division operations for integer types should check for
172-
// division by 0
173-
template <>
174-
Handle<Value> Div<uint64_t>(ValueFactory& value_factory, uint64_t v0,
175-
uint64_t v1) {
176-
auto quot = cel::internal::CheckedDiv(v0, v1);
177-
if (!quot.ok()) {
178-
return value_factory.CreateErrorValue(quot.status());
179-
}
180-
return value_factory.CreateUintValue(*quot);
181-
}
182-
183-
template <>
184-
Handle<Value> Div<double>(ValueFactory& value_factory, double v0, double v1) {
185-
static_assert(std::numeric_limits<double>::is_iec559,
186-
"Division by zero for doubles must be supported");
187-
188-
// For double, division will result in +/- inf
189-
return value_factory.CreateDoubleValue(v0 / v1);
190-
}
191-
192-
// Modulo operation
193-
template <class Type>
194-
Handle<Value> Modulo(ValueFactory& value_factory, Type v0, Type v1);
195-
196-
// Modulo operations for integer types should check for
197-
// division by 0
198-
template <>
199-
Handle<Value> Modulo<int64_t>(ValueFactory& value_factory, int64_t v0,
200-
int64_t v1) {
201-
auto mod = cel::internal::CheckedMod(v0, v1);
202-
if (!mod.ok()) {
203-
return value_factory.CreateErrorValue(mod.status());
204-
}
205-
return value_factory.CreateIntValue(*mod);
206-
}
207-
208-
template <>
209-
Handle<Value> Modulo<uint64_t>(ValueFactory& value_factory, uint64_t v0,
210-
uint64_t v1) {
211-
auto mod = cel::internal::CheckedMod(v0, v1);
212-
if (!mod.ok()) {
213-
return value_factory.CreateErrorValue(mod.status());
214-
}
215-
return value_factory.CreateUintValue(*mod);
216-
}
217-
218-
// Helper method
219-
// Registers all arithmetic functions for template parameter type.
220-
template <class Type>
221-
absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) {
222-
using FunctionAdapter = cel::BinaryFunctionAdapter<Handle<Value>, Type, Type>;
223-
CEL_RETURN_IF_ERROR(registry->Register(
224-
FunctionAdapter::CreateDescriptor(cel::builtin::kAdd, false),
225-
FunctionAdapter::WrapFunction(&Add<Type>)));
226-
227-
CEL_RETURN_IF_ERROR(registry->Register(
228-
FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, false),
229-
FunctionAdapter::WrapFunction(&Sub<Type>)));
230-
231-
CEL_RETURN_IF_ERROR(registry->Register(
232-
FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, false),
233-
FunctionAdapter::WrapFunction(&Mul<Type>)));
234-
235-
return registry->Register(
236-
FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, false),
237-
FunctionAdapter::WrapFunction(&Div<Type>));
238-
}
239-
240-
// Register basic Arithmetic functions for numeric types.
241-
absl::Status RegisterNumericArithmeticFunctions(
242-
CelFunctionRegistry* registry, const InterpreterOptions& options) {
243-
CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType<int64_t>(registry));
244-
CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType<uint64_t>(registry));
245-
CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType<double>(registry));
246-
247-
// Modulo
248-
CEL_RETURN_IF_ERROR(registry->Register(
249-
BinaryFunctionAdapter<Handle<Value>, int64_t, int64_t>::CreateDescriptor(
250-
cel::builtin::kModulo, false),
251-
BinaryFunctionAdapter<Handle<Value>, int64_t, int64_t>::WrapFunction(
252-
&Modulo<int64_t>)));
253-
254-
CEL_RETURN_IF_ERROR(registry->Register(
255-
BinaryFunctionAdapter<Handle<Value>, uint64_t,
256-
uint64_t>::CreateDescriptor(cel::builtin::kModulo,
257-
false),
258-
BinaryFunctionAdapter<Handle<Value>, uint64_t, uint64_t>::WrapFunction(
259-
&Modulo<uint64_t>)));
260-
261-
// Negation group
262-
CEL_RETURN_IF_ERROR(registry->Register(
263-
UnaryFunctionAdapter<Handle<Value>, int64_t>::CreateDescriptor(
264-
cel::builtin::kNeg, false),
265-
UnaryFunctionAdapter<Handle<Value>, int64_t>::WrapFunction(
266-
[](ValueFactory& value_factory, int64_t value) -> Handle<Value> {
267-
auto inv = cel::internal::CheckedNegation(value);
268-
if (!inv.ok()) {
269-
return value_factory.CreateErrorValue(inv.status());
270-
}
271-
return value_factory.CreateIntValue(*inv);
272-
})));
273-
274-
return registry->Register(
275-
UnaryFunctionAdapter<double, double>::CreateDescriptor(cel::builtin::kNeg,
276-
false),
277-
UnaryFunctionAdapter<double, double>::WrapFunction(
278-
[](ValueFactory&, double value) -> double { return -value; }));
279-
}
280-
28171
template <class T>
28272
bool ValueEquals(const CelValue& value, T other);
28373

@@ -1197,11 +987,12 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry,
1197987
cel::RegisterContainerFunctions(modern_registry, runtime_options));
1198988
CEL_RETURN_IF_ERROR(
1199989
cel::RegisterTypeConversionFunctions(modern_registry, runtime_options));
990+
CEL_RETURN_IF_ERROR(
991+
cel::RegisterArithmeticFunctions(modern_registry, runtime_options));
1200992

1201993
return registry->RegisterAll(
1202994
{
1203995
&RegisterEqualityFunctions,
1204-
&RegisterNumericArithmeticFunctions,
1205996
&RegisterTimeFunctions,
1206997
&RegisterStringFunctions,
1207998
&RegisterRegexFunctions,

runtime/standard/BUILD

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,35 @@ cc_test(
167167
"//internal:testing",
168168
],
169169
)
170+
171+
cc_library(
172+
name = "arithmetic_functions",
173+
srcs = ["arithmetic_functions.cc"],
174+
hdrs = ["arithmetic_functions.h"],
175+
deps = [
176+
"//base:builtins",
177+
"//base:data",
178+
"//base:function_adapter",
179+
"//base:handle",
180+
"//internal:overflow",
181+
"//internal:status_macros",
182+
"//runtime:function_registry",
183+
"//runtime:runtime_options",
184+
"@com_google_absl//absl/status",
185+
"@com_google_absl//absl/strings",
186+
],
187+
)
188+
189+
cc_test(
190+
name = "arithmetic_functions_test",
191+
size = "small",
192+
srcs = [
193+
"arithmetic_functions_test.cc",
194+
],
195+
deps = [
196+
":arithmetic_functions",
197+
"//base:builtins",
198+
"//base:function_descriptor",
199+
"//internal:testing",
200+
],
201+
)

0 commit comments

Comments
 (0)