Skip to content

Commit 094c9fa

Browse files
committed
auto-vectorize using clang instead
1 parent 9a112b2 commit 094c9fa

File tree

10 files changed

+326
-403
lines changed

10 files changed

+326
-403
lines changed

bruteforce_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313

1414
/*
1515
cpu: 13th Gen Intel(R) Core(TM) i7-13700K
16-
BenchmarkIndex/search-24 3030 383868 ns/op 265 B/op 2 allocs/op
16+
BenchmarkIndex/search-24 3807 316587 ns/op 264 B/op 2 allocs/op
1717
*/
1818
func BenchmarkIndex(b *testing.B) {
1919
data, err := loadDataset()

internal/cosine/cosine_apple.c

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,27 @@
33
// Licensed under the MIT license. See LICENSE file in the project root for details.
44

55
#include <stdint.h>
6-
#include <arm_neon.h>
76
#include <math.h>
87

98
void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) {
10-
float32x4_t sum_xy = vdupq_n_f32(0.0f); // Sum of x * y
11-
float32x4_t sum_xx = vdupq_n_f32(0.0f); // Sum of x * x
12-
float32x4_t sum_yy = vdupq_n_f32(0.0f); // Sum of y * y
13-
14-
uint64_t i;
15-
for (i = 0; i + 3 < size; i += 4) {
16-
float32x4_t x_vec = vld1q_f32(x + i);
17-
float32x4_t y_vec = vld1q_f32(y + i);
18-
19-
sum_xy = vmlaq_f32(sum_xy, x_vec, y_vec);
20-
sum_xx = vmlaq_f32(sum_xx, x_vec, x_vec);
21-
sum_yy = vmlaq_f32(sum_yy, y_vec, y_vec);
22-
}
23-
24-
// Sum the elements of the vectors
25-
float dot_xy = vaddvq_f32(sum_xy);
26-
float norm_x = vaddvq_f32(sum_xx);
27-
float norm_y = vaddvq_f32(sum_yy);
28-
29-
// Handle any remaining elements
30-
for (; i < size; i++) {
31-
dot_xy += x[i] * y[i];
32-
norm_x += x[i] * x[i];
33-
norm_y += y[i] * y[i];
9+
float sum_xy = 0.0f;
10+
float sum_xx = 0.0f;
11+
float sum_yy = 0.0f;
12+
13+
#pragma clang loop vectorize(enable) interleave_count(4)
14+
for (uint64_t i = 0; i < size; i++) {
15+
sum_xy += x[i] * y[i]; // Sum of x * y
16+
sum_xx += x[i] * x[i]; // Sum of x * x
17+
sum_yy += y[i] * y[i]; // Sum of y * y
3418
}
3519

36-
37-
// Avoid division by zero
38-
float denominator = sqrtf(norm_x) * sqrtf(norm_y);
20+
// Calculate the final result
21+
float denominator = sqrtf(sum_xx) * sqrtf(sum_yy);
3922
if (denominator == 0.0f) {
4023
*result = (double)0.0f;
4124
return;
4225
}
4326

44-
double cosine_similarity = (double)dot_xy / (double)denominator;
27+
double cosine_similarity = (double)sum_xy / (double)denominator;
4528
*result = cosine_similarity;
4629
}

internal/cosine/cosine_avx.c

Lines changed: 12 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,58 +2,28 @@
22
// Copyright (c) Roman Atachiants and contributors. All rights reserved.
33
// Licensed under the MIT license. See LICENSE file in the project root for details.
44

5-
65
#include <stdint.h>
7-
#include <immintrin.h>
86
#include <math.h>
97

108
void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) {
11-
__m256 sum_xy = _mm256_setzero_ps(); // Sum of x * y
12-
__m256 sum_xx = _mm256_setzero_ps(); // Sum of x * x
13-
__m256 sum_yy = _mm256_setzero_ps(); // Sum of y * y
14-
15-
uint64_t i;
16-
for (i = 0; i <= size - 8; i += 8) {
17-
__m256 x_vec = _mm256_loadu_ps(x + i);
18-
__m256 y_vec = _mm256_loadu_ps(y + i);
19-
20-
sum_xy = _mm256_fmadd_ps(x_vec, y_vec, sum_xy); // sum_xy += x_vec * y_vec
21-
sum_xx = _mm256_fmadd_ps(x_vec, x_vec, sum_xx); // sum_xx += x_vec * x_vec
22-
sum_yy = _mm256_fmadd_ps(y_vec, y_vec, sum_yy); // sum_yy += y_vec * y_vec
23-
}
24-
25-
// Sum elements of sum_xy
26-
__m256 temp_xy = _mm256_hadd_ps(sum_xy, sum_xy); // Sum adjacent pairs
27-
temp_xy = _mm256_hadd_ps(temp_xy, temp_xy); // Sum adjacent quadruples
28-
__m128 sum_xy_128 = _mm_add_ps(_mm256_castps256_ps128(temp_xy), _mm256_extractf128_ps(temp_xy, 1));
29-
float dot_xy = _mm_cvtss_f32(sum_xy_128); // Extract final sum
30-
31-
// Sum elements of sum_xx
32-
__m256 temp_xx = _mm256_hadd_ps(sum_xx, sum_xx);
33-
temp_xx = _mm256_hadd_ps(temp_xx, temp_xx);
34-
__m128 sum_xx_128 = _mm_add_ps(_mm256_castps256_ps128(temp_xx), _mm256_extractf128_ps(temp_xx, 1));
35-
float norm_x = _mm_cvtss_f32(sum_xx_128);
36-
37-
// Sum elements of sum_yy
38-
__m256 temp_yy = _mm256_hadd_ps(sum_yy, sum_yy);
39-
temp_yy = _mm256_hadd_ps(temp_yy, temp_yy);
40-
__m128 sum_yy_128 = _mm_add_ps(_mm256_castps256_ps128(temp_yy), _mm256_extractf128_ps(temp_yy, 1));
41-
float norm_y = _mm_cvtss_f32(sum_yy_128);
42-
43-
// Handle remaining elements (if any)
44-
for (; i < size; i++) {
45-
dot_xy += x[i] * y[i];
46-
norm_x += x[i] * x[i];
47-
norm_y += y[i] * y[i];
9+
float sum_xy = 0.0f;
10+
float sum_xx = 0.0f;
11+
float sum_yy = 0.0f;
12+
13+
#pragma clang loop vectorize(enable) interleave_count(4)
14+
for (uint64_t i = 0; i < size; i++) {
15+
sum_xy += x[i] * y[i]; // Sum of x * y
16+
sum_xx += x[i] * x[i]; // Sum of x * x
17+
sum_yy += y[i] * y[i]; // Sum of y * y
4818
}
4919

50-
// Avoid division by zero
51-
float denominator = sqrtf(norm_x) * sqrtf(norm_y);
20+
// Calculate the final result
21+
float denominator = sqrtf(sum_xx) * sqrtf(sum_yy);
5222
if (denominator == 0.0f) {
5323
*result = (double)0.0f;
5424
return;
5525
}
5626

57-
double cosine_similarity = (double)dot_xy / (double)denominator;
27+
double cosine_similarity = (double)sum_xy / (double)denominator;
5828
*result = cosine_similarity;
5929
}

internal/cosine/cosine_neon.c

Lines changed: 12 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,44 +3,27 @@
33
// Licensed under the MIT license. See LICENSE file in the project root for details.
44

55
#include <stdint.h>
6-
#include <arm_neon.h>
76
#include <math.h>
87

98
void f32_cosine_distance(const float *x, const float *y, double *result, const uint64_t size) {
10-
float32x4_t sum_xy = vdupq_n_f32(0.0f); // Sum of x * y
11-
float32x4_t sum_xx = vdupq_n_f32(0.0f); // Sum of x * x
12-
float32x4_t sum_yy = vdupq_n_f32(0.0f); // Sum of y * y
13-
14-
uint64_t i;
15-
for (i = 0; i + 3 < size; i += 4) {
16-
float32x4_t x_vec = vld1q_f32(x + i);
17-
float32x4_t y_vec = vld1q_f32(y + i);
18-
19-
sum_xy = vmlaq_f32(sum_xy, x_vec, y_vec);
20-
sum_xx = vmlaq_f32(sum_xx, x_vec, x_vec);
21-
sum_yy = vmlaq_f32(sum_yy, y_vec, y_vec);
22-
}
23-
24-
// Sum the elements of the vectors
25-
float dot_xy = vaddvq_f32(sum_xy);
26-
float norm_x = vaddvq_f32(sum_xx);
27-
float norm_y = vaddvq_f32(sum_yy);
28-
29-
// Handle any remaining elements
30-
for (; i < size; i++) {
31-
dot_xy += x[i] * y[i];
32-
norm_x += x[i] * x[i];
33-
norm_y += y[i] * y[i];
9+
float sum_xy = 0.0f;
10+
float sum_xx = 0.0f;
11+
float sum_yy = 0.0f;
12+
13+
#pragma clang loop vectorize(enable) interleave_count(4)
14+
for (uint64_t i = 0; i < size; i++) {
15+
sum_xy += x[i] * y[i]; // Sum of x * y
16+
sum_xx += x[i] * x[i]; // Sum of x * x
17+
sum_yy += y[i] * y[i]; // Sum of y * y
3418
}
3519

36-
37-
// Avoid division by zero
38-
float denominator = sqrtf(norm_x) * sqrtf(norm_y);
20+
// Calculate the final result
21+
float denominator = sqrtf(sum_xx) * sqrtf(sum_yy);
3922
if (denominator == 0.0f) {
4023
*result = (double)0.0f;
4124
return;
4225
}
4326

44-
double cosine_similarity = (double)dot_xy / (double)denominator;
27+
double cosine_similarity = (double)sum_xy / (double)denominator;
4528
*result = cosine_similarity;
4629
}

internal/cosine/simd/cosine_apple.s

Lines changed: 82 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -8,120 +8,109 @@ TEXT ·f32_cosine_distance(SB), $0-32
88
MOVD size+24(FP), R3
99
WORD $0xa9bf7bfd // stp x29, x30, [sp, #-16]! ; 16-byte Folded Spill
1010
WORD $0x910003fd // mov x29, sp
11-
WORD $0xf100107f // cmp x3, #4
12-
WORD $0x54000223 // b.lo LBB0_4
13-
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
14-
WORD $0x52800068 // mov w8, #3
15-
WORD $0xaa0103e9 // mov x9, x1
16-
WORD $0xaa0003ea // mov x10, x0
17-
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
18-
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
11+
WORD $0x2f00e400 // movi d0, #0000000000000000
12+
WORD $0xb4000103 // cbz x3, LBB0_3
13+
WORD $0xf100407f // cmp x3, #16
14+
WORD $0x54000182 // b.hs LBB0_4
15+
WORD $0xd2800008 // mov x8, #0
16+
WORD $0x2f00e401 // movi d1, #0000000000000000
17+
WORD $0x2f00e405 // movi d5, #0000000000000000
18+
WORD $0x2f00e411 // movi d17, #0000000000000000
19+
WORD $0x1400003a // b LBB0_7
1920

20-
BB0_2:
21-
WORD $0x3cc10543 // ldr q3, [x10], #16
22-
WORD $0x3cc10524 // ldr q4, [x9], #16
23-
WORD $0x4e23cc80 // fmla.4s v0, v4, v3
24-
WORD $0x4e23cc61 // fmla.4s v1, v3, v3
25-
WORD $0x4e24cc82 // fmla.4s v2, v4, v4
26-
WORD $0x91001108 // add x8, x8, #4
27-
WORD $0xeb03011f // cmp x8, x3
28-
WORD $0x54ffff23 // b.lo LBB0_2
29-
WORD $0x927ef46b // and x11, x3, #0xfffffffffffffffc
30-
WORD $0x14000005 // b LBB0_5
21+
BB0_3:
22+
WORD $0x2f00e402 // movi d2, #0000000000000000
23+
WORD $0x2f00e401 // movi d1, #0000000000000000
24+
WORD $0x1e21c042 // fsqrt s2, s2
25+
WORD $0x1e202048 // fcmp s2, #0.0
26+
WORD $0x540008a0 // b.eq LBB0_10
27+
WORD $0x14000047 // b LBB0_11
3128

3229
BB0_4:
33-
WORD $0xd280000b // mov x11, #0
34-
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
3530
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
36-
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
37-
38-
BB0_5:
39-
WORD $0x6e20d400 // faddp.4s v0, v0, v0
40-
WORD $0x7e30d800 // faddp.2s s0, v0
41-
WORD $0x6e21d421 // faddp.4s v1, v1, v1
42-
WORD $0x7e30d821 // faddp.2s s1, v1
43-
WORD $0x6e22d442 // faddp.4s v2, v2, v2
44-
WORD $0x7e30d842 // faddp.2s s2, v2
45-
WORD $0xeb03017f // cmp x11, x3
46-
WORD $0x54000702 // b.hs LBB0_13
47-
WORD $0xcb0b0069 // sub x9, x3, x11
48-
WORD $0xf100213f // cmp x9, #8
49-
WORD $0x54000062 // b.hs LBB0_8
50-
WORD $0xaa0b03e8 // mov x8, x11
51-
WORD $0x14000028 // b LBB0_11
52-
53-
BB0_8:
54-
WORD $0x927df12a // and x10, x9, #0xfffffffffffffff8
55-
WORD $0x8b0a0168 // add x8, x11, x10
31+
WORD $0x927cec68 // and x8, x3, #0xfffffffffffffff0
32+
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
33+
WORD $0x91008009 // add x9, x0, #32
5634
WORD $0x6f00e403 // movi.2d v3, #0000000000000000
35+
WORD $0x9100802a // add x10, x1, #32
5736
WORD $0x6f00e404 // movi.2d v4, #0000000000000000
58-
WORD $0x6e040444 // mov.s v4[0], v2[0]
59-
WORD $0x6f00e402 // movi.2d v2, #0000000000000000
60-
WORD $0x6e040422 // mov.s v2[0], v1[0]
61-
WORD $0x6f00e401 // movi.2d v1, #0000000000000000
62-
WORD $0x6e040401 // mov.s v1[0], v0[0]
63-
WORD $0xd37ef56b // lsl x11, x11, #2
64-
WORD $0x9100416c // add x12, x11, #16
65-
WORD $0x8b0c000b // add x11, x0, x12
66-
WORD $0x8b0c002c // add x12, x1, x12
67-
WORD $0xaa0a03ed // mov x13, x10
37+
WORD $0xaa0803eb // mov x11, x8
6838
WORD $0x6f00e405 // movi.2d v5, #0000000000000000
69-
WORD $0x6f00e400 // movi.2d v0, #0000000000000000
39+
WORD $0x6f00e406 // movi.2d v6, #0000000000000000
40+
WORD $0x6f00e407 // movi.2d v7, #0000000000000000
41+
WORD $0x6f00e410 // movi.2d v16, #0000000000000000
42+
WORD $0x6f00e411 // movi.2d v17, #0000000000000000
43+
WORD $0x6f00e412 // movi.2d v18, #0000000000000000
44+
WORD $0x6f00e413 // movi.2d v19, #0000000000000000
45+
WORD $0x6f00e414 // movi.2d v20, #0000000000000000
7046

71-
BB0_9:
72-
WORD $0xad7f9d66 // ldp q6, q7, [x11, #-16]
73-
WORD $0xad7fc590 // ldp q16, q17, [x12, #-16]
74-
WORD $0x4e26ce01 // fmla.4s v1, v16, v6
75-
WORD $0x4e27ce20 // fmla.4s v0, v17, v7
76-
WORD $0x4e26ccc2 // fmla.4s v2, v6, v6
77-
WORD $0x4e27cce5 // fmla.4s v5, v7, v7
78-
WORD $0x4e30ce04 // fmla.4s v4, v16, v16
79-
WORD $0x4e31ce23 // fmla.4s v3, v17, v17
80-
WORD $0x9100816b // add x11, x11, #32
81-
WORD $0x9100818c // add x12, x12, #32
82-
WORD $0xf10021ad // subs x13, x13, #8
83-
WORD $0x54fffea1 // b.ne LBB0_9
84-
WORD $0x4e21d400 // fadd.4s v0, v0, v1
85-
WORD $0x6e20d400 // faddp.4s v0, v0, v0
86-
WORD $0x7e30d800 // faddp.2s s0, v0
87-
WORD $0x4e22d4a1 // fadd.4s v1, v5, v2
47+
BB0_5:
48+
WORD $0xad7f5935 // ldp q21, q22, [x9, #-32]
49+
WORD $0xacc26137 // ldp q23, q24, [x9], #64
50+
WORD $0xad7f6959 // ldp q25, q26, [x10, #-32]
51+
WORD $0xacc2715b // ldp q27, q28, [x10], #64
52+
WORD $0x4e35cf21 // fmla.4s v1, v25, v21
53+
WORD $0x4e36cf42 // fmla.4s v2, v26, v22
54+
WORD $0x4e37cf63 // fmla.4s v3, v27, v23
55+
WORD $0x4e38cf84 // fmla.4s v4, v28, v24
56+
WORD $0x4e35cea5 // fmla.4s v5, v21, v21
57+
WORD $0x4e36cec6 // fmla.4s v6, v22, v22
58+
WORD $0x4e37cee7 // fmla.4s v7, v23, v23
59+
WORD $0x4e38cf10 // fmla.4s v16, v24, v24
60+
WORD $0x4e39cf31 // fmla.4s v17, v25, v25
61+
WORD $0x4e3acf52 // fmla.4s v18, v26, v26
62+
WORD $0x4e3bcf73 // fmla.4s v19, v27, v27
63+
WORD $0x4e3ccf94 // fmla.4s v20, v28, v28
64+
WORD $0xf100416b // subs x11, x11, #16
65+
WORD $0x54fffde1 // b.ne LBB0_5
66+
WORD $0x4e31d651 // fadd.4s v17, v18, v17
67+
WORD $0x4e31d671 // fadd.4s v17, v19, v17
68+
WORD $0x4e31d691 // fadd.4s v17, v20, v17
69+
WORD $0x6e31d631 // faddp.4s v17, v17, v17
70+
WORD $0x7e30da31 // faddp.2s s17, v17
71+
WORD $0x4e25d4c5 // fadd.4s v5, v6, v5
72+
WORD $0x4e25d4e5 // fadd.4s v5, v7, v5
73+
WORD $0x4e25d605 // fadd.4s v5, v16, v5
74+
WORD $0x6e25d4a5 // faddp.4s v5, v5, v5
75+
WORD $0x7e30d8a5 // faddp.2s s5, v5
76+
WORD $0x4e21d441 // fadd.4s v1, v2, v1
77+
WORD $0x4e21d461 // fadd.4s v1, v3, v1
78+
WORD $0x4e21d481 // fadd.4s v1, v4, v1
8879
WORD $0x6e21d421 // faddp.4s v1, v1, v1
8980
WORD $0x7e30d821 // faddp.2s s1, v1
90-
WORD $0x4e24d462 // fadd.4s v2, v3, v4
91-
WORD $0x6e22d442 // faddp.4s v2, v2, v2
92-
WORD $0x7e30d842 // faddp.2s s2, v2
93-
WORD $0xeb0a013f // cmp x9, x10
94-
WORD $0x54000180 // b.eq LBB0_13
81+
WORD $0xeb03011f // cmp x8, x3
82+
WORD $0x54000180 // b.eq LBB0_9
9583

96-
BB0_11:
84+
BB0_7:
9785
WORD $0xcb080069 // sub x9, x3, x8
9886
WORD $0xd37ef50a // lsl x10, x8, #2
9987
WORD $0x8b0a0028 // add x8, x1, x10
10088
WORD $0x8b0a000a // add x10, x0, x10
10189

102-
BB0_12:
103-
WORD $0xbc404543 // ldr s3, [x10], #4
104-
WORD $0xbc404504 // ldr s4, [x8], #4
105-
WORD $0x1f030080 // fmadd s0, s4, s3, s0
106-
WORD $0x1f030461 // fmadd s1, s3, s3, s1
107-
WORD $0x1f040882 // fmadd s2, s4, s4, s2
90+
BB0_8:
91+
WORD $0xbc404542 // ldr s2, [x10], #4
92+
WORD $0xbc404503 // ldr s3, [x8], #4
93+
WORD $0x1f020461 // fmadd s1, s3, s2, s1
94+
WORD $0x1f021445 // fmadd s5, s2, s2, s5
95+
WORD $0x1f034471 // fmadd s17, s3, s3, s17
10896
WORD $0xf1000529 // subs x9, x9, #1
109-
WORD $0x54ffff41 // b.ne LBB0_12
97+
WORD $0x54ffff41 // b.ne LBB0_8
11098

111-
BB0_13:
112-
WORD $0x1e210841 // fmul s1, s2, s1
113-
WORD $0x1e21c022 // fsqrt s2, s1
114-
WORD $0x2f00e401 // movi d1, #0000000000000000
99+
BB0_9:
100+
WORD $0x1e3108a2 // fmul s2, s5, s17
101+
WORD $0x1e22c021 // fcvt d1, s1
102+
WORD $0x1e21c042 // fsqrt s2, s2
115103
WORD $0x1e202048 // fcmp s2, #0.0
116-
WORD $0x54000081 // b.ne LBB0_15
117-
WORD $0xfd000041 // str d1, [x2]
104+
WORD $0x54000081 // b.ne LBB0_11
105+
106+
BB0_10:
107+
WORD $0xfd000040 // str d0, [x2]
118108
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
119109
WORD $0xd65f03c0 // ret
120110

121-
BB0_15:
122-
WORD $0x1e22c000 // fcvt d0, s0
123-
WORD $0x1e22c041 // fcvt d1, s2
124-
WORD $0x1e611801 // fdiv d1, d0, d1
125-
WORD $0xfd000041 // str d1, [x2]
111+
BB0_11:
112+
WORD $0x1e22c040 // fcvt d0, s2
113+
WORD $0x1e601820 // fdiv d0, d1, d0
114+
WORD $0xfd000040 // str d0, [x2]
126115
WORD $0xa8c17bfd // ldp x29, x30, [sp], #16 ; 16-byte Folded Reload
127116
WORD $0xd65f03c0 // ret

0 commit comments

Comments
 (0)