Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
[ET-VK] Adding function to set push constants in Command buffer.
Pull Request resolved: #7221

This diff adds a function to set push constants in the Command buffer for ET-VK. The changes include adding a new `set_push_constants` function to the CommandBuffer class and modifying the code in the CommandBuffer class to call this new function.
ghstack-source-id: 257227241
@exported-using-ghexport

Differential Revision: [D66714317](https://our.internmc.facebook.com/intern/diff/D66714317/)
  • Loading branch information
trviv authored and kirklandsign committed Dec 10, 2024
commit 85c4a4e9dca49dbe1244e1e0f0a2d6304e4b6b2a
13 changes: 12 additions & 1 deletion backends/vulkan/runtime/api/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ void Context::register_shader_dispatch(
const vkapi::DescriptorSet& descriptors,
vkapi::PipelineBarrier& pipeline_barrier,
const vkapi::ShaderInfo& shader_descriptor,
const utils::uvec3& global_workgroup_size) {
const utils::uvec3& global_workgroup_size,
const void* push_constants_data,
const uint32_t push_constants_size) {
// Adjust the global workgroup size based on the output tile size
uint32_t global_wg_w = utils::div_up(
global_workgroup_size[0u], shader_descriptor.out_tile_size[0u]);
Expand All @@ -145,6 +147,15 @@ void Context::register_shader_dispatch(
cmd_.bind_descriptors(descriptors.get_bind_handle());
cmd_.insert_barrier(pipeline_barrier);

if (push_constants_size > 0 && push_constants_data != nullptr) {
const VkDescriptorSetLayout shader_layout =
shader_layout_cache().retrieve(shader_descriptor.kernel_layout);
const VkPipelineLayout pipeline_layout =
pipeline_layout_cache().retrieve(shader_layout);
cmd_.set_push_constants(
pipeline_layout, push_constants_data, push_constants_size);
}

cmd_.dispatch(effective_global_wg);
}

Expand Down
4 changes: 3 additions & 1 deletion backends/vulkan/runtime/api/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ class Context final {
const vkapi::DescriptorSet&,
vkapi::PipelineBarrier&,
const vkapi::ShaderInfo&,
const utils::uvec3&);
const utils::uvec3&,
const void* = nullptr,
const uint32_t = 0);

void register_blit(
vkapi::PipelineBarrier&,
Expand Down
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/vk_api/Command.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,21 @@ void CommandBuffer::bind_descriptors(VkDescriptorSet descriptors) {
state_ = CommandBuffer::State::DESCRIPTORS_BOUND;
}

void CommandBuffer::set_push_constants(
VkPipelineLayout pipeline_layout,
const void* push_constants_data,
uint32_t push_constants_size) {
if (push_constants_data != nullptr && push_constants_size > 0) {
vkCmdPushConstants(
handle_,
pipeline_layout,
VK_SHADER_STAGE_COMPUTE_BIT,
0,
push_constants_size,
push_constants_data);
}
}

void CommandBuffer::insert_barrier(PipelineBarrier& pipeline_barrier) {
VK_CHECK_COND(
state_ == CommandBuffer::State::DESCRIPTORS_BOUND ||
Expand Down
1 change: 1 addition & 0 deletions backends/vulkan/runtime/vk_api/Command.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class CommandBuffer final {

void bind_pipeline(VkPipeline, VkPipelineLayout, const utils::uvec3);
void bind_descriptors(VkDescriptorSet);
void set_push_constants(VkPipelineLayout, const void*, uint32_t);

void insert_barrier(PipelineBarrier& pipeline_barrier);
void dispatch(const utils::uvec3&);
Expand Down