diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 28546f8332..fa5bba14fa 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -109,7 +109,7 @@ repos: args: [--fix, --spdx] files: | (?x) - [.](cmake|c|cpp|cu|cuh|h|hpp|sh|pxd|py|pyx|rs|java)$| + [.](cmake|c|cpp([.]in)?|cu([.]in)?|cuh|h|hpp|sh|pxd|py|pyx|rs|java)$| CMakeLists[.]txt$| CMakeLists_standalone[.]txt$| meta[.]yaml$| diff --git a/conda/environments/all_cuda-131_arch-aarch64.yaml b/conda/environments/all_cuda-131_arch-aarch64.yaml index ca42510723..31a7ea3024 100644 --- a/conda/environments/all_cuda-131_arch-aarch64.yaml +++ b/conda/environments/all_cuda-131_arch-aarch64.yaml @@ -13,6 +13,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-python>=13.0.1,<14.0 diff --git a/conda/environments/all_cuda-131_arch-x86_64.yaml b/conda/environments/all_cuda-131_arch-x86_64.yaml index a5261a38dc..8638b280e5 100644 --- a/conda/environments/all_cuda-131_arch-x86_64.yaml +++ b/conda/environments/all_cuda-131_arch-x86_64.yaml @@ -13,6 +13,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-python>=13.0.1,<14.0 diff --git a/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml index d33f325c39..ff3d1f053c 100644 --- a/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml +++ b/conda/environments/bench_ann_cuda-131_arch-aarch64.yaml @@ -12,6 +12,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-python>=13.0.1,<14.0 diff --git a/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml index ff5c4f021f..eeed1d1360 100644 --- a/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml +++ b/conda/environments/bench_ann_cuda-131_arch-x86_64.yaml @@ -12,6 +12,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-python>=13.0.1,<14.0 diff --git a/conda/environments/go_cuda-131_arch-aarch64.yaml b/conda/environments/go_cuda-131_arch-aarch64.yaml index 135f6a88cc..4c5f1862c9 100644 --- a/conda/environments/go_cuda-131_arch-aarch64.yaml +++ b/conda/environments/go_cuda-131_arch-aarch64.yaml @@ -12,6 +12,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-version=13.1 diff --git a/conda/environments/go_cuda-131_arch-x86_64.yaml b/conda/environments/go_cuda-131_arch-x86_64.yaml index df6a779331..0bd7c0a2d3 100644 --- a/conda/environments/go_cuda-131_arch-x86_64.yaml +++ b/conda/environments/go_cuda-131_arch-x86_64.yaml @@ -12,6 +12,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-version=13.1 diff --git a/conda/environments/rust_cuda-131_arch-aarch64.yaml b/conda/environments/rust_cuda-131_arch-aarch64.yaml index 062cbc8ea0..2c6636e695 100644 --- a/conda/environments/rust_cuda-131_arch-aarch64.yaml +++ b/conda/environments/rust_cuda-131_arch-aarch64.yaml @@ -11,6 +11,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-version=13.1 diff --git a/conda/environments/rust_cuda-131_arch-x86_64.yaml b/conda/environments/rust_cuda-131_arch-x86_64.yaml index 2b96d4a64e..dbe4367816 100644 --- a/conda/environments/rust_cuda-131_arch-x86_64.yaml +++ b/conda/environments/rust_cuda-131_arch-x86_64.yaml @@ -11,6 +11,7 @@ dependencies: - cmake>=3.30.4 - cuda-cudart-dev - cuda-nvcc +- cuda-nvrtc-dev - cuda-nvtx-dev - cuda-profiler-api - cuda-version=13.1 diff --git a/conda/recipes/libcuvs/recipe.yaml b/conda/recipes/libcuvs/recipe.yaml index abd3031a94..a916dbde8e 100644 --- a/conda/recipes/libcuvs/recipe.yaml +++ b/conda/recipes/libcuvs/recipe.yaml @@ -75,6 +75,7 @@ cache: - if: cuda_major == "13" then: - libnvjitlink-dev + - cuda-nvrtc-dev - librmm =${{ minor_version }} - libraft-headers =${{ minor_version }} - nccl ${{ nccl_version }} @@ -124,6 +125,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - libraft-headers =${{ minor_version }} @@ -137,6 +139,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc ignore_run_exports: by_name: - cuda-cudart @@ -153,6 +156,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc about: homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }} @@ -192,6 +196,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libcuvs-headers", exact=True) }} @@ -206,6 +211,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc ignore_run_exports: by_name: - cuda-cudart @@ -222,6 +228,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc about: homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }} @@ -259,6 +266,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_compatible("cuda-version", upper_bound="x", lower_bound="x") }} - ${{ pin_subpackage("libcuvs-headers", exact=True) }} @@ -273,6 +281,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc ignore_run_exports: by_name: - cuda-cudart @@ -286,6 +295,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc - librmm - mkl - nccl @@ -346,6 +356,10 @@ outputs: - librmm - mkl - nccl + - if: cuda_major == "13" + then: + - cuda-nvrtc + - libnvjitlink about: homepage: ${{ load_from_file("python/libcuvs/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/libcuvs/pyproject.toml").project.license }} @@ -426,6 +440,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink-dev + - cuda-nvrtc-dev run: - ${{ pin_subpackage("libcuvs-headers", exact=True) }} - ${{ pin_subpackage("libcuvs", exact=True) }} @@ -439,6 +454,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc ignore_run_exports: by_name: - cuda-cudart @@ -452,6 +468,7 @@ outputs: - if: cuda_major == "13" then: - libnvjitlink + - cuda-nvrtc - librmm - mkl - nccl @@ -526,6 +543,10 @@ outputs: - librmm - mkl - nccl + - if: cuda_major == "13" + then: + - cuda-nvrtc + - libnvjitlink about: homepage: ${{ load_from_file("python/cuvs_bench/pyproject.toml").project.urls.Homepage }} license: ${{ load_from_file("python/cuvs_bench/pyproject.toml").project.license }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d90579812a..8c33623615 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -213,6 +213,12 @@ endif() # ################################################################################################## # * cuvs --------------------------------------------------------------------- if(NOT BUILD_CPU_ONLY) + set(JIT_LTO_TARGET_ARCHITECTURE "") + set(JIT_LTO_COMPILATION OFF) + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) + set(JIT_LTO_TARGET_ARCHITECTURE "75-real") + set(JIT_LTO_COMPILATION ON) + endif() add_library(cuvs_cpp_headers INTERFACE) add_library(cuvs::cuvs_cpp_headers ALIAS cuvs_cpp_headers) @@ -222,6 +228,9 @@ if(NOT BUILD_CPU_ONLY) "$" "$" ) + target_compile_definitions( + cuvs_cpp_headers INTERFACE $<$:CUVS_ENABLE_JIT_LTO> + ) target_link_libraries(cuvs_cpp_headers INTERFACE raft::raft rmm::rmm) add_library( @@ -317,6 +326,10 @@ if(NOT BUILD_CPU_ONLY) CUDA_SEPARABLE_COMPILATION ON POSITION_INDEPENDENT_CODE ON ) + target_compile_definitions( + cuvs-cagra-search PRIVATE $<$:CUVS_BUILD_CAGRA_HNSWLIB> + $<$:NVTX_ENABLED> + ) target_link_libraries( cuvs-cagra-search PRIVATE cuvs::cuvs_cpp_headers $ @@ -343,13 +356,6 @@ if(NOT BUILD_CPU_ONLY) ) endif() - set(JIT_LTO_TARGET_ARCHITECTURE "") - set(JIT_LTO_COMPILATION OFF) - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0) - set(JIT_LTO_TARGET_ARCHITECTURE "75-real") - set(JIT_LTO_COMPILATION ON) - endif() - if(JIT_LTO_COMPILATION) # Generate interleaved scan kernel files at build time include(cmake/modules/generate_jit_lto_kernels.cmake) @@ -365,74 +371,216 @@ if(NOT BUILD_CPU_ONLY) "$<$:${CUVS_CUDA_FLAGS}>" ) target_compile_features(jit_lto_kernel_usage_requirements INTERFACE cuda_std_20) + target_compile_definitions(jit_lto_kernel_usage_requirements INTERFACE BUILD_KERNEL) target_link_libraries( jit_lto_kernel_usage_requirements INTERFACE rmm::rmm raft::raft CCCL::CCCL ) - block(PROPAGATE interleaved_scan_files metric_files filter_files post_lambda_files) + block(PROPAGATE jit_lto_kernel_files) set(CMAKE_CUDA_ARCHITECTURES ${JIT_LTO_TARGET_ARCHITECTURE}) generate_jit_lto_kernels( - interleaved_scan_files + jit_lto_kernel_files NAME_FORMAT - "interleaved_scan_capacity_@capacity@_veclen_@veclen@_@ascending_descending@_@compute_norm_name@_data_@type_abbrev@_acc_@acc_abbrev@_idx_@idx_abbrev@" + "ivf_flat_interleaved_scan_capacity_@capacity@_veclen_@veclen@_@ascending_descending@_@compute_norm_name@_data_@type_abbrev@_acc_@acc_abbrev@_idx_@idx_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in" EMBEDDED_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in" - OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/interleaved_scan" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/interleaved_scan" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) generate_jit_lto_kernels( - metric_files - NAME_FORMAT "metric_@metric_name@_veclen_@veclen@_data_@type_abbrev@_acc_@acc_abbrev@" + jit_lto_kernel_files + NAME_FORMAT + "ivf_flat_metric_@metric_name@_veclen_@veclen@_data_@type_abbrev@_acc_@acc_abbrev@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric_kernel.cu.in" EMBEDDED_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in" - OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/metric" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/metric" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) generate_jit_lto_kernels( - filter_files - NAME_FORMAT "@filter_name@" + jit_lto_kernel_files + NAME_FORMAT "ivf_flat_@filter_name@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter_matrix.json" KERNEL_INPUT_FILE - "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter_kernel.cu.in" + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/jit_lto_kernels/filter_kernel.cu.in" EMBEDDED_INPUT_FILE - "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/filter_embedded.cpp.in" - OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/filter" + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/jit_lto_kernels/filter_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/filter" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) generate_jit_lto_kernels( - post_lambda_files - NAME_FORMAT "@post_lambda_name@" + jit_lto_kernel_files + NAME_FORMAT "ivf_flat_@post_lambda_name@" MATRIX_JSON_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_matrix.json" KERNEL_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_kernel.cu.in" EMBEDDED_INPUT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in" - OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/post_lambda" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/ivf_flat/post_lambda" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/setup_workspace" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_@pq_bits@pq_@pq_len@subd_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT "cagra_dist_op_@metric_tag@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/dist_op" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_apply_normalization_standard@normalization_suffix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_embedded.cpp.in" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/apply_normalization_standard" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_search_single_cta_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_search_single_cta_p_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_single_cta_p" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_search_multi_cta@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/search_multi_cta" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_random_pickup@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/random_pickup" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT + "cagra_compute_distance_to_child_nodes@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@_data_@data_abbrev@_query_@query_abbrev@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_embedded.cpp.in" + OUTPUT_DIRECTORY + "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/compute_distance_to_child_nodes" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT "cagra_appy_filter" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/apply_filter" + KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements + ) + generate_jit_lto_kernels( + jit_lto_kernel_files + NAME_FORMAT "cagra_@filter_name@" + MATRIX_JSON_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json" + KERNEL_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/jit_lto_kernels/filter_kernel.cu.in" + EMBEDDED_INPUT_FILE + "${CMAKE_CURRENT_SOURCE_DIR}/src/neighbors/detail/jit_lto_kernels/filter_embedded.cpp.in" + OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/generated_kernels/cagra/filter" KERNEL_LINK_LIBRARIES jit_lto_kernel_usage_requirements ) endblock() add_library( cuvs_jit_lto_kernels STATIC - ${interleaved_scan_files} - ${metric_files} - ${filter_files} - ${post_lambda_files} + ${jit_lto_kernel_files} src/detail/jit_lto/AlgorithmLauncher.cu src/detail/jit_lto/AlgorithmPlanner.cu src/detail/jit_lto/FragmentDatabase.cu src/detail/jit_lto/FragmentEntry.cu src/detail/jit_lto/nvjitlink_checker.cpp + src/detail/jit_lto/NVRTCLTOFragmentCompiler.cu ) set_target_properties( cuvs_jit_lto_kernels PROPERTIES POSITION_INDEPENDENT_CODE ON CXX_STANDARD 20 @@ -442,7 +590,7 @@ if(NOT BUILD_CPU_ONLY) PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/include" "${CMAKE_CURRENT_SOURCE_DIR}/src" "${CMAKE_CURRENT_SOURCE_DIR}/../c/include" ) - target_link_libraries(cuvs_jit_lto_kernels PRIVATE raft::raft) + target_link_libraries(cuvs_jit_lto_kernels PRIVATE raft::raft CUDA::nvJitLink CUDA::nvrtc) add_library(cuvs::cuvs_jit_lto_kernels ALIAS cuvs_jit_lto_kernels) endif() @@ -689,6 +837,7 @@ if(NOT BUILD_CPU_ONLY) ${CUVS_CTK_MATH_DEPENDENCIES} $ $ + $<$:CUDA::nvrtc> ) target_include_directories( @@ -772,6 +921,7 @@ if(NOT BUILD_CPU_ONLY) $ $ $<$:CUDA::nvJitLink> + $<$:CUDA::nvrtc> $<$:$> ) @@ -832,6 +982,7 @@ SECTIONS PRIVATE $ $<$:CUDA::nvJitLink> + $<$:CUDA::nvrtc> $<$:CUDA::nvtx3> $ $ diff --git a/cpp/cmake/modules/generate_jit_lto_kernels.cmake b/cpp/cmake/modules/generate_jit_lto_kernels.cmake index 1454bac97e..e27f432b76 100644 --- a/cpp/cmake/modules/generate_jit_lto_kernels.cmake +++ b/cpp/cmake/modules/generate_jit_lto_kernels.cmake @@ -129,6 +129,11 @@ function(generate_jit_lto_kernels source_list_var) find_package(Python3 REQUIRED COMPONENTS Interpreter) if(_JIT_LTO_MATRIX_JSON_FILE) + set_property( + DIRECTORY + PROPERTY CMAKE_CONFIGURE_DEPENDS "${_JIT_LTO_MATRIX_JSON_FILE}" + APPEND + ) compute_matrix_product(matrix_product MATRIX_JSON_FILE "${_JIT_LTO_MATRIX_JSON_FILE}") else() compute_matrix_product(matrix_product MATRIX_JSON_STRING "${_JIT_LTO_MATRIX_JSON_STRING}") diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 8ecf3686be..ba7f68f09d 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -1,6 +1,6 @@ # ============================================================================= # cmake-format: off -# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: Apache-2.0 # cmake-format: on @@ -60,8 +60,8 @@ endfunction() # To use a different RAFT locally, set the CMake variable # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} + FORK divyegala + PINNED_TAG unneeded-cccl-includes ENABLE_MNMG_DEPENDENCIES OFF ENABLE_NVTX OFF BUILD_STATIC_DEPS ${CUVS_STATIC_RAPIDS_LIBRARIES} diff --git a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp index 18e7f7cb2f..6f551170c4 100644 --- a/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp +++ b/cpp/include/cuvs/detail/jit_lto/AlgorithmLauncher.hpp @@ -34,10 +34,20 @@ struct AlgorithmLauncher { this->call(stream, grid, block, shared_mem, kernel_args); } + template + void dispatch_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, Args&&... args) + { + void* kernel_args[] = {const_cast(static_cast(&args))...}; + this->call_cooperative(stream, grid, block, shared_mem, kernel_args); + } + cudaKernel_t get_kernel() { return this->kernel; } private: void call(cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); + void call_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** args); cudaKernel_t kernel; cudaLibrary_t library; }; diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp index aeb170d861..efedf2ba91 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentDatabase.hpp @@ -6,6 +6,7 @@ #pragma once #include +#include #include #include @@ -21,6 +22,7 @@ class FragmentDatabase { FragmentDatabase& operator=(FragmentDatabase const&) = delete; FragmentEntry* get_fragment(std::string const& key); + bool has_fragment(std::string const& key) const; private: FragmentDatabase(); @@ -34,6 +36,11 @@ class FragmentDatabase { unsigned char const* blob, std::size_t size); + friend void registerNVRTCFragment(std::string const& key, + std::unique_ptr&& program, + std::size_t size); + + mutable std::mutex cache_mutex_; std::unordered_map> cache; }; @@ -43,3 +50,7 @@ void registerFatbinFragment(std::string const& algo, std::string const& params, unsigned char const* blob, std::size_t size); + +void registerNVRTCFragment(std::string const& key, + std::unique_ptr&& program, + std::size_t size); diff --git a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp index a376068425..3bbe7d31a8 100644 --- a/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp +++ b/cpp/include/cuvs/detail/jit_lto/FragmentEntry.hpp @@ -30,3 +30,12 @@ struct FatbinFragmentEntry final : FragmentEntry { std::size_t data_size = 0; unsigned char const* data_view = nullptr; }; + +struct NVRTCFragmentEntry final : FragmentEntry { + NVRTCFragmentEntry(std::string const& key, std::unique_ptr&& program, std::size_t size); + + virtual bool add_to(nvJitLinkHandle& handle) const; + + std::size_t data_size = 0; + std::unique_ptr program{}; +}; diff --git a/cpp/include/cuvs/detail/jit_lto/NVRTCLTOFragmentCompiler.hpp b/cpp/include/cuvs/detail/jit_lto/NVRTCLTOFragmentCompiler.hpp new file mode 100644 index 0000000000..a1f598b6a5 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/NVRTCLTOFragmentCompiler.hpp @@ -0,0 +1,19 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +struct NVRTCLTOFragmentCompiler { + NVRTCLTOFragmentCompiler(); + + void compile(std::string const& key, std::string const& code) const; + + std::vector standard_compile_opts; +}; + +NVRTCLTOFragmentCompiler& nvrtc_compiler(); diff --git a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp b/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp deleted file mode 100644 index d9ed7e6b0b..0000000000 --- a/cpp/include/cuvs/detail/jit_lto/ivf_flat/interleaved_scan_tags.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -namespace cuvs::neighbors::ivf_flat::detail { - -// Tag types for data types -struct tag_f {}; -struct tag_h {}; -struct tag_sc {}; -struct tag_uc {}; - -// Tag types for accumulator types -struct tag_acc_f {}; -struct tag_acc_h {}; -struct tag_acc_i {}; -struct tag_acc_ui {}; - -// Tag types for index types -struct tag_idx_l {}; - -// Tag types for filter subtypes -struct tag_filter_bitset_impl {}; -struct tag_filter_none_impl {}; - -// Tag types for sample filter types with full template info -template -struct tag_filter {}; - -// Tag types for distance metrics with full template info -template -struct tag_metric_euclidean {}; - -template -struct tag_metric_inner_product {}; - -// Tag types for post-processing -struct tag_post_identity {}; -struct tag_post_sqrt {}; -struct tag_post_compose {}; - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/detail/jit_lto/registration_tags.hpp b/cpp/include/cuvs/detail/jit_lto/registration_tags.hpp new file mode 100644 index 0000000000..b9d244f799 --- /dev/null +++ b/cpp/include/cuvs/detail/jit_lto/registration_tags.hpp @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::detail::jit_lto { + +struct tag_f {}; +struct tag_h {}; +struct tag_sc {}; +struct tag_uc {}; +struct tag_idx_l {}; +struct tag_filter_none {}; +struct tag_filter_bitset {}; + +} // namespace cuvs::detail::jit_lto + +namespace cuvs::neighbors::cagra::detail { + +using cuvs::detail::jit_lto::tag_f; +using cuvs::detail::jit_lto::tag_filter_bitset; +using cuvs::detail::jit_lto::tag_filter_none; +using cuvs::detail::jit_lto::tag_h; +using cuvs::detail::jit_lto::tag_idx_l; +using cuvs::detail::jit_lto::tag_sc; +using cuvs::detail::jit_lto::tag_uc; + +struct tag_idx_ui {}; +struct tag_dist_f {}; +struct tag_metric_l2 {}; +struct tag_metric_inner_product {}; +struct tag_metric_cosine {}; +struct tag_metric_hamming {}; +struct tag_team_8 {}; +struct tag_team_16 {}; +struct tag_team_32 {}; +struct tag_dim_128 {}; +struct tag_dim_256 {}; +struct tag_dim_512 {}; +struct tag_pq_bits_8 {}; +struct tag_pq_len_2 {}; +struct tag_pq_len_4 {}; +struct tag_codebook_half {}; + +} // namespace cuvs::neighbors::cagra::detail + +namespace cuvs::neighbors::ivf_flat::detail { + +using cuvs::detail::jit_lto::tag_f; +using cuvs::detail::jit_lto::tag_filter_bitset; +using cuvs::detail::jit_lto::tag_filter_none; +using cuvs::detail::jit_lto::tag_h; +using cuvs::detail::jit_lto::tag_idx_l; +using cuvs::detail::jit_lto::tag_sc; +using cuvs::detail::jit_lto::tag_uc; + +struct tag_acc_f {}; +struct tag_acc_h {}; +struct tag_acc_i {}; +struct tag_acc_ui {}; + +template +struct tag_metric_euclidean {}; + +template +struct tag_metric_inner_product {}; + +template +struct tag_metric_custom_udf {}; + +struct tag_post_identity {}; +struct tag_post_sqrt {}; +struct tag_post_compose {}; + +} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/include/cuvs/distance/distance.hpp b/cpp/include/cuvs/distance/distance.hpp index 13c8c7bd7e..df7d45c8a6 100644 --- a/cpp/include/cuvs/distance/distance.hpp +++ b/cpp/include/cuvs/distance/distance.hpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2021-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2021-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -59,7 +59,9 @@ enum class DistanceType : int { /** Bitstring Hamming distance **/ BitwiseHamming = 20, /** Precomputed (special value) **/ - Precomputed = 100 + Precomputed = 100, + /** Custom metric UDF **/ + CustomUDF = 101 }; /** diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index 23c6dd4944..2b11cd390d 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -71,6 +71,8 @@ struct index_params : cuvs::neighbors::index_params { struct search_params : cuvs::neighbors::search_params { /** The number of clusters to search. */ uint32_t n_probes = 20; + /** Custom metric UDF code. */ + std::optional metric_udf = std::nullopt; }; static_assert(std::is_aggregate_v); @@ -167,6 +169,8 @@ struct index : cuvs::neighbors::index { /** Distance metric used for clustering. */ cuvs::distance::DistanceType metric() const noexcept; + void set_metric(cuvs::distance::DistanceType metric); + /** Whether `centers()` change upon extending the index (ivf_flat::extend). */ bool adaptive_centers() const noexcept; @@ -3047,4 +3051,417 @@ void recompute_internal_state(const raft::resources& res, index 1: wraps packed bytes in a 32-bit word + * + * @tparam T Data type (float, __half, int8_t, uint8_t) + * @tparam AccT Storage/accumulator type (float, __half, int32_t, uint32_t) + * @tparam Veclen Vector length (1, 2, 4, 8, 16) + */ +template +struct point { + using element_type = T; + using storage_type = AccT; + static constexpr int veclen = Veclen; + + storage_type data_; + + point() = default; + __device__ __host__ explicit point(storage_type d) : data_(d) {} + + __device__ __forceinline__ storage_type raw() const { return data_; } + __device__ __forceinline__ storage_type& raw() { return data_; } + + __device__ __host__ static constexpr int size() + { + if constexpr ((std::is_same_v || std::is_same_v) && Veclen > 1) { + return 4; + } else { + return 1; + } + } + + __device__ __host__ static constexpr bool is_packed() + { + return (std::is_same_v || std::is_same_v) && Veclen > 1; + } + + __device__ __forceinline__ T operator[](int i) const + { + if constexpr (std::is_same_v && Veclen > 1) { + return static_cast((data_ >> (i * 8)) & 0xFF); + } else if constexpr (std::is_same_v && Veclen > 1) { + return static_cast((data_ >> (i * 8)) & 0xFF); + } else { + (void)i; + return static_cast(data_); + } + } +}; + +/** + * @brief Base interface for custom distance metrics. + */ +template +struct metric_interface { + using point_type = point; + + virtual __device__ void operator()(AccT& acc, point_type x, point_type y) = 0; + virtual ~metric_interface() = default; +}; + +// ============================================================ +// Helper Operations - Deduce Veclen from point type! +// ============================================================ + +/** @brief Squared difference: (x - y)² */ +template +__device__ __forceinline__ AccT squared_diff(point x, point y) +{ + if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffu4(x.raw(), y.raw()); + return __dp4a(diff, diff, AccT{0}); + } else if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffs4(x.raw(), y.raw()); + return __dp4a(diff, diff, static_cast(0)); + } else { + auto diff = x.raw() - y.raw(); + return diff * diff; + } +} + +/** @brief Absolute difference: |x - y| */ +template +__device__ __forceinline__ AccT abs_diff(point x, point y) +{ + if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffu4(x.raw(), y.raw()); + } else if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffs4(x.raw(), y.raw()); + } else { + auto a = x.raw(); + auto b = y.raw(); + return (a > b) ? (a - b) : (b - a); + } +} + +/** @brief Dot product: x · y */ +template +__device__ __forceinline__ AccT dot_product(point x, point y) +{ + if constexpr ((std::is_same_v || std::is_same_v) && V > 1) { + return __dp4a(x.raw(), y.raw(), AccT{0}); + } else { + return x.raw() * y.raw(); + } +} + +/** @brief Element-wise product: x * y */ +template +__device__ __forceinline__ AccT product(point x, point y) +{ + return dot_product(x, y); +} + +/** @brief Element-wise sum: x + y */ +template +__device__ __forceinline__ AccT sum(point x, point y) +{ + if constexpr ((std::is_same_v || std::is_same_v) && V > 1) { + AccT result = 0; + for (int i = 0; i < x.size(); ++i) { + result += static_cast(x[i]) + static_cast(y[i]); + } + return result; + } else { + return x.raw() + y.raw(); + } +} + +/** @brief Maximum element: max(x, y) */ +template +__device__ __forceinline__ AccT max_elem(point x, point y) +{ + if constexpr ((std::is_same_v || std::is_same_v) && V > 1) { + AccT result = 0; + for (int i = 0; i < x.size(); ++i) { + auto xi = static_cast(x[i]); + auto yi = static_cast(y[i]); + auto val = (xi > yi) ? xi : yi; + if (val > result) result = val; + } + return result; + } else { + auto a = x.raw(); + auto b = y.raw(); + return (a > b) ? a : b; + } +} + +// ============================================================================ +// String versions for JIT compilation +// ============================================================================ + +constexpr std::string_view point_code = R"( +template +struct point { + using element_type = T; + using storage_type = AccT; + static constexpr int veclen = Veclen; + + storage_type data_; + + point() = default; + __device__ __host__ explicit point(storage_type d) : data_(d) {} + + __device__ __forceinline__ storage_type raw() const { return data_; } + __device__ __forceinline__ storage_type& raw() { return data_; } + + __device__ __host__ static constexpr int size() + { + if constexpr ((std::is_same_v || std::is_same_v) && Veclen > 1) { + return 4; + } else { + return 1; + } + } + + __device__ __host__ static constexpr bool is_packed() + { + return (std::is_same_v || std::is_same_v) && Veclen > 1; + } + + __device__ __forceinline__ T operator[](int i) const + { + if constexpr (std::is_same_v && Veclen > 1) { + return static_cast((data_ >> (i * 8)) & 0xFF); + } else if constexpr (std::is_same_v && Veclen > 1) { + return static_cast((data_ >> (i * 8)) & 0xFF); + } else { + (void)i; + return static_cast(data_); + } + } +}; +)"; + +constexpr std::string_view metric_interface_code = R"( +template +struct metric_interface { + using point_type = point; + + virtual __device__ void operator()(AccT& acc, point_type x, point_type y) = 0; + virtual ~metric_interface() = default; +}; +)"; + +constexpr std::string_view squared_diff_code = R"( +template +__device__ __forceinline__ AccT squared_diff(point x, point y) +{ + if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffu4(x.raw(), y.raw()); + return __dp4a(diff, diff, AccT{0}); + } else if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffs4(x.raw(), y.raw()); + return __dp4a(diff, diff, static_cast(0)); + } else { + auto diff = x.raw() - y.raw(); + return diff * diff; + } +} +)"; + +constexpr std::string_view abs_diff_code = R"( +template +__device__ __forceinline__ AccT abs_diff(point x, point y) +{ + if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffu4(x.raw(), y.raw()); + } else if constexpr (std::is_same_v && V > 1) { + auto diff = __vabsdiffs4(x.raw(), y.raw()); + } else { + auto a = x.raw(); + auto b = y.raw(); + return (a > b) ? (a - b) : (b - a); + } +} +)"; + +constexpr std::string_view dot_product_code = R"( +template +__device__ __forceinline__ AccT dot_product(point x, point y) +{ + if constexpr ((std::is_same_v || std::is_same_v) && V > 1) { + return __dp4a(x.raw(), y.raw(), AccT{0}); + } else { + return x.raw() * y.raw(); + } +} +)"; + +constexpr std::string_view product_code = R"( +template +__device__ __forceinline__ AccT product(point x, point y) +{ + return dot_product(x, y); +} +)"; + +constexpr std::string_view sum_code = R"( +template +__device__ __forceinline__ AccT sum(point x, point y) +{ + if constexpr ((std::is_same_v || std::is_same_v) && V > 1) { + AccT result = 0; + for (int i = 0; i < x.size(); ++i) { + result += static_cast(x[i]) + static_cast(y[i]); + } + return result; + } else { + return x.raw() + y.raw(); + } +} +)"; + +constexpr std::string_view max_elem_code = R"( +template +__device__ __forceinline__ AccT max_elem(point x, point y) +{ + if constexpr ((std::is_same_v || std::is_same_v) && V > 1) { + AccT result = 0; + for (int i = 0; i < x.size(); ++i) { + auto xi = static_cast(x[i]); + auto yi = static_cast(y[i]); + auto val = (xi > yi) ? xi : yi; + if (val > result) result = val; + } + return result; + } else { + auto a = x.raw(); + auto b = y.raw(); + return (a > b) ? a : b; + } +} +)"; + +/** + * @brief Preamble code for JIT compilation. + * + * nvrtc doesn't have access to standard library headers, so we define + * the necessary types and utilities inline. + */ +constexpr std::string_view jit_preamble_code = R"( +/* Fixed-width integer types for nvrtc */ +using int8_t = signed char; +using uint8_t = unsigned char; +using int32_t = int; +using uint32_t = unsigned int; +using int64_t = long long; +using uint64_t = unsigned long long; + +/* std::is_same_v implementation for nvrtc */ +namespace std { +template struct is_same { static constexpr bool value = false; }; +template struct is_same { static constexpr bool value = true; }; +template inline constexpr bool is_same_v = is_same::value; +} +)"; + +/** + * @brief Define a custom distance metric with compile-time validation. + * + * This macro creates: + * 1. A struct that inherits from metric_interface (compile-time validation) + * 2. A function NAME_udf() that returns a metric_source for JIT compilation + * 3. A BODY that needs to be compiled by nvrtc and must be valid CUDA device code + * + * @param NAME The name of your metric (becomes struct name and function prefix) + * @param BODY The body of operator()(AccT& acc, point_type x, point_type y) + * + * Available in BODY: + * acc - Accumulated distance (AccT&, modify in-place) + * x, y - Vector elements (point) + * T - Data type (float, __half, int8_t, uint8_t) + * AccT - Accumulator type + * Veclen - Vector length (compile-time constant) + * + * x and y provide: + * x.raw() - Raw packed storage (for power users) + * x[i] - Unpacked element access + * x.size() - Number of elements (4 for packed int8, 1 for float) + * x.is_packed() - Whether data is packed (constexpr) + * + * Helper functions (Veclen deduced automatically!): + * squared_diff(x, y) - (x-y)² optimized for all types + * abs_diff(x, y) - |x-y| optimized for all types + * dot_product(x, y) - x·y optimized for all types + * product(x, y) - element-wise product + * + * Example: + * CUVS_METRIC(my_l2, { + * acc += squared_diff(x, y); // Just works for all types! + * }) + * + * CUVS_METRIC(my_chebyshev, { + * for (int i = 0; i < x.size(); ++i) { + * auto diff = (x[i] > y[i]) ? (x[i] - y[i]) : (y[i] - x[i]); + * if (diff > acc) acc = diff; + * } + * }) + */ +#define CUVS_METRIC(NAME, BODY) \ + template \ + struct NAME : cuvs::neighbors::ivf_flat::experimental::udf::metric_interface { \ + using point_type = cuvs::neighbors::ivf_flat::experimental::udf::point; \ + __device__ void operator()(AccT& acc, point_type x, point_type y) override { BODY } \ + }; \ + \ + inline std::string NAME##_udf() \ + { \ + using namespace cuvs::neighbors::ivf_flat::experimental::udf; \ + std::string result; \ + result += jit_preamble_code; \ + result += point_code; \ + result += squared_diff_code; \ + result += abs_diff_code; \ + result += dot_product_code; \ + result += product_code; \ + result += sum_code; \ + result += max_elem_code; \ + result += metric_interface_code; \ + result += R"( \ +template \ +struct )" #NAME R"( : metric_interface { \ + using point_type = point; \ + __device__ void operator()(AccT& acc, point_type x, point_type y) override \ +)" #BODY R"( \ +}; \ + \ +namespace cuvs { namespace neighbors { namespace ivf_flat { namespace detail { \ +template \ +__device__ __forceinline__ void compute_dist(AccT& acc, AccT x, AccT y) \ +{ \ + ::)" #NAME R"( metric; \ + metric(acc, ::point(x), ::point(y)); \ +} \ +}}}} \ +)"; \ + return result; \ + } + +} // namespace experimental::udf +#endif + } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/detail/jit_lto/AlgorithmLauncher.cu b/cpp/src/detail/jit_lto/AlgorithmLauncher.cu index 0402ef8304..ef72a36107 100644 --- a/cpp/src/detail/jit_lto/AlgorithmLauncher.cu +++ b/cpp/src/detail/jit_lto/AlgorithmLauncher.cu @@ -24,7 +24,6 @@ AlgorithmLauncher::AlgorithmLauncher(AlgorithmLauncher&& other) noexcept AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexcept { if (this != &other) { - // Unload current library if it exists if (library != nullptr) { cudaLibraryUnload(library); } kernel = other.kernel; library = other.library; @@ -36,18 +35,31 @@ AlgorithmLauncher& AlgorithmLauncher::operator=(AlgorithmLauncher&& other) noexc void AlgorithmLauncher::call( cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args) +{ + cudaLaunchConfig_t config; + config.gridDim = grid; + config.blockDim = block; + config.stream = stream; + config.dynamicSmemBytes = shared_mem; + config.numAttrs = 0; + + RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); +} + +void AlgorithmLauncher::call_cooperative( + cudaStream_t stream, dim3 grid, dim3 block, std::size_t shared_mem, void** kernel_args) { cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; + attribute[0].id = cudaLaunchAttributeCooperative; + attribute[0].val.cooperative = 1; cudaLaunchConfig_t config; config.gridDim = grid; config.blockDim = block; config.stream = stream; - config.attrs = attribute; - config.numAttrs = 1; config.dynamicSmemBytes = shared_mem; + config.numAttrs = 1; + config.attrs = attribute; RAFT_CUDA_TRY(cudaLaunchKernelExC(&config, kernel, kernel_args)); } diff --git a/cpp/src/detail/jit_lto/AlgorithmPlanner.cu b/cpp/src/detail/jit_lto/AlgorithmPlanner.cu index 3f8a3eeba1..0ae28d3ced 100644 --- a/cpp/src/detail/jit_lto/AlgorithmPlanner.cu +++ b/cpp/src/detail/jit_lto/AlgorithmPlanner.cu @@ -5,13 +5,14 @@ #include "nvjitlink_checker.hpp" -#include -#include +#include +#include +#include +#include #include #include #include #include -#include #include #include @@ -55,13 +56,14 @@ std::shared_ptr AlgorithmPlanner::get_launcher() if (launchers.count(launch_key) == 0) { add_entrypoint(); add_device_functions(); + RAFT_LOG_INFO("A first-time JIT compilation has been triggered for your algorithm"); std::string log_message = "JIT compiling launcher for entrypoint: " + this->entrypoint + " and device functions: "; for (const auto& device_function : this->device_functions) { log_message += device_function + ","; } log_message.pop_back(); - RAFT_LOG_INFO("%s", log_message.c_str()); + RAFT_LOG_DEBUG("%s", log_message.c_str()); launchers[launch_key] = this->build(); } return launchers[launch_key]; @@ -78,22 +80,19 @@ std::shared_ptr AlgorithmPlanner::build() std::string archs = "-arch=sm_" + std::to_string((major * 10 + minor)); - // Load the generated LTO IR and link them together nvJitLinkHandle handle; - const char* lopts[] = {"-lto", archs.c_str()}; - auto result = nvJitLinkCreate(&handle, 2, lopts); + const char* lopts[] = { + "-lto", "-split-compile=0", "-split-compile-extended=0", "-maxrregcount=64", archs.c_str()}; + auto result = nvJitLinkCreate(&handle, 5, lopts); check_nvjitlink_result(handle, result); for (auto& frag : this->fragments) { frag->add_to(handle); } - // Call to nvJitLinkComplete causes linker to link together all the LTO-IR - // modules perform any optimizations and generate cubin from it. result = nvJitLinkComplete(handle); check_nvjitlink_result(handle, result); - // get cubin from nvJitLink size_t cubin_size; result = nvJitLinkGetLinkedCubinSize(handle, &cubin_size); check_nvjitlink_result(handle, result); @@ -105,7 +104,6 @@ std::shared_ptr AlgorithmPlanner::build() result = nvJitLinkDestroy(&handle); RAFT_EXPECTS(result == NVJITLINK_SUCCESS, "nvJitLinkDestroy failed"); - // cubin is linked, so now load it cudaLibrary_t library; RAFT_CUDA_TRY( cudaLibraryLoadData(&library, cubin.get(), nullptr, nullptr, 0, nullptr, nullptr, 0)); @@ -113,17 +111,14 @@ std::shared_ptr AlgorithmPlanner::build() unsigned int kernel_count = 0; RAFT_CUDA_TRY(cudaLibraryGetKernelCount(&kernel_count, library)); - // NOTE: cudaKernel_t does not need to be freed explicitly std::unique_ptr kernels{new cudaKernel_t[kernel_count]}; RAFT_CUDA_TRY(cudaLibraryEnumerateKernels(kernels.get(), kernel_count, library)); - // Filter out EmptyKernel by checking kernel names using cudaFuncGetName const char* empty_kernel_name = "_ZN3cub6detail11EmptyKernelIvEEvv"; std::vector valid_kernels; valid_kernels.reserve(kernel_count); for (unsigned int i = 0; i < kernel_count; ++i) { - // cudaFuncGetName can be used with cudaKernel_t by casting to void* const void* func_ptr = reinterpret_cast(kernels[i]); const char* func_name = nullptr; RAFT_CUDA_TRY(cudaFuncGetName(&func_name, func_ptr)); @@ -131,14 +126,12 @@ std::shared_ptr AlgorithmPlanner::build() bool is_empty_kernel = false; if (func_name != nullptr) { std::string kernel_name(func_name); - // Check if this is EmptyKernel if (kernel_name.find(empty_kernel_name) != std::string::npos || kernel_name == empty_kernel_name) { is_empty_kernel = true; } } - // Only keep the kernel if it's not EmptyKernel if (!is_empty_kernel) { valid_kernels.push_back(kernels[i]); } } diff --git a/cpp/src/detail/jit_lto/FragmentDatabase.cu b/cpp/src/detail/jit_lto/FragmentDatabase.cu index 02ea688a0d..efe7139aa2 100644 --- a/cpp/src/detail/jit_lto/FragmentDatabase.cu +++ b/cpp/src/detail/jit_lto/FragmentDatabase.cu @@ -7,6 +7,9 @@ #include #include +#include + +#include FragmentDatabase::FragmentDatabase() {} @@ -25,12 +28,20 @@ FragmentDatabase& fragment_database() return database; } +bool FragmentDatabase::has_fragment(std::string const& key) const { return cache.count(key) > 0; } + FragmentEntry* FragmentDatabase::get_fragment(std::string const& key) { auto& db = fragment_database(); auto val = db.cache.find(key); RAFT_EXPECTS(val != db.cache.end(), "FragmentDatabase: Key not found: %s", key.c_str()); - return val->second.get(); + auto* fragment = val->second.get(); + if (fragment == nullptr) { + RAFT_LOG_WARN("[JIT FRAGMENT] Fragment key exists but entry is NULL: %s (cache size: %zu)", + key.c_str(), + db.cache.size()); + } + return fragment; } void registerFatbinFragment(std::string const& algo, @@ -41,7 +52,23 @@ void registerFatbinFragment(std::string const& algo, auto& planner = fragment_database(); std::string key = algo; if (!params.empty()) { key += "_" + params; } - auto entry_exists = planner.make_cache_entry(key); - if (entry_exists) { return; } - planner.cache[key] = std::make_unique(key, blob, size); + { + std::lock_guard lock(planner.cache_mutex_); + auto entry_exists = planner.make_cache_entry(key); + if (entry_exists) { return; } + planner.cache[key] = std::make_unique(key, blob, size); + } +} + +void registerNVRTCFragment(std::string const& key, + std::unique_ptr&& program, + std::size_t size) +{ + auto& planner = fragment_database(); + { + std::lock_guard lock(planner.cache_mutex_); + auto entry_exists = planner.make_cache_entry(key); + if (entry_exists) { return; } + planner.cache[key] = std::make_unique(key, std::move(program), size); + } } diff --git a/cpp/src/detail/jit_lto/FragmentEntry.cu b/cpp/src/detail/jit_lto/FragmentEntry.cu index af1fb90e58..84caa207d5 100644 --- a/cpp/src/detail/jit_lto/FragmentEntry.cu +++ b/cpp/src/detail/jit_lto/FragmentEntry.cu @@ -26,3 +26,19 @@ bool FatbinFragmentEntry::add_to(nvJitLinkHandle& handle) const check_nvjitlink_result(handle, result); return true; } + +NVRTCFragmentEntry::NVRTCFragmentEntry(std::string const& key, + std::unique_ptr&& program, + std::size_t size) + : FragmentEntry(key), program(std::move(program)), data_size(size) +{ +} + +bool NVRTCFragmentEntry::add_to(nvJitLinkHandle& handle) const +{ + auto result = nvJitLinkAddData( + handle, NVJITLINK_INPUT_LTOIR, this->program.get(), this->data_size, this->compute_key.c_str()); + check_nvjitlink_result(handle, result); + + return true; +} diff --git a/cpp/src/detail/jit_lto/NVRTCLTOFragmentCompiler.cu b/cpp/src/detail/jit_lto/NVRTCLTOFragmentCompiler.cu new file mode 100644 index 0000000000..a17226cd01 --- /dev/null +++ b/cpp/src/detail/jit_lto/NVRTCLTOFragmentCompiler.cu @@ -0,0 +1,91 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +#include + +#include "cuda.h" +#include + +#define NVRTC_SAFE_CALL(_call) \ + { \ + nvrtcResult result = _call; \ + std::string error_string = \ + std::string("nvrtc error: ") + std::string(nvrtcGetErrorString(result)); \ + RAFT_EXPECTS(result == NVRTC_SUCCESS, "%s", error_string.c_str()); \ + } + +NVRTCLTOFragmentCompiler::NVRTCLTOFragmentCompiler() +{ + int device = 0; + int major = 0; + int minor = 0; + cudaGetDevice(&device); + cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device); + cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device); + + this->standard_compile_opts.resize(7); + + std::size_t i = 0; + // Use actual GPU architecture for optimal code generation + this->standard_compile_opts[i++] = + std::string{"-arch=sm_" + std::to_string((major * 10 + minor))}; + this->standard_compile_opts[i++] = std::string{"-dlto"}; + this->standard_compile_opts[i++] = std::string{"-rdc=true"}; + this->standard_compile_opts[i++] = std::string{"-default-device"}; + this->standard_compile_opts[i++] = std::string{"--gen-opt-lto"}; + // Optimization flags - NVRTC uses different syntax than nvcc + this->standard_compile_opts[i++] = std::string{"--use_fast_math"}; + this->standard_compile_opts[i++] = std::string{"--extra-device-vectorization"}; +} + +void NVRTCLTOFragmentCompiler::compile(std::string const& key, std::string const& code) const +{ + // Check if this fragment is already cached - avoid expensive NVRTC compilation + if (fragment_database().has_fragment(key)) { return; } + + nvrtcProgram prog; + NVRTC_SAFE_CALL( + nvrtcCreateProgram(&prog, code.c_str(), "nvrtc_lto_fragment", 0, nullptr, nullptr)); + + // Convert std::vector to std::vector for nvrtc API + std::vector opts; + opts.reserve(this->standard_compile_opts.size()); + for (const auto& opt : this->standard_compile_opts) { + opts.push_back(opt.c_str()); + } + + nvrtcResult compileResult = nvrtcCompileProgram(prog, // prog + opts.size(), // numOptions + opts.data()); // options + + if (compileResult != NVRTC_SUCCESS) { + // Obtain compilation log from the program. + size_t log_size; + NVRTC_SAFE_CALL(nvrtcGetProgramLogSize(prog, &log_size)); + std::unique_ptr log{new char[log_size]}; + NVRTC_SAFE_CALL(nvrtcGetProgramLog(prog, log.get())); + RAFT_FAIL("nvrtc compile error log: \n%s", log.get()); + } + + // Obtain generated LTO IR from the program. + std::size_t ltoIRSize; + NVRTC_SAFE_CALL(nvrtcGetLTOIRSize(prog, <oIRSize)); + + std::unique_ptr program = std::make_unique(ltoIRSize); + nvrtcGetLTOIR(prog, program.get()); + + NVRTC_SAFE_CALL(nvrtcDestroyProgram(&prog)); + + registerNVRTCFragment(key, std::move(program), ltoIRSize); +} + +NVRTCLTOFragmentCompiler& nvrtc_compiler() +{ + static NVRTCLTOFragmentCompiler compiler; + return compiler; +} diff --git a/cpp/src/neighbors/detail/cagra/cagra_search.cuh b/cpp/src/neighbors/detail/cagra/cagra_search.cuh index f1650980e0..bca8d3314d 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_search.cuh @@ -32,6 +32,7 @@ #include #include +// All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { template const base_type* @@ -205,6 +206,7 @@ struct alignas(device::LOAD_128BIT_T) dataset_descriptor_base_t { auto per_thread_distances = valid ? compute_distance_impl(args.load(), dataset_index) : 0; return device::team_sum(per_thread_distances, team_size_bitshift_from_smem()); } +#endif }; /** @@ -227,6 +229,14 @@ struct dataset_descriptor_host { uint32_t smem_ws_size_in_bytes = 0; uint32_t team_size = 0; + // JIT LTO metadata - stored when descriptor is created + cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; + uint32_t dataset_block_dim = 0; + bool is_vpq = false; + uint32_t pq_bits = 0; + uint32_t pq_len = 0; + // Codebook type is determined by DataT for VPQ (always half for now) + struct state { using ready_t = std::tuple; using init_f = @@ -270,10 +280,21 @@ struct dataset_descriptor_host { }; template - dataset_descriptor_host(const DescriptorImpl& dd_host, InitF init) + dataset_descriptor_host(const DescriptorImpl& dd_host, + InitF init, + cuvs::distance::DistanceType metric_val, + uint32_t dataset_block_dim_val, + bool is_vpq_val = false, + uint32_t pq_bits_val = 0, + uint32_t pq_len_val = 0) : value_{std::make_shared(init, sizeof(DescriptorImpl))}, smem_ws_size_in_bytes{dd_host.smem_ws_size_in_bytes()}, - team_size{dd_host.team_size()} + team_size{dd_host.team_size()}, + metric{metric_val}, + dataset_block_dim{dataset_block_dim_val}, + is_vpq{is_vpq_val}, + pq_bits{pq_bits_val}, + pq_len{pq_len_val} { } diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh index ecb09f516c..bb8ea1382f 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_standard-impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -13,7 +13,40 @@ #include namespace cuvs::neighbors::cagra::detail { + +#if defined(CUVS_ENABLE_JIT_LTO) || defined(BUILD_KERNEL) + +// When JIT LTO is enabled or building kernel fragments, dist_op is an extern function that gets JIT +// linked from fragments Each fragment provides a metric-specific implementation (L2Expanded, +// InnerProduct, etc.) The planner will link the appropriate fragment based on the metric Note: +// extern functions cannot be constexpr, so we remove constexpr here Note: These are in the detail +// namespace (not anonymous) so they can be found by JIT linking +// QueryT can be float (for most metrics) or uint8_t (for BitwiseHamming) +template +extern __device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b); + +// Normalization is also JIT linked from fragments (no-op for most metrics, cosine normalization for +// CosineExpanded) The planner will link the appropriate fragment (cosine or noop) based on the +// metric +// QueryT is needed to match the descriptor template signature (always float for normalization) +template +extern __device__ DistanceT apply_normalization_standard( + DistanceT distance, + const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index); + +#endif + namespace { + +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + +// When JIT LTO is disabled, dist_op is a template function with Metric as a template parameter template RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) -> std::enable_if_t @@ -41,18 +74,33 @@ RAFT_DEVICE_INLINE_FUNCTION constexpr auto dist_op(DATA_T a, DATA_T b) const auto v = (a ^ b) & 0xffu; return __popc(v); } + +#endif // #if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) } // namespace -template + typename DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + cuvs::distance::DistanceType Metric +#else + , + typename QueryT +#endif + > struct standard_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; - using QUERY_T = typename std:: +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + // When JIT LTO is disabled, Metric is a template parameter + using QUERY_T = typename std:: conditional_t; +#else + // When JIT LTO is enabled, QueryT is passed as a template parameter + using QUERY_T = QueryT; +#endif using base_type::args; using base_type::smem_ws_size_in_bytes; using typename base_type::args_t; @@ -62,7 +110,9 @@ struct standard_dataset_descriptor_t : public dataset_descriptor_base_t uint32_t { return sizeof(standard_dataset_descriptor_t) + raft::round_up_safe(dim, DatasetBlockDim) * sizeof(QUERY_T); } + + private: }; template @@ -169,7 +221,6 @@ _RAFT_DEVICE __noinline__ auto setup_workspace_standard( buf[j] = 0; } } - return const_cast(r); } @@ -206,7 +257,6 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker( if (k >= dim) break; #pragma unroll for (uint32_t v = 0; v < vlen; v++) { - // Note this loop can go above the dataset_dim for padded arrays. This is not a problem // because: // - Above the last element (dataset_dim-1), the query array is filled with zeros. // - The data buffer has to be also padded with zeros. @@ -215,8 +265,16 @@ RAFT_DEVICE_INLINE_FUNCTION auto compute_distance_standard_worker( d, query_smem_ptr + sizeof(QUERY_T) * device::swizzling(k + v)); +#if defined(CUVS_ENABLE_JIT_LTO) || defined(BUILD_KERNEL) + // When JIT LTO is enabled or building kernel fragments, dist_op is an extern function (no + // template parameters) + r += dist_op( + d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); +#else + // When JIT LTO is disabled, dist_op is a template function with Metric parameter r += dist_op( d, cuvs::spatial::knn::detail::utils::mapping{}(data[e][v])); +#endif } } } @@ -233,15 +291,32 @@ _RAFT_DEVICE __noinline__ auto compute_distance_standard( args.dim, args.smem_ws_ptr); +#if defined(CUVS_ENABLE_JIT_LTO) || defined(BUILD_KERNEL) + // Normalization is JIT linked from fragments (no-op or cosine normalization) + // The planner links the appropriate fragment based on the metric + distance = + apply_normalization_standard(distance, args, dataset_index); +#else + // When JIT LTO is disabled, kMetric is always available as a compile-time constant if constexpr (DescriptorT::kMetric == cuvs::distance::DistanceType::CosineExpanded) { const auto* dataset_norms = DescriptorT::dataset_norms_ptr(args); auto norm = dataset_norms[dataset_index]; if (norm > 0) { distance = distance / norm; } } +#endif return distance; } +#ifndef BUILD_KERNEL +// The init kernel is used for both JIT and non-JIT initialization +// When BUILD_KERNEL is defined, we're building a JIT fragment and don't want this kernel. +// The kernel handles JIT vs non-JIT via ifdef internally template ; + standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; + + // For CUDA 12 (non-JIT), set the function pointers properly new (out) desc_type(reinterpret_cast( &setup_workspace_standard), reinterpret_cast( @@ -268,8 +347,30 @@ RAFT_KERNEL __launch_bounds__(1, 1) dim, ld, dataset_norms); +#else + // When JIT LTO is enabled, Metric is not a template parameter + using query_t = + std::conditional_t; + using desc_type = + standard_dataset_descriptor_t; + using base_type = typename desc_type::base_type; + + // For JIT, we don't use the function pointers, so set them to nullptr + // The free functions are called directly instead + new (out) desc_type(nullptr, // setup_workspace_impl - not used in JIT + nullptr, // compute_distance_impl - not used in JIT + ptr, + size, + dim, + ld, + dataset_norms); +#endif } +#endif // #ifndef BUILD_KERNEL +#ifndef BUILD_KERNEL +// The init_ function is used for both JIT and non-JIT initialization +// When BUILD_KERNEL is defined, we're building a JIT fragment and don't want this function. template ; + using base_type = typename desc_type::base_type; +#else + // When JIT LTO is enabled, Metric is not a template parameter + // QueryT depends on metric: uint8_t for BitwiseHamming, float for others + using query_t = + std::conditional_t; using desc_type = - standard_dataset_descriptor_t; + standard_dataset_descriptor_t; using base_type = typename desc_type::base_type; +#endif RAFT_EXPECTS(Metric != cuvs::distance::DistanceType::CosineExpanded || dataset_norms != nullptr, "Dataset norms must be provided for CosineExpanded metric"); - desc_type dd_host{nullptr, nullptr, ptr, size, dim, ld, dataset_norms}; - return host_type{dd_host, + return host_type{desc_type{nullptr, nullptr, ptr, size, dim, ld, dataset_norms}, [=](dataset_descriptor_base_t* dev_ptr, rmm::cuda_stream_view stream) { + // Use init kernel for both JIT and CUDA 12 + // The kernel handles JIT vs non-JIT via ifdef internally standard_dataset_descriptor_init_kernel <<<1, 1, 0, stream>>>(dev_ptr, ptr, size, dim, ld, dataset_norms); RAFT_CUDA_TRY(cudaPeekAtLastError()); - }}; + }, + Metric, + DatasetBlockDim, + false, // is_vpq + 0, // pq_bits + 0}; // pq_len } +#endif // #ifndef BUILD_KERNEL } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh index cdafb173ed..1cb593830d 100644 --- a/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh +++ b/cpp/src/neighbors/detail/cagra/compute_distance_vpq-impl.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -15,19 +15,30 @@ namespace cuvs::neighbors::cagra::detail { -template + typename DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + cuvs::distance::DistanceType Metric +#else + , + typename QueryT +#endif + > struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { using base_type = dataset_descriptor_base_t; using CODE_BOOK_T = CodebookT; - using QUERY_T = half; +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + using QUERY_T = half; +#else + using QUERY_T = QueryT; +#endif using base_type::args; using base_type::extra_ptr3; using typename base_type::args_t; @@ -37,7 +48,9 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t uint32_t { /* SMEM workspace layout: @@ -121,6 +134,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(dim, DatasetBlockDim) * sizeof(QUERY_T); } + + private: }; template @@ -347,6 +362,10 @@ _RAFT_DEVICE __noinline__ auto compute_distance_vpq( args.smem_ws_ptr); } +#ifndef BUILD_KERNEL +// The init kernel is not needed when building JIT fragments (BUILD_KERNEL is defined) +// It's only needed for non-JIT initialization. When BUILD_KERNEL is defined, we're building +// a JIT fragment and don't want this kernel to be instantiated. template ; + DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + Metric +#else + , + half +#endif + >; using base_type = typename desc_type::base_type; +#ifdef CUVS_ENABLE_JIT_LTO + // For JIT, we don't use the function pointers, so set them to nullptr + // The free functions are called directly instead + new (out) desc_type(nullptr, // setup_workspace_impl - not used in JIT + nullptr, // compute_distance_impl - not used in JIT + encoded_dataset_ptr, + encoded_dataset_dim, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim); +#else + // For CUDA 12 (non-JIT), set the function pointers properly new (out) desc_type( reinterpret_cast(&setup_workspace_vpq), reinterpret_cast(&compute_distance_vpq), @@ -384,8 +423,14 @@ RAFT_KERNEL __launch_bounds__(1, 1) pq_code_book_ptr, size, dim); +#endif } +#endif // #ifndef BUILD_KERNEL +#ifndef BUILD_KERNEL +// The init_ function is not needed when building JIT fragments (BUILD_KERNEL is defined) +// It's only needed for non-JIT initialization. When BUILD_KERNEL is defined, we're building +// a JIT fragment and don't want this host function to be included. template ; + DistanceT +#if !defined(CUVS_ENABLE_JIT_LTO) && !defined(BUILD_KERNEL) + , + Metric +#else + , + half +#endif + >; using base_type = typename desc_type::base_type; - desc_type dd_host{nullptr, - nullptr, - encoded_dataset_ptr, - encoded_dataset_dim, - vq_code_book_ptr, - pq_code_book_ptr, - size, - dim}; - return host_type{dd_host, + return host_type{desc_type{nullptr, + nullptr, + encoded_dataset_ptr, + encoded_dataset_dim, + vq_code_book_ptr, + pq_code_book_ptr, + size, + dim}, [=](dataset_descriptor_base_t* dev_ptr, rmm::cuda_stream_view stream) { + // Use init kernel for both JIT and CUDA 12 + // The kernel handles JIT vs non-JIT via ifdef internally vpq_dataset_descriptor_init_kernel +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + registerAlgorithm( + "apply_filter_kernel", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in new file mode 100644 index 0000000000..eae9b5e32d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_kernel.cu.in @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template __global__ void apply_filter_kernel_jit<@index_type@, @distance_type@, @source_index_type@>( + const @source_index_type@* const, @index_type@* const, @distance_type@* const, const std::size_t, const std::uint32_t, const std::uint32_t, const @index_type@, uint32_t*, @source_index_type@, @source_index_type@); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json new file mode 100644 index 0000000000..4f14f7d8c0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_filter_matrix.json @@ -0,0 +1,20 @@ +{ + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_cosine_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_cosine_impl.cuh new file mode 100644 index 0000000000..c691e58ef6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_cosine_impl.cuh @@ -0,0 +1,37 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +// Cosine normalization fragment implementation +// This provides apply_normalization_standard that normalizes by dataset norm (for CosineExpanded +// metric) +// QueryT is needed to match the descriptor template signature, but not used in this function +template +__device__ DistanceT +apply_normalization_standard(DistanceT distance, + const typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // CosineExpanded normalization: divide by dataset norm + const auto* dataset_norms = + standard_dataset_descriptor_t:: + dataset_norms_ptr(args); + auto norm = dataset_norms[dataset_index]; + if (norm > 0) { distance = distance / norm; } + return distance; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_embedded.cpp.in new file mode 100644 index 0000000000..44412d3c8b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + using QueryTag = cuvs::neighbors::cagra::detail::tag_@query_abbrev@; + registerAlgorithm( + "apply_normalization_standard@normalization_suffix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in new file mode 100644 index 0000000000..9606d21ec6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_kernel.cu.in @@ -0,0 +1,14 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail { + +using args_t = typename cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>::args_t; +template __device__ @distance_type@ apply_normalization_standard<@team_size@, @dataset_block_dim@, @data_type@, @index_type@, @distance_type@, @query_type@>( + @distance_type@, const args_t, @index_type@); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json new file mode 100644 index 0000000000..077684b5be --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_matrix.json @@ -0,0 +1,66 @@ +{ + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_normalization": [ + { + "normalization_suffix": "_noop", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "normalization_suffix": "_cosine", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_noop_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_noop_impl.cuh new file mode 100644 index 0000000000..e9b9bc6556 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/apply_normalization_standard_noop_impl.cuh @@ -0,0 +1,31 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" + +namespace cuvs::neighbors::cagra::detail { + +// No-op normalization fragment implementation +// This provides apply_normalization_standard that does nothing (for non-CosineExpanded metrics) +// QueryT is needed to match the descriptor template signature, but not used in this function +template +__device__ DistanceT +apply_normalization_standard(DistanceT distance, + const typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // No normalization needed for non-CosineExpanded metrics + return distance; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp new file mode 100644 index 0000000000..908a27046d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/cagra_planner_base.hpp @@ -0,0 +1,120 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +struct CagraPlannerBase : AlgorithmPlanner { + using AlgorithmPlanner::device_functions; + + CagraPlannerBase(const std::string& entrypoint, const std::string& params) + : AlgorithmPlanner(entrypoint, params) + { + } + void add_setup_workspace_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits = 0, + uint32_t pq_len = 0) + { + std::string key = "setup_workspace"; + if (is_vpq) { + key += "_vpq"; + auto params = make_fragment_key(); + key += "_team_size_" + std::to_string(team_size); + key += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + key += "_" + std::to_string(pq_bits) + "pq_" + std::to_string(pq_len) + "subd"; + if (!params.empty()) { key += "_" + params; } + } else { + key += "_standard_team_size_" + std::to_string(team_size); + key += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + auto params = make_fragment_key(); + if (!params.empty()) { key += "_" + params; } + } + this->device_functions.push_back(key); + } + + void add_compute_distance_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits = 0, + uint32_t pq_len = 0) + { + if (is_vpq) { + std::string key = "compute_distance_vpq"; + auto params = make_fragment_key(); + key += "_team_size_" + std::to_string(team_size); + key += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + key += "_" + std::to_string(pq_bits) + "pq_" + std::to_string(pq_len) + "subd"; + if (!params.empty()) { key += "_" + params; } + this->device_functions.push_back(key); + } else { + std::string key = "compute_distance_standard_team_size_" + std::to_string(team_size); + key += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + auto params = make_fragment_key(); + if (!params.empty()) { key += "_" + params; } + this->device_functions.push_back(key); + add_dist_op_device_function(metric); + add_normalization_device_function(metric, team_size, dataset_block_dim); + } + } + + void add_dist_op_device_function(cuvs::distance::DistanceType metric) + { + std::string metric_tag; + switch (metric) { + case cuvs::distance::DistanceType::L2Expanded: + case cuvs::distance::DistanceType::L2Unexpanded: metric_tag = "l2"; break; + case cuvs::distance::DistanceType::InnerProduct: metric_tag = "inner_product"; break; + case cuvs::distance::DistanceType::CosineExpanded: metric_tag = "inner_product"; break; + case cuvs::distance::DistanceType::BitwiseHamming: metric_tag = "hamming"; break; + default: metric_tag = "unknown"; break; + } + auto params = make_fragment_key(); + std::string key = "dist_op_" + metric_tag; + if (!params.empty()) { key += "_" + params; } + this->device_functions.push_back(key); + } + + void add_normalization_device_function(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim) + { + std::string normalization_type; + if (metric == cuvs::distance::DistanceType::CosineExpanded) { + normalization_type = "cosine"; + } else { + normalization_type = "noop"; + } + auto params = make_fragment_key(); + std::string key = "apply_normalization_standard_" + normalization_type; + key += "_team_size_" + std::to_string(team_size); + key += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + if (!params.empty()) { key += "_" + params; } + this->device_functions.push_back(key); + } + + void add_sample_filter_device_function(std::string filter_name) + { + this->device_functions.push_back("sample_filter_" + filter_name); + } +}; + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_embedded.cpp.in new file mode 100644 index 0000000000..f577c6a6d6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + using QueryTag = cuvs::neighbors::cagra::detail::tag_@query_abbrev@; + registerAlgorithm( + "compute_distance@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in new file mode 100644 index 0000000000..6163dce4ac --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_kernel.cu.in @@ -0,0 +1,14 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail { + +using args_t = typename cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>::args_t; +template __device__ @distance_type@ compute_distance<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>( + const args_t, @index_type@); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json new file mode 100644 index 0000000000..39cf9ad2c5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_matrix.json @@ -0,0 +1,156 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "0", + "pq_bits": "0", + "pq_prefix": "_standard", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "2", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_len": "4", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_standard_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_standard_impl.cuh new file mode 100644 index 0000000000..7aa1a12395 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_standard_impl.cuh @@ -0,0 +1,39 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" +#include "../device_common.hpp" // For dataset_descriptor_base_t + +namespace cuvs::neighbors::cagra::detail { + +// Unified compute_distance implementation for standard descriptors +// This is instantiated when PQ_BITS=0, PQ_LEN=0, CodebookT=void +// QueryT can be float (for most metrics) or uint8_t (for BitwiseHamming) +template +__device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // For standard descriptors, PQ_BITS=0, PQ_LEN=0, CodebookT=void + static_assert(PQ_BITS == 0 && PQ_LEN == 0 && std::is_same_v, + "Standard descriptor requires PQ_BITS=0, PQ_LEN=0, CodebookT=void"); + + // Reconstruct the descriptor type with QueryT and call compute_distance_standard + using desc_t = + standard_dataset_descriptor_t; + return compute_distance_standard(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_embedded.cpp.in new file mode 100644 index 0000000000..b28a89f667 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + registerAlgorithm( + "compute_distance_to_child_nodes@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in new file mode 100644 index 0000000000..1a1baed5e7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_kernel.cu.in @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template __global__ void compute_distance_to_child_nodes_kernel_jit<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@, @source_index_type@, cuvs::neighbors::filtering::none_sample_filter>( + const @index_type@* const, @index_type@* const, @distance_type@* const, const std::size_t, const std::uint32_t, cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, const @index_type@* const, const std::uint32_t, const @source_index_type@*, const @data_type@*, @index_type@* const, const std::uint32_t, @index_type@* const, @distance_type@* const, const std::uint32_t, cuvs::neighbors::filtering::none_sample_filter); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json new file mode 100644 index 0000000000..929165330b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_to_child_nodes_matrix.json @@ -0,0 +1,168 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "0", + "pq_len": "0", + "pq_prefix": "", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "8", + "pq_len": "2", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_bits": "8", + "pq_len": "4", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl.cuh new file mode 100644 index 0000000000..e21d48a2f1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl.cuh @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_vpq-impl.cuh" +#include "../device_common.hpp" // For dataset_descriptor_base_t + +namespace cuvs::neighbors::cagra::detail { + +// Unified compute_distance implementation for VPQ descriptors +// This is instantiated when PQ_BITS>0, PQ_LEN>0, CodebookT=half +// QueryT is always half for VPQ +template +__device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // For VPQ descriptors, PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half + static_assert( + PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && std::is_same_v, + "VPQ descriptor requires PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half"); + + // Reconstruct the descriptor type and call compute_distance_vpq + // QueryT is always half for VPQ + using desc_t = cagra_q_dataset_descriptor_t; + return compute_distance_vpq(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl_unified.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl_unified.cuh new file mode 100644 index 0000000000..e21d48a2f1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/compute_distance_vpq_impl_unified.cuh @@ -0,0 +1,48 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_vpq-impl.cuh" +#include "../device_common.hpp" // For dataset_descriptor_base_t + +namespace cuvs::neighbors::cagra::detail { + +// Unified compute_distance implementation for VPQ descriptors +// This is instantiated when PQ_BITS>0, PQ_LEN>0, CodebookT=half +// QueryT is always half for VPQ +template +__device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index) +{ + // For VPQ descriptors, PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half + static_assert( + PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && std::is_same_v, + "VPQ descriptor requires PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half"); + + // Reconstruct the descriptor type and call compute_distance_vpq + // QueryT is always half for VPQ + using desc_t = cagra_q_dataset_descriptor_t; + return compute_distance_vpq(args, dataset_index); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh new file mode 100644 index 0000000000..983fe93fe3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/device_common_jit.cuh @@ -0,0 +1,247 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../utils.hpp" +#include "extern_device_functions.cuh" + +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { +namespace device { + +// Helper to check if DescriptorT has kPqBits (VPQ descriptor) +template +struct has_kpq_bits { + template + static auto test(int) -> decltype(U::kPqBits, std::true_type{}); + template + static std::false_type test(...); + static constexpr bool value = decltype(test(0))::value; +}; + +template +inline constexpr bool has_kpq_bits_v = has_kpq_bits::value; + +// JIT version of compute_distance_to_random_nodes - uses dataset_descriptor_base_t* pointer +// Shared between single_cta and multi_cta JIT kernels +// Unified template parameters: TeamSize, DatasetBlockDim, PQ_BITS, PQ_LEN, CodebookT, QueryT +template +RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_random_nodes_jit( + IndexT* __restrict__ result_indices_ptr, // [num_pickup] + DistanceT* __restrict__ result_distances_ptr, // [num_pickup] + dataset_descriptor_base_t* smem_desc, + const uint32_t num_pickup, + const uint32_t num_distilation, + const uint64_t rand_xor_mask, + const IndexT* __restrict__ seed_ptr, // [num_seeds] + const uint32_t num_seeds, + IndexT* __restrict__ visited_hash_ptr, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hash_ptr, + const uint32_t traversed_hash_bitlen, + const uint32_t block_id = 0, + const uint32_t num_blocks = 1) +{ + constexpr unsigned warp_size = 32; + + // Get team_size_bits and args directly from base descriptor + using args_t = typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t; + + // Use team_size_bitshift_from_smem since smem_desc is in shared memory + uint32_t team_size_bits = smem_desc->team_size_bitshift_from_smem(); + args_t args = smem_desc->args.load(); + IndexT dataset_size = smem_desc->size; + + const auto max_i = raft::round_up_safe(num_pickup, warp_size >> team_size_bits); + + for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += (blockDim.x >> team_size_bits)) { + const bool valid_i = (i < num_pickup); + + IndexT best_index_team_local = raft::upper_bound(); + DistanceT best_norm2_team_local = raft::upper_bound(); + for (uint32_t j = 0; j < num_distilation; j++) { + // Select a node randomly and compute the distance to it + IndexT seed_index = 0; + if (valid_i) { + uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); + if (seed_ptr && (gid < num_seeds)) { + seed_index = seed_ptr[gid]; + } else { + seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size; + } + } + + // CRITICAL: ALL threads in the team must participate in compute_distance and team_sum + // Otherwise warp shuffles will hang. Each thread calls the unified extern function to get + // its per-thread distance, then team_sum reduces across all threads in the team. + DistanceT per_thread_norm2 = 0; + if (valid_i) { + // Use unified compute_distance function (links standard or VPQ fragment at runtime) + per_thread_norm2 = compute_distance(args, seed_index); + } + // Now ALL threads in the team participate in team_sum + const auto norm2_sum = device::team_sum(per_thread_norm2, team_size_bits); + + if (valid_i && (norm2_sum < best_norm2_team_local)) { + best_norm2_team_local = norm2_sum; + best_index_team_local = seed_index; + } + } + + const unsigned lane_id = threadIdx.x & ((1u << team_size_bits) - 1u); + if (valid_i && lane_id == 0) { + if (best_index_team_local != raft::upper_bound()) { + if (hashmap::insert(visited_hash_ptr, visited_hash_bitlen, best_index_team_local) == 0) { + // Deactivate this entry as insertion into visited hash table has failed. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } else if ((traversed_hash_ptr != nullptr) && + hashmap::search( + traversed_hash_ptr, traversed_hash_bitlen, best_index_team_local)) { + // Deactivate this entry as it has been already used by others. + best_norm2_team_local = raft::upper_bound(); + best_index_team_local = raft::upper_bound(); + } + } + result_distances_ptr[i] = best_norm2_team_local; + result_indices_ptr[i] = best_index_team_local; + } + } +} + +// JIT version of compute_distance_to_child_nodes - uses dataset_descriptor_base_t* pointer +// Shared between single_cta and multi_cta JIT kernels +// Unified template parameters: TeamSize, DatasetBlockDim, PQ_BITS, PQ_LEN, CodebookT, QueryT +template +RAFT_DEVICE_INLINE_FUNCTION void compute_distance_to_child_nodes_jit( + IndexT* __restrict__ result_child_indices_ptr, + DistanceT* __restrict__ result_child_distances_ptr, + dataset_descriptor_base_t* smem_desc, + const IndexT* __restrict__ knn_graph, + const uint32_t knn_k, + IndexT* __restrict__ visited_hashmap_ptr, + const uint32_t visited_hash_bitlen, + IndexT* __restrict__ traversed_hashmap_ptr, + const uint32_t traversed_hash_bitlen, + const IndexT* __restrict__ parent_indices, + const IndexT* __restrict__ internal_topk_list, + const uint32_t search_width, + int* __restrict__ result_position = nullptr, + const int max_result_position = 0) +{ + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr IndexT invalid_index = ~static_cast(0); + + // Read child indices of parents from knn graph and check if the distance computation is + // necessary. + for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { + const IndexT smem_parent_id = parent_indices[i / knn_k]; + IndexT child_id = invalid_index; + if (smem_parent_id != invalid_index) { + const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; + child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; + } + if (child_id != invalid_index) { + if (hashmap::insert(visited_hashmap_ptr, visited_hash_bitlen, child_id) == 0) { + child_id = invalid_index; + } else if ((traversed_hashmap_ptr != nullptr) && + hashmap::search( + traversed_hashmap_ptr, traversed_hash_bitlen, child_id)) { + child_id = invalid_index; + } + } + if (STATIC_RESULT_POSITION) { + result_child_indices_ptr[i] = child_id; + } else if (child_id != invalid_index) { + int j = atomicSub(result_position, 1) - 1; + result_child_indices_ptr[j] = child_id; + } + } + __syncthreads(); + + // Compute the distance to child nodes using unified extern compute_distance + constexpr unsigned warp_size = 32; + + // Get team_size_bits and args directly from base descriptor + using args_t = typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t; + + // Use team_size_bitshift_from_smem since smem_desc is in shared memory + uint32_t team_size_bits = smem_desc->team_size_bitshift_from_smem(); + args_t args = smem_desc->args.load(); + + const auto num_k = knn_k * search_width; + const auto max_i = raft::round_up_safe(num_k, warp_size >> team_size_bits); + const bool lead_lane = (threadIdx.x & ((1u << team_size_bits) - 1u)) == 0; + const uint32_t ofst = STATIC_RESULT_POSITION ? 0 : result_position[0]; + + for (uint32_t i = threadIdx.x >> team_size_bits; i < max_i; i += blockDim.x >> team_size_bits) { + const auto j = i + ofst; + const bool valid_i = STATIC_RESULT_POSITION ? (j < num_k) : (j < max_result_position); + const auto child_id = valid_i ? result_child_indices_ptr[j] : invalid_index; + + // CRITICAL: ALL threads in the team must participate in compute_distance and team_sum + // Otherwise warp shuffles will hang. Each thread calls the unified extern function to get + // its per-thread distance, then team_sum reduces across all threads in the team. + DistanceT per_thread_dist = 0; + if (child_id != invalid_index) { + // Use unified compute_distance function (links standard or VPQ fragment at runtime) + per_thread_dist = compute_distance(args, child_id); + } else { + // Invalid child_id: lead lane gets upper_bound, others get 0 + per_thread_dist = lead_lane ? raft::upper_bound() : 0; + } + + // Now ALL threads in the team participate in team_sum + DistanceT child_dist = device::team_sum(per_thread_dist, team_size_bits); + __syncwarp(); + + // Store the distance + if (valid_i && lead_lane) { result_child_distances_ptr[j] = child_dist; } + } +} + +} // namespace device +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_cosine_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_cosine_impl.cuh new file mode 100644 index 0000000000..ba6c270fa1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_cosine_impl.cuh @@ -0,0 +1,16 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + return -static_cast(a) * static_cast(b); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_embedded.cpp.in new file mode 100644 index 0000000000..a15ab944dd --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_embedded.cpp.in @@ -0,0 +1,23 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) static void register_kernel() +{ + registerAlgorithm( + "dist_op_@metric_tag@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_hamming_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_hamming_impl.cuh new file mode 100644 index 0000000000..cd4ed29ac6 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_hamming_impl.cuh @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + const auto v = (a ^ b) & 0xffu; + return __popc(v); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_inner_product_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_inner_product_impl.cuh new file mode 100644 index 0000000000..ba6c270fa1 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_inner_product_impl.cuh @@ -0,0 +1,16 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + return -static_cast(a) * static_cast(b); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in new file mode 100644 index 0000000000..d5d2bb2cfc --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_kernel.cu.in @@ -0,0 +1,14 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include + +namespace cuvs::neighbors::cagra::detail { + +template __device__ @distance_type@ dist_op<@query_type@, @distance_type@>(@query_type@, @query_type@); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l2_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l2_impl.cuh new file mode 100644 index 0000000000..f74b62b4b0 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_l2_impl.cuh @@ -0,0 +1,17 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +namespace cuvs::neighbors::cagra::detail { + +template +__device__ DISTANCE_T dist_op(QUERY_T a, QUERY_T b) +{ + DISTANCE_T diff = a - b; + return diff * diff; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json new file mode 100644 index 0000000000..7f0772ab1f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/dist_op_matrix.json @@ -0,0 +1,25 @@ +{ + "_metric": [ + { + "metric_tag": "l2", + "query_type": "float", + "query_abbrev": "f" + }, + { + "metric_tag": "inner_product", + "query_type": "float", + "query_abbrev": "f" + }, + { + "metric_tag": "hamming", + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh new file mode 100644 index 0000000000..81f5d56f23 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/extern_device_functions.cuh @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance.hpp" +#include + +namespace cuvs::neighbors::cagra::detail { + +template +extern __device__ dataset_descriptor_base_t* setup_workspace( + dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id); + +template +extern __device__ DistanceT +compute_distance(const typename dataset_descriptor_base_t::args_t args, + IndexT dataset_index); +} // namespace cuvs::neighbors::cagra::detail + +namespace cuvs::neighbors::detail { + +template +extern __device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data); + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json new file mode 100644 index 0000000000..d83fbe4b76 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/filter_matrix.json @@ -0,0 +1,15 @@ +{ + "filter_name": [ + "filter_none", + "filter_bitset" + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "namespace": [ + "cuvs::neighbors::detail" + ] +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_embedded.cpp.in new file mode 100644 index 0000000000..49d89f6416 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + registerAlgorithm( + "random_pickup@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in new file mode 100644 index 0000000000..d5424b780b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_kernel.cu.in @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template __global__ void random_pickup_kernel_jit<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>( + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, const @data_type@* const, const std::size_t, const unsigned, const uint64_t, const @index_type@*, const std::uint32_t, @index_type@* const, @distance_type@* const, const std::uint32_t, @index_type@* const, const std::uint32_t); + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json new file mode 100644 index 0000000000..3c014f8580 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/random_pickup_matrix.json @@ -0,0 +1,156 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "0", + "pq_len": "0", + "pq_prefix": "", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "8", + "pq_len": "2", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_bits": "8", + "pq_len": "4", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_embedded.cpp.in new file mode 100644 index 0000000000..a70e22c696 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + registerAlgorithm( + "search_multi_cta@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_helpers.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_helpers.cuh new file mode 100644 index 0000000000..fe985f7275 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_helpers.cuh @@ -0,0 +1,138 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parent( + INDEX_T* const next_parent_indices, + INDEX_T* const itopk_indices, // [itopk_size * 2] + DISTANCE_T* const itopk_distances, // [itopk_size * 2] + INDEX_T* const hash_ptr, + const uint32_t hash_bitlen) +{ + constexpr uint32_t itopk_size = 32; + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + constexpr INDEX_T invalid_index = ~static_cast(0); + + const unsigned warp_id = threadIdx.x / 32; + if (warp_id > 0) { return; } + if (threadIdx.x == 0) { next_parent_indices[0] = invalid_index; } + __syncwarp(); + + int j = -1; + for (unsigned i = threadIdx.x; i < itopk_size * 2; i += 32) { + INDEX_T index = itopk_indices[i]; + int is_invalid = 0; + int is_candidate = 0; + if (index == invalid_index) { + is_invalid = 1; + } else if (index & index_msb_1_mask) { + } else { + is_candidate = 1; + } + + const auto ballot_mask = __ballot_sync(0xffffffff, is_candidate); + const auto candidate_id = __popc(ballot_mask & ((1 << threadIdx.x) - 1)); + for (int k = 0; k < __popc(ballot_mask); k++) { + int flag_done = 0; + if (is_candidate && candidate_id == k) { + is_candidate = 0; + if (hashmap::insert(hash_ptr, hash_bitlen, index)) { + // Use this candidate as next parent + index |= index_msb_1_mask; // set most significant bit as used node + if (i < itopk_size) { + next_parent_indices[0] = i; + itopk_indices[i] = index; + } else { + next_parent_indices[0] = j; + // Move the next parent node from i-th position to j-th position + itopk_indices[j] = index; + itopk_distances[j] = itopk_distances[i]; + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + } + flag_done = 1; + } else { + // Deactivate the node since it has been used by other CTA. + itopk_indices[i] = invalid_index; + itopk_distances[i] = utils::get_max_value(); + is_invalid = 1; + } + } + if (__any_sync(0xffffffff, (flag_done > 0))) { return; } + } + if (i < itopk_size) { + j = 31 - __clz(__ballot_sync(0xffffffff, is_invalid)); + if (j < 0) { return; } + } + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort(float* distances, // [num_elements] + INDEX_T* indices, // [num_elements] + const uint32_t num_elements) +{ + const unsigned warp_id = threadIdx.x / raft::warp_size(); + if (warp_id > 0) { return; } + const unsigned lane_id = threadIdx.x % raft::warp_size(); + constexpr unsigned N = (MAX_ELEMENTS + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + INDEX_T val[N]; + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_elements) { + key[i] = distances[j]; + val[i] = indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = ~static_cast(0); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store sorted results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_elements) { + distances[j] = key[i]; + indices[j] = val[i]; + } + } +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_64( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<64, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_128( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<128, uint32_t>(distances, indices, num_elements); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_wrapper_256( + float* distances, // [num_elements] + uint32_t* indices, // [num_elements] + const uint32_t num_elements) +{ + topk_by_bitonic_sort<256, uint32_t>(distances, indices, num_elements); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh new file mode 100644 index 0000000000..2f7cbc1c42 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_jit.cuh @@ -0,0 +1,407 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include + +#include +#include +#include + +#ifdef _CLK_BREAKDOWN +#include +#endif + +#include "../../jit_lto_kernels/filter_data.h" +#include "device_common_jit.cuh" +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +using cuvs::neighbors::cagra::detail::device::compute_distance_to_child_nodes_jit; +using cuvs::neighbors::cagra::detail::device::compute_distance_to_random_nodes_jit; +using cuvs::neighbors::cagra::detail::device::has_kpq_bits_v; +using cuvs::neighbors::detail::sample_filter; +template +__global__ __launch_bounds__(1024, 1) void search_kernel_jit( + IndexT* const result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DistanceT* const result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + dataset_descriptor_base_t* dataset_desc, + const DataT* const queries_ptr, // [num_queries, dataset_dim] + const IndexT* const knn_graph, // [dataset_size, graph_degree] + const uint32_t max_elements, + const uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, // [num_queries, search_width] + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + const uint32_t visited_hash_bitlen, + IndexT* const traversed_hashmap_ptr, // [num_queries, 1 << traversed_hash_bitlen] + const uint32_t traversed_hash_bitlen, + const uint32_t itopk_size, + const uint32_t min_iteration, + const uint32_t max_iteration, + uint32_t* const num_executed_iterations, /* stats */ + const uint32_t query_id_offset, // Offset to add to query_id when calling filter + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) +{ + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + auto to_source_index = [source_indices_ptr](INDEX_T x) { + return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; + }; + + const auto num_queries = gridDim.y; + const auto query_id = blockIdx.y; + const auto num_cta_per_query = gridDim.x; + const auto cta_id = blockIdx.x; // local CTA ID + +#ifdef _CLK_BREAKDOWN + uint64_t clk_init = 0; + uint64_t clk_compute_1st_distance = 0; + uint64_t clk_topk = 0; + uint64_t clk_pickup_parents = 0; + uint64_t clk_compute_distance = 0; + uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint8_t smem[]; + + // Layout of result_buffer + // +----------------+---------+---------------------------+ + // | internal_top_k | padding | neighbors of parent nodes | + // | | upto 32 | | + // +----------------+---------+---------------------------+ + // |<--- result_buffer_size_32 --->| + const auto result_buffer_size = itopk_size + graph_degree; + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + assert(result_buffer_size_32 <= max_elements); + + // Get dim and smem_ws_size_in_bytes directly from base descriptor + uint32_t dim = dataset_desc->args.dim; + uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes(); + + // Set smem working buffer using unified setup_workspace + // setup_workspace copies the descriptor to shared memory and returns base pointer to smem + // descriptor + dataset_descriptor_base_t* smem_desc = + setup_workspace(dataset_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ local_visited_hashmap_ptr = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_indices_buffer = + reinterpret_cast(local_visited_hashmap_ptr + hashmap::get_size(visited_hash_bitlen)); + auto* __restrict__ result_position = reinterpret_cast(parent_indices_buffer + 1); + + INDEX_T* const local_traversed_hashmap_ptr = + traversed_hashmap_ptr + (hashmap::get_size(traversed_hash_bitlen) * query_id); + + constexpr INDEX_T invalid_index = ~static_cast(0); + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes using JIT version + _CLK_START(); + const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + uint32_t block_id = cta_id + (num_cta_per_query * query_id); + uint32_t num_blocks = num_cta_per_query * num_queries; + + compute_distance_to_random_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + graph_degree, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + block_id, + num_blocks); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + uint32_t iter = 0; + while (1) { + _CLK_START(); + if (threadIdx.x < 32) { + // [1st warp] Topk with bitonic sort + if constexpr (std::is_same_v) { + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort + // function (vs post-inlining, this impacts register pressure) + if (max_elements <= 64) { + topk_by_bitonic_sort_wrapper_64( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort_wrapper_128( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort_wrapper_256( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } else { + if (max_elements <= 64) { + topk_by_bitonic_sort<64, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else if (max_elements <= 128) { + topk_by_bitonic_sort<128, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } else { + assert(max_elements <= 256); + topk_by_bitonic_sort<256, INDEX_T>( + result_distances_buffer, result_indices_buffer, result_buffer_size_32); + } + } + } + __syncthreads(); + _CLK_REC(clk_topk); + + if (iter + 1 >= max_iteration) { break; } + + _CLK_START(); + if (threadIdx.x < 32) { + // [1st warp] Pick up a next parent + pickup_next_parent(parent_indices_buffer, + result_indices_buffer, + result_distances_buffer, + local_traversed_hashmap_ptr, + traversed_hash_bitlen); + } else { + // [Other warps] Reset visited hashmap + hashmap::init(local_visited_hashmap_ptr, visited_hash_bitlen, 32); + } + __syncthreads(); + _CLK_REC(clk_pickup_parents); + + if ((parent_indices_buffer[0] == invalid_index) && (iter >= min_iteration)) { break; } + + _CLK_START(); + for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + if ((i >= itopk_size) && (index & index_msb_1_mask)) { + // Remove nodes kicked out of the itopk list from the traversed hash table. + hashmap::remove( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index & ~index_msb_1_mask); + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } else { + // Restore visited hashmap by putting nodes on result buffer in it. + index &= ~index_msb_1_mask; + hashmap::insert(local_visited_hashmap_ptr, visited_hash_bitlen, index); + } + } + // Initialize buffer for compute_distance_to_child_nodes. + if (threadIdx.x == blockDim.x - 1) { result_position[0] = result_buffer_size_32; } + __syncthreads(); + + // Compute the norms between child nodes and query node using JIT version + compute_distance_to_child_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + visited_hash_bitlen, + local_traversed_hashmap_ptr, + traversed_hash_bitlen, + parent_indices_buffer, + result_indices_buffer, + 1, + result_position, + result_buffer_size_32); + __syncthreads(); + + // Check the state of the nodes in the result buffer which were not updated + // by the compute_distance_to_child_nodes above, and if it cannot be used as + // a parent node, it is deactivated. + for (uint32_t i = threadIdx.x; i < result_position[0]; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index || index & index_msb_1_mask) { continue; } + if (hashmap::search(local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + _CLK_REC(clk_compute_distance); + + // Filtering - use extern sample_filter function (linked via JIT LTO) + for (unsigned p = threadIdx.x; p < 1; p += blockDim.x) { + if (parent_indices_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (!sample_filter(query_id + query_id_offset, + to_source_index(parent_id), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + // If the parent must not be in the resulting top-k list, remove from the parent list + result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_indices_buffer[p]] = invalid_index; + } + } + } + __syncthreads(); + + iter++; + } + + // Filtering - use extern sample_filter function (linked via JIT LTO) + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { + INDEX_T index = result_indices_buffer[i]; + if (index == invalid_index) { continue; } + index &= ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (!sample_filter(query_id + query_id_offset, + to_source_index(index), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + __syncthreads(); + + // Output search results (1st warp only). + if (threadIdx.x < 32) { + uint32_t offset = 0; + for (uint32_t i = threadIdx.x; i < result_buffer_size_32; i += 32) { + INDEX_T index = result_indices_buffer[i]; + bool is_valid = false; + if (index != invalid_index) { + if (index & index_msb_1_mask) { + is_valid = true; + index &= ~index_msb_1_mask; + } else if ((offset < itopk_size) && + hashmap::insert( + local_traversed_hashmap_ptr, traversed_hash_bitlen, index)) { + // If a node that is not used as a parent can be inserted into + // the traversed hash table, it is considered a valid result. + is_valid = true; + } + } + const auto mask = __ballot_sync(0xffffffff, is_valid); + if (is_valid) { + const auto j = offset + __popc(mask & ((1 << threadIdx.x) - 1)); + if (j < itopk_size) { + uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = index & ~index_msb_1_mask; + if (result_distances_ptr != nullptr) { + DISTANCE_T dist = result_distances_buffer[i]; + result_distances_ptr[k] = dist; + } + } else { + // If it is valid and registered in the traversed hash table but is + // not output as a result, it is removed from the hash table. + hashmap::remove(local_traversed_hashmap_ptr, traversed_hash_bitlen, index); + } + } + offset += __popc(mask); + } + // If the number of outputs is insufficient, fill in with invalid results. + for (uint32_t i = offset + threadIdx.x; i < itopk_size; i += 32) { + uint32_t k = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); + result_indices_ptr[k] = invalid_index; + if (result_distances_ptr != nullptr) { + result_distances_ptr[k] = utils::get_max_value(); + } + } + } + + if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } + +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) && + ((query_id * 3) % gridDim.y < 3)) { + printf( + "%s:%d " + "query, %d, thread, %d" + ", init, %lu" + ", 1st_distance, %lu" + ", topk, %lu" + ", pickup_parents, %lu" + ", distance, %lu" + "\n", + __FILE__, + __LINE__, + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_pickup_parents, + clk_compute_distance); + } +#endif +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in new file mode 100644 index 0000000000..1501a22ff5 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_kernel.cu.in @@ -0,0 +1,14 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +template __global__ void search_kernel_jit<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@, @source_index_type@>( + @index_type@* const, @distance_type@* const, cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, const @data_type@* const, const @index_type@* const, const std::uint32_t, const std::uint32_t, const @source_index_type@*, const unsigned, const uint64_t, const @index_type@*, const std::uint32_t, const std::uint32_t, @index_type@* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, std::uint32_t* const, const std::uint32_t, uint32_t*, @source_index_type@, @source_index_type@); + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json new file mode 100644 index 0000000000..929165330b --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_matrix.json @@ -0,0 +1,168 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "0", + "pq_len": "0", + "pq_prefix": "", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "8", + "pq_len": "2", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_bits": "8", + "pq_len": "4", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp new file mode 100644 index 0000000000..ab819d4bef --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_cta_planner.hpp @@ -0,0 +1,67 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include +#include + +// Use nested namespace syntax to allow inclusion from within parent namespace +namespace cuvs { +namespace neighbors { +namespace cagra { +namespace detail { +namespace multi_cta_search { + +template +struct CagraMultiCtaSearchPlanner + : CagraPlannerBase { + CagraMultiCtaSearchPlanner(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq = false, + uint32_t pq_bits = 0, + uint32_t pq_len = 0) + : CagraPlannerBase( + build_entrypoint_name(metric, team_size, dataset_block_dim, is_vpq, pq_bits, pq_len), + is_vpq ? make_fragment_key() + : make_fragment_key()) + { + } + + static std::string build_entrypoint_name(cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len) + { + std::string name = "search_multi_cta"; + if (is_vpq) { name += "_vpq"; } + name += "_team_size_" + std::to_string(team_size); + name += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + if (is_vpq) { name += "_" + std::to_string(pq_bits) + "pq_" + std::to_string(pq_len) + "subd"; } + return name; + } +}; + +} // namespace multi_cta_search +} // namespace detail +} // namespace cagra +} // namespace neighbors +} // namespace cuvs diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh new file mode 100644 index 0000000000..4c44c6ee6a --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_jit.cuh @@ -0,0 +1,305 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include +#include +#include + +#include "../../jit_lto_kernels/filter_data.h" +#include "extern_device_functions.cuh" + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +template +struct has_kpq_bits { + template + static auto test(int) -> decltype(U::kPqBits, std::true_type{}); + template + static std::false_type test(...); + static constexpr bool value = decltype(test(0))::value; +}; + +template +inline constexpr bool has_kpq_bits_v = has_kpq_bits::value; + +template +RAFT_KERNEL random_pickup_kernel_jit( + dataset_descriptor_base_t* dataset_desc, + const DataT* const queries_ptr, // [num_queries, dataset_dim] + const std::size_t num_pickup, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + IndexT* const result_indices_ptr, // [num_queries, ldr] + DistanceT* const result_distances_ptr, // [num_queries, ldr] + const std::uint32_t ldr, // (*) ldr >= num_pickup + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] + const std::uint32_t hash_bitlen) +{ + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + // Get team_size_bits directly from base descriptor + uint32_t team_size_bits = dataset_desc->team_size_bitshift(); + + const auto ldb = hashmap::get_size(hash_bitlen); + const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) >> team_size_bits; + const uint32_t query_id = blockIdx.y; + if (global_team_index >= num_pickup) { return; } + extern __shared__ uint8_t smem[]; + + // Set smem working buffer using unified setup_workspace + // setup_workspace copies the descriptor to shared memory and returns base pointer to smem + // descriptor NOTE: setup_workspace must be called by ALL threads (it uses __syncthreads()) + dataset_descriptor_base_t* smem_desc = + setup_workspace(dataset_desc, smem, queries_ptr, query_id); + __syncthreads(); + + // Load args once for better performance (avoid repeated loads in the loop) + using args_t = typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t; + args_t args = smem_desc->args.load(); + IndexT dataset_size = smem_desc->size; + + INDEX_T best_index_team_local; + DISTANCE_T best_norm2_team_local = utils::get_max_value(); + for (unsigned i = 0; i < num_distilation; i++) { + INDEX_T seed_index; + if (seed_ptr && (global_team_index < num_seeds)) { + seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; + } else { + // Chose a seed node randomly + seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size; + } + + // CRITICAL: ALL threads in the team must participate in compute_distance and team_sum + // Otherwise warp shuffles will hang. Each thread calls the unified extern function to get + // its per-thread distance, then team_sum reduces across all threads in the team. + DistanceT per_thread_norm2 = 0; + // Use unified compute_distance function (planner links standard or VPQ fragment at runtime) + per_thread_norm2 = compute_distance(args, seed_index); + // Now ALL threads in the team participate in team_sum + const auto norm2 = device::team_sum(per_thread_norm2, team_size_bits); + + if (norm2 < best_norm2_team_local) { + best_norm2_team_local = norm2; + best_index_team_local = seed_index; + } + } + + const auto store_gmem_index = global_team_index + (ldr * query_id); + if ((threadIdx.x & ((1u << team_size_bits) - 1u)) == 0) { + if (hashmap::insert( + visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { + result_distances_ptr[store_gmem_index] = best_norm2_team_local; + result_indices_ptr[store_gmem_index] = best_index_team_local; + } else { + result_distances_ptr[store_gmem_index] = utils::get_max_value(); + result_indices_ptr[store_gmem_index] = utils::get_max_value(); + } + } +} + +template +RAFT_KERNEL compute_distance_to_child_nodes_kernel_jit( + const IndexT* const parent_node_list, // [num_queries, search_width] + IndexT* const parent_candidates_ptr, // [num_queries, search_width] + DistanceT* const parent_distance_ptr, // [num_queries, search_width] + const std::size_t lds, + const std::uint32_t search_width, + dataset_descriptor_base_t* dataset_desc, + const IndexT* const neighbor_graph_ptr, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const DataT* query_ptr, // [num_queries, data_dim] + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t hash_bitlen, + IndexT* const result_indices_ptr, // [num_queries, ldd] + DistanceT* const result_distances_ptr, // [num_queries, ldd] + const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter) +{ + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + + // Get team_size_bits directly from base descriptor + uint32_t team_size_bits = dataset_desc->team_size_bitshift(); + + const auto team_size = 1u << team_size_bits; + const uint32_t ldb = hashmap::get_size(hash_bitlen); + const auto tid = threadIdx.x + blockDim.x * blockIdx.x; + const auto global_team_id = tid >> team_size_bits; + const auto query_id = blockIdx.y; + + extern __shared__ uint8_t smem[]; + // Load a query using unified setup_workspace + // setup_workspace copies the descriptor to shared memory and returns base pointer to smem + // descriptor NOTE: setup_workspace must be called by ALL threads (it uses __syncthreads()) + dataset_descriptor_base_t* smem_desc = + setup_workspace(dataset_desc, smem, query_ptr, query_id); + + __syncthreads(); + if (global_team_id >= search_width * graph_degree) { return; } + + const std::size_t parent_list_index = + parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; + + if (parent_list_index == utils::get_max_value()) { return; } + + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; + + if (raw_parent_index == utils::get_max_value()) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + return; + } + const auto parent_index = raw_parent_index & ~index_msb_1_mask; + + const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); + + const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; + + const auto compute_distance_flag = hashmap::insert( + team_size, visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); + + // Load args once for better performance (avoid repeated loads) + using args_t = typename cuvs::neighbors::cagra::detail:: + dataset_descriptor_base_t::args_t; + args_t args = smem_desc->args.load(); + + // CRITICAL: ALL threads in the team must participate in compute_distance and team_sum + // Otherwise warp shuffles will hang. Each thread calls the unified extern function to get + // its per-thread distance, then team_sum reduces across all threads in the team. + DISTANCE_T per_thread_norm2 = 0; + if (compute_distance_flag) { + // Use unified compute_distance function (planner links standard or VPQ fragment at runtime) + per_thread_norm2 = compute_distance(args, child_id); + } + // Now ALL threads in the team participate in team_sum + DISTANCE_T norm2 = device::team_sum(per_thread_norm2, team_size_bits); + + if (compute_distance_flag) { + if ((threadIdx.x & (team_size - 1)) == 0) { + result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; + result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; + } + } else { + if ((threadIdx.x & (team_size - 1)) == 0) { + result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); + } + } + + if constexpr (!std::is_same::value) { + if (!sample_filter( + query_id, + source_indices_ptr == nullptr ? parent_index : source_indices_ptr[parent_index])) { + parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); + parent_distance_ptr[parent_list_index + (lds * query_id)] = + utils::get_max_value(); + } + } +} + +using cuvs::neighbors::detail::sample_filter; +template +RAFT_KERNEL apply_filter_kernel_jit( + const SourceIndexT* source_indices_ptr, // [num_queries, search_width] + IndexT* const result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const IndexT query_id_offset, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) - in global memory + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= result_buffer_size * num_queries) { return; } + const auto i = tid % result_buffer_size; + const auto j = tid / result_buffer_size; + const auto index = i + j * lds; + + if (result_indices_ptr[index] != ~index_msb_1_mask) { + // Use extern sample_filter function with 3 params: query_id, node_id, filter_data + // filter_data is a void* pointer to bitset_filter_data_t (or nullptr for none_filter) + SourceIndexT node_id = source_indices_ptr == nullptr + ? static_cast(result_indices_ptr[index]) + : source_indices_ptr[result_indices_ptr[index]]; + + // Construct filter_data struct in registers (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + + if (!sample_filter( + query_id_offset + j, node_id, bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_indices_ptr[index] = utils::get_max_value(); + result_distances_ptr[index] = utils::get_max_value(); + } + } +} + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp new file mode 100644 index 0000000000..6d1ac46e35 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_multi_kernel_planner.hpp @@ -0,0 +1,72 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include +#include + +// Use nested namespace syntax to allow inclusion from within parent namespace +namespace cuvs::neighbors::cagra::detail { +namespace multi_kernel_search { + +template +struct CagraMultiKernelSearchPlanner + : CagraPlannerBase { + CagraMultiKernelSearchPlanner(cuvs::distance::DistanceType metric, + const std::string& kernel_name, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq = false, + uint32_t pq_bits = 0, + uint32_t pq_len = 0) + : CagraPlannerBase( + build_entrypoint_name( + kernel_name, metric, team_size, dataset_block_dim, is_vpq, pq_bits, pq_len), + // Special case: apply_filter_kernel doesn't use DataTag, only IndexTag, DistanceTag, + // SourceIndexTag + (kernel_name == "apply_filter_kernel") + ? make_fragment_key() + : (is_vpq + ? make_fragment_key() + : make_fragment_key())) + { + } + + private: + static std::string build_entrypoint_name(const std::string& kernel_name, + cuvs::distance::DistanceType metric, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len) + { + if (kernel_name == "apply_filter_kernel") { return kernel_name; } + + std::string name = kernel_name; + if (is_vpq) { name += "_vpq"; } + name += "_team_size_" + std::to_string(team_size); + name += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + if (is_vpq) { name += "_" + std::to_string(pq_bits) + "pq_" + std::to_string(pq_len) + "subd"; } + return name; + } +}; + +} // namespace multi_kernel_search +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh new file mode 100644 index 0000000000..a5af6f1295 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_device_helpers.cuh @@ -0,0 +1,663 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Device-only includes - no host-side headers +#include "../bitonic.hpp" +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../utils.hpp" + +#include +#include + +#include + +#include +#include +#include +#include + +#include +#include // For uint4 + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Constants for persistent kernels +constexpr size_t kCacheLineBytes = 64; +constexpr uint32_t kMaxJobsNum = 8192; + +// Worker handle for persistent kernels +struct alignas(kCacheLineBytes) worker_handle_t { + using handle_t = uint64_t; + struct value_t { + uint32_t desc_id; + uint32_t query_id; + }; + union data_t { + handle_t handle; + value_t value; + }; + cuda::atomic data; +}; +static_assert(sizeof(worker_handle_t::value_t) == sizeof(worker_handle_t::handle_t)); +static_assert( + cuda::atomic::is_always_lock_free); + +constexpr worker_handle_t::handle_t kWaitForWork = std::numeric_limits::max(); +constexpr worker_handle_t::handle_t kNoMoreWork = kWaitForWork - 1; + +// Job descriptor for persistent kernels +template +struct alignas(kCacheLineBytes) job_desc_t { + using index_type = typename DATASET_DESCRIPTOR_T::INDEX_T; + using distance_type = typename DATASET_DESCRIPTOR_T::DISTANCE_T; + using data_type = typename DATASET_DESCRIPTOR_T::DATA_T; + // The algorithm input parameters + struct value_t { + uintptr_t result_indices_ptr; // [num_queries, top_k] + distance_type* result_distances_ptr; // [num_queries, top_k] + const data_type* queries_ptr; // [num_queries, dataset_dim] + uint32_t top_k; + uint32_t n_queries; + }; + using blob_elem_type = uint4; + constexpr static inline size_t kBlobSize = + raft::div_rounding_up_safe(sizeof(value_t), sizeof(blob_elem_type)); + // Union facilitates loading the input by a warp in a single request + union input_t { + blob_elem_type blob[kBlobSize]; // NOLINT + value_t value; + } input; + // Last thread triggers this flag. + cuda::atomic completion_flag; +}; + +// Pick up next parent nodes from the internal topk list +template +RAFT_DEVICE_INLINE_FUNCTION void pickup_next_parents(std::uint32_t* const terminate_flag, + INDEX_T* const next_parent_indices, + INDEX_T* const internal_topk_indices, + const std::size_t internal_topk_size, + const std::uint32_t search_width) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + + for (std::uint32_t i = threadIdx.x; i < search_width; i += 32) { + next_parent_indices[i] = utils::get_max_value(); + } + std::uint32_t itopk_max = internal_topk_size; + if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } + std::uint32_t num_new_parents = 0; + for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { + std::uint32_t jj = j; + if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } + INDEX_T index; + int new_parent = 0; + if (j < internal_topk_size) { + index = internal_topk_indices[jj]; + if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set + new_parent = 1; + } + } + const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); + if (new_parent) { + const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; + if (i < search_width) { + next_parent_indices[i] = jj; + // set most significant bit as used node + internal_topk_indices[jj] |= index_msb_1_mask; + } + } + num_new_parents += __popc(ballot_mask); + if (num_new_parents >= search_width) { break; } + } + if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } +} + +// Helper function for bitonic sort and full +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full( + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + static_assert(MAX_CANDIDATES <= 256); + if constexpr (!MULTI_WARPS) { + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_CANDIDATES + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } else { + assert(blockDim.x >= 64); + // Use two warps (64 threads) + constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; + static_assert(max_candidates_per_warp <= 128); + constexpr unsigned N = (max_candidates_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (warp_id < 2) { + /* Candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = lane_id + (raft::warp_size() * i); + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates) { + key[i] = candidate_distances[j]; + val[i] = candidate_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Sort */ + bitonic::warp_sort(key, val); + /* Reg -> Temp_candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && jl < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + __syncthreads(); + + unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; + if (warp_id < num_warps_used) { + /* Temp_candidates -> Reg */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned kl = max_candidates_per_warp - 1 - jl; + unsigned j = jl + (max_candidates_per_warp * warp_id); + unsigned k = MAX_CANDIDATES - 1 - j; + if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; + float temp_key = candidate_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + if (warp_id < num_warps_used) { + /* Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Reg -> Temp_itopk */ + for (unsigned i = 0; i < N; i++) { + unsigned jl = (N * lane_id) + i; + unsigned j = jl + (max_candidates_per_warp * warp_id); + if (j < num_candidates && j < num_itopk) { + candidate_distances[device::swizzling(j)] = key[i]; + candidate_indices[device::swizzling(j)] = val[i]; + } + } + } + if (num_warps_used > 1) { __syncthreads(); } + } +} + +// Wrapper functions to avoid pre-inlining (impacts register pressure) +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_64_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<64, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_128_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<128, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_full_wrapper_256_false( + float* candidate_distances, // [num_candidates] + std::uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + const std::uint32_t num_itopk) +{ + topk_by_bitonic_sort_and_full<256, false, uint32_t>( + candidate_distances, candidate_indices, num_candidates, num_itopk); +} + +// TopK by bitonic sort and merge (template version with MAX_ITOPK) +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( + float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + const unsigned lane_id = threadIdx.x % raft::warp_size(); + const unsigned warp_id = threadIdx.x / raft::warp_size(); + + static_assert(MAX_ITOPK <= 512); + if constexpr (!MULTI_WARPS) { + static_assert(MAX_ITOPK <= 256); + if (warp_id > 0) { return; } + constexpr unsigned N = (MAX_ITOPK + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (first) { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + } else { + /* Load itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + key[i] = itopk_distances[device::swizzling(j)]; + val[i] = itopk_indices[device::swizzling(j)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + } + /* Merge candidates */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; // [0:max_itopk-1] + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk || k >= num_candidates) continue; + float candidate_key = candidate_distances[device::swizzling(k)]; + if (key[i] > candidate_key) { + key[i] = candidate_key; + val[i] = candidate_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * lane_id) + i; + if (j < num_itopk) { + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } else { + static_assert(MAX_ITOPK == 512); + assert(blockDim.x >= 64); + // Use two warps (64 threads) or more + constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; + constexpr unsigned N = (max_itopk_per_warp + (raft::warp_size() - 1)) / raft::warp_size(); + float key[N]; + IdxT val[N]; + if (first) { + /* Load itop results (not sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = lane_id + (raft::warp_size() * i) + (max_itopk_per_warp * warp_id); + if (j < num_itopk) { + key[i] = itopk_distances[j]; + val[i] = itopk_indices[j]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Sort */ + bitonic::warp_sort(key, val); + /* Store intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + __syncthreads(); + if (warp_id < 2) { + /* Load intermedidate results */ + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + unsigned k = MAX_ITOPK - 1 - j; + if (k >= num_itopk) continue; + float temp_key = itopk_distances[device::swizzling(k)]; + if (key[i] == temp_key) continue; + if ((warp_id == 0) == (key[i] > temp_key)) { + key[i] = temp_key; + val[i] = itopk_indices[device::swizzling(k)]; + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + } + __syncthreads(); + /* Store itopk results (sorted) */ + if (warp_id < 2) { + for (unsigned i = 0; i < N; i++) { + unsigned j = (N * threadIdx.x) + i; + if (j >= num_itopk) continue; + itopk_distances[device::swizzling(j)] = key[i]; + itopk_indices[device::swizzling(j)] = val[i]; + } + } + } + const uint32_t num_itopk_div2 = num_itopk / 2; + if (threadIdx.x < 3) { + // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. + work_buf[threadIdx.x] = num_itopk_div2; + } + __syncthreads(); + + // Merge candidates (using whole threads) + for (unsigned k = threadIdx.x; k < (num_candidates < num_itopk ? num_candidates : num_itopk); + k += blockDim.x) { + const unsigned j = num_itopk - 1 - k; + const float itopk_key = itopk_distances[device::swizzling(j)]; + const float candidate_key = candidate_distances[device::swizzling(k)]; + if (itopk_key > candidate_key) { + itopk_distances[device::swizzling(j)] = candidate_key; + itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; + if (j < num_itopk_div2) { + atomicMin(work_buf + 2, j); + } else { + atomicMin(work_buf + 1, j - num_itopk_div2); + } + } + } + __syncthreads(); + + // Merge 1st and 2nd half of itopk (using whole threads) + for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { + const unsigned k = j + num_itopk_div2; + float key_0 = itopk_distances[device::swizzling(j)]; + float key_1 = itopk_distances[device::swizzling(k)]; + if (key_0 > key_1) { + itopk_distances[device::swizzling(j)] = key_1; + itopk_distances[device::swizzling(k)] = key_0; + IdxT val_0 = itopk_indices[device::swizzling(j)]; + IdxT val_1 = itopk_indices[device::swizzling(k)]; + itopk_indices[device::swizzling(j)] = val_1; + itopk_indices[device::swizzling(k)] = val_0; + atomicMin(work_buf + 0, j); + } + } + if (threadIdx.x == blockDim.x - 1) { + if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } + } + __syncthreads(); + // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. + if (warp_id < 2) { + // Load intermedidate itopk results + const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 + for (unsigned i = 0; i < N; i++) { + unsigned k = num_itopk; + unsigned j = (N * lane_id) + i; + if (j < turning_point) { + k = j + (num_itopk_div2 * warp_id); + } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { + j -= (MAX_ITOPK / 2 - num_itopk_div2); + if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } + } + if (k < num_itopk) { + key[i] = itopk_distances[device::swizzling(k)]; + val[i] = itopk_indices[device::swizzling(k)]; + } else { + key[i] = utils::get_max_value(); + val[i] = utils::get_max_value(); + } + } + /* Warp Merge */ + bitonic::warp_merge(key, val, raft::warp_size()); + /* Store new itopk results */ + for (unsigned i = 0; i < N; i++) { + const unsigned j = (N * lane_id) + i; + if (j < num_itopk_div2) { + unsigned k = j + (num_itopk_div2 * warp_id); + itopk_distances[device::swizzling(k)] = key[i]; + itopk_indices[device::swizzling(k)] = val[i]; + } + } + } + } +} + +// Wrapper functions to avoid pre-inlining +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_64_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<64, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_128_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<128, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge_wrapper_256_false( + float* itopk_distances, // [num_itopk] + uint32_t* itopk_indices, // [num_itopk] + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + uint32_t* candidate_indices, // [num_candidates] + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + topk_by_bitonic_sort_and_merge<256, false, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); +} + +// TopK by bitonic sort and merge (runtime version) +template +RAFT_DEVICE_INLINE_FUNCTION void topk_by_bitonic_sort_and_merge( + float* itopk_distances, // [num_itopk] + IdxT* itopk_indices, // [num_itopk] + const std::uint32_t max_itopk, + const std::uint32_t num_itopk, + float* candidate_distances, // [num_candidates] + IdxT* candidate_indices, // [num_candidates] + const std::uint32_t max_candidates, + const std::uint32_t num_candidates, + std::uint32_t* work_buf, + const bool first) +{ + static_assert(std::is_same_v); + assert(max_itopk <= 512); + assert(max_candidates <= 256); + assert(!MULTI_WARPS || blockDim.x >= 64); + + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_full + // function (vs post-inlining, this impacts register pressure) + if (max_candidates <= 64) { + topk_by_bitonic_sort_and_full_wrapper_64_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else if (max_candidates <= 128) { + topk_by_bitonic_sort_and_full_wrapper_128_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } else { + topk_by_bitonic_sort_and_full_wrapper_256_false( + candidate_distances, candidate_indices, num_candidates, num_itopk); + } + + if constexpr (!MULTI_WARPS) { + assert(max_itopk <= 256); + // use a non-template wrapper function to avoid pre-inlining the topk_by_bitonic_sort_and_merge + // function (vs post-inlining, this impacts register pressure) + if (max_itopk <= 64) { + topk_by_bitonic_sort_and_merge_wrapper_64_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else if (max_itopk <= 128) { + topk_by_bitonic_sort_and_merge_wrapper_128_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } else { + topk_by_bitonic_sort_and_merge_wrapper_256_false(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } + } else { + assert(max_itopk > 256); + topk_by_bitonic_sort_and_merge<512, MULTI_WARPS, uint32_t>(itopk_distances, + itopk_indices, + num_itopk, + candidate_distances, + candidate_indices, + num_candidates, + work_buf, + first); + } +} + +// This function move the invalid index element to the end of the itopk list. +// Require : array_length % 32 == 0 && The invalid entry is only one. +template +RAFT_DEVICE_INLINE_FUNCTION void move_invalid_to_end_of_list(IdxT* const index_array, + float* const distance_array, + const std::uint32_t array_length) +{ + constexpr std::uint32_t warp_size = 32; + constexpr std::uint32_t invalid_index = utils::get_max_value(); + const std::uint32_t lane_id = threadIdx.x % warp_size; + + if (threadIdx.x >= warp_size) { return; } + + bool found_invalid = false; + if (array_length % warp_size == 0) { + for (std::uint32_t i = lane_id; i < array_length; i += warp_size) { + const auto index = index_array[i]; + const auto distance = distance_array[i]; + + if (found_invalid) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } else { + // Check if the index is invalid + const auto I_found_invalid = (index == invalid_index); + const auto who_has_invalid = raft::ballot(I_found_invalid); + // if a value that is loaded by a smaller lane id thread, shift the array + if (who_has_invalid << (warp_size - lane_id)) { + index_array[i - 1] = index; + distance_array[i - 1] = distance; + } + + found_invalid = who_has_invalid; + } + } + } + if (lane_id == 0) { + index_array[array_length - 1] = invalid_index; + distance_array[array_length - 1] = utils::get_max_value(); + } +} + +template +RAFT_DEVICE_INLINE_FUNCTION void hashmap_restore(INDEX_T* const hashmap_ptr, + const size_t hashmap_bitlen, + const INDEX_T* itopk_indices, + const uint32_t itopk_size, + const uint32_t first_tid = 0) +{ + constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; + if (threadIdx.x < first_tid) return; + for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { + auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit + hashmap::insert(hashmap_ptr, hashmap_bitlen, key); + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_embedded.cpp.in new file mode 100644 index 0000000000..fbb3735454 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + registerAlgorithm( + "search_single_cta_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh new file mode 100644 index 0000000000..9d21c13d3f --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_jit.cuh @@ -0,0 +1,734 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +// Device-only helpers - extracted from search_single_cta_kernel-inl.cuh to avoid host-side includes +#include "search_single_cta_device_helpers.cuh" + +// Additional device-side includes needed +#include "../compute_distance-ext.cuh" +#include "../device_common.hpp" +#include "../hashmap.hpp" +#include "../topk_by_radix.cuh" +#include "../utils.hpp" + +#include // For raft::shfl_xor +#include // For raft::round_up_safe +#include + +#include + +#include +#include + +#include // For assert() + +#ifdef _CLK_BREAKDOWN +#include // For printf() in debug mode +#endif + +// Include extern function declarations before namespace so they're available to kernel definitions +#include "../../jit_lto_kernels/filter_data.h" +#include "extern_device_functions.cuh" +// Include shared JIT device functions +#include "device_common_jit.cuh" + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Helper to check if DescriptorT has kPqBits (VPQ descriptor) - use shared version +// Use fully qualified name since it's a template variable +using cuvs::neighbors::cagra::detail::device::has_kpq_bits_v; + +// are defined in search_single_cta_kernel-inl.cuh which is included by the launcher. +// We don't redefine them here to avoid duplicate definitions. + +// Sample filter extern function +// sample_filter is declared in extern_device_functions.cuh +using cuvs::neighbors::detail::sample_filter; + +// JIT versions of compute_distance_to_random_nodes and compute_distance_to_child_nodes +// are now shared in device_common_jit.cuh - use fully qualified names +using cuvs::neighbors::cagra::detail::device::compute_distance_to_child_nodes_jit; +using cuvs::neighbors::cagra::detail::device::compute_distance_to_random_nodes_jit; + +// JIT version of search_core - uses dataset_descriptor_base_t* pointer +// Unified template parameters: TeamSize, DatasetBlockDim, PQ_BITS, PQ_LEN, CodebookT, QueryT +// For standard descriptors: PQ_BITS=0, PQ_LEN=0, CodebookT=void, QueryT=float (or uint8_t for +// BitwiseHamming) For VPQ descriptors: PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half +template +RAFT_DEVICE_INLINE_FUNCTION void search_core( + uintptr_t result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::uint32_t top_k, + const DataT* const queries_ptr, + const IndexT* const knn_graph, + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + dataset_descriptor_base_t* dataset_desc, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + using LOAD_T = device::LOAD_128BIT_T; + + auto to_source_index = [source_indices_ptr](IndexT x) { + return source_indices_ptr == nullptr ? static_cast(x) : source_indices_ptr[x]; + }; + +#ifdef _CLK_BREAKDOWN + std::uint64_t clk_init = 0; + std::uint64_t clk_compute_1st_distance = 0; + std::uint64_t clk_topk = 0; + std::uint64_t clk_reset_hash = 0; + std::uint64_t clk_pickup_parents = 0; + std::uint64_t clk_restore_hash = 0; + std::uint64_t clk_compute_distance = 0; + std::uint64_t clk_start; +#define _CLK_START() clk_start = clock64() +#define _CLK_REC(V) V += clock64() - clk_start; +#else +#define _CLK_START() +#define _CLK_REC(V) +#endif + _CLK_START(); + + extern __shared__ uint8_t smem[]; + + // Layout of result_buffer + const auto result_buffer_size = internal_topk + (search_width * graph_degree); + const auto result_buffer_size_32 = raft::round_up_safe(result_buffer_size, 32); + const auto small_hash_size = hashmap::get_size(small_hash_bitlen); + + // Get dim and smem_ws_size directly from base descriptor + uint32_t dim = dataset_desc->args.dim; + uint32_t smem_ws_size_in_bytes = dataset_desc->smem_ws_size_in_bytes(); + + // Set smem working buffer using unified setup_workspace + // setup_workspace copies the descriptor to shared memory and returns base pointer to smem + // descriptor NOTE: setup_workspace must be called by ALL threads (it uses __syncthreads()) + dataset_descriptor_base_t* smem_desc = + setup_workspace(dataset_desc, smem, queries_ptr, query_id); + + auto* __restrict__ result_indices_buffer = + reinterpret_cast(smem + smem_ws_size_in_bytes); + auto* __restrict__ result_distances_buffer = + reinterpret_cast(result_indices_buffer + result_buffer_size_32); + auto* __restrict__ visited_hash_buffer = + reinterpret_cast(result_distances_buffer + result_buffer_size_32); + auto* __restrict__ parent_list_buffer = + reinterpret_cast(visited_hash_buffer + small_hash_size); + auto* __restrict__ topk_ws = reinterpret_cast(parent_list_buffer + search_width); + auto* terminate_flag = reinterpret_cast(topk_ws + 3); + auto* __restrict__ smem_work_ptr = reinterpret_cast(terminate_flag + 1); + + // A flag for filtering. + auto filter_flag = terminate_flag; + + if (threadIdx.x == 0) { + terminate_flag[0] = 0; + topk_ws[0] = ~0u; + } + + // Init hashmap + IndexT* local_visited_hashmap_ptr; + if (small_hash_bitlen) { + local_visited_hashmap_ptr = visited_hash_buffer; + } else { + local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * blockIdx.y); + } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); + __syncthreads(); + _CLK_REC(clk_init); + + // compute distance to randomly selecting nodes using JIT version + _CLK_START(); + const IndexT* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; + // Get dataset_size directly from base descriptor + IndexT dataset_size = smem_desc->size; + compute_distance_to_random_nodes_jit(result_indices_buffer, + result_distances_buffer, + smem_desc, + result_buffer_size, + num_distilation, + rand_xor_mask, + local_seed_ptr, + num_seeds, + local_visited_hashmap_ptr, + hash_bitlen, + (IndexT*)nullptr, + 0); + __syncthreads(); + _CLK_REC(clk_compute_1st_distance); + + std::uint32_t iter = 0; + while (1) { + // sort + if constexpr (TOPK_BY_BITONIC_SORT) { + assert(blockDim.x >= 64); + const bool bitonic_sort_and_full_multi_warps = (max_candidates > 128) ? true : false; + + // reset small-hash table. + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + unsigned hash_start_tid; + if (blockDim.x == 32) { + hash_start_tid = 0; + } else if (blockDim.x == 64) { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { + hash_start_tid = 0; + } else { + hash_start_tid = 32; + } + } else { + if (bitonic_sort_and_full_multi_warps || BITONIC_SORT_AND_MERGE_MULTI_WARPS) { + hash_start_tid = 64; + } else { + hash_start_tid = 32; + } + } + hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); + _CLK_REC(clk_reset_hash); + } + + // topk with bitonic sort + _CLK_START(); + // For JIT version, we always check filter_flag at runtime since sample_filter is extern + if (*filter_flag != 0) { + // Move the filtered out index to the end of the itopk list + for (unsigned i = 0; i < search_width; i++) { + move_invalid_to_end_of_list( + result_indices_buffer, result_distances_buffer, internal_topk); + } + if (threadIdx.x == 0) { *terminate_flag = 0; } + } + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + max_itopk, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + max_candidates, + search_width * graph_degree, + topk_ws, + (iter == 0)); + __syncthreads(); + _CLK_REC(clk_topk); + } else { + _CLK_START(); + // topk with radix block sort + topk_by_radix_sort{}(max_itopk, + internal_topk, + result_buffer_size, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + reinterpret_cast(result_distances_buffer), + result_indices_buffer, + nullptr, + topk_ws, + true, + smem_work_ptr); + _CLK_REC(clk_topk); + + // reset small-hash table + if ((iter + 1) % small_hash_reset_interval == 0) { + _CLK_START(); + hashmap::init(local_visited_hashmap_ptr, hash_bitlen); + _CLK_REC(clk_reset_hash); + } + } + __syncthreads(); + + if (iter + 1 == max_iteration) { break; } + + // pick up next parents + if (threadIdx.x < 32) { + _CLK_START(); + pickup_next_parents( + terminate_flag, parent_list_buffer, result_indices_buffer, internal_topk, search_width); + _CLK_REC(clk_pickup_parents); + } + + // restore small-hash table by putting internal-topk indices in it + if ((iter + 1) % small_hash_reset_interval == 0) { + const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); + _CLK_START(); + hashmap_restore( + local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); + _CLK_REC(clk_restore_hash); + } + __syncthreads(); + + if (*terminate_flag && iter >= min_iteration) { break; } + + __syncthreads(); + // compute the norms between child nodes and query node using JIT version + _CLK_START(); + compute_distance_to_child_nodes_jit(result_indices_buffer + internal_topk, + result_distances_buffer + internal_topk, + smem_desc, + knn_graph, + graph_degree, + local_visited_hashmap_ptr, + hash_bitlen, + (IndexT*)nullptr, + 0, + parent_list_buffer, + result_indices_buffer, + search_width); + // Critical: __syncthreads() must be reached by ALL threads + // If any thread is stuck in compute_distance_to_child_nodes_jit, this will hang + __syncthreads(); + _CLK_REC(clk_compute_distance); + + // Filtering - use extern sample_filter function + if (threadIdx.x == 0) { *filter_flag = 0; } + __syncthreads(); + + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const IndexT invalid_index = utils::get_max_value(); + + for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { + if (parent_list_buffer[p] != invalid_index) { + const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (!sample_filter(query_id + query_id_offset, + to_source_index(parent_id), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); + result_indices_buffer[parent_list_buffer[p]] = invalid_index; + *filter_flag = 1; + } + } + } + __syncthreads(); + + iter++; + } + + // Post process for filtering - use extern sample_filter function + constexpr IndexT index_msb_1_mask = utils::gen_index_msb_1_mask::value; + const IndexT invalid_index = utils::get_max_value(); + + for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; i += blockDim.x) { + const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + if (node_id != (invalid_index & ~index_msb_1_mask) && + !sample_filter(query_id + query_id_offset, + to_source_index(node_id), + bitset_ptr != nullptr ? &filter_data : nullptr)) { + result_distances_buffer[i] = utils::get_max_value(); + result_indices_buffer[i] = invalid_index; + } + } + + __syncthreads(); + // Move invalid index items to the end of the buffer without sorting the entire buffer + using scan_op_t = cub::WarpScan; + auto& temp_storage = *reinterpret_cast(smem_work_ptr); + + constexpr std::uint32_t warp_size = 32; + if (threadIdx.x < warp_size) { + std::uint32_t num_found_valid = 0; + for (std::uint32_t buffer_offset = 0; buffer_offset < internal_topk; + buffer_offset += warp_size) { + const auto src_position = buffer_offset + threadIdx.x; + const std::uint32_t is_valid_index = + (result_indices_buffer[src_position] & (~index_msb_1_mask)) == invalid_index ? 0 : 1; + std::uint32_t new_position; + scan_op_t(temp_storage).InclusiveSum(is_valid_index, new_position); + if (is_valid_index) { + const auto dst_position = num_found_valid + (new_position - 1); + result_indices_buffer[dst_position] = result_indices_buffer[src_position]; + result_distances_buffer[dst_position] = result_distances_buffer[src_position]; + } + + num_found_valid += new_position; + for (std::uint32_t offset = (warp_size >> 1); offset > 0; offset >>= 1) { + const auto v = raft::shfl_xor(num_found_valid, offset); + if ((threadIdx.x & offset) == 0) { num_found_valid = v; } + } + + if (num_found_valid >= top_k) { break; } + } + + if (num_found_valid < top_k) { + for (std::uint32_t i = num_found_valid + threadIdx.x; i < internal_topk; i += warp_size) { + result_indices_buffer[i] = invalid_index; + result_distances_buffer[i] = utils::get_max_value(); + } + } + } + + // If the sufficient number of valid indexes are not in the internal topk, pick up from the + // candidate list. + if (top_k > internal_topk || result_indices_buffer[top_k - 1] == invalid_index) { + __syncthreads(); + topk_by_bitonic_sort_and_merge( + result_distances_buffer, + result_indices_buffer, + max_itopk, + internal_topk, + result_distances_buffer + internal_topk, + result_indices_buffer + internal_topk, + max_candidates, + search_width * graph_degree, + topk_ws, + (iter == 0)); + } + __syncthreads(); + + // NB: The indices pointer is tagged with its element size. + const uint32_t index_element_tag = result_indices_ptr & 0x3; + result_indices_ptr ^= index_element_tag; + auto write_indices = + index_element_tag == 3 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : index_element_tag == 2 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : index_element_tag == 1 + ? [](uintptr_t ptr, + uint32_t i, + SourceIndexT x) { reinterpret_cast(ptr)[i] = static_cast(x); } + : [](uintptr_t ptr, uint32_t i, SourceIndexT x) { + reinterpret_cast(ptr)[i] = static_cast(x); + }; + for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { + unsigned j = i + (top_k * query_id); + unsigned ii = i; + if constexpr (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } + if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } + + auto internal_index = + result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit + auto source_index = to_source_index(internal_index); + write_indices(result_indices_ptr, j, source_index); + } + if (threadIdx.x == 0 && num_executed_iterations != nullptr) { + num_executed_iterations[query_id] = iter + 1; + } +#ifdef _CLK_BREAKDOWN + if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) { + printf( + "%s:%d " + "query, %d, thread, %d" + ", init, %lu" + ", 1st_distance, %lu" + ", topk, %lu" + ", reset_hash, %lu" + ", pickup_parents, %lu" + ", restore_hash, %lu" + ", distance, %lu" + "\n", + __FILE__, + __LINE__, + query_id, + threadIdx.x, + clk_init, + clk_compute_1st_distance, + clk_topk, + clk_reset_hash, + clk_pickup_parents, + clk_restore_hash, + clk_compute_distance); + } +#endif +} + +// JIT kernel wrapper - calls search_core +// Unified template parameters: TeamSize, DatasetBlockDim, PQ_BITS, PQ_LEN, CodebookT, QueryT +template +RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_jit( + uintptr_t result_indices_ptr, + DistanceT* const result_distances_ptr, + const std::uint32_t top_k, + const DataT* const queries_ptr, + const IndexT* const knn_graph, + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + dataset_descriptor_base_t* dataset_desc, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + const auto query_id = blockIdx.y; + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + dataset_desc, + bitset_ptr, + bitset_len, + original_nbits); +} + +// No separate JIT types needed - use non-JIT types directly +// Helper descriptor type for job_desc_t +template +struct job_desc_jit_helper_desc { + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; +}; + +// JIT persistent kernel - uses extern functions and JIT search_core +// Unified template parameters: TeamSize, DatasetBlockDim, PQ_BITS, PQ_LEN, CodebookT, QueryT +template +RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel_p_jit( + worker_handle_t* worker_handles, + job_desc_t>* job_descriptors, + uint32_t* completion_counters, + const IndexT* const knn_graph, // [dataset_size, graph_degree] + const std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const unsigned num_distilation, + const uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + const uint32_t num_seeds, + IndexT* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + const std::uint32_t max_candidates, + const std::uint32_t max_itopk, + const std::uint32_t internal_topk, + const std::uint32_t search_width, + const std::uint32_t min_iteration, + const std::uint32_t max_iteration, + std::uint32_t* const num_executed_iterations, // [num_queries] + const std::uint32_t hash_bitlen, + const std::uint32_t small_hash_bitlen, + const std::uint32_t small_hash_reset_interval, + const std::uint32_t query_id_offset, // Offset to add to query_id when calling filter + dataset_descriptor_base_t* dataset_desc, + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len, // Bitset length + SourceIndexT original_nbits) // Original number of bits +{ + using job_desc_type = job_desc_t>; + __shared__ typename job_desc_type::input_t job_descriptor; + __shared__ worker_handle_t::data_t worker_data; + + auto& worker_handle = worker_handles[blockIdx.y].data; + uint32_t job_ix; + + while (true) { + // wait the writing phase + if (threadIdx.x == 0) { + worker_handle_t::data_t worker_data_local; + do { + worker_data_local = worker_handle.load(cuda::memory_order_relaxed); + } while (worker_data_local.handle == kWaitForWork); + if (worker_data_local.handle != kNoMoreWork) { + worker_handle.store({kWaitForWork}, cuda::memory_order_relaxed); + } + job_ix = worker_data_local.value.desc_id; + cuda::atomic_thread_fence(cuda::memory_order_acquire, cuda::thread_scope_system); + worker_data = worker_data_local; + } + if (threadIdx.x < raft::WarpSize) { + // Sync one warp and copy descriptor data + static_assert(job_desc_type::kBlobSize <= raft::WarpSize); + constexpr uint32_t kMaxJobsNum = 8192; + job_ix = raft::shfl(job_ix, 0); + if (threadIdx.x < job_desc_type::kBlobSize && job_ix < kMaxJobsNum) { + job_descriptor.blob[threadIdx.x] = job_descriptors[job_ix].input.blob[threadIdx.x]; + } + } + __syncthreads(); + if (worker_data.handle == kNoMoreWork) { break; } + + // reading phase + auto result_indices_ptr = job_descriptor.value.result_indices_ptr; + auto* result_distances_ptr = job_descriptor.value.result_distances_ptr; + auto* queries_ptr = job_descriptor.value.queries_ptr; + auto top_k = job_descriptor.value.top_k; + auto n_queries = job_descriptor.value.n_queries; + auto query_id = worker_data.value.query_id; + + // work phase - use JIT search_core + search_core(result_indices_ptr, + result_distances_ptr, + top_k, + queries_ptr, + knn_graph, + graph_degree, + source_indices_ptr, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + visited_hashmap_ptr, + max_candidates, + max_itopk, + internal_topk, + search_width, + min_iteration, + max_iteration, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + query_id, + query_id_offset, + dataset_desc, + bitset_ptr, + bitset_len, + original_nbits); + + // make sure all writes are visible even for the host + // (e.g. when result buffers are in pinned memory) + cuda::atomic_thread_fence(cuda::memory_order_release, cuda::thread_scope_system); + + // arrive to mark the end of the work phase + __syncthreads(); + if (threadIdx.x == 0) { + auto completed_count = atomicInc(completion_counters + job_ix, n_queries - 1) + 1; + if (completed_count >= n_queries) { + job_descriptors[job_ix].completion_flag.store(true, cuda::memory_order_relaxed); + } + } + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in new file mode 100644 index 0000000000..2785ea46ac --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_kernel.cu.in @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template __global__ __launch_bounds__(1024, 1) void search_kernel_jit<@topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@, @source_index_type@>( + uintptr_t, @distance_type@* const, const std::uint32_t, const @data_type@* const, const @index_type@* const, const std::uint32_t, const @source_index_type@*, const unsigned, const uint64_t, const @index_type@*, const uint32_t, @index_type@* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, std::uint32_t* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, uint32_t*, @source_index_type@, @source_index_type@); + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json new file mode 100644 index 0000000000..d9f3e97653 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_matrix.json @@ -0,0 +1,208 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_topk_by_bitonic": [ + { + "topk_by_bitonic_sort": "true", + "topk_by_bitonic_sort_str": "topk_by_bitonic_sort" + }, + { + "topk_by_bitonic_sort": "false", + "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort" + } + ], + "_bitonic_sort_and_merge_multi_warps": [ + { + "bitonic_sort_and_merge_multi_warps": "true", + "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps" + }, + { + "bitonic_sort_and_merge_multi_warps": "false", + "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "0", + "pq_len": "0", + "pq_prefix": "", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_topk_by_bitonic": [ + { + "topk_by_bitonic_sort": "true", + "topk_by_bitonic_sort_str": "topk_by_bitonic_sort" + }, + { + "topk_by_bitonic_sort": "false", + "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort" + } + ], + "_bitonic_sort_and_merge_multi_warps": [ + { + "bitonic_sort_and_merge_multi_warps": "true", + "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps" + }, + { + "bitonic_sort_and_merge_multi_warps": "false", + "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "8", + "pq_len": "2", + "pq_prefix": "", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_bits": "8", + "pq_len": "4", + "pq_prefix": "", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_embedded.cpp.in new file mode 100644 index 0000000000..953a6d7cea --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + registerAlgorithm( + "search_single_cta_p_@topk_by_bitonic_sort_str@_@bitonic_sort_and_merge_multi_warps_str@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in new file mode 100644 index 0000000000..96e4784101 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_kernel.cu.in @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template __global__ void search_kernel_p_jit<@topk_by_bitonic_sort@, @bitonic_sort_and_merge_multi_warps@, @team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@, @source_index_type@>( + worker_handle_t*, job_desc_t>*, uint32_t*, const @index_type@* const, const std::uint32_t, const @source_index_type@*, const unsigned, const uint64_t, const @index_type@*, const uint32_t, @index_type@* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, std::uint32_t* const, const std::uint32_t, const std::uint32_t, const std::uint32_t, const std::uint32_t, cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, uint32_t*, @source_index_type@, @source_index_type@); + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json new file mode 100644 index 0000000000..d9f3e97653 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_p_matrix.json @@ -0,0 +1,208 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_topk_by_bitonic": [ + { + "topk_by_bitonic_sort": "true", + "topk_by_bitonic_sort_str": "topk_by_bitonic_sort" + }, + { + "topk_by_bitonic_sort": "false", + "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort" + } + ], + "_bitonic_sort_and_merge_multi_warps": [ + { + "bitonic_sort_and_merge_multi_warps": "true", + "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps" + }, + { + "bitonic_sort_and_merge_multi_warps": "false", + "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "0", + "pq_len": "0", + "pq_prefix": "", + "pq_suffix": "" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_comma": "" + } + ] + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_source_index": [ + { + "source_index_type": "uint32_t", + "source_index_abbrev": "ui" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "_topk_by_bitonic": [ + { + "topk_by_bitonic_sort": "true", + "topk_by_bitonic_sort_str": "topk_by_bitonic_sort" + }, + { + "topk_by_bitonic_sort": "false", + "topk_by_bitonic_sort_str": "no_topk_by_bitonic_sort" + } + ], + "_bitonic_sort_and_merge_multi_warps": [ + { + "bitonic_sort_and_merge_multi_warps": "true", + "bitonic_sort_and_merge_multi_warps_str": "bitonic_sort_and_merge_multi_warps" + }, + { + "bitonic_sort_and_merge_multi_warps": "false", + "bitonic_sort_and_merge_multi_warps_str": "no_bitonic_sort_and_merge_multi_warps" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_bits": "8", + "pq_len": "2", + "pq_prefix": "", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_bits": "8", + "pq_len": "4", + "pq_prefix": "", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_comma": ", " + } + ] + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp new file mode 100644 index 0000000000..24da6d42ad --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/search_single_cta_planner.hpp @@ -0,0 +1,83 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "cagra_planner_base.hpp" +#include +#include +#include +#include + +namespace cuvs { +namespace neighbors { +namespace cagra { +namespace detail { +namespace single_cta_search { + +template +struct CagraSingleCtaSearchPlanner + : CagraPlannerBase { + CagraSingleCtaSearchPlanner(cuvs::distance::DistanceType metric, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq = false, + uint32_t pq_bits = 0, + uint32_t pq_len = 0, + bool persistent = false) + : CagraPlannerBase( + build_entrypoint_name(metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + team_size, + dataset_block_dim, + is_vpq, + pq_bits, + pq_len, + persistent), + is_vpq ? make_fragment_key() + : make_fragment_key()) + { + } + + private: + static std::string build_entrypoint_name(cuvs::distance::DistanceType metric, + bool topk_by_bitonic_sort, + bool bitonic_sort_and_merge_multi_warps, + uint32_t team_size, + uint32_t dataset_block_dim, + bool is_vpq, + uint32_t pq_bits, + uint32_t pq_len, + bool persistent) + { + std::string name = (persistent ? "search_single_cta_p" : "search_single_cta"); + name += std::string(topk_by_bitonic_sort ? "_" : "_no_") + "topk_by_bitonic_sort"; + name += std::string(bitonic_sort_and_merge_multi_warps ? "_" : "_no_") + + "bitonic_sort_and_merge_multi_warps"; + name += "_team_size_" + std::to_string(team_size); + name += "_dataset_block_dim_" + std::to_string(dataset_block_dim); + if (is_vpq) { name += "_" + std::to_string(pq_bits) + "pq_" + std::to_string(pq_len) + "subd"; } + return name; + } +}; + +} // namespace single_cta_search +} // namespace detail +} // namespace cagra +} // namespace neighbors +} // namespace cuvs diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_embedded.cpp.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_embedded.cpp.in new file mode 100644 index 0000000000..2b10e81a17 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_embedded.cpp.in @@ -0,0 +1,26 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include +#include +#include "@embedded_header_file@" + +using namespace cuvs::neighbors::cagra::detail; + +namespace { + +__attribute__((__constructor__)) void register_kernel() +{ + using QueryTag = cuvs::neighbors::cagra::detail::tag_@query_abbrev@; + registerAlgorithm( + "setup_workspace@pq_prefix@_team_size_@team_size@_dataset_block_dim_@dataset_block_dim@@pq_suffix@", + embedded_fatbin, + sizeof(embedded_fatbin)); +} + +} diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in new file mode 100644 index 0000000000..7872328efb --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_kernel.cu.in @@ -0,0 +1,13 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace cuvs::neighbors::cagra::detail { + +template __device__ cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>* setup_workspace<@team_size@, @dataset_block_dim@, @pq_bits@, @pq_len@, @codebook_type@, @data_type@, @index_type@, @distance_type@, @query_type@>( + cuvs::neighbors::cagra::detail::dataset_descriptor_base_t<@data_type@, @index_type@, @distance_type@>*, void*, const @data_type@*, uint32_t); + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json new file mode 100644 index 0000000000..23822b3996 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_matrix.json @@ -0,0 +1,158 @@ +[ + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "__half", + "data_abbrev": "h", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + }, + { + "query_type": "uint8_t", + "query_abbrev": "uc" + } + ] + }, + { + "data_type": "int8_t", + "data_abbrev": "sc", + "_query": [ + { + "query_type": "float", + "query_abbrev": "f" + } + ] + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_prefix": "_standard", + "pq_suffix": "", + "pq_bits": "0", + "pq_len": "0" + } + ], + "_codebook": [ + { + "codebook_type": "void", + "codebook_tag": "", + "codebook_tag_comma": "" + } + ], + "impl_file": "setup_workspace_standard_impl.cuh" + }, + { + "_data": [ + { + "data_type": "float", + "data_abbrev": "f" + }, + { + "data_type": "__half", + "data_abbrev": "h" + }, + { + "data_type": "uint8_t", + "data_abbrev": "uc" + }, + { + "data_type": "int8_t", + "data_abbrev": "sc" + } + ], + "_query": [ + { + "query_type": "half", + "query_abbrev": "h" + } + ], + "_index": [ + { + "index_type": "uint32_t", + "index_abbrev": "ui" + } + ], + "_distance": [ + { + "distance_type": "float", + "distance_abbrev": "f" + } + ], + "team_size": [ + "8", + "16", + "32" + ], + "dataset_block_dim": [ + "128", + "256", + "512" + ], + "_pq": [ + { + "pq_len": "2", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_2subd" + }, + { + "pq_len": "4", + "pq_bits": "8", + "pq_prefix": "_vpq", + "pq_suffix": "_8pq_4subd" + } + ], + "_codebook": [ + { + "codebook_type": "half", + "codebook_tag": "tag_codebook_half", + "codebook_tag_comma": ", " + } + ], + "impl_file": "setup_workspace_vpq_impl.cuh" + } +] diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_standard_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_standard_impl.cuh new file mode 100644 index 0000000000..61d152623d --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_standard_impl.cuh @@ -0,0 +1,47 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_standard-impl.cuh" +#include "../device_common.hpp" + +namespace cuvs::neighbors::cagra::detail { + +// Unified setup_workspace implementation for standard descriptors +// This is instantiated when PQ_BITS=0, PQ_LEN=0, CodebookT=void +// Takes dataset_descriptor_base_t* and reconstructs the derived descriptor inside +// QueryT can be float (for most metrics) or uint8_t (for BitwiseHamming) +template +__device__ dataset_descriptor_base_t* setup_workspace( + dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id) +{ + // For standard descriptors, PQ_BITS=0, PQ_LEN=0, CodebookT=void + static_assert(PQ_BITS == 0 && PQ_LEN == 0 && std::is_same_v, + "Standard descriptor requires PQ_BITS=0, PQ_LEN=0, CodebookT=void"); + + // Reconstruct the descriptor pointer from base pointer with QueryT + using desc_t = + standard_dataset_descriptor_t; + const desc_t* desc = static_cast(desc_ptr); + + // Call the free function directly - it takes DescriptorT as template parameter + const desc_t* result = setup_workspace_standard(desc, smem, queries, query_id); + return const_cast*>( + static_cast*>(result)); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_vpq_impl.cuh b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_vpq_impl.cuh new file mode 100644 index 0000000000..a6c6956066 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/jit_lto_kernels/setup_workspace_vpq_impl.cuh @@ -0,0 +1,56 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../compute_distance_vpq-impl.cuh" +#include "../device_common.hpp" + +namespace cuvs::neighbors::cagra::detail { + +// Unified setup_workspace implementation for VPQ descriptors +// This is instantiated when PQ_BITS>0, PQ_LEN>0, CodebookT=half +// Takes dataset_descriptor_base_t* and reconstructs the derived descriptor inside +// QueryT is always half for VPQ +template +__device__ dataset_descriptor_base_t* setup_workspace( + dataset_descriptor_base_t* desc_ptr, + void* smem, + const DataT* queries, + uint32_t query_id) +{ + // For VPQ descriptors, PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half + static_assert( + PQ_BITS > 0 && PQ_LEN > 0 && std::is_same_v && std::is_same_v, + "VPQ descriptor requires PQ_BITS>0, PQ_LEN>0, CodebookT=half, QueryT=half"); + + // Reconstruct the descriptor pointer from base pointer + // QueryT is always half for VPQ + using desc_t = cagra_q_dataset_descriptor_t; + const desc_t* desc = static_cast(desc_ptr); + + // Call the free function directly - it takes DescriptorT as template parameter + const desc_t* result = setup_workspace_vpq(desc, smem, queries, query_id); + return const_cast*>( + static_cast*>(result)); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh index 74822c8660..f3871a55f3 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once @@ -30,6 +30,8 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp +#include + #include #include #include @@ -91,10 +93,10 @@ struct search constexpr static bool kNeedIndexCopy = sizeof(INDEX_T) != sizeof(OutputIndexT); uint32_t num_cta_per_query; - lightweight_uvector intermediate_indices; - lightweight_uvector intermediate_distances; + rmm::device_uvector intermediate_indices; + rmm::device_uvector intermediate_distances; size_t topk_workspace_size; - lightweight_uvector topk_workspace; + rmm::device_uvector topk_workspace; search(raft::resources const& res, search_params params, @@ -104,9 +106,9 @@ struct search int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), - intermediate_indices(res), - intermediate_distances(res), - topk_workspace(res) + intermediate_indices(0, raft::resource::get_cuda_stream(res)), + intermediate_distances(0, raft::resource::get_cuda_stream(res)), + topk_workspace(0, raft::resource::get_cuda_stream(res)) { set_params(res, params); } diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh index bd4d25d8f3..834b7b21ee 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh @@ -1,37 +1,39 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +#include "../../sample_filter.cuh" +#include "sample_filter_utils.cuh" #include "search_multi_cta_kernel-inl.cuh" #include namespace cuvs::neighbors::cagra::detail::multi_cta_search { -#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ - template void select_and_run( \ - const dataset_descriptor_host& dataset_desc, \ - raft::device_matrix_view graph, \ - const IndexT* source_indices_ptr, \ - uint32_t* topk_indices_ptr, \ - DistanceT* topk_distances_ptr, \ - const DataT* queries_ptr, \ - uint32_t num_queries, \ - const uint32_t* dev_seed_ptr, \ - uint32_t* num_executed_iterations, \ - const search_params& ps, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - uint32_t small_hash_bitlen, \ - int64_t hash_bitlen, \ - uint32_t* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_seeds, \ - SampleFilterT sample_filter, \ +#define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ + template void select_and_run( \ + const dataset_descriptor_host& dataset_desc, \ + raft::device_matrix_view graph, \ + const IndexT* source_indices_ptr, \ + IndexT* topk_indices_ptr, \ + DistanceT* topk_distances_ptr, \ + const DataT* queries_ptr, \ + uint32_t num_queries, \ + const IndexT* dev_seed_ptr, \ + uint32_t* num_executed_iterations, \ + const search_params& ps, \ + uint32_t topk, \ + uint32_t block_size, \ + uint32_t result_buffer_size, \ + uint32_t smem_size, \ + uint32_t small_hash_bitlen, \ + int64_t hash_bitlen, \ + IndexT* hashmap_ptr, \ + uint32_t num_cta_per_query, \ + uint32_t num_seeds, \ + SampleFilterT sample_filter, \ cudaStream_t stream); } // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index ff24724bdc..be1a7a1a56 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -14,6 +14,12 @@ #include "topk_for_cagra/topk.h" // TODO replace with raft topk if possible #include "utils.hpp" +#ifdef CUVS_ENABLE_JIT_LTO +#include "search_multi_cta_kernel_launcher_jit.cuh" +#else +#include "set_value_batch.cuh" +#endif + #include #include #include @@ -453,7 +459,8 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( uint32_t k = j + (itopk_size * (cta_id + (num_cta_per_query * query_id))); result_indices_ptr[k] = index & ~index_msb_1_mask; if (result_distances_ptr != nullptr) { - result_distances_ptr[k] = result_distances_buffer[i]; + DISTANCE_T dist = result_distances_buffer[i]; + result_distances_ptr[k] = dist; } } else { // If it is valid and registered in the traversed hash table but is @@ -502,34 +509,6 @@ RAFT_KERNEL __launch_bounds__(1024, 1) search_kernel( #endif } -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - template struct search_kernel_config { // Search kernel function type. Note that the actual values for the template value @@ -574,6 +553,31 @@ void select_and_run(const dataset_descriptor_host& dat SampleFilterT sample_filter, cudaStream_t stream) { +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + select_and_run_jit(dataset_desc, + graph, + source_indices_ptr, + topk_indices_ptr, + topk_distances_ptr, + queries_ptr, + num_queries, + dev_seed_ptr, + num_executed_iterations, + ps, + topk, + block_size, + result_buffer_size, + smem_size, + visited_hash_bitlen, + traversed_hash_bitlen, + traversed_hashmap_ptr, + num_cta_per_query, + num_seeds, + sample_filter, + stream); +#else + // Non-JIT path auto kernel = search_kernel_config, SourceIndexT, @@ -630,6 +634,7 @@ void select_and_run(const dataset_descriptor_host& dat sample_filter); }; cuvs::neighbors::detail::safely_launch_kernel_with_smem_size(kernel, smem_size, kernel_launcher); +#endif } } // namespace multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..a18e44df51 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_kernel_launcher_jit.cuh @@ -0,0 +1,262 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "search_multi_cta_kernel_launcher_jit.cuh included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +#include "../smem_utils.cuh" + +// Include tags header before any other includes that might open namespaces +#include + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/search_multi_cta_planner.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "set_value_batch.cuh" // For set_value_batch +#include "shared_launcher_jit.hpp" // For shared JIT helper functions +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::multi_cta_search { + +// JIT version of select_and_run for multi_cta +template +void select_and_run_jit( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + IndexT* topk_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] + DistanceT* topk_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + // multi_cta_search (params struct) + uint32_t block_size, // + uint32_t result_buffer_size, + uint32_t smem_size, + uint32_t visited_hash_bitlen, + int64_t traversed_hash_bitlen, + IndexT* traversed_hashmap_ptr, + uint32_t num_cta_per_query, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + // Extract bitset data from filter object (if it's a bitset_filter) + uint32_t* bitset_ptr = nullptr; + SourceIndexT bitset_len = 0; + SourceIndexT original_nbits = 0; + uint32_t query_id_offset = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + // Always extract offset for wrapped filters + query_id_offset = sample_filter.offset; + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // Create planner with tags + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + // Create planner and register device functions + // Pass team_size, dataset_block_dim, and VPQ parameters to match the kernel entrypoint name + std::shared_ptr launcher; + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + std::string filter_name = get_sample_filter_name(); + planner.add_sample_filter_device_function(filter_name); + launcher = planner.get_launcher(); + } else { + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + std::string filter_name = get_sample_filter_name(); + planner.add_sample_filter_device_function(filter_name); + launcher = planner.get_launcher(); + } else { + using QueryTag = query_type_tag_standard_t; + CagraMultiCtaSearchPlanner + planner(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + std::string filter_name = get_sample_filter_name(); + planner.add_sample_filter_device_function(filter_name); + launcher = planner.get_launcher(); + } + } + + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher"); } + + uint32_t max_elements{}; + if (result_buffer_size <= 64) { + max_elements = 64; + } else if (result_buffer_size <= 128) { + max_elements = 128; + } else if (result_buffer_size <= 256) { + max_elements = 256; + } else { + THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); + } + + // Initialize hash table + const uint32_t traversed_hash_size = hashmap::get_size(traversed_hash_bitlen); + set_value_batch(traversed_hashmap_ptr, + traversed_hash_size, + ~static_cast(0), + traversed_hash_size, + num_queries, + stream); + + dim3 block_dims(block_size, 1, 1); + dim3 grid_dims(num_cta_per_query, num_queries, 1); + + // Get the device descriptor pointer + const dataset_descriptor_base_t* dev_desc_base = + dataset_desc.dev_ptr(stream); + const auto* dev_desc = dev_desc_base; + + // Note: dataset_desc is passed by const reference, so it stays alive for the duration of this + // function The descriptor's state is managed by a shared_ptr internally, so no need to explicitly + // keep it alive + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + // graph.extent(1) returns int64_t but kernel expects uint32_t + // traversed_hash_bitlen is int64_t but kernel expects uint32_t + // ps.itopk_size, ps.min_iterations, ps.max_iterations are size_t (8 bytes) but kernel expects + // uint32_t (4 bytes) ps.num_random_samplings is uint32_t but kernel expects unsigned - cast for + // consistency + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t traversed_hash_bitlen_u32 = static_cast(traversed_hash_bitlen); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + auto kernel_launcher = [&](auto const& kernel) -> void { + launcher->dispatch(stream, + grid_dims, + block_dims, + smem_size, + topk_indices_ptr, + topk_distances_ptr, + dev_desc, + queries_ptr, + graph.data_handle(), + max_elements, + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + ps.rand_xor_mask, // uint64_t matches kernel (8 bytes) + dev_seed_ptr, + num_seeds, + visited_hash_bitlen, + traversed_hashmap_ptr, + traversed_hash_bitlen_u32, // Cast int64_t to uint32_t + itopk_size_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + num_executed_iterations, + query_id_offset, // Offset to add to query_id when calling filter + bitset_ptr, + bitset_len, + original_nbits); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( + launcher->get_kernel(), smem_size, kernel_launcher); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh index f7d353d864..12f711b31d 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel.cuh @@ -1,9 +1,18 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once +// Include tags header before any namespace declarations to avoid issues when it's included inside +// functions +#ifdef CUVS_ENABLE_JIT_LTO +#include "search_multi_kernel_launcher_jit.cuh" +#include +#endif + +#include "set_value_batch.cuh" + #include "compute_distance-ext.cuh" #include "device_common.hpp" #include "hashmap.hpp" @@ -168,24 +177,47 @@ void random_pickup(const dataset_descriptor_host& data std::uint32_t hash_bitlen, cudaStream_t cuda_stream) { - const auto block_size = 256u; - const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; - const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, - num_queries); - - random_pickup_kernel<<>>( - dataset_desc.dev_ptr(cuda_stream), - queries_ptr, - num_pickup, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - result_indices_ptr, - result_distances_ptr, - ldr, - visited_hashmap_ptr, - hash_bitlen); +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + random_pickup_jit(dataset_desc, + queries_ptr, + num_queries, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen, + cuda_stream); +#else + // Non-JIT path + { + const auto block_size = 256u; + const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; + const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, + num_queries); + + random_pickup_kernel<<>>(dataset_desc.dev_ptr(cuda_stream), + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen); + } +#endif } template @@ -402,30 +434,55 @@ void compute_distance_to_child_nodes( SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream) { - const auto block_size = 128; - const auto teams_per_block = block_size / dataset_desc.team_size; - const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, - num_queries); - - compute_distance_to_child_nodes_kernel<<>>(parent_node_list, - parent_candidates_ptr, - parent_distance_ptr, - lds, - search_width, - dataset_desc.dev_ptr(cuda_stream), - neighbor_graph_ptr, - graph_degree, - source_indices_ptr, - query_ptr, - visited_hashmap_ptr, - hash_bitlen, - result_indices_ptr, - result_distances_ptr, - ldd, - sample_filter); +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + compute_distance_to_child_nodes_jit(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + num_queries, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter, + cuda_stream); +#else + // Non-JIT path + { + const auto block_size = 128; + const auto teams_per_block = block_size / dataset_desc.team_size; + const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, + num_queries); + + compute_distance_to_child_nodes_kernel<<>>(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_desc.dev_ptr(cuda_stream), + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter); + } +#endif } template @@ -497,17 +554,33 @@ void apply_filter(const SourceIndexT* source_indices_ptr, SAMPLE_FILTER_T sample_filter, cudaStream_t cuda_stream) { - const std::uint32_t block_size = 256; - const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); - - apply_filter_kernel<<>>(source_indices_ptr, - result_indices_ptr, - result_distances_ptr, - lds, - result_buffer_size, - num_queries, - query_id_offset, - sample_filter); +#ifdef CUVS_ENABLE_JIT_LTO + // Use JIT version when JIT is enabled + apply_filter_jit(source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + sample_filter, + cuda_stream); +#else + // Non-JIT path + { + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); + + apply_filter_kernel<<>>(source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + sample_filter); + } +#endif } template @@ -542,34 +615,6 @@ void batched_memcpy(T* const dst, // [batch_size, ld_dst] <<>>(dst, ld_dst, src, ld_src, count, batch_size); } -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - // result_buffer (work buffer) for "multi-kernel" // +--------------------+------------------------------+-------------------+ // | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | @@ -629,18 +674,18 @@ struct search using base_type::num_seeds; size_t result_buffer_allocation_size; - lightweight_uvector result_indices; // results_indices_buffer - lightweight_uvector result_distances; // result_distances_buffer - lightweight_uvector parent_node_list; - lightweight_uvector topk_hint; - lightweight_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; - lightweight_uvector topk_workspace; + rmm::device_uvector result_indices; // results_indices_buffer + rmm::device_uvector result_distances; // result_distances_buffer + rmm::device_uvector parent_node_list; + rmm::device_uvector topk_hint; + rmm::device_uvector terminate_flag; // dev_terminate_flag, host_terminate_flag.; + rmm::device_uvector topk_workspace; // temporary storage for _find_topk - lightweight_uvector input_keys_storage; - lightweight_uvector output_keys_storage; - lightweight_uvector input_values_storage; - lightweight_uvector output_values_storage; + rmm::device_uvector input_keys_storage; + rmm::device_uvector output_keys_storage; + rmm::device_uvector input_values_storage; + rmm::device_uvector output_values_storage; search(raft::resources const& res, search_params params, @@ -650,16 +695,16 @@ struct search int64_t graph_degree, uint32_t topk) : base_type(res, params, dataset_desc, dim, dataset_size, graph_degree, topk), - result_indices(res), - result_distances(res), - parent_node_list(res), - topk_hint(res), - topk_workspace(res), - terminate_flag(res), - input_keys_storage(res), - output_keys_storage(res), - input_values_storage(res), - output_values_storage(res) + result_indices(0, raft::resource::get_cuda_stream(res)), + result_distances(0, raft::resource::get_cuda_stream(res)), + parent_node_list(0, raft::resource::get_cuda_stream(res)), + topk_hint(0, raft::resource::get_cuda_stream(res)), + topk_workspace(0, raft::resource::get_cuda_stream(res)), + terminate_flag(0, raft::resource::get_cuda_stream(res)), + input_keys_storage(0, raft::resource::get_cuda_stream(res)), + output_keys_storage(0, raft::resource::get_cuda_stream(res)), + input_values_storage(0, raft::resource::get_cuda_stream(res)), + output_values_storage(0, raft::resource::get_cuda_stream(res)) { set_params(res); } @@ -858,6 +903,7 @@ struct search // pickup parent nodes uint32_t _small_hash_bitlen = 0; if ((iter + 1) % small_hash_reset_interval == 0) { _small_hash_bitlen = small_hash_bitlen; } + pickup_next_parents(result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, result_buffer_allocation_size, itopk_size, @@ -872,9 +918,11 @@ struct search stream); // termination (2) - if (iter + 1 >= min_iterations && get_value(terminate_flag.data(), stream)) { - iter++; - break; + if (iter + 1 >= min_iterations) { + if (get_value(terminate_flag.data(), stream)) { + iter++; + break; + } } // Compute distance to child nodes that are adjacent to the parent node @@ -982,7 +1030,6 @@ struct search num_executed_iterations[i] = iter; } } - RAFT_CUDA_TRY(cudaPeekAtLastError()); } }; diff --git a/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..c3f13d07c3 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_multi_kernel_launcher_jit.cuh @@ -0,0 +1,392 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "search_multi_kernel_launcher_jit.cuh included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +// Tags header should be included before this header (at file scope, not inside functions) +// to avoid namespace definition errors when this header is included inside function bodies + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/search_multi_kernel_planner.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "shared_launcher_jit.hpp" // For shared JIT helper functions +#include +#include +#include +#include +#include + +#include +#include +#include +// - The launcher doesn't need the kernel function definitions +// - The kernel is dispatched via the JIT LTO launcher system +// - Including it would pull in impl files that cause namespace issues + +namespace cuvs::neighbors::cagra::detail::multi_kernel_search { + +// JIT version of random_pickup +template +void random_pickup_jit(const dataset_descriptor_host& dataset_desc, + const DataT* queries_ptr, // [num_queries, dataset_dim] + std::size_t num_queries, + std::size_t num_pickup, + unsigned num_distilation, + uint64_t rand_xor_mask, + const IndexT* seed_ptr, // [num_queries, num_seeds] + uint32_t num_seeds, + IndexT* result_indices_ptr, // [num_queries, ldr] + DistanceT* result_distances_ptr, // [num_queries, ldr] + std::size_t ldr, // (*) ldr >= num_pickup + IndexT* visited_hashmap_ptr, // [num_queries, 1 << bitlen] + std::uint32_t hash_bitlen, + cudaStream_t cuda_stream) +{ + // Create planner with tags + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); // Use IndexT for source + + // Create planner and register device functions + std::shared_ptr launcher; + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + "random_pickup", + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + launcher = planner.get_launcher(); + } else { + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + "random_pickup", + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + launcher = planner.get_launcher(); + } else { + using QueryTag = query_type_tag_standard_t; + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + "random_pickup", + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + launcher = planner.get_launcher(); + } + } + + const auto block_size = 256u; + const auto num_teams_per_threadblock = block_size / dataset_desc.team_size; + const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, + num_queries); + + // Get the device descriptor pointer + const auto* dev_desc = dataset_desc.dev_ptr(cuda_stream); + + // Cast size_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t ldr_u32 = static_cast(ldr); + + // Dispatch kernel via launcher + launcher->dispatch(cuda_stream, + grid_size, + dim3(block_size, 1, 1), + dataset_desc.smem_ws_size_in_bytes, + dev_desc, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr_u32, // Cast size_t to uint32_t + visited_hashmap_ptr, + hash_bitlen); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// JIT version of compute_distance_to_child_nodes +template +void compute_distance_to_child_nodes_jit( + const IndexT* parent_node_list, // [num_queries, search_width] + IndexT* const parent_candidates_ptr, // [num_queries, search_width] + DistanceT* const parent_distance_ptr, // [num_queries, search_width] + std::size_t lds, + uint32_t search_width, + const dataset_descriptor_host& dataset_desc, + const IndexT* neighbor_graph_ptr, // [dataset_size, graph_degree] + std::uint32_t graph_degree, + const SourceIndexT* source_indices_ptr, + const DataT* query_ptr, // [num_queries, data_dim] + std::uint32_t num_queries, + IndexT* visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] + std::uint32_t hash_bitlen, + IndexT* result_indices_ptr, // [num_queries, ldd] + DistanceT* result_distances_ptr, // [num_queries, ldd] + std::uint32_t ldd, // (*) ldd >= search_width * graph_degree + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream) +{ + // Create planner with tags + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + // Create planner and register device functions + std::shared_ptr launcher; + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + "compute_distance_to_child_nodes", + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + launcher = planner.get_launcher(); + } else { + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + "compute_distance_to_child_nodes", + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + launcher = planner.get_launcher(); + } else { + using QueryTag = query_type_tag_standard_t; + CagraMultiKernelSearchPlanner + planner(dataset_desc.metric, + "compute_distance_to_child_nodes", + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + launcher = planner.get_launcher(); + } + } + + const auto block_size = 128; + const auto teams_per_block = block_size / dataset_desc.team_size; + const dim3 grid_size((search_width * graph_degree + teams_per_block - 1) / teams_per_block, + num_queries); + + // Get the device descriptor pointer + const auto* dev_desc = dataset_desc.dev_ptr(cuda_stream); + + // Dispatch kernel via launcher + launcher->dispatch(cuda_stream, + grid_size, + dim3(block_size, 1, 1), + dataset_desc.smem_ws_size_in_bytes, + parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dev_desc, + neighbor_graph_ptr, + graph_degree, + source_indices_ptr, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +// JIT version of apply_filter +template +void apply_filter_jit(const SourceIndexT* source_indices_ptr, + INDEX_T* const result_indices_ptr, + DISTANCE_T* const result_distances_ptr, + const std::size_t lds, + const std::uint32_t result_buffer_size, + const std::uint32_t num_queries, + const INDEX_T query_id_offset, + SAMPLE_FILTER_T sample_filter, + cudaStream_t cuda_stream) +{ + // Extract bitset data from filter object (if it's a bitset_filter) + uint32_t* bitset_ptr = nullptr; + SourceIndexT bitset_len = 0; + SourceIndexT original_nbits = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + // Note: query_id_offset is already a parameter to this function, so we don't extract it here + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // Create planner with tags + using DataTag = + decltype(get_data_type_tag()); // Not used for apply_filter, but required by planner + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + // Create planner - apply_filter doesn't use dataset_descriptor, so we use dummy values + // The kernel name is "apply_filter_kernel" and build_entrypoint_name will handle it specially + using QueryTag = query_type_tag_standard_t; + using CodebookTag = void; + CagraMultiKernelSearchPlanner + planner(cuvs::distance::DistanceType::L2Expanded, + "apply_filter_kernel", + 8, + 128, + false, + 0, + 0); // Dummy values, not used by apply_filter + + // Add sample filter device function - determine filter type from template parameter + planner.add_sample_filter_device_function(get_sample_filter_name()); + + auto launcher = planner.get_launcher(); + + const std::uint32_t block_size = 256; + const std::uint32_t grid_size = raft::ceildiv(num_queries * result_buffer_size, block_size); + + // Dispatch kernel via launcher with bitset parameters + launcher->dispatch(cuda_stream, + dim3(grid_size, 1, 1), + dim3(block_size, 1, 1), + 0, // No shared memory needed + source_indices_ptr, + result_indices_ptr, + result_distances_ptr, + lds, + result_buffer_size, + num_queries, + query_id_offset, + bitset_ptr, + bitset_len, + original_nbits); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); +} + +} // namespace cuvs::neighbors::cagra::detail::multi_kernel_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh index 0fdf0f208b..7a6a12b67e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta.cuh @@ -34,6 +34,7 @@ #include #include +// All includes are done before opening namespace to avoid nested namespace issues namespace cuvs::neighbors::cagra::detail { namespace single_cta_search { diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh index 11b468cfca..d242e13b95 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh @@ -1,13 +1,16 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ #pragma once -#include "search_single_cta_kernel-inl.cuh" #include +// Include explicit instantiations before namespace (launcher includes JIT LTO headers with +// namespace definitions) +#include "search_single_cta_kernel_explicit_inst.cuh" + namespace cuvs::neighbors::cagra::detail::single_cta_search { #define instantiate_kernel_selection(DataT, IndexT, DistanceT, SampleFilterT) \ diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 62600c97dd..3acfd4999e 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -15,6 +15,8 @@ #include "topk_for_cagra/topk.h" // TODO replace with raft topk #include "utils.hpp" +#include + #include #include #include @@ -2215,138 +2217,5 @@ auto get_runner(Args... args) -> std::shared_ptr weak = runner; return runner; } - -template -void select_and_run( - const dataset_descriptor_host& dataset_desc, - raft::device_matrix_view graph, - std::optional> source_indices, - uintptr_t topk_indices_ptr, // [num_queries, topk] - DistanceT* topk_distances_ptr, // [num_queries, topk] - const DataT* queries_ptr, // [num_queries, dataset_dim] - uint32_t num_queries, - const IndexT* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* num_executed_iterations, // [num_queries,] - const search_params& ps, - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - IndexT* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_seeds, - SampleFilterT sample_filter, - cudaStream_t stream) -{ - const SourceIndexT* source_indices_ptr = - source_indices.has_value() ? source_indices->data_handle() : nullptr; - - uint32_t max_candidates{}; - if (num_itopk_candidates <= 64) { - max_candidates = 64; - } else if (num_itopk_candidates <= 128) { - max_candidates = 128; - } else if (num_itopk_candidates <= 256) { - max_candidates = 256; - } else { - max_candidates = - 32; // irrelevant, radix based topk is used (see choose_itopk_and_max_candidates) - } - - uint32_t max_itopk{}; - assert(ps.itopk_size <= 512); - if (num_itopk_candidates <= 256) { // bitonic sort - if (ps.itopk_size <= 64) { - max_itopk = 64; - } else if (ps.itopk_size <= 128) { - max_itopk = 128; - } else if (ps.itopk_size <= 256) { - max_itopk = 256; - } else { - max_itopk = 512; - } - } else { // radix sort - if (ps.itopk_size <= 256) { - max_itopk = 256; - } else { - max_itopk = 512; - } - } - - if (ps.persistent) { - using runner_type = persistent_runner_t; - - get_runner(/* -Note, we're passing the descriptor by reference here, and this reference is going to be passed to a -new spawned thread, which is dangerous. However, the descriptor is copied in that thread before the -control is returned in this thread (in persistent_runner_t constructor), so we're safe. -*/ - std::cref(dataset_desc), - graph, - source_indices_ptr, - max_candidates, - num_itopk_candidates, - block_size, - smem_size, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - ps.num_random_samplings, - ps.rand_xor_mask, - num_seeds, - max_itopk, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - sample_filter, - ps.persistent_lifetime, - ps.persistent_device_usage) - ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); - } else { - using descriptor_base_type = dataset_descriptor_base_t; - auto kernel = search_kernel_config:: - choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - auto const& kernel_launcher = [&](auto const& kernel) -> void { - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset_desc.dev_ptr(stream), - queries_ptr, - graph.data_handle(), - graph.extent(1), - source_indices_ptr, - ps.num_random_samplings, - ps.rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - max_candidates, - max_itopk, - ps.itopk_size, - ps.search_width, - ps.min_iterations, - ps.max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter); - }; - cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( - kernel, smem_size, kernel_launcher); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } -} } // namespace single_cta_search } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh new file mode 100644 index 0000000000..8f715bbbc4 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_explicit_inst.cuh @@ -0,0 +1,12 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifdef CUVS_ENABLE_JIT_LTO +#include "search_single_cta_kernel_launcher_jit.cuh" +#else +#include "search_single_cta_kernel_launcher.cuh" +#endif diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher.cuh new file mode 100644 index 0000000000..4b7cb0a623 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher.cuh @@ -0,0 +1,122 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "../smem_utils.cuh" + +#include "search_single_cta_kernel-inl.cuh" // For search_kernel_config, persistent_runner_t, etc. +#include "search_single_cta_kernel_launcher_common.cuh" + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +template +void select_and_run( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const SourceIndexT* source_indices_ptr = + source_indices.has_value() ? source_indices->data_handle() : nullptr; + + // Use common logic to compute launch config + auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); + uint32_t max_candidates = config.max_candidates; + uint32_t max_itopk = config.max_itopk; + + if (ps.persistent) { + using runner_type = persistent_runner_t; + + get_runner(/* +Note, we're passing the descriptor by reference here, and this reference is going to be passed to a +new spawned thread, which is dangerous. However, the descriptor is copied in that thread before the +control is returned in this thread (in persistent_runner_t constructor), so we're safe. +*/ + std::cref(dataset_desc), + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + ps.num_random_samplings, + ps.rand_xor_mask, + num_seeds, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + sample_filter, + ps.persistent_lifetime, + ps.persistent_device_usage) + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + } else { + using descriptor_base_type = dataset_descriptor_base_t; + auto kernel = search_kernel_config:: + choose_itopk_and_mx_candidates(ps.itopk_size, num_itopk_candidates, block_size); + + dim3 thread_dims(block_size, 1, 1); + dim3 block_dims(1, num_queries, 1); + RAFT_LOG_DEBUG( + "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); + auto const& kernel_launcher = [&](auto const& kernel) -> void { + kernel<<>>(topk_indices_ptr, + topk_distances_ptr, + topk, + dataset_desc.dev_ptr(stream), + queries_ptr, + graph.data_handle(), + graph.extent(1), + source_indices_ptr, + ps.num_random_samplings, + ps.rand_xor_mask, + dev_seed_ptr, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + num_executed_iterations, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + sample_filter); + }; + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( + kernel, smem_size, kernel_launcher); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh new file mode 100644 index 0000000000..b1e2191fec --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_common.cuh @@ -0,0 +1,63 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// Common logic for computing max_candidates and max_itopk +struct LaunchConfig { + uint32_t max_candidates; + uint32_t max_itopk; + bool topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps; +}; + +inline LaunchConfig compute_launch_config(uint32_t num_itopk_candidates, + uint32_t itopk_size, + uint32_t block_size) +{ + LaunchConfig config{}; + + // Compute max_candidates + if (num_itopk_candidates <= 64) { + config.max_candidates = 64; + } else if (num_itopk_candidates <= 128) { + config.max_candidates = 128; + } else if (num_itopk_candidates <= 256) { + config.max_candidates = 256; + } else { + config.max_candidates = 32; // irrelevant, radix based topk is used + } + + // Compute max_itopk and sort flags + config.topk_by_bitonic_sort = (num_itopk_candidates <= 256); + config.bitonic_sort_and_merge_multi_warps = false; + + if (config.topk_by_bitonic_sort) { + if (itopk_size <= 64) { + config.max_itopk = 64; + } else if (itopk_size <= 128) { + config.max_itopk = 128; + } else if (itopk_size <= 256) { + config.max_itopk = 256; + } else { + config.max_itopk = 512; + config.bitonic_sort_and_merge_multi_warps = (block_size >= 64); + } + } else { + if (itopk_size <= 256) { + config.max_itopk = 256; + } else { + config.max_itopk = 512; + } + } + + return config; +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh new file mode 100644 index 0000000000..336bff5d81 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_kernel_launcher_jit.cuh @@ -0,0 +1,1022 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "search_single_cta_kernel_launcher_jit.cuh included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +#include "../smem_utils.cuh" + +#include +#include + +// Include tags header before any other includes that might open namespaces +#include + +#include "compute_distance.hpp" // For dataset_descriptor_host +#include "jit_lto_kernels/search_single_cta_planner.hpp" +#include "sample_filter_utils.cuh" // For CagraSampleFilterWithQueryIdOffset +#include "search_plan.cuh" // For search_params +#include "search_single_cta_kernel-inl.cuh" // For resource_queue_t, local_deque_t, launcher_t, persistent_runner_base_t, etc. +#include "search_single_cta_kernel_launcher_common.cuh" +#include "shared_launcher_jit.hpp" // For shared JIT helper functions + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail::single_cta_search { + +// The launcher uses types from search_single_cta_kernel-inl.cuh (worker_handle_t, job_desc_t) +// The JIT kernel headers define _jit versions that are compatible + +// Forward declarations +template +auto get_runner_jit(Args... args) -> std::shared_ptr; + +template +auto create_runner_jit(Args... args) -> std::shared_ptr; + +// Helper functions are now in shared_launcher_jit.hpp + +// JIT-compatible launcher_t that works with worker_handle_t (same as non-JIT version) +struct alignas(kCacheLineBytes) launcher_jit_t { + using job_queue_type = resource_queue_t; + using worker_queue_type = resource_queue_t; + using pending_reads_queue_type = local_deque_t; + using completion_flag_type = cuda::atomic; + + pending_reads_queue_type pending_reads; + job_queue_type& job_ids; + worker_queue_type& idle_worker_ids; + worker_handle_t* worker_handles; + uint32_t job_id; + completion_flag_type* completion_flag; + bool all_done = false; + + static inline constexpr auto kDefaultLatency = std::chrono::nanoseconds(50000); + static inline constexpr auto kMaxExpectedLatency = + kDefaultLatency * std::max(10, kMaxJobsNum / 128); + static inline thread_local auto expected_latency = kDefaultLatency; + const std::chrono::time_point start; + std::chrono::time_point now; + const int64_t pause_factor; + int pause_count = 0; + std::chrono::time_point deadline; + + template + launcher_jit_t(job_queue_type& job_ids, + worker_queue_type& idle_worker_ids, + worker_handle_t* worker_handles, + uint32_t n_queries, + std::chrono::milliseconds max_wait_time, + RecordWork record_work) + : pending_reads{std::min(n_queries, kMaxWorkersPerThread)}, + job_ids{job_ids}, + idle_worker_ids{idle_worker_ids}, + worker_handles{worker_handles}, + job_id{job_ids.pop().wait()}, + completion_flag{record_work(job_id)}, + start{std::chrono::system_clock::now()}, + pause_factor{calc_pause_factor(n_queries)}, + now{start}, + deadline{start + max_wait_time + expected_latency} + { + submit_query(idle_worker_ids.pop().wait(), 0); + for (uint32_t i = 1; i < n_queries; i++) { + auto promised_worker = idle_worker_ids.pop(); + uint32_t worker_id; + while (!promised_worker.test(worker_id)) { + if (pending_reads.try_pop_front(worker_id)) { + bool returned_some = false; + for (bool keep_returning = true; keep_returning;) { + if (try_return_worker(worker_id)) { + keep_returning = pending_reads.try_pop_front(worker_id); + returned_some = true; + } else { + pending_reads.push_front(worker_id); + keep_returning = false; + } + } + if (!returned_some) { pause(); } + } else { + worker_id = promised_worker.wait(); + break; + } + } + pause_count = 0; + submit_query(worker_id, i); + if (i >= kSoftMaxWorkersPerThread && pending_reads.try_pop_front(worker_id)) { + if (!try_return_worker(worker_id)) { pending_reads.push_front(worker_id); } + } + } + } + + inline ~launcher_jit_t() noexcept + { + constexpr size_t kWindow = 100; + expected_latency = std::min( + ((kWindow - 1) * expected_latency + now - start) / kWindow, kMaxExpectedLatency); + if (job_id != job_queue_type::kEmpty) { job_ids.push(job_id); } + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + idle_worker_ids.push(worker_id); + } + } + + inline void submit_query(uint32_t worker_id, uint32_t query_id) + { + worker_handles[worker_id].data.store(worker_handle_t::data_t{.value = {job_id, query_id}}, + cuda::memory_order_relaxed); + while (!pending_reads.try_push_back(worker_id)) { + auto pending_worker_id = pending_reads.pop_front(); + while (!try_return_worker(pending_worker_id)) { + pause(); + } + } + pause_count = 0; + } + + inline auto try_return_worker(uint32_t worker_id) -> bool + { + if (all_done || + !is_worker_busy(worker_handles[worker_id].data.load(cuda::memory_order_relaxed).handle)) { + idle_worker_ids.push(worker_id); + return true; + } else { + return false; + } + } + + inline auto is_all_done() + { + if (all_done) { return true; } + all_done = completion_flag->load(cuda::memory_order_relaxed); + return all_done; + } + + [[nodiscard]] inline auto sleep_limit() const + { + constexpr auto kMinWakeTime = std::chrono::nanoseconds(10000); + constexpr double kSleepLimit = 0.6; + return start + expected_latency * kSleepLimit - kMinWakeTime; + } + + [[nodiscard]] inline auto overtime_threshold() const + { + constexpr auto kOvertimeFactor = 3; + return start + expected_latency * kOvertimeFactor; + } + + [[nodiscard]] inline auto calc_pause_factor(uint32_t n_queries) const -> uint32_t + { + constexpr uint32_t kMultiplier = 10; + return kMultiplier * raft::div_rounding_up_safe(n_queries, idle_worker_ids.capacity()); + } + + inline void pause() + { + constexpr auto kSpinLimit = 3; + constexpr auto kPauseTimeMin = std::chrono::nanoseconds(1000); + constexpr auto kPauseTimeMax = std::chrono::nanoseconds(50000); + if (pause_count++ < kSpinLimit) { + std::this_thread::yield(); + return; + } + now = std::chrono::system_clock::now(); + auto pause_time_base = std::max(now - start, expected_latency); + auto pause_time = std::clamp(pause_time_base / pause_factor, kPauseTimeMin, kPauseTimeMax); + if (now + pause_time < sleep_limit()) { + std::this_thread::sleep_for(pause_time); + } else if (now <= overtime_threshold()) { + std::this_thread::yield(); + } else if (now <= deadline) { + std::this_thread::sleep_for(pause_time); + } else { + throw raft::exception( + "The calling thread didn't receive the results from the persistent CAGRA kernel within the " + "expected kernel lifetime. Here are possible reasons of this failure:\n" + " (1) `persistent_lifetime` search parameter is too small - increase it;\n" + " (2) there is other work being executed on the same device and the kernel failed to " + "progress - decreasing `persistent_device_usage` may help (but not guaranteed);\n" + " (3) there is a bug in the implementation - please report it to cuVS team."); + } + } + + inline void wait() + { + uint32_t worker_id; + while (pending_reads.try_pop_front(worker_id)) { + while (!try_return_worker(worker_id)) { + if (!is_all_done()) { pause(); } + } + } + pause_count = 0; + now = std::chrono::system_clock::now(); + while (!is_all_done()) { + auto till_time = sleep_limit(); + if (now < till_time) { + std::this_thread::sleep_until(till_time); + now = std::chrono::system_clock::now(); + } else { + pause(); + } + } + job_ids.push(job_id); + job_id = job_queue_type::kEmpty; + } +}; + +// JIT persistent runner - uses AlgorithmLauncher instead of kernel function pointer +template +struct alignas(kCacheLineBytes) persistent_runner_jit_t : public persistent_runner_base_t { + using index_type = IndexT; + using distance_type = DistanceT; + using data_type = DataT; + // Use non-JIT types - JIT kernel header will alias _jit versions to these + struct job_desc_helper_desc { + using DATA_T = DataT; + using INDEX_T = IndexT; + using DISTANCE_T = DistanceT; + }; + using job_desc_type = job_desc_t; + + std::shared_ptr launcher; + uint32_t block_size; + dataset_descriptor_host dd_host; + rmm::device_uvector worker_handles; + rmm::device_uvector job_descriptors; + rmm::device_uvector completion_counters; + rmm::device_uvector hashmap; + std::atomic> last_touch; + uint64_t param_hash; + uint32_t* bitset_ptr; // Bitset data pointer (nullptr for none_filter) + SourceIndexT bitset_len; // Bitset length + SourceIndexT original_nbits; // Original number of bits + + static inline auto calculate_parameter_hash( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + uint32_t max_itopk, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage, + std::shared_ptr /* launcher_ptr - not part of hash */, + const void* /* dataset_desc - not part of hash */) -> uint64_t + { + return uint64_t(graph.data_handle()) ^ uint64_t(source_indices_ptr) ^ + dataset_desc.get().team_size ^ num_itopk_candidates ^ block_size ^ smem_size ^ + hash_bitlen ^ small_hash_reset_interval ^ num_random_samplings ^ rand_xor_mask ^ + num_seeds ^ itopk_size ^ search_width ^ min_iterations ^ max_iterations ^ + uint64_t(persistent_lifetime * 1000) ^ uint64_t(persistent_device_usage * 1000); + } + + persistent_runner_jit_t( + std::reference_wrapper> dataset_desc, + raft::device_matrix_view graph, + const SourceIndexT* source_indices_ptr, + uint32_t max_candidates, + uint32_t num_itopk_candidates, + uint32_t block_size, + uint32_t smem_size, + int64_t hash_bitlen, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_random_samplings, + uint64_t rand_xor_mask, + uint32_t num_seeds, + uint32_t max_itopk, + size_t itopk_size, + size_t search_width, + size_t min_iterations, + size_t max_iterations, + SampleFilterT sample_filter, + float persistent_lifetime, + float persistent_device_usage, + std::shared_ptr launcher_ptr, + const void* /* dataset_desc - descriptor contains all needed info */) + : persistent_runner_base_t{persistent_lifetime}, + launcher{launcher_ptr}, + block_size{block_size}, + worker_handles(0, stream, worker_handles_mr), + job_descriptors(kMaxJobsNum, stream, job_descriptor_mr), + completion_counters(kMaxJobsNum, stream, device_mr), + hashmap(0, stream, device_mr), + dd_host{dataset_desc.get()}, + param_hash(calculate_parameter_hash(dd_host, + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + num_random_samplings, + rand_xor_mask, + num_seeds, + max_itopk, + itopk_size, + search_width, + min_iterations, + max_iterations, + sample_filter, + persistent_lifetime, + persistent_device_usage, + launcher_ptr, + nullptr)) // descriptor not needed in hash + { + // Extract bitset data from filter object (if it's a bitset_filter) + // Handle both direct bitset_filter and CagraSampleFilterWithQueryIdOffset wrapper + bitset_ptr = nullptr; + bitset_len = 0; + original_nbits = 0; + uint32_t query_id_offset = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + // Always extract offset for wrapped filters + query_id_offset = sample_filter.offset; + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // set kernel launch parameters + dim3 gs = calc_coop_grid_size(block_size, smem_size, persistent_device_usage); + dim3 bs(block_size, 1, 1); + RAFT_LOG_DEBUG( + "Launching JIT persistent kernel with %u threads, %u block %u smem", bs.x, gs.y, smem_size); + + // initialize the job queue + auto* completion_counters_ptr = completion_counters.data(); + auto* job_descriptors_ptr = job_descriptors.data(); + for (uint32_t i = 0; i < kMaxJobsNum; i++) { + auto& jd = job_descriptors_ptr[i].input.value; + jd.result_indices_ptr = 0; + jd.result_distances_ptr = nullptr; + jd.queries_ptr = nullptr; + jd.top_k = 0; + jd.n_queries = 0; + job_descriptors_ptr[i].completion_flag.store(false); + job_queue.push(i); + } + + // initialize the worker queue + worker_queue.set_capacity(gs.y); + worker_handles.resize(gs.y, stream); + auto* worker_handles_ptr = worker_handles.data(); + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); + for (uint32_t i = 0; i < gs.y; i++) { + worker_handles_ptr[i].data.store({kWaitForWork}); + worker_queue.push(i); + } + + index_type* hashmap_ptr = nullptr; + if (small_hash_bitlen == 0) { + hashmap.resize(gs.y * hashmap::get_size(hash_bitlen), stream); + hashmap_ptr = hashmap.data(); + } + + // Prepare kernel arguments + // Get the device descriptor pointer - kernel will use the concrete type from template + const auto* dev_desc = dataset_desc.get().dev_ptr(stream); + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(itopk_size); + const uint32_t search_width_u32 = static_cast(search_width); + const uint32_t min_iterations_u32 = static_cast(min_iterations); + const uint32_t max_iterations_u32 = static_cast(max_iterations); + const unsigned num_random_samplings_u = static_cast(num_random_samplings); + + // Launch the persistent kernel via AlgorithmLauncher + // The persistent kernel now takes the descriptor pointer directly + launcher->dispatch_cooperative( + stream, + gs, + bs, + smem_size, + worker_handles_ptr, + job_descriptors_ptr, + completion_counters_ptr, + graph.data_handle(), + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + rand_xor_mask, // uint64_t matches kernel (8 bytes) + nullptr, // seed_ptr + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, // Cast size_t to uint32_t + search_width_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + nullptr, // num_executed_iterations + hash_bitlen_u32, // Cast int64_t to uint32_t + small_hash_bitlen_u32, // Cast size_t to uint32_t + small_hash_reset_interval_u32, // Cast size_t to uint32_t + query_id_offset, // Offset to add to query_id when calling filter + dev_desc, // Pass descriptor pointer + bitset_ptr, + bitset_len, + original_nbits); + + last_touch.store(std::chrono::system_clock::now(), std::memory_order_relaxed); + } + + ~persistent_runner_jit_t() noexcept override + { + auto whs = worker_handles.data(); + for (auto i = worker_handles.size(); i > 0; i--) { + whs[worker_queue.pop().wait()].data.store({kNoMoreWork}, cuda::memory_order_relaxed); + } + RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream)); + } + + void launch(uintptr_t result_indices_ptr, + distance_type* result_distances_ptr, + const data_type* queries_ptr, + uint32_t num_queries, + uint32_t top_k) + { + launcher_jit_t launcher{job_queue, + worker_queue, + worker_handles.data(), + num_queries, + this->lifetime, + [&job_descriptors = this->job_descriptors, + result_indices_ptr, + result_distances_ptr, + queries_ptr, + top_k, + num_queries](uint32_t job_ix) { + auto& jd = job_descriptors.data()[job_ix].input.value; + auto* cflag = &job_descriptors.data()[job_ix].completion_flag; + jd.result_indices_ptr = result_indices_ptr; + jd.result_distances_ptr = result_distances_ptr; + jd.queries_ptr = queries_ptr; + jd.top_k = top_k; + jd.n_queries = num_queries; + cflag->store(false, cuda::memory_order_relaxed); + cuda::atomic_thread_fence(cuda::memory_order_release, + cuda::thread_scope_system); + return cflag; + }}; + + auto prev_touch = last_touch.load(std::memory_order_relaxed); + if (prev_touch + lifetime / 10 < launcher.now) { + last_touch.store(launcher.now, std::memory_order_relaxed); + } + launcher.wait(); + } + + auto calc_coop_grid_size(uint32_t block_size, uint32_t smem_size, float persistent_device_usage) + -> dim3 + { + int ctas_per_sm = 1; + cudaKernel_t kernel_handle = launcher->get_kernel(); + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel_handle, block_size, smem_size)); + int num_sm = raft::getMultiProcessorCount(); + auto n_blocks = static_cast(persistent_device_usage * (ctas_per_sm * num_sm)); + if (n_blocks > kMaxWorkersNum) { + RAFT_LOG_WARN("Limiting the grid size limit due to the size of the queue: %u -> %u", + n_blocks, + kMaxWorkersNum); + n_blocks = kMaxWorkersNum; + } + return {1, n_blocks, 1}; + } +}; + +template +void select_and_run_jit( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + const SourceIndexT* source_indices_ptr = + source_indices.has_value() ? source_indices->data_handle() : nullptr; + + // Extract bitset data from filter object (if it's a bitset_filter) + // Handle both direct bitset_filter and CagraSampleFilterWithQueryIdOffset wrapper + uint32_t* bitset_ptr = nullptr; + SourceIndexT bitset_len = 0; + SourceIndexT original_nbits = 0; + uint32_t query_id_offset = 0; + + // Check if it has the wrapper members (CagraSampleFilterWithQueryIdOffset) + if constexpr (requires { + sample_filter.filter; + sample_filter.offset; + }) { + using InnerFilter = decltype(sample_filter.filter); + // Always extract offset for wrapped filters + query_id_offset = sample_filter.offset; + if constexpr (is_bitset_filter::value) { + // Extract bitset data for bitset_filter (works for any bitset_filter instantiation) + auto bitset_view = sample_filter.filter.view(); + bitset_ptr = const_cast(bitset_view.data()); + bitset_len = static_cast(bitset_view.size()); + original_nbits = static_cast(bitset_view.get_original_nbits()); + } + } + + // Use common logic to compute launch config + auto config = compute_launch_config(num_itopk_candidates, ps.itopk_size, block_size); + uint32_t max_candidates = config.max_candidates; + uint32_t max_itopk = config.max_itopk; + bool topk_by_bitonic_sort = config.topk_by_bitonic_sort; + bool bitonic_sort_and_merge_multi_warps = config.bitonic_sort_and_merge_multi_warps; + + // Handle persistent kernels + if (ps.persistent) { + // Use persistent runner for JIT kernels + using runner_type = + persistent_runner_jit_t; + + // Create planner with tags for persistent kernel + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + std::shared_ptr launcher; + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + true /* persistent */); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(get_sample_filter_name()); + launcher = planner.get_launcher(); + } else { + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + true /* persistent */); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(get_sample_filter_name()); + launcher = planner.get_launcher(); + } else { + using QueryTag = + query_type_tag_standard_t; + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len, + true /* persistent */); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(get_sample_filter_name()); + launcher = planner.get_launcher(); + } + } + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA persistent search kernel"); } + + // Use get_runner pattern similar to non-JIT version + const auto* dev_desc_persistent = dataset_desc.dev_ptr(stream); + get_runner_jit(std::cref(dataset_desc), + graph, + source_indices_ptr, + max_candidates, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + small_hash_bitlen, + small_hash_reset_interval, + ps.num_random_samplings, + ps.rand_xor_mask, + num_seeds, + max_itopk, + ps.itopk_size, + ps.search_width, + ps.min_iterations, + ps.max_iterations, + sample_filter, + ps.persistent_lifetime, + ps.persistent_device_usage, + launcher, + dev_desc_persistent) // Pass descriptor pointer + ->launch(topk_indices_ptr, topk_distances_ptr, queries_ptr, num_queries, topk); + return; + } else { + // Create planner with tags for regular kernel + using DataTag = decltype(get_data_type_tag()); + using IndexTag = decltype(get_index_type_tag()); + using DistTag = decltype(get_distance_type_tag()); + using SourceTag = decltype(get_source_index_type_tag()); + + std::shared_ptr launcher; + if (dataset_desc.is_vpq) { + using QueryTag = query_type_tag_vpq_t; + using CodebookTag = codebook_tag_vpq_t; + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(get_sample_filter_name()); + launcher = planner.get_launcher(); + } else { + using CodebookTag = codebook_tag_standard_t; + if (dataset_desc.metric == cuvs::distance::DistanceType::BitwiseHamming) { + using QueryTag = + query_type_tag_standard_t; + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(get_sample_filter_name()); + launcher = planner.get_launcher(); + } else { + using QueryTag = + query_type_tag_standard_t; + CagraSingleCtaSearchPlanner + planner(dataset_desc.metric, + topk_by_bitonic_sort, + bitonic_sort_and_merge_multi_warps, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + + planner.add_setup_workspace_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_compute_distance_device_function(dataset_desc.metric, + dataset_desc.team_size, + dataset_desc.dataset_block_dim, + dataset_desc.is_vpq, + dataset_desc.pq_bits, + dataset_desc.pq_len); + planner.add_sample_filter_device_function(get_sample_filter_name()); + launcher = planner.get_launcher(); + } + } + if (!launcher) { RAFT_FAIL("Failed to get JIT launcher for CAGRA search kernel"); } + + // Get the device descriptor pointer - dev_ptr() initializes it if needed + const auto* dev_desc = dataset_desc.dev_ptr(stream); + + // Cast size_t/int64_t parameters to match kernel signature exactly + // The dispatch mechanism uses void* pointers, so parameter sizes must match exactly + const uint32_t graph_degree_u32 = static_cast(graph.extent(1)); + const uint32_t hash_bitlen_u32 = static_cast(hash_bitlen); + const uint32_t small_hash_bitlen_u32 = static_cast(small_hash_bitlen); + const uint32_t small_hash_reset_interval_u32 = static_cast(small_hash_reset_interval); + const uint32_t itopk_size_u32 = static_cast(ps.itopk_size); + const uint32_t search_width_u32 = static_cast(ps.search_width); + const uint32_t min_iterations_u32 = static_cast(ps.min_iterations); + const uint32_t max_iterations_u32 = static_cast(ps.max_iterations); + const unsigned num_random_samplings_u = static_cast(ps.num_random_samplings); + + dim3 grid(1, num_queries, 1); + dim3 block(block_size, 1, 1); + + RAFT_LOG_DEBUG("Launching JIT kernel with %u threads, %u blocks, %u smem", + block_size, + num_queries, + smem_size); + + // Dispatch kernel via launcher + auto kernel_launcher = [&](auto const& kernel) -> void { + launcher->dispatch( + stream, + grid, + block, + smem_size, + topk_indices_ptr, + topk_distances_ptr, + topk, + queries_ptr, + graph.data_handle(), + graph_degree_u32, // Cast int64_t to uint32_t + source_indices_ptr, + num_random_samplings_u, // Cast uint32_t to unsigned for consistency + ps.rand_xor_mask, // uint64_t matches kernel (8 bytes) + dev_seed_ptr, + num_seeds, + hashmap_ptr, + max_candidates, + max_itopk, + itopk_size_u32, // Cast size_t to uint32_t + search_width_u32, // Cast size_t to uint32_t + min_iterations_u32, // Cast size_t to uint32_t + max_iterations_u32, // Cast size_t to uint32_t + num_executed_iterations, + hash_bitlen_u32, // Cast int64_t to uint32_t + small_hash_bitlen_u32, // Cast size_t to uint32_t + small_hash_reset_interval_u32, // Cast size_t to uint32_t + query_id_offset, // Offset to add to query_id when calling filter + dev_desc, // Pass base pointer - kernel expects concrete type but pointer value is same + bitset_ptr, + bitset_len, + original_nbits); + }; + + cuvs::neighbors::detail::safely_launch_kernel_with_smem_size( + launcher->get_kernel(), smem_size, kernel_launcher); + + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } +} + +// Wrapper to match the non-JIT interface +// This function MUST be called if JIT is enabled +template +void select_and_run( + const dataset_descriptor_host& dataset_desc, + raft::device_matrix_view graph, + std::optional> source_indices, + uintptr_t topk_indices_ptr, // [num_queries, topk] + DistanceT* topk_distances_ptr, // [num_queries, topk] + const DataT* queries_ptr, // [num_queries, dataset_dim] + uint32_t num_queries, + const IndexT* dev_seed_ptr, // [num_queries, num_seeds] + uint32_t* num_executed_iterations, // [num_queries,] + const search_params& ps, + uint32_t topk, + uint32_t num_itopk_candidates, + uint32_t block_size, // + uint32_t smem_size, + int64_t hash_bitlen, + IndexT* hashmap_ptr, + size_t small_hash_bitlen, + size_t small_hash_reset_interval, + uint32_t num_seeds, + SampleFilterT sample_filter, + cudaStream_t stream) +{ + select_and_run_jit(dataset_desc, + graph, + source_indices, + topk_indices_ptr, + topk_distances_ptr, + queries_ptr, + num_queries, + dev_seed_ptr, + num_executed_iterations, + ps, + topk, + num_itopk_candidates, + block_size, + smem_size, + hash_bitlen, + hashmap_ptr, + small_hash_bitlen, + small_hash_reset_interval, + num_seeds, + sample_filter, + stream); +} + +// get_runner for JIT persistent runners (similar to non-JIT version) +template +auto get_runner_jit(Args... args) -> std::shared_ptr +{ + static thread_local std::weak_ptr weak; + auto runner = weak.lock(); + if (runner) { + if (runner->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner; + } else { + weak.reset(); + runner.reset(); + } + } + launcher_jit_t::expected_latency = launcher_jit_t::kDefaultLatency; + runner = create_runner_jit(args...); + weak = runner; + return runner; +} + +template +auto create_runner_jit(Args... args) -> std::shared_ptr +{ + std::lock_guard guard(persistent.lock); + std::shared_ptr runner_outer = std::dynamic_pointer_cast(persistent.runner); + if (runner_outer) { + // calculate_parameter_hash needs all args to match constructor signature + // but only uses a subset for the actual hash + if (runner_outer->param_hash == RunnerT::calculate_parameter_hash(args...)) { + return runner_outer; + } else { + runner_outer.reset(); + } + } + persistent.runner.reset(); + + cuda::std::atomic_flag ready{}; + ready.clear(cuda::std::memory_order_relaxed); + std::thread( + [&runner_outer, &ready](Args... thread_args) { + runner_outer = std::make_shared(thread_args...); + auto lifetime = runner_outer->lifetime; + persistent.runner = std::static_pointer_cast(runner_outer); + std::weak_ptr runner_weak = runner_outer; + ready.test_and_set(cuda::std::memory_order_release); + ready.notify_one(); + + while (true) { + std::this_thread::sleep_for(lifetime); + auto runner = runner_weak.lock(); + if (!runner) { return; } + if (runner->last_touch.load(std::memory_order_relaxed) + lifetime < + std::chrono::system_clock::now()) { + std::lock_guard guard(persistent.lock); + if (runner == persistent.runner) { persistent.runner.reset(); } + return; + } + } + }, + args...) + .detach(); + ready.wait(false, cuda::std::memory_order_acquire); + return runner_outer; +} + +} // namespace cuvs::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/set_value_batch.cuh b/cpp/src/neighbors/detail/cagra/set_value_batch.cuh new file mode 100644 index 0000000000..a4433005a7 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/set_value_batch.cuh @@ -0,0 +1,40 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +template +__global__ void set_value_batch_kernel(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size) +{ + const auto tid = threadIdx.x + blockIdx.x * blockDim.x; + if (tid >= count * batch_size) { return; } + const auto batch_id = tid / count; + const auto elem_id = tid % count; + dev_ptr[elem_id + ld * batch_id] = val; +} + +template +void set_value_batch(T* const dev_ptr, + const std::size_t ld, + const T val, + const std::size_t count, + const std::size_t batch_size, + cudaStream_t cuda_stream) +{ + constexpr std::uint32_t block_size = 256; + const auto grid_size = (count * batch_size + block_size - 1) / block_size; + set_value_batch_kernel + <<>>(dev_ptr, ld, val, count, batch_size); +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp new file mode 100644 index 0000000000..33dcf9cbf9 --- /dev/null +++ b/cpp/src/neighbors/detail/cagra/shared_launcher_jit.hpp @@ -0,0 +1,118 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#ifndef CUVS_ENABLE_JIT_LTO +#error "shared_launcher_jit.hpp included but CUVS_ENABLE_JIT_LTO not defined!" +#endif + +// Include tags header before any other includes that might open namespaces +#include + +#include "../../sample_filter.cuh" // For none_sample_filter, bitset_filter + +#include +#include +#include +#include +#include + +namespace cuvs::neighbors::cagra::detail { + +// Helper functions to get tags for JIT LTO +template +constexpr auto get_data_type_tag() +{ + if constexpr (std::is_same_v) { return tag_f{}; } + if constexpr (std::is_same_v) { return tag_h{}; } + if constexpr (std::is_same_v) { return tag_sc{}; } + if constexpr (std::is_same_v) { return tag_uc{}; } +} + +template +constexpr auto get_index_type_tag() +{ + if constexpr (std::is_same_v) { return tag_idx_ui{}; } +} + +template +constexpr auto get_distance_type_tag() +{ + if constexpr (std::is_same_v) { return tag_dist_f{}; } +} + +template +constexpr auto get_source_index_type_tag() +{ + if constexpr (std::is_same_v) { return tag_idx_ui{}; } + if constexpr (std::is_same_v) { return tag_idx_l{}; } +} + +template +struct query_type_tag_standard { + using type = std::conditional_t, + tag_uc, + tag_f>; +}; + +template +using query_type_tag_standard_t = typename query_type_tag_standard::type; + +template +using query_type_tag_vpq_t = tag_h; + +template +using query_type_tag_standard_l2_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_inner_product_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_cosine_t = + query_type_tag_standard_t; +template +using query_type_tag_standard_hamming_t = + query_type_tag_standard_t; + +using codebook_tag_vpq_t = tag_codebook_half; +using codebook_tag_standard_t = void; + +// Helper trait to detect if a type is a bitset_filter (regardless of template parameters) +template +struct is_bitset_filter : std::false_type {}; + +template +struct is_bitset_filter> + : std::true_type {}; + +template +std::string get_sample_filter_name() +{ + using namespace cuvs::neighbors::filtering; + using DecayedFilter = std::decay_t; + + // First check for none_sample_filter (the only unwrapped case) + if constexpr (std::is_same_v) { + return "filter_none_source_index_ui"; + } + + // All other filters are wrapped in CagraSampleFilterWithQueryIdOffset + // Access the inner filter type via decltype + if constexpr (requires { std::declval().filter; }) { + using InnerFilter = decltype(std::declval().filter); + if constexpr (is_bitset_filter::value || + std::is_same_v> || + std::is_same_v>) { + return "filter_bitset_source_index_ui"; + } + } + + // Default to none filter for unknown types + return "filter_none_source_index_ui"; +} + +} // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_bitset.cuh b/cpp/src/neighbors/detail/jit_lto_kernels/filter_bitset.cuh new file mode 100644 index 0000000000..415fae7075 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_bitset.cuh @@ -0,0 +1,77 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include "filter_data.h" + +namespace cuvs::neighbors::detail { + +// Inline implementation of bitset_view::test() to avoid including bitset.cuh +// which transitively includes Thrust +template +__device__ inline bool bitset_view_test(const bitset_t* bitset_ptr, + index_t bitset_len, + index_t original_nbits, + index_t sample_index) +{ + constexpr index_t bitset_element_size = sizeof(bitset_t) * 8; + const index_t nbits = sizeof(bitset_t) * 8; + index_t bit_index = 0; + index_t bit_offset = 0; + + if (original_nbits == 0 || nbits == original_nbits) { + bit_index = sample_index / bitset_element_size; + bit_offset = sample_index % bitset_element_size; + } else { + // Handle original_nbits != nbits case + const index_t original_bit_index = sample_index / original_nbits; + const index_t original_bit_offset = sample_index % original_nbits; + bit_index = original_bit_index * original_nbits / nbits; + bit_offset = 0; + if (original_nbits > nbits) { + bit_index += original_bit_offset / nbits; + bit_offset = original_bit_offset % nbits; + } else { + index_t ratio = nbits / original_nbits; + bit_offset += (original_bit_index % ratio) * original_nbits; + bit_offset += original_bit_offset % nbits; + } + } + const bitset_t bit_element = bitset_ptr[bit_index]; + const bool is_bit_set = (bit_element & (bitset_t{1} << bit_offset)) != 0; + return is_bit_set; +} + +// Unified sample_filter: takes query_id, node_id, and void* filter_data +// Used by both CAGRA and IVF Flat +// For IVF Flat: node_id should be computed from (cluster_ix, sample_ix) using inds_ptrs from +// filter_data +template +__device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data) +{ + // bitset_filter checks if the node_id is in the bitset + // filter_data points to bitset_filter_data_t struct + if (filter_data == nullptr) { + return true; // No filter data, allow all + } + + auto* bitset_data = static_cast*>(filter_data); + if (bitset_data->bitset_ptr == nullptr) { + return true; // No bitset provided, allow all + } + + // Directly test the bitset without needing bitset_filter wrapper + // bitset_view_test returns true if the bit is set (node_id is in the bitset) + // The bitset marks allowed indices (same as non-JIT bitset_filter which returns test() directly) + // Return true if the bit is set (node is allowed), false if not set (node should be filtered out) + bool is_in_bitset = bitset_view_test( + bitset_data->bitset_ptr, bitset_data->bitset_len, bitset_data->original_nbits, node_id); + // If node_id is in the bitset (allowed), return true to allow it + // If node_id is not in the bitset, return false to reject it + return is_in_bitset; +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_data.h b/cpp/src/neighbors/detail/jit_lto_kernels/filter_data.h new file mode 100644 index 0000000000..9fc4336872 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_data.h @@ -0,0 +1,28 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include +#include + +namespace cuvs::neighbors::detail { + +// Structure to hold bitset filter data +// This is passed as void* to the extern sample_filter function +// Used by both CAGRA and IVF Flat +template +struct bitset_filter_data_t { + uint32_t* bitset_ptr; // Pointer to bitset data in global memory + SourceIndexT bitset_len; // Length of bitset array + SourceIndexT original_nbits; // Original number of bits + + __device__ bitset_filter_data_t(uint32_t* ptr, SourceIndexT len, SourceIndexT nbits) + : bitset_ptr(ptr), bitset_len(len), original_nbits(nbits) + { + } +}; + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_embedded.cpp.in b/cpp/src/neighbors/detail/jit_lto_kernels/filter_embedded.cpp.in similarity index 61% rename from cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_embedded.cpp.in rename to cpp/src/neighbors/detail/jit_lto_kernels/filter_embedded.cpp.in index a5a7299b73..3e00b20e9a 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_embedded.cpp.in +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_embedded.cpp.in @@ -1,10 +1,8 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include #include "@embedded_header_file@" @@ -13,7 +11,7 @@ namespace { __attribute__((__constructor__)) void register_kernel() { registerAlgorithm( - "@filter_name@", + "sample_filter_@filter_name@_source_index_@source_index_abbrev@", embedded_fatbin, sizeof(embedded_fatbin)); } diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_kernel.cu.in b/cpp/src/neighbors/detail/jit_lto_kernels/filter_kernel.cu.in new file mode 100644 index 0000000000..7350f6bb58 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_kernel.cu.in @@ -0,0 +1,14 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +namespace @namespace@ { + +// Instantiate the sample_filter device function template +// CAGRA style: sample_filter(query_id, node_id, filter_data) +template __device__ bool sample_filter<@source_index_type@>(uint32_t, @source_index_type@, void*); + +} // namespace @namespace@ diff --git a/cpp/src/neighbors/detail/jit_lto_kernels/filter_none.cuh b/cpp/src/neighbors/detail/jit_lto_kernels/filter_none.cuh new file mode 100644 index 0000000000..e3ca5496c1 --- /dev/null +++ b/cpp/src/neighbors/detail/jit_lto_kernels/filter_none.cuh @@ -0,0 +1,22 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#pragma once + +#include + +namespace cuvs::neighbors::detail { + +// Unified sample_filter: takes query_id, node_id, and void* filter_data +// Used by both CAGRA and IVF Flat +template +__device__ bool sample_filter(uint32_t query_id, SourceIndexT node_id, void* filter_data) +{ + // none_sample_filter always returns true (no filtering) + // filter_data is ignored (can be nullptr) + return true; +} + +} // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/detail/smem_utils.cuh b/cpp/src/neighbors/detail/smem_utils.cuh index 41c95c0ccd..8e9cde4ff2 100644 --- a/cpp/src/neighbors/detail/smem_utils.cuh +++ b/cpp/src/neighbors/detail/smem_utils.cuh @@ -8,37 +8,24 @@ #include #include +#include #include +#include namespace cuvs::neighbors::detail { -/** - * @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size. - * This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed - * atomically. - * - * Used this way, the cudaFuncAttributeMaxDynamicSharedMemorySize can only grow and thus - * guarantees that the kernel is safe to launch. - * - * @tparam KernelT The type of the kernel. - * @tparam InvocationT The type of the invocation function. - * @param kernel The kernel function address (for whom the smem-size is specified). - * @param smem_size The size of the dynamic shared memory to be set. - * @param launch The kernel launch function/lambda. - */ template -void safely_launch_kernel_with_smem_size(KernelT const& kernel, - uint32_t smem_size, - KernelLauncherT const& launch) +void safely_launch_kernel_with_smem_size_impl(KernelT const& kernel, + uint32_t smem_size, + KernelLauncherT const& launch, + std::mutex& mutex, + std::atomic& current_smem_size) { - // the last smem size is parameterized by the kernel thanks to the template parameter. - static std::atomic current_smem_size{0}; auto last_smem_size = current_smem_size.load(std::memory_order_relaxed); if (smem_size > last_smem_size) { // We still need a mutex for the critical section: actualize last_smem_size and set the // attribute. - static auto mutex = std::mutex{}; - auto guard = std::lock_guard{mutex}; + auto guard = std::lock_guard{mutex}; if (!current_smem_size.compare_exchange_strong( last_smem_size, smem_size, std::memory_order_relaxed, std::memory_order_relaxed)) { // The value has been updated by another thread between the load and the mutex acquisition. @@ -59,4 +46,52 @@ void safely_launch_kernel_with_smem_size(KernelT const& kernel, return launch(kernel); } +/** + * @brief (Thread-)Safely invoke a kernel with a maximum dynamic shared memory size. + * This is required because the sequence `cudaFuncSetAttribute` + kernel launch is not executed + * atomically. + * + * Used this way, the cudaFuncAttributeMaxDynamicSharedMemorySize can only grow and thus + * guarantees that the kernel is safe to launch. + * + * @tparam KernelT The type of the kernel. + * @tparam InvocationT The type of the invocation function. + * @param kernel The kernel function address (for whom the smem-size is specified). + * @param smem_size The size of the dynamic shared memory to be set. + * @param launch The kernel launch function/lambda. + */ +// Specialization for cudaKernel_t (JIT LTO kernels) - track by kernel pointer +template +void safely_launch_kernel_with_smem_size(cudaKernel_t kernel, + uint32_t smem_size, + KernelLauncherT const& launch) +{ + // For JIT kernels, track by kernel pointer since all cudaKernel_t have the same type + static std::unordered_map>> + jit_smem_sizes; + std::mutex map_mutex; + + std::pair>* current_smem_size; + { + std::lock_guard map_lock{map_mutex}; + current_smem_size = &jit_smem_sizes[kernel]; + } + safely_launch_kernel_with_smem_size_impl( + kernel, smem_size, launch, current_smem_size->first, current_smem_size->second); +} + +// General template for regular function pointers +template +void safely_launch_kernel_with_smem_size(KernelT const& kernel, + uint32_t smem_size, + KernelLauncherT const& launch) +{ + // the last smem size is parameterized by the kernel thanks to the template parameter. + static std::atomic current_smem_size{0}; + static std::mutex mutex; + + safely_launch_kernel_with_smem_size_impl( + kernel, smem_size, launch, mutex, current_smem_size); +} + } // namespace cuvs::neighbors::detail diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index 4c0bb3644a..1324a1d41a 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -1293,6 +1293,7 @@ struct select_interleaved_scan_kernel { */ template void ivfflat_interleaved_scan(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh index 81833a63b1..cfb4982adf 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_explicit_inst.cuh @@ -23,6 +23,7 @@ typename cuvs::spatial::knn::detail::utils::config::value_t, \ IdxT, \ SampleFilterT>(const index& index, \ + const search_params& params, \ const T* queries, \ const uint32_t* coarse_query_results, \ const uint32_t n_queries, \ diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh index 1d63c52adb..e293362327 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_ext.cuh @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -17,6 +17,7 @@ namespace cuvs::neighbors::ivf_flat::detail { template void ivfflat_interleaved_scan(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, @@ -39,6 +40,7 @@ void ivfflat_interleaved_scan(const index& index, typename cuvs::spatial::knn::detail::utils::config::value_t, \ IdxT, \ SampleFilterT>(const index& index, \ + const search_params& params, \ const T* queries, \ const uint32_t* coarse_query_results, \ const uint32_t n_queries, \ diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh index be8652dd59..256e37221a 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan_jit.cuh @@ -5,10 +5,12 @@ #pragma once +#include "../detail/jit_lto_kernels/filter_data.h" #include "../ivf_common.cuh" #include "jit_lto_kernels/interleaved_scan_planner.hpp" #include -#include +#include +#include #include #include @@ -52,17 +54,27 @@ constexpr auto get_idx_type_tag() if constexpr (std::is_same_v) { return tag_idx_l{}; } } +// Convert type to string for JIT code generation +template +constexpr const char* type_name() +{ + if constexpr (std::is_same_v) { return "float"; } + if constexpr (std::is_same_v) { return "__half"; } + if constexpr (std::is_same_v) { return "int8_t"; } + if constexpr (std::is_same_v) { return "uint8_t"; } + if constexpr (std::is_same_v) { return "int32_t"; } + if constexpr (std::is_same_v) { return "uint32_t"; } + if constexpr (std::is_same_v) { return "int64_t"; } +} + template constexpr auto get_filter_type_tag() { using namespace cuvs::neighbors::filtering; - // Determine the filter implementation tag - if constexpr (std::is_same_v) { - return tag_filter{}; - } + if constexpr (std::is_same_v) { return tag_filter_none{}; } if constexpr (std::is_same_v>) { - return tag_filter{}; + return tag_filter_bitset{}; } } @@ -71,20 +83,21 @@ constexpr auto get_metric_name() { if constexpr (std::is_same_v>) { return "euclidean"; - } - if constexpr (std::is_same_v>) { + } else if constexpr (std::is_same_v>) { return "inner_prod"; + } else if constexpr (std::is_same_v>) { + return "metric_udf"; } } template constexpr auto get_filter_name() { - if constexpr (std::is_same_v>) { - return "filter_none"; + if constexpr (std::is_same_v) { + return "filter_none_source_index_l"; } - if constexpr (std::is_same_v>) { - return "filter_bitset"; + if constexpr (std::is_same_v) { + return "filter_bitset_source_index_l"; } } @@ -128,6 +141,7 @@ template void launch_kernel(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_index, const uint32_t num_queries, @@ -153,9 +167,38 @@ void launch_kernel(const index& index, decltype(get_acc_type_tag()), decltype(get_idx_type_tag())>( Capacity, Veclen, Ascending, ComputeNorm); - kernel_planner.template add_metric_device_function()), - decltype(get_acc_type_tag())>( - get_metric_name(), Veclen); + if (params.metric_udf.has_value()) { + std::string metric_udf = params.metric_udf.value(); + // Add explicit template instantiation with actual types + metric_udf += "\ntemplate void cuvs::neighbors::ivf_flat::detail::compute_dist<"; + metric_udf += std::to_string(Veclen); + metric_udf += ", "; + metric_udf += type_name(); + metric_udf += ", "; + metric_udf += type_name(); + metric_udf += ">("; + metric_udf += type_name(); + metric_udf += "&, "; + metric_udf += type_name(); + metric_udf += ", "; + metric_udf += type_name(); + metric_udf += ");\n"; + // Include hash of UDF source in key to differentiate different UDFs + auto udf_hash = std::to_string(std::hash{}(metric_udf)); + std::string metric_name = "metric_udf_" + udf_hash; + auto& nvrtc_lto_compiler = nvrtc_compiler(); + std::string key = + metric_name + "_veclen_" + std::to_string(Veclen) + "_" + + make_fragment_key()), decltype(get_acc_type_tag())>(); + nvrtc_lto_compiler.compile(key, metric_udf); + kernel_planner.template add_metric_device_function()), + decltype(get_acc_type_tag())>( + metric_name, Veclen); + } else { + kernel_planner.template add_metric_device_function()), + decltype(get_acc_type_tag())>( + get_metric_name(), Veclen); + } kernel_planner.add_filter_device_function(get_filter_name()); kernel_planner.add_post_lambda_device_function(get_post_lambda_name()); auto kernel_launcher = kernel_planner.get_launcher(); @@ -182,6 +225,9 @@ void launch_kernel(const index& index, return; } + // Pass individual filter parameters like CAGRA does + // The kernel will construct filter_data struct internally when needed + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); dim3 grid_dim(grid_dim_x, grid_dim_y, 1); @@ -209,7 +255,6 @@ void launch_kernel(const index& index, max_samples, chunk_indices, index.dim(), - // sample_filter, inds_ptrs, bitset_ptr.value_or(nullptr), bitset_len.value_or(0), @@ -289,6 +334,17 @@ void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... arg tag_post_compose>( std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when // adding here a new metric. + case cuvs::distance::DistanceType::CustomUDF: + return launch_kernel, + tag_post_identity>(std::forward(args)...); default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } } @@ -390,6 +446,7 @@ struct select_interleaved_scan_kernel { */ template void ivfflat_interleaved_scan(const index& index, + const search_params& params, const T* queries, const uint32_t* coarse_query_results, const uint32_t n_queries, @@ -424,6 +481,7 @@ void ivfflat_interleaved_scan(const index& index, select_min, metric, index, + params, queries, coarse_query_results, n_queries, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 3379e7b8dc..f7384294f5 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -26,6 +26,8 @@ #include +#include + namespace cuvs::neighbors::ivf_flat::detail { using namespace cuvs::spatial::knn::detail; // NOLINT @@ -38,6 +40,7 @@ auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool template void search_impl(raft::resources const& handle, const cuvs::neighbors::ivf_flat::index& index, + const search_params& params, const T* queries, uint32_t n_queries, uint32_t queries_offset, @@ -190,6 +193,7 @@ void search_impl(raft::resources const& handle, // query the gridDimX size to store probes topK output ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, + params, nullptr, nullptr, n_queries, @@ -245,6 +249,7 @@ void search_impl(raft::resources const& handle, ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, + params, queries, coarse_indices_dev.data(), n_queries, @@ -350,6 +355,7 @@ inline void search_with_filtering(raft::resources const& handle, search_impl(handle, index, + params, queries + offset_q * index.dim(), queries_batch, offset_q, @@ -383,6 +389,13 @@ void search_with_filtering(raft::resources const& handle, RAFT_EXPECTS(queries.extent(1) == index.dim(), "Number of query dimensions should equal number of dimensions in the index."); + // Save original metric and temporarily set to CustomUDF if using UDF + auto original_metric = index.metric(); + if (params.metric_udf.has_value()) { + const_cast>&>(index).set_metric( + cuvs::distance::DistanceType::CustomUDF); + } + search_with_filtering(handle, params, index, @@ -392,6 +405,12 @@ void search_with_filtering(raft::resources const& handle, neighbors.data_handle(), distances.data_handle(), sample_filter); + + // Restore original metric + if (params.metric_udf.has_value()) { + const_cast>&>(index).set_metric( + original_metric); + } } template diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh deleted file mode 100644 index 07fc4a21f5..0000000000 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh +++ /dev/null @@ -1,30 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "../../sample_filter.cuh" - -namespace cuvs::neighbors::ivf_flat::detail { - -template -__device__ bool sample_filter(index_t* const* const inds_ptrs, - const uint32_t query_ix, - const uint32_t cluster_ix, - const uint32_t sample_ix, - uint32_t* bitset_ptr, - index_t bitset_len, - index_t original_nbits) -{ - auto bitset_view = - raft::core::bitset_view{bitset_ptr, bitset_len, original_nbits}; - auto bitset_filter = cuvs::neighbors::filtering::bitset_filter{bitset_view}; - auto ivf_to_sample_filter = cuvs::neighbors::filtering:: - ivf_to_sample_filter>{ - inds_ptrs, bitset_filter}; - return ivf_to_sample_filter(query_ix, cluster_ix, sample_ix); -} - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_kernel.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_kernel.cu.in deleted file mode 100644 index a4c2f18f53..0000000000 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_kernel.cu.in +++ /dev/null @@ -1,15 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: Apache-2.0 - */ - -// This file is auto-generated. Do not edit manually. - -#include <@header_file@> - -namespace cuvs::neighbors::ivf_flat::detail { - -// Instantiate the device function template -template __device__ bool sample_filter(int64_t* const* const, const uint32_t, const uint32_t, const uint32_t, uint32_t*, int64_t, int64_t); - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_matrix.json b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_matrix.json index 6ceebe78c3..f82737eb93 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_matrix.json +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_matrix.json @@ -1,12 +1,15 @@ { - "_filter": [ + "filter_name": [ + "filter_none", + "filter_bitset" + ], + "_source_index": [ { - "filter_name": "filter_none", - "header_file": "neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh" - }, - { - "filter_name": "filter_bitset", - "header_file": "neighbors/ivf_flat/jit_lto_kernels/filter_bitset.cuh" + "source_index_type": "int64_t", + "source_index_abbrev": "l" } + ], + "namespace": [ + "cuvs::neighbors::detail" ] } diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh deleted file mode 100644 index aad15d64bc..0000000000 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/filter_none.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. - * SPDX-License-Identifier: Apache-2.0 - */ - -#pragma once - -#include "../../sample_filter.cuh" - -namespace cuvs::neighbors::ivf_flat::detail { - -template -__device__ constexpr bool sample_filter(index_t* const* const inds_ptrs, - const uint32_t query_ix, - const uint32_t cluster_ix, - const uint32_t sample_ix, - uint32_t* bitset_ptr, - index_t bitset_len, - index_t original_nbits) -{ - return true; -} - -} // namespace cuvs::neighbors::ivf_flat::detail diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in index 7e247543b0..9270d254fc 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_embedded.cpp.in @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include -#include +#include #include "@embedded_header_file@" using namespace cuvs::neighbors::ivf_flat::detail; diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in index b7ada4de5b..b199f0e4d8 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_kernel.cu.in @@ -3,13 +3,12 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include namespace cuvs::neighbors::ivf_flat::detail { // Instantiate the kernel template +// Pass individual filter parameters like CAGRA does template __global__ void interleaved_scan_kernel<@capacity@, @veclen@, @ascending_value@, @compute_norm_value@, @data_type@, @acc_type@, @idx_type@>( const uint32_t, const @data_type@*, const uint32_t*, const @data_type@* const*, const uint32_t*, const uint32_t, const uint32_t, const uint32_t, const uint32_t, const uint32_t*, const uint32_t, diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp index 77592e8fc3..a539a4e065 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/interleaved_scan_planner.hpp @@ -31,7 +31,7 @@ struct InterleavedScanPlanner : AlgorithmPlanner { void add_filter_device_function(std::string filter_name) { - auto key = filter_name; + auto key = "sample_filter_" + filter_name; this->device_functions.push_back(key); } diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh index 3a14fe8afd..c5ff7b35f3 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/ivf_flat_interleaved_scan_kernel.cuh @@ -5,10 +5,8 @@ #pragma once -#include "../../ivf_common.cuh" - -#include - +#include "../../detail/cagra/jit_lto_kernels/extern_device_functions.cuh" +#include "../../detail/jit_lto_kernels/filter_data.h" #include #include #include @@ -19,20 +17,29 @@ namespace cuvs::neighbors::ivf_flat::detail { +// Define kIndexGroupSize locally to avoid including ivf_flat.hpp which transitively includes Thrust +constexpr static uint32_t kIndexGroupSize = 32; + +// Define dummy_block_sort_t locally to avoid including ivf_common.cuh which transitively includes +// Thrust +template +struct dummy_block_sort_t { + using queue_t = raft::matrix::detail::select::warpsort:: + warp_sort_distributed; + template + __device__ dummy_block_sort_t(int k, Args...) {}; +}; + static constexpr int kThreadsPerBlock = 128; // These extern device functions are linked at runtime using JIT-LTO. template extern __device__ void compute_dist(AccT& acc, AccT x, AccT y); -template -extern __device__ bool sample_filter(index_t* const* const inds_ptrs, - const uint32_t query_ix, - const uint32_t cluster_ix, - const uint32_t sample_ix, - uint32_t* bitset_ptr, - index_t bitset_len, - index_t original_nbits); +// Unified sample_filter interface: takes query_id, node_id, and void* filter_data +// For IVF Flat: node_id should be computed from (cluster_ix, sample_ix) using inds_ptrs from +// filter_data sample_filter is declared in extern_device_functions.cuh (shared with CAGRA) +using cuvs::neighbors::detail::sample_filter; template extern __device__ T post_process(T val); @@ -725,9 +732,8 @@ struct flat_block_sort { }; template -struct flat_block_sort<0, Ascending, T, IdxT> - : ivf::detail::dummy_block_sort_t { - using type = ivf::detail::dummy_block_sort_t; +struct flat_block_sort<0, Ascending, T, IdxT> : dummy_block_sort_t { + using type = dummy_block_sort_t; }; template @@ -776,10 +782,11 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) const uint32_t max_samples, const uint32_t* chunk_indices, const uint32_t dim, - IdxT* const* const inds_ptrs, - uint32_t* bitset_ptr, - IdxT bitset_len, - IdxT original_nbits, + IdxT* const* const inds_ptrs, // Always needed for IVF Flat to convert + // (list_id, vec_id) to node_id + uint32_t* bitset_ptr, // Bitset data pointer (nullptr for none_filter) + IdxT bitset_len, // Bitset length + IdxT original_nbits, // Original number of bits uint32_t* neighbors, float* distances) { @@ -805,7 +812,7 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) } // Copy a part of the query into shared memory for faster processing - copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); + raft::copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); __syncthreads(); using local_topk_t = block_sort_t; @@ -851,13 +858,15 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) // This is the vector a given lane/thread handles const uint32_t vec_id = group_id * raft::WarpSize + lane_id; - const bool valid = vec_id < list_length && sample_filter(inds_ptrs, - queries_offset + blockIdx.y, - list_id, - vec_id, - bitset_ptr, - bitset_len, - original_nbits); + // For IVF Flat, convert (list_id, vec_id) to node_id using inds_ptrs + const IdxT node_id = inds_ptrs[list_id][vec_id]; + // Construct filter_data struct (bitset data is in global memory) + cuvs::neighbors::detail::bitset_filter_data_t filter_data( + bitset_ptr, bitset_len, original_nbits); + const bool valid = + vec_id < list_length && + sample_filter( + queries_offset + blockIdx.y, node_id, bitset_ptr != nullptr ? &filter_data : nullptr); if (valid) { // Process first shm_assisted_dim dimensions (always using shared memory) diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in index b951476565..a979411143 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_embedded.cpp.in @@ -3,10 +3,8 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include -#include +#include #include "@embedded_header_file@" using namespace cuvs::neighbors::ivf_flat::detail; diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_kernel.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_kernel.cu.in index a67956db58..09dedc2bb2 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_kernel.cu.in +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/metric_kernel.cu.in @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include <@header_file@> namespace cuvs::neighbors::ivf_flat::detail { diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in index a2e3f1ea03..b3449e8e17 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_embedded.cpp.in @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include #include "@embedded_header_file@" diff --git a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_kernel.cu.in b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_kernel.cu.in index 363964dd42..99823843c6 100644 --- a/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_kernel.cu.in +++ b/cpp/src/neighbors/ivf_flat/jit_lto_kernels/post_lambda_kernel.cu.in @@ -3,8 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ -// This file is auto-generated. Do not edit manually. - #include <@header_file@> namespace cuvs::neighbors::ivf_flat::detail { diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp index 77b24d4690..6ab162a117 100644 --- a/cpp/src/neighbors/ivf_flat_index.cpp +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -1,5 +1,5 @@ /* - * SPDX-FileCopyrightText: Copyright (c) 2024, NVIDIA CORPORATION. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION. * SPDX-License-Identifier: Apache-2.0 */ @@ -60,6 +60,12 @@ cuvs::distance::DistanceType index::metric() const noexcept return metric_; } +template +void index::set_metric(cuvs::distance::DistanceType metric) +{ + metric_ = metric; +} + template bool index::adaptive_centers() const noexcept { diff --git a/cpp/src/neighbors/refine/refine_device.cuh b/cpp/src/neighbors/refine/refine_device.cuh index 10873891a6..e09bb82efe 100644 --- a/cpp/src/neighbors/refine/refine_device.cuh +++ b/cpp/src/neighbors/refine/refine_device.cuh @@ -103,6 +103,7 @@ void refine_device( cuvs::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( refinement_index, + cuvs::neighbors::ivf_flat::search_params(), queries.data_handle(), fake_coarse_idx.data(), static_cast(n_queries), diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index 35794adf9b..164f0599c1 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -131,6 +131,15 @@ ConfigureTest( PERCENT 100 ) +if(JIT_LTO_COMPILATION) + ConfigureTest( + NAME NEIGHBORS_ANN_IVF_FLAT_UDF_TEST + PATH neighbors/ann_ivf_flat/test_udf.cu + GPUS 1 + PERCENT 100 + ) +endif() + ConfigureTest( NAME NEIGHBORS_ANN_IVF_PQ_TEST PATH neighbors/ann_ivf_pq/test_float_int64_t.cu neighbors/ann_ivf_pq/test_int8_t_int64_t.cu diff --git a/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu b/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu new file mode 100644 index 0000000000..ef6486ed4e --- /dev/null +++ b/cpp/tests/neighbors/ann_ivf_flat/test_udf.cu @@ -0,0 +1,387 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include + +#include +#include +#include +#include +#include +#include + +#include + +namespace cuvs::neighbors::ivf_flat { + +// ============================================================================ +// Define custom metrics using the UDF macro +// ============================================================================ + +// Custom L2 (squared Euclidean) metric - should match built-in L2 +CUVS_METRIC(custom_l2, { acc += squared_diff(x, y); }) + +// ============================================================================ +// Test data traits for different types +// ============================================================================ + +template +struct TestDataTraits; + +template <> +struct TestDataTraits { + static constexpr int64_t dim = 4; + static constexpr int64_t num_db_vecs = 8; + + static std::vector database() + { + // 4-dimensional float dataset + // Vectors arranged for easy distance verification: + // db[0] = [0, 0, 0, 0] - origin + // db[1] = [1, 0, 0, 0] - unit along x + // db[2] = [0, 1, 0, 0] - unit along y + // db[3] = [0, 0, 1, 0] - unit along z + // db[4] = [1, 1, 0, 0] - diagonal in xy + // db[5] = [2, 0, 0, 0] - 2 units along x + // db[6] = [1, 1, 1, 1] - all ones + // db[7] = [3, 4, 0, 0] - for 3-4-5 triangle + return { + 0.0f, 0.0f, 0.0f, 0.0f, // db[0]: origin + 1.0f, 0.0f, 0.0f, 0.0f, // db[1]: L2 dist from origin = 1 + 0.0f, 1.0f, 0.0f, 0.0f, // db[2]: L2 dist from origin = 1 + 0.0f, 0.0f, 1.0f, 0.0f, // db[3]: L2 dist from origin = 1 + 1.0f, 1.0f, 0.0f, 0.0f, // db[4]: L2 dist from origin = 2 + 2.0f, 0.0f, 0.0f, 0.0f, // db[5]: L2 dist from origin = 4 + 1.0f, 1.0f, 1.0f, 1.0f, // db[6]: L2 dist from origin = 4 + 3.0f, 4.0f, 0.0f, 0.0f, // db[7]: L2 dist from origin = 25 + }; + } + + static std::vector queries() + { + // query[0] = origin - nearest is db[0] (dist=0) + // query[1] = [1,0,0,0] - nearest is db[1] (dist=0) + return { + 0.0f, + 0.0f, + 0.0f, + 0.0f, // query[0]: origin + 1.0f, + 0.0f, + 0.0f, + 0.0f, // query[1]: same as db[1] + }; + } + + // Expected: query[0] nearest is db[0] with distance 0 + static int64_t expected_nearest_idx_q0() { return 0; } + static float expected_nearest_dist_q0() { return 0.0f; } + + // Expected: query[1] nearest is db[1] with distance 0 + static int64_t expected_nearest_idx_q1() { return 1; } + static float expected_nearest_dist_q1() { return 0.0f; } +}; + +template <> +struct TestDataTraits { + static constexpr int64_t dim = 16; + static constexpr int64_t num_db_vecs = 8; + + static std::vector database() + { + // 16-dimensional int8 dataset to test vectorized SIMD intrinsics + return { + // db[0]: all zeros + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + // db[1]: unit in first dim + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + // db[2]: unit in second dim + 0, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + // db[3]: all ones - L2 dist from zeros = 16 + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + // db[4]: first 12 dims are 2 - L2 dist from zeros = 48 + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 0, + 0, + 0, + 0, + // db[5]: all twos - L2 dist from zeros = 64 + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + // db[6]: alternating 1,0 - L2 dist from zeros = 8 + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + // db[7]: alternating 0,1 - L2 dist from zeros = 8 + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + }; + } + + static std::vector queries() + { + // query[0] = all zeros - nearest is db[0] (dist=0) + // query[1] = all ones - nearest is db[3] (dist=0) + return { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // query[0] + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // query[1] + }; + } + + // Expected: query[0] nearest is db[0] with distance 0 + static int64_t expected_nearest_idx_q0() { return 0; } + static float expected_nearest_dist_q0() { return 0.0f; } + + // Expected: query[1] nearest is db[3] with distance 0 + static int64_t expected_nearest_idx_q1() { return 3; } + static float expected_nearest_dist_q1() { return 0.0f; } +}; + +// ============================================================================ +// Templated test fixture +// ============================================================================ + +template +class IvfFlatUdfTest : public ::testing::Test { + protected: + using Traits = TestDataTraits; + + void SetUp() override + { + database_ = Traits::database(); + queries_ = Traits::queries(); + num_db_vecs_ = Traits::num_db_vecs; + num_queries_ = 2; + dim_ = Traits::dim; + k_ = 4; + n_lists_ = 2; + n_probes_ = 2; + } + + raft::resources handle_; + std::vector database_; + std::vector queries_; + int64_t num_db_vecs_; + int64_t num_queries_; + int64_t dim_; + int64_t k_; + uint32_t n_lists_; + uint32_t n_probes_; +}; + +using TestTypes = ::testing::Types; +TYPED_TEST_SUITE(IvfFlatUdfTest, TestTypes); + +// ============================================================================ +// Test: UDF L2 metric matches built-in L2 and produces correct distances +// ============================================================================ + +TYPED_TEST(IvfFlatUdfTest, CustomL2MatchesBuiltIn) +{ + using T = TypeParam; + using Traits = TestDataTraits; + + auto stream = raft::resource::get_cuda_stream(this->handle_); + + // Copy data to device + rmm::device_uvector d_database(this->num_db_vecs_ * this->dim_, stream); + rmm::device_uvector d_queries(this->num_queries_ * this->dim_, stream); + raft::copy(d_database.data(), this->database_.data(), this->database_.size(), stream); + raft::copy(d_queries.data(), this->queries_.data(), this->queries_.size(), stream); + + auto database_view = raft::make_device_matrix_view( + d_database.data(), this->num_db_vecs_, this->dim_); + auto queries_view = raft::make_device_matrix_view( + d_queries.data(), this->num_queries_, this->dim_); + + // Build index with L2 metric + ivf_flat::index_params index_params; + index_params.n_lists = this->n_lists_; + index_params.metric = cuvs::distance::DistanceType::L2Expanded; + + auto idx = ivf_flat::build(this->handle_, index_params, database_view); + + // Allocate output buffers + rmm::device_uvector d_indices_builtin(this->num_queries_ * this->k_, stream); + rmm::device_uvector d_distances_builtin(this->num_queries_ * this->k_, stream); + rmm::device_uvector d_indices_udf(this->num_queries_ * this->k_, stream); + rmm::device_uvector d_distances_udf(this->num_queries_ * this->k_, stream); + + auto indices_builtin_view = raft::make_device_matrix_view( + d_indices_builtin.data(), this->num_queries_, this->k_); + auto distances_builtin_view = raft::make_device_matrix_view( + d_distances_builtin.data(), this->num_queries_, this->k_); + auto indices_udf_view = raft::make_device_matrix_view( + d_indices_udf.data(), this->num_queries_, this->k_); + auto distances_udf_view = raft::make_device_matrix_view( + d_distances_udf.data(), this->num_queries_, this->k_); + + // Search with built-in metric + ivf_flat::search_params search_params_builtin; + search_params_builtin.n_probes = this->n_probes_; + + ivf_flat::search(this->handle_, + search_params_builtin, + idx, + queries_view, + indices_builtin_view, + distances_builtin_view); + + // Search with custom UDF metric + ivf_flat::search_params search_params_udf; + search_params_udf.n_probes = this->n_probes_; + search_params_udf.metric_udf = custom_l2_udf(); + + ivf_flat::search( + this->handle_, search_params_udf, idx, queries_view, indices_udf_view, distances_udf_view); + + // Copy results to host + std::vector h_indices_builtin(this->num_queries_ * this->k_); + std::vector h_distances_builtin(this->num_queries_ * this->k_); + std::vector h_indices_udf(this->num_queries_ * this->k_); + std::vector h_distances_udf(this->num_queries_ * this->k_); + + raft::copy( + h_indices_builtin.data(), d_indices_builtin.data(), this->num_queries_ * this->k_, stream); + raft::copy( + h_distances_builtin.data(), d_distances_builtin.data(), this->num_queries_ * this->k_, stream); + raft::copy(h_indices_udf.data(), d_indices_udf.data(), this->num_queries_ * this->k_, stream); + raft::copy(h_distances_udf.data(), d_distances_udf.data(), this->num_queries_ * this->k_, stream); + raft::resource::sync_stream(this->handle_); + + // Verify UDF results match built-in results + for (int64_t i = 0; i < this->num_queries_ * this->k_; ++i) { + EXPECT_EQ(h_indices_udf[i], h_indices_builtin[i]) << "Index mismatch at position " << i; + EXPECT_NEAR(h_distances_udf[i], h_distances_builtin[i], 1e-5f) + << "Distance mismatch at position " << i; + } + + // Verify expected distances for query[0] + EXPECT_EQ(h_indices_udf[0], Traits::expected_nearest_idx_q0()) + << "Query[0] nearest neighbor index mismatch"; + EXPECT_NEAR(h_distances_udf[0], Traits::expected_nearest_dist_q0(), 1e-5f) + << "Query[0] nearest neighbor distance mismatch"; + + // Verify expected distances for query[1] + int64_t q1_offset = this->k_; + EXPECT_EQ(h_indices_udf[q1_offset], Traits::expected_nearest_idx_q1()) + << "Query[1] nearest neighbor index mismatch"; + EXPECT_NEAR(h_distances_udf[q1_offset], Traits::expected_nearest_dist_q1(), 1e-5f) + << "Query[1] nearest neighbor distance mismatch"; +} + +} // namespace cuvs::neighbors::ivf_flat diff --git a/dependencies.yaml b/dependencies.yaml index 329265aae1..0954d3f329 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -332,6 +332,7 @@ dependencies: cuda: "13.*" packages: - libnvjitlink-dev + - cuda-nvrtc-dev - matrix: cuda: "12.*" packages: @@ -343,12 +344,12 @@ dependencies: cuda: "12.*" use_cuda_wheels: "true" packages: - - cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink]==12.* + - cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink,nvrtc]==12.* - matrix: cuda: "13.*" use_cuda_wheels: "true" packages: - - cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink]==13.* + - cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink,nvrtc]==13.* - matrix: use_cuda_wheels: "false" packages: @@ -356,7 +357,7 @@ dependencies: # (just as a source of documentation, as this populates pyproject.toml in source control) - matrix: packages: - - cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink]>=12,<14 + - cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink,nvrtc]>=12,<14 depends_on_cupy: common: - output_types: conda @@ -371,7 +372,7 @@ dependencies: - matrix: cuda: "12.*" packages: - - cupy-cuda12x>=13.6.0 + - cupy-cuda12x>=13.6.0,<14.0 # fallback to CUDA 13 versions if 'cuda' is '13.*' or not provided - matrix: packages: diff --git a/python/libcuvs/pyproject.toml b/python/libcuvs/pyproject.toml index f43bc35dbf..c7e0a57515 100644 --- a/python/libcuvs/pyproject.toml +++ b/python/libcuvs/pyproject.toml @@ -19,7 +19,7 @@ authors = [ license = "Apache-2.0" requires-python = ">=3.11" dependencies = [ - "cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink]>=12,<14", + "cuda-toolkit[cublas,curand,cusolver,cusparse,nvjitlink,nvrtc]>=12,<14", "libraft==26.4.*,>=0.0.0a0", "librmm==26.4.*,>=0.0.0a0", ] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`.