Skip to content

Commit a6750a2

Browse files
committed
feat(stream): add TLS GraphStage engine
Motivation: The stream TLS path still depends on the legacy actor/FanoutProcessor infrastructure. A GraphStage engine is needed for the stream internals while preserving the existing Pekko TLSActor SSLEngine state machine semantics without changing the legacy actor implementation. Modification: - Add TlsGraphStage as a GraphStage adapter for the existing Pekko TLS pump phases. - Reuse the Pekko TCP direct BufferPool for TLS transport buffers and allocate application buffers from SSLEngine session sizes. - Add a pekko.stream.materializer.tls.engine selector with legacy-actor as the default and graph-stage as the opt-in engine. - Run the shared TLS regression matrix against both legacy and GraphStage paths and add focused GraphStage edge-case coverage. - Add TLS JMH benchmarks for cold handshake and warm round-trip scenarios. Result: The GraphStage path is opt-in, the legacy TLSActor remains untouched, and TLS close, truncation, renegotiation, failure-alert, and TLS 1.3 behavior are covered by regression tests. Tests: - stream / scalafmtCheck - stream-tests / scalafmtCheck - bench-jmh / scalafmtCheck - stream / Test / compile - stream-tests / Test / testOnly org.apache.pekko.stream.io.TlsSpec org.apache.pekko.stream.io.TlsGraphStageSpec org.apache.pekko.stream.io.TlsGraphStageEdgeCasesSpec org.apache.pekko.stream.io.TlsGraphStageIsolatedSpec - git diff --check - red-flag rg scan for prior suspicious port markers References: - #2878 - #2860
1 parent ff79912 commit a6750a2

13 files changed

Lines changed: 2036 additions & 40 deletions

File tree

actor/src/main/scala/org/apache/pekko/util/ByteString.scala

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,10 @@ object ByteString {
413413
override def copyToBuffer(buffer: ByteBuffer): Int =
414414
writeToBuffer(buffer, offset = 0)
415415

416+
/** INTERNAL API: Specialized for internal use, copying from an offset without slicing. */
417+
private[pekko] override def copyToBuffer(buffer: ByteBuffer, offset: Int): Int =
418+
writeToBuffer(buffer, offset)
419+
416420
/** INTERNAL API: Specialized for internal use, writing multiple ByteString1C into the same ByteBuffer. */
417421
private[pekko] def writeToBuffer(buffer: ByteBuffer, offset: Int): Int = {
418422
val copyLength = Math.max(0, Math.min(buffer.remaining, length - offset))
@@ -550,13 +554,17 @@ object ByteString {
550554
}
551555

552556
override def copyToBuffer(buffer: ByteBuffer): Int =
553-
writeToBuffer(buffer)
557+
writeToBuffer(buffer, offset = 0)
558+
559+
/** INTERNAL API: Specialized for internal use, copying from an offset without slicing. */
560+
private[pekko] override def copyToBuffer(buffer: ByteBuffer, offset: Int): Int =
561+
writeToBuffer(buffer, offset)
554562

555563
/** INTERNAL API: Specialized for internal use, writing multiple ByteString1C into the same ByteBuffer. */
556-
private[pekko] def writeToBuffer(buffer: ByteBuffer): Int = {
557-
val copyLength = Math.min(buffer.remaining, length)
564+
private[pekko] def writeToBuffer(buffer: ByteBuffer, offset: Int): Int = {
565+
val copyLength = Math.max(0, Math.min(buffer.remaining, length - offset))
558566
if (copyLength > 0) {
559-
buffer.put(bytes, startIndex, copyLength)
567+
buffer.put(bytes, startIndex + offset, copyLength)
560568
}
561569
copyLength
562570
}
@@ -944,12 +952,28 @@ object ByteString {
944952

945953
def isCompact: Boolean = if (bytestrings.length == 1) bytestrings.head.isCompact else false
946954

947-
override def copyToBuffer(buffer: ByteBuffer): Int = {
948-
val it = bytestrings.iterator
955+
override def copyToBuffer(buffer: ByteBuffer): Int =
956+
copyToBuffer(buffer, offset = 0)
957+
958+
/** INTERNAL API: Specialized for internal use, copying from an offset without slicing. */
959+
private[pekko] override def copyToBuffer(buffer: ByteBuffer, offset: Int): Int = {
960+
var remainingOffset = offset
949961
var written = 0
950-
while (it.hasNext && buffer.hasRemaining) {
951-
written += it.next().writeToBuffer(buffer)
962+
var i = 0
963+
val count = bytestrings.length
964+
965+
while (i < count && buffer.hasRemaining) {
966+
val fragment = bytestrings(i)
967+
val fragmentLength = fragment.length
968+
if (remainingOffset >= fragmentLength) {
969+
remainingOffset -= fragmentLength
970+
} else {
971+
written += fragment.writeToBuffer(buffer, remainingOffset)
972+
remainingOffset = 0
973+
}
974+
i += 1
952975
}
976+
953977
written
954978
}
955979

@@ -1770,6 +1794,11 @@ sealed abstract class ByteString
17701794
*/
17711795
def copyToBuffer(@nowarn("msg=never used") buffer: ByteBuffer): Int
17721796

1797+
/** INTERNAL API: Copy bytes to a ByteBuffer from a ByteString offset without allocating a slice. */
1798+
private[pekko] def copyToBuffer(buffer: ByteBuffer, offset: Int): Int =
1799+
if (offset <= 0) copyToBuffer(buffer)
1800+
else drop(offset).copyToBuffer(buffer)
1801+
17731802
/**
17741803
* Create a new ByteString with all contents compacted into a single,
17751804
* full byte array.
2.34 KB
Binary file not shown.
857 Bytes
Binary file not shown.
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.pekko.stream.io
19+
20+
import java.security.{ KeyStore, SecureRandom }
21+
import java.util.concurrent.TimeUnit
22+
import javax.net.ssl.{ KeyManagerFactory, SSLContext, SSLEngine, SSLSession, TrustManagerFactory }
23+
24+
import scala.concurrent.Await
25+
import scala.concurrent.duration._
26+
import scala.util.{ Success, Try }
27+
28+
import com.typesafe.config.{ Config, ConfigFactory }
29+
import org.openjdk.jmh.annotations._
30+
31+
import org.apache.pekko
32+
import pekko.NotUsed
33+
import pekko.actor.ActorSystem
34+
import pekko.stream._
35+
import pekko.stream.TLSProtocol._
36+
import pekko.stream.impl.io.{ TlsGraphStage, TlsModule }
37+
import pekko.stream.scaladsl._
38+
import pekko.util.ByteString
39+
40+
/**
41+
* JMH benchmark comparing the legacy actor-based TLS path (`TlsModule`) to the
42+
* GraphStage path (`TlsGraphStage`).
43+
*
44+
* - `warmRoundTrip` drives a fixed payload through a client+server echo loop, with
45+
* the SSL engines reused across invocations (one materialization per @Setup).
46+
* This isolates per-record encrypt/decrypt overhead — handshake cost is amortized
47+
* away by the iteration count.
48+
* - `coldHandshake` measures the cost of materializing a fresh client+server pair
49+
* and completing the TLS handshake before transferring a tiny payload. This
50+
* represents short-lived connections (e.g. HTTPS request/response).
51+
*
52+
* Run with:
53+
* {{{
54+
* sbt "bench-jmh/Jmh/run -i 5 -wi 3 -f1 -t1 .*TlsBenchmark.*"
55+
* }}}
56+
*/
57+
@State(Scope.Benchmark)
58+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
59+
@BenchmarkMode(Array(Mode.Throughput))
60+
@Warmup(iterations = 5, time = 1, timeUnit = TimeUnit.SECONDS)
61+
@Measurement(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
62+
@Fork(1)
63+
class TlsBenchmark {
64+
65+
private val config: Config = ConfigFactory.parseString("""
66+
pekko {
67+
log-config-on-start = off
68+
log-dead-letters-during-shutdown = off
69+
stdout-loglevel = "OFF"
70+
loglevel = "OFF"
71+
actor.default-dispatcher {
72+
throughput = 1024
73+
}
74+
actor.default-mailbox {
75+
mailbox-type = "org.apache.pekko.dispatch.SingleConsumerOnlyUnboundedMailbox"
76+
}
77+
}""".stripMargin).withFallback(ConfigFactory.load())
78+
79+
implicit var system: ActorSystem = _
80+
private var sslContext: SSLContext = _
81+
private var ciphers: Array[String] = _
82+
83+
@Param(Array("legacy", "graphstage"))
84+
var implementation: String = _
85+
86+
// 256 B = control message; 4 KiB = typical HTTP request; 64 KiB = streaming chunk
87+
@Param(Array("256", "4096", "65536"))
88+
var payloadSize: Int = _
89+
90+
private var payload: ByteString = _
91+
private var payloads: scala.collection.immutable.IndexedSeq[SslTlsOutbound] = _
92+
93+
@Setup
94+
def setup(): Unit = {
95+
system = ActorSystem("TlsBenchmark", config)
96+
SystemMaterializer(system).materializer
97+
98+
sslContext = TlsBenchmark.initSslContext("TLSv1.2")
99+
ciphers = TlsBenchmark.TLS12Ciphers.toArray
100+
101+
payload = ByteString(Array.fill[Byte](payloadSize)('a'.toByte))
102+
payloads = (0 until TlsBenchmark.WarmRoundTripRecords).map(_ => SendBytes(payload))
103+
}
104+
105+
@TearDown
106+
def shutdown(): Unit = {
107+
Await.result(system.terminate(), 10.seconds)
108+
}
109+
110+
private def engine(role: TLSRole): SSLEngine = {
111+
val e = sslContext.createSSLEngine()
112+
e.setUseClientMode(role == Client)
113+
e.setEnabledCipherSuites(ciphers)
114+
e.setEnabledProtocols(Array("TLSv1.2"))
115+
e
116+
}
117+
118+
private def makeBidi(
119+
role: TLSRole,
120+
closing: TLSClosing,
121+
verifySession: SSLSession => Try[Unit] = _ => Success(()))
122+
: BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, NotUsed] =
123+
implementation match {
124+
case "legacy" =>
125+
BidiFlow.fromGraph(
126+
TlsModule(Attributes.none, () => engine(role), verifySession, closing))
127+
case "graphstage" =>
128+
graphStageBidi(role, closing, verifySession)
129+
}
130+
131+
private def graphStageBidi(
132+
role: TLSRole,
133+
closing: TLSClosing,
134+
verifySession: SSLSession => Try[Unit])
135+
: BidiFlow[SslTlsOutbound, ByteString, ByteString, SslTlsInbound, NotUsed] =
136+
BidiFlow
137+
.fromGraph(new TlsGraphStage(() => engine(role), verifySession, closing))
138+
.withAttributes(TlsGraphStage.StreamTlsAttributes)
139+
140+
/**
141+
* Warm round-trip: 1000 payloads through a fresh client+server pair. The
142+
* handshake is amortized over the records and the sink counts bytes instead
143+
* of concatenating payloads, keeping the measurement focused on TLS work.
144+
*/
145+
@Benchmark
146+
@OperationsPerInvocation(1000)
147+
def warmRoundTrip(): Unit = {
148+
val records = TlsBenchmark.WarmRoundTripRecords
149+
val expected = payload.size * records
150+
val client = makeBidi(Client, IgnoreComplete)
151+
val server = makeBidi(Server, IgnoreComplete)
152+
val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) => SendBytes(b) }
153+
154+
val done = Source(payloads)
155+
.via(client.atop(server.reversed).join(echo))
156+
.collect { case SessionBytes(_, b) => b }
157+
.scan(0)((acc, b) => acc + b.size)
158+
.dropWhile(_ < expected)
159+
.runWith(Sink.headOption)
160+
161+
Await.result(done, 30.seconds)
162+
}
163+
164+
/**
165+
* Cold handshake: each invocation builds a fresh client+server pair and
166+
* completes the handshake by exchanging one configured payload. The sink
167+
* counts bytes only, which avoids charging ByteString concatenation to the
168+
* TLS implementation being tested.
169+
*/
170+
@Benchmark
171+
def coldHandshake(): Unit = {
172+
val client = makeBidi(Client, IgnoreComplete)
173+
val server = makeBidi(Server, IgnoreComplete)
174+
val expected = payload.size
175+
val echo = Flow[SslTlsInbound].collect { case SessionBytes(_, b) => SendBytes(b) }
176+
177+
val done = Source
178+
.single[SslTlsOutbound](SendBytes(payload))
179+
.via(client.atop(server.reversed).join(echo))
180+
.collect { case SessionBytes(_, b) => b }
181+
.scan(0)((acc, b) => acc + b.size)
182+
.dropWhile(_ < expected)
183+
.runWith(Sink.headOption)
184+
185+
Await.result(done, 30.seconds)
186+
}
187+
}
188+
189+
object TlsBenchmark {
190+
191+
final val WarmRoundTripRecords = 1000
192+
193+
val TLS12Ciphers: Set[String] = Set(
194+
"TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
195+
"TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384")
196+
197+
def initSslContext(protocol: String): SSLContext = {
198+
val password = "changeme"
199+
200+
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType)
201+
keyStore.load(getClass.getResourceAsStream("/keystore"), password.toCharArray)
202+
203+
val trustStore = KeyStore.getInstance(KeyStore.getDefaultType)
204+
trustStore.load(getClass.getResourceAsStream("/truststore"), password.toCharArray)
205+
206+
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm)
207+
keyManagerFactory.init(keyStore, password.toCharArray)
208+
209+
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
210+
trustManagerFactory.init(trustStore)
211+
212+
val context = SSLContext.getInstance(protocol)
213+
context.init(keyManagerFactory.getKeyManagers, trustManagerFactory.getTrustManagers, new SecureRandom)
214+
context
215+
}
216+
}

0 commit comments

Comments
 (0)