[Unity][Analysis] Add utility for collecting compile-time bindings#16312
Merged
Lunderberg merged 2 commits intoapache:unityfrom Jan 4, 2024
Merged
Conversation
Whether an optimizations should be performed may depend on when the variables in an expression are known. For example, consider a LoRA-adjusted model, with base weights `W` of shape `[m,n]`, LoRA components `A` and `B` with shapes `[r,n]` and `[m,r]` respectively, and activations `x` with shape `[n,1]`. The LoRA-adjusted matmul could be computed either as `(W + B*A)*x` or as `(W*x + B*(A*x))`. If `A` and `B` are provided at run-time, then computing `(W + B*(A*x))` requires significantly fewer computations. * `(W + B*A)*x`: `m*n*(2*r + 3)` operations 1. `B*A`: `2*m*n*r` operations using a naive matmul 2. Adding `W` to (1): `m*n` operations 3. Multiplying `x` by (2): `2*m*n` operations * `(W*x + B*(A*x))`: (2*m*n + r*(2*n + 2*m + 1)) 1. `W*x`: `2*m*n` operations 2. `A*x`: `2*r*n` operations 3. Multiplying `B` by (2): `2*m*r` operations 4. Adding (1) and (3)`: `m` operations However, if `A` and `B` are known at compile-time, then computing `(W + B*A)*x` groups all compile-time values together, allowing them to be computed earlier (i.e. using `LiftTransformParams`) * `(W + B*A)*x`: `2*m*n` operations 1. `B*A`: 0 operations, computed at compile-time 2. Adding `W` to (1): 0 operations, computed at compile-time 3. Multiplying `x` by (2): `2*m*n` operations Since the choice of optimized expression depends on which parameters can be computed at compile-time, it is useful to have a utility that identifies values that can be computed at compile-time.
This was referenced Dec 29, 2023
…s_pr_16312 Pull in bugfix from apache#16322
Contributor
|
The logic seems fine here, but what is the |
Contributor
Author
That's correct. The attribute was initially added for |
slyubomirsky
approved these changes
Jan 4, 2024
Contributor
slyubomirsky
left a comment
There was a problem hiding this comment.
This makes sense then, thanks for implementing the pass.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Whether an optimizations should be performed may depend on when the variables in an expression are known.
For example, consider a LoRA-adjusted model, with base weights
Wof shape[m,n], LoRA componentsAandBwith shapes[r,n]and[m,r]respectively, and activationsxwith shape[n,1]. The LoRA-adjusted matmul could be computed either as(W + B*A)*xor as(W*x + B*(A*x)).If
AandBare provided at run-time, then computing(W + B*(A*x))requires significantly fewer computations.(W + B*A)*x:m*n*(2*r + 3)operationsB*A:2*m*n*roperations using a naive matmulWto (1):m*noperationsxby (2):2*m*noperations(W*x + B*(A*x)): (2mn + r*(2n + 2m + 1))W*x:2*m*noperationsA*x:2*r*noperationsBby (2):2*m*roperations:m` operationsHowever, if
AandBare known at compile-time, then computing `(Wgroups all compile-time values together, allowing them to be computed earlier (i.e. usingLiftTransformParams`)(W + B*A)*x:2*m*noperationsB*A: 0 operations, computed at compile-timeWto (1): 0 operations, computed at compile-timexby (2):2*m*noperationsSince the choice of optimized expression depends on which parameters can be computed at compile-time, it is useful to have a utility that identifies values that can be computed at compile-time.