Skip to content

Commit 3e3290f

Browse files
ajanthanaeneasr
andauthored
feat: add support for response_mode=form_post (#509)
This patch introduces support for `response_mode=form_post` as well as `response_mode` of `none` and `query` and `fragment`. To support this new feature your OAuth2 Client must implement the `fosite.ResponseModeClient` interface. We suggest to always return all response modes there unless you want to explicitly disable one of the response modes: ```go func (c *Client) GetResponseModes() []fosite.ResponseModeType { return []fosite.ResponseModeType{ fosite.ResponseModeDefault, fosite.ResponseModeFormPost, fosite.ResponseModeQuery, fosite.ResponseModeFragment, } } ``` BREAKING CHANGES: As part of this change, methods `GetResponseMode`, `SetDefaultResponseMode`, `GetDefaultResponseMode ` where added to interface `AuthorizeRequester`. Also, methods `GetQuery`, `AddQuery`, and `GetFragment` were merged into one function `GetParameters` and `AddParameter` on the `AuthorizeResponder` interface. Methods on `AuthorizeRequest` and `AuthorizeResponse` changed accordingly and will need to be updated in your codebase. Additionally, the field `Debug` was renamed to `DebugField` and a new method `Debug() string` was added to `RFC6749Error`. Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
1 parent d6e45b3 commit 3e3290f

38 files changed

+1351
-256
lines changed

authorize_error.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ import (
2525
"encoding/json"
2626
"fmt"
2727
"net/http"
28-
29-
"github.com/pkg/errors"
3028
)
3129

3230
func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequester, err error) {
@@ -66,7 +64,11 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest
6664
query.Add("state", ar.GetState())
6765

6866
var redirectURIString string
69-
if !(len(ar.GetResponseTypes()) == 0 || ar.GetResponseTypes().ExactOne("code")) && !errors.Is(err, ErrUnsupportedResponseType) {
67+
if ar.GetResponseMode() == ResponseModeFormPost {
68+
rw.Header().Add("Content-Type", "text/html;charset=UTF-8")
69+
WriteAuthorizeFormPostResponse(redirectURI.String(), query, GetPostFormHTMLTemplate(*f), rw)
70+
return
71+
} else if ar.GetResponseMode() == ResponseModeFragment {
7072
redirectURIString = redirectURI.String() + "#" + query.Encode()
7173
} else {
7274
for key, values := range redirectURI.Query() {

authorize_error_test.go

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ func TestWriteAuthorizeError(t *testing.T) {
8888
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
8989
req.EXPECT().GetState().Return("foostate")
9090
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"}))
91+
req.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2)
9192
rw.EXPECT().Header().Times(3).Return(header)
9293
rw.EXPECT().WriteHeader(http.StatusFound)
9394
},
@@ -106,6 +107,7 @@ func TestWriteAuthorizeError(t *testing.T) {
106107
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
107108
req.EXPECT().GetState().Return("foostate")
108109
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"}))
110+
req.EXPECT().GetResponseMode().Return(ResponseModeDefault).Times(2)
109111
rw.EXPECT().Header().Times(3).Return(header)
110112
rw.EXPECT().WriteHeader(http.StatusFound)
111113
},
@@ -124,6 +126,7 @@ func TestWriteAuthorizeError(t *testing.T) {
124126
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
125127
req.EXPECT().GetState().Return("foostate")
126128
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"}))
129+
req.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2)
127130
rw.EXPECT().Header().Times(3).Return(header)
128131
rw.EXPECT().WriteHeader(http.StatusFound)
129132
},
@@ -142,6 +145,7 @@ func TestWriteAuthorizeError(t *testing.T) {
142145
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
143146
req.EXPECT().GetState().Return("foostate")
144147
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"foobar"}))
148+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
145149
rw.EXPECT().Header().Times(3).Return(header)
146150
rw.EXPECT().WriteHeader(http.StatusFound)
147151
},
@@ -160,6 +164,7 @@ func TestWriteAuthorizeError(t *testing.T) {
160164
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
161165
req.EXPECT().GetState().Return("foostate")
162166
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"}))
167+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
163168
rw.EXPECT().Header().Times(3).Return(header)
164169
rw.EXPECT().WriteHeader(http.StatusFound)
165170
},
@@ -178,6 +183,7 @@ func TestWriteAuthorizeError(t *testing.T) {
178183
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
179184
req.EXPECT().GetState().Return("foostate")
180185
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"}))
186+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
181187
rw.EXPECT().Header().Times(3).Return(header)
182188
rw.EXPECT().WriteHeader(http.StatusFound)
183189
},
@@ -196,6 +202,7 @@ func TestWriteAuthorizeError(t *testing.T) {
196202
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0]))
197203
req.EXPECT().GetState().Return("foostate")
198204
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"}))
205+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
199206
rw.EXPECT().Header().Times(3).Return(header)
200207
rw.EXPECT().WriteHeader(http.StatusFound)
201208
},
@@ -214,6 +221,7 @@ func TestWriteAuthorizeError(t *testing.T) {
214221
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
215222
req.EXPECT().GetState().Return("foostate")
216223
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"}))
224+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
217225
rw.EXPECT().Header().Times(3).Return(header)
218226
rw.EXPECT().WriteHeader(http.StatusFound)
219227
},
@@ -233,6 +241,7 @@ func TestWriteAuthorizeError(t *testing.T) {
233241
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
234242
req.EXPECT().GetState().Return("foostate")
235243
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"}))
244+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
236245
rw.EXPECT().Header().Times(3).Return(header)
237246
rw.EXPECT().WriteHeader(http.StatusFound)
238247
},
@@ -252,6 +261,7 @@ func TestWriteAuthorizeError(t *testing.T) {
252261
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
253262
req.EXPECT().GetState().Return("foostate")
254263
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"id_token"}))
264+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
255265
rw.EXPECT().Header().Times(3).Return(header)
256266
rw.EXPECT().WriteHeader(http.StatusFound)
257267
},
@@ -271,6 +281,7 @@ func TestWriteAuthorizeError(t *testing.T) {
271281
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
272282
req.EXPECT().GetState().Return("foostate")
273283
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"}))
284+
req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2)
274285
rw.EXPECT().Header().Times(3).Return(header)
275286
rw.EXPECT().WriteHeader(http.StatusFound)
276287
},
@@ -282,6 +293,24 @@ func TestWriteAuthorizeError(t *testing.T) {
282293
assert.Equal(t, "no-cache", header.Get("Pragma"))
283294
},
284295
},
296+
{
297+
debug: true,
298+
err: ErrInvalidRequest.WithDebug("with-debug"),
299+
mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) {
300+
req.EXPECT().IsRedirectURIValid().Return(true)
301+
req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1]))
302+
req.EXPECT().GetState().Return("foostate")
303+
req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"}))
304+
req.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(1)
305+
rw.EXPECT().Header().Times(3).Return(header)
306+
rw.EXPECT().Write(gomock.Any()).AnyTimes()
307+
},
308+
checkHeader: func(t *testing.T, k int) {
309+
assert.Equal(t, "no-store", header.Get("Cache-Control"))
310+
assert.Equal(t, "no-cache", header.Get("Pragma"))
311+
assert.Equal(t, "text/html;charset=UTF-8", header.Get("Content-Type"))
312+
},
313+
},
285314
} {
286315
t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) {
287316
oauth2 := &Fosite{

authorize_helper.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
package fosite
2323

2424
import (
25+
"fmt"
26+
"html/template"
27+
"io"
2528
"net/url"
2629
"regexp"
2730
"strings"
@@ -30,6 +33,21 @@ import (
3033
"github.com/pkg/errors"
3134
)
3235

36+
var FormPostDefaultTemplate = template.Must(template.New("form_post").Parse(`<html>
37+
<head>
38+
<title>Submit This Form</title>
39+
</head>
40+
<body onload="javascript:document.forms[0].submit()">
41+
<form method="post" action="{{ .RedirURL }}">
42+
{{ range $key,$value := .Parameters }}
43+
{{ range $parameter:= $value}}
44+
<input type="hidden" name="{{$key}}" value="{{$parameter}}"/>
45+
{{end}}
46+
{{ end }}
47+
</form>
48+
</body>
49+
</html>`))
50+
3351
// MatchRedirectURIWithClientRedirectURIs if the given uri is a registered redirect uri. Does not perform
3452
// uri validation.
3553
//
@@ -182,3 +200,35 @@ func IsLocalhost(redirectURI *url.URL) bool {
182200
hn := redirectURI.Hostname()
183201
return strings.HasSuffix(hn, ".localhost") || hn == "127.0.0.1" || hn == "::1" || hn == "localhost"
184202
}
203+
204+
func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, template *template.Template, rw io.Writer) {
205+
_ = template.Execute(rw, struct {
206+
RedirURL string
207+
Parameters url.Values
208+
}{
209+
RedirURL: redirectURL,
210+
Parameters: parameters,
211+
})
212+
}
213+
214+
func URLSetFragment(source *url.URL, fragment url.Values) {
215+
var f string
216+
for k, v := range fragment {
217+
for _, vv := range v {
218+
if len(f) != 0 {
219+
f += fmt.Sprintf("&%s=%s", k, vv)
220+
} else {
221+
f += fmt.Sprintf("%s=%s", k, vv)
222+
}
223+
}
224+
}
225+
source.Fragment = f
226+
}
227+
228+
func GetPostFormHTMLTemplate(f Fosite) *template.Template {
229+
formPostHTMLTemplate := f.FormPostHTMLTemplate
230+
if formPostHTMLTemplate == nil {
231+
formPostHTMLTemplate = FormPostDefaultTemplate
232+
}
233+
return formPostHTMLTemplate
234+
}

0 commit comments

Comments
 (0)