diff --git a/clickhouse/base/socket.cpp b/clickhouse/base/socket.cpp index e0f8fb1c..e9ed5fd2 100644 --- a/clickhouse/base/socket.cpp +++ b/clickhouse/base/socket.cpp @@ -114,11 +114,25 @@ void SetNonBlock(SOCKET fd, bool value) { void SetTimeout(SOCKET fd, const SocketTimeoutParams& timeout_params) { #if defined(_unix_) - timeval recv_timeout { .tv_sec = timeout_params.recv_timeout.count(), .tv_usec = 0 }; - setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout)); + timeval recv_timeout{ timeout_params.recv_timeout.count() / 1000, static_cast(timeout_params.recv_timeout.count() % 1000 * 1000) }; + auto recv_ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, &recv_timeout, sizeof(recv_timeout)); - timeval send_timeout { .tv_sec = timeout_params.send_timeout.count(), .tv_usec = 0 }; - setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout)); + timeval send_timeout{ timeout_params.send_timeout.count() / 1000, static_cast(timeout_params.send_timeout.count() % 1000 * 1000) }; + auto send_ret = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, &send_timeout, sizeof(send_timeout)); + + if (recv_ret == -1 || send_ret == -1) { + throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to set socket timeout"); + } +#else + DWORD recv_timeout = static_cast(timeout_params.recv_timeout.count()); + auto recv_ret = setsockopt(fd, SOL_SOCKET, SO_RCVTIMEO, (const char*)&recv_timeout, sizeof(DWORD)); + + DWORD send_timeout = static_cast(timeout_params.send_timeout.count()); + auto send_ret = setsockopt(fd, SOL_SOCKET, SO_SNDTIMEO, (const char*)&send_timeout, sizeof(DWORD)); + + if (recv_ret == SOCKET_ERROR || send_ret == SOCKET_ERROR) { + throw std::system_error(getSocketErrorCode(), getErrorCategory(), "fail to set socket timeout"); + } #endif }; @@ -244,6 +258,10 @@ Socket::Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_pa : handle_(SocketConnect(addr, timeout_params)) {} +Socket::Socket(const NetworkAddress & addr) + : handle_(SocketConnect(addr, SocketTimeoutParams{})) +{} + Socket::Socket(Socket&& other) noexcept : handle_(other.handle_) { diff --git a/clickhouse/base/socket.h b/clickhouse/base/socket.h index b3d916e1..c68f250d 100644 --- a/clickhouse/base/socket.h +++ b/clickhouse/base/socket.h @@ -83,13 +83,14 @@ class SocketFactory { struct SocketTimeoutParams { - const std::chrono::seconds recv_timeout {0}; - const std::chrono::seconds send_timeout {0}; + std::chrono::milliseconds recv_timeout{ 0 }; + std::chrono::milliseconds send_timeout{ 0 }; }; class Socket : public SocketBase { public: Socket(const NetworkAddress& addr, const SocketTimeoutParams& timeout_params); + Socket(const NetworkAddress& addr); Socket(Socket&& other) noexcept; Socket& operator=(Socket&& other) noexcept; diff --git a/ut/socket_ut.cpp b/ut/socket_ut.cpp index 36b6a65b..5a263435 100644 --- a/ut/socket_ut.cpp +++ b/ut/socket_ut.cpp @@ -18,7 +18,7 @@ TEST(Socketcase, connecterror) { std::this_thread::sleep_for(std::chrono::seconds(1)); try { - Socket socket(addr, SocketTimeoutParams {}); + Socket socket(addr); } catch (const std::system_error& e) { FAIL(); } @@ -26,7 +26,7 @@ TEST(Socketcase, connecterror) { std::this_thread::sleep_for(std::chrono::seconds(1)); server.stop(); try { - Socket socket(addr, SocketTimeoutParams {}); + Socket socket(addr); FAIL(); } catch (const std::system_error& e) { ASSERT_NE(EINPROGRESS,e.code().value()); @@ -43,14 +43,20 @@ TEST(Socketcase, timeoutrecv) { std::this_thread::sleep_for(std::chrono::seconds(1)); try { - Socket socket(addr, SocketTimeoutParams { .recv_timeout = Seconds(5), .send_timeout = Seconds(5) }); + Socket socket(addr, SocketTimeoutParams { Seconds(5), Seconds(5) }); std::unique_ptr ptr_input_stream = socket.makeInputStream(); char buf[1024]; ptr_input_stream->Read(buf, sizeof(buf)); - } catch (const std::system_error& e) { - ASSERT_EQ(EAGAIN, e.code().value()); + } + catch (const std::system_error& e) { +#if defined(_unix_) + auto expected = EAGAIN; +#else + auto expected = WSAETIMEDOUT; +#endif + ASSERT_EQ(expected, e.code().value()); } std::this_thread::sleep_for(std::chrono::seconds(1));