From fbe3fe96e9191149a3e2440390125b9afca0d368 Mon Sep 17 00:00:00 2001 From: lintanghui Date: Fri, 28 Feb 2025 14:45:40 +0800 Subject: [PATCH 1/4] keep session info in RedisConnContext --- example/BUILD.bazel | 11 ++++ example/redis_c++/redis_server.cpp | 85 +++++++++++++++++++++++++----- src/brpc/policy/redis_protocol.cpp | 32 +---------- src/brpc/redis.cpp | 6 ++- src/brpc/redis.h | 38 +++++++++++-- 5 files changed, 124 insertions(+), 48 deletions(-) diff --git a/example/BUILD.bazel b/example/BUILD.bazel index b38cd5a156..df2722a4f6 100644 --- a/example/BUILD.bazel +++ b/example/BUILD.bazel @@ -123,3 +123,14 @@ cc_binary( "//:brpc", ], ) + +cc_binary( + name = "redis_c++_server", + srcs = [ + "redis_c++/redis_server.cpp", + ], + copts = COPTS, + deps = [ + "//:brpc", + ], +) \ No newline at end of file diff --git a/example/redis_c++/redis_server.cpp b/example/redis_c++/redis_server.cpp index 6ebc385315..674b855bae 100644 --- a/example/redis_c++/redis_server.cpp +++ b/example/redis_c++/redis_server.cpp @@ -30,22 +30,38 @@ #include DEFINE_int32(port, 6379, "TCP Port of this server"); - class RedisServiceImpl : public brpc::RedisService { public: - bool Set(const std::string& key, const std::string& value) { + RedisServiceImpl() { + _user_password["db1"] = "123456"; + _user_password["db2"] = "123456"; + _db_map["db1"].resize(kHashSlotNum); + _db_map["db2"].resize(kHashSlotNum); + } + bool Set(const std::string& db_name, const std::string& key, const std::string& value) { int slot = butil::crc32c::Value(key.c_str(), key.size()) % kHashSlotNum; _mutex[slot].lock(); - _map[slot][key] = value; + auto& kv = _db_map[db_name]; + kv[slot][key] = value; _mutex[slot].unlock(); return true; } - - bool Get(const std::string& key, std::string* value) { + bool Auth(const std::string& db_name, const std::string& password) { + if (_user_password.find(db_name) == _user_password.end()) { + return false; + } else { + if (_user_password[db_name] != password) { + return false; + } + } + return true; + } + bool Get(const std::string& db_name, const std::string& key, std::string* value) { int slot = butil::crc32c::Value(key.c_str(), key.size()) % kHashSlotNum; _mutex[slot].lock(); - auto it = _map[slot].find(key); - if (it == _map[slot].end()) { + auto& kv = _db_map[db_name]; + auto it = kv[slot].find(key); + if (it == kv[slot].end()) { _mutex[slot].unlock(); return false; } @@ -56,7 +72,9 @@ class RedisServiceImpl : public brpc::RedisService { private: const static int kHashSlotNum = 32; - std::unordered_map _map[kHashSlotNum]; + typedef std::unordered_map KVStore; + std::unordered_map> _db_map; + std::unordered_map _user_password; butil::Mutex _mutex[kHashSlotNum]; }; @@ -65,16 +83,20 @@ class GetCommandHandler : public brpc::RedisCommandHandler { explicit GetCommandHandler(RedisServiceImpl* rsimpl) : _rsimpl(rsimpl) {} - brpc::RedisCommandHandlerResult Run(const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { + if (ctx->user_name.empty()) { + output->FormatError("No user name"); + return brpc::REDIS_CMD_HANDLED; + } if (args.size() != 2ul) { output->FormatError("Expect 1 arg for 'get', actually %lu", args.size()-1); return brpc::REDIS_CMD_HANDLED; } const std::string key(args[1].data(), args[1].size()); std::string value; - if (_rsimpl->Get(key, &value)) { + if (_rsimpl->Get(ctx->user_name, key, &value)) { output->SetString(value); } else { output->SetNullString(); @@ -91,22 +113,55 @@ class SetCommandHandler : public brpc::RedisCommandHandler { explicit SetCommandHandler(RedisServiceImpl* rsimpl) : _rsimpl(rsimpl) {} - brpc::RedisCommandHandlerResult Run(const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { + if (ctx->user_name.empty()) { + output->FormatError("No user name"); + return brpc::REDIS_CMD_HANDLED; + } if (args.size() != 3ul) { output->FormatError("Expect 2 args for 'set', actually %lu", args.size()-1); return brpc::REDIS_CMD_HANDLED; } const std::string key(args[1].data(), args[1].size()); const std::string value(args[2].data(), args[2].size()); - _rsimpl->Set(key, value); + _rsimpl->Set(ctx->user_name, key, value); output->SetStatus("OK"); return brpc::REDIS_CMD_HANDLED; } private: - RedisServiceImpl* _rsimpl; + RedisServiceImpl* _rsimpl; +}; + +class AuthCommandHandler : public brpc::RedisCommandHandler { +public: + explicit AuthCommandHandler(RedisServiceImpl* rsimpl) + : _rsimpl(rsimpl) {} + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, const std::vector& args, + brpc::RedisReply* output, + bool /*flush_batched*/) override { + if (args.size() != 3ul) { + output->FormatError("Expect 2 args for 'auth', actually %lu", args.size()-1); + return brpc::REDIS_CMD_HANDLED; + } + + const std::string db_name(args[1].data(), args[1].size()); + const std::string password(args[2].data(), args[2].size()); + + if (_rsimpl->Auth(db_name, password)) { + output->SetStatus("OK"); + ctx->user_name = db_name; + ctx->password = password; + } else { + output->FormatError("Invalid password for database '%s'", db_name.c_str()); + } + return brpc::REDIS_CMD_HANDLED; + } + +private: + RedisServiceImpl* _rsimpl; }; int main(int argc, char* argv[]) { @@ -114,9 +169,11 @@ int main(int argc, char* argv[]) { RedisServiceImpl *rsimpl = new RedisServiceImpl; auto get_handler =std::unique_ptr(new GetCommandHandler(rsimpl)); auto set_handler =std::unique_ptr( new SetCommandHandler(rsimpl)); + auto auth_handler = std::unique_ptr(new AuthCommandHandler(rsimpl)); rsimpl->AddCommandHandler("get", get_handler.get()); rsimpl->AddCommandHandler("set", set_handler.get()); - + rsimpl->AddCommandHandler("auth", auth_handler.get()); + brpc::Server server; brpc::ServerOptions server_options; server_options.redis_service = rsimpl; diff --git a/src/brpc/policy/redis_protocol.cpp b/src/brpc/policy/redis_protocol.cpp index 94524e8b75..4d61d24595 100644 --- a/src/brpc/policy/redis_protocol.cpp +++ b/src/brpc/policy/redis_protocol.cpp @@ -54,27 +54,6 @@ struct InputResponse : public InputMessageBase { } }; -// This class is as parsing_context in socket. -class RedisConnContext : public Destroyable { -public: - explicit RedisConnContext(const RedisService* rs) - : redis_service(rs) - , batched_size(0) {} - - ~RedisConnContext(); - // @Destroyable - void Destroy() override; - - const RedisService* redis_service; - // If user starts a transaction, transaction_handler indicates the - // handler pointer that runs the transaction command. - std::unique_ptr transaction_handler; - // >0 if command handler is run in batched mode. - int batched_size; - - RedisCommandParser parser; - butil::Arena arena; -}; int ConsumeCommand(RedisConnContext* ctx, const std::vector& args, @@ -83,7 +62,7 @@ int ConsumeCommand(RedisConnContext* ctx, RedisReply output(&ctx->arena); RedisCommandHandlerResult result = REDIS_CMD_HANDLED; if (ctx->transaction_handler) { - result = ctx->transaction_handler->Run(args, &output, flush_batched); + result = ctx->transaction_handler->Run(ctx, args, &output, flush_batched); if (result == REDIS_CMD_HANDLED) { ctx->transaction_handler.reset(NULL); } else if (result == REDIS_CMD_BATCHED) { @@ -97,7 +76,7 @@ int ConsumeCommand(RedisConnContext* ctx, snprintf(buf, sizeof(buf), "ERR unknown command `%s`", args[0].as_string().c_str()); output.SetError(buf); } else { - result = ch->Run(args, &output, flush_batched); + result = ch->Run(ctx, args, &output, flush_batched); if (result == REDIS_CMD_CONTINUE) { if (ctx->batched_size != 0) { LOG(ERROR) << "CONTINUE should not be returned in a batched process."; @@ -134,13 +113,6 @@ int ConsumeCommand(RedisConnContext* ctx, return 0; } -// ========== impl of RedisConnContext ========== - -RedisConnContext::~RedisConnContext() { } - -void RedisConnContext::Destroy() { - delete this; -} // ========== impl of RedisConnContext ========== diff --git a/src/brpc/redis.cpp b/src/brpc/redis.cpp index 777e199986..7f14384a6b 100644 --- a/src/brpc/redis.cpp +++ b/src/brpc/redis.cpp @@ -369,5 +369,9 @@ RedisCommandHandler* RedisCommandHandler::NewTransactionHandler() { LOG(ERROR) << "NewTransactionHandler is not implemented"; return NULL; } - +// ========== impl of RedisConnContext ========== +RedisConnContext::~RedisConnContext() { } +void RedisConnContext::Destroy() { + delete this; +} } // namespace brpc diff --git a/src/brpc/redis.h b/src/brpc/redis.h index c6b0ea21f2..e05a459e63 100644 --- a/src/brpc/redis.h +++ b/src/brpc/redis.h @@ -21,8 +21,10 @@ #include +#include "brpc/destroyable.h" #include "brpc/nonreflectable_message.h" #include "brpc/parse_result.h" +#include "brpc/redis_command.h" #include "brpc/pb_compat.h" #include "brpc/redis_reply.h" #include "butil/arena.h" @@ -209,7 +211,31 @@ enum RedisCommandHandlerResult { REDIS_CMD_CONTINUE = 1, REDIS_CMD_BATCHED = 2, }; - +class RedisCommandParser; +// This class is as parsing_context in socket. +class RedisConnContext : public Destroyable { +public: + explicit RedisConnContext(const RedisService* rs) + : redis_service(rs) + , batched_size(0) {} + + ~RedisConnContext(); + // @Destroyable + void Destroy() override; + + const RedisService* redis_service; + // If user starts a transaction, transaction_handler indicates the + // handler pointer that runs the transaction command. + std::unique_ptr transaction_handler; + // >0 if command handler is run in batched mode. + int batched_size; + // If user is authenticated, user_name and password are set. + // Keep auth info in RedisConnContext to distinguish diffrent users( or diffrent db). + std::string user_name; + std::string password; + RedisCommandParser parser; + butil::Arena arena; +}; // The Command handler for a redis request. User should impletement Run(). class RedisCommandHandler { public: @@ -235,8 +261,14 @@ class RedisCommandHandler { // it returns REDIS_CMD_HANDLED. Read the comment below. virtual RedisCommandHandlerResult Run(const std::vector& args, brpc::RedisReply* output, - bool flush_batched) = 0; - + bool flush_batched) { + return REDIS_CMD_HANDLED; + }; + virtual RedisCommandHandlerResult Run(RedisConnContext* ctx, const std::vector& args, + brpc::RedisReply* output, + bool flush_batched) { + return Run(args, output, flush_batched); + } // The Run() returns CONTINUE for "multi", which makes brpc call this method to // create a transaction_handler to process following commands until transaction_handler // returns OK. For example, for command "multi; set k1 v1; set k2 v2; set k3 v3; From 8faa610457ffc3d0d0a2023211d3ee2ee81d2eea Mon Sep 17 00:00:00 2001 From: lintanghui Date: Mon, 3 Mar 2025 12:37:40 +0800 Subject: [PATCH 2/4] fix comment --- example/redis_c++/redis_server.cpp | 42 +++++++++++++++++++++++------- src/brpc/policy/redis_protocol.cpp | 2 -- src/brpc/redis.cpp | 5 ++++ src/brpc/redis.h | 13 +++++---- 4 files changed, 46 insertions(+), 16 deletions(-) diff --git a/example/redis_c++/redis_server.cpp b/example/redis_c++/redis_server.cpp index 674b855bae..7812b24471 100644 --- a/example/redis_c++/redis_server.cpp +++ b/example/redis_c++/redis_server.cpp @@ -30,6 +30,20 @@ #include DEFINE_int32(port, 6379, "TCP Port of this server"); + +class AuthSession : public brpc::Destroyable { +public: + explicit AuthSession(const std::string& user_name, const std::string& password) + : _user_name(user_name), _password(password) {} + + void Destroy() override { + delete this; + } + + const std::string _user_name; + const std::string _password; +}; + class RedisServiceImpl : public brpc::RedisService { public: RedisServiceImpl() { @@ -38,6 +52,7 @@ class RedisServiceImpl : public brpc::RedisService { _db_map["db1"].resize(kHashSlotNum); _db_map["db2"].resize(kHashSlotNum); } + bool Set(const std::string& db_name, const std::string& key, const std::string& value) { int slot = butil::crc32c::Value(key.c_str(), key.size()) % kHashSlotNum; _mutex[slot].lock(); @@ -46,6 +61,7 @@ class RedisServiceImpl : public brpc::RedisService { _mutex[slot].unlock(); return true; } + bool Auth(const std::string& db_name, const std::string& password) { if (_user_password.find(db_name) == _user_password.end()) { return false; @@ -56,6 +72,7 @@ class RedisServiceImpl : public brpc::RedisService { } return true; } + bool Get(const std::string& db_name, const std::string& key, std::string* value) { int slot = butil::crc32c::Value(key.c_str(), key.size()) % kHashSlotNum; _mutex[slot].lock(); @@ -83,10 +100,12 @@ class GetCommandHandler : public brpc::RedisCommandHandler { explicit GetCommandHandler(RedisServiceImpl* rsimpl) : _rsimpl(rsimpl) {} - brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { - if (ctx->user_name.empty()) { + AuthSession* session = static_cast(ctx->session); + if (session->_user_name.empty()) { output->FormatError("No user name"); return brpc::REDIS_CMD_HANDLED; } @@ -96,7 +115,7 @@ class GetCommandHandler : public brpc::RedisCommandHandler { } const std::string key(args[1].data(), args[1].size()); std::string value; - if (_rsimpl->Get(ctx->user_name, key, &value)) { + if (_rsimpl->Get(session->_user_name, key, &value)) { output->SetString(value); } else { output->SetNullString(); @@ -113,10 +132,12 @@ class SetCommandHandler : public brpc::RedisCommandHandler { explicit SetCommandHandler(RedisServiceImpl* rsimpl) : _rsimpl(rsimpl) {} - brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { - if (ctx->user_name.empty()) { + AuthSession* session = static_cast(ctx->session); + if (session->_user_name.empty()) { output->FormatError("No user name"); return brpc::REDIS_CMD_HANDLED; } @@ -126,7 +147,7 @@ class SetCommandHandler : public brpc::RedisCommandHandler { } const std::string key(args[1].data(), args[1].size()); const std::string value(args[2].data(), args[2].size()); - _rsimpl->Set(ctx->user_name, key, value); + _rsimpl->Set(session->_user_name, key, value); output->SetStatus("OK"); return brpc::REDIS_CMD_HANDLED; } @@ -135,11 +156,14 @@ class SetCommandHandler : public brpc::RedisCommandHandler { RedisServiceImpl* _rsimpl; }; + + class AuthCommandHandler : public brpc::RedisCommandHandler { public: explicit AuthCommandHandler(RedisServiceImpl* rsimpl) : _rsimpl(rsimpl) {} - brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { if (args.size() != 3ul) { @@ -152,8 +176,8 @@ class AuthCommandHandler : public brpc::RedisCommandHandler { if (_rsimpl->Auth(db_name, password)) { output->SetStatus("OK"); - ctx->user_name = db_name; - ctx->password = password; + auto auth_session = new AuthSession(db_name, password); + ctx->session = auth_session; } else { output->FormatError("Invalid password for database '%s'", db_name.c_str()); } diff --git a/src/brpc/policy/redis_protocol.cpp b/src/brpc/policy/redis_protocol.cpp index 4d61d24595..f8acf49d6a 100644 --- a/src/brpc/policy/redis_protocol.cpp +++ b/src/brpc/policy/redis_protocol.cpp @@ -114,8 +114,6 @@ int ConsumeCommand(RedisConnContext* ctx, } -// ========== impl of RedisConnContext ========== - ParseResult ParseRedisMessage(butil::IOBuf* source, Socket* socket, bool read_eof, const void* arg) { if (read_eof || source->empty()) { diff --git a/src/brpc/redis.cpp b/src/brpc/redis.cpp index 7f14384a6b..f23895c757 100644 --- a/src/brpc/redis.cpp +++ b/src/brpc/redis.cpp @@ -369,9 +369,14 @@ RedisCommandHandler* RedisCommandHandler::NewTransactionHandler() { LOG(ERROR) << "NewTransactionHandler is not implemented"; return NULL; } + // ========== impl of RedisConnContext ========== RedisConnContext::~RedisConnContext() { } + void RedisConnContext::Destroy() { + if (session) { + session->Destroy(); + } delete this; } } // namespace brpc diff --git a/src/brpc/redis.h b/src/brpc/redis.h index e05a459e63..8880f53776 100644 --- a/src/brpc/redis.h +++ b/src/brpc/redis.h @@ -211,7 +211,9 @@ enum RedisCommandHandlerResult { REDIS_CMD_CONTINUE = 1, REDIS_CMD_BATCHED = 2, }; + class RedisCommandParser; + // This class is as parsing_context in socket. class RedisConnContext : public Destroyable { public: @@ -229,13 +231,13 @@ class RedisConnContext : public Destroyable { std::unique_ptr transaction_handler; // >0 if command handler is run in batched mode. int batched_size; - // If user is authenticated, user_name and password are set. - // Keep auth info in RedisConnContext to distinguish diffrent users( or diffrent db). - std::string user_name; - std::string password; + // If user is authenticated, session is set. + // Keep auth session info in RedisConnContext to distinguish diffrent users( or diffrent db). + Destroyable* session; RedisCommandParser parser; butil::Arena arena; }; + // The Command handler for a redis request. User should impletement Run(). class RedisCommandHandler { public: @@ -264,7 +266,8 @@ class RedisCommandHandler { bool flush_batched) { return REDIS_CMD_HANDLED; }; - virtual RedisCommandHandlerResult Run(RedisConnContext* ctx, const std::vector& args, + virtual RedisCommandHandlerResult Run(RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool flush_batched) { return Run(args, output, flush_batched); From 335399164a10137e2b06ed5d6dc060f2f96f889c Mon Sep 17 00:00:00 2001 From: lintanghui Date: Mon, 3 Mar 2025 19:14:15 +0800 Subject: [PATCH 3/4] add uint test for redis ctx --- example/redis_c++/redis_server.cpp | 15 +++- src/brpc/redis.cpp | 8 ++ src/brpc/redis.h | 13 +++- test/BUILD.bazel | 15 ++++ test/brpc_redis_unittest.cpp | 113 ++++++++++++++++++++++++++--- 5 files changed, 149 insertions(+), 15 deletions(-) diff --git a/example/redis_c++/redis_server.cpp b/example/redis_c++/redis_server.cpp index 7812b24471..a53ee26ed7 100644 --- a/example/redis_c++/redis_server.cpp +++ b/example/redis_c++/redis_server.cpp @@ -104,7 +104,12 @@ class GetCommandHandler : public brpc::RedisCommandHandler { const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { - AuthSession* session = static_cast(ctx->session); + + AuthSession* session = static_cast(ctx->get_session()); + if (session == nullptr) { + output->FormatError("No auth session"); + return brpc::REDIS_CMD_HANDLED; + } if (session->_user_name.empty()) { output->FormatError("No user name"); return brpc::REDIS_CMD_HANDLED; @@ -136,7 +141,11 @@ class SetCommandHandler : public brpc::RedisCommandHandler { const std::vector& args, brpc::RedisReply* output, bool /*flush_batched*/) override { - AuthSession* session = static_cast(ctx->session); + AuthSession* session = static_cast(ctx->get_session()); + if (session == nullptr) { + output->FormatError("No auth session"); + return brpc::REDIS_CMD_HANDLED; + } if (session->_user_name.empty()) { output->FormatError("No user name"); return brpc::REDIS_CMD_HANDLED; @@ -177,7 +186,7 @@ class AuthCommandHandler : public brpc::RedisCommandHandler { if (_rsimpl->Auth(db_name, password)) { output->SetStatus("OK"); auto auth_session = new AuthSession(db_name, password); - ctx->session = auth_session; + ctx->reset_session(auth_session); } else { output->FormatError("Invalid password for database '%s'", db_name.c_str()); } diff --git a/src/brpc/redis.cpp b/src/brpc/redis.cpp index f23895c757..f8870ae5c1 100644 --- a/src/brpc/redis.cpp +++ b/src/brpc/redis.cpp @@ -379,4 +379,12 @@ void RedisConnContext::Destroy() { } delete this; } + +void RedisConnContext::reset_session(Destroyable* s){ + if (session) { + session->Destroy(); + } + session = s; +} + } // namespace brpc diff --git a/src/brpc/redis.h b/src/brpc/redis.h index 8880f53776..50064519f7 100644 --- a/src/brpc/redis.h +++ b/src/brpc/redis.h @@ -219,11 +219,15 @@ class RedisConnContext : public Destroyable { public: explicit RedisConnContext(const RedisService* rs) : redis_service(rs) - , batched_size(0) {} + , batched_size(0) + , session(nullptr) {} ~RedisConnContext(); // @Destroyable void Destroy() override; + void reset_session(Destroyable* s); + + Destroyable* get_session() { return session; } const RedisService* redis_service; // If user starts a transaction, transaction_handler indicates the @@ -231,11 +235,14 @@ class RedisConnContext : public Destroyable { std::unique_ptr transaction_handler; // >0 if command handler is run in batched mode. int batched_size; + + RedisCommandParser parser; + butil::Arena arena; + +private: // If user is authenticated, session is set. // Keep auth session info in RedisConnContext to distinguish diffrent users( or diffrent db). Destroyable* session; - RedisCommandParser parser; - butil::Arena arena; }; // The Command handler for a redis request. User should impletement Run(). diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 9817b45f10..2d26c45cba 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -192,6 +192,21 @@ cc_test( ], ) +cc_test( + name = "brpc_redis_test", + srcs = glob( + [ + "brpc_redis_unittest.cpp", + ], + ), + copts = COPTS, + deps = [ + ":sstream_workaround", + "//:brpc", + "@com_google_googletest//:gtest", + ], +) + cc_test( name = "bvar_test", srcs = glob( diff --git a/test/brpc_redis_unittest.cpp b/test/brpc_redis_unittest.cpp index 573ab2ed5b..49a3507738 100644 --- a/test/brpc_redis_unittest.cpp +++ b/test/brpc_redis_unittest.cpp @@ -811,10 +811,13 @@ butil::Mutex s_mutex; std::unordered_map m; std::unordered_map int_map; + class RedisServiceImpl : public brpc::RedisService { public: RedisServiceImpl() - : _batch_count(0) {} + : _batch_count(0) + , _user("user1") + , _password("password1") {} brpc::RedisCommandHandlerResult OnBatched(const std::vector& args, brpc::RedisReply* output, bool flush_batched) { @@ -864,18 +867,71 @@ class RedisServiceImpl : public brpc::RedisService { std::vector > _batched_command; int _batch_count; + std::string _user; + std::string _password; }; +class AuthSession : public brpc::Destroyable { +public: + explicit AuthSession(const std::string& user_name, const std::string& password) + : _user_name(user_name), _password(password) {} + + void Destroy() override { + delete this; + } + + const std::string _user_name; + const std::string _password; +}; + +class AuthCommandHandler : public brpc::RedisCommandHandler { +public: + AuthCommandHandler(RedisServiceImpl* rs) + : _rs(rs) {} + + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, + brpc::RedisReply* output, + bool flush_batched) { + if (args.size() < 2) { + output->SetError("ERR wrong number of arguments for 'AUTH' command"); + return brpc::REDIS_CMD_HANDLED; + } + const std::string user(args[1].data(), args[1].size()); + const std::string password(args[2].data(), args[2].size()); + if (_rs->_user != user || _rs->_password != password) { + output->SetError("ERR invalid username/password"); + return brpc::REDIS_CMD_HANDLED; + } + auto auth_session = new AuthSession(user, password); + ctx->reset_session(auth_session); + return brpc::REDIS_CMD_HANDLED; + } + +private: + RedisServiceImpl* _rs; +}; + class SetCommandHandler : public brpc::RedisCommandHandler { public: SetCommandHandler(RedisServiceImpl* rs, bool batch_process = false) : _rs(rs) , _batch_process(batch_process) {} - brpc::RedisCommandHandlerResult Run(const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool flush_batched) { + if (!ctx->session) { + output->SetError("ERR no auth"); + return brpc::REDIS_CMD_HANDLED; + } + AuthSession* session = static_cast(ctx->session); + if (!session || (session->_password != _rs->_password) || (session->_user_name != _rs->_user)) { + output->SetError("ERR no auth"); + return brpc::REDIS_CMD_HANDLED; + } if (args.size() < 3) { output->SetError("ERR wrong number of arguments for 'set' command"); return brpc::REDIS_CMD_HANDLED; @@ -898,15 +954,26 @@ class SetCommandHandler : public brpc::RedisCommandHandler { bool _batch_process; }; + class GetCommandHandler : public brpc::RedisCommandHandler { public: GetCommandHandler(RedisServiceImpl* rs, bool batch_process = false) : _rs(rs) , _batch_process(batch_process) {} - brpc::RedisCommandHandlerResult Run(const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool flush_batched) { + if (!ctx->session) { + output->SetError("ERR no auth"); + return brpc::REDIS_CMD_HANDLED; + } + AuthSession* session = static_cast(ctx->session); + if (!session || (session->_password != _rs->_password) || (session->_user_name != _rs->_user)) { + output->SetError("ERR no auth"); + return brpc::REDIS_CMD_HANDLED; + } if (args.size() < 2) { output->SetError("ERR wrong number of arguments for 'get' command"); return brpc::REDIS_CMD_HANDLED; @@ -935,11 +1002,22 @@ class GetCommandHandler : public brpc::RedisCommandHandler { class IncrCommandHandler : public brpc::RedisCommandHandler { public: - IncrCommandHandler() {} + IncrCommandHandler(RedisServiceImpl* rs) + : _rs(rs) {} - brpc::RedisCommandHandlerResult Run(const std::vector& args, + brpc::RedisCommandHandlerResult Run(brpc::RedisConnContext* ctx, + const std::vector& args, brpc::RedisReply* output, bool flush_batched) { + if (!ctx->session) { + output->SetError("ERR no auth"); + return brpc::REDIS_CMD_HANDLED; + } + AuthSession* session = static_cast(ctx->session); + if (!session || (session->_password != _rs->_password) || (session->_user_name != _rs->_user)) { + output->SetError("ERR no auth"); + return brpc::REDIS_CMD_HANDLED; + } if (args.size() < 2) { output->SetError("ERR wrong number of arguments for 'incr' command"); return brpc::REDIS_CMD_HANDLED; @@ -951,6 +1029,9 @@ class IncrCommandHandler : public brpc::RedisCommandHandler { output->SetInteger(value); return brpc::REDIS_CMD_HANDLED; } + +private: + RedisServiceImpl* _rs; }; TEST_F(RedisTest, server_sanity) { @@ -959,10 +1040,12 @@ TEST_F(RedisTest, server_sanity) { RedisServiceImpl* rsimpl = new RedisServiceImpl; GetCommandHandler *gh = new GetCommandHandler(rsimpl); SetCommandHandler *sh = new SetCommandHandler(rsimpl); - IncrCommandHandler *ih = new IncrCommandHandler; + AuthCommandHandler *ah = new AuthCommandHandler(rsimpl); + IncrCommandHandler *ih = new IncrCommandHandler(rsimpl); rsimpl->AddCommandHandler("get", gh); rsimpl->AddCommandHandler("set", sh); rsimpl->AddCommandHandler("incr", ih); + rsimpl->AddCommandHandler("auth", ah); server_options.redis_service = rsimpl; brpc::PortRange pr(8081, 8900); ASSERT_EQ(0, server.Start("127.0.0.1", pr, &server_options)); @@ -975,6 +1058,7 @@ TEST_F(RedisTest, server_sanity) { brpc::RedisRequest request; brpc::RedisResponse response; brpc::Controller cntl; + ASSERT_TRUE(request.AddCommand("auth user1 password1")); ASSERT_TRUE(request.AddCommand("get hello")); ASSERT_TRUE(request.AddCommand("get hello2")); ASSERT_TRUE(request.AddCommand("set key1 value1")); @@ -1029,7 +1113,13 @@ TEST_F(RedisTest, server_sanity) { void* incr_thread(void* arg) { brpc::Channel* c = static_cast(arg); - + // do auth + brpc::RedisRequest auth_req; + brpc::RedisResponse auth_resp; + brpc::Controller auth_cntl; + EXPECT_TRUE(auth_req.AddCommand("auth user1 password1")); + c->CallMethod(NULL, &auth_cntl, &auth_req, &auth_resp, NULL); + EXPECT_FALSE(auth_cntl.Failed()) << auth_cntl.ErrorText(); for (int i = 0; i < 5000; ++i) { brpc::RedisRequest request; brpc::RedisResponse response; @@ -1048,8 +1138,10 @@ TEST_F(RedisTest, server_concurrency) { brpc::Server server; brpc::ServerOptions server_options; RedisServiceImpl* rsimpl = new RedisServiceImpl; - IncrCommandHandler *ih = new IncrCommandHandler; + AuthCommandHandler *ah = new AuthCommandHandler(rsimpl); + IncrCommandHandler *ih = new IncrCommandHandler(rsimpl); rsimpl->AddCommandHandler("incr", ih); + rsimpl->AddCommandHandler("auth", ah); server_options.redis_service = rsimpl; brpc::PortRange pr(8081, 8900); ASSERT_EQ(0, server.Start("0.0.0.0", pr, &server_options)); @@ -1132,7 +1224,7 @@ TEST_F(RedisTest, server_command_continue) { RedisServiceImpl* rsimpl = new RedisServiceImpl; rsimpl->AddCommandHandler("get", new GetCommandHandler(rsimpl)); rsimpl->AddCommandHandler("set", new SetCommandHandler(rsimpl)); - rsimpl->AddCommandHandler("incr", new IncrCommandHandler); + rsimpl->AddCommandHandler("incr", new IncrCommandHandler(rsimpl)); rsimpl->AddCommandHandler("multi", new MultiCommandHandler); server_options.redis_service = rsimpl; brpc::PortRange pr(8081, 8900); @@ -1207,6 +1299,8 @@ TEST_F(RedisTest, server_handle_pipeline) { RedisServiceImpl* rsimpl = new RedisServiceImpl; GetCommandHandler* getch = new GetCommandHandler(rsimpl, true); SetCommandHandler* setch = new SetCommandHandler(rsimpl, true); + AuthCommandHandler* authch = new AuthCommandHandler(rsimpl); + rsimpl->AddCommandHandler("auth", authch); rsimpl->AddCommandHandler("get", getch); rsimpl->AddCommandHandler("set", setch); rsimpl->AddCommandHandler("multi", new MultiCommandHandler); @@ -1222,6 +1316,7 @@ TEST_F(RedisTest, server_handle_pipeline) { brpc::RedisRequest request; brpc::RedisResponse response; brpc::Controller cntl; + ASSERT_TRUE(request.AddCommand("auth user1 password1")); ASSERT_TRUE(request.AddCommand("set key1 v1")); ASSERT_TRUE(request.AddCommand("set key2 v2")); ASSERT_TRUE(request.AddCommand("set key3 v3")); From 28291684f4811544fd7ffea8454484356e6adf2c Mon Sep 17 00:00:00 2001 From: lintanghui Date: Tue, 4 Mar 2025 10:56:45 +0800 Subject: [PATCH 4/4] fix uinttest --- test/BUILD.bazel | 14 -------------- test/brpc_redis_unittest.cpp | 24 ++++++++++++++++++++++++ 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/test/BUILD.bazel b/test/BUILD.bazel index 2d26c45cba..1e0ef9667d 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -192,20 +192,6 @@ cc_test( ], ) -cc_test( - name = "brpc_redis_test", - srcs = glob( - [ - "brpc_redis_unittest.cpp", - ], - ), - copts = COPTS, - deps = [ - ":sstream_workaround", - "//:brpc", - "@com_google_googletest//:gtest", - ], -) cc_test( name = "bvar_test", diff --git a/test/brpc_redis_unittest.cpp b/test/brpc_redis_unittest.cpp index 49a3507738..017d5c7e4a 100644 --- a/test/brpc_redis_unittest.cpp +++ b/test/brpc_redis_unittest.cpp @@ -906,6 +906,7 @@ class AuthCommandHandler : public brpc::RedisCommandHandler { } auto auth_session = new AuthSession(user, password); ctx->reset_session(auth_session); + output->SetStatus("OK"); return brpc::REDIS_CMD_HANDLED; } @@ -1059,6 +1060,14 @@ TEST_F(RedisTest, server_sanity) { brpc::RedisResponse response; brpc::Controller cntl; ASSERT_TRUE(request.AddCommand("auth user1 password1")); + channel.CallMethod(NULL, &cntl, &request, &response, NULL); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + ASSERT_EQ(1, response.reply_size()); + ASSERT_EQ(brpc::REDIS_REPLY_STATUS, response.reply(0).type()); + ASSERT_STREQ("OK", response.reply(0).c_str()); + request.Clear(); + response.Clear(); + cntl.Reset(); ASSERT_TRUE(request.AddCommand("get hello")); ASSERT_TRUE(request.AddCommand("get hello2")); ASSERT_TRUE(request.AddCommand("set key1 value1")); @@ -1222,6 +1231,7 @@ TEST_F(RedisTest, server_command_continue) { brpc::Server server; brpc::ServerOptions server_options; RedisServiceImpl* rsimpl = new RedisServiceImpl; + rsimpl->AddCommandHandler("auth", new AuthCommandHandler(rsimpl)); rsimpl->AddCommandHandler("get", new GetCommandHandler(rsimpl)); rsimpl->AddCommandHandler("set", new SetCommandHandler(rsimpl)); rsimpl->AddCommandHandler("incr", new IncrCommandHandler(rsimpl)); @@ -1234,6 +1244,13 @@ TEST_F(RedisTest, server_command_continue) { options.protocol = brpc::PROTOCOL_REDIS; brpc::Channel channel; ASSERT_EQ(0, channel.Init("127.0.0.1", server.listen_address().port, &options)); + // do auth + brpc::RedisRequest auth_req; + brpc::RedisResponse auth_resp; + brpc::Controller auth_cntl; + ASSERT_TRUE(auth_req.AddCommand("auth user1 password1")); + channel.CallMethod(NULL, &auth_cntl, &auth_req, &auth_resp, NULL); + ASSERT_FALSE(auth_cntl.Failed()) << auth_cntl.ErrorText(); { brpc::RedisRequest request; @@ -1317,6 +1334,13 @@ TEST_F(RedisTest, server_handle_pipeline) { brpc::RedisResponse response; brpc::Controller cntl; ASSERT_TRUE(request.AddCommand("auth user1 password1")); + channel.CallMethod(NULL, &cntl, &request, &response, NULL); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + ASSERT_EQ(1, response.reply_size()); + ASSERT_STREQ("OK", response.reply(0).c_str()); + request.Clear(); + response.Clear(); + cntl.Reset(); ASSERT_TRUE(request.AddCommand("set key1 v1")); ASSERT_TRUE(request.AddCommand("set key2 v2")); ASSERT_TRUE(request.AddCommand("set key3 v3"));