-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathannoy.r
More file actions
executable file
·129 lines (126 loc) · 4.49 KB
/
annoy.r
File metadata and controls
executable file
·129 lines (126 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#-------------------------------------------------------------------------------
#' Annoy: approximate nearest neighbours oh yeah
#'
#' This code is directly lifted from seurat's annoy implementation here:
#' https://github.com/satijalab/seurat/blob/master/R/clustering.R
#'
#' More detaills on knn implementations in this paper:
#' https://www.ncbi.nlm.nih.gov/pmc/articles/PMC11014608/
#'
#' The reason for using this over exact NN is the speed,
#' but other flavours of approximimate NN are available
#' Run annoy
#'
#' @param data Data to build the index with
#' @param query A set of data to be queried against data
#' @param metric Distance metric; can be one of "euclidean", "cosine", "manhattan",
#' "hamming"
#' @param n.trees More trees gives higher precision when querying
#' @param k Number of neighbors
#' @param search.k During the query it will inspect up to search_k nodes which
#' gives you a run-time tradeoff between better accuracy and speed
#' @param include.distance Include the corresponding distances
#' @param index optional index object, will be recomputed if not provided
#' @keywords internal
AnnoyNN <- function(data,
query = data,
metric = "euclidean",
n.trees = 50,
k,
search.k = -1,
include.distance = TRUE,
index = NULL) {
if (!is.null(index)) {
idx <- index
} else {
idx <- AnnoyBuildIndex(
data = data,
metric = metric,
n.trees = n.trees
)
}
nn <- AnnoySearch(
index = idx,
query = query,
k = k,
search.k = search.k,
include.distance = include.distance
)
nn$idx <- idx
nn$alg.info <- list(metric = metric, ndim = ncol(x = data))
return(nn)
}
#-------------------------------------------------------------------------------
#' Build the annoy index
#'
#' This code is directly lifted from seurat's annoy implementation here:
#' https://github.com/satijalab/seurat/blob/master/R/clustering.R
#'
#' @param data Data to build the index with
#' @param metric Distance metric; can be one of "euclidean", "cosine", "manhattan",
#' "hamming"
#' @param n.trees More trees gives higher precision when querying
#'
#' @importFrom RcppAnnoy AnnoyEuclidean AnnoyAngular AnnoyManhattan AnnoyHamming
#' @keywords internal
AnnoyBuildIndex <- function(data, metric = "euclidean", n.trees = 50) {
f <- ncol(x = data)
a <- switch(
EXPR = metric,
"euclidean" = new(Class = RcppAnnoy::AnnoyEuclidean, f),
"cosine" = new(Class = RcppAnnoy::AnnoyAngular, f),
"manhattan" = new(Class = RcppAnnoy::AnnoyManhattan, f),
"hamming" = new(Class = RcppAnnoy::AnnoyHamming, f),
stop("Invalid metric")
)
for (ii in seq(nrow(x = data))) {
a$addItem(ii - 1, data[ii, ])
}
a$build(n.trees)
return(a)
}
#-------------------------------------------------------------------------------
#' Search an Annoy approximate nearest neighbor index
#'
#' This code is directly lifted from seurat's annoy implementation here:
#' https://github.com/satijalab/seurat/blob/master/R/clustering.R
#'
#' @param Annoy index, built with AnnoyBuildIndex
#' @param query A set of data to be queried against the index
#' @param k Number of neighbors
#' @param search.k During the query it will inspect up to search_k nodes which
#' gives you a run-time tradeoff between better accuracy and speed
#' @param include.distance Include the corresponding distances in the result
#'
#' @return A list with 'nn.idx' (for each element in 'query', the index of the
#' nearest k elements in the index) and 'nn.dists' (the distances of the nearest
#' k elements)
#'
#' @importFrom future plan
#' @importFrom future.apply future_lapply
#' @keywords internal
AnnoySearch <- function(index, query, k, search.k = -1, include.distance = TRUE) {
n <- nrow(x = query)
idx <- matrix(nrow = n, ncol = k)
dist <- matrix(nrow = n, ncol = k)
convert <- methods::is(index, "Rcpp_AnnoyAngular")
if (!inherits(x = plan(), what = "multicore")) {
oplan <- plan(strategy = "sequential")
on.exit(plan(oplan), add = TRUE)
}
res <- future_lapply(X = 1:n, FUN = function(x) {
res <- index$getNNsByVectorList(query[x, ], k, search.k, include.distance)
# Convert from Angular to Cosine distance
if (convert) {
res$dist <- 0.5 * (res$dist * res$dist)
}
list(res$item + 1, res$distance)
})
for (i in 1:n) {
idx[i, ] <- res[[i]][[1]]
if (include.distance) {
dist[i, ] <- res[[i]][[2]]
}
}
return(list(nn.idx = idx, nn.dists = dist))
}