diff --git a/src/butil/third_party/rapidjson/reader.h b/src/butil/third_party/rapidjson/reader.h index fd6398b580..26e53c2a63 100644 --- a/src/butil/third_party/rapidjson/reader.h +++ b/src/butil/third_party/rapidjson/reader.h @@ -1536,9 +1536,15 @@ class GenericReader { state = d; // Do not further consume streams if a root JSON has been parsed. - if ((parseFlags & kParseStopWhenDoneFlag) && state == IterativeParsingFinishState) + if ((parseFlags & kParseStopWhenDoneFlag) && state == IterativeParsingFinishState) { + // wwb: Update parseResult_.Offset() when kParseStopWhenDoneFlag + // is set which means the user needs to know where to resume + // parsing in next calls to JsonToProtoMessage() + if (is.Peek() != '\0') { + SetParseError(kParseErrorNone, is.Tell()); + } break; - + } SkipWhitespace(is); } diff --git a/src/json2pb/json_to_pb.cpp b/src/json2pb/json_to_pb.cpp index 82327cd6d3..e758bdb3ab 100644 --- a/src/json2pb/json_to_pb.cpp +++ b/src/json2pb/json_to_pb.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include "butil/strings/string_number_conversions.h" #include "butil/third_party/rapidjson/error/error.h" #include "butil/third_party/rapidjson/rapidjson.h" @@ -59,6 +60,12 @@ namespace json2pb { +// Use iterative parsing to avoid stack overflow. +const int RAPIDJSON_PARSE_FLAG_DEFAULT = BUTIL_RAPIDJSON_NAMESPACE::kParseIterativeFlag; +const int RAPIDJSON_PARSE_FLAG_STOP_WHEN_DONE = BUTIL_RAPIDJSON_NAMESPACE::kParseStopWhenDoneFlag | RAPIDJSON_PARSE_FLAG_DEFAULT; + +DEFINE_int32(json2pb_max_recursion_depth, 100, "Maximum recursion depth of JSON parser"); + Json2PbOptions::Json2PbOptions() #ifdef BAIDU_INTERNAL : base64_to_bytes(false) @@ -284,8 +291,7 @@ bool JsonValueToProtoMessage(const BUTIL_RAPIDJSON_NAMESPACE::Value& json_value, google::protobuf::Message* message, const Json2PbOptions& options, std::string* err, - bool root_val = false); - + int depth); //Json value to protobuf convert rules for type: //Json value type Protobuf type convert rules //int int uint int64 uint64 valid convert is available @@ -314,7 +320,8 @@ static bool JsonValueToProtoField(const BUTIL_RAPIDJSON_NAMESPACE::Value& value, const google::protobuf::FieldDescriptor* field, google::protobuf::Message* message, const Json2PbOptions& options, - std::string* err) { + std::string* err, + int depth) { if (value.IsNull()) { if (field->is_required()) { J2PERROR(err, "Missing required field: %s", field->full_name().c_str()); @@ -477,13 +484,13 @@ static bool JsonValueToProtoField(const BUTIL_RAPIDJSON_NAMESPACE::Value& value, const BUTIL_RAPIDJSON_NAMESPACE::Value& item = value[index]; if (TYPE_MATCH == J2PCHECKTYPE(item, message, Object)) { if (!JsonValueToProtoMessage( - item, reflection->AddMessage(message, field), options, err)) { + item, reflection->AddMessage(message, field), options, err, depth + 1)) { return false; } } } } else if (!JsonValueToProtoMessage( - value, reflection->MutableMessage(message, field), options, err)) { + value, reflection->MutableMessage(message, field), options, err, depth + 1)) { return false; } break; @@ -495,7 +502,8 @@ bool JsonMapToProtoMap(const BUTIL_RAPIDJSON_NAMESPACE::Value& value, const google::protobuf::FieldDescriptor* map_desc, google::protobuf::Message* message, const Json2PbOptions& options, - std::string* err) { + std::string* err, + int depth) { if (!value.IsObject()) { J2PERROR(err, "Non-object value for map field: %s", map_desc->full_name().c_str()); @@ -515,7 +523,7 @@ bool JsonMapToProtoMap(const BUTIL_RAPIDJSON_NAMESPACE::Value& value, entry_reflection->SetString( entry, key_desc, std::string(it->name.GetString(), it->name.GetStringLength())); - if (!JsonValueToProtoField(it->value, value_desc, entry, options, err)) { + if (!JsonValueToProtoField(it->value, value_desc, entry, options, err, depth + 1)) { return false; } } @@ -526,10 +534,14 @@ bool JsonValueToProtoMessage(const BUTIL_RAPIDJSON_NAMESPACE::Value& json_value, google::protobuf::Message* message, const Json2PbOptions& options, std::string* err, - bool root_val) { + int depth) { + if (depth > FLAGS_json2pb_max_recursion_depth) { + J2PERROR_WITH_PB(message, err, "Exceeded maximum recursion depth"); + return false; + } const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); if (!json_value.IsObject() && - !(json_value.IsArray() && options.array_to_single_repeated && root_val)) { + !(json_value.IsArray() && options.array_to_single_repeated && depth == 0)) { J2PERROR_WITH_PB(message, err, "The input is not a json object"); return false; } @@ -560,7 +572,7 @@ bool JsonValueToProtoMessage(const BUTIL_RAPIDJSON_NAMESPACE::Value& json_value, if (json_value.IsArray()) { if (fields.size() == 1 && fields.front()->is_repeated()) { - return JsonValueToProtoField(json_value, fields.front(), message, options, err); + return JsonValueToProtoField(json_value, fields.front(), message, options, err, depth); } J2PERROR_WITH_PB(message, err, "the input json can't be array here"); @@ -602,11 +614,11 @@ bool JsonValueToProtoMessage(const BUTIL_RAPIDJSON_NAMESPACE::Value& json_value, if (IsProtobufMap(field) && value_ptr->IsObject()) { // Try to parse json like {"key":value, ...} into protobuf map - if (!JsonMapToProtoMap(*value_ptr, field, message, options, err)) { + if (!JsonMapToProtoMap(*value_ptr, field, message, options, err, depth)) { return false; } } else { - if (!JsonValueToProtoField(*value_ptr, field, message, options, err)) { + if (!JsonValueToProtoField(*value_ptr, field, message, options, err, depth)) { return false; } } @@ -624,12 +636,12 @@ inline bool JsonToProtoMessageInline(const std::string& json_string, } BUTIL_RAPIDJSON_NAMESPACE::Document d; if (options.allow_remaining_bytes_after_parsing) { - d.Parse(json_string.c_str()); + d.Parse(json_string.c_str()); if (parsed_offset != nullptr) { *parsed_offset = d.GetErrorOffset(); } } else { - d.Parse<0>(json_string.c_str()); + d.Parse(json_string.c_str()); } if (d.HasParseError()) { if (options.allow_remaining_bytes_after_parsing) { @@ -642,7 +654,7 @@ inline bool JsonToProtoMessageInline(const std::string& json_string, J2PERROR_WITH_PB(message, error, "Invalid json: %s", BUTIL_RAPIDJSON_NAMESPACE::GetParseError_En(d.GetParseError())); return false; } - return JsonValueToProtoMessage(d, message, options, error, true); + return JsonValueToProtoMessage(d, message, options, error, 0); } bool JsonToProtoMessage(const std::string& json_string, @@ -672,12 +684,12 @@ bool JsonToProtoMessage(ZeroCopyStreamReader* reader, } BUTIL_RAPIDJSON_NAMESPACE::Document d; if (options.allow_remaining_bytes_after_parsing) { - d.ParseStream>(*reader); + d.ParseStream>(*reader); if (parsed_offset != nullptr) { *parsed_offset = d.GetErrorOffset(); } } else { - d.ParseStream<0, BUTIL_RAPIDJSON_NAMESPACE::UTF8<>>(*reader); + d.ParseStream>(*reader); } if (d.HasParseError()) { if (options.allow_remaining_bytes_after_parsing) { @@ -690,7 +702,7 @@ bool JsonToProtoMessage(ZeroCopyStreamReader* reader, J2PERROR_WITH_PB(message, error, "Invalid json: %s", BUTIL_RAPIDJSON_NAMESPACE::GetParseError_En(d.GetParseError())); return false; } - return JsonValueToProtoMessage(d, message, options, error, true); + return JsonValueToProtoMessage(d, message, options, error, 0); } bool JsonToProtoMessage(const std::string& json_string, diff --git a/src/json2pb/pb_to_json.cpp b/src/json2pb/pb_to_json.cpp index 9671979cc0..f232226785 100644 --- a/src/json2pb/pb_to_json.cpp +++ b/src/json2pb/pb_to_json.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include "json2pb/zero_copy_stream_writer.h" #include "json2pb/encode_decode.h" #include "json2pb/protobuf_map.h" @@ -33,6 +34,57 @@ #include "butil/base64.h" namespace json2pb { + +DECLARE_int32(json2pb_max_recursion_depth); + +// Helper function to check if the maximum depth of a message is exceeded. +bool ExceedMaxDepth(const google::protobuf::Message& message, int current_depth) { + if (current_depth > FLAGS_json2pb_max_recursion_depth) { + return true; + } + + const google::protobuf::Descriptor* descriptor = message.GetDescriptor(); + const google::protobuf::Reflection* reflection = message.GetReflection(); + + std::vector fields; + // Collect declared fields. + for (int i = 0; i < descriptor->field_count(); ++i) { + fields.push_back(descriptor->field(i)); + } + // Collect extension fields (if any). + { + std::vector ext_fields; + descriptor->file()->pool()->FindAllExtensions(descriptor, &ext_fields); + fields.insert(fields.end(), ext_fields.begin(), ext_fields.end()); + } + + for (const auto* field : fields) { + if (field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + continue; + } + + if (field->is_repeated()) { + const int count = reflection->FieldSize(message, field); + for (int j = 0; j < count; ++j) { + const google::protobuf::Message& sub_message = + reflection->GetRepeatedMessage(message, field, j); + if (ExceedMaxDepth(sub_message, current_depth + 1)) { + return true; + } + } + } else { + if (reflection->HasField(message, field)) { + const google::protobuf::Message& sub_message = + reflection->GetMessage(message, field); + if (ExceedMaxDepth(sub_message, current_depth + 1)) { + return true; + } + } + } + } + return false; +} + Pb2JsonOptions::Pb2JsonOptions() : enum_option(OUTPUT_ENUM_BY_NAME) , pretty_json(false) @@ -52,7 +104,7 @@ class PbToJsonConverter { explicit PbToJsonConverter(const Pb2JsonOptions& opt) : _option(opt) {} template - bool Convert(const google::protobuf::Message& message, Handler& handler, bool root_msg = false); + bool Convert(const google::protobuf::Message& message, Handler& handler, int depth = 0); const std::string& ErrorText() const { return _error; } @@ -60,14 +112,18 @@ class PbToJsonConverter { template bool _PbFieldToJson(const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* field, - Handler& handler); + Handler& handler, int depth); std::string _error; Pb2JsonOptions _option; }; template -bool PbToJsonConverter::Convert(const google::protobuf::Message& message, Handler& handler, bool root_msg) { +bool PbToJsonConverter::Convert(const google::protobuf::Message& message, Handler& handler, int depth) { + if (depth > FLAGS_json2pb_max_recursion_depth) { + _error = "Exceeded maximum recursion depth"; + return false; + } const google::protobuf::Reflection* reflection = message.GetReflection(); const google::protobuf::Descriptor* descriptor = message.GetDescriptor(); @@ -101,9 +157,9 @@ bool PbToJsonConverter::Convert(const google::protobuf::Message& message, Handle } } - if (root_msg && _option.single_repeated_to_array) { + if (depth == 0 && _option.single_repeated_to_array) { if (map_fields.empty() && fields.size() == 1 && fields.front()->is_repeated()) { - return _PbFieldToJson(message, fields.front(), handler); + return _PbFieldToJson(message, fields.front(), handler, depth); } } @@ -134,7 +190,7 @@ bool PbToJsonConverter::Convert(const google::protobuf::Message& message, Handle bool decoded = decode_name(orig_name, field_name_str); const std::string& name = decoded ? field_name_str : orig_name; handler.Key(name.data(), name.size(), false); - if (!_PbFieldToJson(message, field, handler)) { + if (!_PbFieldToJson(message, field, handler, depth)) { return false; } } @@ -164,7 +220,7 @@ bool PbToJsonConverter::Convert(const google::protobuf::Message& message, Handle handler.Key(entry_name.data(), entry_name.size(), false); // Fill in entries into this json object - if (!_PbFieldToJson(entry, value_desc, handler)) { + if (!_PbFieldToJson(entry, value_desc, handler, depth)) { return false; } } @@ -180,7 +236,7 @@ template bool PbToJsonConverter::_PbFieldToJson( const google::protobuf::Message& message, const google::protobuf::FieldDescriptor* field, - Handler& handler) { + Handler& handler, int depth) { const google::protobuf::Reflection* reflection = message.GetReflection(); switch (field->cpp_type()) { #define CASE_FIELD_TYPE(cpptype, method, valuetype, handle) \ @@ -280,14 +336,14 @@ bool PbToJsonConverter::_PbFieldToJson( handler.StartArray(); for (int index = 0; index < field_size; ++index) { if (!Convert(reflection->GetRepeatedMessage( - message, field, index), handler)) { + message, field, index), handler, depth + 1)) { return false; } } handler.EndArray(field_size); } else { - if (!Convert(reflection->GetMessage(message, field), handler)) { + if (!Convert(reflection->GetMessage(message, field), handler, depth + 1)) { return false; } } @@ -305,10 +361,10 @@ bool ProtoMessageToJsonStream(const google::protobuf::Message& message, bool succ = false; if (options.pretty_json) { BUTIL_RAPIDJSON_NAMESPACE::PrettyWriter writer(os); - succ = converter.Convert(message, writer, true); + succ = converter.Convert(message, writer); } else { BUTIL_RAPIDJSON_NAMESPACE::OptimizedWriter writer(os); - succ = converter.Convert(message, writer, true); + succ = converter.Convert(message, writer); } if (!succ && error) { error->clear(); @@ -352,6 +408,12 @@ bool ProtoMessageToJson(const google::protobuf::Message& message, bool ProtoMessageToProtoJson(const google::protobuf::Message& message, google::protobuf::io::ZeroCopyOutputStream* json, const Pb2ProtoJsonOptions& options, std::string* error) { + if (ExceedMaxDepth(message, 0)) { + if (error) { + *error = "Exceeded maximum recursion depth"; + } + return false; + } butil::IOBuf buf; butil::IOBufAsZeroCopyOutputStream output_stream(&buf); if (!message.SerializeToZeroCopyStream(&output_stream)) { diff --git a/test/brpc_protobuf_json_unittest.cpp b/test/brpc_protobuf_json_unittest.cpp index e8435d2cf3..b9289b20d5 100644 --- a/test/brpc_protobuf_json_unittest.cpp +++ b/test/brpc_protobuf_json_unittest.cpp @@ -26,6 +26,7 @@ #include "butil/strings/string_util.h" #include "butil/third_party/rapidjson/rapidjson.h" #include "butil/time.h" +#include "butil/memory/scope_guard.h" #include "gperftools_helper.h" #include "json2pb/pb_to_json.h" #include "json2pb/json_to_pb.h" @@ -36,6 +37,7 @@ #include "addressbook.pb.h" #include "addressbook_encode_decode.pb.h" #include "addressbook_map.pb.h" +#include "echo.pb.h" namespace { // just for coding-style check @@ -497,6 +499,151 @@ TEST_F(ProtobufJsonTest, json_to_pb_expected_failed_case) { ASSERT_STREQ("Invalid value `23' for optional field `Content.uid' which SHOULD be string, Missing required field: Ext.databyte", error.data()); } +const int DEEP_RECURSION_TEST_DEPTH = 140000; + +TEST_F(ProtobufJsonTest, json_to_pb_unbounded_recursion) { + test::RecursiveMessage msg; + + // Generate a deeply nested JSON string to trigger unbounded recursion. + const int recursion_depth = DEEP_RECURSION_TEST_DEPTH; + std::string nested_json = ""; + for (int i = 0; i < recursion_depth; ++i) { + nested_json += "{\"child\":"; + } + nested_json += "{\"data\":\"leaf\"}"; + for (int i = 0; i < recursion_depth; ++i) { + nested_json += "}"; + } + + { + std::string error; + bool ret = json2pb::JsonToProtoMessage(nested_json, &msg, &error); + ASSERT_FALSE(ret); + ASSERT_EQ("Exceeded maximum recursion depth [RecursiveMessage]", error); + } + { + json2pb::ProtoJson2PbOptions options; + std::string error; + bool ret = json2pb::ProtoJsonToProtoMessage(nested_json, &msg, options, &error); + ASSERT_FALSE(ret); + ASSERT_EQ("INVALID_ARGUMENT:Message too deep. Max recursion depth reached for key 'child'", error); + } +} + +TEST_F(ProtobufJsonTest, pb_to_json_unbounded_recursion) { + test::RecursiveMessage msg; + + // Create a deeply nested protobuf message. + const int recursion_depth = DEEP_RECURSION_TEST_DEPTH; + test::RecursiveMessage* current = &msg; + std::vector nodes; + nodes.reserve(recursion_depth); + for (int i = 0; i < recursion_depth; ++i) { + nodes.push_back(current); + current = current->mutable_child(); + } + current->set_data("leaf"); + + BRPC_SCOPE_EXIT { + // Release msg memory from end to start to avoid stack overflow. + for (size_t i = nodes.size() - 1; i > 0; --i) { + delete nodes[i]->release_child(); + } + }; + + { + std::string json_output; + std::string error; + bool ret = json2pb::ProtoMessageToJson(msg, &json_output, &error); + ASSERT_FALSE(ret); + ASSERT_EQ("Exceeded maximum recursion depth", error); + } + { + std::string json_output; + std::string error; + json2pb::Pb2ProtoJsonOptions options; + bool ret = json2pb::ProtoMessageToProtoJson(msg, &json_output, options, &error); + ASSERT_FALSE(ret); + ASSERT_EQ("Exceeded maximum recursion depth", error); + } +} + +TEST_F(ProtobufJsonTest, pb_parse_unbounded_recursion) { + auto generate_binary = [](int recursion_depth) { + // Innermost message: { data: "leaf" } + // data field: tag = (2<<3)|2 = 0x12, len=4, bytes "leaf" + const char kLeafRaw[] = "\x12\x04" "leaf"; + const std::string leaf_msg(kLeafRaw, sizeof(kLeafRaw) - 1); + + // Precompute sizes: + // S[0] = leaf size + // S[i] = 1 (tag 0x0A) + varint_len(S[i-1]) + S[i-1] + auto varint_len = [](size_t v) { + int n = 1; + while (v >= 128) { v >>= 7; ++n; } + return n; + }; + + std::vector sizes; + sizes.reserve(recursion_depth + 1); + sizes.push_back(leaf_msg.size()); + for (int i = 1; i <= recursion_depth; ++i) { + size_t inner = sizes[i-1]; + sizes.push_back(1 + varint_len(inner) + inner); + } + const size_t final_size = sizes.back(); + + std::string out; + out.resize(final_size); + size_t off = 0; + + // Emit outermost -> innermost wrappers: tag(0x0A) + varint(len(inner)) + for (int depth = recursion_depth; depth >= 1; --depth) { + out[off++] = static_cast(0x0A); // tag for child + size_t len = sizes[depth - 1]; + while (true) { + uint8_t byte = static_cast(len & 0x7F); + len >>= 7; + if (len) byte |= 0x80; + out[off++] = static_cast(byte); + if (!len) break; + } + } + + // Copy leaf payload + memcpy(&out[off], leaf_msg.data(), leaf_msg.size()); + off += leaf_msg.size(); + return out; + }; + + // Test protobuf max depth limit (100). + { + test::RecursiveMessage msg; + std::string binary_data = generate_binary(100); + bool ret = msg.ParseFromString(binary_data); + ASSERT_TRUE(ret); + ASSERT_TRUE(msg.IsInitialized()); + + std::string error; + std::string json_output; + ret = json2pb::ProtoMessageToJson(msg, &json_output, &error); + ASSERT_TRUE(ret); + ASSERT_EQ("", error); + } + { + test::RecursiveMessage msg; + std::string binary_data = generate_binary(101); + bool ret = msg.ParseFromString(binary_data); + ASSERT_FALSE(ret); + } + { + test::RecursiveMessage msg; + std::string binary_data = generate_binary(DEEP_RECURSION_TEST_DEPTH); + bool ret = msg.ParseFromString(binary_data); + ASSERT_FALSE(ret); + } +} + TEST_F(ProtobufJsonTest, json_to_pb_perf_case) { std::string info3 = "{\"content\":[{\"distance\":1.0,\ diff --git a/test/echo.proto b/test/echo.proto index d7573fc6c4..970ef1dbb1 100644 --- a/test/echo.proto +++ b/test/echo.proto @@ -110,3 +110,8 @@ message Message1 { message Message2 { required State1 stat = 1; }; + +message RecursiveMessage { + optional RecursiveMessage child = 1; + optional string data = 2; +}