Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tls_codec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ use std::io::{Read, Write};
mod arrays;
mod primitives;
mod quic_vec;
mod string;
mod tls_vec;
mod varint;

Expand Down
224 changes: 224 additions & 0 deletions tls_codec/src/string.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
use alloc::string::String;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a doc comment on top here that this is not specified in the TLS presentation language. But that it's what's implicitly done all the time. And then say how the string is actually being serialized (just its bytes).

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do!


#[cfg(feature = "std")]
use crate::{Deserialize, Serialize};

use crate::{DeserializeBytes, SerializeBytes, Size, TlsByteVecU8};

impl Size for String {
fn tls_serialized_len(&self) -> usize {
self.as_bytes().tls_serialized_len()
}
}

#[cfg(feature = "std")]
impl Serialize for String {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, crate::Error> {
Serialize::tls_serialize(&self.as_bytes(), writer)
}
}

impl SerializeBytes for String {
fn tls_serialize(&self) -> Result<alloc::vec::Vec<u8>, crate::Error> {
SerializeBytes::tls_serialize(&self.as_bytes())
}
}

impl Size for &str {
fn tls_serialized_len(&self) -> usize {
self.as_bytes().tls_serialized_len()
}
}

#[cfg(feature = "std")]
impl Serialize for &str {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, crate::Error> {
Serialize::tls_serialize(&self.as_bytes(), writer)
}
}

impl SerializeBytes for &str {
fn tls_serialize(&self) -> Result<alloc::vec::Vec<u8>, crate::Error> {
SerializeBytes::tls_serialize(&self.as_bytes())
}
}

#[cfg(feature = "std")]
impl Deserialize for String {
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, crate::Error>
where
Self: Sized,
{
let bytes = TlsByteVecU8::tls_deserialize(bytes)?;

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't right. You serialize to variable length encoding.
This goes back to the comment on top. We need to clearly describe how strings are being encoded if we want to provider a general implementation.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh, I thought this would happen by itself. Alright. It looks like VLByteSlice doesn't implement De/SerializeBytes yet, and I would like to impl those as well for nostd support. Should I do that in this PR or a separate one?

String::from_utf8(bytes.into())
.map_err(|err| crate::Error::DecodingError(format!("invalid utf8: {err}")))
}
}

impl DeserializeBytes for String {
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), crate::Error>
where
Self: Sized,
{
let (bytes, rest) = TlsByteVecU8::tls_deserialize_bytes(bytes)?;
let text = String::from_utf8(bytes.into())
.map_err(|err| crate::Error::DecodingError(format!("invalid utf8: {err}")))?;

Ok((text, rest))
}
}

#[cfg(test)]
mod tests {
use alloc::string::String;

#[cfg(feature = "std")]
use crate::{Deserialize, Serialize};

use crate::{DeserializeBytes, SerializeBytes, Size};

#[cfg(feature = "std")]
#[test]
fn serialize_empty_string() {
let s = String::new();
let buf = s.tls_serialize_detached().unwrap();
// TlsByteVecU8: 1-byte length prefix (0) + no payload
assert_eq!(buf, [0]);
assert_eq!(s.tls_serialized_len(), 1);
}

#[test]
fn serialize_empty_str() {
let s = "";

#[cfg(feature = "std")]
{
let mut buf = [0u8; 1];
Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap();
// TlsByteVecU8: 1-byte length prefix (0) + no payload
assert_eq!(buf, [0]);
assert_eq!(s.tls_serialized_len(), 1);
}

let buf = SerializeBytes::tls_serialize(&s).unwrap();
// TlsByteVecU8: 1-byte length prefix (0) + no payload
assert_eq!(buf, [0]);
assert_eq!(s.tls_serialized_len(), 1);
}

#[cfg(feature = "std")]
#[test]
fn serialize_hello_string() {
let s = String::from("hello");
let buf = s.tls_serialize_detached().unwrap();
// 1-byte length prefix (5) + b"hello"
assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']);
assert_eq!(s.tls_serialized_len(), 6);
}

#[test]
fn serialize_hello_str() {
let s = "hello";
#[cfg(feature = "std")]
{
let mut buf = [0u8; 6];
Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap();
// 1-byte length prefix (5) + b"hello"
assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']);
assert_eq!(s.tls_serialized_len(), 6);
}

let buf = SerializeBytes::tls_serialize(&s).unwrap();
// 1-byte length prefix (5) + b"hello"
assert_eq!(buf, [5, b'h', b'e', b'l', b'l', b'o']);
assert_eq!(s.tls_serialized_len(), 6);
}

#[cfg(feature = "std")]
#[test]
fn serialize_multibyte_utf8_string() {
// U+00FC = "ü", encoded as 2 bytes in UTF-8: [0xC3, 0xBC]
let s = String::from("ü");
let buf = s.tls_serialize_detached().unwrap();
assert_eq!(buf, [2, 0xC3, 0xBC]);
assert_eq!(s.tls_serialized_len(), 3);
}

#[test]
fn serialize_multibyte_utf8_str() {
// U+00FC = "ü", encoded as 2 bytes in UTF-8: [0xC3, 0xBC]
let s = "ü";
#[cfg(feature = "std")]
{
let mut buf = [0u8; 3];
Serialize::tls_serialize(&s, &mut buf.as_mut_slice()).unwrap();
assert_eq!(buf, [2, 0xC3, 0xBC]);
assert_eq!(s.tls_serialized_len(), 3);
}

let buf = SerializeBytes::tls_serialize(&s).unwrap();
assert_eq!(buf, [2, 0xC3, 0xBC]);
assert_eq!(s.tls_serialized_len(), 3);
}

#[cfg(feature = "std")]
#[test]
fn roundtrip_deserialize() {
let original = String::from("roundtrip test");

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a longer string here would surface the incompatible serialize/deserialize. e.g. it will with let original = String::from_utf8(vec![0x30; 300]).unwrap();.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right, that fails. I'll add it to ensure this is caught.

let buf = original.tls_serialize_detached().unwrap();
let deserialized = String::tls_deserialize_exact(&buf).unwrap();
assert_eq!(original, deserialized);
}

#[cfg(feature = "std")]
#[test]
fn roundtrip_deserialize_empty() {
let original = String::new();
let buf = original.tls_serialize_detached().unwrap();
let deserialized = String::tls_deserialize_exact(&buf).unwrap();
assert_eq!(original, deserialized);
}

#[cfg(feature = "std")]
#[test]
fn deserialize_invalid_utf8() {
// length prefix 2 + two bytes that are not valid UTF-8
let buf: &[u8] = &[2, 0xFF, 0xFE];
let err = String::tls_deserialize_exact(buf).unwrap_err();
assert!(matches!(err, crate::Error::DecodingError(msg) if msg.contains("invalid utf8")));
}

#[test]
fn deserialize_bytes_hello() {
let input = [5, b'h', b'e', b'l', b'l', b'o'];
let (s, rest) = String::tls_deserialize_bytes(&input).unwrap();
assert_eq!(s, "hello");
assert!(rest.is_empty());
assert_eq!(s.tls_serialized_len(), 6);
}

#[test]
fn deserialize_bytes_with_trailing_data() {
// "hi" (length 2) followed by extra byte 0x99
let input = [2, b'h', b'i', 0x99];
let (s, rest) = String::tls_deserialize_bytes(&input).unwrap();
assert_eq!(s, "hi");
assert_eq!(rest, [0x99]);
}

#[test]
fn deserialize_bytes_invalid_utf8() {
// length prefix 3 + 3 bytes that form an invalid UTF-8 sequence
let input = [3, 0xED, 0xA0, 0x80]; // surrogates are invalid in UTF-8
let err = String::tls_deserialize_exact_bytes(&input).unwrap_err();
assert!(matches!(err, crate::Error::DecodingError(msg) if msg.contains("invalid utf8")));
}

#[test]
fn deserialize_bytes_empty_string() {
let input = [0];
let (s, rest) = String::tls_deserialize_bytes(&input).unwrap();
assert_eq!(s, "");
assert!(rest.is_empty());
}
}