Skip to content

Commit f1fa9c4

Browse files
Merge pull request docker-archive-public#3799 from mrburrito/aws-iam-roles
Update AWS Credentials lookup to use AWS default provider chain.
2 parents befce21 + a26a550 commit f1fa9c4

File tree

5 files changed

+205
-119
lines changed

5 files changed

+205
-119
lines changed

drivers/amazonec2/amazonec2.go

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,22 @@ var (
5151
dockerPort = 2376
5252
swarmPort = 3376
5353
errorNoPrivateSSHKey = errors.New("using --amazonec2-keypair-name also requires --amazonec2-ssh-keypath")
54-
errorMissingAccessKeyOption = errors.New("amazonec2 driver requires the --amazonec2-access-key option or proper credentials in ~/.aws/credentials")
55-
errorMissingSecretKeyOption = errors.New("amazonec2 driver requires the --amazonec2-secret-key option or proper credentials in ~/.aws/credentials")
54+
errorMissingCredentials = errors.New("amazonec2 driver requires AWS credentials configured with the --amazonec2-access-key and --amazonec2-secret-key options, environment variables, ~/.aws/credentials, or an instance role")
5655
errorNoVPCIdFound = errors.New("amazonec2 driver requires either the --amazonec2-subnet-id or --amazonec2-vpc-id option or an AWS Account with a default vpc-id")
5756
errorDisableSSLWithoutCustomEndpoint = errors.New("using --amazonec2-insecure-transport also requires --amazonec2-endpoint")
5857
)
5958

6059
type Driver struct {
6160
*drivers.BaseDriver
62-
clientFactory func() Ec2Client
63-
awsCredentials awsCredentials
64-
Id string
65-
AccessKey string
66-
SecretKey string
67-
SessionToken string
68-
Region string
69-
AMI string
70-
SSHKeyID int
61+
clientFactory func() Ec2Client
62+
awsCredentialsFactory func() awsCredentials
63+
Id string
64+
AccessKey string
65+
SecretKey string
66+
SessionToken string
67+
Region string
68+
AMI string
69+
SSHKeyID int
7170
// ExistingKey keeps track of whether the key was created by us or we used an existing one. If an existing one was used, we shouldn't delete it when the machine is deleted.
7271
ExistingKey bool
7372
KeyName string
@@ -281,10 +280,10 @@ func NewDriver(hostName, storePath string) *Driver {
281280
MachineName: hostName,
282281
StorePath: storePath,
283282
},
284-
awsCredentials: &defaultAWSCredentials{},
285283
}
286284

287285
driver.clientFactory = driver.buildClient
286+
driver.awsCredentialsFactory = driver.buildCredentials
288287

289288
return driver
290289
}
@@ -293,7 +292,7 @@ func (d *Driver) buildClient() Ec2Client {
293292
config := aws.NewConfig()
294293
alogger := AwsLogger()
295294
config = config.WithRegion(d.Region)
296-
config = config.WithCredentials(d.awsCredentials.NewStaticCredentials(d.AccessKey, d.SecretKey, d.SessionToken))
295+
config = config.WithCredentials(d.awsCredentialsFactory().Credentials())
297296
config = config.WithLogger(alogger)
298297
config = config.WithLogLevel(aws.LogDebugWithHTTPBody)
299298
config = config.WithMaxRetries(d.RetryCount)
@@ -304,6 +303,10 @@ func (d *Driver) buildClient() Ec2Client {
304303
return ec2.New(session.New(config))
305304
}
306305

306+
func (d *Driver) buildCredentials() awsCredentials {
307+
return NewAWSCredentials(d.AccessKey, d.SecretKey, d.SessionToken)
308+
}
309+
307310
func (d *Driver) getClient() Ec2Client {
308311
return d.clientFactory()
309312
}
@@ -363,24 +366,9 @@ func (d *Driver) SetConfigFromFlags(flags drivers.DriverOptions) error {
363366
return errorNoPrivateSSHKey
364367
}
365368

366-
if d.AccessKey == "" && d.SecretKey == "" {
367-
credentials, err := d.awsCredentials.NewSharedCredentials("", "").Get()
368-
if err != nil {
369-
log.Debug("Could not load credentials from ~/.aws/credentials")
370-
} else {
371-
log.Debug("Successfully loaded credentials from ~/.aws/credentials")
372-
d.AccessKey = credentials.AccessKeyID
373-
d.SecretKey = credentials.SecretAccessKey
374-
d.SessionToken = credentials.SessionToken
375-
}
376-
}
377-
378-
if d.AccessKey == "" {
379-
return errorMissingAccessKeyOption
380-
}
381-
382-
if d.SecretKey == "" {
383-
return errorMissingSecretKeyOption
369+
_, err = d.awsCredentialsFactory().Credentials().Get()
370+
if err != nil {
371+
return errorMissingCredentials
384372
}
385373

386374
if d.VpcId == "" {

drivers/amazonec2/amazonec2_test.go

Lines changed: 12 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,9 @@ func TestValidateAwsRegionInvalid(t *testing.T) {
165165

166166
func TestFindDefaultVPC(t *testing.T) {
167167
driver := NewDriver("machineFoo", "path")
168-
driver.clientFactory = func() Ec2Client { return &fakeEC2WithLogin{} }
168+
driver.clientFactory = func() Ec2Client {
169+
return &fakeEC2WithLogin{}
170+
}
169171

170172
vpc, err := driver.getDefaultVPCId()
171173

@@ -191,7 +193,7 @@ func TestDefaultVPCIsMissing(t *testing.T) {
191193

192194
func TestGetRegionZoneForDefaultEndpoint(t *testing.T) {
193195
driver := NewCustomTestDriver(&fakeEC2WithLogin{})
194-
driver.awsCredentials = &fileCredentials{}
196+
driver.awsCredentialsFactory = NewValidAwsCredentials
195197
options := &commandstest.FakeFlagger{
196198
Data: map[string]interface{}{
197199
"name": "test",
@@ -210,7 +212,7 @@ func TestGetRegionZoneForDefaultEndpoint(t *testing.T) {
210212

211213
func TestGetRegionZoneForCustomEndpoint(t *testing.T) {
212214
driver := NewCustomTestDriver(&fakeEC2WithLogin{})
213-
driver.awsCredentials = &fileCredentials{}
215+
driver.awsCredentialsFactory = NewValidAwsCredentials
214216
options := &commandstest.FakeFlagger{
215217
Data: map[string]interface{}{
216218
"name": "test",
@@ -242,9 +244,10 @@ func TestDescribeAccountAttributeFails(t *testing.T) {
242244
assert.Empty(t, vpc)
243245
}
244246

245-
func TestAccessKeyIsMandatory(t *testing.T) {
247+
func TestAwsCredentialsAreRequired(t *testing.T) {
246248
driver := NewTestDriver()
247-
driver.awsCredentials = &cliCredentials{}
249+
driver.awsCredentialsFactory = NewErrorAwsCredentials
250+
248251
options := &commandstest.FakeFlagger{
249252
Data: map[string]interface{}{
250253
"name": "test",
@@ -254,47 +257,12 @@ func TestAccessKeyIsMandatory(t *testing.T) {
254257
}
255258

256259
err := driver.SetConfigFromFlags(options)
257-
258-
assert.Equal(t, err, errorMissingAccessKeyOption)
259-
}
260-
261-
func TestAccessKeyIsMandatoryEvenIfSecretKeyIsPassed(t *testing.T) {
262-
driver := NewTestDriver()
263-
driver.awsCredentials = &cliCredentials{}
264-
options := &commandstest.FakeFlagger{
265-
Data: map[string]interface{}{
266-
"name": "test",
267-
"amazonec2-secret-key": "123",
268-
"amazonec2-region": "us-east-1",
269-
"amazonec2-zone": "e",
270-
},
271-
}
272-
273-
err := driver.SetConfigFromFlags(options)
274-
275-
assert.Equal(t, err, errorMissingAccessKeyOption)
276-
}
277-
278-
func TestSecretKeyIsMandatory(t *testing.T) {
279-
driver := NewTestDriver()
280-
driver.awsCredentials = &cliCredentials{}
281-
options := &commandstest.FakeFlagger{
282-
Data: map[string]interface{}{
283-
"name": "test",
284-
"amazonec2-access-key": "foobar",
285-
"amazonec2-region": "us-east-1",
286-
"amazonec2-zone": "e",
287-
},
288-
}
289-
290-
err := driver.SetConfigFromFlags(options)
291-
292-
assert.Equal(t, err, errorMissingSecretKeyOption)
260+
assert.Equal(t, err, errorMissingCredentials)
293261
}
294262

295-
func TestLoadingFromCredentialsWorked(t *testing.T) {
263+
func TestValidAwsCredentialsAreAccepted(t *testing.T) {
296264
driver := NewCustomTestDriver(&fakeEC2WithLogin{})
297-
driver.awsCredentials = &fileCredentials{}
265+
driver.awsCredentialsFactory = NewValidAwsCredentials
298266
options := &commandstest.FakeFlagger{
299267
Data: map[string]interface{}{
300268
"name": "test",
@@ -304,36 +272,12 @@ func TestLoadingFromCredentialsWorked(t *testing.T) {
304272
}
305273

306274
err := driver.SetConfigFromFlags(options)
307-
308-
assert.NoError(t, err)
309-
assert.Equal(t, "access", driver.AccessKey)
310-
assert.Equal(t, "secret", driver.SecretKey)
311-
assert.Equal(t, "token", driver.SessionToken)
312-
}
313-
314-
func TestPassingBothCLIArgWorked(t *testing.T) {
315-
driver := NewCustomTestDriver(&fakeEC2WithLogin{})
316-
driver.awsCredentials = &cliCredentials{}
317-
options := &commandstest.FakeFlagger{
318-
Data: map[string]interface{}{
319-
"name": "test",
320-
"amazonec2-access-key": "foobar",
321-
"amazonec2-secret-key": "123",
322-
"amazonec2-region": "us-east-1",
323-
"amazonec2-zone": "e",
324-
},
325-
}
326-
327-
err := driver.SetConfigFromFlags(options)
328-
329275
assert.NoError(t, err)
330-
assert.Equal(t, "foobar", driver.AccessKey)
331-
assert.Equal(t, "123", driver.SecretKey)
332276
}
333277

334278
func TestEndpointIsMandatoryWhenSSLDisabled(t *testing.T) {
335279
driver := NewTestDriver()
336-
driver.awsCredentials = &cliCredentials{}
280+
driver.awsCredentialsFactory = NewValidAwsCredentials
337281
options := &commandstest.FakeFlagger{
338282
Data: map[string]interface{}{
339283
"name": "test",
Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,63 @@
11
package amazonec2
22

3-
import "github.com/aws/aws-sdk-go/aws/credentials"
3+
import (
4+
"github.com/aws/aws-sdk-go/aws/credentials"
5+
"github.com/aws/aws-sdk-go/aws/session"
6+
)
47

58
type awsCredentials interface {
6-
NewStaticCredentials(id, secret, token string) *credentials.Credentials
9+
Credentials() *credentials.Credentials
10+
}
11+
12+
type ProviderFactory interface {
13+
NewStaticProvider(id, secret, token string) credentials.Provider
14+
}
15+
16+
type defaultAWSCredentials struct {
17+
AccessKey string
18+
SecretKey string
19+
SessionToken string
20+
providerFactory ProviderFactory
21+
fallbackProvider awsCredentials
22+
}
23+
24+
func NewAWSCredentials(id, secret, token string) *defaultAWSCredentials {
25+
creds := defaultAWSCredentials{
26+
AccessKey: id,
27+
SecretKey: secret,
28+
SessionToken: token,
29+
fallbackProvider: &AwsDefaultCredentialsProvider{},
30+
providerFactory: &defaultProviderFactory{},
31+
}
32+
return &creds
33+
}
734

8-
NewSharedCredentials(filename, profile string) *credentials.Credentials
35+
func (c *defaultAWSCredentials) Credentials() *credentials.Credentials {
36+
providers := []credentials.Provider{}
37+
if c.AccessKey != "" && c.SecretKey != "" {
38+
providers = append(providers, c.providerFactory.NewStaticProvider(c.AccessKey, c.SecretKey, c.SessionToken))
39+
}
40+
if c.fallbackProvider != nil {
41+
fallbackCreds, err := c.fallbackProvider.Credentials().Get()
42+
if err == nil {
43+
providers = append(providers, &credentials.StaticProvider{Value: fallbackCreds})
44+
}
45+
}
46+
return credentials.NewChainCredentials(providers)
947
}
1048

11-
type defaultAWSCredentials struct{}
49+
type AwsDefaultCredentialsProvider struct{}
1250

13-
func (c *defaultAWSCredentials) NewStaticCredentials(id, secret, token string) *credentials.Credentials {
14-
return credentials.NewStaticCredentials(id, secret, token)
51+
func (c *AwsDefaultCredentialsProvider) Credentials() *credentials.Credentials {
52+
return session.New().Config.Credentials
1553
}
1654

17-
func (c *defaultAWSCredentials) NewSharedCredentials(filename, profile string) *credentials.Credentials {
18-
return credentials.NewSharedCredentials(filename, profile)
55+
type defaultProviderFactory struct{}
56+
57+
func (c *defaultProviderFactory) NewStaticProvider(id, secret, token string) credentials.Provider {
58+
return &credentials.StaticProvider{Value: credentials.Value{
59+
AccessKeyID: id,
60+
SecretAccessKey: secret,
61+
SessionToken: token,
62+
}}
1963
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package amazonec2
2+
3+
import (
4+
"github.com/stretchr/testify/assert"
5+
"testing"
6+
)
7+
8+
func TestAccessKeyIsMandatoryWhenSystemCredentialsAreNotPresent(t *testing.T) {
9+
awsCreds := NewAWSCredentials("", "", "")
10+
awsCreds.fallbackProvider = nil
11+
12+
_, err := awsCreds.Credentials().Get()
13+
assert.Error(t, err)
14+
}
15+
16+
func TestAccessKeyIsMandatoryEvenIfSecretKeyIsPassedWhenSystemCredentialsAreNotPresent(t *testing.T) {
17+
awsCreds := NewAWSCredentials("", "secret", "")
18+
awsCreds.fallbackProvider = nil
19+
20+
_, err := awsCreds.Credentials().Get()
21+
assert.Error(t, err)
22+
}
23+
24+
func TestSecretKeyIsMandatoryWhenSystemCredentialsAreNotPresent(t *testing.T) {
25+
awsCreds := NewAWSCredentials("access", "", "")
26+
awsCreds.fallbackProvider = nil
27+
28+
_, err := awsCreds.Credentials().Get()
29+
assert.Error(t, err)
30+
}
31+
32+
func TestFallbackCredentialsAreLoadedWhenAccessKeyAndSecretKeyAreMissing(t *testing.T) {
33+
awsCreds := NewAWSCredentials("", "", "")
34+
awsCreds.fallbackProvider = &fallbackCredentials{}
35+
36+
creds, err := awsCreds.Credentials().Get()
37+
38+
assert.NoError(t, err)
39+
assert.Equal(t, "fallback_access", creds.AccessKeyID)
40+
assert.Equal(t, "fallback_secret", creds.SecretAccessKey)
41+
assert.Equal(t, "fallback_token", creds.SessionToken)
42+
}
43+
44+
func TestFallbackCredentialsAreLoadedWhenAccessKeyIsMissing(t *testing.T) {
45+
awsCreds := NewAWSCredentials("", "secret", "")
46+
awsCreds.fallbackProvider = &fallbackCredentials{}
47+
48+
creds, err := awsCreds.Credentials().Get()
49+
50+
assert.NoError(t, err)
51+
assert.Equal(t, "fallback_access", creds.AccessKeyID)
52+
assert.Equal(t, "fallback_secret", creds.SecretAccessKey)
53+
assert.Equal(t, "fallback_token", creds.SessionToken)
54+
}
55+
56+
func TestFallbackCredentialsAreLoadedWhenSecretKeyIsMissing(t *testing.T) {
57+
awsCreds := NewAWSCredentials("access", "", "")
58+
awsCreds.fallbackProvider = &fallbackCredentials{}
59+
60+
creds, err := awsCreds.Credentials().Get()
61+
62+
assert.NoError(t, err)
63+
assert.Equal(t, "fallback_access", creds.AccessKeyID)
64+
assert.Equal(t, "fallback_secret", creds.SecretAccessKey)
65+
assert.Equal(t, "fallback_token", creds.SessionToken)
66+
}
67+
68+
func TestOptionCredentialsAreLoadedWhenAccessKeyAndSecretKeyAreProvided(t *testing.T) {
69+
awsCreds := NewAWSCredentials("access", "secret", "")
70+
awsCreds.fallbackProvider = &fallbackCredentials{}
71+
72+
creds, err := awsCreds.Credentials().Get()
73+
74+
assert.NoError(t, err)
75+
assert.Equal(t, "access", creds.AccessKeyID)
76+
assert.Equal(t, "secret", creds.SecretAccessKey)
77+
assert.Equal(t, "", creds.SessionToken)
78+
}
79+
80+
func TestFallbackCredentialsAreLoadedIfStaticCredentialsGenerateError(t *testing.T) {
81+
awsCreds := NewAWSCredentials("access", "secret", "token")
82+
awsCreds.fallbackProvider = &fallbackCredentials{}
83+
awsCreds.providerFactory = &errorCredentialsProvider{}
84+
85+
creds, err := awsCreds.Credentials().Get()
86+
87+
assert.NoError(t, err)
88+
assert.Equal(t, "fallback_access", creds.AccessKeyID)
89+
assert.Equal(t, "fallback_secret", creds.SecretAccessKey)
90+
assert.Equal(t, "fallback_token", creds.SessionToken)
91+
}
92+
93+
func TestErrorGeneratedWhenAllProvidersGenerateErrors(t *testing.T) {
94+
awsCreds := NewAWSCredentials("access", "secret", "token")
95+
awsCreds.fallbackProvider = &errorFallbackCredentials{}
96+
awsCreds.providerFactory = &errorCredentialsProvider{}
97+
98+
_, err := awsCreds.Credentials().Get()
99+
assert.Error(t, err)
100+
}

0 commit comments

Comments
 (0)