1212#[ pymodule]
1313pub ( crate ) mod _struct {
1414 use crate :: {
15- builtins:: { float, PyBaseExceptionRef , PyBytesRef , PyStr , PyStrRef , PyTupleRef , PyTypeRef } ,
15+ builtins:: {
16+ float, PyBaseExceptionRef , PyBytes , PyBytesRef , PyStr , PyStrRef , PyTupleRef , PyTypeRef ,
17+ } ,
1618 common:: str:: wchar_t,
1719 function:: { ArgBytesLike , ArgIntoBool , ArgMemoryBuffer , IntoPyObject , PosArgs } ,
1820 protocol:: PyIterReturn ,
1921 slots:: { IteratorIterable , SlotConstructor , SlotIterator } ,
20- utils:: Either ,
21- PyObjectRef , PyRef , PyResult , PyValue , TryFromObject , VirtualMachine ,
22+ PyObjectRef , PyRef , PyResult , PyValue , TryFromObject , TypeProtocol , VirtualMachine ,
2223 } ;
2324 use crossbeam_utils:: atomic:: AtomicCell ;
2425 use half:: f16;
@@ -202,6 +203,39 @@ pub(crate) mod _struct {
202203
203204 const OVERFLOW_MSG : & str = "total struct size too long" ;
204205
206+ struct IntoStructFormatBytes ( PyStrRef ) ;
207+
208+ impl TryFromObject for IntoStructFormatBytes {
209+ fn try_from_object ( vm : & VirtualMachine , obj : PyObjectRef ) -> PyResult < Self > {
210+ // CPython turns str to bytes but we do reversed way here
211+ // The only performance difference is this transition cost
212+ let fmt = match_class ! {
213+ match obj {
214+ s @ PyStr => if s. is_ascii( ) {
215+ Some ( s)
216+ } else {
217+ None
218+ } ,
219+ b @ PyBytes => if b. is_ascii( ) {
220+ Some ( unsafe {
221+ PyStr :: new_ascii_unchecked( b. as_bytes( ) . to_vec( ) )
222+ } . into_ref( vm) )
223+ } else {
224+ None
225+ } ,
226+ other => return Err ( vm. new_type_error( format!( "Struct() argument 1 must be a str or bytes object, not {}" , other. class( ) . name( ) ) ) ) ,
227+ }
228+ } . ok_or_else ( || vm. new_unicode_decode_error ( "Struct format must be a ascii string" . to_owned ( ) ) ) ?;
229+ Ok ( IntoStructFormatBytes ( fmt) )
230+ }
231+ }
232+
233+ impl IntoStructFormatBytes {
234+ fn format_spec ( & self , vm : & VirtualMachine ) -> PyResult < FormatSpec > {
235+ FormatSpec :: parse ( self . 0 . as_str ( ) . as_bytes ( ) , vm)
236+ }
237+ }
238+
205239 #[ derive( Debug , Clone ) ]
206240 pub ( crate ) struct FormatSpec {
207241 endianness : Endianness ,
@@ -211,24 +245,8 @@ pub(crate) mod _struct {
211245 }
212246
213247 impl FormatSpec {
214- fn decode_and_parse (
215- vm : & VirtualMachine ,
216- fmt : & Either < PyStrRef , PyBytesRef > ,
217- ) -> PyResult < FormatSpec > {
218- let decoded_fmt = match fmt {
219- Either :: A ( string) => string. as_str ( ) ,
220- Either :: B ( bytes) if bytes. is_ascii ( ) => std:: str:: from_utf8 ( bytes) . unwrap ( ) ,
221- _ => {
222- return Err ( vm. new_unicode_decode_error (
223- "Struct format must be a ascii string" . to_owned ( ) ,
224- ) )
225- }
226- } ;
227- FormatSpec :: parse ( decoded_fmt, vm)
228- }
229-
230- pub fn parse ( fmt : & str , vm : & VirtualMachine ) -> PyResult < FormatSpec > {
231- let mut chars = fmt. bytes ( ) . peekable ( ) ;
248+ pub fn parse ( fmt : & [ u8 ] , vm : & VirtualMachine ) -> PyResult < FormatSpec > {
249+ let mut chars = fmt. iter ( ) . copied ( ) . peekable ( ) ;
232250
233251 // First determine "@", "<", ">","!" or "="
234252 let endianness = parse_endianness ( & mut chars) ;
@@ -399,10 +417,10 @@ pub(crate) mod _struct {
399417 let mut repeat = 0isize ;
400418 while let Some ( b'0' ..=b'9' ) = chars. peek ( ) {
401419 if let Some ( c) = chars. next ( ) {
402- let current_digit = ( c as char ) . to_digit ( 10 ) . unwrap ( ) as isize ;
420+ let current_digit = c - b'0' ;
403421 repeat = repeat
404422 . checked_mul ( 10 )
405- . and_then ( |r| r. checked_add ( current_digit) )
423+ . and_then ( |r| r. checked_add ( current_digit as _ ) )
406424 . ok_or_else ( || OVERFLOW_MSG . to_owned ( ) ) ?;
407425 }
408426 }
@@ -486,20 +504,26 @@ pub(crate) mod _struct {
486504 }
487505 buffer_len - ( -offset as usize )
488506 } else {
489- if offset as usize >= buffer_len {
507+ let offset = offset as usize ;
508+ let ( op, op_action) = if is_pack {
509+ ( "pack_into" , "packing" )
510+ } else {
511+ ( "unpack_from" , "unpacking" )
512+ } ;
513+ if offset >= buffer_len {
490514 let msg = format ! (
491515 "{op} requires a buffer of at least {required} bytes for {op_action} {needed} \
492516 bytes at offset {offset} (actual buffer size is {buffer_len})",
493- op = if is_pack { "pack_into" } else { "unpack_from" } ,
494- op_action = if is_pack { "packing" } else { "unpacking" } ,
517+ op = op ,
518+ op_action = op_action ,
495519 required = needed + offset as usize ,
496520 needed = needed,
497521 offset = offset,
498522 buffer_len = buffer_len
499523 ) ;
500524 return Err ( new_struct_error ( vm, msg) ) ;
501525 }
502- offset as usize
526+ offset
503527 } ;
504528
505529 if ( buffer_len - offset_from_start) < needed {
@@ -717,24 +741,19 @@ pub(crate) mod _struct {
717741 }
718742
719743 #[ pyfunction]
720- fn pack (
721- fmt : Either < PyStrRef , PyBytesRef > ,
722- args : PosArgs ,
723- vm : & VirtualMachine ,
724- ) -> PyResult < Vec < u8 > > {
725- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt) ?;
726- format_spec. pack ( args. into_vec ( ) , vm)
744+ fn pack ( fmt : IntoStructFormatBytes , args : PosArgs , vm : & VirtualMachine ) -> PyResult < Vec < u8 > > {
745+ fmt. format_spec ( vm) ?. pack ( args. into_vec ( ) , vm)
727746 }
728747
729748 #[ pyfunction]
730749 fn pack_into (
731- fmt : Either < PyStrRef , PyBytesRef > ,
750+ fmt : IntoStructFormatBytes ,
732751 buffer : ArgMemoryBuffer ,
733752 offset : isize ,
734753 args : PosArgs ,
735754 vm : & VirtualMachine ,
736755 ) -> PyResult < ( ) > {
737- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
756+ let format_spec = fmt . format_spec ( vm) ?;
738757 let offset = get_buffer_offset ( buffer. len ( ) , offset, format_spec. size , true , vm) ?;
739758 buffer. with_ref ( |data| format_spec. pack_into ( & mut data[ offset..] , args. into_vec ( ) , vm) )
740759 }
@@ -757,11 +776,11 @@ pub(crate) mod _struct {
757776
758777 #[ pyfunction]
759778 fn unpack (
760- fmt : Either < PyStrRef , PyBytesRef > ,
779+ fmt : IntoStructFormatBytes ,
761780 buffer : ArgBytesLike ,
762781 vm : & VirtualMachine ,
763782 ) -> PyResult < PyTupleRef > {
764- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
783+ let format_spec = fmt . format_spec ( vm) ?;
765784 buffer. with_ref ( |buf| format_spec. unpack ( buf, vm) )
766785 }
767786
@@ -774,11 +793,11 @@ pub(crate) mod _struct {
774793
775794 #[ pyfunction]
776795 fn unpack_from (
777- fmt : Either < PyStrRef , PyBytesRef > ,
796+ fmt : IntoStructFormatBytes ,
778797 args : UpdateFromArgs ,
779798 vm : & VirtualMachine ,
780799 ) -> PyResult < PyTupleRef > {
781- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
800+ let format_spec = fmt . format_spec ( vm) ?;
782801 let offset =
783802 get_buffer_offset ( args. buffer . len ( ) , args. offset , format_spec. size , false , vm) ?;
784803 args. buffer
@@ -849,47 +868,42 @@ pub(crate) mod _struct {
849868
850869 #[ pyfunction]
851870 fn iter_unpack (
852- fmt : Either < PyStrRef , PyBytesRef > ,
871+ fmt : IntoStructFormatBytes ,
853872 buffer : ArgBytesLike ,
854873 vm : & VirtualMachine ,
855874 ) -> PyResult < UnpackIterator > {
856- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt ) ?;
875+ let format_spec = fmt . format_spec ( vm) ?;
857876 UnpackIterator :: new ( vm, format_spec, buffer)
858877 }
859878
860879 #[ pyfunction]
861- fn calcsize ( fmt : Either < PyStrRef , PyBytesRef > , vm : & VirtualMachine ) -> PyResult < usize > {
862- let format_spec = FormatSpec :: decode_and_parse ( vm, & fmt) ?;
863- Ok ( format_spec. size )
880+ fn calcsize ( fmt : IntoStructFormatBytes , vm : & VirtualMachine ) -> PyResult < usize > {
881+ Ok ( fmt. format_spec ( vm) ?. size )
864882 }
865883
866884 #[ pyattr]
867885 #[ pyclass( name = "Struct" ) ]
868886 #[ derive( Debug , PyValue ) ]
869887 struct PyStruct {
870888 spec : FormatSpec ,
871- fmt_str : PyStrRef ,
889+ format : PyStrRef ,
872890 }
873891
874892 impl SlotConstructor for PyStruct {
875- type Args = Either < PyStrRef , PyBytesRef > ;
893+ type Args = IntoStructFormatBytes ;
876894
877895 fn py_new ( cls : PyTypeRef , fmt : Self :: Args , vm : & VirtualMachine ) -> PyResult {
878- let spec = FormatSpec :: decode_and_parse ( vm, & fmt) ?;
879- let fmt_str = match fmt {
880- Either :: A ( s) => s,
881- Either :: B ( b) => PyStr :: from ( std:: str:: from_utf8 ( b. as_bytes ( ) ) . unwrap ( ) )
882- . into_ref_with_type ( vm, vm. ctx . types . str_type . clone ( ) ) ?,
883- } ;
884- PyStruct { spec, fmt_str } . into_pyresult_with_type ( vm, cls)
896+ let spec = fmt. format_spec ( vm) ?;
897+ let format = fmt. 0 ;
898+ PyStruct { spec, format } . into_pyresult_with_type ( vm, cls)
885899 }
886900 }
887901
888902 #[ pyimpl( with( SlotConstructor ) ) ]
889903 impl PyStruct {
890904 #[ pyproperty]
891905 fn format ( & self ) -> PyStrRef {
892- self . fmt_str . clone ( )
906+ self . format . clone ( )
893907 }
894908
895909 #[ pyproperty]
0 commit comments