diff --git a/src/Memory.cs b/src/Memory.cs index eb7e322a..89e0afa1 100644 --- a/src/Memory.cs +++ b/src/Memory.cs @@ -72,25 +72,83 @@ public uint GetSize() return Native.wasmtime_memory_size(store.Context.handle, this.memory); } + /// + /// Gets the current length of the memory, in bytes. + /// + /// Returns the current length of the memory, in bytes. + public long GetLength() + { + return checked((long)(nuint)Native.wasmtime_memory_data_size(store.Context.handle, this.memory)); + } + + /// + /// Returns a pointer to the start of the memory. The length for which the pointer + /// is valid can be retrieved with . + /// + /// Returns a pointer to the start of the memory. + /// + /// + /// The pointer may become invalid if the memory grows. + /// + /// This may happen if the memory is explicitly requested to grow or + /// grows as a result of WebAssembly execution. + /// + /// + /// Therefore, the returned pointer should not be used after calling the grow method or + /// after calling into WebAssembly code. + /// + /// + public unsafe IntPtr GetPointer() + { + var data = Native.wasmtime_memory_data(store.Context.handle, this.memory); + return (nint)data; + } + /// /// Gets the span of the memory. /// /// Returns the span of the memory. + /// The memory has more than 32767 pages. /// + /// /// The span may become invalid if the memory grows. /// /// This may happen if the memory is explicitly requested to grow or /// grows as a result of WebAssembly execution. - /// + /// + /// /// Therefore, the returned span should not be used after calling the grow method or /// after calling into WebAssembly code. + /// /// - public unsafe Span GetSpan() + [Obsolete("This method will throw an OverflowException if the memory has more than 32767 pages. " + + "Use the " + nameof(GetSpan) + " overload taking an address and a length.")] + public Span GetSpan() { - var context = store.Context; - var data = Native.wasmtime_memory_data(context.handle, this.memory); - var size = Convert.ToInt32(Native.wasmtime_memory_data_size(context.handle, this.memory).ToUInt32()); - return new Span(data, size); + return GetSpan(0, checked((int)GetLength())); + } + + /// + /// Gets a span of a section of the memory. + /// + /// Returns the span of a section of the memory. + /// The zero-based address of the start of the span. + /// The length of the span. + /// + /// + /// The span may become invalid if the memory grows. + /// + /// This may happen if the memory is explicitly requested to grow or + /// grows as a result of WebAssembly execution. + /// + /// + /// Therefore, the returned span should not be used after calling the grow method or + /// after calling into WebAssembly code. + /// + /// + public Span GetSpan(long address, int length) + { + return GetSpan(address, length); } /// @@ -98,19 +156,78 @@ public unsafe Span GetSpan() /// /// The zero-based address of the start of the span. /// Returns the span of the memory. + /// The memory exceeds the byte length that can be + /// represented by a . /// + /// /// The span may become invalid if the memory grows. /// /// This may happen if the memory is explicitly requested to grow or /// grows as a result of WebAssembly execution. - /// + /// + /// /// Therefore, the returned span should not be used after calling the grow method or /// after calling into WebAssembly code. + /// + /// + /// Note that WebAssembly always uses little endian as byte order. On platforms + /// that use big endian, you will need to convert numeric values accordingly. + /// /// public unsafe Span GetSpan(int address) where T : unmanaged { - return MemoryMarshal.Cast(GetSpan()[address..]); + return GetSpan(address, checked((int)((GetLength() - address) / sizeof(T)))); + } + + /// + /// Gets a span of a section of the memory. + /// + /// Returns the span of a section of the memory. + /// The zero-based address of the start of the span. + /// The length of the span. + /// + /// + /// The span may become invalid if the memory grows. + /// + /// This may happen if the memory is explicitly requested to grow or + /// grows as a result of WebAssembly execution. + /// + /// + /// Therefore, the returned span should not be used after calling the grow method or + /// after calling into WebAssembly code. + /// + /// + /// Note that WebAssembly always uses little endian as byte order. On platforms + /// that use big endian, you will need to convert numeric values accordingly. + /// + /// + public unsafe Span GetSpan(long address, int length) + where T : unmanaged + { + if (address < 0) + { + throw new ArgumentOutOfRangeException(nameof(address)); + } + + if (length < 0) + { + throw new ArgumentOutOfRangeException(nameof(length)); + } + + var context = store.Context; + var data = Native.wasmtime_memory_data(context.handle, this.memory); + var memoryLength = this.GetLength(); + + // Note: A Span can span more than 2 GiB bytes if sizeof(T) > 1. + long byteLength = (long)length * sizeof(T); + + if (address > memoryLength - byteLength) + { + throw new ArgumentException("The specified address and length exceed the Memory's bounds."); + } + + return new Span((T*)(data + address), length); } /// @@ -119,10 +236,16 @@ public unsafe Span GetSpan(int address) /// Type of the struct to read. Ensure layout in C# is identical to layout in WASM. /// The zero-based address to read from. /// Returns the struct read from memory. - public T Read(int address) + /// + /// + /// Note that WebAssembly always uses little endian as byte order. On platforms + /// that use big endian, you will need to convert numeric values accordingly. + /// + /// + public T Read(long address) where T : unmanaged { - return GetSpan(address)[0]; + return GetSpan(address, 1)[0]; } /// @@ -131,10 +254,16 @@ public T Read(int address) /// /// The zero-based address to read from. /// The struct to write. - public void Write(int address, T value) + /// + /// + /// Note that WebAssembly always uses little endian as byte order. On platforms + /// that use big endian, you will need to convert numeric values accordingly. + /// + /// + public void Write(long address, T value) where T : unmanaged { - GetSpan(address)[0] = value; + GetSpan(address, 1)[0] = value; } /// @@ -144,14 +273,14 @@ public void Write(int address, T value) /// The length of bytes to read. /// The encoding to use when reading the string; if null, UTF-8 encoding is used. /// Returns the string read from memory. - public string ReadString(int address, int length, Encoding? encoding = null) + public string ReadString(long address, int length, Encoding? encoding = null) { if (encoding is null) { encoding = Encoding.UTF8; } - return encoding.GetString(GetSpan().Slice(address, length)); + return encoding.GetString(GetSpan(address, length)); } /// @@ -159,9 +288,15 @@ public string ReadString(int address, int length, Encoding? encoding = null) /// /// The zero-based address to read from. /// Returns the string read from memory. - public string ReadNullTerminatedString(int address) + public string ReadNullTerminatedString(long address) { - var slice = GetSpan().Slice(address); + if (address < 0) + { + throw new ArgumentOutOfRangeException(nameof(address)); + } + + // We can only read a maximum of 2 GiB. + var slice = GetSpan(address, (int)Math.Min(int.MaxValue, GetLength() - address)); var terminator = slice.IndexOf((byte)0); if (terminator == -1) { @@ -178,14 +313,19 @@ public string ReadNullTerminatedString(int address) /// The string to write. /// The encoding to use when writing the string; if null, UTF-8 encoding is used. /// Returns the number of bytes written. - public int WriteString(int address, string value, Encoding? encoding = null) + public int WriteString(long address, string value, Encoding? encoding = null) { + if (address < 0) + { + throw new ArgumentOutOfRangeException(nameof(address)); + } + if (encoding is null) { encoding = Encoding.UTF8; } - return encoding.GetBytes(value, GetSpan().Slice(address)); + return encoding.GetBytes(value, GetSpan(address, (int)Math.Min(int.MaxValue, GetLength() - address))); } /// @@ -193,9 +333,9 @@ public int WriteString(int address, string value, Encoding? encoding = null) /// /// The zero-based address to read from. /// Returns the byte read from memory. - public byte ReadByte(int address) + public byte ReadByte(long address) { - return GetSpan()[address]; + return GetSpan(address, sizeof(byte))[0]; } /// @@ -203,9 +343,9 @@ public byte ReadByte(int address) /// /// The zero-based address to write to. /// The byte to write. - public void WriteByte(int address, byte value) + public void WriteByte(long address, byte value) { - GetSpan()[address] = value; + GetSpan(address, sizeof(byte))[0] = value; } /// @@ -213,9 +353,9 @@ public void WriteByte(int address, byte value) /// /// The zero-based address to read from. /// Returns the short read from memory. - public short ReadInt16(int address) + public short ReadInt16(long address) { - return BinaryPrimitives.ReadInt16LittleEndian(GetSpan().Slice(address, 2)); + return BinaryPrimitives.ReadInt16LittleEndian(GetSpan(address, sizeof(short))); } /// @@ -223,9 +363,9 @@ public short ReadInt16(int address) /// /// The zero-based address to write to. /// The short to write. - public void WriteInt16(int address, short value) + public void WriteInt16(long address, short value) { - BinaryPrimitives.WriteInt16LittleEndian(GetSpan().Slice(address, 2), value); + BinaryPrimitives.WriteInt16LittleEndian(GetSpan(address, sizeof(short)), value); } /// @@ -233,9 +373,9 @@ public void WriteInt16(int address, short value) /// /// The zero-based address to read from. /// Returns the int read from memory. - public int ReadInt32(int address) + public int ReadInt32(long address) { - return BinaryPrimitives.ReadInt32LittleEndian(GetSpan().Slice(address, 4)); + return BinaryPrimitives.ReadInt32LittleEndian(GetSpan(address, sizeof(int))); } /// @@ -243,9 +383,9 @@ public int ReadInt32(int address) /// /// The zero-based address to write to. /// The int to write. - public void WriteInt32(int address, int value) + public void WriteInt32(long address, int value) { - BinaryPrimitives.WriteInt32LittleEndian(GetSpan().Slice(address, 4), value); + BinaryPrimitives.WriteInt32LittleEndian(GetSpan(address, sizeof(int)), value); } /// @@ -253,9 +393,9 @@ public void WriteInt32(int address, int value) /// /// The zero-based address to read from. /// Returns the long read from memory. - public long ReadInt64(int address) + public long ReadInt64(long address) { - return BinaryPrimitives.ReadInt64LittleEndian(GetSpan().Slice(address, 8)); + return BinaryPrimitives.ReadInt64LittleEndian(GetSpan(address, sizeof(long))); } /// @@ -263,9 +403,9 @@ public long ReadInt64(int address) /// /// The zero-based address to write to. /// The long to write. - public void WriteInt64(int address, long value) + public void WriteInt64(long address, long value) { - BinaryPrimitives.WriteInt64LittleEndian(GetSpan().Slice(address, 8), value); + BinaryPrimitives.WriteInt64LittleEndian(GetSpan(address, sizeof(long)), value); } /// @@ -273,7 +413,7 @@ public void WriteInt64(int address, long value) /// /// The zero-based address to read from. /// Returns the IntPtr read from memory. - public IntPtr ReadIntPtr(int address) + public IntPtr ReadIntPtr(long address) { if (IntPtr.Size == 4) { @@ -287,15 +427,15 @@ public IntPtr ReadIntPtr(int address) /// /// The zero-based address to write to. /// The IntPtr to write. - public void WriteIntPtr(int address, IntPtr value) + public void WriteIntPtr(long address, IntPtr value) { if (IntPtr.Size == 4) { - WriteInt32(address, value.ToInt32()); + WriteInt32(address, (int)value); } else { - WriteInt64(address, value.ToInt64()); + WriteInt64(address, (long)value); } } @@ -304,13 +444,9 @@ public void WriteIntPtr(int address, IntPtr value) /// /// The zero-based address to read from. /// Returns the single read from memory. - public float ReadSingle(int address) + public float ReadSingle(long address) { - unsafe - { - var i = ReadInt32(address); - return *((float*)&i); - } + return BitConverter.Int32BitsToSingle(ReadInt32(address)); } /// @@ -318,12 +454,9 @@ public float ReadSingle(int address) /// /// The zero-based address to write to. /// The single to write. - public void WriteSingle(int address, float value) + public void WriteSingle(long address, float value) { - unsafe - { - WriteInt32(address, *(int*)&value); - } + WriteInt32(address, BitConverter.SingleToInt32Bits(value)); } /// @@ -331,13 +464,9 @@ public void WriteSingle(int address, float value) /// /// The zero-based address to read from. /// Returns the double read from memory. - public double ReadDouble(int address) + public double ReadDouble(long address) { - unsafe - { - var i = ReadInt64(address); - return *((double*)&i); - } + return BitConverter.Int64BitsToDouble(ReadInt64(address)); } /// @@ -345,12 +474,9 @@ public double ReadDouble(int address) /// /// The zero-based address to write to. /// The double to write. - public void WriteDouble(int address, double value) + public void WriteDouble(long address, double value) { - unsafe - { - WriteInt64(address, *(long*)&value); - } + WriteInt64(address, BitConverter.DoubleToInt64Bits(value)); } /// diff --git a/tests/MemoryAccessTests.cs b/tests/MemoryAccessTests.cs new file mode 100644 index 00000000..d1afa5d4 --- /dev/null +++ b/tests/MemoryAccessTests.cs @@ -0,0 +1,99 @@ +using System; +using System.Linq; +using FluentAssertions; +using Xunit; + +namespace Wasmtime.Tests +{ + public class MemoryAccessFixture : ModuleFixture + { + protected override string ModuleFileName => "MemoryAccess.wat"; + } + + public class MemoryAccessTests : IClassFixture, IDisposable + { + private MemoryAccessFixture Fixture { get; set; } + + private Store Store { get; set; } + + private Linker Linker { get; set; } + + public MemoryAccessTests(MemoryAccessFixture fixture) + { + Fixture = fixture; + Store = new Store(Fixture.Engine); + Linker = new Linker(Fixture.Engine); + } + + [Fact] + public unsafe void ItCanAccessMemoryWith65536Pages() + { + var instance = Linker.Instantiate(Store, Fixture.Module); + var memory = instance.GetMemory("mem"); + + memory.GetLength().Should().Be(uint.MaxValue + 1L); + + memory.ReadInt32(0).Should().Be(0); + memory.ReadInt32(0L).Should().Be(0); + + memory.WriteInt64(100, 1234); + memory.ReadInt64(100L).Should().Be(1234); + + memory.ReadByte(uint.MaxValue).Should().Be(0x63); + memory.ReadInt16(uint.MaxValue - 1).Should().Be(0x6364); + memory.ReadInt32(uint.MaxValue - 3).Should().Be(0x63646500); + + memory.ReadSingle(uint.MaxValue - 3).Should().Be(4.2131355E+21F); + + var span = memory.GetSpan(uint.MaxValue - 1, 2); + span.SequenceEqual(new byte[] { 0x64, 0x63 }).Should().BeTrue(); + + var int16Span = memory.GetSpan(0, int.MaxValue); + int16Span[int.MaxValue - 1].Should().Be(0x6500); + + int16Span = memory.GetSpan(2); + int16Span[int.MaxValue - 1].Should().Be(0x6364); + + byte* ptr = (byte*)memory.GetPointer(); + ptr += uint.MaxValue; + (*ptr).Should().Be(0x63); + + string str1 = "Hello World"; + memory.WriteString(uint.MaxValue - str1.Length, str1); + memory.ReadString(uint.MaxValue - str1.Length, str1.Length).Should().Be(str1); + } + + [Fact] + public void ItThrowsForOutOfBoundsAccess() + { + var instance = Linker.Instantiate(Store, Fixture.Module); + var memory = instance.GetMemory("mem"); + +#pragma warning disable CS0618 // Type or member is obsolete + Action action = () => memory.GetSpan(); +#pragma warning restore CS0618 // Type or member is obsolete + action.Should().Throw(); + + action = () => memory.GetSpan(0); + action.Should().Throw(); + + action = () => memory.GetSpan(-1L, 0); + action.Should().Throw(); + + action = () => memory.GetSpan(0L, -1); + action.Should().Throw(); + + action = () => memory.ReadInt16(uint.MaxValue); + action.Should().Throw(); + + action = () => memory.GetSpan(uint.MaxValue, 1); + action.Should().Throw(); + } + + public void Dispose() + { + Store.Dispose(); + Linker.Dispose(); + } + } +} diff --git a/tests/Modules/MemoryAccess.wat b/tests/Modules/MemoryAccess.wat new file mode 100644 index 00000000..f64b97a4 --- /dev/null +++ b/tests/Modules/MemoryAccess.wat @@ -0,0 +1,15 @@ +(module + (memory (export "mem") 65536) + (func $start + i32.const 4294967295 + i32.const 99 + i32.store8 + i32.const 4294967294 + i32.const 100 + i32.store8 + i32.const 4294967293 + i32.const 101 + i32.store8 + ) + (start $start) +)