Skip to content

Commit 3a6de23

Browse files
committed
feat: add generic filter iterator and use for fitlering
1 parent da4a194 commit 3a6de23

File tree

5 files changed

+162
-76
lines changed

5 files changed

+162
-76
lines changed

routing/http/server/filters.go

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
package server
22

33
import (
4+
"errors"
45
"reflect"
56
"slices"
67
"strings"
78

89
"github.com/ipfs/boxo/routing/http/types"
10+
"github.com/ipfs/boxo/routing/http/types/iter"
911
"github.com/multiformats/go-multiaddr"
1012
)
1113

@@ -18,18 +20,68 @@ func parseFilter(param string) []string {
1820
return strings.Split(strings.ToLower(param), ",")
1921
}
2022

21-
func filterProviders(providers []types.Record, filterAddrs, filterProtocols []string) []types.Record {
23+
// applyFiltersToIter applies the filters to the given iterator and returns a new iterator.
24+
func applyFiltersToIter(recordsIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) iter.ResultIter[types.Record] {
25+
mappedIter := iter.Map(recordsIter, func(v iter.Result[types.Record]) iter.Result[types.Record] {
26+
if v.Err != nil || v.Val == nil {
27+
return v
28+
}
29+
30+
switch v.Val.GetSchema() {
31+
case types.SchemaPeer:
32+
record, ok := v.Val.(*types.PeerRecord)
33+
if !ok {
34+
logger.Errorw("problem casting find providers record", "Schema", v.Val.GetSchema(), "Type", reflect.TypeOf(v).String())
35+
// TODO: Do we want to let failed type assertions to pass through?
36+
return v
37+
}
38+
39+
record = applyFilters(record, filterAddrs, filterProtocols)
40+
if record == nil {
41+
return iter.Result[types.Record]{Err: errors.New("record is nil")}
42+
}
43+
v.Val = record
44+
45+
//lint:ignore SA1019 // ignore staticcheck
46+
case types.SchemaBitswap:
47+
//lint:ignore SA1019 // ignore staticcheck
48+
record, ok := v.Val.(*types.BitswapRecord)
49+
if !ok {
50+
logger.Errorw("problem casting find providers record", "Schema", v.Val.GetSchema(), "Type", reflect.TypeOf(v).String())
51+
// TODO: Do we want to let failed type assertions to pass through?
52+
return v
53+
}
54+
peerRecord := types.FromBitswapRecord(record)
55+
peerRecord = applyFilters(peerRecord, filterAddrs, filterProtocols)
56+
if peerRecord == nil {
57+
return iter.Result[types.Record]{Err: errors.New("record is nil")}
58+
}
59+
v.Val = peerRecord
60+
}
61+
return v
62+
})
63+
64+
// filter out nil results and errors
65+
filteredIter := iter.Filter(mappedIter, func(v iter.Result[types.Record]) bool {
66+
return v.Err == nil && v.Val != nil
67+
})
68+
69+
return filteredIter
70+
}
71+
72+
func filterRecords(records []types.Record, filterAddrs, filterProtocols []string) []types.Record {
2273
if len(filterAddrs) == 0 && len(filterProtocols) == 0 {
23-
return providers
74+
return records
2475
}
2576

26-
filtered := make([]types.Record, 0, len(providers))
77+
filtered := make([]types.Record, 0, len(records))
2778

28-
for _, provider := range providers {
29-
if schema := provider.GetSchema(); schema == types.SchemaPeer {
30-
peer, ok := provider.(*types.PeerRecord)
79+
for _, record := range records {
80+
// TODO: Handle SchemaBitswap
81+
if schema := record.GetSchema(); schema == types.SchemaPeer {
82+
peer, ok := record.(*types.PeerRecord)
3183
if !ok {
32-
logger.Errorw("problem casting find providers result", "Schema", provider.GetSchema(), "Type", reflect.TypeOf(provider).String())
84+
logger.Errorw("problem casting find providers result", "Schema", record.GetSchema(), "Type", reflect.TypeOf(record).String())
3385
// if the type assertion fails, we exlude record from results
3486
continue
3587
}
@@ -42,7 +94,7 @@ func filterProviders(providers []types.Record, filterAddrs, filterProtocols []st
4294

4395
} else {
4496
// Will we ever encounter the SchemaBitswap type? Evidence seems to suggest that no longer
45-
logger.Errorw("encountered unknown provider schema", "Schema", provider.GetSchema(), "Type", reflect.TypeOf(provider).String())
97+
logger.Errorw("encountered unknown provider schema", "Schema", record.GetSchema(), "Type", reflect.TypeOf(record).String())
4698
}
4799
}
48100
return filtered
@@ -51,6 +103,10 @@ func filterProviders(providers []types.Record, filterAddrs, filterProtocols []st
51103
// Applies the filters. Returns nil if the provider does not pass the protocols filter
52104
// The address filter is more complicated because it potentially modifies the Addrs slice.
53105
func applyFilters(provider *types.PeerRecord, filterAddrs, filterProtocols []string) *types.PeerRecord {
106+
if len(filterAddrs) == 0 && len(filterProtocols) == 0 {
107+
return provider
108+
}
109+
54110
if !applyProtocolFilter(provider.Protocols, filterProtocols) {
55111
// If the provider doesn't match any of the passed protocols, the provider is omitted from the response.
56112
return nil

routing/http/server/server.go

Lines changed: 5 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -235,84 +235,21 @@ func (s *server) findProviders(w http.ResponseWriter, httpReq *http.Request) {
235235
func (s *server) findProvidersJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) {
236236
defer provIter.Close()
237237

238-
providers, err := iter.ReadAllResults(provIter)
238+
filteredIter := applyFiltersToIter(provIter, filterAddrs, filterProtocols)
239+
providers, err := iter.ReadAllResults(filteredIter)
239240
if err != nil {
240241
writeErr(w, "FindProviders", http.StatusInternalServerError, fmt.Errorf("delegate error: %w", err))
241242
return
242243
}
243244

244-
filteredProviders := filterProviders(providers, filterAddrs, filterProtocols)
245-
246245
writeJSONResult(w, "FindProviders", jsontypes.ProvidersResponse{
247-
Providers: filteredProviders,
246+
Providers: providers,
248247
})
249248
}
250249
func (s *server) findProvidersNDJSON(w http.ResponseWriter, provIter iter.ResultIter[types.Record], filterAddrs, filterProtocols []string) {
251-
defer provIter.Close()
252-
253-
w.Header().Set("Content-Type", mediaTypeNDJSON)
254-
w.Header().Add("Vary", "Accept")
255-
w.Header().Set("Last-Modified", time.Now().UTC().Format(http.TimeFormat))
256-
257-
hasResults := false
258-
for provIter.Next() {
259-
res := provIter.Val()
260-
if res.Err != nil {
261-
logger.Errorw("ndjson iterator error", "Error", res.Err)
262-
return
263-
}
264-
265-
// handle filtering per record as we iterate
266-
if len(filterAddrs) > 0 || len(filterProtocols) > 0 {
267-
switch v := res.Val.(type) {
268-
case *types.PeerRecord:
269-
record := applyFilters(v, filterAddrs, filterProtocols)
270-
if record == nil {
271-
// if the record is nil, we skip it
272-
continue
273-
}
274-
res.Val = record
275-
default:
276-
logger.Warn("unexpected type for res.Val, expected types.PeerRecord")
277-
continue
278-
}
279-
}
280-
281-
// don't use an encoder because we can't easily differentiate writer errors from encoding errors
282-
b, err := drjson.MarshalJSONBytes(res.Val)
283-
if err != nil {
284-
logger.Errorw("ndjson marshal error", "Error", err)
285-
return
286-
}
287-
288-
if !hasResults {
289-
hasResults = true
290-
// There's results, cache useful result for longer
291-
setCacheControl(w, maxAgeWithResults, maxStale)
292-
}
250+
filteredIter := applyFiltersToIter(provIter, filterAddrs, filterProtocols)
293251

294-
_, err = w.Write(b)
295-
if err != nil {
296-
logger.Warn("ndjson write error", "Error", err)
297-
return
298-
}
299-
300-
_, err = w.Write([]byte{'\n'})
301-
if err != nil {
302-
logger.Warn("ndjson write error", "Error", err)
303-
return
304-
}
305-
306-
if f, ok := w.(http.Flusher); ok {
307-
f.Flush()
308-
}
309-
}
310-
311-
if !hasResults {
312-
// There weren't results, cache for shorter and send 404
313-
setCacheControl(w, maxAgeWithoutResults, maxStale)
314-
w.WriteHeader(http.StatusNotFound)
315-
}
252+
writeResultsIterNDJSON(w, filteredIter)
316253
}
317254

318255
func (s *server) findPeers(w http.ResponseWriter, r *http.Request) {

routing/http/types/iter/filter.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package iter
2+
3+
// Filter returns an iterator that filters out values that don't satisfy the predicate f.
4+
func Filter[T any](iter Iter[T], f func(t T) bool) *FilterIter[T] {
5+
return &FilterIter[T]{iter: iter, f: f}
6+
}
7+
8+
type FilterIter[T any] struct {
9+
iter Iter[T]
10+
f func(T) bool
11+
12+
done bool
13+
val T
14+
}
15+
16+
func (f *FilterIter[T]) Next() bool {
17+
if f.done {
18+
return false
19+
}
20+
21+
ok := f.iter.Next()
22+
f.done = !ok
23+
24+
if f.done {
25+
return false
26+
}
27+
28+
f.val = f.iter.Val()
29+
30+
if f.f(f.val) {
31+
return true
32+
}
33+
34+
return f.Next()
35+
}
36+
37+
func (f *FilterIter[T]) Val() T {
38+
return f.val
39+
}
40+
41+
func (f *FilterIter[T]) Close() error {
42+
return f.iter.Close()
43+
}
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
package iter
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestFilter(t *testing.T) {
11+
for _, c := range []struct {
12+
input Iter[int]
13+
f func(int) bool
14+
expResults []int
15+
}{
16+
{
17+
input: FromSlice([]int{1, 2, 3, 4}),
18+
f: func(i int) bool { return i%2 == 0 },
19+
expResults: []int{2, 4},
20+
},
21+
{
22+
input: FromSlice([]int{}),
23+
f: func(i int) bool { return i%2 == 0 },
24+
expResults: nil,
25+
},
26+
{
27+
input: FromSlice([]int{1, 3, 5, 100}),
28+
f: func(i int) bool { return i > 2 },
29+
expResults: []int{3, 5, 100},
30+
},
31+
} {
32+
t.Run(fmt.Sprintf("%v", c.input), func(t *testing.T) {
33+
iter := Filter(c.input, c.f)
34+
var res []int
35+
for iter.Next() {
36+
res = append(res, iter.Val())
37+
}
38+
assert.Equal(t, c.expResults, res)
39+
})
40+
}
41+
}

routing/http/types/record_peer.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,12 @@ func (pr PeerRecord) MarshalJSON() ([]byte, error) {
7979

8080
return drjson.MarshalJSONBytes(m)
8181
}
82+
83+
func FromBitswapRecord(br *BitswapRecord) *PeerRecord {
84+
return &PeerRecord{
85+
Schema: SchemaPeer,
86+
ID: br.ID,
87+
Addrs: br.Addrs,
88+
Protocols: []string{br.Protocol},
89+
}
90+
}

0 commit comments

Comments
 (0)