-
Notifications
You must be signed in to change notification settings - Fork 969
Add Gemma 4 MLX install-path support #19065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
5f455a2
0a822bd
fd78741
0e00290
3a26baa
0bf5fc4
90e5577
ee272c3
ca37250
818a51d
6e520dd
391cde4
19d6f09
41e3a51
9d3f841
719d2e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -444,26 +444,50 @@ def _make_io_slots(self): # noqa: C901 | |
| else: | ||
| raise NotImplementedError(f"Support for input {arg} is not implemented") | ||
|
|
||
| placeholder_nodes = { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't follow this change. Why is gemma4 sensistive to this?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got here by diffing a previously working Gemma 4 What changed there was the slot assignment for the two rotary constants used by sliding-window vs full attention. This change was just to make that assignment deterministic instead of depending on raw placeholder traversal order. Gemma 4 is where I noticed it because that model exercises both constants in the same path. If you’d prefer, I can drop this |
||
| node.name: node for node in self.ep.graph.nodes if node.op == "placeholder" | ||
| } | ||
|
|
||
| # Allocate placeholder-backed slots in graph-signature order instead of | ||
| # raw FX node traversal order. This keeps lifted constant tids stable | ||
| # across equivalent exports, which matters for models like Gemma 4 that | ||
| # carry multiple rotary constant placeholders with similar structure. | ||
| for name in constant_tensors: | ||
| node = placeholder_nodes.get(name) | ||
| if node is None or node.users == {}: | ||
| continue | ||
| self.make_or_get_slot(node, id_space=IdSpace.Constant) | ||
|
|
||
| for name in user_inputs: | ||
| node = placeholder_nodes.get(name) | ||
| if node is None or node.users == {}: | ||
| continue | ||
| val = node.meta.get("val", None) | ||
| if isinstance(val, torch.Tensor) and not val.is_contiguous(): | ||
| raise ValueError( | ||
| f"MLX backend requires contiguous input tensors, " | ||
| f"but input '{node.name}' has non-contiguous strides. " | ||
| f"shape={list(val.shape)}, stride={list(val.stride())}. " | ||
| f"Ensure example inputs passed to torch.export.export() " | ||
| f"are contiguous (call .contiguous() on them)." | ||
| ) | ||
| self.make_or_get_slot(node, id_space=IdSpace.Input) | ||
|
|
||
| for name in mutable_buffers: | ||
| node = placeholder_nodes.get(name) | ||
| if node is None or node.users == {}: | ||
| continue | ||
| self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) | ||
|
|
||
| classified_placeholders = ( | ||
| set(constant_tensors) | set(user_inputs) | set(mutable_buffers) | ||
| ) | ||
|
|
||
| for node in self.ep.graph.nodes: | ||
| if node.op == "placeholder": | ||
| if node.users == {}: | ||
| continue | ||
| if node.name in constant_tensors: | ||
| self.make_or_get_slot(node, id_space=IdSpace.Constant) | ||
| elif node.name in user_inputs: | ||
| val = node.meta.get("val", None) | ||
| if isinstance(val, torch.Tensor) and not val.is_contiguous(): | ||
| raise ValueError( | ||
| f"MLX backend requires contiguous input tensors, " | ||
| f"but input '{node.name}' has non-contiguous strides. " | ||
| f"shape={list(val.shape)}, stride={list(val.stride())}. " | ||
| f"Ensure example inputs passed to torch.export.export() " | ||
| f"are contiguous (call .contiguous() on them)." | ||
| ) | ||
| self.make_or_get_slot(node, id_space=IdSpace.Input) | ||
| elif node.name in mutable_buffers: | ||
| self.make_or_get_slot(node, id_space=IdSpace.MutableBuffer) | ||
| else: | ||
| if node.name not in classified_placeholders: | ||
| raise NotImplementedError( | ||
| f"Support for placeholder {node.name} is not implemented" | ||
| ) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why no embeeding?