aboutsummaryrefslogtreecommitdiff
path: root/credentials.go
blob: a9412da2dcbf517537456cf5e1f0eabe6766a39d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package main

import (
	"time"

	"github.com/aws/aws-sdk-go/aws"
	"github.com/aws/aws-sdk-go/aws/credentials"
	"github.com/aws/aws-sdk-go/aws/session"
	"github.com/aws/aws-sdk-go/service/sts"
	jww "github.com/spf13/jwalterweatherman"
)

// Try to refresh credentials 3 times an hour but in the worst case if the
// credential refresh fails twice try to get one last refresh in before the end
// of the hour when the credential expires.
const REFRESH_INTERVAL = time.Duration(19) * time.Minute

type CredentialHandler interface {
	Start()
	InGoodState() bool
	SetBootstrapCredential(*credentials.Credentials)
	Output() chan *IAMCredentials
}

type credentialHandler struct {
	region         *string
	roleARN        *string
	sessionName    *string
	bootstrapCreds *credentials.Credentials
	output         chan *IAMCredentials
	input          chan *credentials.Credentials
}

func NewCredentialHandler(region, arn, name *string) CredentialHandler {
	return &credentialHandler{
		region:      region,
		roleARN:     arn,
		sessionName: name,
		output:      make(chan *IAMCredentials),
		input:       make(chan *credentials.Credentials, 1), // 1-item buffer to allow pre-start bootstrapping
	}
}

func (h *credentialHandler) Output() chan *IAMCredentials {
	return h.output
}

func (h *credentialHandler) InGoodState() bool {
	c := <-h.Output()
	return c.Code == "Success"
}

func (h *credentialHandler) SetBootstrapCredential(bc *credentials.Credentials) {
	h.input <- bc
}

func (h *credentialHandler) Start() {
	c := &IAMCredentials{Code: "Failure"}
	updateChan := make(chan *IAMCredentials)

	ticker := time.NewTicker(REFRESH_INTERVAL)
	defer ticker.Stop()

	jww.INFO.Printf("Starting credential handler, awaiting bootstrap")

	for {
		select {
		// Read and update bootstrap credentials
		case h.bootstrapCreds = <-h.input:
			go h.refreshCredential(nil, updateChan)
		// HTTP handler requests credential
		case h.output <- c:
		// Time to refresh credentials
		case <-ticker.C:
			go h.refreshCredential(c.rawCredentials, updateChan)
		// Updated credentials arrive
		case up := <-updateChan:
			if up == nil && c.Expiration.After(time.Now()) {
				c = &IAMCredentials{Code: "Failure"}
			} else {
				c = up
			}
		}
	}
}

func (h *credentialHandler) refreshCredential(creds *credentials.Credentials, out chan *IAMCredentials) {
	jww.INFO.Printf("Attempting to obtain credentials")

	if creds == nil && h.bootstrapCreds == nil {
		jww.WARN.Printf("No session or bootstrap credentials available")
		return
	}

	if creds != nil {
		jww.DEBUG.Printf("Attempting to use session credentials")

		c, err := h.assumeRole(creds)
		if err != nil {
			jww.WARN.Printf("Failed to obtain with session credentials: %s", err)
		} else {
			jww.INFO.Printf("Successfully obtained credentials")
			out <- c
			return
		}
	}

	if h.bootstrapCreds != nil {
		jww.DEBUG.Printf("Attempting to use bootstrap credentials")

		c, err := h.assumeRole(h.bootstrapCreds)
		if err != nil {
			jww.WARN.Printf("Failed to obtain with bootstrap credentials: %s", err)
		} else {
			jww.INFO.Printf("Successfully obtained credentials")
			out <- c
			return
		}
	}

	jww.ERROR.Printf("Failed to obtain credentials")
	out <- nil
}

func (h *credentialHandler) assumeRole(creds *credentials.Credentials) (*IAMCredentials, error) {
	ses := session.New(&aws.Config{
		Region:      h.region,
		Credentials: creds,
	})

	assumed, err := sts.New(ses).AssumeRole(&sts.AssumeRoleInput{
		RoleArn:         h.roleARN,
		RoleSessionName: h.sessionName,
	})
	if err != nil {
		return nil, err
	}

	return &IAMCredentials{
		Code:            "Success",
		Type:            "AWS-HMAC",
		AccessKeyId:     *assumed.Credentials.AccessKeyId,
		SecretAccessKey: *assumed.Credentials.SecretAccessKey,
		Token:           *assumed.Credentials.SessionToken,
		LastUpdated:     time.Now().UTC().Round(time.Second),
		Expiration:      *assumed.Credentials.Expiration,
		rawCredentials: credentials.NewStaticCredentials(
			*assumed.Credentials.AccessKeyId,
			*assumed.Credentials.SecretAccessKey,
			*assumed.Credentials.SessionToken,
		),
	}, nil
}