Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 95 additions & 85 deletions src/qs_codec/utils/encode_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,24 @@
class EncodeUtils:
"""A collection of encode utility methods used by the library."""

HEX_TABLE: t.Tuple[str, ...] = tuple(f"%{i.to_bytes(1, 'big').hex().upper().zfill(2)}" for i in range(256))
HEX_TABLE: t.Tuple[str, ...] = tuple(f"%{i:02X}" for i in range(256))
"""Hex table of all 256 characters"""

SAFE_ALPHA: t.Set[int] = set(range(0x30, 0x3A)) | set(range(0x41, 0x5B)) | set(range(0x61, 0x7B))
"""0-9, A-Z, a-z"""

SAFE_POINTS: t.Set[int] = SAFE_ALPHA | {0x40, 0x2A, 0x5F, 0x2D, 0x2B, 0x2E, 0x2F}
"""0-9, A-Z, a-z, @, *, _, -, +, ., /"""

RFC1738_SAFE_POINTS: t.Set[int] = SAFE_POINTS | {0x28, 0x29}
"""0-9, A-Z, a-z, @, *, _, -, +, ., /, (, )"""

SAFE_CHARS: t.Set[int] = SAFE_ALPHA | {0x2D, 0x2E, 0x5F, 0x7E}
"""0-9, A-Z, a-z, -, ., _, ~"""

RFC1738_SAFE_CHARS = SAFE_CHARS | {0x28, 0x29}
"""0-9, A-Z, a-z, -, ., _, ~, (, )"""

@classmethod
def escape(
cls,
Expand All @@ -27,36 +42,22 @@ def escape(

https://developer.mozilla.org/en-US/docs/web/javascript/reference/global_objects/escape
"""
# Build a set of "safe" code points.
safe_points: t.Set[int] = cls.RFC1738_SAFE_POINTS if format == Format.RFC1738 else cls.SAFE_POINTS

buffer: t.List[str] = []

i: int
for i, _ in enumerate(string):
char: str
for i, char in enumerate(string):
# Use code_unit_at if it does more than ord()
c: int = code_unit_at(string, i)

# These 69 characters are safe for escaping
# ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789@*_+-./
if (
(0x30 <= c <= 0x39) # 0-9
or (0x41 <= c <= 0x5A) # A-Z
or (0x61 <= c <= 0x7A) # a-z
or c == 0x40 # @
or c == 0x2A # *
or c == 0x5F # _
or c == 0x2D # -
or c == 0x2B # +
or c == 0x2E # .
or c == 0x2F # /
or (format == Format.RFC1738 and (c == 0x28 or c == 0x29)) # ( )
):
buffer.append(string[i])
continue

if c < 256:
buffer.extend([f"%{c.to_bytes(1, 'big').hex().upper().zfill(2)}"])
continue

buffer.extend([f"%u{c.to_bytes(2, 'big').hex().upper().zfill(4)}"])

if c in safe_points:
buffer.append(char)
elif c < 256:
buffer.append(f"%{c:02X}")
else:
buffer.append(f"%u{c:04X}")
return "".join(buffer)

@classmethod
Expand All @@ -70,83 +71,92 @@ def encode(
if value is None or not isinstance(value, (int, float, Decimal, Enum, str, bool, bytes)):
return ""

string: str
if isinstance(value, bytes):
string = value.decode("utf-8")
elif isinstance(value, bool):
string = str(value).lower()
elif isinstance(value, str):
string = value
else:
string = str(value)
string: str = cls._convert_value_to_string(value)

if value == "":
if not string:
return ""

if charset == Charset.LATIN1:
return re.sub(
r"%u[0-9a-f]{4}",
lambda match: f"%26%23{int(match.group(0)[2:], 16)}%3B",
cls.escape(cls.to_surrogates(string), format),
cls.escape(cls._to_surrogates(string), format),
flags=re.IGNORECASE,
)

return cls._encode_string(string, format)

@staticmethod
def _convert_value_to_string(value: t.Any) -> str:
"""Convert the value to a string based on its type."""
if isinstance(value, bytes):
return value.decode("utf-8")
elif isinstance(value, bool):
return str(value).lower()
elif isinstance(value, str):
return value
else:
return str(value)

@classmethod
def _encode_string(cls, string: str, format: t.Optional[Format]) -> str:
"""Encode the string to a URL-encoded format."""
buffer: t.List[str] = []

i: int
for i, _ in enumerate(string):
c: int = code_unit_at(string, i)

if (
c == 0x2D # -
or c == 0x2E # .
or c == 0x5F # _
or c == 0x7E # ~
or (0x30 <= c <= 0x39) # 0-9
or (0x41 <= c <= 0x5A) # a-z
or (0x61 <= c <= 0x7A) # A-Z
or (format == Format.RFC1738 and (c == 0x28 or c == 0x29)) # ( )
):
if cls._is_safe_char(c, format):
buffer.append(string[i])
continue
elif c < 0x80: # ASCII
buffer.extend([cls.HEX_TABLE[c]])
continue
elif c < 0x800: # 2 bytes
buffer.extend(
[
cls.HEX_TABLE[0xC0 | (c >> 6)],
cls.HEX_TABLE[0x80 | (c & 0x3F)],
],
)
continue
elif c < 0xD800 or c >= 0xE000: # 3 bytes
buffer.extend(
[
cls.HEX_TABLE[0xE0 | (c >> 12)],
cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)],
cls.HEX_TABLE[0x80 | (c & 0x3F)],
],
)
continue
else:
i += 1
c = 0x10000 + (((c & 0x3FF) << 10) | (code_unit_at(string, i) & 0x3FF))
buffer.extend(
[
cls.HEX_TABLE[0xF0 | (c >> 18)],
cls.HEX_TABLE[0x80 | ((c >> 12) & 0x3F)],
cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)],
cls.HEX_TABLE[0x80 | (c & 0x3F)],
],
)
buffer.extend(cls._encode_char(string, i, c))

return "".join(buffer)

@classmethod
def _is_safe_char(cls, c: int, format: t.Optional[Format]) -> bool:
"""Check if the character (given by its code point) is safe to be included in the URL without encoding."""
return c in cls.RFC1738_SAFE_CHARS if format == Format.RFC1738 else c in cls.SAFE_CHARS

@classmethod
def _encode_char(cls, string: str, i: int, c: int) -> t.List[str]:
"""Encode a single character to its URL-encoded representation."""
if c < 0x80: # ASCII
return [cls.HEX_TABLE[c]]
elif c < 0x800: # 2 bytes
return [
cls.HEX_TABLE[0xC0 | (c >> 6)],
cls.HEX_TABLE[0x80 | (c & 0x3F)],
]
elif c < 0xD800 or c >= 0xE000: # 3 bytes
return [
cls.HEX_TABLE[0xE0 | (c >> 12)],
cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)],
cls.HEX_TABLE[0x80 | (c & 0x3F)],
]
else:
return cls._encode_surrogate_pair(string, i, c)

@classmethod
def _encode_surrogate_pair(cls, string: str, i: int, c: int) -> t.List[str]:
"""Encode a surrogate pair character to its URL-encoded representation."""
buffer: t.List[str] = []
c = 0x10000 + (((c & 0x3FF) << 10) | (code_unit_at(string, i + 1) & 0x3FF))
buffer.extend(
[
cls.HEX_TABLE[0xF0 | (c >> 18)],
cls.HEX_TABLE[0x80 | ((c >> 12) & 0x3F)],
cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)],
cls.HEX_TABLE[0x80 | (c & 0x3F)],
],
)
return buffer

@staticmethod
def to_surrogates(string: str) -> str:
def _to_surrogates(string: str) -> str:
"""Convert characters in the string that are outside the BMP (i.e. code points > 0xFFFF) into their corresponding surrogate pair."""
result: t.List[str] = []
buffer: t.List[str] = []

ch: str
for ch in string:
Expand All @@ -156,11 +166,11 @@ def to_surrogates(string: str) -> str:
cp -= 0x10000
high: int = 0xD800 + (cp >> 10)
low: int = 0xDC00 + (cp & 0x3FF)
result.append(chr(high))
result.append(chr(low))
buffer.append(chr(high))
buffer.append(chr(low))
else:
result.append(ch)
return "".join(result)
buffer.append(ch)
return "".join(buffer)

@staticmethod
def serialize_date(dt: datetime) -> str:
Expand Down
67 changes: 47 additions & 20 deletions tests/unit/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,9 +483,9 @@ def test_merges_array_into_object(self) -> None:
) == {"foo": {"bar": "baz", "0": "xyzzy"}}

def test_combine_both_arrays(self) -> None:
a = [1]
b = [2]
combined = Utils.combine(a, b)
a: t.List[int] = [1]
b: t.List[int] = [2]
combined: t.List[int] = Utils.combine(a, b)

assert a == [1]
assert b == [2]
Expand All @@ -494,31 +494,31 @@ def test_combine_both_arrays(self) -> None:
assert combined == [1, 2]

def test_combine_one_array_one_non_array(self) -> None:
aN = 1
a = [aN]
bN = 2
b = [bN]

combined_an_b = Utils.combine(aN, b)
assert b == [bN]
assert aN is not combined_an_b
a_n: int = 1
a: t.List[int] = [a_n]
b_n: int = 2
b: t.List[int] = [b_n]

combined_an_b: t.List[int] = Utils.combine(a_n, b)
assert b == [b_n]
assert a_n is not combined_an_b
assert a is not combined_an_b
assert bN is not combined_an_b
assert b_n is not combined_an_b
assert b is not combined_an_b
assert combined_an_b == [1, 2]

combined_a_bn = Utils.combine(a, bN)
assert a == [aN]
assert aN is not combined_a_bn
combined_a_bn = Utils.combine(a, b_n)
assert a == [a_n]
assert a_n is not combined_a_bn
assert a is not combined_a_bn
assert bN is not combined_a_bn
assert b_n is not combined_a_bn
assert b is not combined_a_bn
assert combined_a_bn == [1, 2]

def test_combine_neither_is_an_array(self) -> None:
a = 1
b = 2
combined = Utils.combine(a, b)
a: int = 1
b: int = 2
combined: t.List[int] = Utils.combine(a, b)

assert a is not combined
assert b is not combined
Expand Down Expand Up @@ -579,4 +579,31 @@ def test_remove_undefined_from_map(self) -> None:
],
)
def test_to_surrogates(self, input_str: str, expected: str) -> None:
assert EncodeUtils.to_surrogates(input_str) == expected
assert EncodeUtils._to_surrogates(input_str) == expected

@pytest.mark.parametrize(
"char, format, expected",
[
# Alphanumeric characters (always safe)
("a", Format.RFC3986, True),
("Z", Format.RFC3986, True),
("0", Format.RFC3986, True),
# The safe punctuation in SAFE_CHARS: -, ., _, ~
("-", Format.RFC3986, True),
(".", Format.RFC3986, True),
("_", Format.RFC3986, True),
("~", Format.RFC3986, True),
# Parentheses are not in SAFE_CHARS but are in RFC1738_SAFE_CHARS.
("(", Format.RFC3986, False),
(")", Format.RFC3986, False),
("(", Format.RFC1738, True),
(")", Format.RFC1738, True),
# Characters that are not safe in either case.
("@", Format.RFC3986, False),
("@", Format.RFC1738, False),
("*", Format.RFC3986, False),
("*", Format.RFC1738, False),
],
)
def test_is_safe_char(self, char: str, format: Format, expected: bool) -> None:
assert EncodeUtils._is_safe_char(ord(char), format) is expected