1313from dataclasses import dataclass
1414from typing import ClassVar , List , Literal , Optional , Tuple
1515
16+ from executorch .exir ._serialize ._cord import Cord
1617from executorch .exir ._serialize ._dataclass import _DataclassEncoder , _json_to_dataclass
1718from executorch .exir ._serialize ._flatbuffer import (
1819 _FlatbufferResult ,
2930 Program ,
3031 SubsegmentOffsets ,
3132)
33+ from executorch .exir .tensor import ALIGNMENT
3234
3335
3436# Byte order of numbers written to program headers. Always little-endian
@@ -240,15 +242,15 @@ def _get_extended_header(program_data: bytes) -> Optional[_ExtendedHeader]:
240242
241243
242244def _extract_delegate_segments (
243- program : Program , segments : List [bytes ], segment_alignment : int
245+ program : Program ,
246+ segments : List [Cord ],
244247) -> None :
245- """The input program and segments list are modified in place.
248+ """Extracts the delegate segments inlined in the program into a list of buffers.
249+ The program is modified in-place to remove the delegate data.
246250
247251 Args:
248252 program: The program to extract segments from. Modified in-place.
249- segments: A list to which extracted segments will be appended. Modified in-place.
250- segment_alignment: Alignment in bytes. The starting offset of each
251- segment will be aligned to this value.
253+ segments: A list of buffers to append extracted segments to. Modified in-place.
252254 """
253255 remaining_inline : List [BackendDelegateInlineData ] = []
254256 inline_indices_seen : set [int ] = set ()
@@ -278,24 +280,11 @@ def _extract_delegate_segments(
278280 if inline .data :
279281 # Move the delegate data out of the program.
280282 segment_index = len (segments )
281- segments .append (inline .data )
283+ segments .append (Cord ( inline .data ) )
282284 delegate .processed = BackendDelegateDataReference (
283285 location = DataLocation .SEGMENT ,
284286 index = segment_index ,
285287 )
286-
287- # Update the segment list in the root Program object.
288- prev_end = (
289- program .segments [- 1 ].offset + program .segments [- 1 ].size
290- if program .segments
291- else 0
292- )
293- program .segments .append (
294- DataSegment (
295- offset = _aligned_size (prev_end , segment_alignment ),
296- size = len (inline .data ),
297- ),
298- )
299288 else :
300289 # Not moving into a segment. Keep it inline, but update the
301290 # index.
@@ -321,183 +310,32 @@ def _extract_delegate_segments(
321310def _extract_constant_segment (
322311 constant_buffer : List [Buffer ],
323312 tensor_alignment : int ,
324- ) -> Tuple [bytes , List [int ]]:
325- """Copies the tensors from the provided list into a single buffer and tracks the offsets
326- of each tensor.
313+ ) -> Tuple [Cord , List [int ]]:
314+ """Copies the tensors from the provided list into a Cord and tracks the offsets
315+ of each tensor.
327316
317+ Args:
328318 constant_buffer: list of Buffers from which to extract constants from. Not modified.
329- tensor_alignment: Alignment in bytes. The starting offset of each tensor in the
330- constant segment will be aligned to this value. Default to 16 .
319+ tensor_alignment: Alignment in bytes. Each tensor in the cord will be padded to align
320+ with this value. Defaults to ALIGNMENT .
331321
332322 Returns:
333323 A tuple of (constant segment, list of offsets for each tensor in the segment)
334324 """
335- constant_segment_data : bytearray = bytearray ()
325+ constant_segment_data : Cord = Cord ()
336326 constant_segment_offsets : List [int ] = []
337327 current_offset : int = 0
338328 for i in range (len (constant_buffer )):
339329 buffer = constant_buffer [i ]
330+ constant_segment_data .append (buffer .storage )
340331 buffer_length = len (buffer .storage )
341332 pad_length = _padding_required (buffer_length , tensor_alignment )
342-
343- # Append each constant buffer to the constant segment.
344- constant_segment_data += buffer .storage
345- # Add padding for all but the last tensor.
346333 if i < len (constant_buffer ) - 1 :
347- constant_segment_data += b"\x00 " * pad_length
348-
349- # Append constant data offset.
334+ constant_segment_data .append (b"\x00 " * pad_length )
350335 constant_segment_offsets .append (current_offset )
351336 current_offset += buffer_length + pad_length
352- return bytes (constant_segment_data ), constant_segment_offsets
353-
354-
355- def _extract_segments (
356- program : Program ,
357- extract_delegate_segments : bool ,
358- extract_constant_segment : bool ,
359- segment_alignment : int ,
360- constant_tensor_alignment : int ,
361- ) -> Tuple [Program , List [bytes ]]:
362- """Extracts constant and/or delegate data from a given Program into separate segments.
363-
364- Args:
365- program: The Program to extract segments from.
366- extract_delegate_segments: Whether to extract delegate data blobs from the program.
367- extract_constant_segment: Whether to extract constant data from the program.
368- segment_alignment: Alignment in bytes. The starting offset of each
369- segment will be aligned to this value in the output data.
370- constant_tensor_alignment: Alignment in bytes. The starting offset of each tensor
371- in the constant segment will be aligned to this value.
372- Returns:
373- A tuple of (modified program, list of segment data).
374- Raises:
375- ValueError, if the program already contains segments.
376- """
377- if program .segments :
378- raise ValueError (
379- f"Program already has { len (program .segments )} segments: "
380- + f"{ repr (program .segments )} "
381- )
382-
383- # Don't modify the original program.
384- # TODO(T144120904): Could avoid yet more huge copies with a more shallow
385- # copy, reusing the actual data blobs.
386- program = copy .deepcopy (program )
387-
388- # Segment data to be written to the file following the flatbuffer data.
389- segments : List [bytes ] = []
390-
391- if extract_constant_segment :
392- constant_segment_data , constant_segment_offsets = _extract_constant_segment (
393- program .constant_buffer , tensor_alignment = constant_tensor_alignment
394- )
395-
396- if constant_segment_data :
397- # Append constant_segment_data to the list of segments if non-empty.
398- segments .append (constant_segment_data )
399- # Append constant_segment offset to the list of DataSegments. Added as the
400- # first segment here, but it's not mandatory that the constant segment be
401- # the first one in the file.
402- program .segments .append (
403- DataSegment (offset = 0 , size = len (constant_segment_data ))
404- )
405-
406- # Fill in constant_segment offsets and clear the constant buffer; only one of
407- # constant_segment and constant_buffer should be non-empty.
408- program .constant_segment = SubsegmentOffsets (
409- segment_index = 0 , offsets = constant_segment_offsets
410- )
411- program .constant_buffer = []
412-
413- if extract_delegate_segments :
414- _extract_delegate_segments (
415- program , segments = segments , segment_alignment = segment_alignment
416- )
417- return program , segments
418-
419-
420- def _append_segments (
421- program_data : bytes ,
422- segments : List [bytes ],
423- alignment : int ,
424- segment_table : List [DataSegment ],
425- base_offset : int ,
426- ) -> bytes :
427- """Appends segments to the end of the program data.
428-
429- Appends each element of `segments` to `program_data`, with '\0 ' padding to
430- ensure that the offset of each segment is aligned to `alignment`.
431-
432- Args:
433- program_data: The flatbuffer-serialized Program.
434- segments: The list of segments to append to `program_data`.
435- alignment: Alignment in bytes. The starting offset of each
436- segment will be aligned to this value in the output data.
437- segment_table: The expected offsets and sizes of each element in
438- `segments`. This is typically `program.segments`. Must have the
439- same length as `segments`.
440- base_offset: The expected segment base offset from the extended header.
441- Should point to the aligned offset following the end of
442- `program_data`.
443- Returns:
444- A copy of `program_data` with the segment data and padding appended.
445- If there are no segments, returns `program_data` directly.
446- Raises:
447- ValueError: If the length of `segments` doesn't match the length of
448- `segment_table`.
449- """
450- if len (segments ) != len (segment_table ):
451- raise ValueError (
452- f"Segments length { len (segments )} does not match "
453- + f"segment_table length { len (segment_table )} "
454- )
455- if not segments :
456- return program_data
457-
458- # The pieces that will be concatenated to create the output data.
459- # `program_data` will be its first element.
460- padded_segments : List [bytes ] = []
461- # Length of all elements in padded_segments. Only used for assertions.
462- current_offset : int = 0
463- for i , segment in enumerate ([program_data ] + segments ):
464- # Add padding if necessary to align the start of this segment.
465- pad_length : int = _padding_required (current_offset , alignment )
466- if pad_length > 0 :
467- padded_segments .append (b"\x00 " * pad_length )
468- current_offset += pad_length
469-
470- # Make sure that we're about to add this segment to the offset that
471- # agrees with program.segments. Skip the first entry, which is the
472- # Program itself and isn't included in program.segments.
473- if i == 1 :
474- # The first real segment should start at the base offset.
475- assert current_offset == base_offset , (
476- f"Offset of first segment { current_offset } "
477- + f"!= base_offset { base_offset } "
478- )
479- if i > 0 :
480- # Adding a real segment, not `program_data`.
481- expected_segment = segment_table [i - 1 ]
482- expected_offset = base_offset + expected_segment .offset
483- assert current_offset == expected_offset , (
484- f"Segment { i } offset { current_offset } "
485- + f"!= expected offset { expected_offset } "
486- + f"(base { base_offset } + { expected_segment .offset } ) "
487- )
488- assert expected_segment .size == len (segment ), (
489- f"Segment { i } size { len (segment )} "
490- + f"!= expected size { expected_segment .size } "
491- )
492-
493- # Add the payload. If this is the final segment, it does not need
494- # padding after it.
495- padded_segments .append (segment )
496- current_offset += len (segment )
497337
498- # Use join() instead of appending to avoid O(n) reallocation of these
499- # potentially-large buffers.
500- return b"" .join (padded_segments )
338+ return constant_segment_data , constant_segment_offsets
501339
502340
503341def serialize_pte_binary (
@@ -524,9 +362,8 @@ def serialize_pte_binary(
524362 into a separate segment.
525363 segment_alignment: Alignment in bytes. The starting offset of each
526364 segment will be aligned to this value in the output data.
527- constant_tensor_alignment: If provided, the minimum alignment of tensor
528- buffers in the program. Must be a power of 2. If not provided, uses
529- the value in the schema file.
365+ constant_tensor_alignment: The minimum alignment of tensor
366+ buffers in the program. Must be a power of 2. Defaults to ALIGNMENT.
530367 delegate_alignment: If provided, the minimum alignment of delegate data
531368 in the program. Must be a power of 2. If not provided, uses the
532369 value in the schema file.
@@ -535,20 +372,53 @@ def serialize_pte_binary(
535372 """
536373 # Default tensor alignment.
537374 if constant_tensor_alignment is None :
538- constant_tensor_alignment = 16
375+ constant_tensor_alignment = ALIGNMENT
539376
540- # Segment data to be written to the file following the flatbuffer data.
541- segments : List [bytes ] = []
377+ # Don't modify the original program.
378+ # TODO(T144120904): Could avoid yet more huge copies with a more shallow
379+ # copy, reusing the actual data blobs.
380+ program = copy .deepcopy (program )
381+
382+ # Store extracted segment data; this may be constant data or delegate data.
383+ segments : List [Cord ] = []
384+
385+ if extract_constant_segment :
386+ constant_segment_data , constant_segment_offsets = _extract_constant_segment (
387+ program .constant_buffer , tensor_alignment = constant_tensor_alignment
388+ )
389+ if len (constant_segment_data ) > 0 :
390+ # Update program.constant_segment with constant subsegment offset information.
391+ program .constant_segment = SubsegmentOffsets (
392+ segment_index = len (segments ), offsets = constant_segment_offsets
393+ )
394+ # Clear the constant buffer, as constant data will be stored in segments.
395+ program .constant_buffer = []
396+ # Add to the aggregate segments cord.
397+ segments .append (constant_segment_data )
542398
543- # Extract constant segment and delegate segments, if requested.
544- if extract_constant_segment or extract_delegate_segments :
545- program , segments = _extract_segments (
546- program = program ,
547- extract_delegate_segments = extract_delegate_segments ,
548- extract_constant_segment = extract_constant_segment ,
549- segment_alignment = segment_alignment ,
550- constant_tensor_alignment = constant_tensor_alignment ,
399+ if extract_delegate_segments :
400+ _extract_delegate_segments (program , segments )
401+
402+ # Append all segments into a single Cord, adding any necessary padding to ensure that
403+ # each segment begins at the required alignment.
404+ # Update program.segments with the offsets to each segment.
405+ segments_data = Cord ()
406+ for data in segments :
407+ prev_end = (
408+ (program .segments [- 1 ].offset + program .segments [- 1 ].size )
409+ if program .segments
410+ else 0
411+ )
412+ program .segments .append (
413+ DataSegment (
414+ offset = _aligned_size (prev_end , segment_alignment ), size = len (data )
415+ )
551416 )
417+ # Add to aggregate segments cord with padding.
418+ padding_length = _padding_required (len (segments_data ), segment_alignment )
419+ if padding_length > 0 :
420+ segments_data .append (b"\x00 " * padding_length )
421+ segments_data .append (data )
552422
553423 # Convert to a standard flatbuffer binary.
554424 result : _FlatbufferResult = _program_json_to_flatbuffer (
@@ -558,7 +428,7 @@ def serialize_pte_binary(
558428 )
559429
560430 # If there are no segments present, do not insert the extended header.
561- if not segments :
431+ if len ( segments_data ) == 0 :
562432 return result .data
563433
564434 # Size of the header to insert. Its size is padded to the largest
@@ -572,7 +442,7 @@ def serialize_pte_binary(
572442 # Offset to the first segment, or zero if there are no segments.
573443 segment_base_offset : int = (
574444 _aligned_size (input_size = program_size , alignment = segment_alignment )
575- if segments
445+ if len ( segments_data ) > 0
576446 else 0
577447 )
578448
@@ -600,18 +470,21 @@ def serialize_pte_binary(
600470 assert eh .program_size == program_size
601471 assert eh .segment_base_offset == segment_base_offset
602472
603- if segments :
604- # Add segments to the end of the data, in order, with the appropriate
605- # padding.
606- program_data = _append_segments (
607- program_data = program_data ,
608- segments = segments ,
609- alignment = segment_alignment ,
610- segment_table = program .segments ,
611- base_offset = segment_base_offset ,
612- )
613-
614- return program_data
473+ # Construct the final pte file containing:
474+ # - program data; written to offset 0.
475+ # - segments data (optional); aligned to segment_alignment.
476+ pte_data = Cord (program_data )
477+ if len (segments_data ) > 0 :
478+ padding_length = _padding_required (len (pte_data ), segment_alignment )
479+ pte_data .append (b"\x00 " * padding_length )
480+ # The first segment after program data should start at the segment base offset.
481+ assert (
482+ len (pte_data ) == segment_base_offset
483+ ), f"Offset of first segment { len (pte_data )} != segment base offset { segment_base_offset } "
484+ pte_data .append (segments_data )
485+
486+ # TODO(lfq): this creates a copy of all the data; once we update existing callsites this will change.
487+ return bytes (pte_data )
615488
616489
617490def _restore_segments (program : Program , segment_data : bytes ) -> Program :
0 commit comments