11def replicate32 n = replicate (i64 .i32 n )
22
3- def bins k is = map (% k ) is |> map i64 .i32
3+ def bins k is = map (% k ) is |> map i64 .i32
44
55-- Simple cases with addition operator, which can be translated
66-- directly to atomic addition.
@@ -12,7 +12,7 @@ def bins k is = map (%k) is |> map i64.i32
1212-- random input { 10000 [1000000]i32 [1000000]i32 } auto output
1313-- random input { 100000 [1000000]i32 [1000000]i32 } auto output
1414
15- entry sum_i32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i32 ) : []i32 =
15+ entry sum_i32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i32 ) : []i32 =
1616 hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs
1717
1818-- An f32 requires a little more work from the compiler.
@@ -24,7 +24,7 @@ entry sum_i32 [n] (k: i32) (is : [n]i32) (vs : [n]i32) : []i32 =
2424-- random input { 10000 [1000000]i32 [1000000]f32 } auto output
2525-- random input { 100000 [1000000]i32 [1000000]f32 } auto output
2626
27- entry sum_f32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]f32 ) : []f32 =
27+ entry sum_f32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]f32 ) : []f32 =
2828 hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs
2929
3030-- Do both!
@@ -36,9 +36,10 @@ entry sum_f32 [n] (k: i32) (is : [n]i32) (vs : [n]f32) : []f32 =
3636-- random input { 10000 [1000000]i32 [1000000]i32 [1000000]f32 } auto output
3737-- random input { 100000 [1000000]i32 [1000000]i32 [1000000]f32 } auto output
3838
39- entry sum_i32_f32 [n ] (k : i32 ) (is : [n ]i32 ) (vs1 : [n ]i32 ) (vs2 : [n ]f32 ) : ([]i32 , []f32 ) =
40- (hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs1 ,
41- hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs2 )
39+ entry sum_i32_f32 [n ] (k : i32 ) (is : [n ]i32 ) (vs1 : [n ]i32 ) (vs2 : [n ]f32 ) : ([]i32 , []f32 ) =
40+ ( hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs1
41+ , hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs2
42+ )
4243
4344-- Now a fancier operator, but because the payload is an i32, an
4445-- efficient implementation is possible.
@@ -50,10 +51,10 @@ entry sum_i32_f32 [n] (k: i32) (is : [n]i32) (vs1 : [n]i32) (vs2 : [n]f32) : ([]
5051-- random input { 10000 [1000000]i32 [1000000]i32 } auto output
5152-- random input { 100000 [1000000]i32 [1000000]i32 } auto output
5253
53- def absmax (x : i32 ) (y : i32 ): i32 =
54+ def absmax (x : i32 ) (y : i32 ) : i32 =
5455 if i32 .abs x < i32 .abs y then y else x
5556
56- entry absmax_i32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i32 ) : []i32 =
57+ entry absmax_i32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i32 ) : []i32 =
5758 hist absmax 0 (i64 .i32 k ) (bins k is ) vs
5859
5960-- Now a vectorised operator. If the compiler is clever, it can
@@ -65,9 +66,9 @@ entry absmax_i32 [n] (k: i32) (is : [n]i32) (vs : [n]i32) : []i32 =
6566-- random input { 10000 [10000]i32 [1000000]i32 } auto output
6667-- random input { 10000 [1000]i32 [1000000]i32 } auto output
6768
68- entry sum_vec_i32 [n ][m ] (k : i32 ) (is : [m ]i32 ) (vs : [n ]i32 ) : [][]i32 =
69- let l = n / m
70- let vs' = unflatten (sized (m * l ) vs )
69+ entry sum_vec_i32 [n ] [m ] (k : i32 ) (is : [m ]i32 ) (vs : [n ]i32 ) : [][]i32 =
70+ let l = n / m
71+ let vs' = unflatten (sized (m * l ) vs )
7172 in hist (map2 (+ )) (replicate l 0 ) (i64 .i32 k ) (bins k is ) vs'
7273
7374-- An operator that the compiler really cannot do anything clever
@@ -81,12 +82,42 @@ entry sum_vec_i32 [n][m] (k: i32) (is : [m]i32) (vs : [n]i32) : [][]i32 =
8182-- random input { 100000 [1000000]i32 [1000000]i32 } auto output
8283
8384def argmax_op ((x : i32 ), (i : i32 )) ((y : i32 ), (j : i32 )) =
84- if y > x then (y , j )
85- else if y < x then (x , i )
86- else if i > j then (y , j )
85+ if y > x
86+ then (y , j )
87+ else if y < x
88+ then (x , i )
89+ else if i > j
90+ then (y , j )
8791 else (x , i )
8892
89- entry argmax_i32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i32 ) : ([]i32 , []i32 ) =
90- hist argmax_op (i32 .lowest , -1 ) (i64 .i32 k )
91- (bins k is ) (zip vs (map i32 .i64 (iota n )))
93+ entry argmax_i32 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i32 ) : ([]i32 , []i32 ) =
94+ hist argmax_op
95+ (i32 .lowest , -1 )
96+ (i64 .i32 k )
97+ (bins k is )
98+ (zip vs (map i32 .i64 (iota n )))
9299 |> unzip
100+
101+ -- f16 addition is not directly supported.
102+ -- ==
103+ -- entry: sum_f16
104+ -- random input { 10 [1000000]i32 [1000000]f16 } auto output
105+ -- random input { 100 [1000000]i32 [1000000]f16 } auto output
106+ -- random input { 1000 [1000000]i32 [1000000]f16 } auto output
107+ -- random input { 10000 [1000000]i32 [1000000]f16 } auto output
108+ -- random input { 100000 [1000000]i32 [1000000]f16 } auto output
109+
110+ entry sum_f16 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]f16 ) : []f16 =
111+ hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs
112+
113+ -- i8 addition is not directly supported.
114+ -- ==
115+ -- entry: sum_i8
116+ -- random input { 10 [1000000]i32 [1000000]i8 } auto output
117+ -- random input { 100 [1000000]i32 [1000000]i8 } auto output
118+ -- random input { 1000 [1000000]i32 [1000000]i8 } auto output
119+ -- random input { 10000 [1000000]i32 [1000000]i8 } auto output
120+ -- random input { 100000 [1000000]i32 [1000000]i8 } auto output
121+
122+ entry sum_i8 [n ] (k : i32 ) (is : [n ]i32 ) (vs : [n ]i8 ) : []i8 =
123+ hist (+ ) 0 (i64 .i32 k ) (bins k is ) vs
0 commit comments