Skip to content
Prev Previous commit
Corrected typos/examples in docstring
  • Loading branch information
Lunderberg committed May 11, 2022
commit 7708d9eee37b4067d41dc01f0c2309483f200aad
33 changes: 14 additions & 19 deletions python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def inverse(self, shape: List[Union[Range, PrimExpr]]) -> "IndexMap":

Throws an error if the function is not bijective.

Paramters
---------
Parameters
----------
shape: List[Union[Range,PrimExpr]]

The region over which the inverse should be determined.
Expand All @@ -387,23 +387,8 @@ def non_surjective_inverse(

Can be applied to transformations that introduce padding.

Examples
--------

Before unroll, in TensorIR, the IR is:

.. code-block:: python

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
inverse_map

.. code-block:: python

index_map = IndexMap.from_func(lambda i: [i//4, i%4])

Paramters
---------
Parameters
----------
shape: List[Union[Range,PrimExpr]]

The region over which the inverse should be determined.
Expand All @@ -415,6 +400,16 @@ def non_surjective_inverse(

The inverse, and a predicate for which the inverse maps to
a valid index in the input range.

Examples
--------

.. code-block:: python

index_map = IndexMap.from_func(lambda i: [i//4, i%4])
inverse_map, predicate = index_map.non_surjective_inverse([14])
assert inverse_map.is_equivalent_to(IndexMap.from_func(lambda j,k: [4*j + k])
print(predicate) # Prints "(axis0==3) && (axis2 >= 2)"
"""

shape = [dim if isinstance(dim, Range) else Range(0, dim) for dim in shape]
Expand Down