Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 65 additions & 7 deletions zstd/src/main/java/io/github/dfa1/zstd/Zstd.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,44 @@ public static byte[] compress(byte[] src) {
/// @return a self-describing zstd frame
public static byte[] compress(byte[] src, int level) {
Objects.requireNonNull(src, "src");
return compress(src, 0, src.length, level);
}

/// Compresses the `length`-byte sub-range of `src` starting at `offset`, at the
/// library default level. Lets a caller holding a payload inside a larger buffer
/// compress it without copying the sub-range out first.
///
/// @param src buffer holding the bytes to compress
/// @param offset index of the first byte to compress
/// @param length number of bytes to compress
/// @return a self-describing zstd frame
/// @throws IndexOutOfBoundsException if `offset` and `length` do not denote a
/// valid range within `src`
public static byte[] compress(byte[] src, int offset, int length) {
return compress(src, offset, length, defaultCompressionLevel());
}

/// Compresses the `length`-byte sub-range of `src` starting at `offset`, at the
/// given level. Lets a caller holding a payload inside a larger buffer compress
/// it without copying the sub-range out first.
///
/// @param src buffer holding the bytes to compress
/// @param offset index of the first byte to compress
/// @param length number of bytes to compress
/// @param level compression level in [[#minCompressionLevel()], [#maxCompressionLevel()]];
/// higher is smaller but slower
/// @return a self-describing zstd frame
/// @throws IndexOutOfBoundsException if `offset` and `length` do not denote a
/// valid range within `src`
public static byte[] compress(byte[] src, int offset, int length, int level) {
Objects.requireNonNull(src, "src");
Objects.checkFromIndexSize(offset, length, src.length);
try (Arena arena = Arena.ofConfined()) {
MemorySegment in = copyIn(arena, src);
long bound = compressBound(src.length);
MemorySegment in = copyIn(arena, src, offset, length);
long bound = compressBound(length);
MemorySegment out = arena.allocate(bound);
long written = NativeCall.checkReturnValue(() -> (long) Bindings.COMPRESS.invokeExact(
out, bound, in, (long) src.length, level));
out, bound, in, (long) length, level));
return copyOut(out, written);
}
}
Expand Down Expand Up @@ -101,11 +133,33 @@ private static int toArrayLength(long size) {
/// @throws ZstdException if the frame is invalid or larger than `maxSize`
public static byte[] decompress(byte[] compressed, int maxSize) {
Objects.requireNonNull(compressed, "compressed");
return decompress(compressed, 0, compressed.length, maxSize);
}

/// Decompresses the `length`-byte sub-range of `compressed` starting at `offset`
/// into a buffer of at most `maxSize` bytes. Lets a caller decode a frame embedded
/// inside a larger buffer without copying the frame out first; otherwise identical
/// to [#decompress(byte[], int)].
///
/// As with [#decompress(byte[], int)], `maxSize` caps the allocation and decode,
/// so this is the safe entry point for **untrusted** input.
///
/// @param compressed buffer holding a complete zstd frame
/// @param offset index of the first byte of the frame
/// @param length number of bytes the frame occupies
/// @param maxSize upper bound on the decompressed length
/// @return the original bytes (length ≤ `maxSize`)
/// @throws IndexOutOfBoundsException if `offset` and `length` do not denote a
/// valid range within `compressed`
/// @throws ZstdException if the frame is invalid or larger than `maxSize`
public static byte[] decompress(byte[] compressed, int offset, int length, int maxSize) {
Objects.requireNonNull(compressed, "compressed");
Objects.checkFromIndexSize(offset, length, compressed.length);
try (Arena arena = Arena.ofConfined()) {
MemorySegment in = copyIn(arena, compressed);
MemorySegment in = copyIn(arena, compressed, offset, length);
MemorySegment out = arena.allocate(Math.max(maxSize, 1));
long written = NativeCall.checkReturnValue(() -> (long) Bindings.DECOMPRESS.invokeExact(
out, (long) maxSize, in, (long) compressed.length));
out, (long) maxSize, in, (long) length));
return copyOut(out, written);
}
}
Expand Down Expand Up @@ -316,8 +370,12 @@ public static int versionNumber() {
// Native-call status checking and segment guards live in NativeCall.

static MemorySegment copyIn(Arena arena, byte[] src) {
MemorySegment seg = arena.allocate(Math.max(src.length, 1));
MemorySegment.copy(src, 0, seg, JAVA_BYTE, 0, src.length);
return copyIn(arena, src, 0, src.length);
}

static MemorySegment copyIn(Arena arena, byte[] src, int offset, int length) {
MemorySegment seg = arena.allocate(Math.max(length, 1));
MemorySegment.copy(src, offset, seg, JAVA_BYTE, 0, length);
return seg;
}

Expand Down
147 changes: 147 additions & 0 deletions zstd/src/test/java/io/github/dfa1/zstd/ZstdTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
Expand Down Expand Up @@ -150,6 +151,152 @@ void rejectsNullInputWithANamedMessage() {
}
}

@Nested
class RangedOverloads {

@ParameterizedTest
@MethodSource("io.github.dfa1.zstd.ZstdTestSupport#bytes")
void rangedCompressRoundTripsToTheSubRange(byte[] payload) {
// Given a payload embedded in the middle of a larger buffer
byte[] buffer = embed(payload, 7, 5);

// When the sub-range is compressed and decompressed
byte[] frame = Zstd.compress(buffer, 7, payload.length);
byte[] restored = Zstd.decompress(frame, payload.length);

// Then the sub-range bytes come back exactly
assertThat(restored).isEqualTo(payload);
}

@ParameterizedTest
@MethodSource("io.github.dfa1.zstd.ZstdTestSupport#levels")
void rangedCompressEqualsCompressingTheExtractedSubRange(int level) {
// Given a payload embedded in a larger buffer
byte[] payload = ZstdTestSupport.randomBytes(42, 4096);
byte[] buffer = embed(payload, 13, 9);

// When compressing the sub-range and the equivalent extracted copy
byte[] ranged = Zstd.compress(buffer, 13, payload.length, level);
byte[] extracted = Zstd.compress(Arrays.copyOfRange(buffer, 13, 13 + payload.length), level);

// Then the frames are identical — the range is honored
assertThat(ranged).as("level %d", level).isEqualTo(extracted);
}

@Test
void fullRangeEqualsTheWholeArrayOverload() {
// Given a payload compressed both ways
byte[] payload = ZstdTestSupport.randomBytes(99, 2048);

// When using offset 0 + full length vs. the whole-array overload
byte[] ranged = Zstd.compress(payload, 0, payload.length);
byte[] whole = Zstd.compress(payload);

// Then they produce the same frame
assertThat(ranged).isEqualTo(whole);
}

@Test
void emptyRangeRoundTrips() {
// Given a non-empty buffer but a zero-length range
byte[] buffer = "ignored payload".getBytes(StandardCharsets.UTF_8);

// When the empty range is compressed and decompressed
byte[] frame = Zstd.compress(buffer, 4, 0);
byte[] restored = Zstd.decompress(frame, 0);

// Then the result is empty
assertThat(restored).isEmpty();
}

@Test
void rangedDecompressDecodesAnEmbeddedFrame() {
// Given a frame embedded inside a larger buffer at an offset
byte[] payload = ZstdTestSupport.randomBytes(7, 3000);
byte[] frame = Zstd.compress(payload);
byte[] buffer = new byte[frame.length + 20];
System.arraycopy(frame, 0, buffer, 11, frame.length);

// When the embedded sub-range is decompressed
byte[] restored = Zstd.decompress(buffer, 11, frame.length, payload.length);

// Then the original payload comes back
assertThat(restored).isEqualTo(payload);
}

@Test
void rejectsNegativeOffset() {
// Given a buffer
byte[] buffer = new byte[16];

// When compressing with a negative offset
ThrowingCallable result = () -> Zstd.compress(buffer, -1, 4);

// Then it fails before touching native memory
assertThatThrownBy(result).isInstanceOf(IndexOutOfBoundsException.class);
}

@Test
void rejectsNegativeLength() {
// Given a buffer
byte[] buffer = new byte[16];

// When compressing with a negative length
ThrowingCallable result = () -> Zstd.compress(buffer, 0, -1);

// Then it fails
assertThatThrownBy(result).isInstanceOf(IndexOutOfBoundsException.class);
}

@Test
void rejectsRangeBeyondTheArray() {
// Given a buffer
byte[] buffer = new byte[16];

// When the range runs past the end of the array
ThrowingCallable result = () -> Zstd.compress(buffer, 10, 10);

// Then it fails
assertThatThrownBy(result).isInstanceOf(IndexOutOfBoundsException.class);
}

@Test
void rejectsNullSource() {
// When null is passed to the ranged overloads
ThrowingCallable compressNull = () -> Zstd.compress(null, 0, 0);
ThrowingCallable decompressNull = () -> Zstd.decompress(null, 0, 0, 0);

// Then each fails fast naming its parameter
assertThatThrownBy(compressNull)
.isInstanceOf(NullPointerException.class)
.hasMessageContaining("src");
assertThatThrownBy(decompressNull)
.isInstanceOf(NullPointerException.class)
.hasMessageContaining("compressed");
}

@Test
void rejectsRangedDecompressBeyondTheArray() {
// Given a buffer
byte[] buffer = new byte[16];

// When the decompress range runs past the end of the array
ThrowingCallable result = () -> Zstd.decompress(buffer, 10, 10, 16);

// Then it fails before touching native memory
assertThatThrownBy(result).isInstanceOf(IndexOutOfBoundsException.class);
}

// Embed `payload` into a larger buffer with `before` filler bytes ahead of it
// and `after` filler bytes behind it, so callers exercise a real sub-range.
private static byte[] embed(byte[] payload, int before, int after) {
byte[] buffer = new byte[before + payload.length + after];
Arrays.fill(buffer, (byte) 0x5A);
System.arraycopy(payload, 0, buffer, before, payload.length);
return buffer;
}
}

@Nested
class Metadata {

Expand Down
Loading