Skip to content

Enable Shared Function in LiftTransformParam Pass#16717

Merged
vinx13 merged 12 commits intoapache:mainfrom
zxybazh:feature/2024-03-13/enable-share-func-in-lift-transform-params-pass
Mar 19, 2024
Merged

Enable Shared Function in LiftTransformParam Pass#16717
vinx13 merged 12 commits intoapache:mainfrom
zxybazh:feature/2024-03-13/enable-share-func-in-lift-transform-params-pass

Conversation

@zxybazh
Copy link
Copy Markdown
Member

@zxybazh zxybazh commented Mar 14, 2024

This PR enables specifying a list of function names to extract shared transform parameters. A single parameter transformation function will be produced, containing the preprocessing steps common across each function whose name is in a given function name list.

Unit tests are passing except the one that has no shared prepocessing (transpose), skipping for now, will follow-up in another PR.

Cherry-picked from @vinx13's working branch.

vinx13 and others added 7 commits March 13, 2024 22:45
Currently, the `relax.transform.LiftTransformParams` pass produces a
separate `transform_params` function for every function in the
`IRModule`.  In most cases, the functions in an `IRModule` all accept
the same set of model weights (e.g. `"prefill"` and `"decode"` in a
transformer model).  However, the lifted `*_transform_params`
functions may be different for each inference function.

The goal is to introduce a new optional parameter `shared_transform`
for `LiftTransformParams`.  If set, a single parameter transformation
function should be generated for the entire `IRModule`, rather than
one parameter transformation function for each original function.

Because the shared parameter transformation function must be
compatible with all existing functions, it should only contain
parameter transformation steps that are common across all input
functions.
Comment thread src/relax/transform/lift_transform_params.cc Outdated
@zxybazh zxybazh marked this pull request as ready for review March 15, 2024 16:01
Copy link
Copy Markdown
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good, and thank you for the improvement! A couple of comments on it, then it should be good to go!

Comment thread include/tvm/relax/transform.h Outdated
Comment thread include/tvm/relax/transform.h Outdated
Comment thread python/tvm/relax/transform/transform.py Outdated
Comment thread python/tvm/relax/transform/transform.py Outdated
Comment thread src/relax/transform/lift_transform_params.cc Outdated
Comment thread src/relax/transform/lift_transform_params.cc Outdated
Comment thread src/relax/transform/lift_transform_params.cc Outdated
@zxybazh
Copy link
Copy Markdown
Member Author

zxybazh commented Mar 18, 2024

@Lunderberg thank you for the detailed review, comments are all addressed, please take another look when you got time, thanks a lot!

Copy link
Copy Markdown
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making the changes, and LGTM!

@vinx13 vinx13 merged commit ff6ce9c into apache:main Mar 19, 2024
thaisacs pushed a commit to thaisacs/tvm that referenced this pull request Apr 3, 2024
* [WIP] LiftTransformParams for multiple functions

* pass test

* [In-Progress] Define desired behavior for shared LiftTransformParams

Currently, the `relax.transform.LiftTransformParams` pass produces a
separate `transform_params` function for every function in the
`IRModule`.  In most cases, the functions in an `IRModule` all accept
the same set of model weights (e.g. `"prefill"` and `"decode"` in a
transformer model).  However, the lifted `*_transform_params`
functions may be different for each inference function.

The goal is to introduce a new optional parameter `shared_transform`
for `LiftTransformParams`.  If set, a single parameter transformation
function should be generated for the entire `IRModule`, rather than
one parameter transformation function for each original function.

Because the shared parameter transformation function must be
compatible with all existing functions, it should only contain
parameter transformation steps that are common across all input
functions.

* [TIR] Implemented shared lift transform params

* Comments & skip test.

* Linting.

* Avoid c++20 feature to pass CI.

* Remove unused code.

* Fix interface as suggested.

* Fix docs.

* Fix interface as suggested.

* Move code for readability.

---------

Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Eric Lunderberg <elunderberg@octoml.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants