@@ -41,45 +41,76 @@ func (m *mockContentRouter) Provide(ctx context.Context, req *server.WriteProvid
4141}
4242
4343type testDeps struct {
44- router * mockContentRouter
45- server * httptest.Server
46- peerID peer.ID
47- addrs []multiaddr.Multiaddr
48- client * client
44+ // recordingHandler records requests received on the server side
45+ recordingHandler * recordingHandler
46+ // recordingHTTPClient records responses received on the client side
47+ recordingHTTPClient * recordingHTTPClient
48+ router * mockContentRouter
49+ server * httptest.Server
50+ peerID peer.ID
51+ addrs []multiaddr.Multiaddr
52+ client * client
4953}
5054
51- func makeTestDeps (t * testing.T ) testDeps {
55+ type recordingHandler struct {
56+ http.Handler
57+ f []func (* http.Request )
58+ }
59+
60+ func (h * recordingHandler ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
61+ for _ , f := range h .f {
62+ f (r )
63+ }
64+ h .Handler .ServeHTTP (w , r )
65+ }
66+
67+ type recordingHTTPClient struct {
68+ httpClient
69+ f []func (* http.Response )
70+ }
71+
72+ func (c * recordingHTTPClient ) Do (req * http.Request ) (* http.Response , error ) {
73+ resp , err := c .httpClient .Do (req )
74+ for _ , f := range c .f {
75+ f (resp )
76+ }
77+ return resp , err
78+ }
79+
80+ func makeTestDeps (t * testing.T , clientsOpts []Option , serverOpts []server.Option ) testDeps {
5281 const testUserAgent = "testUserAgent"
5382 peerID , addrs , identity := makeProviderAndIdentity ()
5483 router := & mockContentRouter {}
55- server := httptest .NewServer (server .Handler (router ))
84+ recordingHandler := & recordingHandler {
85+ Handler : server .Handler (router , serverOpts ... ),
86+ f : []func (* http.Request ){
87+ func (r * http.Request ) {
88+ assert .Equal (t , testUserAgent , r .Header .Get ("User-Agent" ))
89+ },
90+ },
91+ }
92+ server := httptest .NewServer (recordingHandler )
5693 t .Cleanup (server .Close )
5794 serverAddr := "http://" + server .Listener .Addr ().String ()
58- c , err := New (serverAddr , WithProviderInfo (peerID , addrs ), WithIdentity (identity ), WithUserAgent (testUserAgent ))
95+ recordingHTTPClient := & recordingHTTPClient {httpClient : defaultHTTPClient }
96+ defaultClientOpts := []Option {
97+ WithProviderInfo (peerID , addrs ),
98+ WithIdentity (identity ),
99+ WithUserAgent (testUserAgent ),
100+ WithHTTPClient (recordingHTTPClient ),
101+ }
102+ c , err := New (serverAddr , append (defaultClientOpts , clientsOpts ... )... )
59103 if err != nil {
60104 panic (err )
61105 }
62- assertUserAgentOverride (t , c , testUserAgent )
63106 return testDeps {
64- router : router ,
65- server : server ,
66- peerID : peerID ,
67- addrs : addrs ,
68- client : c ,
69- }
70- }
71-
72- func assertUserAgentOverride (t * testing.T , c * client , expected string ) {
73- httpClient , ok := c .httpClient .(* http.Client )
74- if ! ok {
75- t .Error ("invalid c.httpClient" )
76- }
77- transport , ok := httpClient .Transport .(* ResponseBodyLimitedTransport )
78- if ! ok {
79- t .Error ("invalid httpClient.Transport" )
80- }
81- if transport .UserAgent != expected {
82- t .Error ("invalid httpClient.Transport.UserAgent" )
107+ recordingHandler : recordingHandler ,
108+ recordingHTTPClient : recordingHTTPClient ,
109+ router : router ,
110+ server : server ,
111+ peerID : peerID ,
112+ addrs : addrs ,
113+ client : c ,
83114 }
84115}
85116
@@ -149,6 +180,10 @@ type osErrContains struct {
149180}
150181
151182func (e * osErrContains ) errContains (t * testing.T , err error ) {
183+ if e .expContains == "" && e .expContainsWin == "" {
184+ assert .NoError (t , err )
185+ return
186+ }
152187 if runtime .GOOS == "windows" && len (e .expContainsWin ) != 0 {
153188 assert .ErrorContains (t , err , e .expContainsWin )
154189 } else {
@@ -163,37 +198,90 @@ func TestClient_FindProviders(t *testing.T) {
163198 }
164199
165200 cases := []struct {
166- name string
167- httpStatusCode int
168- stopServer bool
169- routerProvs []iter.Result [types.ProviderResponse ]
170- routerErr error
171-
172- expProvs []iter.Result [types.ProviderResponse ]
173- expErrContains []osErrContains
201+ name string
202+ httpStatusCode int
203+ stopServer bool
204+ routerProvs []iter.Result [types.ProviderResponse ]
205+ routerErr error
206+ clientRequiresStreaming bool
207+ serverStreamingDisabled bool
208+
209+ expErrContains osErrContains
210+ expProvs []iter.Result [types.ProviderResponse ]
211+ expStreamingResponse bool
212+ expJSONResponse bool
174213 }{
175214 {
176- name : "happy case" ,
177- routerProvs : bitswapProvs ,
178- expProvs : bitswapProvs ,
215+ name : "happy case" ,
216+ routerProvs : bitswapProvs ,
217+ expProvs : bitswapProvs ,
218+ expStreamingResponse : true ,
219+ },
220+ {
221+ name : "server doesn't support streaming" ,
222+ routerProvs : bitswapProvs ,
223+ expProvs : bitswapProvs ,
224+ serverStreamingDisabled : true ,
225+ expJSONResponse : true ,
226+ },
227+ {
228+ name : "client requires streaming but server doesn't support it" ,
229+ serverStreamingDisabled : true ,
230+ clientRequiresStreaming : true ,
231+ expErrContains : osErrContains {expContains : "HTTP error with StatusCode=400: no supported content types" },
179232 },
180233 {
181234 name : "returns an error if there's a non-200 response" ,
182235 httpStatusCode : 500 ,
183- expErrContains : [] osErrContains {{ expContains : "HTTP error with StatusCode=500: " } },
236+ expErrContains : osErrContains {expContains : "HTTP error with StatusCode=500" },
184237 },
185238 {
186239 name : "returns an error if the HTTP client returns a non-HTTP error" ,
187240 stopServer : true ,
188- expErrContains : [] osErrContains { {
241+ expErrContains : osErrContains {
189242 expContains : "connect: connection refused" ,
190243 expContainsWin : "connectex: No connection could be made because the target machine actively refused it." ,
191- }},
244+ },
245+ },
246+ {
247+ name : "returns no providers if the HTTP server returns a 404 respones" ,
248+ httpStatusCode : 404 ,
249+ expProvs : nil ,
192250 },
193251 }
194252 for _ , c := range cases {
195253 t .Run (c .name , func (t * testing.T ) {
196- deps := makeTestDeps (t )
254+ var clientOpts []Option
255+ var serverOpts []server.Option
256+ var onRespReceived []func (* http.Response )
257+ var onReqReceived []func (* http.Request )
258+
259+ if c .serverStreamingDisabled {
260+ serverOpts = append (serverOpts , server .WithStreamingResultsDisabled ())
261+ }
262+ if c .clientRequiresStreaming {
263+ clientOpts = append (clientOpts , WithStreamResultsRequired ())
264+ onReqReceived = append (onReqReceived , func (r * http.Request ) {
265+ assert .Equal (t , mediaTypeNDJSON , r .Header .Get ("Accept" ))
266+ })
267+ }
268+
269+ if c .expStreamingResponse {
270+ onRespReceived = append (onRespReceived , func (r * http.Response ) {
271+ assert .Equal (t , mediaTypeNDJSON , r .Header .Get ("Content-Type" ))
272+ })
273+ }
274+ if c .expJSONResponse {
275+ onRespReceived = append (onRespReceived , func (r * http.Response ) {
276+ assert .Equal (t , mediaTypeJSON , r .Header .Get ("Content-Type" ))
277+ })
278+ }
279+
280+ deps := makeTestDeps (t , clientOpts , serverOpts )
281+
282+ deps .recordingHTTPClient .f = append (deps .recordingHTTPClient .f , onRespReceived ... )
283+ deps .recordingHandler .f = append (deps .recordingHandler .f , onReqReceived ... )
284+
197285 client := deps .client
198286 router := deps .router
199287
@@ -218,12 +306,7 @@ func TestClient_FindProviders(t *testing.T) {
218306
219307 provsIter , err := client .FindProviders (ctx , cid )
220308
221- for _ , exp := range c .expErrContains {
222- exp .errContains (t , err )
223- }
224- if len (c .expErrContains ) == 0 {
225- require .NoError (t , err )
226- }
309+ c .expErrContains .errContains (t , err )
227310
228311 provs := iter.ReadAll [iter.Result [types.ProviderResponse ]](provsIter )
229312 assert .Equal (t , c .expProvs , provs )
@@ -263,8 +346,7 @@ func TestClient_Provide(t *testing.T) {
263346 name : "should return a 403 if the payload signature verification fails" ,
264347 cids : []cid.Cid {},
265348 mangleSignature : true ,
266-
267- expErrContains : "HTTP error with StatusCode=403" ,
349+ expErrContains : "HTTP error with StatusCode=403" ,
268350 },
269351 {
270352 name : "should return error if identity is not provided" ,
@@ -290,7 +372,7 @@ func TestClient_Provide(t *testing.T) {
290372 }
291373 for _ , c := range cases {
292374 t .Run (c .name , func (t * testing.T ) {
293- deps := makeTestDeps (t )
375+ deps := makeTestDeps (t , nil , nil )
294376 client := deps .client
295377 router := deps .router
296378
0 commit comments