Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions backends/mlx/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
AsStridedNode,
AsTypeNode,
Atan2Node,
BitwiseInvertNode,
BroadcastToNode,
CeilNode,
ClipNode,
Expand Down Expand Up @@ -3066,27 +3067,40 @@ def _where_handler(P: MLXProgramBuilder, n: Node) -> Slot:

@REGISTRY.register(target=[torch.ops.aten.bitwise_not.default])
def _bitwise_not_handler(P: MLXProgramBuilder, n: Node) -> Slot:
"""Handle aten.bitwise_not - for boolean tensors, dispatch to logical_not."""
"""Handle aten.bitwise_not - logical_not for bool, bitwise_invert for integers."""
args = P.args(n)
require_args(args, 1, 1, "aten.bitwise_not")
require_kwargs(P.kwargs(n), set(), "aten.bitwise_not")
x_meta = n.args[0].meta.get("val")
out = P.make_or_get_slot(n)

if x_meta is not None and x_meta.dtype == torch.bool:
# For boolean tensors, bitwise_not is equivalent to logical_not
out = P.make_or_get_slot(n)
if x_meta is None or not hasattr(x_meta, "dtype"):
raise NotImplementedError(
"aten.bitwise_not requires known input dtype metadata for MLX lowering"
)

if x_meta.dtype == torch.bool:
P.emit(
LogicalNotNode(
x=P.slot_to_tid(args[0]),
out=P.slot_to_tid(out),
)
)
return out
elif x_meta.dtype in {
torch.int32,
torch.int64,
}:
P.emit(
BitwiseInvertNode(
x=P.slot_to_tid(args[0]),
out=P.slot_to_tid(out),
)
)
else:
raise NotImplementedError(
f"aten.bitwise_not is only supported for boolean tensors. "
f"Got dtype={x_meta.dtype if x_meta else 'unknown'}"
f"aten.bitwise_not on dtype {x_meta.dtype} is not supported for MLX lowering"
)
return out
Comment thread
AlessandroVacca marked this conversation as resolved.


@REGISTRY.register(
Expand Down
11 changes: 11 additions & 0 deletions backends/mlx/runtime/MLXInterpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1380,6 +1380,13 @@ inline void exec_logical_not(
st.set_tensor(n.out, logical_not(st.const_tensor_ref(n.x), s));
}

inline void exec_bitwise_invert(
const BitwiseInvertNode& n,
ExecutionState& st,
StreamOrDevice s) {
st.set_tensor(n.out, bitwise_invert(st.const_tensor_ref(n.x), s));
}

inline void exec_logical_and(
const LogicalAndNode& n,
ExecutionState& st,
Expand Down Expand Up @@ -2028,6 +2035,10 @@ class Interpreter {
case OpCode::LOGICAL_NOT:
ops::exec_logical_not(std::get<LogicalNotNode>(instr.node), st, s);
break;
case OpCode::BITWISE_INVERT:
ops::exec_bitwise_invert(
std::get<BitwiseInvertNode>(instr.node), st, s);
break;
case OpCode::LOGICAL_AND:
ops::exec_logical_and(std::get<LogicalAndNode>(instr.node), st, s);
break;
Expand Down
8 changes: 7 additions & 1 deletion backends/mlx/serialization/schema.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,11 @@ table LogicalNotNode {
out: Tid (required);
}

table BitwiseInvertNode {
x: Tid (required);
out: Tid (required);
}

table LogicalAndNode {
a: Tid (required);
b: Tid (required);
Expand Down Expand Up @@ -1113,7 +1118,8 @@ union OpNode {
GatherMmNode,
GatherQmmNode,
ScanNode,
MetalKernelNode
MetalKernelNode,
BitwiseInvertNode
// BC: Add new op nodes here (append only)
}

Expand Down
1 change: 1 addition & 0 deletions backends/mlx/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4111,6 +4111,7 @@ def create_model(self) -> nn.Module:
{"op_name": "abs", "op_fn": torch.abs},
{"op_name": "neg", "op_fn": torch.neg},
{"op_name": "logical_not","op_fn": torch.logical_not, "shapes": [(2, 3, 4), (10,), (4, 8)], "dtypes": [torch.bool], "input_fn": _bool_input_fn()},
{"op_name": "bitwise_not_int", "op_fn": torch.bitwise_not, "shapes": _SHAPES_3, "dtypes": [torch.int32, torch.int64], "input_fn": _int_input_fn()},
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consciously decided to test only what was asked to be added, lmk if you need also the bool test case

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test whatever dtypes you assert as supported in the handler.

Also, can you paste the outcome of successful test in the PR description?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the test summary in the description of the PR any good?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not advertised anymore whatsoever

{"op_name": "isnan", "op_fn": torch.isnan, "shapes": _SHAPES_3, "dtypes": [torch.float32, torch.float16, torch.bfloat16], "input_fn": _nan_input_fn()},
# activations
{"op_name": "relu", "op_fn": torch.relu, "shapes": [(2, 3, 4), (10,), (4, 8), (2, 8, 16), (1, 128, 64)], "dtypes": [torch.float32], "input_fn": _input_fn(scale=2, offset=-1)},
Expand Down
Loading