diff --git a/src/qs_codec/utils/encode_utils.py b/src/qs_codec/utils/encode_utils.py index 0cc519b..372506c 100644 --- a/src/qs_codec/utils/encode_utils.py +++ b/src/qs_codec/utils/encode_utils.py @@ -14,9 +14,24 @@ class EncodeUtils: """A collection of encode utility methods used by the library.""" - HEX_TABLE: t.Tuple[str, ...] = tuple(f"%{i.to_bytes(1, 'big').hex().upper().zfill(2)}" for i in range(256)) + HEX_TABLE: t.Tuple[str, ...] = tuple(f"%{i:02X}" for i in range(256)) """Hex table of all 256 characters""" + SAFE_ALPHA: t.Set[int] = set(range(0x30, 0x3A)) | set(range(0x41, 0x5B)) | set(range(0x61, 0x7B)) + """0-9, A-Z, a-z""" + + SAFE_POINTS: t.Set[int] = SAFE_ALPHA | {0x40, 0x2A, 0x5F, 0x2D, 0x2B, 0x2E, 0x2F} + """0-9, A-Z, a-z, @, *, _, -, +, ., /""" + + RFC1738_SAFE_POINTS: t.Set[int] = SAFE_POINTS | {0x28, 0x29} + """0-9, A-Z, a-z, @, *, _, -, +, ., /, (, )""" + + SAFE_CHARS: t.Set[int] = SAFE_ALPHA | {0x2D, 0x2E, 0x5F, 0x7E} + """0-9, A-Z, a-z, -, ., _, ~""" + + RFC1738_SAFE_CHARS = SAFE_CHARS | {0x28, 0x29} + """0-9, A-Z, a-z, -, ., _, ~, (, )""" + @classmethod def escape( cls, @@ -27,36 +42,22 @@ def escape( https://developer.mozilla.org/en-US/docs/web/javascript/reference/global_objects/escape """ + # Build a set of "safe" code points. + safe_points: t.Set[int] = cls.RFC1738_SAFE_POINTS if format == Format.RFC1738 else cls.SAFE_POINTS + buffer: t.List[str] = [] i: int - for i, _ in enumerate(string): + char: str + for i, char in enumerate(string): + # Use code_unit_at if it does more than ord() c: int = code_unit_at(string, i) - - # These 69 characters are safe for escaping - # ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789@*_+-./ - if ( - (0x30 <= c <= 0x39) # 0-9 - or (0x41 <= c <= 0x5A) # A-Z - or (0x61 <= c <= 0x7A) # a-z - or c == 0x40 # @ - or c == 0x2A # * - or c == 0x5F # _ - or c == 0x2D # - - or c == 0x2B # + - or c == 0x2E # . - or c == 0x2F # / - or (format == Format.RFC1738 and (c == 0x28 or c == 0x29)) # ( ) - ): - buffer.append(string[i]) - continue - - if c < 256: - buffer.extend([f"%{c.to_bytes(1, 'big').hex().upper().zfill(2)}"]) - continue - - buffer.extend([f"%u{c.to_bytes(2, 'big').hex().upper().zfill(4)}"]) - + if c in safe_points: + buffer.append(char) + elif c < 256: + buffer.append(f"%{c:02X}") + else: + buffer.append(f"%u{c:04X}") return "".join(buffer) @classmethod @@ -70,83 +71,92 @@ def encode( if value is None or not isinstance(value, (int, float, Decimal, Enum, str, bool, bytes)): return "" - string: str - if isinstance(value, bytes): - string = value.decode("utf-8") - elif isinstance(value, bool): - string = str(value).lower() - elif isinstance(value, str): - string = value - else: - string = str(value) + string: str = cls._convert_value_to_string(value) - if value == "": + if not string: return "" if charset == Charset.LATIN1: return re.sub( r"%u[0-9a-f]{4}", lambda match: f"%26%23{int(match.group(0)[2:], 16)}%3B", - cls.escape(cls.to_surrogates(string), format), + cls.escape(cls._to_surrogates(string), format), flags=re.IGNORECASE, ) + return cls._encode_string(string, format) + + @staticmethod + def _convert_value_to_string(value: t.Any) -> str: + """Convert the value to a string based on its type.""" + if isinstance(value, bytes): + return value.decode("utf-8") + elif isinstance(value, bool): + return str(value).lower() + elif isinstance(value, str): + return value + else: + return str(value) + + @classmethod + def _encode_string(cls, string: str, format: t.Optional[Format]) -> str: + """Encode the string to a URL-encoded format.""" buffer: t.List[str] = [] i: int for i, _ in enumerate(string): c: int = code_unit_at(string, i) - if ( - c == 0x2D # - - or c == 0x2E # . - or c == 0x5F # _ - or c == 0x7E # ~ - or (0x30 <= c <= 0x39) # 0-9 - or (0x41 <= c <= 0x5A) # a-z - or (0x61 <= c <= 0x7A) # A-Z - or (format == Format.RFC1738 and (c == 0x28 or c == 0x29)) # ( ) - ): + if cls._is_safe_char(c, format): buffer.append(string[i]) - continue - elif c < 0x80: # ASCII - buffer.extend([cls.HEX_TABLE[c]]) - continue - elif c < 0x800: # 2 bytes - buffer.extend( - [ - cls.HEX_TABLE[0xC0 | (c >> 6)], - cls.HEX_TABLE[0x80 | (c & 0x3F)], - ], - ) - continue - elif c < 0xD800 or c >= 0xE000: # 3 bytes - buffer.extend( - [ - cls.HEX_TABLE[0xE0 | (c >> 12)], - cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)], - cls.HEX_TABLE[0x80 | (c & 0x3F)], - ], - ) - continue else: - i += 1 - c = 0x10000 + (((c & 0x3FF) << 10) | (code_unit_at(string, i) & 0x3FF)) - buffer.extend( - [ - cls.HEX_TABLE[0xF0 | (c >> 18)], - cls.HEX_TABLE[0x80 | ((c >> 12) & 0x3F)], - cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)], - cls.HEX_TABLE[0x80 | (c & 0x3F)], - ], - ) + buffer.extend(cls._encode_char(string, i, c)) return "".join(buffer) + @classmethod + def _is_safe_char(cls, c: int, format: t.Optional[Format]) -> bool: + """Check if the character (given by its code point) is safe to be included in the URL without encoding.""" + return c in cls.RFC1738_SAFE_CHARS if format == Format.RFC1738 else c in cls.SAFE_CHARS + + @classmethod + def _encode_char(cls, string: str, i: int, c: int) -> t.List[str]: + """Encode a single character to its URL-encoded representation.""" + if c < 0x80: # ASCII + return [cls.HEX_TABLE[c]] + elif c < 0x800: # 2 bytes + return [ + cls.HEX_TABLE[0xC0 | (c >> 6)], + cls.HEX_TABLE[0x80 | (c & 0x3F)], + ] + elif c < 0xD800 or c >= 0xE000: # 3 bytes + return [ + cls.HEX_TABLE[0xE0 | (c >> 12)], + cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)], + cls.HEX_TABLE[0x80 | (c & 0x3F)], + ] + else: + return cls._encode_surrogate_pair(string, i, c) + + @classmethod + def _encode_surrogate_pair(cls, string: str, i: int, c: int) -> t.List[str]: + """Encode a surrogate pair character to its URL-encoded representation.""" + buffer: t.List[str] = [] + c = 0x10000 + (((c & 0x3FF) << 10) | (code_unit_at(string, i + 1) & 0x3FF)) + buffer.extend( + [ + cls.HEX_TABLE[0xF0 | (c >> 18)], + cls.HEX_TABLE[0x80 | ((c >> 12) & 0x3F)], + cls.HEX_TABLE[0x80 | ((c >> 6) & 0x3F)], + cls.HEX_TABLE[0x80 | (c & 0x3F)], + ], + ) + return buffer + @staticmethod - def to_surrogates(string: str) -> str: + def _to_surrogates(string: str) -> str: """Convert characters in the string that are outside the BMP (i.e. code points > 0xFFFF) into their corresponding surrogate pair.""" - result: t.List[str] = [] + buffer: t.List[str] = [] ch: str for ch in string: @@ -156,11 +166,11 @@ def to_surrogates(string: str) -> str: cp -= 0x10000 high: int = 0xD800 + (cp >> 10) low: int = 0xDC00 + (cp & 0x3FF) - result.append(chr(high)) - result.append(chr(low)) + buffer.append(chr(high)) + buffer.append(chr(low)) else: - result.append(ch) - return "".join(result) + buffer.append(ch) + return "".join(buffer) @staticmethod def serialize_date(dt: datetime) -> str: diff --git a/tests/unit/utils_test.py b/tests/unit/utils_test.py index c20034b..52c7f79 100644 --- a/tests/unit/utils_test.py +++ b/tests/unit/utils_test.py @@ -483,9 +483,9 @@ def test_merges_array_into_object(self) -> None: ) == {"foo": {"bar": "baz", "0": "xyzzy"}} def test_combine_both_arrays(self) -> None: - a = [1] - b = [2] - combined = Utils.combine(a, b) + a: t.List[int] = [1] + b: t.List[int] = [2] + combined: t.List[int] = Utils.combine(a, b) assert a == [1] assert b == [2] @@ -494,31 +494,31 @@ def test_combine_both_arrays(self) -> None: assert combined == [1, 2] def test_combine_one_array_one_non_array(self) -> None: - aN = 1 - a = [aN] - bN = 2 - b = [bN] - - combined_an_b = Utils.combine(aN, b) - assert b == [bN] - assert aN is not combined_an_b + a_n: int = 1 + a: t.List[int] = [a_n] + b_n: int = 2 + b: t.List[int] = [b_n] + + combined_an_b: t.List[int] = Utils.combine(a_n, b) + assert b == [b_n] + assert a_n is not combined_an_b assert a is not combined_an_b - assert bN is not combined_an_b + assert b_n is not combined_an_b assert b is not combined_an_b assert combined_an_b == [1, 2] - combined_a_bn = Utils.combine(a, bN) - assert a == [aN] - assert aN is not combined_a_bn + combined_a_bn = Utils.combine(a, b_n) + assert a == [a_n] + assert a_n is not combined_a_bn assert a is not combined_a_bn - assert bN is not combined_a_bn + assert b_n is not combined_a_bn assert b is not combined_a_bn assert combined_a_bn == [1, 2] def test_combine_neither_is_an_array(self) -> None: - a = 1 - b = 2 - combined = Utils.combine(a, b) + a: int = 1 + b: int = 2 + combined: t.List[int] = Utils.combine(a, b) assert a is not combined assert b is not combined @@ -579,4 +579,31 @@ def test_remove_undefined_from_map(self) -> None: ], ) def test_to_surrogates(self, input_str: str, expected: str) -> None: - assert EncodeUtils.to_surrogates(input_str) == expected + assert EncodeUtils._to_surrogates(input_str) == expected + + @pytest.mark.parametrize( + "char, format, expected", + [ + # Alphanumeric characters (always safe) + ("a", Format.RFC3986, True), + ("Z", Format.RFC3986, True), + ("0", Format.RFC3986, True), + # The safe punctuation in SAFE_CHARS: -, ., _, ~ + ("-", Format.RFC3986, True), + (".", Format.RFC3986, True), + ("_", Format.RFC3986, True), + ("~", Format.RFC3986, True), + # Parentheses are not in SAFE_CHARS but are in RFC1738_SAFE_CHARS. + ("(", Format.RFC3986, False), + (")", Format.RFC3986, False), + ("(", Format.RFC1738, True), + (")", Format.RFC1738, True), + # Characters that are not safe in either case. + ("@", Format.RFC3986, False), + ("@", Format.RFC1738, False), + ("*", Format.RFC3986, False), + ("*", Format.RFC1738, False), + ], + ) + def test_is_safe_char(self, char: str, format: Format, expected: bool) -> None: + assert EncodeUtils._is_safe_char(ord(char), format) is expected