Skip to content

Commit 407532f

Browse files
jasonodonnellcalvn
andauthored
creds/aws: Add support for DSA signature verification for EC2 (#12340) (#12361)
* creds/aws: import pkcs7 verification package * Add DSA support * changelog * Add DSA to correct verify function * Remove unneeded tests * Fix backend test * Update builtin/credential/aws/pkcs7/README.md Co-authored-by: Calvin Leung Huang <1883212+calvn@users.noreply.github.com> * Update builtin/credential/aws/path_login.go Co-authored-by: Calvin Leung Huang <1883212+calvn@users.noreply.github.com> Co-authored-by: Calvin Leung Huang <1883212+calvn@users.noreply.github.com> Co-authored-by: Calvin Leung Huang <1883212+calvn@users.noreply.github.com>
1 parent b882dde commit 407532f

File tree

18 files changed

+2973
-6
lines changed

18 files changed

+2973
-6
lines changed

builtin/credential/aws/backend_test.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,6 +1129,22 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndAccessListIdentity(t *testing
11291129
}
11301130
}
11311131

1132+
// Configure additional metadata to be returned for ec2 logins.
1133+
identity := map[string]interface{}{
1134+
"ec2_metadata": []string{"instance_id", "region", "ami_id"},
1135+
}
1136+
1137+
// store the identity
1138+
_, err = b.HandleRequest(context.Background(), &logical.Request{
1139+
Operation: logical.UpdateOperation,
1140+
Storage: storage,
1141+
Path: "config/identity",
1142+
Data: identity,
1143+
})
1144+
if err != nil {
1145+
t.Fatal(err)
1146+
}
1147+
11321148
loginInput := map[string]interface{}{
11331149
"pkcs7": pkcs7,
11341150
"nonce": "vault-client-nonce",
@@ -1241,6 +1257,7 @@ func TestBackendAcc_LoginWithInstanceIdentityDocAndAccessListIdentity(t *testing
12411257
delete(loginInput, "pkcs7")
12421258
loginInput["identity"] = identityDoc
12431259
loginInput["signature"] = identityDocSig
1260+
12441261
resp, err = b.HandleRequest(context.Background(), loginRequest)
12451262
if err != nil {
12461263
t.Fatal(err)

builtin/credential/aws/path_login.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@ import (
2020
awsClient "github.com/aws/aws-sdk-go/aws/client"
2121
"github.com/aws/aws-sdk-go/service/ec2"
2222
"github.com/aws/aws-sdk-go/service/iam"
23-
"github.com/fullsailor/pkcs7"
2423
"github.com/hashicorp/errwrap"
2524
cleanhttp "github.com/hashicorp/go-cleanhttp"
2625
"github.com/hashicorp/go-retryablehttp"
2726
uuid "github.com/hashicorp/go-uuid"
27+
"github.com/hashicorp/vault/builtin/credential/aws/pkcs7"
2828
"github.com/hashicorp/vault/sdk/framework"
2929
"github.com/hashicorp/vault/sdk/helper/awsutil"
3030
"github.com/hashicorp/vault/sdk/helper/cidrutil"
@@ -348,8 +348,8 @@ func (b *backend) parseIdentityDocument(ctx context.Context, s logical.Storage,
348348

349349
// Verify extracts the authenticated attributes in the PKCS#7 signature, and verifies
350350
// the authenticity of the content using 'dsa.PublicKey' embedded in the public certificate.
351-
if pkcs7Data.Verify() != nil {
352-
return nil, fmt.Errorf("failed to verify the signature")
351+
if err := pkcs7Data.Verify(); err != nil {
352+
return nil, fmt.Errorf("failed to verify the signature: %w", err)
353353
}
354354

355355
// Check if the signature has content inside of it
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# PKCS7
2+
3+
This code is used to verify PKCS7 signatures for the EC2 auth method. The code
4+
was forked from [mozilla-services/pkcs7](https://github.com/mozilla-services/pkcs7)
5+
and modified for Vault.
Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
package pkcs7
2+
3+
import (
4+
"bytes"
5+
"errors"
6+
)
7+
8+
var encodeIndent = 0
9+
10+
type asn1Object interface {
11+
EncodeTo(writer *bytes.Buffer) error
12+
}
13+
14+
type asn1Structured struct {
15+
tagBytes []byte
16+
content []asn1Object
17+
}
18+
19+
func (s asn1Structured) EncodeTo(out *bytes.Buffer) error {
20+
//fmt.Printf("%s--> tag: % X\n", strings.Repeat("| ", encodeIndent), s.tagBytes)
21+
encodeIndent++
22+
inner := new(bytes.Buffer)
23+
for _, obj := range s.content {
24+
err := obj.EncodeTo(inner)
25+
if err != nil {
26+
return err
27+
}
28+
}
29+
encodeIndent--
30+
out.Write(s.tagBytes)
31+
encodeLength(out, inner.Len())
32+
out.Write(inner.Bytes())
33+
return nil
34+
}
35+
36+
type asn1Primitive struct {
37+
tagBytes []byte
38+
length int
39+
content []byte
40+
}
41+
42+
func (p asn1Primitive) EncodeTo(out *bytes.Buffer) error {
43+
_, err := out.Write(p.tagBytes)
44+
if err != nil {
45+
return err
46+
}
47+
if err = encodeLength(out, p.length); err != nil {
48+
return err
49+
}
50+
//fmt.Printf("%s--> tag: % X length: %d\n", strings.Repeat("| ", encodeIndent), p.tagBytes, p.length)
51+
//fmt.Printf("%s--> content length: %d\n", strings.Repeat("| ", encodeIndent), len(p.content))
52+
out.Write(p.content)
53+
54+
return nil
55+
}
56+
57+
func ber2der(ber []byte) ([]byte, error) {
58+
if len(ber) == 0 {
59+
return nil, errors.New("ber2der: input ber is empty")
60+
}
61+
//fmt.Printf("--> ber2der: Transcoding %d bytes\n", len(ber))
62+
out := new(bytes.Buffer)
63+
64+
obj, _, err := readObject(ber, 0)
65+
if err != nil {
66+
return nil, err
67+
}
68+
obj.EncodeTo(out)
69+
70+
// if offset < len(ber) {
71+
// return nil, fmt.Errorf("ber2der: Content longer than expected. Got %d, expected %d", offset, len(ber))
72+
//}
73+
74+
return out.Bytes(), nil
75+
}
76+
77+
// encodes lengths that are longer than 127 into string of bytes
78+
func marshalLongLength(out *bytes.Buffer, i int) (err error) {
79+
n := lengthLength(i)
80+
81+
for ; n > 0; n-- {
82+
err = out.WriteByte(byte(i >> uint((n-1)*8)))
83+
if err != nil {
84+
return
85+
}
86+
}
87+
88+
return nil
89+
}
90+
91+
// computes the byte length of an encoded length value
92+
func lengthLength(i int) (numBytes int) {
93+
numBytes = 1
94+
for i > 255 {
95+
numBytes++
96+
i >>= 8
97+
}
98+
return
99+
}
100+
101+
// encodes the length in DER format
102+
// If the length fits in 7 bits, the value is encoded directly.
103+
//
104+
// Otherwise, the number of bytes to encode the length is first determined.
105+
// This number is likely to be 4 or less for a 32bit length. This number is
106+
// added to 0x80. The length is encoded in big endian encoding follow after
107+
//
108+
// Examples:
109+
// length | byte 1 | bytes n
110+
// 0 | 0x00 | -
111+
// 120 | 0x78 | -
112+
// 200 | 0x81 | 0xC8
113+
// 500 | 0x82 | 0x01 0xF4
114+
//
115+
func encodeLength(out *bytes.Buffer, length int) (err error) {
116+
if length >= 128 {
117+
l := lengthLength(length)
118+
err = out.WriteByte(0x80 | byte(l))
119+
if err != nil {
120+
return
121+
}
122+
err = marshalLongLength(out, length)
123+
if err != nil {
124+
return
125+
}
126+
} else {
127+
err = out.WriteByte(byte(length))
128+
if err != nil {
129+
return
130+
}
131+
}
132+
return
133+
}
134+
135+
func readObject(ber []byte, offset int) (asn1Object, int, error) {
136+
berLen := len(ber)
137+
if offset >= berLen {
138+
return nil, 0, errors.New("ber2der: offset is after end of ber data")
139+
}
140+
tagStart := offset
141+
b := ber[offset]
142+
offset++
143+
if offset >= berLen {
144+
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
145+
}
146+
tag := b & 0x1F // last 5 bits
147+
if tag == 0x1F {
148+
tag = 0
149+
for ber[offset] >= 0x80 {
150+
tag = tag*128 + ber[offset] - 0x80
151+
offset++
152+
if offset > berLen {
153+
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
154+
}
155+
}
156+
// jvehent 20170227: this doesn't appear to be used anywhere...
157+
//tag = tag*128 + ber[offset] - 0x80
158+
offset++
159+
if offset > berLen {
160+
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
161+
}
162+
}
163+
tagEnd := offset
164+
165+
kind := b & 0x20
166+
if kind == 0 {
167+
debugprint("--> Primitive\n")
168+
} else {
169+
debugprint("--> Constructed\n")
170+
}
171+
// read length
172+
var length int
173+
l := ber[offset]
174+
offset++
175+
if offset > berLen {
176+
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
177+
}
178+
hack := 0
179+
if l > 0x80 {
180+
numberOfBytes := (int)(l & 0x7F)
181+
if numberOfBytes > 4 { // int is only guaranteed to be 32bit
182+
return nil, 0, errors.New("ber2der: BER tag length too long")
183+
}
184+
if numberOfBytes == 4 && (int)(ber[offset]) > 0x7F {
185+
return nil, 0, errors.New("ber2der: BER tag length is negative")
186+
}
187+
if (int)(ber[offset]) == 0x0 {
188+
return nil, 0, errors.New("ber2der: BER tag length has leading zero")
189+
}
190+
debugprint("--> (compute length) indicator byte: %x\n", l)
191+
debugprint("--> (compute length) length bytes: % X\n", ber[offset:offset+numberOfBytes])
192+
for i := 0; i < numberOfBytes; i++ {
193+
length = length*256 + (int)(ber[offset])
194+
offset++
195+
if offset > berLen {
196+
return nil, 0, errors.New("ber2der: cannot move offset forward, end of ber data reached")
197+
}
198+
}
199+
} else if l == 0x80 {
200+
// find length by searching content
201+
markerIndex := bytes.LastIndex(ber[offset:], []byte{0x0, 0x0})
202+
if markerIndex == -1 {
203+
return nil, 0, errors.New("ber2der: Invalid BER format")
204+
}
205+
length = markerIndex
206+
hack = 2
207+
debugprint("--> (compute length) marker found at offset: %d\n", markerIndex+offset)
208+
} else {
209+
length = (int)(l)
210+
}
211+
if length < 0 {
212+
return nil, 0, errors.New("ber2der: invalid negative value found in BER tag length")
213+
}
214+
//fmt.Printf("--> length : %d\n", length)
215+
contentEnd := offset + length
216+
if contentEnd > len(ber) {
217+
return nil, 0, errors.New("ber2der: BER tag length is more than available data")
218+
}
219+
debugprint("--> content start : %d\n", offset)
220+
debugprint("--> content end : %d\n", contentEnd)
221+
debugprint("--> content : % X\n", ber[offset:contentEnd])
222+
var obj asn1Object
223+
if kind == 0 {
224+
obj = asn1Primitive{
225+
tagBytes: ber[tagStart:tagEnd],
226+
length: length,
227+
content: ber[offset:contentEnd],
228+
}
229+
} else {
230+
var subObjects []asn1Object
231+
for offset < contentEnd {
232+
var subObj asn1Object
233+
var err error
234+
subObj, offset, err = readObject(ber[:contentEnd], offset)
235+
if err != nil {
236+
return nil, 0, err
237+
}
238+
subObjects = append(subObjects, subObj)
239+
}
240+
obj = asn1Structured{
241+
tagBytes: ber[tagStart:tagEnd],
242+
content: subObjects,
243+
}
244+
}
245+
246+
return obj, contentEnd + hack, nil
247+
}
248+
249+
func debugprint(format string, a ...interface{}) {
250+
//fmt.Printf(format, a)
251+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package pkcs7
2+
3+
import (
4+
"bytes"
5+
"encoding/asn1"
6+
"strings"
7+
"testing"
8+
)
9+
10+
func TestBer2Der(t *testing.T) {
11+
// indefinite length fixture
12+
ber := []byte{0x30, 0x80, 0x02, 0x01, 0x01, 0x00, 0x00}
13+
expected := []byte{0x30, 0x03, 0x02, 0x01, 0x01}
14+
der, err := ber2der(ber)
15+
if err != nil {
16+
t.Fatalf("ber2der failed with error: %v", err)
17+
}
18+
if !bytes.Equal(der, expected) {
19+
t.Errorf("ber2der result did not match.\n\tExpected: % X\n\tActual: % X", expected, der)
20+
}
21+
22+
if der2, err := ber2der(der); err != nil {
23+
t.Errorf("ber2der on DER bytes failed with error: %v", err)
24+
} else {
25+
if !bytes.Equal(der, der2) {
26+
t.Error("ber2der is not idempotent")
27+
}
28+
}
29+
var thing struct {
30+
Number int
31+
}
32+
rest, err := asn1.Unmarshal(der, &thing)
33+
if err != nil {
34+
t.Errorf("Cannot parse resulting DER because: %v", err)
35+
} else if len(rest) > 0 {
36+
t.Errorf("Resulting DER has trailing data: % X", rest)
37+
}
38+
}
39+
40+
func TestBer2Der_Negatives(t *testing.T) {
41+
fixtures := []struct {
42+
Input []byte
43+
ErrorContains string
44+
}{
45+
{[]byte{0x30, 0x85}, "tag length too long"},
46+
{[]byte{0x30, 0x84, 0x80, 0x0, 0x0, 0x0}, "length is negative"},
47+
{[]byte{0x30, 0x82, 0x0, 0x1}, "length has leading zero"},
48+
{[]byte{0x30, 0x80, 0x1, 0x2}, "Invalid BER format"},
49+
{[]byte{0x30, 0x03, 0x01, 0x02}, "length is more than available data"},
50+
{[]byte{0x30}, "end of ber data reached"},
51+
}
52+
53+
for _, fixture := range fixtures {
54+
_, err := ber2der(fixture.Input)
55+
if err == nil {
56+
t.Errorf("No error thrown. Expected: %s", fixture.ErrorContains)
57+
}
58+
if !strings.Contains(err.Error(), fixture.ErrorContains) {
59+
t.Errorf("Unexpected error thrown.\n\tExpected: /%s/\n\tActual: %s", fixture.ErrorContains, err.Error())
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)