diff --git a/barretenberg/cpp/src/barretenberg/vm2/common/tagged_value.hpp b/barretenberg/cpp/src/barretenberg/vm2/common/tagged_value.hpp index ca2061949c4f..531797b2f983 100644 --- a/barretenberg/cpp/src/barretenberg/vm2/common/tagged_value.hpp +++ b/barretenberg/cpp/src/barretenberg/vm2/common/tagged_value.hpp @@ -53,7 +53,7 @@ class CastException : public TaggedValueException { {} }; -enum class ValueTag { +enum class ValueTag : uint8_t { FF = MEM_TAG_FF, U1 = MEM_TAG_U1, U8 = MEM_TAG_U8, diff --git a/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.cpp b/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.cpp index 480f2ff73438..12c5c8343185 100644 --- a/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.cpp +++ b/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.cpp @@ -2,10 +2,12 @@ #include #include +#include #include "barretenberg/common/bb_bench.hpp" #include "barretenberg/common/log.hpp" #include "barretenberg/vm2/common/aztec_constants.hpp" +#include "barretenberg/vm2/common/tagged_value.hpp" #include "barretenberg/vm2/common/to_radix.hpp" #include "barretenberg/vm2/common/uint1.hpp" #include "barretenberg/vm2/simulation/events/addressing_event.hpp" @@ -417,7 +419,7 @@ void Execution::shr(ContextInterface& context, MemoryAddress a_addr, MemoryAddre * * @throws OutOfGasException if the gas limit is exceeded. */ -void Execution::cast(ContextInterface& context, MemoryAddress src_addr, MemoryAddress dst_addr, uint8_t dst_tag) +void Execution::cast(ContextInterface& context, MemoryAddress src_addr, MemoryAddress dst_addr, MemoryTag dst_tag) { BB_BENCH_NAME("Execution::cast"); constexpr auto opcode = ExecutionOpCode::CAST; @@ -426,7 +428,7 @@ void Execution::cast(ContextInterface& context, MemoryAddress src_addr, MemoryAd set_and_validate_inputs(opcode, { val }); get_gas_tracker().consume_gas(); - MemoryValue truncated = alu.truncate(val.as_ff(), static_cast(dst_tag)); + MemoryValue truncated = alu.truncate(val.as_ff(), dst_tag); memory.set(dst_addr, truncated); set_output(opcode, truncated); } @@ -436,12 +438,13 @@ void Execution::cast(ContextInterface& context, MemoryAddress src_addr, MemoryAd * * @param context The context. * @param dst_addr The resolved address of the output value. - * @param var_enum The enum value of the environment variable to get. + * @param env_var The enum value of the environment variable to get (as an uint8_t). + * We need to use uint8_t here to manually perform validation. * * @throws OutOfGasException if the gas limit is exceeded. * @throws OpcodeExecutionException if the enum value is invalid. */ -void Execution::get_env_var(ContextInterface& context, MemoryAddress dst_addr, uint8_t var_enum) +void Execution::get_env_var(ContextInterface& context, MemoryAddress dst_addr, uint8_t env_var_value) { BB_BENCH_NAME("Execution::get_env_var"); constexpr auto opcode = ExecutionOpCode::GETENVVAR; @@ -449,10 +452,14 @@ void Execution::get_env_var(ContextInterface& context, MemoryAddress dst_addr, u get_gas_tracker().consume_gas(); + // If env_var_value is not a valid EnvironmentVariable enum value, throw an OpcodeExecutionException. + if (env_var_value > static_cast(EnvironmentVariable::MAX)) { + throw OpcodeExecutionException("Invalid environment variable enum value"); + } + MemoryValue result; - EnvironmentVariable env_var = static_cast(var_enum); - switch (env_var) { + switch (static_cast(env_var_value)) { case EnvironmentVariable::ADDRESS: result = MemoryValue::from(context.get_address()); break; @@ -490,6 +497,7 @@ void Execution::get_env_var(ContextInterface& context, MemoryAddress dst_addr, u result = MemoryValue::from(context.gas_left().da_gas); break; default: + // We leave this here defensively. throw OpcodeExecutionException("Invalid environment variable enum value"); } @@ -503,19 +511,18 @@ void Execution::get_env_var(ContextInterface& context, MemoryAddress dst_addr, u * * @param context The context. * @param dst_addr The resolved address of the output memory value. - * @param dst_tag The destination tag of the value to set. (as an uint8_t) + * @param dst_tag The destination tag of the value to set. * @param value The source value to set. (might get truncated) * * @throws OutOfGasException if the gas limit is exceeded. */ -// TODO: My dispatch system makes me have a uint8_t tag. Rethink. -void Execution::set(ContextInterface& context, MemoryAddress dst_addr, uint8_t dst_tag, const FF& value) +void Execution::set(ContextInterface& context, MemoryAddress dst_addr, MemoryTag dst_tag, const FF& value) { BB_BENCH_NAME("Execution::set"); get_gas_tracker().consume_gas(); constexpr auto opcode = ExecutionOpCode::SET; - MemoryValue truncated = alu.truncate(value, static_cast(dst_tag)); + MemoryValue truncated = alu.truncate(value, dst_tag); context.get_memory().set(dst_addr, truncated); set_output(opcode, truncated); } @@ -1781,9 +1788,7 @@ EnqueuedCallResult Execution::execute(std::unique_ptr enqueued gas_tracker = execution_components.make_gas_tracker(gas_event, instruction, context); dispatch_opcode(instruction.get_exec_opcode(), context, resolved_operands); - } - // TODO(fcarreiro): handle this in a better way. - catch (const BytecodeRetrievalError& e) { + } catch (const BytecodeRetrievalError& e) { vinfo("Bytecode retrieval error:: ", e.what()); error = ExecutionError::BYTECODE_RETRIEVAL; handle_exceptional_halt(context, e.what()); @@ -1819,8 +1824,6 @@ EnqueuedCallResult Execution::execute(std::unique_ptr enqueued context.set_pc(context.get_next_pc()); execution_id_manager.increment_execution_id(); - // TODO: We set the inputs and outputs here and into the execution event, - // but maybe there's a better way to do this. events.emit({ .error = error, .wire_instruction = instruction, @@ -1932,14 +1935,9 @@ void Execution::handle_exit_call() parent_context.set_gas_used(result.gas_used + parent_context.get_gas_used()); parent_context.set_child_context(std::move(child_context)); - // TODO(fcarreiro): move somewhere else. - if (parent_context.get_checkpoint_id_at_creation() != merkle_db.get_checkpoint_id()) { - throw std::runtime_error(format("Checkpoint id mismatch: ", - parent_context.get_checkpoint_id_at_creation(), - " != ", - merkle_db.get_checkpoint_id(), - " (gone back to the wrong db/context)")); - } + BB_ASSERT_EQ(parent_context.get_checkpoint_id_at_creation(), + merkle_db.get_checkpoint_id(), + "Checkpoint id mismatch: gone back to the wrong db/context"); } // Else: was top level. ExecutionResult is already set and that will be returned. } @@ -2144,7 +2142,15 @@ inline void Execution::call_with_operands(void (Execution::*f)(ContextInterface& BB_ASSERT_DEBUG(resolved_operands.size() == sizeof...(Ts), "Resolved operands size mismatch"); auto operand_indices = std::make_index_sequence{}; [f, this, &context, &resolved_operands](std::index_sequence) { - (this->*f)(context, resolved_operands.at(Is).to>()...); + // This helper handles operand conversion. In particular it converts enums to their underlying type first. + auto convert_operand = [](const Operand& op) -> T { + if constexpr (std::is_enum_v) { + return static_cast(op.to>()); + } else { + return op.to(); + } + }; + (this->*f)(context, convert_operand.template operator()>(resolved_operands.at(Is))...); }(operand_indices); } diff --git a/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.hpp b/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.hpp index 63be94a9ece2..0e6fcfa514c7 100644 --- a/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.hpp +++ b/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.hpp @@ -9,6 +9,7 @@ #include "barretenberg/vm2/common/field.hpp" #include "barretenberg/vm2/common/memory_types.hpp" #include "barretenberg/vm2/common/opcodes.hpp" +#include "barretenberg/vm2/common/tagged_value.hpp" #include "barretenberg/vm2/simulation/events/context_events.hpp" #include "barretenberg/vm2/simulation/events/event_emitter.hpp" #include "barretenberg/vm2/simulation/events/execution_event.hpp" @@ -117,9 +118,9 @@ class Execution : public ExecutionInterface { void lt(ContextInterface& context, MemoryAddress a_addr, MemoryAddress b_addr, MemoryAddress dst_addr); void lte(ContextInterface& context, MemoryAddress a_addr, MemoryAddress b_addr, MemoryAddress dst_addr); void op_not(ContextInterface& context, MemoryAddress src_addr, MemoryAddress dst_addr); - void cast(ContextInterface& context, MemoryAddress src_addr, MemoryAddress dst_addr, uint8_t dst_tag); - void get_env_var(ContextInterface& context, MemoryAddress dst_addr, uint8_t var_enum); - void set(ContextInterface& context, MemoryAddress dst_addr, uint8_t tag, const FF& value); + void cast(ContextInterface& context, MemoryAddress src_addr, MemoryAddress dst_addr, MemoryTag dst_tag); + void get_env_var(ContextInterface& context, MemoryAddress dst_addr, uint8_t env_var_value); + void set(ContextInterface& context, MemoryAddress dst_addr, MemoryTag tag, const FF& value); void mov(ContextInterface& context, MemoryAddress src_addr, MemoryAddress dst_addr); void jump(ContextInterface& context, uint32_t loc); void jumpi(ContextInterface& context, MemoryAddress cond_addr, uint32_t loc); diff --git a/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.test.cpp b/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.test.cpp index e7ef09e4952b..5b7ad0ec7fe3 100644 --- a/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm2/simulation/gadgets/execution.test.cpp @@ -982,7 +982,7 @@ TEST_F(ExecutionSimulationTest, EmitNullifierCollision) TEST_F(ExecutionSimulationTest, Set) { MemoryAddress dst_addr = 10; - uint8_t dst_tag = static_cast(MemoryTag::U8); + MemoryTag dst_tag = MemoryTag::U8; FF value = 7; EXPECT_CALL(context, get_memory()); @@ -997,7 +997,7 @@ TEST_F(ExecutionSimulationTest, Cast) { MemoryAddress src_addr = 9; MemoryAddress dst_addr = 10; - uint8_t dst_tag = static_cast(MemoryTag::U1); + MemoryTag dst_tag = MemoryTag::U1; MemoryValue value = MemoryValue::from(7); EXPECT_CALL(context, get_memory()).WillOnce(ReturnRef(memory));