diff --git a/.gitignore b/.gitignore index b01de2507..96b9f616a 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,5 @@ ace __pycache__ crates/guest-rust/src/cabi_realloc.o wit_component + +/wit-bindgen.sln diff --git a/Cargo.lock b/Cargo.lock index 018b2d15e..7ebbe7ef6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1460,6 +1460,7 @@ dependencies = [ "clap", "heck", "indexmap", + "regex", "wasm-metadata 0.247.0", "wit-bindgen-core", "wit-component", diff --git a/crates/core/src/abi.rs b/crates/core/src/abi.rs index 4f8bace95..0c4635419 100644 --- a/crates/core/src/abi.rs +++ b/crates/core/src/abi.rs @@ -1256,7 +1256,7 @@ impl<'a, B: Bindgen> Generator<'a, B> { // based on slightly different logic for the `task.return` // intrinsic. // - // Note that in the async import case teh code below deals with the CM function being lowered, + // Note that in the async import case the code below deals with the CM function being lowered, // not the core function that is underneath that (i.e. func.result may be empty, // where the associated core function underneath must have a i32 status code result) let (lower_to_memory, async_flat_results) = match (async_, &func.result) { diff --git a/crates/csharp/Cargo.toml b/crates/csharp/Cargo.toml index 47dfa7d5d..c1773a5e4 100644 --- a/crates/csharp/Cargo.toml +++ b/crates/csharp/Cargo.toml @@ -28,6 +28,7 @@ heck = { workspace = true } clap = { workspace = true, optional = true } anyhow = { workspace = true } indexmap = { workspace = true } +regex = "1.12.3" [features] default = ["aot"] diff --git a/crates/csharp/src/AsyncSupport.cs b/crates/csharp/src/AsyncSupport.cs index 46cdbc1bb..37c617511 100644 --- a/crates/csharp/src/AsyncSupport.cs +++ b/crates/csharp/src/AsyncSupport.cs @@ -12,7 +12,7 @@ public enum EventCode { Cancel = 6, } -public enum CallbackCode : uint +public enum CallbackCode : int { Exit = 0, Yield = 1, @@ -68,6 +68,9 @@ private static class Interop [global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[context-get-0]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] internal static unsafe extern ContextTask* ContextGet(); + + [global::System.Runtime.InteropServices.DllImport("$root", EntryPoint = "[subtask-drop]"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] + internal static unsafe extern void SubtaskDrop(int handle); } public static int WaitableSetNew() @@ -130,7 +133,7 @@ public static unsafe void ContextSet(ContextTask* contextTask) } // unsafe because we are using pointers. - public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr) + public static unsafe int Callback(EventWaitable e) { ContextTask* contextTaskPtr = ContextGet(); @@ -139,7 +142,7 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr) if (e.IsDropped) { - waitableInfoState.FutureStream.OtherSideDropped(); + waitableInfoState.FutureStream!.OtherSideDropped(); } if (e.IsCompleted || e.IsDropped) @@ -148,13 +151,25 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr) waitables.Remove(e.Waitable, out _); if (e.IsSubtask) { - // TODO: Handle/lift async function return values. - waitableInfoState.SetResult(0 /* not used */); + switch (e.SubtaskStatus) + { + case { IsStarting: true }: + throw new Exception("unexpected subtask status Starting " + e.Code); + + case { IsStarted: true }: + break; + + case { IsReturned: true }: + waitableInfoState.SetResult(e.WaitableCount); + Interop.SubtaskDrop(e.Waitable); + break; + + default: + throw new Exception("TODO: subtask status " + e.Code); + } } else { - waitableInfoState.FutureStream.FreeBuffer(); - if (e.IsDropped) { waitableInfoState.SetException(new StreamDroppedException()); @@ -170,10 +185,10 @@ public static unsafe uint Callback(EventWaitable e, ContextTask* contextPtr) { ContextSet(null); Marshal.FreeHGlobal((IntPtr)contextTaskPtr); - return (uint)CallbackCode.Exit; + return (int)CallbackCode.Exit; } - return (uint)CallbackCode.Wait | (uint)(contextTaskPtr->WaitableSetHandle << 4); + return (int)CallbackCode.Wait | (int)(contextTaskPtr->WaitableSetHandle << 4); } throw new NotImplementedException($"WaitableStatus not implemented {e.WaitableStatus.State} in set {contextTaskPtr->WaitableSetHandle}"); @@ -186,7 +201,7 @@ internal static unsafe Task TaskFromStatus(uint status) status = status & 0xF; var tcs = new TaskCompletionSource(); - if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted) + if (subtaskStatus.IsStarting || subtaskStatus.IsStarted) { ContextTask* contextTaskPtr = ContextGet(); if (contextTaskPtr == null) @@ -198,10 +213,10 @@ internal static unsafe Task TaskFromStatus(uint status) return tcs.Task; } - else if (subtaskStatus.IsSubtaskReturned) + else if (subtaskStatus.IsReturned) { tcs.SetResult(0); - return Task.CompletedTask; + return tcs.Task; } else { @@ -214,7 +229,7 @@ public static unsafe Task TaskFromStatus(uint status, Func liftFunc) { var subtaskStatus = new SubtaskStatus(status); - if (subtaskStatus.IsSubtaskStarting || subtaskStatus.IsSubtaskStarted) + if (subtaskStatus.IsStarting || subtaskStatus.IsStarted) { ContextTask* contextTaskPtr = ContextGet(); if (contextTaskPtr == null) { @@ -227,7 +242,7 @@ public static unsafe Task TaskFromStatus(uint status, Func liftFunc) return tcs.Task; } - else if (subtaskStatus.IsSubtaskReturned) + else if (subtaskStatus.IsReturned) { var tcs = new TaskCompletionSource(); tcs.SetResult(liftFunc()); @@ -242,11 +257,30 @@ public static unsafe Task TaskFromStatus(uint status, Func liftFunc) // Placeholder, TODO: Needs implementing for async functions that return values. internal class LiftingTaskCompletionSource : TaskCompletionSource { - internal LiftingTaskCompletionSource(TaskCompletionSource innerTaskCompletionSource, Func _liftFunc) + internal LiftingTaskCompletionSource(TaskCompletionSource innerTaskCompletionSource, Func liftFunc) { innerTaskCompletionSource.Task.ContinueWith(t => { - throw new NotImplementedException("lifting results from async functions not implemented yet"); - }); + if (t.Status == TaskStatus.RanToCompletion) + { + try + { + SetResult(liftFunc()); + } + catch(Exception e) + { + SetException(e); + } + } + else if (t.Status == TaskStatus.Faulted) + { + SetException(t.Exception!); + } + else if (t.Status == TaskStatus.Canceled) + { + SetCanceled(); + } + throw new NotImplementedException("LiftingTaskCompletionSource unexpected task status " + t.Status); + }, TaskContinuationOptions.ExecuteSynchronously); } } @@ -272,26 +306,17 @@ internal LiftingTaskCompletionSource(TaskCompletionSource innerTaskCompleti public delegate uint StreamWrite(int handle, IntPtr buffer, uint length); public delegate uint StreamRead(int handle, IntPtr buffer, uint length); -public delegate void Lower(object payload, uint size); +public delegate Array Lift(IntPtr buffer, Array? resultBuffer); +public delegate void Lower(object payload, List cleanups); public delegate uint CancelRead(int handle); public delegate uint CancelWrite(int handle); -public interface ICancelableRead -{ - uint CancelRead(int handle); -} - -public interface ICancelableWrite -{ - uint CancelWrite(int handle); -} - public interface ICancelable { uint Cancel(); } -public class CancelableRead(ICancelableRead cancelableVTable, int handle) : ICancelable +public class CancelableRead(IVTable cancelableVTable, int handle) : ICancelable { public uint Cancel() { @@ -299,7 +324,7 @@ public uint Cancel() } } -public class CancelableWrite(ICancelableWrite cancelableVTable, int handle) : ICancelable +public class CancelableWrite(IVTable cancelableVTable, int handle) : ICancelable { public uint Cancel() { @@ -307,17 +332,35 @@ public uint Cancel() } } -public struct FutureVTable : ICancelableRead, ICancelableWrite +/// +/// Common to all VTables. +/// +public interface IVTable +{ + uint CancelRead(int handle); + uint CancelWrite(int handle); + uint Size { get; set; } + uint Align { get; set; } +} + +public struct FutureVTable : IVTable { + // Generated code even if we are not using futures, so disable the warning. +#pragma warning disable 649 internal New New; internal FutureRead Read; internal FutureWrite Write; internal DropReader DropReader; internal DropWriter DropWriter; + internal Lift? Lift; internal Lower? Lower; internal CancelWrite CancelWriteDelegate; internal CancelRead CancelReadDelegate; +#pragma warning disable 649 + // The size and alignment of the buffer. + public uint Size { get; set; } + public uint Align { get; set; } public uint CancelRead(int handle) { return CancelReadDelegate(handle); @@ -329,16 +372,24 @@ public uint CancelWrite(int handle) } } -public struct StreamVTable : ICancelableRead, ICancelableWrite +public struct StreamVTable : IVTable { + // Generated code even if we are not using streams, so disable the warning. +#pragma warning disable 649 internal New New; internal StreamRead Read; internal StreamWrite Write; internal DropReader DropReader; internal DropWriter DropWriter; + internal Lift? Lift; internal Lower? Lower; internal CancelWrite CancelWriteDelegate; internal CancelRead CancelReadDelegate; +#pragma warning disable 649 + + // The size and alignment of the buffer. + public uint Size { get; set; } + public uint Align { get; set; } public uint CancelRead(int handle) { @@ -353,7 +404,6 @@ public uint CancelWrite(int handle) internal interface IFutureStream : IDisposable { - void FreeBuffer(); // Called when notified the other side is dropped. void OtherSideDropped(); } @@ -403,7 +453,6 @@ internal static (StreamReader, StreamWriter) RawStreamNew(StreamVTable public abstract class ReaderBase : IFutureStream { - private GCHandle? bufferHandle; private bool writerDropped; internal ReaderBase(int handle) @@ -424,24 +473,54 @@ internal int TakeHandle() return handle; } - protected GCHandle LiftBuffer(T[] buffer) + protected unsafe IntPtr GetBuffer(int length, T[]? userBuffer, IVTable vTable, List cleanups) + { + // For primitive, blittable types, TODO: this probably does not align 100% with the component ABI? + if (typeof(T).IsPrimitive || typeof(T).IsValueType) + { + T[] buffer; + // For Streams, the user passes the buffer, so use that. + if(userBuffer != null) + { + buffer = userBuffer; + } + else + { + buffer = new T[length]; + } + var handle = GCHandle.Alloc(buffer, GCHandleType.Pinned); + cleanups.Add(() => handle.Free()); + return handle.AddrOfPinnedObject(); + } + else + { + System.Diagnostics.Debug.Assert(vTable.Size > 0, $"Did not compute size for {typeof(T)}."); + IntPtr bufferPtr = (IntPtr)global::System.Runtime.InteropServices.NativeMemory.AlignedAlloc(vTable.Size, vTable.Align); + cleanups.Add(() => global::System.Runtime.InteropServices.NativeMemory.Free((void*)bufferPtr)); + return bufferPtr; + } + } + + protected unsafe T[] LiftBuffer(IntPtr buffer, T[] resultBuffer, Lift? liftFunc) { // For primitive, blittable types if (typeof(T).IsPrimitive || typeof(T).IsValueType) { - return GCHandle.Alloc(buffer, GCHandleType.Pinned); + // TODO array length > 1 + resultBuffer[0] = *(T*)buffer; } else { - // TODO: create buffers for lowered stream types and then lift - throw new NotImplementedException("reading from futures types that require lifting"); + liftFunc(buffer, resultBuffer); } + + return resultBuffer; } internal abstract uint VTableRead(IntPtr bufferPtr, int length); // unsafe as we are working with pointers. - internal unsafe ComponentTask ReadInternal(Func liftBuffer, int length, ICancelableRead cancelableRead) + internal unsafe ComponentTask ReadInternal(IntPtr buffer, int length, IVTable vtable) { if (Handle == 0) { @@ -453,11 +532,10 @@ internal unsafe ComponentTask ReadInternal(Func liftBuffer, int throw new StreamDroppedException(); } - bufferHandle = liftBuffer(); - var status = new WaitableStatus(VTableRead(bufferHandle == null ? IntPtr.Zero : bufferHandle.Value.AddrOfPinnedObject(), length)); + var status = new WaitableStatus(VTableRead(buffer, length)); if (status.IsBlocked) { - var task = new ComponentTask(new CancelableRead(cancelableRead, Handle)); + var task = new ComponentTask(new CancelableRead(vtable, Handle)); ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); if(contextTaskPtr == null) { @@ -475,11 +553,6 @@ internal unsafe ComponentTask ReadInternal(Func liftBuffer, int throw new NotImplementedException(status.State.ToString()); } - void IFutureStream.FreeBuffer() - { - bufferHandle?.Free(); - } - void IFutureStream.OtherSideDropped() { writerDropped = true; @@ -519,7 +592,7 @@ internal FutureReader(int handle, FutureVTable vTable) : base(handle) public ComponentTask Read() { - return ReadInternal(() => null, 0, VTable); + return ReadInternal(IntPtr.Zero, 0, VTable); } internal override uint VTableRead(IntPtr ptr, int length) @@ -539,8 +612,9 @@ public class FutureReader(int handle, FutureVTable vTable) : ReaderBase(handl public ComponentTask Read() { - T[] buf = new T[1]; - ComponentTask internalTask = ReadInternal(() => LiftBuffer(buf), 1, VTable); + var cleanups = new List(); + var buf = GetBuffer(1, null /* We need the buffer created for us */, VTable, cleanups); + ComponentTask internalTask = ReadInternal(buf, 1, VTable); // Wrap the task so we can return a T and not the number of Ts read ComponentTask readTask = new(new DelegatingCancelable(internalTask)); @@ -549,13 +623,25 @@ public ComponentTask Read() { if (it.IsCompletedSuccessfully) { - readTask.SetResult(buf[0]); + try + { + readTask.SetResult(((T[])VTable.Lift(buf, new T[1]))[0]); + } + catch(Exception e) + { + readTask.SetException(e); + } } else if (!it.IsCanceled) { //TODO throw new NotImplementedException("faulted future read not implemented"); } + + foreach(var cleanup in cleanups) + { + cleanup(); + } }); return readTask; } @@ -598,7 +684,7 @@ public StreamReader(int handle, StreamVTable vTable) : base(handle) public ComponentTask Read(int length) { - return ReadInternal(() => null, length, VTable); + return ReadInternal(IntPtr.Zero, length, VTable); } internal override uint VTableRead(IntPtr ptr, int length) @@ -616,9 +702,31 @@ public class StreamReader(int handle, StreamVTable vTable) : ReaderBase(hand { public StreamVTable VTable { get; private set; } = vTable; - public ComponentTask Read(T[] buffer) + public ComponentTask Read(T[] resultBuffer) { - return ReadInternal(() => LiftBuffer(buffer), buffer.Length, VTable); + var cleanups = new List(); + var buf = GetBuffer(resultBuffer.Length, resultBuffer, VTable, cleanups); + + var task = ReadInternal(buf, resultBuffer.Length, VTable); + task.ContinueWith(it => + { + if (it.IsCompletedSuccessfully) + { + VTable.Lift(buf, resultBuffer); + } + else if (!it.IsCanceled) + { + //TODO + throw new NotImplementedException("faulted stream read not implemented"); + } + + foreach(var cleanup in cleanups) + { + cleanup(); + } + }); + + return task; } internal override uint VTableRead(IntPtr ptr, int length) @@ -634,7 +742,7 @@ internal override void VTableDrop() public abstract class WriterBase : IFutureStream { - private GCHandle? bufferHandle; + private nint bufferPtr; private bool readerDropped; private bool canDrop; @@ -659,7 +767,7 @@ internal int TakeHandle() internal abstract uint VTableWrite(IntPtr bufferPtr, int length); // unsafe as we are working with pointers. - internal unsafe ComponentTask WriteInternal(Func lowerPayload, int length, ICancelableWrite cancelable) + internal unsafe ComponentTask WriteInternal(Func, nint> lowerPayload, int length, IVTable cancelable) { if (Handle == 0) { @@ -670,15 +778,21 @@ internal unsafe ComponentTask WriteInternal(Func lowerPayload, i { throw new StreamDroppedException(); } - bufferHandle = lowerPayload(); + var cleanups = new List(); + bufferPtr = lowerPayload(cleanups); - var status = new WaitableStatus(VTableWrite(bufferHandle == null ? IntPtr.Zero : bufferHandle.Value.AddrOfPinnedObject(), length)); + var status = new WaitableStatus(VTableWrite(bufferPtr, length)); canDrop = true; // We can only call drop once something has been written. if (status.IsBlocked) { var tcs = new ComponentTask(new CancelableWrite(cancelable, Handle)); tcs.ContinueWith(t => { + foreach(var cleanup in cleanups) + { + cleanup(); + } + if (t.IsCanceled) { canDrop = false; @@ -696,18 +810,12 @@ internal unsafe ComponentTask WriteInternal(Func lowerPayload, i if (status.IsCompleted) { - bufferHandle?.Free(); return ComponentTask.FromResult((int)status.Count); } throw new NotImplementedException($"Unsupported write status {status.State}"); } - void IFutureStream.FreeBuffer() - { - bufferHandle?.Free(); - } - void IFutureStream.OtherSideDropped() { readerDropped = true; @@ -742,7 +850,7 @@ public class FutureWriter(int handle, FutureVTable vTable) : WriterBase(handle) public ComponentTask Write() { - return WriteInternal(() => null, 0, VTable); + return WriteInternal(_ => 0, 0, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -760,23 +868,23 @@ public class FutureWriter(int handle, FutureVTable vTable) : WriterBase(handl { public FutureVTable VTable { get; private set; } = vTable; - private GCHandle LowerPayload(T[] payload) + private nint LowerPayload(T payload, List cleanups) { if (VTable.Lower == null) { - return GCHandle.Alloc(payload, GCHandleType.Pinned); + return GCHandle.Alloc(payload, GCHandleType.Pinned).AddrOfPinnedObject(); } else { // Lower the payload - throw new NotSupportedException("StreamWriter Write where the payload must be lowered."); - // var loweredPayload = VTable.Lower(payload); + VTable.Lower(payload, cleanups); + return InteropReturnArea.returnArea.AddressOfReturnArea(); } } public ComponentTask Write(T payload) { - return WriteInternal(() => LowerPayload([payload]), 1, VTable); + return WriteInternal(cleanups => LowerPayload(payload, cleanups), 1, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -796,7 +904,7 @@ public class StreamWriter(int handle, StreamVTable vTable) : WriterBase(handle) public ComponentTask Write() { - return WriteInternal(() => null, 0, VTable); + return WriteInternal(_ => 0, 0, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -812,26 +920,26 @@ internal override void VTableDrop() public class StreamWriter(int handle, StreamVTable vTable) : WriterBase(handle) { - private GCHandle bufferHandle; + private nint bufferPtr; public StreamVTable VTable { get; private set; } = vTable; - private GCHandle LowerPayload(T[] payload) + private nint LowerPayload(T[] payload, List cleanups) { if (VTable.Lower == null) { - return GCHandle.Alloc(payload, GCHandleType.Pinned); + return GCHandle.Alloc(payload, GCHandleType.Pinned).AddrOfPinnedObject(); } else { // Lower the payload - throw new NotSupportedException("StreamWriter Write where the payload must be lowered."); - // var loweredPayload = VTable.Lower(payload); + VTable.Lower(payload, cleanups); + return InteropReturnArea.returnArea.AddressOfReturnArea(); } } public ComponentTask Write(T[] payload) { - return WriteInternal(() => LowerPayload(payload), payload.Length, VTable); + return WriteInternal(cleanups => LowerPayload(payload, cleanups), payload.Length, VTable); } internal override uint VTableWrite(IntPtr bufferPtr, int length) @@ -988,7 +1096,7 @@ public void SetResult(T result) tcs.SetResult(result); } - public static ComponentTask FromResult(T result) + public static ComponentTask FromResult(T result) { var task = new ComponentTask(); task.tcs.SetResult(result); diff --git a/crates/csharp/src/FutureCommonSupport.cs b/crates/csharp/src/FutureCommonSupport.cs index 0db3f3347..aac3d5b0d 100644 --- a/crates/csharp/src/FutureCommonSupport.cs +++ b/crates/csharp/src/FutureCommonSupport.cs @@ -16,18 +16,17 @@ public readonly struct SubtaskStatus (uint status) { public uint State => status & 0xf; public int Handle => (int)(status >> 4); - public bool IsSubtaskStarting => State == 0; - public bool IsSubtaskStarted => State == 1; - public bool IsSubtaskReturned => State == 2; - public bool IsSubtaskStartedCancelled => State == 3; - public bool IsSubtaskReturnedCancelled => State == 4; + public bool IsStarting => State == 0; + public bool IsStarted => State == 1; + public bool IsReturned => State == 2; + public bool IsStartedCancelled => State == 3; + public bool IsReturnedCancelled => State == 4; } public readonly struct EventWaitable { public EventWaitable(EventCode eventCode, uint waitable, uint code) { - Console.WriteLine($"EventWaitable with code {code}"); EventCode = eventCode; Waitable = (int)waitable; Code = code; @@ -52,6 +51,6 @@ public EventWaitable(EventCode eventCode, uint waitable, uint code) public readonly SubtaskStatus SubtaskStatus; public readonly int WaitableCount => (int)Code >> 4; public bool IsDropped => !IsSubtask && WaitableStatus.IsDropped; - public bool IsCompleted => IsSubtask && SubtaskStatus.IsSubtaskReturned || !IsSubtask && WaitableStatus.IsCompleted; + public bool IsCompleted => IsSubtask && SubtaskStatus.IsReturned || !IsSubtask && WaitableStatus.IsCompleted; } diff --git a/crates/csharp/src/function.rs b/crates/csharp/src/function.rs index 046aa2491..c35f00cb1 100644 --- a/crates/csharp/src/function.rs +++ b/crates/csharp/src/function.rs @@ -2,6 +2,7 @@ use crate::csharp_ident::ToCSharpIdent; use crate::interface::{InterfaceGenerator, ParameterType, variant_new_func_name}; use crate::world_generator::CSharp; use heck::ToUpperCamelCase; +use regex::Regex; use std::fmt::Write; use std::mem; use std::ops::Deref; @@ -299,6 +300,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { target: String, func_name: String, oper: String, + use_await: bool, ) -> String { let ret = self.locals.tmp("ret"); if self.interface_gen.csharp_gen.opts.with_wit_results { @@ -312,7 +314,7 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { .type_name_with_qualifier(&func.result.unwrap(), true); let is_async = InterfaceGenerator::is_async(&func.kind); - if is_async { + if is_async && !use_await { uwriteln!(self.src, "Task<{ty}> {ret};"); } else { uwriteln!(self.src, "{ty} {ret};"); @@ -362,7 +364,11 @@ impl<'a, 'b> FunctionBindgen<'a, 'b> { } else { format!("{target}.{func_name}({oper})") }; - uwriteln!(self.src, "{ret} = {head}{val}{tail};"); + uwriteln!( + self.src, + "{ret} = {}{head}{val}{tail};", + if use_await { "await " } else { "" } + ); if !self.results.is_empty() { self.interface_gen.csharp_gen.needs_wit_exception = true; let cases = cases.join("\n"); @@ -1229,6 +1235,71 @@ impl Bindgen for FunctionBindgen<'_, '_> { } let is_async = InterfaceGenerator::is_async(self.kind); + if is_async && self.interface_gen.direction == Direction::Export { + // The UCO method cannot be async so we create and call another async method. + // This allows us to follow the same codegen pattern as other languages. + self.interface_gen.csharp_gen.needs_async_support = true; + let async_func_name = format!("{}Async", self.func_name.to_upper_camel_case()); + + uwriteln!( + self.src, + r#"var task = {async_func_name}({oper}); + if (task.IsCompletedSuccessfully) + {{ + return (int)CallbackCode.Exit; + }} + + // TODO: Defer dropping borrowed resources until a result is returned. + ContextTask* contextTaskPtr = AsyncSupport.ContextGet(); + + return (int)CallbackCode.Wait | (int)(contextTaskPtr->WaitableSetHandle << 4); + }} + "# + ); + + // Start the Async function + uwriteln!( + self.src, + r#"public static async Task {async_func_name}({}) + {{ + var cleanups = new global::System.Collections.Generic.List(); + + "#, + func.params + .iter() + .enumerate() + .map(|(i, p)| { + let mut param_type = + self.interface_gen.type_name_with_qualifier(&p.ty, false); + + // Resource types need the Impl class to be distinguised. + match p.ty { + Type::Id(type_id) => { + let id = dealias(self.interface_gen.resolve, type_id); + + let kind = &self.interface_gen.resolve.types[id].kind; + match kind { + TypeDefKind::Handle(handle) => { + let (Handle::Own(ty) | Handle::Borrow(ty)) = handle; + param_type = + self.interface_gen.csharp_gen.all_resources + [&ty] + .export_impl_name(); + } + _ => {} + } + } + _ => {} + } + + format!("{} {}", param_type, strip_lift(&operands[i])) + }) + .collect::>() + .join(", ") + ); + self.needs_cleanup = true; + } + match self.kind { FunctionKind::Constructor(id) => { let target = @@ -1251,20 +1322,26 @@ impl Bindgen for FunctionBindgen<'_, '_> { match func.result { None => { if is_async { - uwriteln!(self.src, "var ret = {target}.{func_name}({oper});"); + uwriteln!(self.src, "await {target}.{func_name}({oper});"); } else { uwriteln!(self.src, "{target}.{func_name}({oper});"); } } Some(_ty) => { - let ret = self.handle_result_call(func, target, func_name, oper); + let ret = self.handle_result_call( + func, + target, + func_name, + oper, + is_async && self.interface_gen.direction == Direction::Export, + ); results.push(ret); } } } } - if is_async { + if is_async && self.interface_gen.direction == Direction::Import { self.interface_gen.csharp_gen.needs_async_support = true; let name = self.func_name.to_upper_camel_case(); let ret_param = match func.result { @@ -1557,11 +1634,17 @@ impl Bindgen for FunctionBindgen<'_, '_> { Instruction::FutureLower { payload, ty: _ } | Instruction::StreamLower { payload, ty: _ } => { let op = &operands[0]; - let generic_type_name = match payload { - Some(generic_type) => &self - .interface_gen - .type_name_with_qualifier(generic_type, false), - None => "", + let (generic_type_name, generic_type_name_with_qualifier) = match payload { + Some(generic_type) => { + let name = self + .interface_gen + .type_name_with_qualifier(generic_type, false); + let qualified_name = self + .interface_gen + .type_name_with_qualifier(generic_type, true); + (name, qualified_name) + } + None => (String::new(), String::new()), }; match inst { @@ -1569,6 +1652,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { self.interface_gen.add_future( self.func_name, &generic_type_name, + &generic_type_name_with_qualifier, **payload, ); } @@ -1576,6 +1660,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { self.interface_gen.add_stream( self.func_name, &generic_type_name, + &generic_type_name_with_qualifier, **payload, ); } @@ -1584,8 +1669,26 @@ impl Bindgen for FunctionBindgen<'_, '_> { results.push(format!("{op}.TakeHandle()")); } - Instruction::AsyncTaskReturn { name: _, params: _ } => { + Instruction::AsyncTaskReturn { name, params: _ } => { + let name = name + .strip_prefix("[task-return]") + .unwrap() + .to_upper_camel_case(); uwriteln!(self.src, "// TODO: task_cancel.forget();"); + if self.interface_gen.direction == Direction::Export { + uwriteln!(self.src, "{name}TaskReturn({});", operands.join(", ")); + + if self.needs_cleanup { + uwriteln!( + self.src, + " + foreach (var cleanup in cleanups) + {{ + cleanup(); + }}" + ); + } + } } Instruction::FutureLift { payload, ty: _ } @@ -1636,6 +1739,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { self.interface_gen.add_future( self.func_name, &generic_type_name, + &generic_type_name_with_qualifier, **payload, ); } @@ -1643,6 +1747,7 @@ impl Bindgen for FunctionBindgen<'_, '_> { self.interface_gen.add_stream( self.func_name, &generic_type_name, + &generic_type_name_with_qualifier, **payload, ); } @@ -1768,6 +1873,14 @@ impl Bindgen for FunctionBindgen<'_, '_> { } } +// TODO: this is not great, we want the underlying parameter, but it is passed in operands already lifted. +// This regex will transform unchecked((uint)(p0)) -> p0 +pub fn strip_lift(lifted_param: &String) -> String { + let re = Regex::new(r"(?x)unchecked\(\s*\(\s*\w+\s*\)\s*\(\s*(\w+)\s*\)\s*\)").unwrap(); + let out = re.replace_all(lifted_param, "$1"); + out.into_owned() +} + /// Dereference any number `TypeDefKind::Type` aliases to retrieve the target type. pub fn dealias(resolve: &Resolve, mut id: TypeId) -> TypeId { loop { diff --git a/crates/csharp/src/interface.rs b/crates/csharp/src/interface.rs index af3c36e78..bc76aff9d 100644 --- a/crates/csharp/src/interface.rs +++ b/crates/csharp/src/interface.rs @@ -20,6 +20,8 @@ use wit_parser::{ Record, Resolve, Result_, Tuple, Type, TypeDefKind, TypeId, TypeOwner, Variant, WorldKey, }; +const MAX_FLAT_PARAMS: usize = 16; + pub(crate) struct InterfaceFragment { pub(crate) csharp_src: String, pub(crate) csharp_interop_src: String, @@ -41,9 +43,11 @@ impl InterfaceTypeAndFragments { } } +#[derive(Clone)] pub(crate) struct FutureInfo { pub name: String, pub generic_type_name: String, + pub qualified_generic_type_name: String, pub ty: Option, } @@ -230,18 +234,23 @@ impl InterfaceGenerator<'_> { let mut generated_future_types: HashSet> = HashSet::new(); let (_namespace, interface_name) = &CSharp::get_class_name_from_qualified_name(self.name); let interop_name = format!("{}Interop", interface_name.strip_prefix("I").unwrap()); - let (futures_or_streams, stream_length_param) = if is_future { - (&self.futures, "") + // avoid the immutable self borrow + let (futures_or_streams, stream_length_param): (Vec, &str) = if is_future { + (self.futures.iter().cloned().collect(), "") } else { - (&self.streams, ", uint length") + (self.streams.iter().cloned().collect(), ", uint length") }; + let mut index = 0; + let mut size = 0; + let mut align = 0; for future in futures_or_streams { // This code originally copied from Rust codegen generate_payload. // See the rust codegen for the comment - essentially we canonicalize to one per type. let canonical_payload = match future.ty { Some(Type::Id(id)) => { let id = self.csharp_gen.types.get_representative_type(id); + match self.resolve.types[id].kind { TypeDefKind::Type(t) => Some(t), _ => Some(Type::Id(id)), @@ -254,10 +263,23 @@ impl InterfaceGenerator<'_> { continue; } } - + let mut lift_func = "null".to_string(); + let mut lower_func = "null".to_string(); let future_name = &future.name; let generic_type_name = &future.generic_type_name; let upper_camel_future_type = generic_type_name.to_upper_camel_case(); + + if let Some(payload) = canonical_payload { + //TODO: wasm64 + size = self.csharp_gen.sizes.size(&payload).size_wasm32(); + align = self.csharp_gen.sizes.align(&payload).align_wasm32(); + + if needs_ptr(&payload) { + lift_func = format!("{future_stream_name}Lift{upper_camel_future_type}"); + lower_func = format!("{future_stream_name}Lower{upper_camel_future_type}"); + } + } + uwrite!( self.csharp_interop_src, r#" @@ -270,6 +292,10 @@ impl InterfaceGenerator<'_> { DropWriter = {future_stream_name}DropWriter{upper_camel_future_type}, CancelReadDelegate = {future_stream_name}CancelRead{upper_camel_future_type}, CancelWriteDelegate = {future_stream_name}CancelWrite{upper_camel_future_type}, + Lift = {lift_func}, + Lower = {lower_func}, + Size = {size}, + Align = {align}, }}; "# ); @@ -364,6 +390,55 @@ impl InterfaceGenerator<'_> { direction: Some(self.direction), }); + if lift_func != "null" + && let Some(payload_type) = canonical_payload + { + let (lift, result) = self.lift_from_memory("buffer", &payload_type); + + uwrite!( + self.csharp_interop_src, + r#" + public static unsafe Array {future_stream_name}Lift{upper_camel_future_type}(IntPtr buffer, Array resultBuffer) + {{ + {lift} + + // TODO: length > 1 + resultBuffer.SetValue({result}, 0); + return resultBuffer; + }} + "# + ); + } + + if lower_func != "null" + && let Some(payload_type) = canonical_payload + { + let lower_code = self.lower_to_memory("ptr", "typedToLower", &payload_type); + let size_align = self.csharp_gen.sizes.params(Some(&payload_type)); + // TODO: Wasm64 + self.csharp_gen.return_area_size = self + .csharp_gen + .return_area_size + .max(size_align.size.size_wasm32()); + self.csharp_gen.return_area_align = self + .csharp_gen + .return_area_align + .max(size_align.align.align_wasm32()); + self.csharp_gen.needs_export_return_area = true; + let qualified_generic_type_name = &future.qualified_generic_type_name; + uwrite!( + self.csharp_interop_src, + r#" + public static unsafe void {future_stream_name}Lower{upper_camel_future_type}(object toLower, List cleanups) + {{ + var ptr = InteropReturnArea.returnArea.AddressOfReturnArea(); + var typedToLower = ({qualified_generic_type_name})toLower; + {lower_code} + }} + "# + ); + } + if !bool_generic_new_added { self.csharp_gen .interface_fragments @@ -472,9 +547,9 @@ impl InterfaceGenerator<'_> { if requires_async_return_buffer_param { if param_list.is_empty() { - "void *taskResultBuffer".to_string() + "nint taskResultBuffer".to_string() } else { - format!("{param_list}, void *taskResultBuffer") + format!("{param_list}, nint taskResultBuffer") } } else { param_list @@ -673,12 +748,19 @@ impl InterfaceGenerator<'_> { let name = func.name.to_upper_camel_case(); let raw_name = format!("IImportsInterop.{name}WasmInterop.wasmImport{name}"); - let wasm_params = wasm_params - .iter() - .map(|v| v.as_str()) - .chain(func.result.map(|_| "address")) - .collect::>() - .join(", "); + let mut wasm_param_refs: Vec<&str> = wasm_params.iter().map(|v| v.as_str()).collect(); + + let async_return_buffer_prefixed; + + if func.result.is_some() { + // Build a new owned string with the prefix + async_return_buffer_prefixed = + Some(format!("(nint){}", async_return_buffer.as_ref().unwrap())); + // Borrow from that owned string + wasm_param_refs.push(async_return_buffer_prefixed.as_ref().unwrap().as_str()); + } + + let wasm_params = wasm_param_refs.join(", "); // TODO: lift expr let code = format!( @@ -686,19 +768,23 @@ impl InterfaceGenerator<'_> { var {async_status_var} = {raw_name}({wasm_params}); " ); - src = format!("{code}{}", bindgen.src); + src = format!("{code}"); + + let bindgen_src = bindgen.src; if let Some(buffer) = async_return_buffer { let ty = bindgen.result_type.expect("expected a result type"); - let lift_expr = abi::lift_from_memory( - bindgen.interface_gen.resolve, - &mut bindgen, - buffer.clone(), - &ty, - ); + let (lift_expr, res) = self.lift_from_memory("address", &ty); + uwriteln!(src, "{}", bindgen_src); let return_type = self.type_name_with_qualifier(&ty, true); - let lift_func = format!("() => {lift_expr}"); + let lift_func = format!( + "() => {{ + {lift_expr} + + return {res}; + }}" + ); uwriteln!( src, " @@ -807,7 +893,7 @@ var {async_status_var} = {raw_name}({wasm_params}); .join(";\n"); let wasm_result_type = if async_ { - "uint" + "int" } else { match &sig.results[..] { [] => "void", @@ -881,24 +967,13 @@ var {async_status_var} = {raw_name}({wasm_params}); "# ); - // TODO: Get the results from a static dictionary? - if sig.results.len() > 0 { - uwriteln!( - self.csharp_interop_src, - r#" - throw new NotImplementedException("callbacks with parameters are not yet implemented."); - }} - "# - ); - } else { - uwriteln!( - self.csharp_interop_src, - r#" - return (uint)AsyncSupport.Callback(e, (ContextTask *)IntPtr.Zero); + uwriteln!( + self.csharp_interop_src, + r#" + return (uint)AsyncSupport.Callback(e); }} "# - ); - } + ); } if abi::guest_export_needs_post_return(self.resolve, func) { @@ -961,29 +1036,77 @@ var {async_status_var} = {raw_name}({wasm_params}); interop_class_name = format!("Exports.{interop_class_name}"); } - // TODO: The task return function can take up to 16 core parameters. - let (task_return_param_sig, task_return_param) = match &sig.results[..] { - [] => (String::new(), String::new()), - [_result] => (format!("{wasm_result_type} result"), "result".to_string()), - _ => unreachable!(), + let resolve = &self.resolve; + let (task_return_param_sig_vec, task_return_param_vec): (Vec, Vec) = + func.result + .map(|ty| { + let mut storage = vec![abi::WasmType::I32; MAX_FLAT_PARAMS]; + let mut flat = abi::FlatTypes::new(&mut storage); + if resolve.push_flat(&ty, &mut flat) { + flat.to_vec() + } else { + vec![abi::WasmType::I32] + } + }) + .unwrap_or_default() + .into_iter() + .enumerate() + .map(|(i, ty)| { + let ty = crate::world_generator::wasm_type(ty); + (format!("{ty} arg{i}"), format!("arg{i}")) + }) + .unzip(); + + let task_return_param_sig = task_return_param_sig_vec.join(", "); + let task_return_param = task_return_param_vec.join(", "); + + let resource_type_name = match func.kind { + FunctionKind::Method(resource_type_id) + | FunctionKind::AsyncMethod(resource_type_id) + | FunctionKind::Static(resource_type_id) + | FunctionKind::Constructor(resource_type_id) => { + format!( + "Method{}", + self.csharp_gen.all_resources[&resource_type_id] + .name + .to_upper_camel_case() + ) + } + _ => String::new(), }; uwriteln!( self.src, r#" - public static void {camel_name}TaskReturn({task_return_param_sig} ) + public static void {camel_name}TaskReturn({task_return_param_sig}) {{ - {interop_class_name}.{camel_name}TaskReturn({task_return_param}); + {interop_class_name}.{resource_type_name}{camel_name}TaskReturn({task_return_param}); }} "# ); + let resource_type_name = match func.kind { + FunctionKind::Method(resource_type_id) + | FunctionKind::AsyncMethod(resource_type_id) + | FunctionKind::Static(resource_type_id) + | FunctionKind::Constructor(resource_type_id) => { + format!( + "Method{}", + self.csharp_gen.all_resources[&resource_type_id] + .name + .to_upper_camel_case() + ) + } + _ => String::new(), + }; + let task_return_name = format!("{resource_type_name}{camel_name}"); + uwriteln!( self.csharp_interop_src, r#" // TODO: The task return function can take up to 16 core parameters. [global::System.Runtime.InteropServices.DllImportAttribute("[export]{import_module}", EntryPoint = "[task-return]{wasm_func_name}"), global::System.Runtime.InteropServices.WasmImportLinkageAttribute] - public static extern void {camel_name}TaskReturn({task_return_param_sig}); + public static extern void {task_return_name}TaskReturn({task_return_param_sig}); "# ); } @@ -1457,11 +1580,13 @@ var {async_status_var} = {raw_name}({wasm_params}); &mut self, func_name: &str, generic_type_name: &str, + qualified_generic_type_name: &str, ty: Option, ) { self.futures.push(FutureInfo { name: func_name.to_string(), generic_type_name: generic_type_name.to_string(), + qualified_generic_type_name: qualified_generic_type_name.to_string(), ty: ty, }); } @@ -1470,14 +1595,56 @@ var {async_status_var} = {raw_name}({wasm_params}); &mut self, func_name: &str, generic_type_name: &str, + qualified_generic_type_name: &str, ty: Option, ) { self.streams.push(FutureInfo { name: func_name.to_string(), generic_type_name: generic_type_name.to_string(), + qualified_generic_type_name: qualified_generic_type_name.to_string(), ty: ty, }); } + + fn lift_from_memory(&mut self, address: &str, ty: &Type) -> (String, String) { + let boxed: Box<[String]> = Vec::new().into_boxed_slice(); + + let mut f = FunctionBindgen::new( + self, + "", + &FunctionKind::AsyncFreestanding, + boxed, + vec![], + ParameterType::ABI, + None, + ); + let result = abi::lift_from_memory(f.interface_gen.resolve, &mut f, address.into(), ty); + + (f.src, result) + } + + fn lower_to_memory(&mut self, address: &str, value: &str, ty: &Type) -> String { + let boxed: Box<[String]> = Vec::new().into_boxed_slice(); + + let mut f = FunctionBindgen::new( + self, + "", + &FunctionKind::AsyncFreestanding, + boxed, + vec![], + ParameterType::ABI, + None, + ); + abi::lower_to_memory( + f.interface_gen.resolve, + &mut f, + address.into(), + value.into(), + ty, + ); + + f.src + } } impl<'a> CoreInterfaceGenerator<'a> for InterfaceGenerator<'a> { @@ -1781,6 +1948,26 @@ impl<'a> CoreInterfaceGenerator<'a> for InterfaceGenerator<'a> { } } +// TODO: Is this not publicly available elsewhere, a function that says if a type needs to be returned as a pointer +fn needs_ptr(ty: &Type) -> bool { + // TODO: This list is not complete, e.g. handles should be here. + match *ty { + Type::Bool + | Type::S8 + | Type::U8 + | Type::S16 + | Type::U16 + | Type::S32 + | Type::U32 + | Type::S64 + | Type::U64 + | Type::Char + | Type::F32 + | Type::F64 => false, + _ => true, + } +} + // Handles the tag being the same name as the variant, which would cause a method with the same name as the type in C# which is not valid. pub fn variant_new_func_name(variant_name: &String, tag: &String) -> String { if *tag == *variant_name { diff --git a/crates/csharp/src/world_generator.rs b/crates/csharp/src/world_generator.rs index 21f6a52b9..7199b66fd 100644 --- a/crates/csharp/src/world_generator.rs +++ b/crates/csharp/src/world_generator.rs @@ -585,14 +585,25 @@ impl WorldGenerator for CSharp { ) } + if self.needs_async_support { + self.needs_export_return_area = true; + } + // Declare a statically-allocated return area, if needed. We only do // this for export bindings, because import bindings allocate their // return-area on the stack. if self.needs_export_return_area { let mut ret_area_str = String::new(); - let (array_size, element_type) = - dotnet_aligned_array(self.return_area_size, self.return_area_align); + //TODO: only generate if used. Currently we need this for any async function, even if it returns void. + let (array_size, element_type) = if self.return_area_size == 0 { + (1, "byte".to_owned()) + } else { + crate::world_generator::dotnet_aligned_array( + self.return_area_size, + self.return_area_align, + ) + }; uwrite!( ret_area_str, diff --git a/crates/test/src/csharp.rs b/crates/test/src/csharp.rs index e95adc3b3..a4d017e18 100644 --- a/crates/test/src/csharp.rs +++ b/crates/test/src/csharp.rs @@ -48,7 +48,6 @@ impl LanguageMethods for Csharp { | "async-resource-func.wit" | "import-export-resource.wit" | "issue-1433.wit" - | "issue-1598.wit" | "named-fixed-length-list.wit" | "map.wit" ) diff --git a/tests/runtime-async/async/ping-pong/runner.cs b/tests/runtime-async/async/ping-pong/runner.cs new file mode 100644 index 000000000..f5621eb1b --- /dev/null +++ b/tests/runtime-async/async/ping-pong/runner.cs @@ -0,0 +1,36 @@ +using System.Diagnostics; +using RunnerWorld.wit.Imports.my.test; +using RunnerWorld; + +public class RunnerWorldExportsImpl +{ + public static async Task Run() + { + try + { + string pingResult; + { + var (reader, writer) = IIImports.FutureNewString(); + var pingTask = IIImports.Ping(reader, "world"); + await writer.Write("hello"); + var pingFutureResult = await pingTask; + var result = await pingFutureResult.Read(); + Debug.Assert(result == "helloworld"); + + pingResult = result; + } + + { + var (reader, writer) = IIImports.FutureNewString(); + var pongTask = IIImports.Pong(reader); + await writer.Write(pingResult); + var pongResult = await pongTask; + Debug.Assert(pongResult == "helloworld"); + } + } + catch(Exception e) + { + Console.WriteLine(e); + } + } +} diff --git a/tests/runtime-async/async/ping-pong/test.cs b/tests/runtime-async/async/ping-pong/test.cs new file mode 100644 index 000000000..fbfc7b4b2 --- /dev/null +++ b/tests/runtime-async/async/ping-pong/test.cs @@ -0,0 +1,29 @@ +using System.Diagnostics; +using System.Runtime.InteropServices; +using System.Threading.Tasks; + +namespace TestWorld.wit.Exports.my.test +{ + public class IExportsImpl : IIExports + { + public static async Task> Ping(FutureReader future, string s) + { + var msg = (await future.Read()) + s; + var (newFutureReader, newFutureWriter) = IIExports.FutureNewString(); + var writeTask = newFutureWriter.Write(msg); + writeTask.ContinueWith(t => + { + if(t.Exception != null) + { + Debug.Fail("Exception in returned future write." + t.Exception); + } + }); + return newFutureReader; + } + + public static async Task Pong(FutureReader future) + { + return await future.Read(); + } + } +}