diff --git a/numcodecs/zstd.pyx b/numcodecs/zstd.pyx index f93da633..7b88a8da 100644 --- a/numcodecs/zstd.pyx +++ b/numcodecs/zstd.pyx @@ -15,6 +15,9 @@ from .abc import Codec from libc.stdlib cimport malloc, realloc, free +cdef extern from "stdint.h": + cdef size_t SIZE_MAX + cdef extern from "zstd.h": unsigned ZSTD_versionNumber() nogil @@ -202,8 +205,8 @@ def decompress(source, dest=None): Py_buffer* dest_pb char* dest_ptr size_t source_size, dest_size, decompressed_size - size_t nbytes, cbytes, blocksize size_t dest_nbytes + unsigned long long content_size # obtain source memoryview source_mv = ensure_continguous_memoryview(source) @@ -214,14 +217,20 @@ def decompress(source, dest=None): source_size = source_pb.len try: - - # determine uncompressed size - dest_size = ZSTD_getFrameContentSize(source_ptr, source_size) - if dest_size == 0 or dest_size == ZSTD_CONTENTSIZE_ERROR: + # determine uncompressed size using unsigned long long for full range + content_size = ZSTD_getFrameContentSize(source_ptr, source_size) + if content_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None: + return stream_decompress(source_pb) + elif content_size == ZSTD_CONTENTSIZE_UNKNOWN: + # dest is not None + # set dest_size based on dest + pass + elif content_size == ZSTD_CONTENTSIZE_ERROR or content_size == 0: raise RuntimeError('Zstd decompression error: invalid input data') + elif content_size > (SIZE_MAX): + raise RuntimeError('Zstd decompression error: content size too large for platform') - if dest_size == ZSTD_CONTENTSIZE_UNKNOWN and dest is None: - return stream_decompress(source_pb) + dest_size = content_size # setup destination buffer if dest is None: @@ -236,7 +245,7 @@ def decompress(source, dest=None): dest_ptr = dest_pb.buf dest_nbytes = dest_pb.len - if dest_size == ZSTD_CONTENTSIZE_UNKNOWN: + if content_size == ZSTD_CONTENTSIZE_UNKNOWN: dest_size = dest_nbytes # validate output buffer