@@ -176,9 +176,9 @@ def _construct_data_iter_tree(
176176
177177 @staticmethod
178178 def from_nested_tuple (
179- data : Tuple ,
180- strides : Tuple ,
181- device : Optional [Tuple ] = None ,
179+ data : Union [ Tuple , int ] ,
180+ strides : Union [ Tuple , int ] ,
181+ device : Optional [Union [ Tuple , int ] ] = None ,
182182 exclusive : Optional [Tuple ] = None ,
183183 from_to : Optional [Tuple [str ]] = None ,
184184 ) -> "TileLayout" :
@@ -189,6 +189,11 @@ def inc_leaf_cnt():
189189 leaf_cnt += 1
190190 return leaf_cnt - 1
191191
192+ if not isinstance (data , tuple ):
193+ data = (data ,)
194+ if not isinstance (strides , tuple ):
195+ strides = (strides ,)
196+
192197 if device is None :
193198 assert exclusive is None , "exclusive must be None if device is None"
194199 assert from_to is None , "from_to must be None if device is None"
@@ -199,6 +204,9 @@ def inc_leaf_cnt():
199204 return TileLayout (data_tree = data_tree )
200205
201206 else :
207+ if not isinstance (device , tuple ):
208+ device = (device ,)
209+
202210 assert from_to is not None , "from_to must be provided if device is provided"
203211 assert isinstance (from_to , tuple ) and len (from_to ) == 2 , "from_to must be a tuple of 2"
204212
@@ -227,6 +235,28 @@ def inc_leaf_cnt():
227235 def tile (outer : "TileLayout" , inner : "TileLayout" ) -> "TileLayout" :
228236 return get_global_func ("tir.TileLayoutTile" )(outer , inner )
229237
238+ @staticmethod
239+ def shard (
240+ shape : Tuple [PrimExpr , int ],
241+ mesh : Tuple ,
242+ strategy : str ,
243+ inner : "TileLayout" ,
244+ from_to : Optional [Tuple [str ]] = None ,
245+ ) -> "TileLayout" :
246+ assert from_to is not None , "from_to must be provided if device is provided"
247+ assert isinstance (from_to , tuple ) and len (from_to ) == 2 , "from_to must be a tuple of 2"
248+
249+ f = get_global_func ("tir.IterTreeFromTuple" )
250+ iter_tree , _ = f (convert_to_object (mesh ))
251+ return get_global_func ("tir.TileLayoutShard" )(
252+ shape ,
253+ iter_tree ,
254+ strategy ,
255+ inner ,
256+ ExecScope .create (from_to [0 ]) if from_to else None ,
257+ ExecScope .create (from_to [1 ]) if from_to else None ,
258+ )
259+
230260 @staticmethod
231261 def normalize (layout : "TileLayout" ) -> "TileLayout" :
232262 return get_global_func ("tir.NormalizeTileLayout" )(layout )
0 commit comments