Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fix style in Reduction.cpp and op_registry.py
  • Loading branch information
alexdean08 committed Aug 4, 2025
commit f891c4b9a638fec90d661ff44500b70f8df03b05
15 changes: 12 additions & 3 deletions backends/vulkan/op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,13 +385,22 @@ def check_reduce_node(node: torch.fx.Node) -> bool:
if memory_layout is not None:
for dim in dim_list:
# For WIDTH_PACKED layout, dimension 3 (W) is packed
if memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED and dim == 3:
if (
memory_layout == VkMemoryLayout.TENSOR_WIDTH_PACKED
and dim == 3
):
return False
# For HEIGHT_PACKED layout, dimension 2 (H) is packed
elif memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED and dim == 2:
elif (
memory_layout == VkMemoryLayout.TENSOR_HEIGHT_PACKED
and dim == 2
):
return False
# For CHANNELS_PACKED layout, dimension 1 (C) is packed
elif memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED and dim == 1:
elif (
memory_layout == VkMemoryLayout.TENSOR_CHANNELS_PACKED
and dim == 1
):
return False
except (AssertionError, KeyError, AttributeError):
# If we can't get memory layout information, we'll assume the dims aren't packed
Expand Down
33 changes: 20 additions & 13 deletions backends/vulkan/runtime/graph/ops/impl/Reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ void resize_reduce2d_node(
vTensorPtr in = graph->get_tensor(args[1].refs[0]);

// Extract the dimensions to reduce over
const std::vector<int64_t> dims_list = graph->extract_int_or_symint_list(resize_args.at(0));
const std::vector<int64_t> dims_list =
graph->extract_int_or_symint_list(resize_args.at(0));
int32_t reduce_dim1_nchw = dims_list[0];
int32_t reduce_dim2_nchw = dims_list[1];

Expand Down Expand Up @@ -161,24 +162,25 @@ void add_reduce2d_node(
const ValueRef dims_ref,
const ValueRef out,
const std::string& op_name) {

VK_CHECK_COND(
!graph.is_buffer_storage(in) && !graph.is_buffer_storage(out),
"Vulkan reduction only supports texture storage");

const int64_t ndim = graph.dim_of(in);

// Extract the two dimensions to reduce over
const std::vector<int64_t> dims_list = graph.extract_int_or_symint_list(dims_ref);
VK_CHECK_COND(dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");

const std::vector<int64_t> dims_list =
graph.extract_int_or_symint_list(dims_ref);
VK_CHECK_COND(
dims_list.size() == 2, "reduce2d requires exactly 2 dimensions");

int32_t reduce_dim1 = normalize(dims_list[0], ndim);
int32_t reduce_dim2 = normalize(dims_list[1], ndim);

// Convert to WHCN format
reduce_dim1 = nchw_dim_to_whcn_dim(reduce_dim1, ndim);
reduce_dim2 = nchw_dim_to_whcn_dim(reduce_dim2, ndim);

// Check that none of the reduction dims are packed
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim1);
VK_CHECK_COND(graph.packed_dim_of(in) != reduce_dim2);
Expand All @@ -193,7 +195,7 @@ void add_reduce2d_node(
VK_CHECK_COND(graph.concat_dim_of(out) != reduce_dim2);
}

std::string kernel_name = op_name + "2d"; // Add "2d" suffix
std::string kernel_name = op_name + "2d"; // Add "2d" suffix
kernel_name.reserve(kShaderNameReserve);
add_dtype_suffix(kernel_name, graph.dtype_of(out));

Expand All @@ -206,8 +208,10 @@ void add_reduce2d_node(
}
}

const ValueRef reduce_dim1_whcn_ref = graph.get_or_add_value_for_int(reduce_dim1);
const ValueRef reduce_dim2_whcn_ref = graph.get_or_add_value_for_int(reduce_dim2);
const ValueRef reduce_dim1_whcn_ref =
graph.get_or_add_value_for_int(reduce_dim1);
const ValueRef reduce_dim2_whcn_ref =
graph.get_or_add_value_for_int(reduce_dim2);
const ValueRef group_dim_whcn_ref = graph.get_or_add_value_for_int(group_dim);

graph.execute_nodes().emplace_back(new DynamicDispatchNode(
Expand All @@ -224,15 +228,18 @@ void add_reduce2d_node(
// Specialization Constants
{graph.packed_dim_of(out), reduce_dim1, reduce_dim2, group_dim},
// Resize Args
{dims_ref, reduce_dim1_whcn_ref, reduce_dim2_whcn_ref, group_dim_whcn_ref},
{dims_ref,
reduce_dim1_whcn_ref,
reduce_dim2_whcn_ref,
group_dim_whcn_ref},
// Resizing Logic
resize_reduce2d_node));
}

#define DEFINE_REDUCE_FN(op_name, out_arg_idx) \
void op_name(ComputeGraph& graph, const std::vector<ValueRef>& args) { \
const std::vector<int64_t> dims_list = \
graph.extract_int_or_symint_list(args[1]); \
graph.extract_int_or_symint_list(args[1]); \
if (dims_list.size() == 1) { \
const int64_t dim_val = dims_list.at(0); \
const ValueRef dim_ref = graph.get_or_add_value_for_int(dim_val); \
Expand Down