Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.flink.runtime.state.StreamStateHandle;
import org.apache.flink.util.ExceptionUtils;
import org.apache.flink.util.FileUtils;
import org.apache.flink.util.FlinkRuntimeException;
import org.apache.flink.util.IOUtils;
import org.apache.flink.util.concurrent.FutureUtils;
import org.apache.flink.util.function.ThrowingRunnable;
Expand All @@ -44,10 +43,13 @@
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand Down Expand Up @@ -84,6 +86,7 @@ public void transferAllStateDataToDirectory(
CloseableRegistry internalCloser = new CloseableRegistry();
// Make sure we also react to external close signals.
closeableRegistry.registerCloseable(internalCloser);
AtomicReference<Throwable> rawException = new AtomicReference<>();
try {
// We have to wait for all futures to be completed, to make sure in
// case of failure that we will clean up all the files
Expand All @@ -96,6 +99,12 @@ public void transferAllStateDataToDirectory(
runnable,
transfer.getExecutorService()))
.collect(Collectors.toList()));

// Capture the raw CompletionException before get() strips it. get() unwraps
// CompletionException to its cause (RuntimeException), losing the suppressed list
// that holds all parallel thread failures. whenComplete fires before get() unblocks.
downloadFuture.whenComplete((v, t) -> rawException.set(t));

Exception interruptedException = null;
while (!downloadFuture.isDone() || downloadFuture.isCompletedExceptionally()) {
try {
Expand All @@ -114,14 +123,8 @@ public void transferAllStateDataToDirectory(
.map(StateHandleDownloadSpec::getDownloadDestination)
.map(Path::toFile)
.forEach(FileUtils::deleteDirectoryQuietly);
// Error reporting
Throwable throwable = ExceptionUtils.stripExecutionException(e);
throwable = ExceptionUtils.stripException(throwable, RuntimeException.class);
if (throwable instanceof IOException) {
throw (IOException) throwable;
} else {
throw new FlinkRuntimeException("Failed to download data for state handles.", e);
}
Throwable raw = rawException.get();
throw buildDownloadException(raw != null ? raw : e);
} finally {
// Unregister and close the internal closer.
if (closeableRegistry.unregisterCloseable(internalCloser)) {
Expand Down Expand Up @@ -261,6 +264,40 @@ private void downloadDataForStateHandle(
}
}

/**
* Builds a descriptive {@link IOException} from a potentially cascaded failure across multiple
* parallel download threads.
*
* <p>When one thread fails it closes the shared {@link CloseableRegistry}, causing all other
* threads to throw a {@code ClosedChannelException} on their local file writes. This method
* strips the wrapper chain of each collected failure to reach the real {@link IOException},
* deduplicates by type and message, and returns either the single unique cause directly or a
* merged exception listing all distinct failures.
*/
private static IOException buildDownloadException(Throwable rawException) {
Map<String, Throwable> unique = new LinkedHashMap<>();
Stream.concat(Stream.of(rawException), Stream.of(rawException.getSuppressed()))
.map(t -> ExceptionUtils.stripException(t, CompletionException.class))
.map(t -> ExceptionUtils.stripException(t, RuntimeException.class))
.forEach(t -> unique.putIfAbsent(t.getClass().getName() + ":" + t.getMessage(), t));

if (unique.size() == 1) {
Throwable t = unique.values().iterator().next();
return t instanceof IOException ? (IOException) t : new IOException(t);
}

String summary =
unique.values().stream()
.map(t -> t.getClass().getSimpleName() + ": " + t.getMessage())
.collect(Collectors.joining(" | "));
IOException merged =
new IOException(
unique.size() + " downloads failed with distinct errors: [" + summary + "]",
unique.values().iterator().next());
unique.values().stream().skip(1).forEach(merged::addSuppressed);
return merged;
}

@Override
public void close() throws IOException {
this.transfer.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import java.util.UUID;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -230,11 +231,178 @@ public void testMultiThreadCleanupOnFailure() throws Exception {
Assert.assertFalse(closeableRegistry.isClosed());
}

/**
* Tests that when a single download fails, the root cause exception is surfaced directly
* without being wrapped in a merged "N downloads failed" message.
*/
@Test
public void testSingleDownloadFailureSurfacedDirectly() throws Exception {
IOException rootCause = new IOException("file not found on remote storage");
StreamStateHandle failingHandle = new ThrowingStateHandle(rootCause);

IncrementalRemoteKeyedStateHandle stateHandle =
new IncrementalRemoteKeyedStateHandle(
UUID.randomUUID(),
KeyGroupRange.EMPTY_KEY_GROUP_RANGE,
1L,
singletonList(HandleAndLocalPath.of(failingHandle, "state")),
emptyList(),
failingHandle);

try (RocksDBStateDownloader downloader = new RocksDBStateDownloader(1)) {
downloader.transferAllStateDataToDirectory(
singletonList(
new StateHandleDownloadSpec(
stateHandle, temporaryFolder.newFolder().toPath())),
new CloseableRegistry());
fail("Expected IOException");
} catch (IOException e) {
assertEquals(rootCause, e.getCause());
Assert.assertFalse(
"Single failure should not produce a merged message, got: " + e.getMessage(),
e.getMessage() != null && e.getMessage().contains("downloads failed"));
}
}

/**
* Tests that when one download fails among many parallel ones, the root cause IOException is
* visible in the error rather than being buried under cascade ClosedChannelExceptions.
*
* <p>Before the fix, the error always showed ClosedChannelException (a cascade artifact) and
* the real cause (e.g. FileNotFoundException for a missing state file) was lost.
*/
@Test
public void testRootCauseVisibleAmongCascadeFailures() throws Exception {
int numRemoteHandles = 3;
int numSubHandles = 6;
byte[][][] contents = createContents(numRemoteHandles, numSubHandles);
List<StateHandleDownloadSpec> downloadRequests = new ArrayList<>(numRemoteHandles);
for (int i = 0; i < numRemoteHandles; ++i) {
downloadRequests.add(
createDownloadRequestForContent(
temporaryFolder.newFolder().toPath(), contents[i], i));
}

IOException rootCause = new IOException("state file missing from remote storage");
downloadRequests
.get(0)
.getStateHandle()
.getSharedState()
.add(HandleAndLocalPath.of(new ThrowingStateHandle(rootCause), "error-handle"));

try (RocksDBStateDownloader downloader = new RocksDBStateDownloader(5)) {
downloader.transferAllStateDataToDirectory(downloadRequests, new CloseableRegistry());
fail("Expected IOException");
} catch (IOException e) {
boolean rootCauseVisible =
(e.getCause() != null
&& rootCause.getMessage().equals(e.getCause().getMessage()))
|| (e.getMessage() != null
&& e.getMessage().contains(rootCause.getMessage()))
|| ExceptionUtils.findThrowable(
e, t -> rootCause.getMessage().equals(t.getMessage()))
.isPresent();
Assert.assertTrue(
"Root cause '"
+ rootCause.getMessage()
+ "' should be visible in exception, got: "
+ e,
rootCauseVisible);
}
}

/**
* Tests that when multiple downloads fail with distinct exceptions simultaneously, all distinct
* errors appear in the merged error message. A {@link CyclicBarrier} ensures all threads reach
* their failure point before any registry closure, so each failure is captured independently.
*/
@Test
public void testMultipleDistinctFailuresMergedInMessage() throws Exception {
int n = 3;
CyclicBarrier barrier = new CyclicBarrier(n);
IOException causeA = new IOException("error-A: bucket not accessible");
IOException causeB = new IOException("error-B: file not found");
IOException causeC = new IOException("error-C: read timeout");

List<HandleAndLocalPath> handles = new ArrayList<>();
handles.add(HandleAndLocalPath.of(new BarrierThrowingStateHandle(barrier, causeA), "s1"));
handles.add(HandleAndLocalPath.of(new BarrierThrowingStateHandle(barrier, causeB), "s2"));
handles.add(HandleAndLocalPath.of(new BarrierThrowingStateHandle(barrier, causeC), "s3"));

IncrementalRemoteKeyedStateHandle stateHandle =
new IncrementalRemoteKeyedStateHandle(
UUID.randomUUID(),
KeyGroupRange.EMPTY_KEY_GROUP_RANGE,
1L,
handles,
emptyList(),
handles.get(0).getHandle());

try (RocksDBStateDownloader downloader = new RocksDBStateDownloader(n)) {
downloader.transferAllStateDataToDirectory(
singletonList(
new StateHandleDownloadSpec(
stateHandle, temporaryFolder.newFolder().toPath())),
new CloseableRegistry());
fail("Expected IOException");
} catch (IOException e) {
Assert.assertTrue(
"Expected merged error message, got: " + e.getMessage(),
e.getMessage() != null
&& e.getMessage().contains("downloads failed with distinct errors"));
Assert.assertTrue(
"Expected causeA in message", e.getMessage().contains(causeA.getMessage()));
Assert.assertTrue(
"Expected causeB in message", e.getMessage().contains(causeB.getMessage()));
Assert.assertTrue(
"Expected causeC in message", e.getMessage().contains(causeC.getMessage()));
}
}

private void assertStateContentEqual(byte[] expected, Path path) throws IOException {
byte[] actual = Files.readAllBytes(Paths.get(path.toUri()));
assertArrayEquals(expected, actual);
}

/**
* A {@link StreamStateHandle} that synchronizes all N threads at a {@link CyclicBarrier} before
* throwing, ensuring all failures happen before any registry closure.
*/
private static class BarrierThrowingStateHandle implements TestStreamStateHandle {
private static final long serialVersionUID = 1L;

private final CyclicBarrier barrier;
private final IOException exception;

private BarrierThrowingStateHandle(CyclicBarrier barrier, IOException exception) {
this.barrier = barrier;
this.exception = exception;
}

@Override
public FSDataInputStream openInputStream() throws IOException {
try {
barrier.await(30, TimeUnit.SECONDS);
} catch (Exception e) {
throw new IOException("Barrier interrupted", e);
}
throw exception;
}

@Override
public Optional<byte[]> asBytesIfInMemory() {
return Optional.empty();
}

@Override
public void discardState() {}

@Override
public long getStateSize() {
return 0;
}
}

private static class SpecifiedException extends IOException {
SpecifiedException(String message) {
super(message);
Expand Down