Make the forward pass torch.compile compatible#87
Conversation
danieldk
left a comment
There was a problem hiding this comment.
Looks great! Added some comments to polish it up.
| This decorator stores the layer name and original forward method, which will be used | ||
| by the kernelize function to replace the forward implementation with the appropriate | ||
| kernel from the hub. | ||
|
|
There was a problem hiding this comment.
let's explain why we don't replace on the fly (cf compile)
There was a problem hiding this comment.
I added a FAQ with this question in it.
| The kernelized model | ||
| """ | ||
| is_compiling = _is_torchdynamo_compiling() | ||
| needs_backward = model.training |
There was a problem hiding this comment.
does this mean we cannot adapt to the user changing model.training .
Would be nice if we can register two forwards, and switch between them on the fly!
There was a problem hiding this comment.
Yeah, that would be nice! I think to support this we would have to override Module.train and switch forwards based on the value that is passed. I think it makes sense that I do this in a PR after this, since this one is already getting pretty big.
Then we'll remove the needs_backward argument again in the next PR (which should be fine as long we don't make a release).
- Process comment from the PR. - Replacement should be on instances, not the class. - Remove torch compile checks (not relevant during kernelize). We might add it back in a different way in another commit: add an option to `kernelize`.
d56849f to
4485e9b
Compare
torch.compile compatible
LysandreJik
left a comment
There was a problem hiding this comment.
Very nice! Love the new additions and the clear doc, thanks for spending time on this
| A model will not use Hub kernels by default, even if it contains extensible | ||
| layers. To enable the use of Hub kernels in the model, it needs to be | ||
| 'kernelized' using the `kernelize` function. This function traverses the | ||
| model graph and replaces the `forward` methods of extensible layers for which | ||
| Hub kernels are registered. Kernelize can be used as follows: | ||
|
|
||
| ```python | ||
| model = MyModel(...) | ||
| model = kernelize(model) | ||
| ``` |
There was a problem hiding this comment.
I like that opt-in, that's clean
| it signifies that the maintainer intends to keep the `forward` signature | ||
| compatible with layers from the hub. | ||
|
|
||
| ## Kernelizing a model |
This PR updates the flow of replace_kernel_forward_from_hub to ensure the appropriate forward pass is selected before model compilation. It introduces a new function
kernelizethat can be called to replace the forward before the model is compiled. It can also be called after registering new kernels for more flexibility