diff --git a/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs b/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs index 000c0d5f7c5..e23258fa78f 100644 --- a/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs +++ b/src/Mono.Android/Xamarin.Android.Net/AndroidMessageHandler.cs @@ -104,6 +104,116 @@ public void Reset () } } + sealed class CancellationAwareResponseStream : Stream + { + readonly Stream stream; + readonly HttpURLConnection httpConnection; + int streamDisposed; + + public CancellationAwareResponseStream (Stream stream, HttpURLConnection httpConnection) + { + this.stream = stream ?? throw new ArgumentNullException (nameof (stream)); + this.httpConnection = httpConnection ?? throw new ArgumentNullException (nameof (httpConnection)); + } + + public override bool CanRead => stream.CanRead; + public override bool CanSeek => stream.CanSeek; + public override bool CanWrite => stream.CanWrite; + public override long Length => stream.Length; + + public override long Position { + get => stream.Position; + set => stream.Position = value; + } + + protected override void Dispose (bool disposing) + { + if (disposing) { + DisposeStream (); + } + + base.Dispose (disposing); + } + + public override void Flush () => stream.Flush (); + + public override async Task CopyToAsync (Stream destination, int bufferSize, CancellationToken cancellationToken) + { + cancellationToken.ThrowIfCancellationRequested (); + + using (cancellationToken.Register (QueueAbortRead, useSynchronizationContext: false)) { + try { + await stream.CopyToAsync (destination, bufferSize, cancellationToken).ConfigureAwait (false); + } catch (Exception ex) when (ShouldMapToCancellation (ex, cancellationToken)) { + throw new System.OperationCanceledException ("Response body read was canceled.", ex, cancellationToken); + } + } + } + + public override int Read (byte[] buffer, int offset, int count) => stream.Read (buffer, offset, count); + + public override Task ReadAsync (byte[] buffer, int offset, int count, CancellationToken cancellationToken) => ReadAsync (buffer.AsMemory (offset, count), cancellationToken).AsTask (); + + // StreamContent uses this overload on modern runtimes, so the wrapper must handle its ValueTask-based contract. + public override async ValueTask ReadAsync (Memory buffer, CancellationToken cancellationToken = default) + { + cancellationToken.ThrowIfCancellationRequested (); + + using (cancellationToken.Register (QueueAbortRead, useSynchronizationContext: false)) { + try { + return await stream.ReadAsync (buffer, cancellationToken).ConfigureAwait (false); + } catch (Exception ex) when (ShouldMapToCancellation (ex, cancellationToken)) { + throw new System.OperationCanceledException ("Response body read was canceled.", ex, cancellationToken); + } + } + } + + public override long Seek (long offset, SeekOrigin origin) => stream.Seek (offset, origin); + + public override void SetLength (long value) => stream.SetLength (value); + + public override void Write (byte[] buffer, int offset, int count) => stream.Write (buffer, offset, count); + + void QueueAbortRead () => + Task.Run (AbortRead).ContinueWith ( + task => Logger.Log (LogLevel.Info, LOG_APP, $"Response body cancellation exception: {task.Exception}"), + CancellationToken.None, + TaskContinuationOptions.OnlyOnFaulted | TaskContinuationOptions.ExecuteSynchronously, + TaskScheduler.Default); + + void AbortRead () + { + try { + httpConnection.Disconnect (); + } catch (Exception ex) { + Logger.Log (LogLevel.Info, LOG_APP, $"Disconnection exception: {ex}"); + } + + try { + DisposeStream (); + } catch (Exception ex) { + Logger.Log (LogLevel.Info, LOG_APP, $"Response stream close exception: {ex}"); + } + } + + void DisposeStream () + { + if (Interlocked.Exchange (ref streamDisposed, 1) == 0) + stream.Dispose (); + } + + static bool ShouldMapToCancellation (Exception ex, CancellationToken cancellationToken) + { + return cancellationToken.IsCancellationRequested && + ex is global::System.IO.IOException + or Java.IO.IOException + or InvalidDataException + or ObjectDisposedException + or WebException; + } + + } + internal const string LOG_APP = "monodroid-net"; const string GZIP_ENCODING = "gzip"; @@ -903,10 +1013,10 @@ Stream GetDecompressionWrapper (URLConnection httpConnection, Stream inputStream return ret ?? inputStream; } - HttpContent GetContent (URLConnection httpConnection, Stream contentStream, ContentState contentState) + HttpContent GetContent (HttpURLConnection httpConnection, Stream contentStream, ContentState contentState) { Stream inputStream = GetDecompressionWrapper (httpConnection, new BufferedStream (contentStream), contentState); - return new StreamContent (inputStream); + return new StreamContent (new CancellationAwareResponseStream (inputStream, httpConnection)); } bool HandleRedirect (HttpStatusCode redirectCode, HttpURLConnection httpConnection, RequestRedirectionState redirectState, out bool disposeRet) diff --git a/tests/Mono.Android-Tests/Mono.Android-Tests/Mono.Android.NET-Tests.csproj b/tests/Mono.Android-Tests/Mono.Android-Tests/Mono.Android.NET-Tests.csproj index c916be26314..a69ce7d5b53 100644 --- a/tests/Mono.Android-Tests/Mono.Android-Tests/Mono.Android.NET-Tests.csproj +++ b/tests/Mono.Android-Tests/Mono.Android-Tests/Mono.Android.NET-Tests.csproj @@ -136,6 +136,7 @@ + diff --git a/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs b/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs new file mode 100644 index 00000000000..0d6e5439a9a --- /dev/null +++ b/tests/Mono.Android-Tests/Mono.Android-Tests/Xamarin.Android.Net/AndroidMessageHandlerCancellationTests.cs @@ -0,0 +1,203 @@ +#nullable enable + +using System; +using System.Net; +using System.Net.Http; +using System.Net.Sockets; +using System.Threading; +using System.Threading.Tasks; + +using Xamarin.Android.Net; + +using NUnit.Framework; + +namespace Xamarin.Android.NetTests +{ + [TestFixture] + [Category ("AndroidMessageHandlerCancellation")] + [Category ("InetAccess")] + public class AndroidMessageHandlerCancellationTests + { + const int StalledResponseContentLength = 1024 * 1024; + const int BodyReadBlockDelayMilliseconds = 250; + const int PromptCancellationTimeoutMilliseconds = 3000; + + static readonly byte[] InitialResponseChunk = [42]; + StalledResponseServer? stalledResponseServer; + + [SetUp] + public void SetUp () + { + stalledResponseServer = new StalledResponseServer (); + } + + [TearDown] + public void TearDown () + { + var server = stalledResponseServer; + stalledResponseServer = null; + + // NUnitLite used by the on-device tests does not support async TearDown methods. + if (server != null) + server.Stop (); + } + + [Test] + public async Task ResponseContentReadBodyReadCancellationIsPrompt () + { + var server = stalledResponseServer ?? throw new InvalidOperationException ("The stalled response server was not initialized."); + using var handler = new AndroidMessageHandler (); + using var client = new HttpClient (handler); + using var cts = new CancellationTokenSource (); + using var request = new HttpRequestMessage (HttpMethod.Get, $"http://localhost:{server.Port}/"); + + Task readTask = client.SendAsync (request, HttpCompletionOption.ResponseContentRead, cts.Token); + + await WaitForBodyReadToBlock (server.BodyStartedTask).ConfigureAwait (false); + cts.Cancel (); + await AssertCanceledPromptly (readTask, server.ReleaseResponseBody).ConfigureAwait (false); + } + + [Test] + public async Task ResponseHeadersReadBodyReadCancellationIsPrompt () + { + var server = stalledResponseServer ?? throw new InvalidOperationException ("The stalled response server was not initialized."); + using var handler = new AndroidMessageHandler (); + using var client = new HttpClient (handler); + using var request = new HttpRequestMessage (HttpMethod.Get, $"http://localhost:{server.Port}/"); + using var response = await client.SendAsync (request, HttpCompletionOption.ResponseHeadersRead).ConfigureAwait (false); + using var readCts = new CancellationTokenSource (); + + Task readContentTask = response.Content.ReadAsByteArrayAsync (readCts.Token); + + await WaitForBodyReadToBlock (server.BodyStartedTask).ConfigureAwait (false); + readCts.Cancel (); + await AssertCanceledPromptly (readContentTask, server.ReleaseResponseBody).ConfigureAwait (false); + } + + static int GetAvailablePort () + { + using var tcpListener = new TcpListener (IPAddress.Loopback, 0); + tcpListener.Start (); + int port = ((IPEndPoint) tcpListener.LocalEndpoint).Port; + tcpListener.Stop (); + return port; + } + + static async Task WaitForBodyReadToBlock (Task bodyStarted) + { + var completed = await Task.WhenAny (bodyStarted, Task.Delay (PromptCancellationTimeoutMilliseconds)).ConfigureAwait (false); + if (completed != bodyStarted) + Assert.Fail ($"The test server did not start sending a response body within {PromptCancellationTimeoutMilliseconds}ms."); + + await bodyStarted.ConfigureAwait (false); + await Task.Delay (BodyReadBlockDelayMilliseconds).ConfigureAwait (false); + } + + static async Task AssertCanceledPromptly (Task readTask, Action releaseBody) + { + var completed = await Task.WhenAny (readTask, Task.Delay (PromptCancellationTimeoutMilliseconds)).ConfigureAwait (false); + if (completed != readTask) { + releaseBody (); + await ObserveReadTaskAfterRelease (readTask).ConfigureAwait (false); + Assert.Fail ($"Response body read did not observe cancellation within {PromptCancellationTimeoutMilliseconds}ms."); + } + + try { + await readTask.ConfigureAwait (false); + Assert.Fail ("Response body read completed successfully after cancellation."); + } catch (OperationCanceledException) { + return; + } + } + + static async Task ObserveReadTaskAfterRelease (Task readTask) + { + var completed = await Task.WhenAny (readTask, Task.Delay (PromptCancellationTimeoutMilliseconds)).ConfigureAwait (false); + if (completed != readTask) + return; + + try { + await readTask.ConfigureAwait (false); + } catch (Exception ex) { + Console.WriteLine ($"Exception after releasing stalled response body: {ex}"); + } + } + + sealed class StalledResponseServer + { + readonly HttpListener listener; + readonly TaskCompletionSource bodyStarted = new TaskCompletionSource (TaskCreationOptions.RunContinuationsAsynchronously); + readonly TaskCompletionSource releaseBody = new TaskCompletionSource (TaskCreationOptions.RunContinuationsAsynchronously); + readonly Task serverTask; + + public StalledResponseServer () + { + Port = GetAvailablePort (); + listener = new HttpListener (); + listener.Prefixes.Add ($"http://localhost:{Port}/"); + listener.Start (); + + serverTask = ServeStalledResponseBody (); + } + + public int Port { get; } + + public Task BodyStartedTask => bodyStarted.Task; + + public void Stop () + { + ReleaseResponseBody (); + listener.Close (); + ObserveServerTask ().GetAwaiter ().GetResult (); + } + + public void ReleaseResponseBody () + { + releaseBody.TrySetResult (true); + } + + async Task ServeStalledResponseBody () + { + try { + var context = await listener.GetContextAsync ().ConfigureAwait (false); + using var response = context.Response; + response.StatusCode = 200; + response.ContentLength64 = StalledResponseContentLength; + await response.OutputStream.WriteAsync (InitialResponseChunk, 0, InitialResponseChunk.Length).ConfigureAwait (false); + await response.OutputStream.FlushAsync ().ConfigureAwait (false); + bodyStarted.TrySetResult (true); + + await releaseBody.Task.ConfigureAwait (false); + await WriteRemainingResponseBody (response).ConfigureAwait (false); + } catch (Exception ex) { + if (!BodyStartedTask.IsCompleted) { + bodyStarted.TrySetException (ex); + return; + } + Console.WriteLine ($"Exception while serving stalled response body: {ex}"); + } + } + + async Task WriteRemainingResponseBody (HttpListenerResponse response) + { + var buffer = new byte [4096]; + int remainingBytes = StalledResponseContentLength - InitialResponseChunk.Length; + while (remainingBytes > 0) { + int bytesToWrite = Math.Min (remainingBytes, buffer.Length); + await response.OutputStream.WriteAsync (buffer, 0, bytesToWrite).ConfigureAwait (false); + remainingBytes -= bytesToWrite; + } + } + + async Task ObserveServerTask () + { + var completed = await Task.WhenAny (serverTask, Task.Delay (PromptCancellationTimeoutMilliseconds)).ConfigureAwait (false); + if (completed != serverTask) + return; + + await serverTask.ConfigureAwait (false); + } + } + } +}