Skip to content

Commit 6a9b6f4

Browse files
Merge pull request google#149 from doyensec:socket_normalization
PiperOrigin-RevId: 867446717 Change-Id: I0e89dea0e87814758cf70d62b83c7ade3c9b9eea
2 parents 01a5c71 + 85e6cfb commit 6a9b6f4

File tree

11 files changed

+1581
-0
lines changed

11 files changed

+1581
-0
lines changed
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.tsunami.common.net.socket;
17+
18+
import static com.google.common.base.Preconditions.checkArgument;
19+
import static com.google.common.base.Preconditions.checkNotNull;
20+
21+
import com.google.common.flogger.GoogleLogger;
22+
import java.io.IOException;
23+
import java.net.InetAddress;
24+
import java.net.InetSocketAddress;
25+
import java.net.Socket;
26+
import java.time.Duration;
27+
import javax.net.SocketFactory;
28+
import javax.net.ssl.SSLSocket;
29+
import javax.net.ssl.SSLSocketFactory;
30+
31+
/**
32+
* Default implementation of {@link TsunamiSocketFactory} that creates TCP sockets.
33+
*
34+
* <p>This implementation wraps the standard Java {@link SocketFactory} and {@link SSLSocketFactory}
35+
* to ensure that all created sockets have proper timeout settings configured, preventing plugins
36+
* from hanging indefinitely.
37+
*/
38+
public final class DefaultTsunamiSocketFactory implements TsunamiSocketFactory {
39+
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();
40+
41+
private final SocketFactory socketFactory;
42+
private final SSLSocketFactory sslSocketFactory;
43+
private final Duration defaultConnectTimeout;
44+
private final Duration defaultReadTimeout;
45+
46+
/**
47+
* Creates a new DefaultTsunamiSocketFactory.
48+
*
49+
* @param socketFactory the underlying socket factory to use for plain TCP connections
50+
* @param sslSocketFactory the underlying SSL socket factory to use for SSL/TLS connections
51+
* @param defaultConnectTimeout the default timeout for establishing connections
52+
* @param defaultReadTimeout the default timeout for read operations
53+
*/
54+
public DefaultTsunamiSocketFactory(
55+
SocketFactory socketFactory,
56+
SSLSocketFactory sslSocketFactory,
57+
Duration defaultConnectTimeout,
58+
Duration defaultReadTimeout) {
59+
this.socketFactory = checkNotNull(socketFactory);
60+
this.sslSocketFactory = checkNotNull(sslSocketFactory);
61+
this.defaultConnectTimeout = checkNotNull(defaultConnectTimeout);
62+
this.defaultReadTimeout = checkNotNull(defaultReadTimeout);
63+
64+
checkArgument(!defaultConnectTimeout.isNegative(), "Connect timeout cannot be negative");
65+
checkArgument(!defaultReadTimeout.isNegative(), "Read timeout cannot be negative");
66+
67+
logger.atInfo().log(
68+
"TsunamiSocketFactory initialized with connect timeout: %s, read timeout: %s",
69+
defaultConnectTimeout, defaultReadTimeout);
70+
}
71+
72+
@Override
73+
public Socket createSocket(String host, int port) throws IOException {
74+
return createSocket(host, port, defaultConnectTimeout, defaultReadTimeout);
75+
}
76+
77+
@Override
78+
public Socket createSocket(String host, int port, Duration timeout) throws IOException {
79+
return createSocket(host, port, timeout, timeout);
80+
}
81+
82+
@Override
83+
public Socket createSocket(String host, int port, Duration connectTimeout, Duration readTimeout)
84+
throws IOException {
85+
checkNotNull(host);
86+
checkArgument(port > 0 && port <= 65535, "Port must be between 1 and 65535");
87+
checkNotNull(connectTimeout);
88+
checkNotNull(readTimeout);
89+
90+
logger.atFine().log(
91+
"Creating socket to %s:%d with connect timeout %s, read timeout %s",
92+
host, port, connectTimeout, readTimeout);
93+
94+
Socket socket = socketFactory.createSocket();
95+
configureAndConnect(socket, new InetSocketAddress(host, port), connectTimeout, readTimeout);
96+
return socket;
97+
}
98+
99+
@Override
100+
public Socket createSocket(InetAddress address, int port) throws IOException {
101+
return createSocket(address, port, defaultConnectTimeout, defaultReadTimeout);
102+
}
103+
104+
@Override
105+
public Socket createSocket(InetAddress address, int port, Duration timeout) throws IOException {
106+
return createSocket(address, port, timeout, timeout);
107+
}
108+
109+
@Override
110+
public Socket createSocket(
111+
InetAddress address, int port, Duration connectTimeout, Duration readTimeout)
112+
throws IOException {
113+
checkNotNull(address);
114+
checkArgument(port > 0 && port <= 65535, "Port must be between 1 and 65535");
115+
checkNotNull(connectTimeout);
116+
checkNotNull(readTimeout);
117+
118+
logger.atFine().log(
119+
"Creating socket to %s:%d with connect timeout %s, read timeout %s",
120+
address.getHostAddress(), port, connectTimeout, readTimeout);
121+
122+
Socket socket = socketFactory.createSocket();
123+
configureAndConnect(socket, new InetSocketAddress(address, port), connectTimeout, readTimeout);
124+
return socket;
125+
}
126+
127+
@Override
128+
public Socket createUnconnectedSocket() throws IOException {
129+
Socket socket = socketFactory.createSocket();
130+
socket.setSoTimeout((int) defaultReadTimeout.toMillis());
131+
logger.atFine().log("Created unconnected socket with read timeout %s", defaultReadTimeout);
132+
return socket;
133+
}
134+
135+
@Override
136+
public SSLSocket createSslSocket(String host, int port) throws IOException {
137+
return createSslSocket(host, port, defaultConnectTimeout, defaultReadTimeout);
138+
}
139+
140+
@Override
141+
public SSLSocket createSslSocket(String host, int port, Duration timeout) throws IOException {
142+
return createSslSocket(host, port, timeout, timeout);
143+
}
144+
145+
@Override
146+
public SSLSocket createSslSocket(
147+
String host, int port, Duration connectTimeout, Duration readTimeout) throws IOException {
148+
checkNotNull(host);
149+
checkArgument(port > 0 && port <= 65535, "Port must be between 1 and 65535");
150+
checkNotNull(connectTimeout);
151+
checkNotNull(readTimeout);
152+
153+
logger.atFine().log(
154+
"Creating SSL socket to %s:%d with connect timeout %s, read timeout %s",
155+
host, port, connectTimeout, readTimeout);
156+
157+
// Create a plain socket first, connect with timeout, then wrap with SSL
158+
Socket plainSocket = socketFactory.createSocket();
159+
configureAndConnect(
160+
plainSocket, new InetSocketAddress(host, port), connectTimeout, readTimeout);
161+
162+
SSLSocket sslSocket = (SSLSocket) sslSocketFactory.createSocket(plainSocket, host, port, true);
163+
sslSocket.setSoTimeout((int) readTimeout.toMillis());
164+
sslSocket.startHandshake();
165+
166+
return sslSocket;
167+
}
168+
169+
@Override
170+
public SSLSocket createSslSocket(InetAddress address, int port) throws IOException {
171+
return createSslSocket(address, port, defaultConnectTimeout, defaultReadTimeout);
172+
}
173+
174+
@Override
175+
public SSLSocket createSslSocket(InetAddress address, int port, Duration timeout)
176+
throws IOException {
177+
return createSslSocket(address, port, timeout, timeout);
178+
}
179+
180+
@Override
181+
public SSLSocket createSslSocket(
182+
InetAddress address, int port, Duration connectTimeout, Duration readTimeout)
183+
throws IOException {
184+
checkNotNull(address);
185+
checkArgument(port > 0 && port <= 65535, "Port must be between 1 and 65535");
186+
checkNotNull(connectTimeout);
187+
checkNotNull(readTimeout);
188+
189+
logger.atFine().log(
190+
"Creating SSL socket to %s:%d with connect timeout %s, read timeout %s",
191+
address.getHostAddress(), port, connectTimeout, readTimeout);
192+
193+
// Create a plain socket first, connect with timeout, then wrap with SSL
194+
Socket plainSocket = socketFactory.createSocket();
195+
configureAndConnect(
196+
plainSocket, new InetSocketAddress(address, port), connectTimeout, readTimeout);
197+
198+
SSLSocket sslSocket =
199+
(SSLSocket)
200+
sslSocketFactory.createSocket(plainSocket, address.getHostAddress(), port, true);
201+
sslSocket.setSoTimeout((int) readTimeout.toMillis());
202+
sslSocket.startHandshake();
203+
204+
return sslSocket;
205+
}
206+
207+
@Override
208+
public SSLSocket wrapWithSsl(Socket socket, String host, int port, boolean autoClose)
209+
throws IOException {
210+
checkNotNull(socket);
211+
checkNotNull(host);
212+
checkArgument(port > 0 && port <= 65535, "Port must be between 1 and 65535");
213+
214+
logger.atFine().log("Wrapping existing socket with SSL for host %s:%d", host, port);
215+
216+
SSLSocket sslSocket = (SSLSocket) sslSocketFactory.createSocket(socket, host, port, autoClose);
217+
// Preserve the timeout from the original socket if set, otherwise use default
218+
int originalTimeout = socket.getSoTimeout();
219+
if (originalTimeout > 0) {
220+
sslSocket.setSoTimeout(originalTimeout);
221+
} else {
222+
sslSocket.setSoTimeout((int) defaultReadTimeout.toMillis());
223+
}
224+
sslSocket.startHandshake();
225+
226+
return sslSocket;
227+
}
228+
229+
@Override
230+
public Duration getDefaultConnectTimeout() {
231+
return defaultConnectTimeout;
232+
}
233+
234+
@Override
235+
public Duration getDefaultReadTimeout() {
236+
return defaultReadTimeout;
237+
}
238+
239+
/**
240+
* Configures socket options and connects to the specified address with timeout.
241+
*
242+
* @param socket the socket to configure and connect
243+
* @param address the address to connect to
244+
* @param connectTimeout the timeout for establishing the connection
245+
* @param readTimeout the timeout for read operations
246+
* @throws IOException if an I/O error occurs
247+
*/
248+
private void configureAndConnect(
249+
Socket socket, InetSocketAddress address, Duration connectTimeout, Duration readTimeout)
250+
throws IOException {
251+
// Set read timeout before connecting
252+
socket.setSoTimeout((int) readTimeout.toMillis());
253+
254+
// Enable TCP keep-alive to detect dead connections
255+
socket.setKeepAlive(true);
256+
257+
// Disable Nagle's algorithm for better latency in security scanning
258+
socket.setTcpNoDelay(true);
259+
260+
// Connect with timeout
261+
socket.connect(address, (int) connectTimeout.toMillis());
262+
263+
logger.atFine().log(
264+
"Socket connected to %s with SO_TIMEOUT=%dms", address, readTimeout.toMillis());
265+
}
266+
}

0 commit comments

Comments
 (0)