From 2c20d57047d41ef507256096035eba0973e0b08d Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Sat, 18 Apr 2026 09:10:44 +0000 Subject: [PATCH 1/9] add direct worker --- udf/worker/README.md | 158 +++++- udf/worker/core/pom.xml | 4 + .../udf/worker/core/WorkerConnection.scala | 45 ++ .../udf/worker/core/WorkerDispatcher.scala | 6 +- .../spark/udf/worker/core/WorkerSession.scala | 96 +++- .../core/direct/DirectWorkerDispatcher.scala | 429 +++++++++++++++ .../core/direct/DirectWorkerProcess.scala | 136 +++++ .../core/direct/DirectWorkerSession.scala | 56 ++ .../core/DirectWorkerDispatcherSuite.scala | 512 ++++++++++++++++++ .../worker/core/WorkerAbstractionSuite.scala | 25 - .../proto/src/main/protobuf/common.proto | 7 + .../proto/src/main/protobuf/worker_spec.proto | 29 +- .../udf/worker/WorkerSpecification.scala | 27 - 13 files changed, 1436 insertions(+), 94 deletions(-) create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala create mode 100644 udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala delete mode 100644 udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala delete mode 100644 udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala diff --git a/udf/worker/README.md b/udf/worker/README.md index fa27430b62b62..8f0ef476d21d9 100644 --- a/udf/worker/README.md +++ b/udf/worker/README.md @@ -5,44 +5,162 @@ Package structure for the UDF worker framework described in ## Overview -Spark processes a UDF by first obtaining a **WorkerDispatcher** from the worker -specification (plus context such as security scope). The dispatcher manages the -actual worker processes behind the scenes -- pooling, reuse, and termination are -all invisible to Spark. +Spark processes a UDF by obtaining a **WorkerDispatcher** from a worker +specification. The dispatcher manages workers behind the scenes. From +the dispatcher, Spark gets a **WorkerSession** -- one per UDF invocation -- +with an Iterator-to-Iterator `process` API that streams input batches +through the worker and returns result batches. -From the dispatcher, Spark gets a **WorkerSession**, which represents one single -UDF execution and can carry per-execution state. A WorkerSession is not 1-to-1 -mapped to an actual worker -- multiple sessions may share the same underlying -worker when it is reused. Worker reuse is managed by each dispatcher -implementation based on the worker specification. +``` +WorkerSpecification -- how to create and configure workers + | + v +WorkerDispatcher -- manages workers, creates sessions + | + v +WorkerSession -- one UDF execution + | 1. session.init(InitMessage(payload, inputSchema, outputSchema)) + | 2. val results = session.process(inputBatches) + | 3. session.close() +``` + +How workers are created depends on the dispatcher implementation. The +framework currently provides **direct worker creation** (local OS +processes) and is designed for future **indirect creation** (via a +provisioning service or daemon). ## Sub-packages ``` udf/worker/ -├── proto/ Protobuf definition of the worker specification -│ (UDFWorkerSpecification). -│ WorkerSpecification -- typed Scala wrapper around the protobuf spec. -└── core/ Engine-side APIs (all @Experimental): - WorkerDispatcher -- manages workers for one spec; creates sessions. - WorkerSession -- represents one single UDF execution. - WorkerSecurityScope -- security boundary for connection pooling. +├── proto/ +│ worker_spec.proto -- UDFWorkerSpecification protobuf (+ generated Java classes) +│ common.proto -- shared enums (UDFWorkerDataFormat, etc.) +│ +└── core/ -- abstract interfaces + WorkerDispatcher.scala -- creates sessions, manages worker lifecycle + WorkerSession.scala -- per-UDF init/process/cancel/close + InitMessage + WorkerConnection.scala -- transport channel abstraction + WorkerSecurityScope.scala -- security boundary for worker pooling + │ + └── direct/ -- "direct" creation: local OS processes + DirectWorkerDispatcher.scala -- spawns processes, env lifecycle + DirectWorkerProcess.scala -- OS process + connection + UDS socket + DirectWorkerSession.scala -- session backed by a direct process +``` + +The `core/` package defines abstract interfaces that are independent of how +workers are created. The `core/direct/` sub-package implements "direct" +worker creation where Spark spawns local OS processes. Future packages +(e.g., `core/indirect/`) can implement alternative creation modes such as +obtaining workers from a provisioning service or daemon. + +### Direct worker creation + +`DirectWorkerDispatcher` spawns worker processes locally. On the first +session, it runs the optional environment lifecycle callables from the +`WorkerSpecification`: + +- **`environmentVerification`** -- checks if the environment is ready + (exit 0 = ready). When it succeeds, installation is skipped. +- **`installation`** -- prepares the environment (installs runtime, + dependencies, worker binaries). Only runs when verification is absent + or fails. +- **`environmentCleanup`** -- runs after the dispatcher is closed or on + JVM shutdown to clean up temporary resources. + +Environment setup runs **once per dispatcher** (not per session). +Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed. + +## Basic usage (Scala) + +```scala +import org.apache.spark.udf.worker.{ + DirectWorker, ProcessCallable, UDFProtoCommunicationPattern, + UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification, + UnixDomainSocket, WorkerCapabilities, WorkerConnection, WorkerEnvironment} +import org.apache.spark.udf.worker.core._ + +// 1. Define a worker spec (direct creation mode). +val spec = UDFWorkerSpecification.newBuilder() + .setEnvironment(WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("python").addCommand("-c").addCommand("import my_udf_worker").build()) + .setInstallation(ProcessCallable.newBuilder() + .addCommand("pip").addCommand("install").addCommand("my_udf_worker").build()) + .build()) + .setCapabilities(WorkerCapabilities.newBuilder() + .addSupportedDataFormats(UDFWorkerDataFormat.ARROW) + .addSupportedCommunicationPatterns( + UDFProtoCommunicationPattern.BIDIRECTIONAL_STREAMING) + .build()) + .setDirect(DirectWorker.newBuilder() + .setRunner(ProcessCallable.newBuilder() + .addCommand("python").addCommand("-m").addCommand("my_udf_worker").build()) + .setProperties(UDFWorkerProperties.newBuilder() + .addConnections(WorkerConnection.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .build()) + .build()) + .build() + +// 2. Create a dispatcher. Use a protocol-specific subclass of +// DirectWorkerDispatcher (e.g., gRPC over UDS). +val dispatcher: WorkerDispatcher = ... + +// 3. Create a session for one UDF execution. +val session = dispatcher.createSession(securityScope = None) +try { + // 4. Initialize with the serialized function and schemas. + session.init(InitMessage( + functionPayload = serializedFunction, + inputSchema = arrowInputSchema, + outputSchema = arrowOutputSchema)) + + // 5. Process data -- Iterator in, Iterator out. + val results: Iterator[Array[Byte]] = + session.process(inputBatches) + + // Consume results lazily. + results.foreach(processResultBatch) +} finally { + session.close() +} + +// 6. Shut down all workers. +dispatcher.close() ``` ## Build SBT: ``` -build/sbt "udf-worker-core/compile" -build/sbt "udf-worker-core/test" +build/sbt "udf-worker-proto/compile" "udf-worker-core/compile" ``` Maven: ``` -./build/mvn -pl udf/worker/proto,udf/worker/core -am compile -./build/mvn -pl udf/worker/proto,udf/worker/core -am test +build/mvn compile -pl udf/worker/proto,udf/worker/core -am ``` +## Test + +SBT: +``` +build/sbt "udf-worker-core/test" +``` + +## Current status + +This is the **first MVP** providing the core abstraction layer and the +direct worker dispatcher. +The following are left as TODOs: + +- **Connection pooling** -- reuse workers across sessions +- **Security scope isolation** -- partition pools by `WorkerSecurityScope` +- **Indirect worker creation** -- obtain workers from a service or daemon +- **Protocol-specific implementations** -- e.g., gRPC over UDS + ## Design references * [SPIP Language-agnostic UDF Protocol for Spark](https://docs.google.com/document/d/19Whzq127QxVt2Luk0EClgaDtcpBsFUp67NcVdKKyPF8/edit?tab=t.0) diff --git a/udf/worker/core/pom.xml b/udf/worker/core/pom.xml index 69088d284365f..724ff9a1e0d2e 100644 --- a/udf/worker/core/pom.xml +++ b/udf/worker/core/pom.xml @@ -47,6 +47,10 @@ spark-udf-worker-proto_${scala.binary.version} ${project.version} + + org.slf4j + slf4j-api + org.scala-lang scala-library diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala new file mode 100644 index 0000000000000..1f78c961c9f75 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * A transport-level connection to a running UDF worker process. + * + * A [[WorkerConnection]] represents the communication channel between the + * Spark engine and a single worker process (e.g., a gRPC channel over a + * Unix domain socket, or a raw TCP socket). It is owned by a worker + * process wrapper (e.g., [[direct.DirectWorkerProcess]]) and shared + * across all [[WorkerSession]]s that use that process. + * + * Implementations wrap the concrete transport and expose only lifecycle + * methods. Data transmission happens at the [[WorkerSession]] level, not + * here -- this class is solely about whether the channel is open. + * + * '''Relationship to other classes (direct creation mode):''' + * {{{ + * DirectWorkerProcess 1 --- 1 WorkerConnection (transport over UDS) + * DirectWorkerProcess 1 --- * WorkerSession (UDF executions) + * }}} + */ +@Experimental +abstract class WorkerConnection extends AutoCloseable { + /** Returns true if the underlying transport channel is still usable. */ + def isActive: Boolean +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala index 58fabbaea00df..008cfc2993a09 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerDispatcher.scala @@ -17,11 +17,11 @@ package org.apache.spark.udf.worker.core import org.apache.spark.annotation.Experimental -import org.apache.spark.udf.worker.WorkerSpecification +import org.apache.spark.udf.worker.UDFWorkerSpecification /** * :: Experimental :: - * Manages workers for a single [[WorkerSpecification]] and hides worker details from Spark. + * Manages workers for a single [[UDFWorkerSpecification]] and hides worker details from Spark. * * A [[WorkerDispatcher]] is created from a worker specification (plus context such * as security scope). It owns the underlying worker processes and connections, @@ -31,7 +31,7 @@ import org.apache.spark.udf.worker.WorkerSpecification @Experimental trait WorkerDispatcher extends AutoCloseable { - def workerSpec: WorkerSpecification + def workerSpec: UDFWorkerSpecification /** * Creates a [[WorkerSession]] that maps to one single UDF execution. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala index 83c392a895b66..3724e37a6dc6f 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala @@ -20,18 +20,96 @@ import org.apache.spark.annotation.Experimental /** * :: Experimental :: - * Represents one single UDF execution. + * Carries all information needed to initialize a UDF execution on a worker. * - * A [[WorkerSession]] is obtained from [[WorkerDispatcher#createSession]] and - * can carry per-execution state for that UDF invocation. Implementations may - * add concrete data-processing methods and lifecycle hooks as needed. + * This message is passed to [[WorkerSession#init]] and contains the function + * definition, schemas, and any additional configuration. It is designed to be + * extended in future versions with new fields (e.g., UDF shape, data format, + * Spark context metadata, chaining information) without breaking existing + * worker implementations. * - * A WorkerSession is not 1-to-1 mapped to an actual worker process. Multiple - * WorkerSessions may be backed by the same worker when the worker is reused. - * Worker reuse and pooling are managed by each [[WorkerDispatcher]] - * implementation based on the [[WorkerSpecification]]. + * @param functionPayload serialized function (e.g., pickled Python, JVM bytes) + * @param inputSchema serialized input schema (e.g., Arrow schema bytes) + * @param outputSchema serialized output schema (e.g., Arrow schema bytes) + * @param properties additional key-value configuration. Can carry + * protocol-specific or engine-specific metadata that + * does not yet have a dedicated field. + */ +@Experimental +case class InitMessage( + functionPayload: Array[Byte], + inputSchema: Array[Byte], + outputSchema: Array[Byte], + properties: Map[String, String] = Map.empty) + +/** + * :: Experimental :: + * One UDF execution on a worker -- the main interface Spark uses to run UDFs. + * + * A [[WorkerSession]] is the '''per-UDF-invocation''' handle that Spark + * obtains from [[WorkerDispatcher#createSession]]. It carries the full + * init / data-stream / finish lifecycle for a single UDF evaluation. + * + * A [[WorkerSession]] does ''not'' own the underlying worker or its + * transport channel -- those are managed by the [[WorkerDispatcher]]. + * Multiple sessions may share the same worker when the worker supports + * concurrency. + * + * '''Usage:''' + * {{{ + * val session = dispatcher.createSession(securityScope = None) + * try { + * session.init(InitMessage(functionPayload, inputSchema, outputSchema)) + * val results = session.process(inputBatches) + * results.foreach(handleBatch) + * } finally { + * session.close() + * } + * }}} + * + * '''Lifecycle:''' + * - [[init]] must be called exactly once before [[process]]. + * - [[process]] must be called at most once per session. + * - [[close]] must always be called (use try-finally). + * - [[cancel]] may be called at any time to abort execution. */ @Experimental abstract class WorkerSession extends AutoCloseable { - override def close(): Unit = {} + + /** + * Initializes the UDF execution. Must be called exactly once before + * [[process]]. + * + * @param message the initialization parameters including the serialized + * function, input/output schemas, and configuration. + */ + def init(message: InitMessage): Unit + + /** + * Processes input data through the worker and returns results. + * + * Follows Spark's Iterator-to-Iterator pattern: input batches are streamed + * to the worker, and result batches are lazily pulled from the returned + * iterator. The session sends a Finish signal to the worker when the input + * iterator is exhausted. + * + * Must be called at most once per session. + * + * @param input iterator of raw input data batches (e.g., Arrow IPC) + * @return iterator of raw result data batches + */ + def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] + + /** + * Requests cancellation of the current UDF execution. + * + * '''Thread-safety:''' implementations must allow [[cancel]] to be called + * from a thread different from the one driving [[process]] (typically a + * task interruption thread). It may be invoked at any point after + * [[init]] and should be a no-op if execution has already finished. + */ + def cancel(): Unit + + /** Closes this session and releases resources. */ + override def close(): Unit } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala new file mode 100644 index 0000000000000..58e12e76afb20 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.io.File +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.util.UUID +import java.util.concurrent.TimeUnit + +import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ +import scala.util.control.NonFatal + +import org.slf4j.{Logger, LoggerFactory} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification} +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, + WorkerSecurityScope, WorkerSession} +import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, + EnvironmentState} + +/** + * :: Experimental :: + * A [[WorkerDispatcher]] that creates workers by spawning local OS processes + * ("direct" creation mode from the worker specification). + * + * On the first [[createSession]], the dispatcher ensures the environment is + * ready (verify / install) and registers the cleanup hook. Currently spawns + * a fresh worker per session; pooling/reuse is TODO. + * + * Subclasses implement [[createConnection]] and [[createSessionForWorker]] + * to provide protocol-specific behavior (e.g., gRPC, raw sockets). + * + * For workers obtained through a provisioning service or daemon (indirect + * creation), see the `indirect` package (TODO). + * + * @param workerSpec worker specification (proto) + * @param logger SLF4J logger for dispatcher-internal messages. Callers may + * inject their own logger (e.g., backed by Spark's `Logging` + * trait in an engine context) to route messages through their + * own logging configuration. Defaults to an SLF4J logger for + * this class. + */ +@Experimental +abstract class DirectWorkerDispatcher( + override val workerSpec: UDFWorkerSpecification, + protected val logger: Logger = + LoggerFactory.getLogger(classOf[DirectWorkerDispatcher])) + extends WorkerDispatcher { + + // TODO: Connection pooling -- reuse idle workers across sessions. + // TODO: Security scope isolation -- partition pool by WorkerSecurityScope. + + // Multi-connection workers (e.g., a separate control channel) are a future + // extension; today the proto field is `repeated` but the engine requires + // exactly one. TCP transport is declared in the proto but not yet + // implemented; the engine currently only supports UDS. + { + val props = workerSpec.getDirect.getProperties + val n = props.getConnectionsCount + require(n == 1, + s"DirectWorker.properties.connections must have exactly one entry, got $n") + val conn = props.getConnections(0) + require(conn.hasUnixDomainSocket, + "DirectWorker currently only supports UNIX domain socket transport, " + + s"got ${conn.getTransportCase}") + } + + // worker_spec.proto documents that verification is only meaningful together + // with installation -- verification exists so the engine can skip running + // installation when the environment is already prepared. A verification + // callable with no installation callable would either always succeed (no-op) + // or always fail (worker spawn then fails) -- both user errors worth + // catching at spec-validation time. + { + val env = workerSpec.getEnvironment + require(!env.hasEnvironmentVerification || env.hasInstallation, + "WorkerEnvironment.environment_verification requires installation to be set") + } + + private val SOCKET_POLL_INTERVAL_MS = 100L + private val DEFAULT_INIT_TIMEOUT_MS = 10000L + private val DEFAULT_CALLABLE_TIMEOUT_MS = 120000L + private val DEFAULT_GRACEFUL_TIMEOUT_MS = 5000L + private val PROCESS_OUTPUT_TAIL_LINES = 50 + + /** + * Maximum time to wait for a setup/verify/cleanup callable to finish. + * Subclasses may override this to accommodate slow installation steps + * (e.g., a large dependency install). Defaults to 120 seconds. + */ + protected def callableTimeoutMs: Long = DEFAULT_CALLABLE_TIMEOUT_MS + + private val initTimeoutMs: Long = { + val props = workerSpec.getDirect.getProperties + if (props.hasInitializationTimeoutMs && props.getInitializationTimeoutMs > 0) { + props.getInitializationTimeoutMs.toLong + } else { + DEFAULT_INIT_TIMEOUT_MS + } + } + + private val gracefulTimeoutMs: Long = { + val props = workerSpec.getDirect.getProperties + if (props.hasGracefulTerminationTimeoutMs && props.getGracefulTerminationTimeoutMs > 0) { + props.getGracefulTerminationTimeoutMs.toLong + } else { + DEFAULT_GRACEFUL_TIMEOUT_MS + } + } + + // The socket directory is removed explicitly in close(). deleteOnExit is + // deliberately not registered: it is redundant with the explicit cleanup, + // it leaks memory in long-lived JVMs (the JDK retains the path string for + // the process lifetime), and it only works on empty directories. + private val socketDir = Files.createTempDirectory("spark-udf-worker") + private val workers = new ArrayBuffer[DirectWorkerProcess]() + private val workersLock = new Object + + @volatile private var environmentState: EnvironmentState = EnvironmentState.Pending + private val environmentLock = new Object + private var cleanupHook: Option[Thread] = None + + /** Creates a protocol-specific connection to a worker at the given socket path. */ + protected def createConnection(socketPath: String): WorkerConnection + + /** Creates a protocol-specific session for the given worker. */ + protected def createSessionForWorker(worker: DirectWorkerProcess): WorkerSession + + override def createSession( + securityScope: Option[WorkerSecurityScope]): WorkerSession = { + // Pooling keyed by security scope is not yet implemented. Accepting a + // non-None scope here would silently create a one-off worker and give + // the caller a false expectation of isolation, so reject it until the + // dispatcher actually honors the scope. + require(securityScope.isEmpty, + "securityScope is not supported yet; pass None until pooling lands") + ensureEnvironmentReady() + val worker = spawnWorker() + workersLock.synchronized { workers += worker } + worker.acquireSession() + try { + createSessionForWorker(worker) + } catch { + case e: Exception => + worker.releaseSession() + workersLock.synchronized { workers -= worker } + try { + worker.close() + } catch { + case NonFatal(closeEx) => + logger.warn("Error closing worker after session creation failed", closeEx) + } + throw e + } + } + + override def close(): Unit = { + // TODO: Close workers in parallel. Worst-case shutdown today is + // N * gracefulTimeoutMs because each worker waits for SIGTERM to + // complete before the next one is signalled. A small pool of + // short-lived threads would bound shutdown to ~gracefulTimeoutMs. + workersLock.synchronized { + workers.foreach { w => + try { + w.close() + } catch { + case NonFatal(e) => + logger.warn(s"Error closing worker at ${w.socketPath}", e) + } + } + workers.clear() + } + try { + val dir = socketDir.toFile + if (dir.exists()) { + val remaining = dir.listFiles() + if (remaining != null) remaining.foreach(_.delete()) + dir.delete() + } + } catch { + case NonFatal(e) => + logger.warn(s"Error cleaning up socket directory $socketDir", e) + } + deregisterEnvironmentCleanupHook() + runEnvironmentCleanup() + } + + // -- Environment lifecycle ------------------------------------------------- + + // TODO: Handle permanently unrecoverable environment failures (e.g., wrong + // CPU architecture, unavailable system resources) differently from transient + // ones. Currently all failures are treated as permanent, but some callers + // may want to distinguish retriable vs. fatal failures. + private def ensureEnvironmentReady(): Unit = { + environmentLock.synchronized { + environmentState match { + case EnvironmentState.Ready | EnvironmentState.CleanedUp => + return + case EnvironmentState.Failed(msg) => + throw new RuntimeException(s"Environment setup previously failed: $msg") + case EnvironmentState.Pending => + } + + val env = workerSpec.getEnvironment + val verified = env.hasEnvironmentVerification && + runCallable(env.getEnvironmentVerification).exitCode == 0 + if (!verified && env.hasInstallation) { + val result = runCallable(env.getInstallation) + if (result.exitCode != 0) { + val detail = s"exit code ${result.exitCode}\n${result.outputTail}" + environmentState = EnvironmentState.Failed(detail) + throw new RuntimeException( + s"Environment installation failed with $detail") + } + } + + registerEnvironmentCleanupHook() + environmentState = EnvironmentState.Ready + } + } + + /** + * Registers the JVM shutdown hook that runs the cleanup callable. + * + * '''Caller must hold `environmentLock`''' -- this method reads and + * writes `cleanupHook` without its own synchronization. It is only + * called from `ensureEnvironmentReady`, which already owns the lock. + */ + private def registerEnvironmentCleanupHook(): Unit = { + if (cleanupHook.isDefined) return + if (workerSpec.getEnvironment.hasEnvironmentCleanup) { + val hook = new Thread(() => runEnvironmentCleanup(), "udf-env-cleanup") + cleanupHook = Some(hook) + // scalastyle:off runtimeaddshutdownhook + Runtime.getRuntime.addShutdownHook(hook) + // scalastyle:on runtimeaddshutdownhook + } + } + + private def deregisterEnvironmentCleanupHook(): Unit = { + environmentLock.synchronized { + cleanupHook.foreach { hook => + try { + Runtime.getRuntime.removeShutdownHook(hook) + } catch { + case _: IllegalStateException => // JVM already shutting down + } + cleanupHook = None + } + } + } + + private def runEnvironmentCleanup(): Unit = { + environmentLock.synchronized { + environmentState match { + case EnvironmentState.CleanedUp => return + case _ => + } + if (workerSpec.getEnvironment.hasEnvironmentCleanup) { + try { + val result = runCallable(workerSpec.getEnvironment.getEnvironmentCleanup) + if (result.exitCode != 0) { + logger.warn(s"Environment cleanup exited with code ${result.exitCode}" + + s"\n${result.outputTail}") + } + } catch { + case NonFatal(e) => logger.warn("Environment cleanup failed", e) + } + } + environmentState = EnvironmentState.CleanedUp + } + } + + // -- Process helpers ------------------------------------------------------- + + /** + * Runs a [[ProcessCallable]] synchronously and returns the result. + * Always throws on timeout; callers check `exitCode` for non-timeout failures. + */ + private[core] def runCallable(callable: ProcessCallable): CallableResult = { + val cmd = (callable.getCommandList.asScala ++ callable.getArgumentsList.asScala).toSeq + require(cmd.nonEmpty, + "ProcessCallable must have at least one entry in command or arguments") + val outputFile = Files.createTempFile("udf-callable-", ".log") + try { + val process = launchProcess( + cmd, callable.getEnvironmentVariablesMap.asScala.toMap, outputFile.toFile) + val timeoutMs = callableTimeoutMs + if (!process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) { + process.destroyForcibly() + val tail = readOutputTail(outputFile.toFile) + throw new RuntimeException( + s"Callable timed out after ${timeoutMs}ms: " + + s"${cmd.mkString(" ")}\n$tail") + } + val tail = readOutputTail(outputFile.toFile) + CallableResult(process.exitValue(), tail) + } finally { + Files.deleteIfExists(outputFile) + } + } + + private def spawnWorker(): DirectWorkerProcess = { + val runner = workerSpec.getDirect.getRunner + val baseCmd = (runner.getCommandList.asScala ++ runner.getArgumentsList.asScala).toSeq + require(baseCmd.nonEmpty, + "DirectWorker.runner must have at least one entry in command or arguments") + val workerId = UUID.randomUUID().toString + val socketPath = socketDir.resolve(s"worker-$workerId.sock").toString + // Per the ProcessCallable contract in worker_spec.proto, the engine must + // always pass --id (worker identifier for logs) and --connection (the + // engine-assigned endpoint, format depending on transport). + val cmd = baseCmd ++ Seq("--id", workerId, "--connection", socketPath) + val env = runner.getEnvironmentVariablesMap.asScala.toMap + val outputFile = Files.createTempFile("udf-worker-", ".log") + val process = launchProcess(cmd, env, outputFile.toFile) + + try { + waitForSocket(socketPath, process, outputFile.toFile) + val connection = createConnection(socketPath) + // Ownership of `outputFile` transfers to the DirectWorkerProcess: it + // remains valid for the child's file descriptor and is deleted in + // DirectWorkerProcess.close(). + new DirectWorkerProcess( + process, connection, socketPath, outputFile, gracefulTimeoutMs, logger) + } catch { + case e: Exception => + if (process.isAlive) process.destroyForcibly() + // If the worker (or createConnection) had already created the socket + // file, remove it so it doesn't linger until dispatcher.close(). + try Files.deleteIfExists(new File(socketPath).toPath) catch { + case NonFatal(cleanupEx) => + logger.debug(s"Failed to clean up socket file $socketPath", cleanupEx) + } + Files.deleteIfExists(outputFile) + throw e + } + } + + /** + * Starts an OS process. stdout and stderr are merged and redirected to the + * given file so that output can be read back for error reporting. + */ + private def launchProcess( + command: Seq[String], + env: Map[String, String], + outputFile: File): Process = { + val builder = new ProcessBuilder(command: _*) + env.foreach { case (k, v) => builder.environment().put(k, v) } + builder.redirectErrorStream(true) + builder.redirectOutput(outputFile) + builder.start() + } + + private def waitForSocket( + socketPath: String, + process: Process, + outputFile: File): Unit = { + val file = new File(socketPath) + // Ensure at least one poll attempt even for very small init timeouts, + // so we don't declare a premature timeout before the worker has any + // chance to create the socket. + val maxAttempts = math.max(1, (initTimeoutMs / SOCKET_POLL_INTERVAL_MS).toInt) + var attempts = 0 + while (!file.exists() && attempts < maxAttempts) { + if (!process.isAlive) { + val tail = readOutputTail(outputFile) + throw new RuntimeException( + s"Worker exited with code ${process.exitValue()} " + + s"before creating socket at $socketPath\n$tail") + } + Thread.sleep(SOCKET_POLL_INTERVAL_MS) + attempts += 1 + } + if (!file.exists()) { + val tail = readOutputTail(outputFile) + if (process.isAlive) process.destroyForcibly() + throw new RuntimeException( + s"Worker did not create socket at $socketPath within ${initTimeoutMs}ms\n$tail") + } + } + + private def readOutputTail(file: File): String = { + if (!file.exists() || file.length() == 0) return "" + val src = scala.io.Source.fromFile(file, StandardCharsets.UTF_8.name()) + try { + val lines = src.getLines().toVector + val tail = lines.takeRight(PROCESS_OUTPUT_TAIL_LINES) + if (tail.isEmpty) "" + else "Process output (last lines):\n" + tail.mkString("\n") + } catch { + case NonFatal(e) => + logger.debug(s"Failed to read process output from $file", e) + "" + } finally { + src.close() + } + } +} + +private[direct] object DirectWorkerDispatcher { + /** Result of running a [[ProcessCallable]]. */ + private[core] case class CallableResult(exitCode: Int, outputTail: String) + + private[direct] sealed trait EnvironmentState + private[direct] object EnvironmentState { + case object Pending extends EnvironmentState + case object Ready extends EnvironmentState + case class Failed(detail: String) extends EnvironmentState + case object CleanedUp extends EnvironmentState + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala new file mode 100644 index 0000000000000..b505bb29c1f83 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.io.File +import java.nio.file.{Files, Path} +import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicInteger + +import scala.util.control.NonFatal + +import org.slf4j.{Logger, LoggerFactory} + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.core.WorkerConnection + +/** + * :: Experimental :: + * A locally-spawned OS process running a UDF worker, together with the + * transport connection to it. + * + * A [[DirectWorkerProcess]] combines three things: + * - An OS '''process''' (the worker binary, started by the dispatcher). + * - A '''[[WorkerConnection]]''' (the transport channel to that process). + * - A '''socket path''' (a UDS socket file) that both sides use. + * + * Multiple [[DirectWorkerSession]]s may share the same process when the + * worker supports concurrent UDFs. The [[acquireSession]]/[[releaseSession]] + * ref-count tracks how many sessions are active. + * + * Closing tears down everything: closes the connection, sends SIGTERM + * (then SIGKILL), and removes the socket file and the process output log. + * + * @param process the OS process handle + * @param connection the transport connection to this worker + * @param socketPath the UDS socket path used by this worker + * @param outputFile the merged stdout/stderr log for this worker. + * Kept open for the lifetime of the worker (so it + * remains a valid target for the child's file + * descriptor and so its contents can be inspected + * while the worker runs) and deleted in [[close]]. + * @param gracefulTimeoutMs milliseconds to wait after SIGTERM before + * escalating to SIGKILL. + * @param logger SLF4J logger for process-level messages. Defaults + * to an SLF4J logger for this class; the dispatcher + * passes its own logger so all messages share a + * category. + */ +@Experimental +class DirectWorkerProcess( + val process: Process, + val connection: WorkerConnection, + val socketPath: String, + val outputFile: Path, + val gracefulTimeoutMs: Long, + protected val logger: Logger = + LoggerFactory.getLogger(classOf[DirectWorkerProcess])) + extends AutoCloseable { + + // The active-session ref-count below is scaffolding for future connection + // pooling. With pooling, the dispatcher would keep an idle worker alive + // when its ref-count drops to 0 and hand it out to the next session. + // TODO: Idle timeout tracking and concurrent session capacity. + // + // Until pooling lands, the dispatcher spawns one worker per session and + // tears it down at dispatcher close; the ref-count is informational only. + + private val activeSessionCount = new AtomicInteger(0) + + /** Number of sessions currently using this worker. */ + def activeSessions: Int = activeSessionCount.get() + + /** Increments the active session count. */ + def acquireSession(): Unit = activeSessionCount.incrementAndGet() + + /** Decrements the active session count. */ + def releaseSession(): Unit = activeSessionCount.updateAndGet(c => math.max(0, c - 1)) + + /** Returns true if the OS process is running and the connection is usable. */ + def isAlive: Boolean = process.isAlive && connection.isActive + + /** + * Shuts down the connection, then terminates the OS process. + * Sends SIGTERM first; escalates to SIGKILL after [[gracefulTimeoutMs]]. + */ + override def close(): Unit = { + try { + connection.close() + } catch { + case NonFatal(e) => + logger.warn(s"Error closing connection to worker at $socketPath", e) + } + + if (process.isAlive) { + process.destroy() // SIGTERM + try { + if (!process.waitFor(gracefulTimeoutMs, TimeUnit.MILLISECONDS)) { + process.destroyForcibly() // SIGKILL + } + } catch { + case _: InterruptedException => + process.destroyForcibly() + Thread.currentThread().interrupt() + } + } + + try { + val f = new File(socketPath) + if (f.exists()) f.delete() + } catch { + case NonFatal(e) => + logger.warn(s"Error cleaning up socket file $socketPath", e) + } + + try { + Files.deleteIfExists(outputFile) + } catch { + case NonFatal(e) => + logger.warn(s"Error cleaning up worker output file $outputFile", e) + } + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala new file mode 100644 index 0000000000000..7cdc5329350e3 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerSession} + +/** + * :: Experimental :: + * A [[WorkerSession]] backed by a locally-spawned [[DirectWorkerProcess]]. + * + * This is the session type returned by [[DirectWorkerDispatcher]]. It ties + * the session lifecycle to the worker's ref-count: the dispatcher increments + * the count before construction, and [[close]] decrements it, so the + * dispatcher knows when a worker process is idle and can be terminated or + * reused. + * + * Subclasses implement the protocol-specific data transmission + * ([[init]], [[process]], [[cancel]]). + * + * @param workerProcess the direct worker process backing this session. + * Internal to the `core` package and test code -- the + * worker handle is a dispatcher implementation detail, + * not part of the public WorkerSession API. + */ +@Experimental +abstract class DirectWorkerSession( + private[core] val workerProcess: DirectWorkerProcess) extends WorkerSession { + + private val released = new AtomicBoolean(false) + + /** The connection to the worker for this session. */ + def connection: WorkerConnection = workerProcess.connection + + override def close(): Unit = { + if (released.compareAndSet(false, true)) { + workerProcess.releaseSession() + } + } +} diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala new file mode 100644 index 0000000000000..5a770ad8b86ce --- /dev/null +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -0,0 +1,512 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.io.File +import java.nio.file.Files + +// scalastyle:off funsuite +import org.scalatest.BeforeAndAfterEach +import org.scalatest.funsuite.AnyFunSuite + +import org.apache.spark.udf.worker.{ + DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, + UDFWorkerSpecification, UnixDomainSocket, + WorkerConnection => WorkerConnectionProto, WorkerEnvironment} +import org.apache.spark.udf.worker.core.direct.{DirectWorkerDispatcher, + DirectWorkerProcess, DirectWorkerSession} + +/** + * A [[WorkerConnection]] test implementation that considers the connection + * active as long as the socket file exists on disk. + */ +class SocketFileConnection(socketPath: String) extends WorkerConnection { + override def isActive: Boolean = new File(socketPath).exists() + override def close(): Unit = { + val f = new File(socketPath) + if (f.exists()) f.delete() + } +} + +/** + * A stub [[DirectWorkerSession]] for process-lifecycle tests that don't + * need actual data transmission. + */ +class StubWorkerSession( + workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) { + + override def init(message: InitMessage): Unit = {} + + override def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = + Iterator.empty + + override def cancel(): Unit = {} +} + +/** + * A [[DirectWorkerDispatcher]] subclass for testing that uses a socket-file + * connection and stub sessions instead of a real protocol implementation. + */ +class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification) + extends DirectWorkerDispatcher(spec) { + + override protected def createConnection(socketPath: String): WorkerConnection = { + new SocketFileConnection(socketPath) + } + + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = { + new StubWorkerSession(worker) + } +} + +/** + * Tests for [[DirectWorkerDispatcher]] process lifecycle: spawning workers + * and terminating them on close. + */ +class DirectWorkerDispatcherSuite + extends AnyFunSuite with BeforeAndAfterEach { +// scalastyle:on funsuite + + private val echoWorkerScript = + """ + |#!/bin/bash + |SOCKET_PATH="" + |while [[ $# -gt 0 ]]; do + | case "$1" in + | --connection) SOCKET_PATH="$2"; shift 2 ;; + | *) shift ;; + | esac + |done + |cleanup() { rm -f "$SOCKET_PATH"; exit 0; } + |trap cleanup SIGTERM + |touch "$SOCKET_PATH" + |while true; do sleep 1; done + """.stripMargin.trim + + private def defaultRunner: ProcessCallable = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand(echoWorkerScript).addCommand("--") + .build() + + private def udsProperties: UDFWorkerProperties = UDFWorkerProperties.newBuilder() + .addConnections(WorkerConnectionProto.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance) + .build()) + .build() + + private def directWorker(runner: ProcessCallable): DirectWorker = + DirectWorker.newBuilder().setRunner(runner).setProperties(udsProperties).build() + + private def specWithRunner(runner: ProcessCallable): UDFWorkerSpecification = + UDFWorkerSpecification.newBuilder() + .setDirect(directWorker(runner)) + .build() + + private def specWithEnv( + runner: ProcessCallable = defaultRunner, + env: WorkerEnvironment): UDFWorkerSpecification = + UDFWorkerSpecification.newBuilder() + .setEnvironment(env) + .setDirect(directWorker(runner)) + .build() + + private var dispatcher: TestDirectWorkerDispatcher = _ + + override def afterEach(): Unit = { + if (dispatcher != null) { + dispatcher.close() + dispatcher = null + } + super.afterEach() + } + + test("creates a worker and session") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val session = dispatcher.createSession(None).asInstanceOf[StubWorkerSession] + val worker = session.workerProcess + + assert(worker.isAlive, "worker should be alive after creation") + assert(worker.activeSessions == 1, "should have 1 active session") + assert(new File(worker.socketPath).exists(), "socket file should exist") + + session.close() + assert(worker.activeSessions == 0, "should have 0 sessions after close") + } + + test("concurrent createSession calls produce distinct workers") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val threads = 8 + val sessions = new java.util.concurrent.ConcurrentLinkedQueue[StubWorkerSession]() + val startGate = new java.util.concurrent.CountDownLatch(1) + val doneGate = new java.util.concurrent.CountDownLatch(threads) + val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() + + (1 to threads).foreach { _ => + new Thread(() => { + try { + startGate.await() + sessions.add( + dispatcher.createSession(None).asInstanceOf[StubWorkerSession]) + } catch { + case t: Throwable => errors.add(t) + } finally { + doneGate.countDown() + } + }).start() + } + startGate.countDown() + assert(doneGate.await(30, java.util.concurrent.TimeUnit.SECONDS), + "createSession threads did not finish in time") + + assert(errors.isEmpty, + s"unexpected errors during concurrent createSession: ${errors.toArray.mkString(", ")}") + assert(sessions.size == threads, "expected one session per thread") + + val workerObjects = sessions.toArray.map(_.asInstanceOf[StubWorkerSession].workerProcess) + assert(workerObjects.distinct.length == threads, + "each session should have its own DirectWorkerProcess") + + sessions.toArray.foreach(_.asInstanceOf[StubWorkerSession].close()) + } + + test("close shuts down all workers via SIGTERM") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val session1 = dispatcher.createSession(None).asInstanceOf[StubWorkerSession] + val session2 = dispatcher.createSession(None).asInstanceOf[StubWorkerSession] + + val worker1 = session1.workerProcess + val worker2 = session2.workerProcess + + session1.close() + session2.close() + dispatcher.close() + dispatcher = null + + assert(!worker1.process.isAlive, "worker1 should be terminated") + assert(!worker2.process.isAlive, "worker2 should be terminated") + } + + // -- Error-path tests ------------------------------------------------------- + + test("worker is cleaned up when createSessionForWorker throws") { + // A dispatcher whose createSessionForWorker always throws. The spawned + // worker must be terminated rather than leaked until dispatcher.close(). + var capturedWorker: DirectWorkerProcess = null + val failingDispatcher = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection(socketPath: String): WorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = { + capturedWorker = worker + throw new RuntimeException("session creation failed") + } + } + + try { + val ex = intercept[RuntimeException] { + failingDispatcher.createSession(None) + } + assert(ex.getMessage.contains("session creation failed")) + assert(capturedWorker != null, "worker should have been spawned before the failure") + assert(!capturedWorker.process.isAlive, + "worker process should have been terminated after session creation failed") + assert(capturedWorker.activeSessions == 0, + "worker session count should be released after failure") + } finally { + failingDispatcher.close() + } + } + + test("DirectWorker without connections is rejected") { + val badSpec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder().setRunner(defaultRunner).build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("exactly one entry"), + s"expected connections-count error, got: ${ex.getMessage}") + } + + test("DirectWorker with non-UDS transport is rejected") { + val tcpProperties = UDFWorkerProperties.newBuilder() + .addConnections(WorkerConnectionProto.newBuilder() + .setTcp(LocalTcpConnection.getDefaultInstance).build()) + .build() + val badSpec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(tcpProperties).build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("UNIX domain socket"), + s"expected UDS-only error, got: ${ex.getMessage}") + } + + test("DirectWorker with multiple connections is rejected") { + val twoConnections = UDFWorkerProperties.newBuilder() + .addConnections(WorkerConnectionProto.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .addConnections(WorkerConnectionProto.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .build() + val badSpec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(twoConnections).build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(badSpec) + } + assert(ex.getMessage.contains("exactly one entry"), + s"expected connections-count error, got: ${ex.getMessage}") + } + + test("socket file is cleaned up when createConnection throws") { + val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]() + val failingDispatcher = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection(socketPath: String): WorkerConnection = { + capturedSocketPaths.add(socketPath) + throw new RuntimeException("connection creation failed") + } + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } + try { + val ex = intercept[RuntimeException] { + failingDispatcher.createSession(None) + } + assert(ex.getMessage.contains("connection creation failed")) + assert(capturedSocketPaths.size == 1, "createConnection should have been called once") + val socketPath = capturedSocketPaths.peek() + assert(!new File(socketPath).exists(), + s"socket file $socketPath should have been cleaned up") + } finally { + failingDispatcher.close() + } + } + + test("empty ProcessCallable command is rejected with a clear error") { + val emptyRunner = ProcessCallable.newBuilder().build() + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(emptyRunner)) + val ex = intercept[IllegalArgumentException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("at least one entry"), + s"expected explicit empty-command error, got: ${ex.getMessage}") + } + + test("spawnWorker fails when worker process exits immediately") { + val runner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("echo 'fatal: bad config' >&2; exit 42").addCommand("--") + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(runner)) + + val ex = intercept[RuntimeException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("exited with code 42"), + s"expected early-exit error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("fatal: bad config"), + s"expected process output in error, got: ${ex.getMessage}") + } + + // -- Environment lifecycle tests ------------------------------------------- + + test("skips installation when verification succeeds") { + val markerFile = Files.createTempFile("env-install-marker", ".txt").toFile + markerFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand("exit 0").build()) + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${markerFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(!markerFile.exists(), + "installation should not run when verification succeeds") + } + + test("runs installation when verification fails") { + val markerFile = Files.createTempFile("env-install-marker", ".txt").toFile + markerFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand("exit 1").build()) + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${markerFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(markerFile.exists(), + "installation should run when verification fails") + markerFile.delete() + } + + test("runs installation when no verification callable is provided") { + val markerFile = Files.createTempFile("env-install-marker", ".txt").toFile + markerFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${markerFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(markerFile.exists(), + "installation should run when no verification is defined") + markerFile.delete() + } + + test("installation failure throws with process output and prevents worker creation") { + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("echo 'missing dependency: libfoo' >&2; exit 7").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val ex = intercept[RuntimeException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("exit code 7"), + s"expected installation failure, got: ${ex.getMessage}") + assert(ex.getMessage.contains("missing dependency: libfoo"), + s"expected process output in error, got: ${ex.getMessage}") + } + + test("environment setup runs only once across multiple sessions") { + val counterFile = Files.createTempFile("env-counter", ".txt").toFile + counterFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"echo invoked >> ${counterFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val s1 = dispatcher.createSession(None); s1.close() + val s2 = dispatcher.createSession(None); s2.close() + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run exactly once, but ran ${lines.size} time(s)") + counterFile.delete() + } + + test("failed environment setup is not retried on subsequent createSession") { + val counterFile = Files.createTempFile("env-failed-counter", ".txt").toFile + counterFile.delete() + + // Installation script appends a line every time it runs, then always + // fails. The first createSession should run it; the second should be + // rejected immediately without re-running. + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand( + s"echo invoked >> ${counterFile.getAbsolutePath}; exit 1").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val first = intercept[RuntimeException] { dispatcher.createSession(None) } + assert(first.getMessage.contains("installation failed"), + s"expected first-attempt installation failure, got: ${first.getMessage}") + + val second = intercept[RuntimeException] { dispatcher.createSession(None) } + assert(second.getMessage.contains("previously failed"), + s"expected cached failure on retry, got: ${second.getMessage}") + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run only once across failed retries, got ${lines.size}") + counterFile.delete() + } + + test("non-None securityScope is rejected until pooling lands") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + val scope = new WorkerSecurityScope { + override def equals(obj: Any): Boolean = obj.isInstanceOf[this.type] + override def hashCode(): Int = 0 + } + val ex = intercept[IllegalArgumentException] { + dispatcher.createSession(Some(scope)) + } + assert(ex.getMessage.contains("not supported yet"), + s"expected unsupported-scope error, got: ${ex.getMessage}") + } + + test("verification without installation is rejected") { + val env = WorkerEnvironment.newBuilder() + .setEnvironmentVerification(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand("exit 0").build()) + .build() + val ex = intercept[IllegalArgumentException] { + new TestDirectWorkerDispatcher(specWithEnv(env = env)) + } + assert(ex.getMessage.contains("installation"), + s"expected installation-required error, got: ${ex.getMessage}") + } + + test("cleanup runs on dispatcher close") { + val cleanupMarker = Files.createTempFile("env-cleanup-marker", ".txt").toFile + cleanupMarker.delete() + + val env = WorkerEnvironment.newBuilder() + .setEnvironmentCleanup(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand(s"touch ${cleanupMarker.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val session = dispatcher.createSession(None) + session.close() + + assert(!cleanupMarker.exists(), + "cleanup should not run until dispatcher is closed") + + dispatcher.close() + dispatcher = null + + assert(cleanupMarker.exists(), + "cleanup should run when dispatcher is closed") + cleanupMarker.delete() + } +} diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala deleted file mode 100644 index 42f53af07424a..0000000000000 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/WorkerAbstractionSuite.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.udf.worker.core - -import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite - -class WorkerAbstractionSuite - extends AnyFunSuite { // scalastyle:ignore funsuite - - test("dummy") {} -} diff --git a/udf/worker/proto/src/main/protobuf/common.proto b/udf/worker/proto/src/main/protobuf/common.proto index 9c50cdd7a7e4b..ee032def73efe 100644 --- a/udf/worker/proto/src/main/protobuf/common.proto +++ b/udf/worker/proto/src/main/protobuf/common.proto @@ -32,6 +32,13 @@ enum UDFWorkerDataFormat { } // The UDF execution type/shape. +// +// BIDIRECTIONAL_STREAMING is the only pattern supported by the engine for +// now. It may be possible to express all UDF types (scalar, mapPartitions, +// and eventually UDAF/UDTF/streaming) on top of this single pattern by +// framing their phases as messages on the stream, but that is a design +// question worth revisiting as additional UDF types are added -- for +// example, aggregation may prefer a multi-round or specialized pattern. enum UDFProtoCommunicationPattern { UDF_PROTO_COMMUNICATION_PATTERN_UNSPECIFIED = 0; diff --git a/udf/worker/proto/src/main/protobuf/worker_spec.proto b/udf/worker/proto/src/main/protobuf/worker_spec.proto index f2eacf2b3ce35..a6c1c910ec443 100644 --- a/udf/worker/proto/src/main/protobuf/worker_spec.proto +++ b/udf/worker/proto/src/main/protobuf/worker_spec.proto @@ -140,13 +140,17 @@ message WorkerCapabilities { // Whether multiple, concurrent UDF // connections are supported by this worker // (for example via multi-threading). - // + // // In the first implementation of the engine-side // worker specification, this property will not be used. - // + // // Usage of this property can be enabled in the future if the // engine implements more advanced resource management (TBD). // + // TODO: wire this into planning/scheduling -- SPIP worker-spec §2.4 + // "Parallelism" describes the intended use (e.g., multiplex tasks onto + // a single worker vs. spawn multiple workers per executor). + // // (Optional) optional bool supports_concurrent_udfs = 3; @@ -190,22 +194,27 @@ message UDFWorkerProperties { // (Optional) optional int32 graceful_termination_timeout_ms = 2; - // The connection this [[DirectWorker]] supports. Note that a single - // connection is sufficient to run multiple UDFs and (gRPC) services. + // The connections this [[DirectWorker]] supports. A single connection is + // sufficient to run multiple UDFs and (gRPC) services; multi-connection + // workers (e.g., a separate control channel for stateful streaming) are + // a future extension and are not yet supported by the engine -- today + // exactly one connection must be specified. // - // On [[DirectWorker]] creation, connection information - // is passed to the callable as a string parameter. + // On [[DirectWorker]] creation, connection information + // is passed to the callable as a string parameter. // The string format depends on the [[WorkerConnection]]: - // + // // For example, when using TCP, the callable argument will be: // --connection PORT // Here is a concrete example // --connection 8080 - // + // // For the format of each specific transport type, see the comments below. // - // (Required) - WorkerConnection connection = 3; + // (Required) Exactly one entry today; the field is repeated to allow + // additional connections (e.g., data + control) to be added without a + // schema-breaking migration. + repeated WorkerConnection connections = 3; } message WorkerConnection { diff --git a/udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala b/udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala deleted file mode 100644 index e25b99b69990c..0000000000000 --- a/udf/worker/proto/src/main/scala/org/apache/spark/udf/worker/WorkerSpecification.scala +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.udf.worker - -import org.apache.spark.annotation.Experimental - -/** - * :: Experimental :: - * Typed Scala wrapper around the protobuf [[UDFWorkerSpecification]]. - */ -@Experimental -class WorkerSpecification(val proto: UDFWorkerSpecification) { -} From 30ffef7c57513698607f9857a319c72a9247a777 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Sat, 18 Apr 2026 09:28:06 +0000 Subject: [PATCH 2/9] drop direct log4j dep. --- udf/worker/core/pom.xml | 4 -- .../spark/udf/worker/core/WorkerLogger.scala | 51 +++++++++++++++++++ .../core/direct/DirectWorkerDispatcher.scala | 17 +++---- .../core/direct/DirectWorkerProcess.scala | 15 +++--- 4 files changed, 64 insertions(+), 23 deletions(-) create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala diff --git a/udf/worker/core/pom.xml b/udf/worker/core/pom.xml index 724ff9a1e0d2e..69088d284365f 100644 --- a/udf/worker/core/pom.xml +++ b/udf/worker/core/pom.xml @@ -47,10 +47,6 @@ spark-udf-worker-proto_${scala.binary.version} ${project.version} - - org.slf4j - slf4j-api - org.scala-lang scala-library diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala new file mode 100644 index 0000000000000..a8f135f688908 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerLogger.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Minimal logging surface used by the udf/worker framework. + * + * The framework deliberately does not depend on SLF4J (or any other + * concrete logging backend) so callers can embed it without dragging a + * specific logger onto the classpath. Embedders should supply an + * adapter that forwards to their preferred backend (Spark's `Logging` + * trait, SLF4J, java.util.logging, etc.). + * + * Only the methods actually used by the framework are exposed. + * Messages are passed by-name so the formatting cost is avoided when + * the backend decides to drop the event. + */ +@Experimental +trait WorkerLogger { + def warn(msg: => String): Unit + def warn(msg: => String, t: Throwable): Unit + def debug(msg: => String): Unit + def debug(msg: => String, t: Throwable): Unit +} + +object WorkerLogger { + /** Discards all messages. Default for callers that don't wire up logging. */ + val NoOp: WorkerLogger = new WorkerLogger { + override def warn(msg: => String): Unit = () + override def warn(msg: => String, t: Throwable): Unit = () + override def debug(msg: => String): Unit = () + override def debug(msg: => String, t: Throwable): Unit = () + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 58e12e76afb20..a0edac3460ecb 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -26,12 +26,10 @@ import scala.collection.mutable.ArrayBuffer import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal -import org.slf4j.{Logger, LoggerFactory} - import org.apache.spark.annotation.Experimental import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification} import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, - WorkerSecurityScope, WorkerSession} + WorkerLogger, WorkerSecurityScope, WorkerSession} import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, EnvironmentState} @@ -51,17 +49,16 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR * creation), see the `indirect` package (TODO). * * @param workerSpec worker specification (proto) - * @param logger SLF4J logger for dispatcher-internal messages. Callers may - * inject their own logger (e.g., backed by Spark's `Logging` - * trait in an engine context) to route messages through their - * own logging configuration. Defaults to an SLF4J logger for - * this class. + * @param logger [[WorkerLogger]] used for dispatcher-internal messages. + * The framework does not depend on any concrete logging + * backend; callers should pass an adapter that forwards + * to their preferred logger (Spark's `Logging` trait, + * SLF4J, etc.). Defaults to [[WorkerLogger.NoOp]]. */ @Experimental abstract class DirectWorkerDispatcher( override val workerSpec: UDFWorkerSpecification, - protected val logger: Logger = - LoggerFactory.getLogger(classOf[DirectWorkerDispatcher])) + protected val logger: WorkerLogger = WorkerLogger.NoOp) extends WorkerDispatcher { // TODO: Connection pooling -- reuse idle workers across sessions. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index b505bb29c1f83..cf18df7be3562 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -23,10 +23,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.util.control.NonFatal -import org.slf4j.{Logger, LoggerFactory} - import org.apache.spark.annotation.Experimental -import org.apache.spark.udf.worker.core.WorkerConnection +import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} /** * :: Experimental :: @@ -55,10 +53,10 @@ import org.apache.spark.udf.worker.core.WorkerConnection * while the worker runs) and deleted in [[close]]. * @param gracefulTimeoutMs milliseconds to wait after SIGTERM before * escalating to SIGKILL. - * @param logger SLF4J logger for process-level messages. Defaults - * to an SLF4J logger for this class; the dispatcher - * passes its own logger so all messages share a - * category. + * @param logger [[WorkerLogger]] used for process-level + * messages. Defaults to [[WorkerLogger.NoOp]]; + * the dispatcher normally passes its own + * logger so all messages share a category. */ @Experimental class DirectWorkerProcess( @@ -67,8 +65,7 @@ class DirectWorkerProcess( val socketPath: String, val outputFile: Path, val gracefulTimeoutMs: Long, - protected val logger: Logger = - LoggerFactory.getLogger(classOf[DirectWorkerProcess])) + protected val logger: WorkerLogger = WorkerLogger.NoOp) extends AutoCloseable { // The active-session ref-count below is scaffolding for future connection From 0f103c8a0841ff87607e493d24b5d4cad01a9b01 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Sat, 18 Apr 2026 10:13:56 +0000 Subject: [PATCH 3/9] improvements --- udf/worker/README.md | 4 +- .../core/direct/DirectWorkerDispatcher.scala | 226 +++++++++++------- .../core/direct/DirectWorkerProcess.scala | 24 +- .../core/direct/DirectWorkerSession.scala | 7 + .../core/DirectWorkerDispatcherSuite.scala | 27 ++- 5 files changed, 188 insertions(+), 100 deletions(-) diff --git a/udf/worker/README.md b/udf/worker/README.md index 8f0ef476d21d9..b3c634c9e0bf8 100644 --- a/udf/worker/README.md +++ b/udf/worker/README.md @@ -12,7 +12,7 @@ with an Iterator-to-Iterator `process` API that streams input batches through the worker and returns result batches. ``` -WorkerSpecification -- how to create and configure workers +UDFWorkerSpecification -- how to create and configure workers | v WorkerDispatcher -- manages workers, creates sessions @@ -59,7 +59,7 @@ obtaining workers from a provisioning service or daemon. `DirectWorkerDispatcher` spawns worker processes locally. On the first session, it runs the optional environment lifecycle callables from the -`WorkerSpecification`: +`UDFWorkerSpecification`: - **`environmentVerification`** -- checks if the environment is ready (exit 0 = ready). When it succeeds, installation is skipped. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index a0edac3460ecb..0993c0888b80a 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -16,13 +16,13 @@ */ package org.apache.spark.udf.worker.core.direct -import java.io.File +import java.io.{BufferedReader, File, FileInputStream, InputStreamReader} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.util.UUID import java.util.concurrent.TimeUnit -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, Queue => MQueue} import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -31,7 +31,9 @@ import org.apache.spark.udf.worker.{ProcessCallable, UDFWorkerSpecification} import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, WorkerLogger, WorkerSecurityScope, WorkerSession} import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, - EnvironmentState} + DEFAULT_CALLABLE_TIMEOUT_MS, DEFAULT_GRACEFUL_TIMEOUT_MS, DEFAULT_INIT_TIMEOUT_MS, + EnvironmentState, MAX_OUTPUT_SCAN_BYTES, PROCESS_OUTPUT_TAIL_LINES, + SOCKET_POLL_INTERVAL_MS} /** * :: Experimental :: @@ -64,38 +66,8 @@ abstract class DirectWorkerDispatcher( // TODO: Connection pooling -- reuse idle workers across sessions. // TODO: Security scope isolation -- partition pool by WorkerSecurityScope. - // Multi-connection workers (e.g., a separate control channel) are a future - // extension; today the proto field is `repeated` but the engine requires - // exactly one. TCP transport is declared in the proto but not yet - // implemented; the engine currently only supports UDS. - { - val props = workerSpec.getDirect.getProperties - val n = props.getConnectionsCount - require(n == 1, - s"DirectWorker.properties.connections must have exactly one entry, got $n") - val conn = props.getConnections(0) - require(conn.hasUnixDomainSocket, - "DirectWorker currently only supports UNIX domain socket transport, " + - s"got ${conn.getTransportCase}") - } - - // worker_spec.proto documents that verification is only meaningful together - // with installation -- verification exists so the engine can skip running - // installation when the environment is already prepared. A verification - // callable with no installation callable would either always succeed (no-op) - // or always fail (worker spawn then fails) -- both user errors worth - // catching at spec-validation time. - { - val env = workerSpec.getEnvironment - require(!env.hasEnvironmentVerification || env.hasInstallation, - "WorkerEnvironment.environment_verification requires installation to be set") - } - - private val SOCKET_POLL_INTERVAL_MS = 100L - private val DEFAULT_INIT_TIMEOUT_MS = 10000L - private val DEFAULT_CALLABLE_TIMEOUT_MS = 120000L - private val DEFAULT_GRACEFUL_TIMEOUT_MS = 5000L - private val PROCESS_OUTPUT_TAIL_LINES = 50 + validateTransportSupport() + validateEnvironmentCallables() /** * Maximum time to wait for a setup/verify/cleanup callable to finish. @@ -155,16 +127,24 @@ abstract class DirectWorkerDispatcher( try { createSessionForWorker(worker) } catch { - case e: Exception => - worker.releaseSession() - workersLock.synchronized { workers -= worker } - try { - worker.close() - } catch { - case NonFatal(closeEx) => - logger.warn("Error closing worker after session creation failed", closeEx) - } + case e: InterruptedException => + Thread.currentThread().interrupt() + cleanupFailedSession(worker) throw e + case NonFatal(e) => + cleanupFailedSession(worker) + throw e + } + } + + private def cleanupFailedSession(worker: DirectWorkerProcess): Unit = { + worker.releaseSession() + workersLock.synchronized { workers -= worker } + try { + worker.close() + } catch { + case NonFatal(closeEx) => + logger.warn("Error closing worker after session creation failed", closeEx) } } @@ -209,27 +189,25 @@ abstract class DirectWorkerDispatcher( environmentLock.synchronized { environmentState match { case EnvironmentState.Ready | EnvironmentState.CleanedUp => - return + // Already set up (or torn down); nothing to do. case EnvironmentState.Failed(msg) => throw new RuntimeException(s"Environment setup previously failed: $msg") case EnvironmentState.Pending => + val env = workerSpec.getEnvironment + val verified = env.hasEnvironmentVerification && + runCallable(env.getEnvironmentVerification).exitCode == 0 + if (!verified && env.hasInstallation) { + val result = runCallable(env.getInstallation) + if (result.exitCode != 0) { + val detail = s"exit code ${result.exitCode}\n${result.outputTail}" + environmentState = EnvironmentState.Failed(detail) + throw new RuntimeException( + s"Environment installation failed with $detail") + } + } + registerEnvironmentCleanupHook() + environmentState = EnvironmentState.Ready } - - val env = workerSpec.getEnvironment - val verified = env.hasEnvironmentVerification && - runCallable(env.getEnvironmentVerification).exitCode == 0 - if (!verified && env.hasInstallation) { - val result = runCallable(env.getInstallation) - if (result.exitCode != 0) { - val detail = s"exit code ${result.exitCode}\n${result.outputTail}" - environmentState = EnvironmentState.Failed(detail) - throw new RuntimeException( - s"Environment installation failed with $detail") - } - } - - registerEnvironmentCleanupHook() - environmentState = EnvironmentState.Ready } } @@ -267,21 +245,22 @@ abstract class DirectWorkerDispatcher( private def runEnvironmentCleanup(): Unit = { environmentLock.synchronized { environmentState match { - case EnvironmentState.CleanedUp => return + case EnvironmentState.CleanedUp => + // Already cleaned up; nothing to do. case _ => - } - if (workerSpec.getEnvironment.hasEnvironmentCleanup) { - try { - val result = runCallable(workerSpec.getEnvironment.getEnvironmentCleanup) - if (result.exitCode != 0) { - logger.warn(s"Environment cleanup exited with code ${result.exitCode}" + - s"\n${result.outputTail}") + if (workerSpec.getEnvironment.hasEnvironmentCleanup) { + try { + val result = runCallable(workerSpec.getEnvironment.getEnvironmentCleanup) + if (result.exitCode != 0) { + logger.warn(s"Environment cleanup exited with code ${result.exitCode}" + + s"\n${result.outputTail}") + } + } catch { + case NonFatal(e) => logger.warn("Environment cleanup failed", e) + } } - } catch { - case NonFatal(e) => logger.warn("Environment cleanup failed", e) - } + environmentState = EnvironmentState.CleanedUp } - environmentState = EnvironmentState.CleanedUp } } @@ -338,19 +317,30 @@ abstract class DirectWorkerDispatcher( new DirectWorkerProcess( process, connection, socketPath, outputFile, gracefulTimeoutMs, logger) } catch { - case e: Exception => - if (process.isAlive) process.destroyForcibly() - // If the worker (or createConnection) had already created the socket - // file, remove it so it doesn't linger until dispatcher.close(). - try Files.deleteIfExists(new File(socketPath).toPath) catch { - case NonFatal(cleanupEx) => - logger.debug(s"Failed to clean up socket file $socketPath", cleanupEx) - } - Files.deleteIfExists(outputFile) + case e: InterruptedException => + Thread.currentThread().interrupt() + cleanupFailedSpawn(process, socketPath, outputFile) + throw e + case NonFatal(e) => + cleanupFailedSpawn(process, socketPath, outputFile) throw e } } + private def cleanupFailedSpawn( + process: Process, + socketPath: String, + outputFile: java.nio.file.Path): Unit = { + if (process.isAlive) process.destroyForcibly() + // If the worker (or createConnection) had already created the socket + // file, remove it so it doesn't linger until dispatcher.close(). + try Files.deleteIfExists(new File(socketPath).toPath) catch { + case NonFatal(cleanupEx) => + logger.debug(s"Failed to clean up socket file $socketPath", cleanupEx) + } + Files.deleteIfExists(outputFile) + } + /** * Starts an OS process. stdout and stderr are merged and redirected to the * given file so that output can be read back for error reporting. @@ -394,25 +384,85 @@ abstract class DirectWorkerDispatcher( } } + // Reads at most the final MAX_OUTPUT_SCAN_BYTES of `file` and returns the + // last PROCESS_OUTPUT_TAIL_LINES lines via a fixed-size ring buffer, so a + // runaway worker that writes gigabytes of output does not OOM the caller + // during error reporting. private def readOutputTail(file: File): String = { if (!file.exists() || file.length() == 0) return "" - val src = scala.io.Source.fromFile(file, StandardCharsets.UTF_8.name()) + val fileLen = file.length() + val startPos = math.max(0L, fileLen - MAX_OUTPUT_SCAN_BYTES) + val fis = new FileInputStream(file) try { - val lines = src.getLines().toVector - val tail = lines.takeRight(PROCESS_OUTPUT_TAIL_LINES) - if (tail.isEmpty) "" - else "Process output (last lines):\n" + tail.mkString("\n") + var remaining = startPos + while (remaining > 0) { + val n = fis.skip(remaining) + if (n <= 0) remaining = 0 else remaining -= n + } + val reader = new BufferedReader( + new InputStreamReader(fis, StandardCharsets.UTF_8)) + // If we started mid-line, the first line is partial -- discard it so + // the tail never shows a line fragment. + if (startPos > 0) reader.readLine() + val buffer = new MQueue[String]() + var line = reader.readLine() + while (line != null) { + if (buffer.size >= PROCESS_OUTPUT_TAIL_LINES) buffer.dequeue() + buffer.enqueue(line) + line = reader.readLine() + } + if (buffer.isEmpty) "" + else "Process output (last lines):\n" + buffer.mkString("\n") } catch { case NonFatal(e) => logger.debug(s"Failed to read process output from $file", e) "" } finally { - src.close() + fis.close() } } + + // -- Spec validation ------------------------------------------------------- + + // Multi-connection workers (e.g., a separate control channel) are a future + // extension; today the proto field is `repeated` but the engine requires + // exactly one. TCP transport is declared in the proto but not yet + // implemented; the engine currently only supports UDS. + private def validateTransportSupport(): Unit = { + val props = workerSpec.getDirect.getProperties + val n = props.getConnectionsCount + require(n == 1, + s"DirectWorker.properties.connections must have exactly one entry, got $n") + val conn = props.getConnections(0) + require(conn.hasUnixDomainSocket, + "DirectWorker currently only supports UNIX domain socket transport, " + + s"got ${conn.getTransportCase}") + } + + // worker_spec.proto documents that verification is only meaningful together + // with installation -- verification exists so the engine can skip running + // installation when the environment is already prepared. A verification + // callable with no installation callable would either always succeed (no-op) + // or always fail (worker spawn then fails) -- both user errors worth + // catching at spec-validation time. + private def validateEnvironmentCallables(): Unit = { + val env = workerSpec.getEnvironment + require(!env.hasEnvironmentVerification || env.hasInstallation, + "WorkerEnvironment.environment_verification requires installation to be set") + } } private[direct] object DirectWorkerDispatcher { + private[direct] val SOCKET_POLL_INTERVAL_MS = 100L + private[direct] val DEFAULT_INIT_TIMEOUT_MS = 10000L + private[direct] val DEFAULT_CALLABLE_TIMEOUT_MS = 120000L + private[direct] val DEFAULT_GRACEFUL_TIMEOUT_MS = 5000L + private[direct] val PROCESS_OUTPUT_TAIL_LINES = 50 + // Cap the amount of log file scanned by readOutputTail so a runaway worker + // producing gigabytes of output cannot OOM the caller during error + // reporting. The tail is still limited to PROCESS_OUTPUT_TAIL_LINES. + private[direct] val MAX_OUTPUT_SCAN_BYTES = 1024L * 1024L // 1 MiB + /** Result of running a [[ProcessCallable]]. */ private[core] case class CallableResult(exitCode: Int, outputTail: String) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index cf18df7be3562..d1d40fd705d27 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -19,7 +19,7 @@ package org.apache.spark.udf.worker.core.direct import java.io.File import java.nio.file.{Files, Path} import java.util.concurrent.TimeUnit -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import scala.util.control.NonFatal @@ -77,6 +77,7 @@ class DirectWorkerProcess( // tears it down at dispatcher close; the ref-count is informational only. private val activeSessionCount = new AtomicInteger(0) + private val closed = new AtomicBoolean(false) /** Number of sessions currently using this worker. */ def activeSessions: Int = activeSessionCount.get() @@ -84,8 +85,20 @@ class DirectWorkerProcess( /** Increments the active session count. */ def acquireSession(): Unit = activeSessionCount.incrementAndGet() - /** Decrements the active session count. */ - def releaseSession(): Unit = activeSessionCount.updateAndGet(c => math.max(0, c - 1)) + /** + * Decrements the active session count. Logs a warning and resets to zero + * if the count goes negative, which would indicate an unbalanced + * acquire/release (a bug we want to surface rather than paper over, + * especially once pooling consumes this count). + */ + def releaseSession(): Unit = { + val c = activeSessionCount.decrementAndGet() + if (c < 0) { + logger.warn( + s"releaseSession called without a matching acquireSession (count=$c)") + activeSessionCount.set(0) + } + } /** Returns true if the OS process is running and the connection is usable. */ def isAlive: Boolean = process.isAlive && connection.isActive @@ -93,8 +106,13 @@ class DirectWorkerProcess( /** * Shuts down the connection, then terminates the OS process. * Sends SIGTERM first; escalates to SIGKILL after [[gracefulTimeoutMs]]. + * Idempotent: only the first call performs teardown; subsequent calls + * are no-ops so the dispatcher's close path and the createSession error + * path can both invoke close without double-releasing resources. */ override def close(): Unit = { + if (!closed.compareAndSet(false, true)) return + try { connection.close() } catch { diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala index 7cdc5329350e3..f92df2a907ba4 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -48,6 +48,13 @@ abstract class DirectWorkerSession( /** The connection to the worker for this session. */ def connection: WorkerConnection = workerProcess.connection + // TODO: Introduce an idle timeout so the dispatcher tears down a worker + // whose ref-count has dropped to zero and stayed there for some interval. + // Without that, sessions release back to a worker that is never reaped + // until dispatcher.close(), which leaks one process + UDS + FDs per + // session over the dispatcher's lifetime. The timeout should live on + // the dispatcher (it owns the worker pool) and be pluggable from the + // worker spec. override def close(): Unit = { if (released.compareAndSet(false, true)) { workerProcess.releaseSession() diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 5a770ad8b86ce..e312d803281ed 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.udf.worker.core import java.io.File import java.nio.file.Files +import scala.jdk.CollectionConverters._ + // scalastyle:off funsuite import org.scalatest.BeforeAndAfterEach import org.scalatest.funsuite.AnyFunSuite @@ -134,10 +136,21 @@ class DirectWorkerDispatcherSuite super.afterEach() } + // Narrow the publicly-typed WorkerSession returned by `createSession` back + // down to StubWorkerSession in one place, with a descriptive failure if + // the cast is ever wrong, so individual tests don't scatter `asInstanceOf` + // (which would throw ClassCastException rather than a useful message). + private def createStubSession(): StubWorkerSession = + dispatcher.createSession(None) match { + case stub: StubWorkerSession => stub + case other => fail( + s"Expected StubWorkerSession, got ${other.getClass.getSimpleName}") + } + test("creates a worker and session") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) - val session = dispatcher.createSession(None).asInstanceOf[StubWorkerSession] + val session = createStubSession() val worker = session.workerProcess assert(worker.isAlive, "worker should be alive after creation") @@ -161,8 +174,7 @@ class DirectWorkerDispatcherSuite new Thread(() => { try { startGate.await() - sessions.add( - dispatcher.createSession(None).asInstanceOf[StubWorkerSession]) + sessions.add(createStubSession()) } catch { case t: Throwable => errors.add(t) } finally { @@ -178,18 +190,19 @@ class DirectWorkerDispatcherSuite s"unexpected errors during concurrent createSession: ${errors.toArray.mkString(", ")}") assert(sessions.size == threads, "expected one session per thread") - val workerObjects = sessions.toArray.map(_.asInstanceOf[StubWorkerSession].workerProcess) + val sessionList = sessions.asScala.toList + val workerObjects = sessionList.map(_.workerProcess) assert(workerObjects.distinct.length == threads, "each session should have its own DirectWorkerProcess") - sessions.toArray.foreach(_.asInstanceOf[StubWorkerSession].close()) + sessionList.foreach(_.close()) } test("close shuts down all workers via SIGTERM") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) - val session1 = dispatcher.createSession(None).asInstanceOf[StubWorkerSession] - val session2 = dispatcher.createSession(None).asInstanceOf[StubWorkerSession] + val session1 = createStubSession() + val session2 = createStubSession() val worker1 = session1.workerProcess val worker2 = session2.workerProcess From be7f65f44991359787df1ba391f483e7c5f22953 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Thu, 23 Apr 2026 10:16:48 +0000 Subject: [PATCH 4/9] cleanup hook - currently per session, but open to idle pooling in future. use a map to manage workers --- .../core/direct/DirectWorkerDispatcher.scala | 63 +++++++++++------- .../core/direct/DirectWorkerProcess.scala | 49 +++++++++++--- .../core/direct/DirectWorkerSession.scala | 12 ++-- .../core/DirectWorkerDispatcherSuite.scala | 64 +++++++++++++++++++ 4 files changed, 148 insertions(+), 40 deletions(-) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 0993c0888b80a..5e0f111e76fa8 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -20,9 +20,9 @@ import java.io.{BufferedReader, File, FileInputStream, InputStreamReader} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.util.UUID -import java.util.concurrent.TimeUnit +import java.util.concurrent.{ConcurrentHashMap, TimeUnit} -import scala.collection.mutable.{ArrayBuffer, Queue => MQueue} +import scala.collection.mutable.{Queue => MQueue} import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal @@ -41,8 +41,9 @@ import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableR * ("direct" creation mode from the worker specification). * * On the first [[createSession]], the dispatcher ensures the environment is - * ready (verify / install) and registers the cleanup hook. Currently spawns - * a fresh worker per session; pooling/reuse is TODO. + * ready (verify / install) and registers the cleanup hook. Each session + * currently gets a fresh worker that is terminated when the session closes + * (the single-reference case of the future pooling policy). * * Subclasses implement [[createConnection]] and [[createSessionForWorker]] * to provide protocol-specific behavior (e.g., gRPC, raw sockets). @@ -99,8 +100,12 @@ abstract class DirectWorkerDispatcher( // it leaks memory in long-lived JVMs (the JDK retains the path string for // the process lifetime), and it only works on empty directories. private val socketDir = Files.createTempDirectory("spark-udf-worker") - private val workers = new ArrayBuffer[DirectWorkerProcess]() - private val workersLock = new Object + // Keyed by worker id so release is O(1) and lock-free. Iterating for + // shutdown is weakly consistent, which is fine: no new workers should be + // spawning once close() is running, and any in-flight releaseWorker that + // overlaps with close() is idempotent on both sides (remove-missing is a + // no-op and DirectWorkerProcess.close() is CAS-guarded). + private val workers = new ConcurrentHashMap[String, DirectWorkerProcess]() @volatile private var environmentState: EnvironmentState = EnvironmentState.Pending private val environmentLock = new Object @@ -122,29 +127,40 @@ abstract class DirectWorkerDispatcher( "securityScope is not supported yet; pass None until pooling lands") ensureEnvironmentReady() val worker = spawnWorker() - workersLock.synchronized { workers += worker } + workers.put(worker.id, worker) worker.acquireSession() try { createSessionForWorker(worker) } catch { case e: InterruptedException => Thread.currentThread().interrupt() - cleanupFailedSession(worker) + // Drop the ref-count to 0, firing `releaseWorker` via the worker's + // callback to remove and tear down the worker. + worker.releaseSession() throw e case NonFatal(e) => - cleanupFailedSession(worker) + worker.releaseSession() throw e } } - private def cleanupFailedSession(worker: DirectWorkerProcess): Unit = { - worker.releaseSession() - workersLock.synchronized { workers -= worker } + /** + * Called from [[DirectWorkerProcess.releaseSession]] when the last + * active session on `worker` closes. Today this always terminates the + * worker; with pooling, this is where the decision to reuse vs. evict + * will live (idle-pool handoff, capacity limits, health checks). + * + * Safe to invoke after dispatcher [[close]] has already reaped this + * worker: the worker's own idempotent close guard turns the second + * teardown into a no-op. + */ + private def releaseWorker(worker: DirectWorkerProcess): Unit = { + workers.remove(worker.id) try { worker.close() } catch { - case NonFatal(closeEx) => - logger.warn("Error closing worker after session creation failed", closeEx) + case NonFatal(e) => + logger.warn(s"Error closing worker ${worker.id}", e) } } @@ -153,17 +169,15 @@ abstract class DirectWorkerDispatcher( // N * gracefulTimeoutMs because each worker waits for SIGTERM to // complete before the next one is signalled. A small pool of // short-lived threads would bound shutdown to ~gracefulTimeoutMs. - workersLock.synchronized { - workers.foreach { w => - try { - w.close() - } catch { - case NonFatal(e) => - logger.warn(s"Error closing worker at ${w.socketPath}", e) - } + workers.values().iterator().asScala.foreach { w => + try { + w.close() + } catch { + case NonFatal(e) => + logger.warn(s"Error closing worker ${w.id}", e) } - workers.clear() } + workers.clear() try { val dir = socketDir.toFile if (dir.exists()) { @@ -315,7 +329,8 @@ abstract class DirectWorkerDispatcher( // remains valid for the child's file descriptor and is deleted in // DirectWorkerProcess.close(). new DirectWorkerProcess( - process, connection, socketPath, outputFile, gracefulTimeoutMs, logger) + workerId, process, connection, socketPath, outputFile, + gracefulTimeoutMs, logger, onLastSessionReleased = releaseWorker) } catch { case e: InterruptedException => Thread.currentThread().interrupt() diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index d1d40fd705d27..9c530471dcdad 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -43,6 +43,11 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} * Closing tears down everything: closes the connection, sends SIGTERM * (then SIGKILL), and removes the socket file and the process output log. * + * @param id stable identifier for this worker (a UUID + * generated by the dispatcher and passed to the + * worker binary as `--id`). Used as the map key + * in the dispatcher's worker registry and in + * diagnostic messages. * @param process the OS process handle * @param connection the transport connection to this worker * @param socketPath the UDS socket path used by this worker @@ -57,24 +62,39 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} * messages. Defaults to [[WorkerLogger.NoOp]]; * the dispatcher normally passes its own * logger so all messages share a category. + * @param onLastSessionReleased + * callback fired when the session ref-count + * transitions to 0. The dispatcher wires this to + * its reuse policy: today, terminate the worker; + * with pooling, return it to an idle pool or + * evict it. Runs on the thread that calls + * [[releaseSession]] (typically the session-close + * thread). May fire multiple times across a + * worker's lifetime -- once per 0-transition -- + * and runs without holding the count at 0: a + * concurrent `acquireSession` can push the count + * back up before the callback returns, so a + * pooling dispatcher must arbitrate reuse itself + * rather than assume "no active sessions" here. */ @Experimental class DirectWorkerProcess( + val id: String, val process: Process, val connection: WorkerConnection, val socketPath: String, val outputFile: Path, val gracefulTimeoutMs: Long, - protected val logger: WorkerLogger = WorkerLogger.NoOp) + protected val logger: WorkerLogger = WorkerLogger.NoOp, + private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ()) extends AutoCloseable { - // The active-session ref-count below is scaffolding for future connection - // pooling. With pooling, the dispatcher would keep an idle worker alive - // when its ref-count drops to 0 and hand it out to the next session. + // The ref-count drives worker teardown: when the last session releases, + // the dispatcher's `onLastSessionReleased` callback decides what to do + // with the worker. Today that callback terminates the worker (no reuse); + // with pooling it will return the worker to an idle pool subject to + // capacity / timeout. // TODO: Idle timeout tracking and concurrent session capacity. - // - // Until pooling lands, the dispatcher spawns one worker per session and - // tears it down at dispatcher close; the ref-count is informational only. private val activeSessionCount = new AtomicInteger(0) private val closed = new AtomicBoolean(false) @@ -86,8 +106,10 @@ class DirectWorkerProcess( def acquireSession(): Unit = activeSessionCount.incrementAndGet() /** - * Decrements the active session count. Logs a warning and resets to zero - * if the count goes negative, which would indicate an unbalanced + * Decrements the active session count. Fires + * [[onLastSessionReleased]] when the count transitions to zero so the + * dispatcher can apply its reuse policy. Logs a warning and resets to + * zero if the count goes negative, which would indicate an unbalanced * acquire/release (a bug we want to surface rather than paper over, * especially once pooling consumes this count). */ @@ -97,6 +119,15 @@ class DirectWorkerProcess( logger.warn( s"releaseSession called without a matching acquireSession (count=$c)") activeSessionCount.set(0) + } else if (c == 0) { + // Swallow callback errors so a misbehaving dispatcher cannot turn + // `session.close()` into an exception -- close must always succeed + // from the caller's perspective. The dispatcher callback has its + // own error handling for the underlying worker teardown. + try onLastSessionReleased(this) catch { + case NonFatal(e) => + logger.warn(s"onLastSessionReleased callback failed for worker $id", e) + } } } diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala index f92df2a907ba4..fabf146296229 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -48,13 +48,11 @@ abstract class DirectWorkerSession( /** The connection to the worker for this session. */ def connection: WorkerConnection = workerProcess.connection - // TODO: Introduce an idle timeout so the dispatcher tears down a worker - // whose ref-count has dropped to zero and stayed there for some interval. - // Without that, sessions release back to a worker that is never reaped - // until dispatcher.close(), which leaks one process + UDS + FDs per - // session over the dispatcher's lifetime. The timeout should live on - // the dispatcher (it owns the worker pool) and be pluggable from the - // worker spec. + // Releasing the session decrements the worker's ref-count; when it hits + // zero the worker's `onLastSessionReleased` callback (wired by the + // dispatcher) decides whether to terminate it or return it to a pool. + // Today it is always terminated, so closing the session is enough to + // reap the worker. override def close(): Unit = { if (released.compareAndSet(false, true)) { workerProcess.releaseSession() diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index e312d803281ed..0a6d528eb29b4 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -216,6 +216,70 @@ class DirectWorkerDispatcherSuite assert(!worker2.process.isAlive, "worker2 should be terminated") } + test("closing a session terminates its worker") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val session = createStubSession() + val worker = session.workerProcess + val socketFile = new File(worker.socketPath) + + assert(worker.process.isAlive, "worker should be alive before session close") + assert(socketFile.exists(), "socket file should exist before session close") + + session.close() + + // The session-close path is synchronous: SIGTERM is sent and the process + // is reaped before `close` returns. + assert(!worker.process.isAlive, + "worker process should be terminated when the session closes") + assert(!socketFile.exists(), + "socket file should be cleaned up when the session closes") + } + + test("concurrent session.close and dispatcher.close do not double-close the worker") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + + val sessions = (1 to 4).map(_ => createStubSession()) + val workers = sessions.map(_.workerProcess) + + val barrier = new java.util.concurrent.CyclicBarrier(sessions.size + 1) + val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() + + val sessionThreads = sessions.map { s => + val t = new Thread(() => { + try { + barrier.await() + s.close() + } catch { + case t: Throwable => errors.add(t) + } + }) + t.start() + t + } + + val dispatcherThread = new Thread(() => { + try { + barrier.await() + dispatcher.close() + } catch { + case t: Throwable => errors.add(t) + } + }) + dispatcherThread.start() + + sessionThreads.foreach(_.join(30000)) + dispatcherThread.join(30000) + dispatcher = null + + assert(errors.isEmpty, + s"unexpected errors during concurrent close: ${errors.toArray.mkString(", ")}") + workers.foreach { w => + assert(!w.process.isAlive, + s"worker at ${w.socketPath} should be terminated after concurrent close") + } + } + // -- Error-path tests ------------------------------------------------------- test("worker is cleaned up when createSessionForWorker throws") { From 8733a594e2964ceb9ea05010e6a91b8666e760a8 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Fri, 24 Apr 2026 06:03:52 +0000 Subject: [PATCH 5/9] address PR review feedback Concurrency / correctness: - Introduce AtomicBoolean `closed` on DirectWorkerDispatcher; reject createSession after close; re-check after workers.put so any worker spawned concurrently with close() tears itself down via the ref-count callback instead of leaking. close() itself is idempotent via CAS. - Acquire the session ref-count BEFORE publishing to `workers`; a concurrent close() iterating the map can no longer tear down a worker whose caller is about to increment the count. - Introduce DirectWorkerDispatcher.destroyForciblyAndReap: SIGKILL plus a bounded waitFor so the kernel actually reaps the child. Used from cleanupFailedSpawn, waitForSocket, runCallable timeout, and both kill paths in DirectWorkerProcess.close(). Separate SIGKILL_REAP_TIMEOUT_MS (5s, distinct from gracefulTimeoutMs); logs a warning with a context tag if the reap window expires. - waitForSocket terminal branch reports the worker's exit code when it exited cleanly before creating the socket, instead of the ambiguous "did not create within Nms" message. - Mark `closed`, `workers`, `cleanupHook` as `private[this]` so the lifecycle state cannot be touched from other instances. Error model: - Introduce DirectWorkerException (extends RuntimeException) in the direct package. Replace all runtime-failure RuntimeException throws in DirectWorkerDispatcher so callers can catch the specific type instead of every RuntimeException. IllegalArgumentException and IllegalStateException continue to signal programming errors. Environment lifecycle: - Register the cleanup hook up front in the Pending branch whenever environment_cleanup is configured, independent of verify/install. Cleanup is user-defined and may tear down worker state, temp files, and other runtime artifacts beyond install output, so it should honor the user's configuration regardless of whether setup ran. Early registration also covers partial-install failure. - assert(Thread.holdsLock(environmentLock)) in registerEnvironmentCleanupHook so the docstring contract fails loud if a future caller forgets the lock. - TODO noting the per-dispatcher shutdown-hook retention cost in long-running drivers, pointing at a shared coordinator as the follow-up. WorkerSession contract: - Make init and process final on WorkerSession. Guard both with AtomicBoolean and delegate to abstract doInit / doProcess. Subclasses get "exactly once" / "at most once" for free and cannot bypass the contract. - Rewrite the InitMessage docstring: drop the PythonInputPartition / UpdateDelegationTokens cross-references and mark it as a placeholder that will be replaced by a proto message once the UDF protocol lands. Security: - Create the socket directory with POSIX 0700 via PosixFilePermissions.asFileAttribute so UDS sockets inside are not reachable by other local users. Non-POSIX filesystems fall back to File.setXxx; that path WARNs if any setXxx returns false so operators see the degraded mode instead of silently shipping a world-accessible directory. Polish: - Rename proto message WorkerConnection -> WorkerConnectionSpec to disambiguate from the core Scala WorkerConnection abstraction. Drop the WorkerConnectionProto import alias in the test suite and update the README usage snippet. - Fix typo in worker_spec.proto: engine-assinged -> engine-assigned. - DirectWorkerProcess Scaladoc: say one process per session today; the ref-count is scaffolding for future pooling, not live sharing. - readOutputTail uses FileChannel.position for an O(1) unambiguous seek instead of the FileInputStream.skip loop. - Extract throwClosed() and throwWorkerExitedBeforeSocket() helpers to deduplicate error-message construction. - Catch IOException on the outputFile delete in cleanupFailedSpawn so a cleanup failure does not mask the original spawn error. - destroyForciblyAndReap early-returns on InterruptedException to avoid a spurious "still alive after SIGKILL" warning when the full reap window was never actually waited. - Scaladoc on close() noting it does not drain in-flight createSession calls; brace the CAS early-return for readability. Tests: - createSession after close is rejected. - Socket directory is owner-only (0700) on POSIX; use `assume` instead of a silent skip on non-POSIX. - Socket directory is removed after dispatcher.close. - SIGKILL escalation: worker traps SIGTERM; assert the process is reaped after close() and that close waited at least gracefulTimeoutMs. Uses a 500ms graceful window to keep the test bounded. - initTimeoutMs: worker stays alive but never binds the socket; assert the "did not create socket" error and reap. - callableTimeoutMs: install sleeps past the timeout; assert "Callable timed out" and reap. - Concurrent createSession still installs exactly once -- races many threads through ensureEnvironmentReady. - Strengthen "distinct workers" test with a socketPath uniqueness check (object identity alone is insufficient). - TODO on StubWorkerSession noting cancel() needs real coverage once a concrete session impl lands. --- udf/worker/README.md | 4 +- .../spark/udf/worker/core/WorkerSession.scala | 52 +++- .../core/direct/DirectWorkerDispatcher.scala | 236 ++++++++++++++--- .../core/direct/DirectWorkerException.scala | 34 +++ .../core/direct/DirectWorkerProcess.scala | 16 +- .../core/DirectWorkerDispatcherSuite.scala | 237 +++++++++++++++++- .../proto/src/main/protobuf/worker_spec.proto | 12 +- 7 files changed, 536 insertions(+), 55 deletions(-) create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala diff --git a/udf/worker/README.md b/udf/worker/README.md index b3c634c9e0bf8..2d6ee3f09cf02 100644 --- a/udf/worker/README.md +++ b/udf/worker/README.md @@ -78,7 +78,7 @@ Workers are terminated via SIGTERM/SIGKILL when the dispatcher is closed. import org.apache.spark.udf.worker.{ DirectWorker, ProcessCallable, UDFProtoCommunicationPattern, UDFWorkerDataFormat, UDFWorkerProperties, UDFWorkerSpecification, - UnixDomainSocket, WorkerCapabilities, WorkerConnection, WorkerEnvironment} + UnixDomainSocket, WorkerCapabilities, WorkerConnectionSpec, WorkerEnvironment} import org.apache.spark.udf.worker.core._ // 1. Define a worker spec (direct creation mode). @@ -98,7 +98,7 @@ val spec = UDFWorkerSpecification.newBuilder() .setRunner(ProcessCallable.newBuilder() .addCommand("python").addCommand("-m").addCommand("my_udf_worker").build()) .setProperties(UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnection.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) .build()) .build()) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala index 3724e37a6dc6f..7052fd16312e8 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.udf.worker.core +import java.util.concurrent.atomic.AtomicBoolean + import org.apache.spark.annotation.Experimental /** @@ -28,6 +30,13 @@ import org.apache.spark.annotation.Experimental * Spark context metadata, chaining information) without breaking existing * worker implementations. * + * Placeholder until the UDF protocol lands: this Scala case class will be + * replaced by a generated proto message once the wire protocol is + * introduced. Do not rely on case-class equality -- `Array[Byte]` fields + * compare by reference under the default case-class `equals`/`hashCode`, + * so do not use [[InitMessage]] as a hash-based collection key or + * compare instances by value without first wrapping the byte arrays. + * * @param functionPayload serialized function (e.g., pickled Python, JVM bytes) * @param inputSchema serialized input schema (e.g., Arrow schema bytes) * @param outputSchema serialized output schema (e.g., Arrow schema bytes) @@ -72,18 +81,34 @@ case class InitMessage( * - [[process]] must be called at most once per session. * - [[close]] must always be called (use try-finally). * - [[cancel]] may be called at any time to abort execution. + * + * The lifecycle is enforced by this class: [[init]] and [[process]] are + * `final` and delegate to [[doInit]] / [[doProcess]] after checking + * AtomicBoolean guards. Subclasses implement the protocol-specific work + * in [[doInit]] and [[doProcess]] and do not need to re-check the + * contract themselves. */ @Experimental abstract class WorkerSession extends AutoCloseable { + private val initialized = new AtomicBoolean(false) + private val processed = new AtomicBoolean(false) + /** * Initializes the UDF execution. Must be called exactly once before * [[process]]. * + * Throws `IllegalStateException` if called more than once. + * * @param message the initialization parameters including the serialized * function, input/output schemas, and configuration. */ - def init(message: InitMessage): Unit + final def init(message: InitMessage): Unit = { + if (!initialized.compareAndSet(false, true)) { + throw new IllegalStateException("init has already been called on this session") + } + doInit(message) + } /** * Processes input data through the worker and returns results. @@ -93,12 +118,33 @@ abstract class WorkerSession extends AutoCloseable { * iterator. The session sends a Finish signal to the worker when the input * iterator is exhausted. * - * Must be called at most once per session. + * Must be called after [[init]] and at most once per session. + * Throws `IllegalStateException` if called before [[init]] or more than once. * * @param input iterator of raw input data batches (e.g., Arrow IPC) * @return iterator of raw result data batches */ - def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] + final def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = { + if (!initialized.get()) { + throw new IllegalStateException("process called before init") + } + if (!processed.compareAndSet(false, true)) { + throw new IllegalStateException("process has already been called on this session") + } + doProcess(input) + } + + /** + * Subclass hook for [[init]]. Called exactly once, after the lifecycle + * guard has verified init has not already run. + */ + protected def doInit(message: InitMessage): Unit + + /** + * Subclass hook for [[process]]. Called at most once, after the + * lifecycle guard has verified init has run and process has not. + */ + protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] /** * Requests cancellation of the current UDF execution. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 5e0f111e76fa8..68ee680b86190 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -18,9 +18,11 @@ package org.apache.spark.udf.worker.core.direct import java.io.{BufferedReader, File, FileInputStream, InputStreamReader} import java.nio.charset.StandardCharsets -import java.nio.file.Files +import java.nio.file.{Files, Path} +import java.nio.file.attribute.PosixFilePermissions import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.{Queue => MQueue} import scala.jdk.CollectionConverters._ @@ -99,17 +101,27 @@ abstract class DirectWorkerDispatcher( // deliberately not registered: it is redundant with the explicit cleanup, // it leaks memory in long-lived JVMs (the JDK retains the path string for // the process lifetime), and it only works on empty directories. - private val socketDir = Files.createTempDirectory("spark-udf-worker") + // + // Created with POSIX 0700 so UDS sockets inside are not reachable by other + // local users. On non-POSIX filesystems the JDK rejects the attribute with + // UnsupportedOperationException; fall back to owner-only permissions via + // the File API, which is best-effort but still narrows the mode. + private val socketDir: Path = createPrivateTempDirectory() // Keyed by worker id so release is O(1) and lock-free. Iterating for - // shutdown is weakly consistent, which is fine: no new workers should be - // spawning once close() is running, and any in-flight releaseWorker that - // overlaps with close() is idempotent on both sides (remove-missing is a - // no-op and DirectWorkerProcess.close() is CAS-guarded). - private val workers = new ConcurrentHashMap[String, DirectWorkerProcess]() + // shutdown is weakly consistent, which is fine: any createSession racing + // with close() is caught by the `closed` flag checks in createSession, and + // any in-flight releaseWorker that overlaps with close() is idempotent on + // both sides (remove-missing is a no-op, DirectWorkerProcess.close() is + // CAS-guarded). + private[this] val workers = new ConcurrentHashMap[String, DirectWorkerProcess]() + // Flips to true on close(); createSession rejects afterwards and any + // already-spawned worker caught between `workers.put` and the post-publish + // check tears itself down instead of leaking. + private[this] val closed = new AtomicBoolean(false) @volatile private var environmentState: EnvironmentState = EnvironmentState.Pending private val environmentLock = new Object - private var cleanupHook: Option[Thread] = None + private[this] var cleanupHook: Option[Thread] = None /** Creates a protocol-specific connection to a worker at the given socket path. */ protected def createConnection(socketPath: String): WorkerConnection @@ -125,10 +137,21 @@ abstract class DirectWorkerDispatcher( // dispatcher actually honors the scope. require(securityScope.isEmpty, "securityScope is not supported yet; pass None until pooling lands") + if (closed.get()) throwClosed() ensureEnvironmentReady() val worker = spawnWorker() - workers.put(worker.id, worker) + // Acquire the session ref-count BEFORE publishing to `workers`: otherwise + // close() could iterate `workers`, tear this worker down, and leave the + // subsequent acquireSession handing the caller a dead worker. worker.acquireSession() + workers.put(worker.id, worker) + // Re-check after publishing. close() may have iterated `workers` before + // our put (orphan case) or after (already closed here). Either way, + // releasing our ref-count fires the callback that removes and closes. + if (closed.get()) { + worker.releaseSession() + throwClosed() + } try { createSessionForWorker(worker) } catch { @@ -164,7 +187,29 @@ abstract class DirectWorkerDispatcher( } } + private def throwClosed(): Nothing = + throw new IllegalStateException("Dispatcher is closed") + + /** + * Shuts down the dispatcher: terminates tracked workers, removes the + * socket directory, and runs environment cleanup. Idempotent via CAS -- + * only the first caller performs teardown; subsequent calls are no-ops. + * + * Does not block in-flight `createSession` calls. A createSession caller + * racing past the post-publish `closed` check will still produce a + * spawned worker; that worker tears itself down via the ref-count + * callback, but the subprocess teardown may outlive this `close()`. + * Callers needing a fully quiescent state must externally synchronize + * with their `createSession` callers. + */ override def close(): Unit = { + // Flip `closed` first so that any concurrent createSession either bails + // on its fast-path check or cleans up its own worker via the post-publish + // check. Only the first caller performs the teardown; subsequent calls + // are no-ops. + if (!closed.compareAndSet(false, true)) { + return + } // TODO: Close workers in parallel. Worst-case shutdown today is // N * gracefulTimeoutMs because each worker waits for SIGTERM to // complete before the next one is signalled. A small pool of @@ -205,9 +250,18 @@ abstract class DirectWorkerDispatcher( case EnvironmentState.Ready | EnvironmentState.CleanedUp => // Already set up (or torn down); nothing to do. case EnvironmentState.Failed(msg) => - throw new RuntimeException(s"Environment setup previously failed: $msg") + throw new DirectWorkerException(s"Environment setup previously failed: $msg") case EnvironmentState.Pending => val env = workerSpec.getEnvironment + // Register the cleanup hook up front. Cleanup is user-defined + // and may tear down more than install artifacts (worker state, + // temp files, etc.), so we honor it whenever it is configured, + // independent of whether verify/install run. Registering before + // any callable runs also ensures a partially-successful install + // (e.g., half-copied files) still gets cleaned up at JVM + // shutdown if the caller never reaches dispatcher.close(). The + // hook is a no-op when environment_cleanup is not configured. + registerEnvironmentCleanupHook() val verified = env.hasEnvironmentVerification && runCallable(env.getEnvironmentVerification).exitCode == 0 if (!verified && env.hasInstallation) { @@ -215,16 +269,24 @@ abstract class DirectWorkerDispatcher( if (result.exitCode != 0) { val detail = s"exit code ${result.exitCode}\n${result.outputTail}" environmentState = EnvironmentState.Failed(detail) - throw new RuntimeException( + throw new DirectWorkerException( s"Environment installation failed with $detail") } } - registerEnvironmentCleanupHook() environmentState = EnvironmentState.Ready } } } + // TODO: Share a single JVM shutdown hook across all dispatchers in the + // process. Today each dispatcher (when environment_cleanup is set) + // registers its own Thread, which the JVM retains until shutdown. + // In a long-running driver that creates many dispatchers without + // explicit close() (e.g., failed-task paths), this keeps per-dispatcher + // memory live for the process lifetime. A shared cleanup coordinator + // draining a ConcurrentLinkedQueue would collapse the N hooks + // into one. + /** * Registers the JVM shutdown hook that runs the cleanup callable. * @@ -233,6 +295,8 @@ abstract class DirectWorkerDispatcher( * called from `ensureEnvironmentReady`, which already owns the lock. */ private def registerEnvironmentCleanupHook(): Unit = { + assert(Thread.holdsLock(environmentLock), + "registerEnvironmentCleanupHook must be called while holding environmentLock") if (cleanupHook.isDefined) return if (workerSpec.getEnvironment.hasEnvironmentCleanup) { val hook = new Thread(() => runEnvironmentCleanup(), "udf-env-cleanup") @@ -294,9 +358,10 @@ abstract class DirectWorkerDispatcher( cmd, callable.getEnvironmentVariablesMap.asScala.toMap, outputFile.toFile) val timeoutMs = callableTimeoutMs if (!process.waitFor(timeoutMs, TimeUnit.MILLISECONDS)) { - process.destroyForcibly() + DirectWorkerDispatcher.destroyForciblyAndReap( + process, logger, s"callable timeout: ${cmd.head}") val tail = readOutputTail(outputFile.toFile) - throw new RuntimeException( + throw new DirectWorkerException( s"Callable timed out after ${timeoutMs}ms: " + s"${cmd.mkString(" ")}\n$tail") } @@ -345,15 +410,57 @@ abstract class DirectWorkerDispatcher( private def cleanupFailedSpawn( process: Process, socketPath: String, - outputFile: java.nio.file.Path): Unit = { - if (process.isAlive) process.destroyForcibly() + outputFile: Path): Unit = { + DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, "failed spawn") // If the worker (or createConnection) had already created the socket // file, remove it so it doesn't linger until dispatcher.close(). try Files.deleteIfExists(new File(socketPath).toPath) catch { case NonFatal(cleanupEx) => logger.debug(s"Failed to clean up socket file $socketPath", cleanupEx) } - Files.deleteIfExists(outputFile) + // Swallow IOException here so we don't replace the original spawn + // failure with a cleanup failure. + try Files.deleteIfExists(outputFile) catch { + case NonFatal(cleanupEx) => + logger.debug(s"Failed to clean up worker output file $outputFile", cleanupEx) + } + } + + /** + * Creates a temp directory with owner-only permissions (0700 on POSIX). + * Falls back to a best-effort `File.setXxx` on non-POSIX filesystems + * that cannot honor the attribute. The fallback is racy and weaker + * than the POSIX path; the logger call surfaces when the platform + * refuses the `setXxx` calls so operators see the degraded mode. + */ + private def createPrivateTempDirectory(): Path = { + val attr = PosixFilePermissions.asFileAttribute( + PosixFilePermissions.fromString("rwx------")) + try { + Files.createTempDirectory("spark-udf-worker", attr) + } catch { + case _: UnsupportedOperationException => + // Non-POSIX filesystem. The dir exists with default perms between + // `createTempDirectory` and the `setXxx` calls below, so this + // fallback is TOCTOU-racy by nature. + val dir = Files.createTempDirectory("spark-udf-worker") + val f = dir.toFile + // Strip group/other access, then restore owner rwx. `&` (non-short- + // circuiting) so every call is attempted before we decide whether + // to warn; any `false` return means the platform silently refused + // the change and the directory is less private than advertised. + val applied = + f.setReadable(false, false) & f.setWritable(false, false) & + f.setExecutable(false, false) & f.setReadable(true, true) & + f.setWritable(true, true) & f.setExecutable(true, true) + if (!applied) { + logger.warn( + s"Could not fully restrict permissions on $dir; socket " + + s"directory may be accessible to other local users on this " + + s"filesystem") + } + dir + } } /** @@ -382,23 +489,36 @@ abstract class DirectWorkerDispatcher( val maxAttempts = math.max(1, (initTimeoutMs / SOCKET_POLL_INTERVAL_MS).toInt) var attempts = 0 while (!file.exists() && attempts < maxAttempts) { - if (!process.isAlive) { - val tail = readOutputTail(outputFile) - throw new RuntimeException( - s"Worker exited with code ${process.exitValue()} " + - s"before creating socket at $socketPath\n$tail") - } + if (!process.isAlive) throwWorkerExitedBeforeSocket(process, socketPath, outputFile) Thread.sleep(SOCKET_POLL_INTERVAL_MS) attempts += 1 } if (!file.exists()) { - val tail = readOutputTail(outputFile) - if (process.isAlive) process.destroyForcibly() - throw new RuntimeException( - s"Worker did not create socket at $socketPath within ${initTimeoutMs}ms\n$tail") + if (process.isAlive) { + DirectWorkerDispatcher.destroyForciblyAndReap( + process, logger, s"init timeout $socketPath") + val tail = readOutputTail(outputFile) + throw new DirectWorkerException( + s"Worker did not create socket at $socketPath within ${initTimeoutMs}ms\n$tail") + } else { + // The worker exited between the last file.exists() poll and here + // without creating the socket -- report the exit code rather than + // an ambiguous "did not create" message. + throwWorkerExitedBeforeSocket(process, socketPath, outputFile) + } } } + private def throwWorkerExitedBeforeSocket( + process: Process, + socketPath: String, + outputFile: File): Nothing = { + val tail = readOutputTail(outputFile) + throw new DirectWorkerException( + s"Worker exited with code ${process.exitValue()} " + + s"before creating socket at $socketPath\n$tail") + } + // Reads at most the final MAX_OUTPUT_SCAN_BYTES of `file` and returns the // last PROCESS_OUTPUT_TAIL_LINES lines via a fixed-size ring buffer, so a // runaway worker that writes gigabytes of output does not OOM the caller @@ -409,11 +529,10 @@ abstract class DirectWorkerDispatcher( val startPos = math.max(0L, fileLen - MAX_OUTPUT_SCAN_BYTES) val fis = new FileInputStream(file) try { - var remaining = startPos - while (remaining > 0) { - val n = fis.skip(remaining) - if (n <= 0) remaining = 0 else remaining -= n - } + // FileChannel.position is O(1) and unambiguously seeks, unlike + // FileInputStream.skip which is allowed to return 0 even when bytes + // remain. + if (startPos > 0) fis.getChannel.position(startPos) val reader = new BufferedReader( new InputStreamReader(fis, StandardCharsets.UTF_8)) // If we started mid-line, the first line is partial -- discard it so @@ -477,6 +596,59 @@ private[direct] object DirectWorkerDispatcher { // producing gigabytes of output cannot OOM the caller during error // reporting. The tail is still limited to PROCESS_OUTPUT_TAIL_LINES. private[direct] val MAX_OUTPUT_SCAN_BYTES = 1024L * 1024L // 1 MiB + // Bound on how long we wait for the kernel to reap a SIGKILL'd child. + // Distinct from the user-configurable gracefulTimeoutMs: SIGKILL is + // unblockable, so the only delay is kernel latency (milliseconds in the + // common case). Five seconds is generous; if we exceed it the child is + // likely stuck in uninterruptible I/O (D-state) and further waiting + // won't help. + private[direct] val SIGKILL_REAP_TIMEOUT_MS = 5000L + + /** + * Sends SIGKILL to `process` and waits up to [[SIGKILL_REAP_TIMEOUT_MS]] + * for the kernel to reap it. + * + * `destroyForcibly()` returns before the child has been reaped from the + * process table. Without a bounded `waitFor` the child lingers as a + * zombie until the JVM itself exits; in long-lived drivers with repeated + * failed spawns or session teardowns, the zombies accumulate. Callers + * should use this helper instead of calling `destroyForcibly()` directly. + * + * Behaviour: + * - No-op if the process is already dead. + * - If the reap times out, logs a warning and returns -- we do not + * block forever on a wedged child. The process remains a zombie; + * this is a real (but rare) operational condition the warning + * surfaces. + * - If the current thread is interrupted while waiting, re-raises + * the interrupt and returns without logging a zombie warning -- + * we have not actually waited the full reap window. + * + * @param context short human-readable tag included in the warn log so + * operators can correlate a wedged child with its source + * (e.g. worker id, or the callable that was killed). + */ + private[direct] def destroyForciblyAndReap( + process: Process, + logger: WorkerLogger, + context: String = ""): Unit = { + if (!process.isAlive) return + process.destroyForcibly() + val reaped = try { + process.waitFor(SIGKILL_REAP_TIMEOUT_MS, TimeUnit.MILLISECONDS) + } catch { + case _: InterruptedException => + Thread.currentThread().interrupt() + return + } + if (!reaped && process.isAlive) { + val suffix = if (context.nonEmpty) s" [$context]" else "" + logger.warn( + s"Process ${process.pid()}$suffix still alive ${SIGKILL_REAP_TIMEOUT_MS}ms " + + s"after SIGKILL; leaving behind as zombie " + + s"(likely stuck in uninterruptible kernel state)") + } + } /** Result of running a [[ProcessCallable]]. */ private[core] case class CallableResult(exitCode: Int, outputTail: String) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala new file mode 100644 index 0000000000000..3f2737571a56a --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * Thrown by [[DirectWorkerDispatcher]] for runtime failures: worker + * spawn problems, environment setup or cleanup failures, callable + * timeouts, and socket-establishment timeouts. + * + * Distinguished from `IllegalArgumentException` (bad spec) and + * `IllegalStateException` (using a closed dispatcher), which indicate + * programming errors. Catching this type lets callers handle runtime + * failures specifically without catching every `RuntimeException`. + */ +@Experimental +class DirectWorkerException(message: String, cause: Throwable = null) + extends RuntimeException(message, cause) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index 9c530471dcdad..3c196650ebcbe 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -36,9 +36,10 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} * - A '''[[WorkerConnection]]''' (the transport channel to that process). * - A '''socket path''' (a UDS socket file) that both sides use. * - * Multiple [[DirectWorkerSession]]s may share the same process when the - * worker supports concurrent UDFs. The [[acquireSession]]/[[releaseSession]] - * ref-count tracks how many sessions are active. + * The dispatcher currently creates one process per session and tears it + * down when the session closes; the [[acquireSession]]/[[releaseSession]] + * ref-count is scaffolding for future pooling, where a single process + * will back multiple concurrent sessions. * * Closing tears down everything: closes the connection, sends SIGTERM * (then SIGKILL), and removes the socket file and the process output log. @@ -155,11 +156,16 @@ class DirectWorkerProcess( process.destroy() // SIGTERM try { if (!process.waitFor(gracefulTimeoutMs, TimeUnit.MILLISECONDS)) { - process.destroyForcibly() // SIGKILL + // Graceful window expired; escalate to SIGKILL. destroyForciblyAndReap + // waits for the kernel to actually reap the child so we don't leak + // a zombie here. + DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, s"worker $id") } } catch { case _: InterruptedException => - process.destroyForcibly() + // Interrupted mid-wait: kill forcibly so we don't leave the child + // behind, then re-raise the interrupt on the caller. + DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, s"worker $id") Thread.currentThread().interrupt() } } diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 0a6d528eb29b4..09735e9395fde 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.udf.worker.core import java.io.File -import java.nio.file.Files +import java.nio.file.{Files, Path} +import java.nio.file.attribute.PosixFileAttributeView import scala.jdk.CollectionConverters._ @@ -27,8 +28,8 @@ import org.scalatest.funsuite.AnyFunSuite import org.apache.spark.udf.worker.{ DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, - UDFWorkerSpecification, UnixDomainSocket, - WorkerConnection => WorkerConnectionProto, WorkerEnvironment} + UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec, + WorkerEnvironment} import org.apache.spark.udf.worker.core.direct.{DirectWorkerDispatcher, DirectWorkerProcess, DirectWorkerSession} @@ -47,13 +48,21 @@ class SocketFileConnection(socketPath: String) extends WorkerConnection { /** * A stub [[DirectWorkerSession]] for process-lifecycle tests that don't * need actual data transmission. + * + * TODO: [[cancel]] is a no-op here. Once a concrete [[DirectWorkerSession]] + * with real data-plane wiring lands, add tests exercising cancel() in + * particular: cancel from a different thread than process(), cancel + * after process() has returned, and cancel before init (should be a + * no-op). Tracking the thread-safety contract in the docstring on + * [[org.apache.spark.udf.worker.core.WorkerSession.cancel]]. */ class StubWorkerSession( workerProcess: DirectWorkerProcess) extends DirectWorkerSession(workerProcess) { - override def init(message: InitMessage): Unit = {} + override protected def doInit(message: InitMessage): Unit = {} - override def process(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = + override protected def doProcess( + input: Iterator[Array[Byte]]): Iterator[Array[Byte]] = Iterator.empty override def cancel(): Unit = {} @@ -105,7 +114,7 @@ class DirectWorkerDispatcherSuite .build() private def udsProperties: UDFWorkerProperties = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionProto.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance) .build()) .build() @@ -194,6 +203,13 @@ class DirectWorkerDispatcherSuite val workerObjects = sessionList.map(_.workerProcess) assert(workerObjects.distinct.length == threads, "each session should have its own DirectWorkerProcess") + // Object-identity is not sufficient on its own: a future regression + // that accidentally shared underlying transport resources could still + // hand out distinct DirectWorkerProcess wrappers pointing at the same + // socket. Verify socket paths are unique too. + val socketPaths = workerObjects.map(_.socketPath) + assert(socketPaths.distinct.length == threads, + s"each worker should have its own socket path, got $socketPaths") sessionList.foreach(_.close()) } @@ -216,6 +232,54 @@ class DirectWorkerDispatcherSuite assert(!worker2.process.isAlive, "worker2 should be terminated") } + test("close escalates to SIGKILL when worker ignores SIGTERM") { + // The worker traps SIGTERM so the graceful stop is ineffective; the + // dispatcher must escalate to SIGKILL via destroyForciblyAndReap. + // Using a short gracefulTimeoutMs (500ms) keeps the test bounded: + // max close time is gracefulTimeoutMs + SIGKILL_REAP_TIMEOUT_MS. + val sigtermIgnoringScript = + """ + |#!/bin/bash + |SOCKET_PATH="" + |while [[ $# -gt 0 ]]; do + | case "$1" in + | --connection) SOCKET_PATH="$2"; shift 2 ;; + | *) shift ;; + | esac + |done + |touch "$SOCKET_PATH" + |trap '' SIGTERM + |while true; do sleep 1; done + """.stripMargin.trim + val runner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c").addCommand(sigtermIgnoringScript).addCommand("--") + .build() + val shortGracefulProps = UDFWorkerProperties.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setGracefulTerminationTimeoutMs(500) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(runner).setProperties(shortGracefulProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + val session = createStubSession() + val worker = session.workerProcess + assert(worker.process.isAlive, "worker should be alive before close") + + val closeStart = System.nanoTime() + session.close() + val closeElapsedMs = (System.nanoTime() - closeStart) / 1000000L + + assert(!worker.process.isAlive, + s"worker should have been SIGKILLed after ignoring SIGTERM (took ${closeElapsedMs}ms)") + assert(closeElapsedMs >= 500L, + s"close should have waited for gracefulTimeoutMs before escalating, " + + s"took ${closeElapsedMs}ms") + } + test("closing a session terminates its worker") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) @@ -280,6 +344,53 @@ class DirectWorkerDispatcherSuite } } + test("createSession after close is rejected") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + dispatcher.close() + + val ex = intercept[IllegalStateException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("closed"), + s"expected dispatcher-closed error, got: ${ex.getMessage}") + dispatcher = null + } + + test("socket directory is owner-only (0700) on POSIX") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + // Drive one createSession so a worker (and therefore the socket dir) is + // observable via session.workerProcess.socketPath. + val session = createStubSession() + val socketDir: Path = new File(session.workerProcess.socketPath).toPath.getParent + session.close() + + val view = Files.getFileAttributeView(socketDir, classOf[PosixFileAttributeView]) + // Skip explicitly on non-POSIX filesystems rather than silently pass, + // so a CI environment without POSIX attributes is visible in the + // test report instead of giving false confidence. + assume(view != null, s"POSIX file attributes required to check $socketDir") + val perms = view.readAttributes().permissions().asScala.toSet + val expected = java.nio.file.attribute.PosixFilePermissions + .fromString("rwx------").asScala.toSet + assert(perms == expected, + s"socket directory $socketDir should be 0700, got ${perms.mkString(",")}") + } + + test("socket directory is removed after dispatcher.close") { + dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) + val session = createStubSession() + val socketDir = new File(session.workerProcess.socketPath).toPath.getParent.toFile + assert(socketDir.exists(), + s"socket directory $socketDir should exist while a session is open") + session.close() + + dispatcher.close() + dispatcher = null + + assert(!socketDir.exists(), + s"socket directory $socketDir should be removed after dispatcher.close") + } + // -- Error-path tests ------------------------------------------------------- test("worker is cleaned up when createSessionForWorker throws") { @@ -324,7 +435,7 @@ class DirectWorkerDispatcherSuite test("DirectWorker with non-UDS transport is rejected") { val tcpProperties = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionProto.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() .setTcp(LocalTcpConnection.getDefaultInstance).build()) .build() val badSpec = UDFWorkerSpecification.newBuilder() @@ -340,9 +451,9 @@ class DirectWorkerDispatcherSuite test("DirectWorker with multiple connections is rejected") { val twoConnections = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionProto.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) - .addConnections(WorkerConnectionProto.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) .build() val badSpec = UDFWorkerSpecification.newBuilder() @@ -407,6 +518,34 @@ class DirectWorkerDispatcherSuite s"expected process output in error, got: ${ex.getMessage}") } + test("spawnWorker times out when worker stays alive but never creates socket") { + // Distinct from the "process exits immediately" case: here the worker + // process is healthy but simply doesn't bind the socket, so the + // dispatcher must time out and SIGKILL-reap it rather than wait forever. + val hangingRunner = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("while true; do sleep 1; done").addCommand("--") + .build() + val shortInitProps = UDFWorkerProperties.newBuilder() + .addConnections(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setInitializationTimeoutMs(500) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(hangingRunner).setProperties(shortInitProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + val ex = intercept[RuntimeException] { + dispatcher.createSession(None) + } + assert(ex.getMessage.contains("did not create socket"), + s"expected init-timeout error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("500ms"), + s"expected timeout value in error, got: ${ex.getMessage}") + } + // -- Environment lifecycle tests ------------------------------------------- test("skips installation when verification succeeds") { @@ -486,6 +625,36 @@ class DirectWorkerDispatcherSuite s"expected process output in error, got: ${ex.getMessage}") } + test("installation that exceeds callableTimeoutMs is killed and reported") { + // Installation sleeps longer than callableTimeoutMs; the dispatcher + // must SIGKILL-reap it and surface a "Callable timed out" error + // rather than hang the caller. + val slowInstall = ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand("sleep 30").build() + val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build() + val shortTimeoutDispatcher = + new DirectWorkerDispatcher(specWithEnv(env = env)) { + override protected def callableTimeoutMs: Long = 500L + override protected def createConnection(socketPath: String): WorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } + try { + val ex = intercept[RuntimeException] { + shortTimeoutDispatcher.createSession(None) + } + assert(ex.getMessage.contains("Callable timed out"), + s"expected callable-timeout error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("500ms"), + s"expected timeout value in error, got: ${ex.getMessage}") + } finally { + shortTimeoutDispatcher.close() + } + } + test("environment setup runs only once across multiple sessions") { val counterFile = Files.createTempFile("env-counter", ".txt").toFile counterFile.delete() @@ -507,6 +676,56 @@ class DirectWorkerDispatcherSuite counterFile.delete() } + test("concurrent createSession still installs exactly once") { + // The sequential single-install test above cannot catch a missing + // lock around ensureEnvironmentReady. Race many createSession calls + // with an install script that takes long enough for the threads to + // queue on environmentLock, then verify it still ran exactly once. + val counterFile = Files.createTempFile("env-concurrent-install", ".txt").toFile + counterFile.delete() + + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand( + s"sleep 0.2; echo invoked >> ${counterFile.getAbsolutePath}").build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(specWithEnv(env = env)) + + val threads = 4 + val startGate = new java.util.concurrent.CountDownLatch(1) + val doneGate = new java.util.concurrent.CountDownLatch(threads) + val sessions = new java.util.concurrent.ConcurrentLinkedQueue[WorkerSession]() + val errors = new java.util.concurrent.ConcurrentLinkedQueue[Throwable]() + + (1 to threads).foreach { _ => + new Thread(() => { + try { + startGate.await() + sessions.add(dispatcher.createSession(None)) + } catch { + case t: Throwable => errors.add(t) + } finally { + doneGate.countDown() + } + }).start() + } + startGate.countDown() + assert(doneGate.await(30, java.util.concurrent.TimeUnit.SECONDS), + "createSession threads did not finish in time") + assert(errors.isEmpty, + s"unexpected errors during concurrent createSession: ${errors.toArray.mkString(", ")}") + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run exactly once under concurrent createSession, " + + s"but ran ${lines.size} time(s)") + + sessions.asScala.foreach(_.close()) + counterFile.delete() + } + test("failed environment setup is not retried on subsequent createSession") { val counterFile = Files.createTempFile("env-failed-counter", ".txt").toFile counterFile.delete() diff --git a/udf/worker/proto/src/main/protobuf/worker_spec.proto b/udf/worker/proto/src/main/protobuf/worker_spec.proto index a6c1c910ec443..29718f792b43e 100644 --- a/udf/worker/proto/src/main/protobuf/worker_spec.proto +++ b/udf/worker/proto/src/main/protobuf/worker_spec.proto @@ -202,7 +202,7 @@ message UDFWorkerProperties { // // On [[DirectWorker]] creation, connection information // is passed to the callable as a string parameter. - // The string format depends on the [[WorkerConnection]]: + // The string format depends on the [[WorkerConnectionSpec]]: // // For example, when using TCP, the callable argument will be: // --connection PORT @@ -214,10 +214,14 @@ message UDFWorkerProperties { // (Required) Exactly one entry today; the field is repeated to allow // additional connections (e.g., data + control) to be added without a // schema-breaking migration. - repeated WorkerConnection connections = 3; + repeated WorkerConnectionSpec connections = 3; } -message WorkerConnection { +// Describes one connection (transport endpoint) that a [[DirectWorker]] +// exposes. This is a configuration message -- the live transport object +// used by the engine at runtime is the Scala abstraction +// `org.apache.spark.udf.worker.core.WorkerConnection`. +message WorkerConnectionSpec { // (Required) oneof transport { UnixDomainSocket unix_domain_socket = 1; @@ -284,7 +288,7 @@ message ProcessCallable { // // --connection // The value of the connection argument is a string with - // engine-assinged connection parameters. See [[UDFWorkerProperties]] + // engine-assigned connection parameters. See [[UDFWorkerProperties]] // for details. // // (Optional) From 8384d0cc17023d5f81366c686ac61c1fec749159 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Fri, 24 Apr 2026 14:20:16 +0000 Subject: [PATCH 6/9] address comments. 1. use one connection for now. 2. apply a max cap for time out, 3. better cleanup logic. --- udf/worker/README.md | 2 +- .../udf/worker/core/WorkerConnection.scala | 6 + .../core/direct/DirectWorkerDispatcher.scala | 96 +++++++++------ .../core/direct/DirectWorkerProcess.scala | 115 ++++++++++++------ .../core/DirectWorkerDispatcherSuite.scala | 69 +++++++---- .../proto/src/main/protobuf/worker_spec.proto | 15 +-- 6 files changed, 195 insertions(+), 108 deletions(-) diff --git a/udf/worker/README.md b/udf/worker/README.md index 2d6ee3f09cf02..b843c430d0e04 100644 --- a/udf/worker/README.md +++ b/udf/worker/README.md @@ -98,7 +98,7 @@ val spec = UDFWorkerSpecification.newBuilder() .setRunner(ProcessCallable.newBuilder() .addCommand("python").addCommand("-m").addCommand("my_udf_worker").build()) .setProperties(UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionSpec.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) .build()) .build()) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala index 1f78c961c9f75..e1c4dd324ee67 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala @@ -28,6 +28,12 @@ import org.apache.spark.annotation.Experimental * process wrapper (e.g., [[direct.DirectWorkerProcess]]) and shared * across all [[WorkerSession]]s that use that process. * + * One connection, many sessions: a worker process exposes a single + * server-side endpoint (e.g., one UDS path, one TCP port), and multiple + * concurrent UDF executions share it. For gRPC transports the channel + * multiplexes independent streams per session; for a raw-socket transport + * the connection still represents the shared underlying channel. + * * Implementations wrap the concrete transport and expose only lifecycle * methods. Data transmission happens at the [[WorkerSession]] level, not * here -- this class is solely about whether the channel is open. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 68ee680b86190..7ba1beb5e4a3e 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -34,8 +34,8 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, WorkerLogger, WorkerSecurityScope, WorkerSession} import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, DEFAULT_CALLABLE_TIMEOUT_MS, DEFAULT_GRACEFUL_TIMEOUT_MS, DEFAULT_INIT_TIMEOUT_MS, - EnvironmentState, MAX_OUTPUT_SCAN_BYTES, PROCESS_OUTPUT_TAIL_LINES, - SOCKET_POLL_INTERVAL_MS} + ENGINE_MAX_TIMEOUT_MS, EnvironmentState, MAX_OUTPUT_SCAN_BYTES, + PROCESS_OUTPUT_TAIL_LINES, SOCKET_POLL_INTERVAL_MS} /** * :: Experimental :: @@ -79,22 +79,45 @@ abstract class DirectWorkerDispatcher( */ protected def callableTimeoutMs: Long = DEFAULT_CALLABLE_TIMEOUT_MS - private val initTimeoutMs: Long = { + // Per worker_spec.proto, the engine applies a maximum cap to each + // proto-provided worker timeout (initialization_timeout_ms, + // graceful_termination_timeout_ms) so a misbehaving spec cannot make the + // engine wait arbitrarily long during startup or shutdown. The cap is + // [[ENGINE_MAX_TIMEOUT_MS]] (fixed at 30s for now). The dispatcher-internal + // `callableTimeoutMs` is subclass-controlled, not user-controlled, and is + // not subject to this cap. + // Exposed at package visibility so the core-level test suite can assert + // the engine-side clamp applied (kept out of the public API). + private[core] val initTimeoutMs: Long = { val props = workerSpec.getDirect.getProperties - if (props.hasInitializationTimeoutMs && props.getInitializationTimeoutMs > 0) { + val raw = if (props.hasInitializationTimeoutMs && props.getInitializationTimeoutMs > 0) { props.getInitializationTimeoutMs.toLong } else { DEFAULT_INIT_TIMEOUT_MS } + clampTimeout("initialization_timeout_ms", raw) } private val gracefulTimeoutMs: Long = { val props = workerSpec.getDirect.getProperties - if (props.hasGracefulTerminationTimeoutMs && props.getGracefulTerminationTimeoutMs > 0) { + val raw = if (props.hasGracefulTerminationTimeoutMs && + props.getGracefulTerminationTimeoutMs > 0) { props.getGracefulTerminationTimeoutMs.toLong } else { DEFAULT_GRACEFUL_TIMEOUT_MS } + clampTimeout("graceful_termination_timeout_ms", raw) + } + + private def clampTimeout(field: String, raw: Long): Long = { + if (raw > ENGINE_MAX_TIMEOUT_MS) { + logger.warn( + s"Worker-provided $field=${raw}ms exceeds engine maximum " + + s"${ENGINE_MAX_TIMEOUT_MS}ms; using ${ENGINE_MAX_TIMEOUT_MS}ms instead") + ENGINE_MAX_TIMEOUT_MS + } else { + raw + } } // The socket directory is removed explicitly in close(). deleteOnExit is @@ -386,46 +409,31 @@ abstract class DirectWorkerDispatcher( val env = runner.getEnvironmentVariablesMap.asScala.toMap val outputFile = Files.createTempFile("udf-worker-", ".log") val process = launchProcess(cmd, env, outputFile.toFile) + // Wrap the raw resources into a closeable bundle immediately so any + // subsequent failure (waitForSocket, createConnection) can dispose them + // through the same path the happy-path teardown uses. + val artifacts = new WorkerArtifacts(process, socketPath, outputFile, logger) try { waitForSocket(socketPath, process, outputFile.toFile) val connection = createConnection(socketPath) - // Ownership of `outputFile` transfers to the DirectWorkerProcess: it - // remains valid for the child's file descriptor and is deleted in - // DirectWorkerProcess.close(). + // Ownership of `artifacts` transfers to the DirectWorkerProcess: they + // remain alive for the duration of the worker and are disposed via + // DirectWorkerProcess.close() -> artifacts.close(). new DirectWorkerProcess( - workerId, process, connection, socketPath, outputFile, - gracefulTimeoutMs, logger, onLastSessionReleased = releaseWorker) + workerId, artifacts, connection, gracefulTimeoutMs, logger, + onLastSessionReleased = releaseWorker) } catch { case e: InterruptedException => Thread.currentThread().interrupt() - cleanupFailedSpawn(process, socketPath, outputFile) + artifacts.close() throw e case NonFatal(e) => - cleanupFailedSpawn(process, socketPath, outputFile) + artifacts.close() throw e } } - private def cleanupFailedSpawn( - process: Process, - socketPath: String, - outputFile: Path): Unit = { - DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, "failed spawn") - // If the worker (or createConnection) had already created the socket - // file, remove it so it doesn't linger until dispatcher.close(). - try Files.deleteIfExists(new File(socketPath).toPath) catch { - case NonFatal(cleanupEx) => - logger.debug(s"Failed to clean up socket file $socketPath", cleanupEx) - } - // Swallow IOException here so we don't replace the original spawn - // failure with a cleanup failure. - try Files.deleteIfExists(outputFile) catch { - case NonFatal(cleanupEx) => - logger.debug(s"Failed to clean up worker output file $outputFile", cleanupEx) - } - } - /** * Creates a temp directory with owner-only permissions (0700 on POSIX). * Falls back to a best-effort `File.setXxx` on non-POSIX filesystems @@ -558,16 +566,13 @@ abstract class DirectWorkerDispatcher( // -- Spec validation ------------------------------------------------------- - // Multi-connection workers (e.g., a separate control channel) are a future - // extension; today the proto field is `repeated` but the engine requires - // exactly one. TCP transport is declared in the proto but not yet - // implemented; the engine currently only supports UDS. + // TCP transport is declared in the proto but not yet implemented; the + // engine currently only supports UDS. private def validateTransportSupport(): Unit = { val props = workerSpec.getDirect.getProperties - val n = props.getConnectionsCount - require(n == 1, - s"DirectWorker.properties.connections must have exactly one entry, got $n") - val conn = props.getConnections(0) + require(props.hasConnection, + "DirectWorker.properties.connection must be set") + val conn = props.getConnection require(conn.hasUnixDomainSocket, "DirectWorker currently only supports UNIX domain socket transport, " + s"got ${conn.getTransportCase}") @@ -591,6 +596,19 @@ private[direct] object DirectWorkerDispatcher { private[direct] val DEFAULT_INIT_TIMEOUT_MS = 10000L private[direct] val DEFAULT_CALLABLE_TIMEOUT_MS = 120000L private[direct] val DEFAULT_GRACEFUL_TIMEOUT_MS = 5000L + // Engine-side cap for proto-provided worker timeouts + // (initialization_timeout_ms, graceful_termination_timeout_ms). Fixed at + // 30s for now, matching the example value in worker_spec.proto; + // intentionally not configurable until we have a concrete need. The + // defaults below must stay at or under this cap so the clamp path only + // triggers on user-provided values, never on dispatcher defaults. + private[direct] val ENGINE_MAX_TIMEOUT_MS = 30000L + require(DEFAULT_INIT_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS, + s"DEFAULT_INIT_TIMEOUT_MS ($DEFAULT_INIT_TIMEOUT_MS) must not exceed " + + s"ENGINE_MAX_TIMEOUT_MS ($ENGINE_MAX_TIMEOUT_MS)") + require(DEFAULT_GRACEFUL_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS, + s"DEFAULT_GRACEFUL_TIMEOUT_MS ($DEFAULT_GRACEFUL_TIMEOUT_MS) must not " + + s"exceed ENGINE_MAX_TIMEOUT_MS ($ENGINE_MAX_TIMEOUT_MS)") private[direct] val PROCESS_OUTPUT_TAIL_LINES = 50 // Cap the amount of log file scanned by readOutputTail so a runaway worker // producing gigabytes of output cannot OOM the caller during error diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index 3c196650ebcbe..470df8a6c4431 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -32,33 +32,33 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} * transport connection to it. * * A [[DirectWorkerProcess]] combines three things: - * - An OS '''process''' (the worker binary, started by the dispatcher). + * - An OS '''process''' (the worker binary, started by the dispatcher), + * along with its socket file and output log -- all owned jointly by a + * [[WorkerArtifacts]] bundle. * - A '''[[WorkerConnection]]''' (the transport channel to that process). - * - A '''socket path''' (a UDS socket file) that both sides use. + * - A session '''ref-count''' for future pooling. * * The dispatcher currently creates one process per session and tears it * down when the session closes; the [[acquireSession]]/[[releaseSession]] * ref-count is scaffolding for future pooling, where a single process * will back multiple concurrent sessions. * - * Closing tears down everything: closes the connection, sends SIGTERM - * (then SIGKILL), and removes the socket file and the process output log. + * Closing this wrapper closes the connection, attempts a graceful + * SIGTERM, then delegates the forced kill + file cleanup to + * [[WorkerArtifacts.close]]. * * @param id stable identifier for this worker (a UUID * generated by the dispatcher and passed to the * worker binary as `--id`). Used as the map key * in the dispatcher's worker registry and in * diagnostic messages. - * @param process the OS process handle - * @param connection the transport connection to this worker - * @param socketPath the UDS socket path used by this worker - * @param outputFile the merged stdout/stderr log for this worker. - * Kept open for the lifetime of the worker (so it - * remains a valid target for the child's file - * descriptor and so its contents can be inspected - * while the worker runs) and deleted in [[close]]. + * @param artifacts owns the OS process, socket file, and output + * log. The same bundle is used for spawn-failure + * teardown before this wrapper exists, so the + * dispose logic is centralised there. + * @param connection the transport connection to this worker. * @param gracefulTimeoutMs milliseconds to wait after SIGTERM before - * escalating to SIGKILL. + * escalating (via `artifacts.close`) to SIGKILL. * @param logger [[WorkerLogger]] used for process-level * messages. Defaults to [[WorkerLogger.NoOp]]; * the dispatcher normally passes its own @@ -81,10 +81,8 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} @Experimental class DirectWorkerProcess( val id: String, - val process: Process, + private[direct] val artifacts: WorkerArtifacts, val connection: WorkerConnection, - val socketPath: String, - val outputFile: Path, val gracefulTimeoutMs: Long, protected val logger: WorkerLogger = WorkerLogger.NoOp, private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ()) @@ -100,6 +98,15 @@ class DirectWorkerProcess( private val activeSessionCount = new AtomicInteger(0) private val closed = new AtomicBoolean(false) + /** The OS process handle for this worker. */ + def process: Process = artifacts.process + + /** The UDS socket path used by this worker. */ + def socketPath: String = artifacts.socketPath + + /** Path to the merged stdout/stderr log for this worker. */ + def outputFile: Path = artifacts.outputFile + /** Number of sessions currently using this worker. */ def activeSessions: Int = activeSessionCount.get() @@ -136,11 +143,17 @@ class DirectWorkerProcess( def isAlive: Boolean = process.isAlive && connection.isActive /** - * Shuts down the connection, then terminates the OS process. - * Sends SIGTERM first; escalates to SIGKILL after [[gracefulTimeoutMs]]. + * Shuts down the connection, attempts a graceful SIGTERM, then disposes + * the worker artifacts (forced kill + reap + file cleanup). + * * Idempotent: only the first call performs teardown; subsequent calls * are no-ops so the dispatcher's close path and the createSession error * path can both invoke close without double-releasing resources. + * + * If the graceful SIGTERM wait expires (or is interrupted), we fall + * through to [[WorkerArtifacts.close]], which issues SIGKILL and reaps. + * If the process has already exited voluntarily, artifacts.close()'s + * kill step is a no-op and we just clean up the files. */ override def close(): Unit = { if (!closed.compareAndSet(false, true)) return @@ -155,32 +168,66 @@ class DirectWorkerProcess( if (process.isAlive) { process.destroy() // SIGTERM try { - if (!process.waitFor(gracefulTimeoutMs, TimeUnit.MILLISECONDS)) { - // Graceful window expired; escalate to SIGKILL. destroyForciblyAndReap - // waits for the kernel to actually reap the child so we don't leak - // a zombie here. - DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, s"worker $id") - } + // Whether the wait returns true (clean exit) or false (timeout), + // we fall through to artifacts.close(); the helper's SIGKILL step + // is a no-op on an already-dead process. + process.waitFor(gracefulTimeoutMs, TimeUnit.MILLISECONDS) } catch { case _: InterruptedException => - // Interrupted mid-wait: kill forcibly so we don't leave the child - // behind, then re-raise the interrupt on the caller. - DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, s"worker $id") Thread.currentThread().interrupt() } } - try { - val f = new File(socketPath) - if (f.exists()) f.delete() - } catch { + artifacts.close() + } +} + +/** + * Closeable bundle whose [[close]] always does a forced kill (SIGKILL + + * reap) followed by socket-file and output-log cleanup. Graceful SIGTERM + * is the responsibility of higher layers that know the worker was actually + * running (see [[DirectWorkerProcess#close]]). + * + * Holds the per-worker OS-level resources that are born together at spawn + * time -- the child [[Process]], its UDS socket file, and the merged + * stdout/stderr log file -- so the happy-path teardown and the spawn- + * failure teardown (before a `DirectWorkerProcess` wrapper even exists) + * share one dispose implementation. Callers should not dispose the + * individual fields directly. + * + * On the spawn-failure path the worker may never have installed signal + * handlers; jumping straight to SIGKILL avoids paying a graceful-timeout + * window on a process that wasn't listening anyway. + */ +private[direct] final class WorkerArtifacts( + val process: Process, + val socketPath: String, + val outputFile: Path, + private[this] val logger: WorkerLogger) extends AutoCloseable { + + private[this] val closed = new AtomicBoolean(false) + + /** + * Idempotently tears down the three resources in a fixed order: + * 1. SIGKILL the process and wait for the kernel to reap it. + * 2. Delete the socket file if it was created. + * 3. Delete the output (stdout/stderr) log file. + * + * Each step is guarded so a failure in one does not prevent the next. + * Subsequent calls return immediately. + */ + override def close(): Unit = { + if (!closed.compareAndSet(false, true)) return + + DirectWorkerDispatcher.destroyForciblyAndReap( + process, logger, s"worker artifacts $socketPath") + + try Files.deleteIfExists(new File(socketPath).toPath) catch { case NonFatal(e) => logger.warn(s"Error cleaning up socket file $socketPath", e) } - try { - Files.deleteIfExists(outputFile) - } catch { + try Files.deleteIfExists(outputFile) catch { case NonFatal(e) => logger.warn(s"Error cleaning up worker output file $outputFile", e) } diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 09735e9395fde..ea4a5baf21ef7 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -114,7 +114,7 @@ class DirectWorkerDispatcherSuite .build() private def udsProperties: UDFWorkerProperties = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionSpec.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance) .build()) .build() @@ -255,7 +255,7 @@ class DirectWorkerDispatcherSuite .addCommand("bash").addCommand("-c").addCommand(sigtermIgnoringScript).addCommand("--") .build() val shortGracefulProps = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionSpec.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) .setGracefulTerminationTimeoutMs(500) .build() @@ -344,6 +344,43 @@ class DirectWorkerDispatcherSuite } } + test("worker-provided graceful timeout is capped at the engine-side maximum") { + // The proto documents an engine-configurable maximum (fixed at 30s today). + // A 60s spec value should be clamped down. + val oversizedProps = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setGracefulTerminationTimeoutMs(60000) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(oversizedProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + val session = createStubSession() + assert(session.workerProcess.gracefulTimeoutMs == 30000L, + s"graceful timeout should be capped at 30000ms, " + + s"got ${session.workerProcess.gracefulTimeoutMs}") + session.close() + } + + test("worker-provided init timeout is capped at the engine-side maximum") { + val oversizedProps = UDFWorkerProperties.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() + .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) + .setInitializationTimeoutMs(60000) + .build() + val spec = UDFWorkerSpecification.newBuilder() + .setDirect(DirectWorker.newBuilder() + .setRunner(defaultRunner).setProperties(oversizedProps).build()) + .build() + dispatcher = new TestDirectWorkerDispatcher(spec) + + assert(dispatcher.initTimeoutMs == 30000L, + s"init timeout should be capped at 30000ms, got ${dispatcher.initTimeoutMs}") + } + test("createSession after close is rejected") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) dispatcher.close() @@ -422,20 +459,20 @@ class DirectWorkerDispatcherSuite } } - test("DirectWorker without connections is rejected") { + test("DirectWorker without a connection is rejected") { val badSpec = UDFWorkerSpecification.newBuilder() .setDirect(DirectWorker.newBuilder().setRunner(defaultRunner).build()) .build() val ex = intercept[IllegalArgumentException] { new TestDirectWorkerDispatcher(badSpec) } - assert(ex.getMessage.contains("exactly one entry"), - s"expected connections-count error, got: ${ex.getMessage}") + assert(ex.getMessage.contains("connection must be set"), + s"expected missing-connection error, got: ${ex.getMessage}") } test("DirectWorker with non-UDS transport is rejected") { val tcpProperties = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionSpec.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() .setTcp(LocalTcpConnection.getDefaultInstance).build()) .build() val badSpec = UDFWorkerSpecification.newBuilder() @@ -449,24 +486,6 @@ class DirectWorkerDispatcherSuite s"expected UDS-only error, got: ${ex.getMessage}") } - test("DirectWorker with multiple connections is rejected") { - val twoConnections = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionSpec.newBuilder() - .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) - .addConnections(WorkerConnectionSpec.newBuilder() - .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) - .build() - val badSpec = UDFWorkerSpecification.newBuilder() - .setDirect(DirectWorker.newBuilder() - .setRunner(defaultRunner).setProperties(twoConnections).build()) - .build() - val ex = intercept[IllegalArgumentException] { - new TestDirectWorkerDispatcher(badSpec) - } - assert(ex.getMessage.contains("exactly one entry"), - s"expected connections-count error, got: ${ex.getMessage}") - } - test("socket file is cleaned up when createConnection throws") { val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]() val failingDispatcher = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { @@ -527,7 +546,7 @@ class DirectWorkerDispatcherSuite .addCommand("while true; do sleep 1; done").addCommand("--") .build() val shortInitProps = UDFWorkerProperties.newBuilder() - .addConnections(WorkerConnectionSpec.newBuilder() + .setConnection(WorkerConnectionSpec.newBuilder() .setUnixDomainSocket(UnixDomainSocket.getDefaultInstance).build()) .setInitializationTimeoutMs(500) .build() diff --git a/udf/worker/proto/src/main/protobuf/worker_spec.proto b/udf/worker/proto/src/main/protobuf/worker_spec.proto index 29718f792b43e..83dac4f962e5f 100644 --- a/udf/worker/proto/src/main/protobuf/worker_spec.proto +++ b/udf/worker/proto/src/main/protobuf/worker_spec.proto @@ -194,11 +194,10 @@ message UDFWorkerProperties { // (Optional) optional int32 graceful_termination_timeout_ms = 2; - // The connections this [[DirectWorker]] supports. A single connection is - // sufficient to run multiple UDFs and (gRPC) services; multi-connection - // workers (e.g., a separate control channel for stateful streaming) are - // a future extension and are not yet supported by the engine -- today - // exactly one connection must be specified. + // A [[DirectWorker]] exposes one server-side connection endpoint (a + // UDS path or a TCP port) that all sessions on the worker share. + // Multi-connection workers (e.g., separate data and control channels) + // are not supported in this release. // // On [[DirectWorker]] creation, connection information // is passed to the callable as a string parameter. @@ -211,10 +210,8 @@ message UDFWorkerProperties { // // For the format of each specific transport type, see the comments below. // - // (Required) Exactly one entry today; the field is repeated to allow - // additional connections (e.g., data + control) to be added without a - // schema-breaking migration. - repeated WorkerConnectionSpec connections = 3; + // (Required) + WorkerConnectionSpec connection = 3; } // Describes one connection (transport endpoint) that a [[DirectWorker]] From 4d2ad8246974c63ebc7383ac160bf5b82accadaf Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Fri, 24 Apr 2026 14:37:38 +0000 Subject: [PATCH 7/9] trimmed documentation as some code self-explains --- .../udf/worker/core/WorkerConnection.scala | 14 +- .../spark/udf/worker/core/WorkerSession.scala | 33 +-- .../core/direct/DirectWorkerDispatcher.scala | 239 +++++------------- .../core/direct/DirectWorkerProcess.scala | 135 +++------- .../core/direct/DirectWorkerSession.scala | 5 - 5 files changed, 115 insertions(+), 311 deletions(-) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala index e1c4dd324ee67..82b2fff8df585 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerConnection.scala @@ -28,15 +28,13 @@ import org.apache.spark.annotation.Experimental * process wrapper (e.g., [[direct.DirectWorkerProcess]]) and shared * across all [[WorkerSession]]s that use that process. * - * One connection, many sessions: a worker process exposes a single - * server-side endpoint (e.g., one UDS path, one TCP port), and multiple - * concurrent UDF executions share it. For gRPC transports the channel - * multiplexes independent streams per session; for a raw-socket transport - * the connection still represents the shared underlying channel. + * One connection, many sessions: the worker exposes a single server-side + * endpoint that all sessions share. For gRPC, per-session work lives on + * multiplexed streams over this channel. * - * Implementations wrap the concrete transport and expose only lifecycle - * methods. Data transmission happens at the [[WorkerSession]] level, not - * here -- this class is solely about whether the channel is open. + * Implementations expose only lifecycle. Data transmission happens at + * the [[WorkerSession]] level -- this class is solely about whether the + * channel is open. * * '''Relationship to other classes (direct creation mode):''' * {{{ diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala index 7052fd16312e8..f4c4091688c94 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/WorkerSession.scala @@ -25,17 +25,11 @@ import org.apache.spark.annotation.Experimental * Carries all information needed to initialize a UDF execution on a worker. * * This message is passed to [[WorkerSession#init]] and contains the function - * definition, schemas, and any additional configuration. It is designed to be - * extended in future versions with new fields (e.g., UDF shape, data format, - * Spark context metadata, chaining information) without breaking existing - * worker implementations. + * definition, schemas, and any additional configuration. * - * Placeholder until the UDF protocol lands: this Scala case class will be - * replaced by a generated proto message once the wire protocol is - * introduced. Do not rely on case-class equality -- `Array[Byte]` fields - * compare by reference under the default case-class `equals`/`hashCode`, - * so do not use [[InitMessage]] as a hash-based collection key or - * compare instances by value without first wrapping the byte arrays. + * Placeholder: will be replaced by a generated proto message once the + * UDF wire protocol lands. Do not rely on case-class equality -- + * `Array[Byte]` fields compare by reference. * * @param functionPayload serialized function (e.g., pickled Python, JVM bytes) * @param inputSchema serialized input schema (e.g., Arrow schema bytes) @@ -82,11 +76,10 @@ case class InitMessage( * - [[close]] must always be called (use try-finally). * - [[cancel]] may be called at any time to abort execution. * - * The lifecycle is enforced by this class: [[init]] and [[process]] are - * `final` and delegate to [[doInit]] / [[doProcess]] after checking - * AtomicBoolean guards. Subclasses implement the protocol-specific work - * in [[doInit]] and [[doProcess]] and do not need to re-check the - * contract themselves. + * The lifecycle is enforced here: [[init]] and [[process]] are `final` + * and delegate to [[doInit]] / [[doProcess]] after AtomicBoolean guards. + * Subclasses implement the protocol-specific work and do not re-check + * the contract. */ @Experimental abstract class WorkerSession extends AutoCloseable { @@ -134,16 +127,10 @@ abstract class WorkerSession extends AutoCloseable { doProcess(input) } - /** - * Subclass hook for [[init]]. Called exactly once, after the lifecycle - * guard has verified init has not already run. - */ + /** Subclass hook for [[init]]. Called once, after the guard. */ protected def doInit(message: InitMessage): Unit - /** - * Subclass hook for [[process]]. Called at most once, after the - * lifecycle guard has verified init has run and process has not. - */ + /** Subclass hook for [[process]]. Called at most once, after the guard. */ protected def doProcess(input: Iterator[Array[Byte]]): Iterator[Array[Byte]] /** diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 7ba1beb5e4a3e..08f7af05850ef 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -79,15 +79,10 @@ abstract class DirectWorkerDispatcher( */ protected def callableTimeoutMs: Long = DEFAULT_CALLABLE_TIMEOUT_MS - // Per worker_spec.proto, the engine applies a maximum cap to each - // proto-provided worker timeout (initialization_timeout_ms, - // graceful_termination_timeout_ms) so a misbehaving spec cannot make the - // engine wait arbitrarily long during startup or shutdown. The cap is - // [[ENGINE_MAX_TIMEOUT_MS]] (fixed at 30s for now). The dispatcher-internal - // `callableTimeoutMs` is subclass-controlled, not user-controlled, and is - // not subject to this cap. - // Exposed at package visibility so the core-level test suite can assert - // the engine-side clamp applied (kept out of the public API). + // Proto-provided timeouts are clamped to ENGINE_MAX_TIMEOUT_MS. The + // dispatcher-internal callableTimeoutMs above is subclass-controlled and + // not subject to the cap. + // Package-private for test access. private[core] val initTimeoutMs: Long = { val props = workerSpec.getDirect.getProperties val raw = if (props.hasInitializationTimeoutMs && props.getInitializationTimeoutMs > 0) { @@ -120,26 +115,10 @@ abstract class DirectWorkerDispatcher( } } - // The socket directory is removed explicitly in close(). deleteOnExit is - // deliberately not registered: it is redundant with the explicit cleanup, - // it leaks memory in long-lived JVMs (the JDK retains the path string for - // the process lifetime), and it only works on empty directories. - // - // Created with POSIX 0700 so UDS sockets inside are not reachable by other - // local users. On non-POSIX filesystems the JDK rejects the attribute with - // UnsupportedOperationException; fall back to owner-only permissions via - // the File API, which is best-effort but still narrows the mode. + // Removed explicitly in close(). deleteOnExit is avoided because the JDK + // retains the path for the JVM lifetime, which leaks in long-lived drivers. private val socketDir: Path = createPrivateTempDirectory() - // Keyed by worker id so release is O(1) and lock-free. Iterating for - // shutdown is weakly consistent, which is fine: any createSession racing - // with close() is caught by the `closed` flag checks in createSession, and - // any in-flight releaseWorker that overlaps with close() is idempotent on - // both sides (remove-missing is a no-op, DirectWorkerProcess.close() is - // CAS-guarded). private[this] val workers = new ConcurrentHashMap[String, DirectWorkerProcess]() - // Flips to true on close(); createSession rejects afterwards and any - // already-spawned worker caught between `workers.put` and the post-publish - // check tears itself down instead of leaking. private[this] val closed = new AtomicBoolean(false) @volatile private var environmentState: EnvironmentState = EnvironmentState.Pending @@ -154,23 +133,17 @@ abstract class DirectWorkerDispatcher( override def createSession( securityScope: Option[WorkerSecurityScope]): WorkerSession = { - // Pooling keyed by security scope is not yet implemented. Accepting a - // non-None scope here would silently create a one-off worker and give - // the caller a false expectation of isolation, so reject it until the - // dispatcher actually honors the scope. require(securityScope.isEmpty, "securityScope is not supported yet; pass None until pooling lands") if (closed.get()) throwClosed() ensureEnvironmentReady() val worker = spawnWorker() - // Acquire the session ref-count BEFORE publishing to `workers`: otherwise - // close() could iterate `workers`, tear this worker down, and leave the - // subsequent acquireSession handing the caller a dead worker. + // Acquire before publish: a concurrent close() iterating `workers` must + // not tear down this worker before we hand it to the caller. worker.acquireSession() workers.put(worker.id, worker) - // Re-check after publishing. close() may have iterated `workers` before - // our put (orphan case) or after (already closed here). Either way, - // releasing our ref-count fires the callback that removes and closes. + // Re-check for close() that ran concurrently. Releasing fires the + // ref-count callback, which removes and tears down the worker. if (closed.get()) { worker.releaseSession() throwClosed() @@ -180,8 +153,6 @@ abstract class DirectWorkerDispatcher( } catch { case e: InterruptedException => Thread.currentThread().interrupt() - // Drop the ref-count to 0, firing `releaseWorker` via the worker's - // callback to remove and tear down the worker. worker.releaseSession() throw e case NonFatal(e) => @@ -191,14 +162,10 @@ abstract class DirectWorkerDispatcher( } /** - * Called from [[DirectWorkerProcess.releaseSession]] when the last - * active session on `worker` closes. Today this always terminates the - * worker; with pooling, this is where the decision to reuse vs. evict - * will live (idle-pool handoff, capacity limits, health checks). - * - * Safe to invoke after dispatcher [[close]] has already reaped this - * worker: the worker's own idempotent close guard turns the second - * teardown into a no-op. + * Invoked when a worker's last session closes. Terminates the worker + * today; future pooling can reuse it here instead. Safe to call after + * dispatcher close -- the worker's own CAS-idempotent close makes a + * second teardown a no-op. */ private def releaseWorker(worker: DirectWorkerProcess): Unit = { workers.remove(worker.id) @@ -214,29 +181,18 @@ abstract class DirectWorkerDispatcher( throw new IllegalStateException("Dispatcher is closed") /** - * Shuts down the dispatcher: terminates tracked workers, removes the - * socket directory, and runs environment cleanup. Idempotent via CAS -- - * only the first caller performs teardown; subsequent calls are no-ops. - * - * Does not block in-flight `createSession` calls. A createSession caller - * racing past the post-publish `closed` check will still produce a - * spawned worker; that worker tears itself down via the ref-count - * callback, but the subprocess teardown may outlive this `close()`. - * Callers needing a fully quiescent state must externally synchronize - * with their `createSession` callers. + * Terminates tracked workers, removes the socket directory, and runs + * environment cleanup. Idempotent via CAS. Does not drain in-flight + * createSession calls -- a worker spawned racing with close tears + * itself down through the ref-count callback, which may outlive this + * method. */ override def close(): Unit = { - // Flip `closed` first so that any concurrent createSession either bails - // on its fast-path check or cleans up its own worker via the post-publish - // check. Only the first caller performs the teardown; subsequent calls - // are no-ops. if (!closed.compareAndSet(false, true)) { return } - // TODO: Close workers in parallel. Worst-case shutdown today is - // N * gracefulTimeoutMs because each worker waits for SIGTERM to - // complete before the next one is signalled. A small pool of - // short-lived threads would bound shutdown to ~gracefulTimeoutMs. + // TODO: close workers in parallel -- today shutdown is serialised at + // N * gracefulTimeoutMs worst case. workers.values().iterator().asScala.foreach { w => try { w.close() @@ -263,27 +219,18 @@ abstract class DirectWorkerDispatcher( // -- Environment lifecycle ------------------------------------------------- - // TODO: Handle permanently unrecoverable environment failures (e.g., wrong - // CPU architecture, unavailable system resources) differently from transient - // ones. Currently all failures are treated as permanent, but some callers - // may want to distinguish retriable vs. fatal failures. + // TODO: distinguish retriable vs permanent environment failures. private def ensureEnvironmentReady(): Unit = { environmentLock.synchronized { environmentState match { case EnvironmentState.Ready | EnvironmentState.CleanedUp => - // Already set up (or torn down); nothing to do. case EnvironmentState.Failed(msg) => throw new DirectWorkerException(s"Environment setup previously failed: $msg") case EnvironmentState.Pending => val env = workerSpec.getEnvironment - // Register the cleanup hook up front. Cleanup is user-defined - // and may tear down more than install artifacts (worker state, - // temp files, etc.), so we honor it whenever it is configured, - // independent of whether verify/install run. Registering before - // any callable runs also ensures a partially-successful install - // (e.g., half-copied files) still gets cleaned up at JVM - // shutdown if the caller never reaches dispatcher.close(). The - // hook is a no-op when environment_cleanup is not configured. + // Register up front so a partially-successful install still gets + // torn down at JVM shutdown if dispatcher.close is never called. + // No-op when environment_cleanup is not configured. registerEnvironmentCleanupHook() val verified = env.hasEnvironmentVerification && runCallable(env.getEnvironmentVerification).exitCode == 0 @@ -301,22 +248,10 @@ abstract class DirectWorkerDispatcher( } } - // TODO: Share a single JVM shutdown hook across all dispatchers in the - // process. Today each dispatcher (when environment_cleanup is set) - // registers its own Thread, which the JVM retains until shutdown. - // In a long-running driver that creates many dispatchers without - // explicit close() (e.g., failed-task paths), this keeps per-dispatcher - // memory live for the process lifetime. A shared cleanup coordinator - // draining a ConcurrentLinkedQueue would collapse the N hooks - // into one. + // TODO: share one JVM shutdown hook across all dispatchers in the + // process. Each live dispatcher is retained by the JVM until shutdown. - /** - * Registers the JVM shutdown hook that runs the cleanup callable. - * - * '''Caller must hold `environmentLock`''' -- this method reads and - * writes `cleanupHook` without its own synchronization. It is only - * called from `ensureEnvironmentReady`, which already owns the lock. - */ + /** Registers the JVM shutdown hook that runs the cleanup callable. */ private def registerEnvironmentCleanupHook(): Unit = { assert(Thread.holdsLock(environmentLock), "registerEnvironmentCleanupHook must be called while holding environmentLock") @@ -347,7 +282,6 @@ abstract class DirectWorkerDispatcher( environmentLock.synchronized { environmentState match { case EnvironmentState.CleanedUp => - // Already cleaned up; nothing to do. case _ => if (workerSpec.getEnvironment.hasEnvironmentCleanup) { try { @@ -402,24 +336,19 @@ abstract class DirectWorkerDispatcher( "DirectWorker.runner must have at least one entry in command or arguments") val workerId = UUID.randomUUID().toString val socketPath = socketDir.resolve(s"worker-$workerId.sock").toString - // Per the ProcessCallable contract in worker_spec.proto, the engine must - // always pass --id (worker identifier for logs) and --connection (the - // engine-assigned endpoint, format depending on transport). + // Proto contract: the engine must pass --id and --connection. val cmd = baseCmd ++ Seq("--id", workerId, "--connection", socketPath) val env = runner.getEnvironmentVariablesMap.asScala.toMap val outputFile = Files.createTempFile("udf-worker-", ".log") val process = launchProcess(cmd, env, outputFile.toFile) - // Wrap the raw resources into a closeable bundle immediately so any - // subsequent failure (waitForSocket, createConnection) can dispose them - // through the same path the happy-path teardown uses. + // Bundle raw resources so spawn-failure teardown reuses the happy-path + // dispose path. val artifacts = new WorkerArtifacts(process, socketPath, outputFile, logger) try { waitForSocket(socketPath, process, outputFile.toFile) val connection = createConnection(socketPath) - // Ownership of `artifacts` transfers to the DirectWorkerProcess: they - // remain alive for the duration of the worker and are disposed via - // DirectWorkerProcess.close() -> artifacts.close(). + // Ownership of `artifacts` transfers to the DirectWorkerProcess. new DirectWorkerProcess( workerId, artifacts, connection, gracefulTimeoutMs, logger, onLastSessionReleased = releaseWorker) @@ -436,10 +365,9 @@ abstract class DirectWorkerDispatcher( /** * Creates a temp directory with owner-only permissions (0700 on POSIX). - * Falls back to a best-effort `File.setXxx` on non-POSIX filesystems - * that cannot honor the attribute. The fallback is racy and weaker - * than the POSIX path; the logger call surfaces when the platform - * refuses the `setXxx` calls so operators see the degraded mode. + * On non-POSIX filesystems falls back to best-effort `File.setXxx`, + * which is TOCTOU-racy and weaker; a WARN surfaces if the platform + * refuses the setters. */ private def createPrivateTempDirectory(): Path = { val attr = PosixFilePermissions.asFileAttribute( @@ -448,15 +376,10 @@ abstract class DirectWorkerDispatcher( Files.createTempDirectory("spark-udf-worker", attr) } catch { case _: UnsupportedOperationException => - // Non-POSIX filesystem. The dir exists with default perms between - // `createTempDirectory` and the `setXxx` calls below, so this - // fallback is TOCTOU-racy by nature. val dir = Files.createTempDirectory("spark-udf-worker") val f = dir.toFile - // Strip group/other access, then restore owner rwx. `&` (non-short- - // circuiting) so every call is attempted before we decide whether - // to warn; any `false` return means the platform silently refused - // the change and the directory is less private than advertised. + // `&` (non-short-circuiting) so every setter is attempted even if + // an earlier one refused. val applied = f.setReadable(false, false) & f.setWritable(false, false) & f.setExecutable(false, false) & f.setReadable(true, true) & @@ -491,9 +414,8 @@ abstract class DirectWorkerDispatcher( process: Process, outputFile: File): Unit = { val file = new File(socketPath) - // Ensure at least one poll attempt even for very small init timeouts, - // so we don't declare a premature timeout before the worker has any - // chance to create the socket. + // At least one poll so very small initTimeouts don't trip a premature + // timeout before the worker has any chance to create the socket. val maxAttempts = math.max(1, (initTimeoutMs / SOCKET_POLL_INTERVAL_MS).toInt) var attempts = 0 while (!file.exists() && attempts < maxAttempts) { @@ -509,9 +431,8 @@ abstract class DirectWorkerDispatcher( throw new DirectWorkerException( s"Worker did not create socket at $socketPath within ${initTimeoutMs}ms\n$tail") } else { - // The worker exited between the last file.exists() poll and here - // without creating the socket -- report the exit code rather than - // an ambiguous "did not create" message. + // Worker exited after the last poll without creating the socket; + // prefer the exit-code message over the ambiguous "did not create". throwWorkerExitedBeforeSocket(process, socketPath, outputFile) } } @@ -527,24 +448,18 @@ abstract class DirectWorkerDispatcher( s"before creating socket at $socketPath\n$tail") } - // Reads at most the final MAX_OUTPUT_SCAN_BYTES of `file` and returns the - // last PROCESS_OUTPUT_TAIL_LINES lines via a fixed-size ring buffer, so a - // runaway worker that writes gigabytes of output does not OOM the caller - // during error reporting. + // Bounded scan so a runaway worker that writes gigabytes of output does + // not OOM the caller during error reporting. private def readOutputTail(file: File): String = { if (!file.exists() || file.length() == 0) return "" val fileLen = file.length() val startPos = math.max(0L, fileLen - MAX_OUTPUT_SCAN_BYTES) val fis = new FileInputStream(file) try { - // FileChannel.position is O(1) and unambiguously seeks, unlike - // FileInputStream.skip which is allowed to return 0 even when bytes - // remain. if (startPos > 0) fis.getChannel.position(startPos) val reader = new BufferedReader( new InputStreamReader(fis, StandardCharsets.UTF_8)) - // If we started mid-line, the first line is partial -- discard it so - // the tail never shows a line fragment. + // Discard the first (partial) line when we seeked into the middle. if (startPos > 0) reader.readLine() val buffer = new MQueue[String]() var line = reader.readLine() @@ -578,12 +493,9 @@ abstract class DirectWorkerDispatcher( s"got ${conn.getTransportCase}") } - // worker_spec.proto documents that verification is only meaningful together - // with installation -- verification exists so the engine can skip running - // installation when the environment is already prepared. A verification - // callable with no installation callable would either always succeed (no-op) - // or always fail (worker spawn then fails) -- both user errors worth - // catching at spec-validation time. + // Verification exists to short-circuit installation when the environment + // is already prepared, so requiring installation alongside verification + // catches user errors at spec-validation time. private def validateEnvironmentCallables(): Unit = { val env = workerSpec.getEnvironment require(!env.hasEnvironmentVerification || env.hasInstallation, @@ -596,55 +508,28 @@ private[direct] object DirectWorkerDispatcher { private[direct] val DEFAULT_INIT_TIMEOUT_MS = 10000L private[direct] val DEFAULT_CALLABLE_TIMEOUT_MS = 120000L private[direct] val DEFAULT_GRACEFUL_TIMEOUT_MS = 5000L - // Engine-side cap for proto-provided worker timeouts - // (initialization_timeout_ms, graceful_termination_timeout_ms). Fixed at - // 30s for now, matching the example value in worker_spec.proto; - // intentionally not configurable until we have a concrete need. The - // defaults below must stay at or under this cap so the clamp path only - // triggers on user-provided values, never on dispatcher defaults. + // Engine-side cap on proto-provided worker timeouts. The defaults below + // must stay at or under this cap so the clamp only fires on + // user-provided values. private[direct] val ENGINE_MAX_TIMEOUT_MS = 30000L - require(DEFAULT_INIT_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS, - s"DEFAULT_INIT_TIMEOUT_MS ($DEFAULT_INIT_TIMEOUT_MS) must not exceed " + - s"ENGINE_MAX_TIMEOUT_MS ($ENGINE_MAX_TIMEOUT_MS)") - require(DEFAULT_GRACEFUL_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS, - s"DEFAULT_GRACEFUL_TIMEOUT_MS ($DEFAULT_GRACEFUL_TIMEOUT_MS) must not " + - s"exceed ENGINE_MAX_TIMEOUT_MS ($ENGINE_MAX_TIMEOUT_MS)") + require(DEFAULT_INIT_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS && + DEFAULT_GRACEFUL_TIMEOUT_MS <= ENGINE_MAX_TIMEOUT_MS, + "default timeouts must not exceed ENGINE_MAX_TIMEOUT_MS") private[direct] val PROCESS_OUTPUT_TAIL_LINES = 50 - // Cap the amount of log file scanned by readOutputTail so a runaway worker - // producing gigabytes of output cannot OOM the caller during error - // reporting. The tail is still limited to PROCESS_OUTPUT_TAIL_LINES. private[direct] val MAX_OUTPUT_SCAN_BYTES = 1024L * 1024L // 1 MiB - // Bound on how long we wait for the kernel to reap a SIGKILL'd child. - // Distinct from the user-configurable gracefulTimeoutMs: SIGKILL is - // unblockable, so the only delay is kernel latency (milliseconds in the - // common case). Five seconds is generous; if we exceed it the child is - // likely stuck in uninterruptible I/O (D-state) and further waiting - // won't help. + // 5s bounds the wait for the kernel to reap a SIGKILL'd child. SIGKILL + // is unblockable, so exceeding this usually means the process is stuck + // in uninterruptible I/O (D-state) and further waiting will not help. private[direct] val SIGKILL_REAP_TIMEOUT_MS = 5000L /** - * Sends SIGKILL to `process` and waits up to [[SIGKILL_REAP_TIMEOUT_MS]] - * for the kernel to reap it. - * - * `destroyForcibly()` returns before the child has been reaped from the - * process table. Without a bounded `waitFor` the child lingers as a - * zombie until the JVM itself exits; in long-lived drivers with repeated - * failed spawns or session teardowns, the zombies accumulate. Callers - * should use this helper instead of calling `destroyForcibly()` directly. - * - * Behaviour: - * - No-op if the process is already dead. - * - If the reap times out, logs a warning and returns -- we do not - * block forever on a wedged child. The process remains a zombie; - * this is a real (but rare) operational condition the warning - * surfaces. - * - If the current thread is interrupted while waiting, re-raises - * the interrupt and returns without logging a zombie warning -- - * we have not actually waited the full reap window. + * SIGKILL `process` and wait up to [[SIGKILL_REAP_TIMEOUT_MS]] for the + * kernel to reap it. `destroyForcibly()` alone returns before the child + * is reaped, which leaks a zombie until JVM exit. On reap-timeout logs + * a warning; on interrupt re-raises the interrupt and returns. * - * @param context short human-readable tag included in the warn log so - * operators can correlate a wedged child with its source - * (e.g. worker id, or the callable that was killed). + * @param context short tag included in the timeout warning so operators + * can correlate a stuck child with its source. */ private[direct] def destroyForciblyAndReap( process: Process, diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index 470df8a6c4431..d840d4352bbb4 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -28,55 +28,25 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} /** * :: Experimental :: - * A locally-spawned OS process running a UDF worker, together with the - * transport connection to it. + * A locally-spawned OS process running a UDF worker, together with its + * transport connection. Wraps a [[WorkerArtifacts]] bundle (process + + * socket file + output log) and a [[WorkerConnection]], plus a session + * ref-count scaffolding for future pooling -- today one process per + * session. * - * A [[DirectWorkerProcess]] combines three things: - * - An OS '''process''' (the worker binary, started by the dispatcher), - * along with its socket file and output log -- all owned jointly by a - * [[WorkerArtifacts]] bundle. - * - A '''[[WorkerConnection]]''' (the transport channel to that process). - * - A session '''ref-count''' for future pooling. + * Closing sends SIGTERM, waits up to [[gracefulTimeoutMs]], then + * delegates forced kill + file cleanup to [[WorkerArtifacts.close]]. * - * The dispatcher currently creates one process per session and tears it - * down when the session closes; the [[acquireSession]]/[[releaseSession]] - * ref-count is scaffolding for future pooling, where a single process - * will back multiple concurrent sessions. - * - * Closing this wrapper closes the connection, attempts a graceful - * SIGTERM, then delegates the forced kill + file cleanup to - * [[WorkerArtifacts.close]]. - * - * @param id stable identifier for this worker (a UUID - * generated by the dispatcher and passed to the - * worker binary as `--id`). Used as the map key - * in the dispatcher's worker registry and in - * diagnostic messages. - * @param artifacts owns the OS process, socket file, and output - * log. The same bundle is used for spawn-failure - * teardown before this wrapper exists, so the - * dispose logic is centralised there. - * @param connection the transport connection to this worker. - * @param gracefulTimeoutMs milliseconds to wait after SIGTERM before - * escalating (via `artifacts.close`) to SIGKILL. - * @param logger [[WorkerLogger]] used for process-level - * messages. Defaults to [[WorkerLogger.NoOp]]; - * the dispatcher normally passes its own - * logger so all messages share a category. - * @param onLastSessionReleased - * callback fired when the session ref-count - * transitions to 0. The dispatcher wires this to - * its reuse policy: today, terminate the worker; - * with pooling, return it to an idle pool or - * evict it. Runs on the thread that calls - * [[releaseSession]] (typically the session-close - * thread). May fire multiple times across a - * worker's lifetime -- once per 0-transition -- - * and runs without holding the count at 0: a - * concurrent `acquireSession` can push the count - * back up before the callback returns, so a - * pooling dispatcher must arbitrate reuse itself - * rather than assume "no active sessions" here. + * @param id stable worker identifier (UUID passed to the binary as `--id`). + * @param artifacts process + socket + output-log, disposed together. + * @param connection transport channel to this worker. + * @param gracefulTimeoutMs wait after SIGTERM before escalating to SIGKILL. + * @param logger [[WorkerLogger]] for process-level messages. + * @param onLastSessionReleased fires when the ref-count hits 0. Runs on + * the thread calling [[releaseSession]]. May fire more than once + * across a worker's lifetime; a concurrent `acquireSession` can + * re-increment the count before the callback returns, so pooling + * dispatchers must arbitrate reuse themselves. */ @Experimental class DirectWorkerProcess( @@ -88,12 +58,7 @@ class DirectWorkerProcess( private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ()) extends AutoCloseable { - // The ref-count drives worker teardown: when the last session releases, - // the dispatcher's `onLastSessionReleased` callback decides what to do - // with the worker. Today that callback terminates the worker (no reuse); - // with pooling it will return the worker to an idle pool subject to - // capacity / timeout. - // TODO: Idle timeout tracking and concurrent session capacity. + // TODO: idle-timeout tracking and concurrent session capacity. private val activeSessionCount = new AtomicInteger(0) private val closed = new AtomicBoolean(false) @@ -114,12 +79,9 @@ class DirectWorkerProcess( def acquireSession(): Unit = activeSessionCount.incrementAndGet() /** - * Decrements the active session count. Fires - * [[onLastSessionReleased]] when the count transitions to zero so the - * dispatcher can apply its reuse policy. Logs a warning and resets to - * zero if the count goes negative, which would indicate an unbalanced - * acquire/release (a bug we want to surface rather than paper over, - * especially once pooling consumes this count). + * Decrements the active session count. Fires [[onLastSessionReleased]] + * on the 0-transition. A negative count indicates an unbalanced + * acquire/release; we log and reset to 0 rather than silently mask it. */ def releaseSession(): Unit = { val c = activeSessionCount.decrementAndGet() @@ -128,10 +90,7 @@ class DirectWorkerProcess( s"releaseSession called without a matching acquireSession (count=$c)") activeSessionCount.set(0) } else if (c == 0) { - // Swallow callback errors so a misbehaving dispatcher cannot turn - // `session.close()` into an exception -- close must always succeed - // from the caller's perspective. The dispatcher callback has its - // own error handling for the underlying worker teardown. + // Swallow callback errors so session.close cannot throw. try onLastSessionReleased(this) catch { case NonFatal(e) => logger.warn(s"onLastSessionReleased callback failed for worker $id", e) @@ -143,17 +102,9 @@ class DirectWorkerProcess( def isAlive: Boolean = process.isAlive && connection.isActive /** - * Shuts down the connection, attempts a graceful SIGTERM, then disposes - * the worker artifacts (forced kill + reap + file cleanup). - * - * Idempotent: only the first call performs teardown; subsequent calls - * are no-ops so the dispatcher's close path and the createSession error - * path can both invoke close without double-releasing resources. - * - * If the graceful SIGTERM wait expires (or is interrupted), we fall - * through to [[WorkerArtifacts.close]], which issues SIGKILL and reaps. - * If the process has already exited voluntarily, artifacts.close()'s - * kill step is a no-op and we just clean up the files. + * Shuts down the connection, sends SIGTERM, waits up to + * [[gracefulTimeoutMs]], then disposes artifacts (SIGKILL + file + * cleanup). Idempotent via CAS. */ override def close(): Unit = { if (!closed.compareAndSet(false, true)) return @@ -168,9 +119,8 @@ class DirectWorkerProcess( if (process.isAlive) { process.destroy() // SIGTERM try { - // Whether the wait returns true (clean exit) or false (timeout), - // we fall through to artifacts.close(); the helper's SIGKILL step - // is a no-op on an already-dead process. + // Ignore the return value: artifacts.close() SIGKILLs if still + // alive and no-ops if already dead. process.waitFor(gracefulTimeoutMs, TimeUnit.MILLISECONDS) } catch { case _: InterruptedException => @@ -183,21 +133,14 @@ class DirectWorkerProcess( } /** - * Closeable bundle whose [[close]] always does a forced kill (SIGKILL + - * reap) followed by socket-file and output-log cleanup. Graceful SIGTERM - * is the responsibility of higher layers that know the worker was actually - * running (see [[DirectWorkerProcess#close]]). - * - * Holds the per-worker OS-level resources that are born together at spawn - * time -- the child [[Process]], its UDS socket file, and the merged - * stdout/stderr log file -- so the happy-path teardown and the spawn- - * failure teardown (before a `DirectWorkerProcess` wrapper even exists) - * share one dispose implementation. Callers should not dispose the - * individual fields directly. + * Closeable bundle of per-worker OS resources: the child [[Process]], its + * UDS socket file, and its merged stdout/stderr log. [[close]] always + * SIGKILL-reaps then deletes the files; graceful SIGTERM is the higher + * layer's responsibility (see [[DirectWorkerProcess#close]]). * - * On the spawn-failure path the worker may never have installed signal - * handlers; jumping straight to SIGKILL avoids paying a graceful-timeout - * window on a process that wasn't listening anyway. + * One dispose implementation shared by the happy-path teardown and the + * spawn-failure path (which runs before a `DirectWorkerProcess` wrapper + * exists). */ private[direct] final class WorkerArtifacts( val process: Process, @@ -208,13 +151,9 @@ private[direct] final class WorkerArtifacts( private[this] val closed = new AtomicBoolean(false) /** - * Idempotently tears down the three resources in a fixed order: - * 1. SIGKILL the process and wait for the kernel to reap it. - * 2. Delete the socket file if it was created. - * 3. Delete the output (stdout/stderr) log file. - * - * Each step is guarded so a failure in one does not prevent the next. - * Subsequent calls return immediately. + * Idempotently SIGKILLs the process, deletes the socket file, deletes + * the output log. Each step is guarded so a failure in one does not + * skip the next. */ override def close(): Unit = { if (!closed.compareAndSet(false, true)) return diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala index fabf146296229..7cdc5329350e3 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerSession.scala @@ -48,11 +48,6 @@ abstract class DirectWorkerSession( /** The connection to the worker for this session. */ def connection: WorkerConnection = workerProcess.connection - // Releasing the session decrements the worker's ref-count; when it hits - // zero the worker's `onLastSessionReleased` callback (wired by the - // dispatcher) decides whether to terminate it or return it to a pool. - // Today it is always terminated, so closing the session is enough to - // reap the worker. override def close(): Unit = { if (released.compareAndSet(false, true)) { workerProcess.releaseSession() From 9f5e3eefb43fdee4ebba9b9fa86c8128e3926cf1 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Fri, 24 Apr 2026 22:55:46 +0000 Subject: [PATCH 8/9] address more comments. --- .../core/direct/DirectWorkerDispatcher.scala | 23 ++- .../core/direct/DirectWorkerException.scala | 12 ++ .../core/DirectWorkerDispatcherSuite.scala | 144 +++++++++++++++++- 3 files changed, 171 insertions(+), 8 deletions(-) diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 08f7af05850ef..97ee9408dab7c 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -235,7 +235,18 @@ abstract class DirectWorkerDispatcher( val verified = env.hasEnvironmentVerification && runCallable(env.getEnvironmentVerification).exitCode == 0 if (!verified && env.hasInstallation) { - val result = runCallable(env.getInstallation) + // Treat any install failure (timeout or non-zero exit) as + // permanent. A partially-completed install can leave files on + // disk that a retry would race with; retry policy belongs in + // the future predicate (see TODO above). + val result = try { + runCallable(env.getInstallation) + } catch { + case e: DirectWorkerException => + environmentState = EnvironmentState.Failed( + s"installation failed: ${e.getMessage}") + throw e + } if (result.exitCode != 0) { val detail = s"exit code ${result.exitCode}\n${result.outputTail}" environmentState = EnvironmentState.Failed(detail) @@ -253,8 +264,10 @@ abstract class DirectWorkerDispatcher( /** Registers the JVM shutdown hook that runs the cleanup callable. */ private def registerEnvironmentCleanupHook(): Unit = { - assert(Thread.holdsLock(environmentLock), - "registerEnvironmentCleanupHook must be called while holding environmentLock") + if (!Thread.holdsLock(environmentLock)) { + throw new IllegalStateException( + "registerEnvironmentCleanupHook must be called while holding environmentLock") + } if (cleanupHook.isDefined) return if (workerSpec.getEnvironment.hasEnvironmentCleanup) { val hook = new Thread(() => runEnvironmentCleanup(), "udf-env-cleanup") @@ -318,7 +331,7 @@ abstract class DirectWorkerDispatcher( DirectWorkerDispatcher.destroyForciblyAndReap( process, logger, s"callable timeout: ${cmd.head}") val tail = readOutputTail(outputFile.toFile) - throw new DirectWorkerException( + throw new DirectWorkerTimeoutException( s"Callable timed out after ${timeoutMs}ms: " + s"${cmd.mkString(" ")}\n$tail") } @@ -428,7 +441,7 @@ abstract class DirectWorkerDispatcher( DirectWorkerDispatcher.destroyForciblyAndReap( process, logger, s"init timeout $socketPath") val tail = readOutputTail(outputFile) - throw new DirectWorkerException( + throw new DirectWorkerTimeoutException( s"Worker did not create socket at $socketPath within ${initTimeoutMs}ms\n$tail") } else { // Worker exited after the last poll without creating the socket; diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala index 3f2737571a56a..b0ece15eae38f 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerException.scala @@ -32,3 +32,15 @@ import org.apache.spark.annotation.Experimental @Experimental class DirectWorkerException(message: String, cause: Throwable = null) extends RuntimeException(message, cause) + +/** + * :: Experimental :: + * A [[DirectWorkerException]] caused specifically by a timeout: a worker + * that did not bind its socket within `initialization_timeout_ms`, or a + * setup callable (verify / install / cleanup) that exceeded + * `callableTimeoutMs`. Exposed as a distinct type so callers can choose + * different retry / escalation paths for timeouts vs other failures. + */ +@Experimental +class DirectWorkerTimeoutException(message: String, cause: Throwable = null) + extends DirectWorkerException(message, cause) diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index ea4a5baf21ef7..53507b6db235f 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -31,7 +31,8 @@ import org.apache.spark.udf.worker.{ UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec, WorkerEnvironment} import org.apache.spark.udf.worker.core.direct.{DirectWorkerDispatcher, - DirectWorkerProcess, DirectWorkerSession} + DirectWorkerException, DirectWorkerProcess, DirectWorkerSession, + DirectWorkerTimeoutException} /** * A [[WorkerConnection]] test implementation that considers the connection @@ -344,6 +345,99 @@ class DirectWorkerDispatcherSuite } } + test("close racing with in-flight createSession does not leak the worker") { + // The acquire-before-publish + post-publish closed re-check pattern in + // createSession is designed for this race: thread A is mid-spawn when + // thread B calls close(). Thread A must either throw IllegalStateException + // (post-publish check caught the close) or receive a session whose worker + // is reaped by close()'s iteration. No orphan process or socket file + // should remain in either case. + val readyLatch = new java.util.concurrent.CountDownLatch(1) + val releaseLatch = new java.util.concurrent.CountDownLatch(1) + val capturedWorkers = + new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]() + val racing = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection(socketPath: String): WorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = { + capturedWorkers.add(worker) + readyLatch.countDown() + // Block here so dispatcher.close() runs while createSession is in + // flight. Use a generous wait so a slow CI doesn't time out. + if (!releaseLatch.await(30, java.util.concurrent.TimeUnit.SECONDS)) { + fail("releaseLatch never fired -- test orchestration broken") + } + new StubWorkerSession(worker) + } + } + try { + val outcome = new java.util.concurrent.atomic.AtomicReference[Either[Throwable, WorkerSession]]() + val createThread = new Thread(() => { + try { + val s = racing.createSession(None) + outcome.set(Right(s)) + } catch { + case t: Throwable => outcome.set(Left(t)) + } + }, "createSession-racer") + createThread.start() + + // Wait for thread A to have published the worker and entered the + // blocking override. + assert(readyLatch.await(10, java.util.concurrent.TimeUnit.SECONDS), + "createSession thread never reached createSessionForWorker") + + val closeThread = new Thread(() => racing.close(), "close-racer") + closeThread.start() + // Give close() time to flip `closed` and iterate workers. + Thread.sleep(200) + + // Now release the in-flight createSession. + releaseLatch.countDown() + + createThread.join(10000) + closeThread.join(10000) + assert(!createThread.isAlive, "createSession thread did not finish") + assert(!closeThread.isAlive, "close thread did not finish") + + val captured = capturedWorkers.toArray(Array.empty[DirectWorkerProcess]) + assert(captured.length == 1, + s"expected exactly one worker spawned, got ${captured.length}") + val worker = captured(0) + + outcome.get() match { + case Left(e: IllegalStateException) => + // Contractually allowed, but unreachable with this orchestration: + // readyLatch only fires after createSession has cleared both + // `closed` checks, so B's close cannot flip `closed` in time for + // A to observe it. Kept defensive so a future internal change + // that introduces a new window is still covered. + assert(e.getMessage.contains("closed"), + s"expected dispatcher-closed error, got: ${e.getMessage}") + case Left(other) => + fail(s"unexpected exception from racing createSession: $other") + case Right(_) => + // close() iterated the published worker and tore it down; the + // returned session points at a worker that should now be dead. + } + + // Whichever path won, the worker must not still be running and the + // socket file must be gone. + val deadline = System.currentTimeMillis() + 5000 + while (worker.process.isAlive && System.currentTimeMillis() < deadline) { + Thread.sleep(50) + } + assert(!worker.process.isAlive, + s"worker process should be terminated after close, still alive at ${worker.socketPath}") + assert(!new java.io.File(worker.socketPath).exists(), + s"socket file ${worker.socketPath} should have been removed") + } finally { + releaseLatch.countDown() + racing.close() + } + } + test("worker-provided graceful timeout is capped at the engine-side maximum") { // The proto documents an engine-configurable maximum (fixed at 30s today). // A 60s spec value should be clamped down. @@ -556,7 +650,7 @@ class DirectWorkerDispatcherSuite .build() dispatcher = new TestDirectWorkerDispatcher(spec) - val ex = intercept[RuntimeException] { + val ex = intercept[DirectWorkerTimeoutException] { dispatcher.createSession(None) } assert(ex.getMessage.contains("did not create socket"), @@ -662,7 +756,7 @@ class DirectWorkerDispatcherSuite new StubWorkerSession(worker) } try { - val ex = intercept[RuntimeException] { + val ex = intercept[DirectWorkerTimeoutException] { shortTimeoutDispatcher.createSession(None) } assert(ex.getMessage.contains("Callable timed out"), @@ -775,6 +869,50 @@ class DirectWorkerDispatcherSuite counterFile.delete() } + test("installation timeout transitions to Failed and is not retried") { + val counterFile = Files.createTempFile("env-timeout-counter", ".txt").toFile + counterFile.delete() + + // Install appends to a counter file, then sleeps past callableTimeoutMs + // so runCallable times out. The dispatcher must mark the env Failed + // and reject the next createSession without re-running install. + val env = WorkerEnvironment.newBuilder() + .setInstallation(ProcessCallable.newBuilder() + .addCommand("bash").addCommand("-c") + .addCommand( + s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build()) + .build() + val timeoutDispatcher = new DirectWorkerDispatcher(specWithEnv(env = env)) { + override protected def callableTimeoutMs: Long = 500L + override protected def createConnection(socketPath: String): WorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } + try { + val first = intercept[DirectWorkerTimeoutException] { + timeoutDispatcher.createSession(None) + } + assert(first.getMessage.contains("Callable timed out"), + s"expected callable-timeout error, got: ${first.getMessage}") + + val second = intercept[DirectWorkerException] { + timeoutDispatcher.createSession(None) + } + assert(second.getMessage.contains("previously failed"), + s"expected cached failure on retry, got: ${second.getMessage}") + + val src = scala.io.Source.fromFile(counterFile) + val lines = try src.getLines().toList finally src.close() + assert(lines.size == 1, + s"installation should run only once across timed-out retries, got ${lines.size}") + } finally { + timeoutDispatcher.close() + counterFile.delete() + } + } + test("non-None securityScope is rejected until pooling lands") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) val scope = new WorkerSecurityScope { From c393de729f7123c9ff55b1770a43fb7531997ae5 Mon Sep 17 00:00:00 2001 From: Haiyang Sun Date: Sat, 25 Apr 2026 00:00:51 +0000 Subject: [PATCH 9/9] improved connection abstraction - move connection protocol out from base direct dispatcher, also improved cleanup logic. --- .../core/UnixSocketWorkerConnection.scala | 41 +++++ .../DirectUnixSocketWorkerDispatcher.scala | 145 +++++++++++++++ .../core/direct/DirectWorkerDispatcher.scala | 167 +++++++----------- .../core/direct/DirectWorkerProcess.scala | 58 +++--- .../core/DirectWorkerDispatcherSuite.scala | 122 +++++++------ 5 files changed, 338 insertions(+), 195 deletions(-) create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala create mode 100644 udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala new file mode 100644 index 0000000000000..b3b40d16e7443 --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/UnixSocketWorkerConnection.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core + +import java.io.File + +import org.apache.spark.annotation.Experimental + +/** + * :: Experimental :: + * A [[WorkerConnection]] over a Unix domain socket. Owns the socket + * path and removes the socket file on [[close]]. Subclasses provide the + * protocol-specific channel (e.g. gRPC over UDS) and may override + * [[close]] to add transport-level shutdown -- they should call + * `super.close()` to ensure the socket file is removed. + * + * [[close]] is idempotent: deleting an already-removed file is a no-op. + */ +@Experimental +abstract class UnixSocketWorkerConnection(val socketPath: String) + extends WorkerConnection { + + override def close(): Unit = { + val f = new File(socketPath) + if (f.exists()) f.delete() + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala new file mode 100644 index 0000000000000..8da0354187e4f --- /dev/null +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectUnixSocketWorkerDispatcher.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.udf.worker.core.direct + +import java.io.File +import java.nio.file.{Files, Path} +import java.nio.file.attribute.PosixFilePermissions + +import org.apache.spark.annotation.Experimental +import org.apache.spark.udf.worker.UDFWorkerSpecification +import org.apache.spark.udf.worker.core.{UnixSocketWorkerConnection, WorkerLogger} +import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.SOCKET_POLL_INTERVAL_MS + +/** + * :: Experimental :: + * A [[DirectWorkerDispatcher]] using Unix domain sockets as the worker + * transport. Allocates a private 0700 socket directory at construction; + * each worker is given a UDS path inside it. + * + * Concrete subclasses implement [[createConnection]] (with a UDS protocol + * of choice) and [[createSessionForWorker]]. + */ +@Experimental +abstract class DirectUnixSocketWorkerDispatcher( + workerSpec: UDFWorkerSpecification, + logger: WorkerLogger = WorkerLogger.NoOp) + extends DirectWorkerDispatcher(workerSpec, logger) { + + // Removed explicitly in closeTransport(). deleteOnExit is avoided because + // the JDK retains the path for the JVM lifetime, which leaks in + // long-lived drivers. + private val socketDir: Path = createPrivateTempDirectory() + + override protected def newEndpointAddress(workerId: String): String = + socketDir.resolve(s"worker-$workerId.sock").toString + + override protected def waitForReady( + address: String, + process: Process, + outputFile: File): Unit = { + val file = new File(address) + // At least one poll so very small initTimeouts don't trip a premature + // timeout before the worker has any chance to create the socket. + val maxAttempts = math.max(1, (initTimeoutMs / SOCKET_POLL_INTERVAL_MS).toInt) + var attempts = 0 + while (!file.exists() && attempts < maxAttempts) { + if (!process.isAlive) throwWorkerExitedBeforeSocket(process, address, outputFile) + Thread.sleep(SOCKET_POLL_INTERVAL_MS) + attempts += 1 + } + if (!file.exists()) { + if (process.isAlive) { + DirectWorkerDispatcher.destroyForciblyAndReap( + process, logger, s"init timeout $address") + val tail = readOutputTail(outputFile) + throw new DirectWorkerTimeoutException( + s"Worker did not create socket at $address within ${initTimeoutMs}ms\n$tail") + } else { + // Worker exited after the last poll without creating the socket; + // prefer the exit-code message over the ambiguous "did not create". + throwWorkerExitedBeforeSocket(process, address, outputFile) + } + } + } + + override protected def cleanupEndpointAddress(address: String): Unit = { + Files.deleteIfExists(new File(address).toPath) + } + + override protected def closeTransport(): Unit = { + val dir = socketDir.toFile + if (dir.exists()) { + val remaining = dir.listFiles() + if (remaining != null) remaining.foreach(_.delete()) + dir.delete() + } + } + + override protected def validateTransportSupport(): Unit = { + val props = workerSpec.getDirect.getProperties + require(props.hasConnection, + "DirectWorker.properties.connection must be set") + val conn = props.getConnection + require(conn.hasUnixDomainSocket, + "DirectUnixSocketWorkerDispatcher requires UNIX domain socket transport, " + + s"got ${conn.getTransportCase}") + } + + override protected def createConnection(address: String): UnixSocketWorkerConnection + + private def throwWorkerExitedBeforeSocket( + process: Process, + address: String, + outputFile: File): Nothing = { + val tail = readOutputTail(outputFile) + throw new DirectWorkerException( + s"Worker exited with code ${process.exitValue()} " + + s"before creating socket at $address\n$tail") + } + + /** + * Creates a temp directory with owner-only permissions (0700 on POSIX). + * On non-POSIX filesystems falls back to best-effort `File.setXxx`, + * which is TOCTOU-racy and weaker; a WARN surfaces if the platform + * refuses the setters. + */ + private def createPrivateTempDirectory(): Path = { + val attr = PosixFilePermissions.asFileAttribute( + PosixFilePermissions.fromString("rwx------")) + try { + Files.createTempDirectory("spark-udf-worker", attr) + } catch { + case _: UnsupportedOperationException => + val dir = Files.createTempDirectory("spark-udf-worker") + val f = dir.toFile + // `&` (non-short-circuiting) so every setter is attempted even if + // an earlier one refused. + val applied = + f.setReadable(false, false) & f.setWritable(false, false) & + f.setExecutable(false, false) & f.setReadable(true, true) & + f.setWritable(true, true) & f.setExecutable(true, true) + if (!applied) { + logger.warn( + s"Could not fully restrict permissions on $dir; socket " + + s"directory may be accessible to other local users on this " + + s"filesystem") + } + dir + } + } +} diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala index 97ee9408dab7c..afaf23791d80f 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerDispatcher.scala @@ -19,7 +19,6 @@ package org.apache.spark.udf.worker.core.direct import java.io.{BufferedReader, File, FileInputStream, InputStreamReader} import java.nio.charset.StandardCharsets import java.nio.file.{Files, Path} -import java.nio.file.attribute.PosixFilePermissions import java.util.UUID import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicBoolean @@ -35,7 +34,7 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerDispatcher, import org.apache.spark.udf.worker.core.direct.DirectWorkerDispatcher.{CallableResult, DEFAULT_CALLABLE_TIMEOUT_MS, DEFAULT_GRACEFUL_TIMEOUT_MS, DEFAULT_INIT_TIMEOUT_MS, ENGINE_MAX_TIMEOUT_MS, EnvironmentState, MAX_OUTPUT_SCAN_BYTES, - PROCESS_OUTPUT_TAIL_LINES, SOCKET_POLL_INTERVAL_MS} + PROCESS_OUTPUT_TAIL_LINES} /** * :: Experimental :: @@ -115,9 +114,6 @@ abstract class DirectWorkerDispatcher( } } - // Removed explicitly in close(). deleteOnExit is avoided because the JDK - // retains the path for the JVM lifetime, which leaks in long-lived drivers. - private val socketDir: Path = createPrivateTempDirectory() private[this] val workers = new ConcurrentHashMap[String, DirectWorkerProcess]() private[this] val closed = new AtomicBoolean(false) @@ -125,8 +121,43 @@ abstract class DirectWorkerDispatcher( private val environmentLock = new Object private[this] var cleanupHook: Option[Thread] = None - /** Creates a protocol-specific connection to a worker at the given socket path. */ - protected def createConnection(socketPath: String): WorkerConnection + /** + * Allocates a fresh endpoint address for a new worker. The string is + * passed to the worker binary as `--connection
`. + */ + protected def newEndpointAddress(workerId: String): String + + /** + * Waits for the worker process to be ready to accept connections at + * `address`. Throws [[DirectWorkerTimeoutException]] on timeout, or + * [[DirectWorkerException]] if the process exits early. + */ + protected def waitForReady( + address: String, + process: Process, + outputFile: File): Unit + + /** + * Best-effort per-endpoint cleanup, called from the spawn-failure path + * before any [[WorkerArtifacts]] / [[WorkerConnection]] exists. + */ + protected def cleanupEndpointAddress(address: String): Unit + + /** + * Cleans up dispatcher-level transport state (e.g., a UDS socket + * directory). Called from [[close]]. + */ + protected def closeTransport(): Unit + + /** + * Validates the worker spec's transport choice. Subclasses declare + * which transports they support. Called from the base constructor; + * implementations must only read base-class state (`workerSpec`). + */ + protected def validateTransportSupport(): Unit + + /** Creates a protocol-specific connection to a worker at the given address. */ + protected def createConnection(address: String): WorkerConnection /** Creates a protocol-specific session for the given worker. */ protected def createSessionForWorker(worker: DirectWorkerProcess): WorkerSession @@ -202,16 +233,9 @@ abstract class DirectWorkerDispatcher( } } workers.clear() - try { - val dir = socketDir.toFile - if (dir.exists()) { - val remaining = dir.listFiles() - if (remaining != null) remaining.foreach(_.delete()) - dir.delete() - } - } catch { + try closeTransport() catch { case NonFatal(e) => - logger.warn(s"Error cleaning up socket directory $socketDir", e) + logger.warn("Error cleaning up transport state", e) } deregisterEnvironmentCleanupHook() runEnvironmentCleanup() @@ -348,62 +372,42 @@ abstract class DirectWorkerDispatcher( require(baseCmd.nonEmpty, "DirectWorker.runner must have at least one entry in command or arguments") val workerId = UUID.randomUUID().toString - val socketPath = socketDir.resolve(s"worker-$workerId.sock").toString + val address = newEndpointAddress(workerId) // Proto contract: the engine must pass --id and --connection. - val cmd = baseCmd ++ Seq("--id", workerId, "--connection", socketPath) + val cmd = baseCmd ++ Seq("--id", workerId, "--connection", address) val env = runner.getEnvironmentVariablesMap.asScala.toMap val outputFile = Files.createTempFile("udf-worker-", ".log") val process = launchProcess(cmd, env, outputFile.toFile) - // Bundle raw resources so spawn-failure teardown reuses the happy-path - // dispose path. - val artifacts = new WorkerArtifacts(process, socketPath, outputFile, logger) try { - waitForSocket(socketPath, process, outputFile.toFile) - val connection = createConnection(socketPath) - // Ownership of `artifacts` transfers to the DirectWorkerProcess. + waitForReady(address, process, outputFile.toFile) + val connection = createConnection(address) + val artifacts = new WorkerArtifacts(process, connection, outputFile, logger) new DirectWorkerProcess( - workerId, artifacts, connection, gracefulTimeoutMs, logger, + workerId, artifacts, gracefulTimeoutMs, logger, onLastSessionReleased = releaseWorker) } catch { case e: InterruptedException => Thread.currentThread().interrupt() - artifacts.close() + cleanupRawSpawn(process, address, outputFile) throw e case NonFatal(e) => - artifacts.close() + cleanupRawSpawn(process, address, outputFile) throw e } } - /** - * Creates a temp directory with owner-only permissions (0700 on POSIX). - * On non-POSIX filesystems falls back to best-effort `File.setXxx`, - * which is TOCTOU-racy and weaker; a WARN surfaces if the platform - * refuses the setters. - */ - private def createPrivateTempDirectory(): Path = { - val attr = PosixFilePermissions.asFileAttribute( - PosixFilePermissions.fromString("rwx------")) - try { - Files.createTempDirectory("spark-udf-worker", attr) - } catch { - case _: UnsupportedOperationException => - val dir = Files.createTempDirectory("spark-udf-worker") - val f = dir.toFile - // `&` (non-short-circuiting) so every setter is attempted even if - // an earlier one refused. - val applied = - f.setReadable(false, false) & f.setWritable(false, false) & - f.setExecutable(false, false) & f.setReadable(true, true) & - f.setWritable(true, true) & f.setExecutable(true, true) - if (!applied) { - logger.warn( - s"Could not fully restrict permissions on $dir; socket " + - s"directory may be accessible to other local users on this " + - s"filesystem") - } - dir + // Pre-WorkerArtifacts cleanup: the connection has not been built yet, + // so we have no bundle to close(). Each step is independent. + private def cleanupRawSpawn(p: Process, address: String, outputFile: Path): Unit = { + DirectWorkerDispatcher.destroyForciblyAndReap(p, logger, "failed spawn") + try cleanupEndpointAddress(address) catch { + case NonFatal(e) => + logger.debug(s"Failed to clean up endpoint address $address", e) + } + try Files.deleteIfExists(outputFile) catch { + case NonFatal(e) => + logger.debug(s"Failed to clean up worker output file $outputFile", e) } } @@ -422,48 +426,9 @@ abstract class DirectWorkerDispatcher( builder.start() } - private def waitForSocket( - socketPath: String, - process: Process, - outputFile: File): Unit = { - val file = new File(socketPath) - // At least one poll so very small initTimeouts don't trip a premature - // timeout before the worker has any chance to create the socket. - val maxAttempts = math.max(1, (initTimeoutMs / SOCKET_POLL_INTERVAL_MS).toInt) - var attempts = 0 - while (!file.exists() && attempts < maxAttempts) { - if (!process.isAlive) throwWorkerExitedBeforeSocket(process, socketPath, outputFile) - Thread.sleep(SOCKET_POLL_INTERVAL_MS) - attempts += 1 - } - if (!file.exists()) { - if (process.isAlive) { - DirectWorkerDispatcher.destroyForciblyAndReap( - process, logger, s"init timeout $socketPath") - val tail = readOutputTail(outputFile) - throw new DirectWorkerTimeoutException( - s"Worker did not create socket at $socketPath within ${initTimeoutMs}ms\n$tail") - } else { - // Worker exited after the last poll without creating the socket; - // prefer the exit-code message over the ambiguous "did not create". - throwWorkerExitedBeforeSocket(process, socketPath, outputFile) - } - } - } - - private def throwWorkerExitedBeforeSocket( - process: Process, - socketPath: String, - outputFile: File): Nothing = { - val tail = readOutputTail(outputFile) - throw new DirectWorkerException( - s"Worker exited with code ${process.exitValue()} " + - s"before creating socket at $socketPath\n$tail") - } - // Bounded scan so a runaway worker that writes gigabytes of output does // not OOM the caller during error reporting. - private def readOutputTail(file: File): String = { + protected def readOutputTail(file: File): String = { if (!file.exists() || file.length() == 0) return "" val fileLen = file.length() val startPos = math.max(0L, fileLen - MAX_OUTPUT_SCAN_BYTES) @@ -494,18 +459,6 @@ abstract class DirectWorkerDispatcher( // -- Spec validation ------------------------------------------------------- - // TCP transport is declared in the proto but not yet implemented; the - // engine currently only supports UDS. - private def validateTransportSupport(): Unit = { - val props = workerSpec.getDirect.getProperties - require(props.hasConnection, - "DirectWorker.properties.connection must be set") - val conn = props.getConnection - require(conn.hasUnixDomainSocket, - "DirectWorker currently only supports UNIX domain socket transport, " + - s"got ${conn.getTransportCase}") - } - // Verification exists to short-circuit installation when the environment // is already prepared, so requiring installation alongside verification // catches user errors at spec-validation time. diff --git a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala index d840d4352bbb4..f4b5c1df63193 100644 --- a/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala +++ b/udf/worker/core/src/main/scala/org/apache/spark/udf/worker/core/direct/DirectWorkerProcess.scala @@ -16,7 +16,6 @@ */ package org.apache.spark.udf.worker.core.direct -import java.io.File import java.nio.file.{Files, Path} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} @@ -30,16 +29,15 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} * :: Experimental :: * A locally-spawned OS process running a UDF worker, together with its * transport connection. Wraps a [[WorkerArtifacts]] bundle (process + - * socket file + output log) and a [[WorkerConnection]], plus a session - * ref-count scaffolding for future pooling -- today one process per - * session. + * connection + output log) plus a session ref-count scaffolding for + * future pooling -- today one process per session. * * Closing sends SIGTERM, waits up to [[gracefulTimeoutMs]], then - * delegates forced kill + file cleanup to [[WorkerArtifacts.close]]. + * delegates connection close + forced kill + file cleanup to + * [[WorkerArtifacts.close]]. * * @param id stable worker identifier (UUID passed to the binary as `--id`). - * @param artifacts process + socket + output-log, disposed together. - * @param connection transport channel to this worker. + * @param artifacts process + connection + output-log, disposed together. * @param gracefulTimeoutMs wait after SIGTERM before escalating to SIGKILL. * @param logger [[WorkerLogger]] for process-level messages. * @param onLastSessionReleased fires when the ref-count hits 0. Runs on @@ -52,7 +50,6 @@ import org.apache.spark.udf.worker.core.{WorkerConnection, WorkerLogger} class DirectWorkerProcess( val id: String, private[direct] val artifacts: WorkerArtifacts, - val connection: WorkerConnection, val gracefulTimeoutMs: Long, protected val logger: WorkerLogger = WorkerLogger.NoOp, private[direct] val onLastSessionReleased: DirectWorkerProcess => Unit = _ => ()) @@ -66,8 +63,8 @@ class DirectWorkerProcess( /** The OS process handle for this worker. */ def process: Process = artifacts.process - /** The UDS socket path used by this worker. */ - def socketPath: String = artifacts.socketPath + /** The transport connection for this worker. */ + def connection: WorkerConnection = artifacts.connection /** Path to the merged stdout/stderr log for this worker. */ def outputFile: Path = artifacts.outputFile @@ -102,20 +99,13 @@ class DirectWorkerProcess( def isAlive: Boolean = process.isAlive && connection.isActive /** - * Shuts down the connection, sends SIGTERM, waits up to - * [[gracefulTimeoutMs]], then disposes artifacts (SIGKILL + file + * Sends SIGTERM, waits up to [[gracefulTimeoutMs]] for the worker to + * exit, then disposes artifacts (connection close + SIGKILL + file * cleanup). Idempotent via CAS. */ override def close(): Unit = { if (!closed.compareAndSet(false, true)) return - try { - connection.close() - } catch { - case NonFatal(e) => - logger.warn(s"Error closing connection to worker at $socketPath", e) - } - if (process.isAlive) { process.destroy() // SIGTERM try { @@ -134,38 +124,36 @@ class DirectWorkerProcess( /** * Closeable bundle of per-worker OS resources: the child [[Process]], its - * UDS socket file, and its merged stdout/stderr log. [[close]] always - * SIGKILL-reaps then deletes the files; graceful SIGTERM is the higher - * layer's responsibility (see [[DirectWorkerProcess#close]]). - * - * One dispose implementation shared by the happy-path teardown and the - * spawn-failure path (which runs before a `DirectWorkerProcess` wrapper - * exists). + * transport [[WorkerConnection]], and its merged stdout/stderr log. + * [[close]] runs connection close (which for UDS removes the socket + * file), then SIGKILL-reaps the process, then deletes the output log. + * Graceful SIGTERM is the higher layer's responsibility (see + * [[DirectWorkerProcess#close]]). */ private[direct] final class WorkerArtifacts( val process: Process, - val socketPath: String, + val connection: WorkerConnection, val outputFile: Path, private[this] val logger: WorkerLogger) extends AutoCloseable { private[this] val closed = new AtomicBoolean(false) /** - * Idempotently SIGKILLs the process, deletes the socket file, deletes - * the output log. Each step is guarded so a failure in one does not - * skip the next. + * Idempotently closes the connection (transport teardown + any + * transport-specific cleanup such as deleting a UDS socket file), + * SIGKILL-reaps the process, and deletes the output log. Each step + * is guarded so a failure in one does not skip the next. */ override def close(): Unit = { if (!closed.compareAndSet(false, true)) return - DirectWorkerDispatcher.destroyForciblyAndReap( - process, logger, s"worker artifacts $socketPath") - - try Files.deleteIfExists(new File(socketPath).toPath) catch { + try connection.close() catch { case NonFatal(e) => - logger.warn(s"Error cleaning up socket file $socketPath", e) + logger.warn("Error closing worker connection", e) } + DirectWorkerDispatcher.destroyForciblyAndReap(process, logger, "worker artifacts") + try Files.deleteIfExists(outputFile) catch { case NonFatal(e) => logger.warn(s"Error cleaning up worker output file $outputFile", e) diff --git a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala index 53507b6db235f..60f5e2211b702 100644 --- a/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala +++ b/udf/worker/core/src/test/scala/org/apache/spark/udf/worker/core/DirectWorkerDispatcherSuite.scala @@ -30,20 +30,18 @@ import org.apache.spark.udf.worker.{ DirectWorker, LocalTcpConnection, ProcessCallable, UDFWorkerProperties, UDFWorkerSpecification, UnixDomainSocket, WorkerConnectionSpec, WorkerEnvironment} -import org.apache.spark.udf.worker.core.direct.{DirectWorkerDispatcher, +import org.apache.spark.udf.worker.core.direct.{DirectUnixSocketWorkerDispatcher, DirectWorkerException, DirectWorkerProcess, DirectWorkerSession, DirectWorkerTimeoutException} /** * A [[WorkerConnection]] test implementation that considers the connection - * active as long as the socket file exists on disk. + * active as long as the socket file exists on disk. Inherits socket-file + * deletion from [[UnixSocketWorkerConnection.close]]. */ -class SocketFileConnection(socketPath: String) extends WorkerConnection { +class SocketFileConnection(socketPath: String) + extends UnixSocketWorkerConnection(socketPath) { override def isActive: Boolean = new File(socketPath).exists() - override def close(): Unit = { - val f = new File(socketPath) - if (f.exists()) f.delete() - } } /** @@ -70,20 +68,20 @@ class StubWorkerSession( } /** - * A [[DirectWorkerDispatcher]] subclass for testing that uses a socket-file - * connection and stub sessions instead of a real protocol implementation. + * A [[DirectUnixSocketWorkerDispatcher]] subclass for testing that uses + * a socket-file connection and stub sessions instead of a real protocol + * implementation. */ class TestDirectWorkerDispatcher(spec: UDFWorkerSpecification) - extends DirectWorkerDispatcher(spec) { + extends DirectUnixSocketWorkerDispatcher(spec) { - override protected def createConnection(socketPath: String): WorkerConnection = { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = new SocketFileConnection(socketPath) - } override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = { + worker: DirectWorkerProcess): WorkerSession = new StubWorkerSession(worker) - } } /** @@ -157,6 +155,14 @@ class DirectWorkerDispatcherSuite s"Expected StubWorkerSession, got ${other.getClass.getSimpleName}") } + // The whole suite uses UDS as the only transport, so reaching past the + // generic WorkerConnection abstraction to read the socket path is fine. + private def udsPath(w: DirectWorkerProcess): String = w.connection match { + case uds: UnixSocketWorkerConnection => uds.socketPath + case other => fail( + s"Expected UnixSocketWorkerConnection, got ${other.getClass.getSimpleName}") + } + test("creates a worker and session") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) @@ -165,7 +171,7 @@ class DirectWorkerDispatcherSuite assert(worker.isAlive, "worker should be alive after creation") assert(worker.activeSessions == 1, "should have 1 active session") - assert(new File(worker.socketPath).exists(), "socket file should exist") + assert(new File(udsPath(worker)).exists(), "socket file should exist") session.close() assert(worker.activeSessions == 0, "should have 0 sessions after close") @@ -208,7 +214,7 @@ class DirectWorkerDispatcherSuite // that accidentally shared underlying transport resources could still // hand out distinct DirectWorkerProcess wrappers pointing at the same // socket. Verify socket paths are unique too. - val socketPaths = workerObjects.map(_.socketPath) + val socketPaths = workerObjects.map(udsPath) assert(socketPaths.distinct.length == threads, s"each worker should have its own socket path, got $socketPaths") @@ -286,7 +292,7 @@ class DirectWorkerDispatcherSuite val session = createStubSession() val worker = session.workerProcess - val socketFile = new File(worker.socketPath) + val socketFile = new File(udsPath(worker)) assert(worker.process.isAlive, "worker should be alive before session close") assert(socketFile.exists(), "socket file should exist before session close") @@ -341,7 +347,7 @@ class DirectWorkerDispatcherSuite s"unexpected errors during concurrent close: ${errors.toArray.mkString(", ")}") workers.foreach { w => assert(!w.process.isAlive, - s"worker at ${w.socketPath} should be terminated after concurrent close") + s"worker at ${udsPath(w)} should be terminated after concurrent close") } } @@ -356,8 +362,9 @@ class DirectWorkerDispatcherSuite val releaseLatch = new java.util.concurrent.CountDownLatch(1) val capturedWorkers = new java.util.concurrent.ConcurrentLinkedQueue[DirectWorkerProcess]() - val racing = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { - override protected def createConnection(socketPath: String): WorkerConnection = + val racing = new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = new SocketFileConnection(socketPath) override protected def createSessionForWorker( worker: DirectWorkerProcess): WorkerSession = { @@ -372,7 +379,8 @@ class DirectWorkerDispatcherSuite } } try { - val outcome = new java.util.concurrent.atomic.AtomicReference[Either[Throwable, WorkerSession]]() + val outcome = + new java.util.concurrent.atomic.AtomicReference[Either[Throwable, WorkerSession]]() val createThread = new Thread(() => { try { val s = racing.createSession(None) @@ -428,10 +436,11 @@ class DirectWorkerDispatcherSuite while (worker.process.isAlive && System.currentTimeMillis() < deadline) { Thread.sleep(50) } + val sockPath = udsPath(worker) assert(!worker.process.isAlive, - s"worker process should be terminated after close, still alive at ${worker.socketPath}") - assert(!new java.io.File(worker.socketPath).exists(), - s"socket file ${worker.socketPath} should have been removed") + s"worker process should be terminated after close, still alive at $sockPath") + assert(!new java.io.File(sockPath).exists(), + s"socket file $sockPath should have been removed") } finally { releaseLatch.countDown() racing.close() @@ -490,9 +499,9 @@ class DirectWorkerDispatcherSuite test("socket directory is owner-only (0700) on POSIX") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) // Drive one createSession so a worker (and therefore the socket dir) is - // observable via session.workerProcess.socketPath. + // observable via the UDS connection's path. val session = createStubSession() - val socketDir: Path = new File(session.workerProcess.socketPath).toPath.getParent + val socketDir: Path = new File(udsPath(session.workerProcess)).toPath.getParent session.close() val view = Files.getFileAttributeView(socketDir, classOf[PosixFileAttributeView]) @@ -510,7 +519,7 @@ class DirectWorkerDispatcherSuite test("socket directory is removed after dispatcher.close") { dispatcher = new TestDirectWorkerDispatcher(specWithRunner(defaultRunner)) val session = createStubSession() - val socketDir = new File(session.workerProcess.socketPath).toPath.getParent.toFile + val socketDir = new File(udsPath(session.workerProcess)).toPath.getParent.toFile assert(socketDir.exists(), s"socket directory $socketDir should exist while a session is open") session.close() @@ -528,15 +537,17 @@ class DirectWorkerDispatcherSuite // A dispatcher whose createSessionForWorker always throws. The spawned // worker must be terminated rather than leaked until dispatcher.close(). var capturedWorker: DirectWorkerProcess = null - val failingDispatcher = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { - override protected def createConnection(socketPath: String): WorkerConnection = - new SocketFileConnection(socketPath) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = { - capturedWorker = worker - throw new RuntimeException("session creation failed") + val failingDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = { + capturedWorker = worker + throw new RuntimeException("session creation failed") + } } - } try { val ex = intercept[RuntimeException] { @@ -582,15 +593,17 @@ class DirectWorkerDispatcherSuite test("socket file is cleaned up when createConnection throws") { val capturedSocketPaths = new java.util.concurrent.ConcurrentLinkedQueue[String]() - val failingDispatcher = new DirectWorkerDispatcher(specWithRunner(defaultRunner)) { - override protected def createConnection(socketPath: String): WorkerConnection = { - capturedSocketPaths.add(socketPath) - throw new RuntimeException("connection creation failed") + val failingDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithRunner(defaultRunner)) { + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = { + capturedSocketPaths.add(socketPath) + throw new RuntimeException("connection creation failed") + } + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) } - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = - new StubWorkerSession(worker) - } try { val ex = intercept[RuntimeException] { failingDispatcher.createSession(None) @@ -747,9 +760,10 @@ class DirectWorkerDispatcherSuite .addCommand("sleep 30").build() val env = WorkerEnvironment.newBuilder().setInstallation(slowInstall).build() val shortTimeoutDispatcher = - new DirectWorkerDispatcher(specWithEnv(env = env)) { + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { override protected def callableTimeoutMs: Long = 500L - override protected def createConnection(socketPath: String): WorkerConnection = + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = new SocketFileConnection(socketPath) override protected def createSessionForWorker( worker: DirectWorkerProcess): WorkerSession = @@ -882,14 +896,16 @@ class DirectWorkerDispatcherSuite .addCommand( s"echo invoked >> ${counterFile.getAbsolutePath}; sleep 30").build()) .build() - val timeoutDispatcher = new DirectWorkerDispatcher(specWithEnv(env = env)) { - override protected def callableTimeoutMs: Long = 500L - override protected def createConnection(socketPath: String): WorkerConnection = - new SocketFileConnection(socketPath) - override protected def createSessionForWorker( - worker: DirectWorkerProcess): WorkerSession = - new StubWorkerSession(worker) - } + val timeoutDispatcher = + new DirectUnixSocketWorkerDispatcher(specWithEnv(env = env)) { + override protected def callableTimeoutMs: Long = 500L + override protected def createConnection( + socketPath: String): UnixSocketWorkerConnection = + new SocketFileConnection(socketPath) + override protected def createSessionForWorker( + worker: DirectWorkerProcess): WorkerSession = + new StubWorkerSession(worker) + } try { val first = intercept[DirectWorkerTimeoutException] { timeoutDispatcher.createSession(None)