Skip to content

Make the forward pass torch.compile compatible#87

Merged
danieldk merged 11 commits into
huggingface:mainfrom
MekkCyber:compile_compatible_2nd
Jun 3, 2025
Merged

Make the forward pass torch.compile compatible#87
danieldk merged 11 commits into
huggingface:mainfrom
MekkCyber:compile_compatible_2nd

Conversation

@MekkCyber
Copy link
Copy Markdown
Contributor

This PR updates the flow of replace_kernel_forward_from_hub to ensure the appropriate forward pass is selected before model compilation. It introduces a new function kernelize that can be called to replace the forward before the model is compiled. It can also be called after registering new kernels for more flexibility

Comment thread src/kernels/layer.py
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Copy link
Copy Markdown
Member

@danieldk danieldk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Added some comments to polish it up.

Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Comment thread src/kernels/layer.py Outdated
Copy link
Copy Markdown
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment thread src/kernels/layer.py Outdated
This decorator stores the layer name and original forward method, which will be used
by the kernelize function to replace the forward implementation with the appropriate
kernel from the hub.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's explain why we don't replace on the fly (cf compile)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a FAQ with this question in it.

Comment thread src/kernels/layer.py Outdated
The kernelized model
"""
is_compiling = _is_torchdynamo_compiling()
needs_backward = model.training
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean we cannot adapt to the user changing model.training .
Would be nice if we can register two forwards, and switch between them on the fly!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would be nice! I think to support this we would have to override Module.train and switch forwards based on the value that is passed. I think it makes sense that I do this in a PR after this, since this one is already getting pretty big.

Then we'll remove the needs_backward argument again in the next PR (which should be fine as long we don't make a release).

MekkCyber and others added 7 commits June 2, 2025 13:57
- Process comment from the PR.
- Replacement should be on instances, not the class.
- Remove torch compile checks (not relevant during kernelize). We
  might add it back in a different way in another commit: add an
  option to `kernelize`.
@danieldk danieldk force-pushed the compile_compatible_2nd branch from d56849f to 4485e9b Compare June 2, 2025 19:03
@danieldk danieldk marked this pull request as draft June 2, 2025 19:04
@danieldk danieldk changed the title Make the forward pass torch.compile compatible Make the forward pass torch.compile compatible Jun 3, 2025
@danieldk danieldk marked this pull request as ready for review June 3, 2025 12:31
Copy link
Copy Markdown
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Love the new additions and the clear doc, thanks for spending time on this

Comment thread docs/layers.md
Comment on lines +48 to +57
A model will not use Hub kernels by default, even if it contains extensible
layers. To enable the use of Hub kernels in the model, it needs to be
'kernelized' using the `kernelize` function. This function traverses the
model graph and replaces the `forward` methods of extensible layers for which
Hub kernels are registered. Kernelize can be used as follows:

```python
model = MyModel(...)
model = kernelize(model)
```
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that opt-in, that's clean

Comment thread docs/layers.md
it signifies that the maintainer intends to keep the `forward` signature
compatible with layers from the hub.

## Kernelizing a model
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clear doc

@danieldk danieldk merged commit 32ec496 into huggingface:main Jun 3, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants