aboutsummaryrefslogtreecommitdiff
path: root/main.go
diff options
context:
space:
mode:
Diffstat (limited to 'main.go')
-rw-r--r--main.go180
1 files changed, 170 insertions, 10 deletions
diff --git a/main.go b/main.go
index a1a2625..6c153fa 100644
--- a/main.go
+++ b/main.go
@@ -2,13 +2,23 @@ package main
2 2
3import ( 3import (
4 "bytes" 4 "bytes"
5 "context"
6 "flag"
5 "fmt" 7 "fmt"
6 "log" 8 "log"
9 "net/http"
7 "os" 10 "os"
8 11
9 "code.crute.us/mcrute/ses-smtpd-proxy/smtpd" 12 "code.crute.us/mcrute/ses-smtpd-proxy/smtpd"
13 "github.com/aws/aws-sdk-go/aws"
14 "github.com/aws/aws-sdk-go/aws/credentials"
10 "github.com/aws/aws-sdk-go/aws/session" 15 "github.com/aws/aws-sdk-go/aws/session"
11 "github.com/aws/aws-sdk-go/service/ses" 16 "github.com/aws/aws-sdk-go/service/ses"
17 "github.com/hashicorp/vault/api"
18 "github.com/hashicorp/vault/api/auth/approle"
19 "github.com/prometheus/client_golang/prometheus"
20 "github.com/prometheus/client_golang/prometheus/promauto"
21 "github.com/prometheus/client_golang/prometheus/promhttp"
12) 22)
13 23
14const ( 24const (
@@ -18,6 +28,34 @@ const (
18 28
19var sesClient *ses.SES 29var sesClient *ses.SES
20 30
31var (
32 emailSent = promauto.NewCounter(prometheus.CounterOpts{
33 Namespace: "smtpd",
34 Name: "email_send_success_total",
35 Help: "Total number of successfuly sent emails",
36 })
37 emailError = promauto.NewCounterVec(prometheus.CounterOpts{
38 Namespace: "smtpd",
39 Name: "email_send_fail_total",
40 Help: "Total number emails that failed to send",
41 }, []string{"type"})
42 sesError = promauto.NewCounter(prometheus.CounterOpts{
43 Namespace: "smtpd",
44 Name: "ses_error_total",
45 Help: "Total number errors with SES",
46 })
47 credentialRenewalSuccess = promauto.NewCounter(prometheus.CounterOpts{
48 Namespace: "smtpd",
49 Name: "credential_renewal_success_total",
50 Help: "Total number successful credential renewals",
51 })
52 credentialRenewalError = promauto.NewCounter(prometheus.CounterOpts{
53 Namespace: "smtpd",
54 Name: "credential_renewal_error_total",
55 Help: "Total number errors during credential renewal",
56 })
57)
58
21type Envelope struct { 59type Envelope struct {
22 from string 60 from string
23 rcpts []*string 61 rcpts []*string
@@ -32,6 +70,7 @@ func (e *Envelope) AddRecipient(rcpt smtpd.MailAddress) error {
32 70
33func (e *Envelope) BeginData() error { 71func (e *Envelope) BeginData() error {
34 if len(e.rcpts) == 0 { 72 if len(e.rcpts) == 0 {
73 emailError.With(prometheus.Labels{"type": "no valid recipients"}).Inc()
35 return smtpd.SMTPError("554 5.5.1 Error: no valid recipients") 74 return smtpd.SMTPError("554 5.5.1 Error: no valid recipients")
36 } 75 }
37 return nil 76 return nil
@@ -40,7 +79,8 @@ func (e *Envelope) BeginData() error {
40func (e *Envelope) Write(line []byte) error { 79func (e *Envelope) Write(line []byte) error {
41 e.b.Write(line) 80 e.b.Write(line)
42 if e.b.Len() > SesSizeLimit { // SES limitation 81 if e.b.Len() > SesSizeLimit { // SES limitation
43 log.Printf("message size %d exceeds SES limit of %d\n", e.b.Len(), SesSizeLimit) 82 emailError.With(prometheus.Labels{"type": "minimum message size exceed"}).Inc()
83 log.Printf("message size %d exceeds SES limit of %d", e.b.Len(), SesSizeLimit)
44 return smtpd.SMTPError("554 5.5.1 Error: maximum message size exceeded") 84 return smtpd.SMTPError("554 5.5.1 Error: maximum message size exceeded")
45 } 85 }
46 return nil 86 return nil
@@ -52,10 +92,10 @@ func (e *Envelope) logMessageSend() {
52 dr[i] = *e.rcpts[i] 92 dr[i] = *e.rcpts[i]
53 } 93 }
54 log.Printf("sending message from %+v to %+v", e.from, dr) 94 log.Printf("sending message from %+v to %+v", e.from, dr)
95 emailSent.Inc()
55} 96}
56 97
57func (e *Envelope) Close() error { 98func (e *Envelope) Close() error {
58 e.logMessageSend()
59 r := &ses.SendRawEmailInput{ 99 r := &ses.SendRawEmailInput{
60 Source: &e.from, 100 Source: &e.from,
61 Destinations: e.rcpts, 101 Destinations: e.rcpts,
@@ -64,30 +104,150 @@ func (e *Envelope) Close() error {
64 _, err := sesClient.SendRawEmail(r) 104 _, err := sesClient.SendRawEmail(r)
65 if err != nil { 105 if err != nil {
66 log.Printf("ERROR: ses: %v", err) 106 log.Printf("ERROR: ses: %v", err)
67 return smtpd.SMTPError(fmt.Sprintf("451 4.5.1 Temporary server error. Please try again later: %v", err)) 107 emailError.With(prometheus.Labels{"type": "ses error"}).Inc()
108 sesError.Inc()
109 return smtpd.SMTPError("451 4.5.1 Temporary server error. Please try again later")
68 } 110 }
111 e.logMessageSend()
69 return err 112 return err
70} 113}
71 114
115func renewSecret(vc *api.Client, s *api.Secret) error {
116 w, err := vc.NewLifetimeWatcher(&api.LifetimeWatcherInput{Secret: s})
117 if err != nil {
118 return err
119 }
120 go w.Start()
121
122 go func() {
123 for {
124 select {
125 case err := <-w.DoneCh():
126 if err != nil {
127 credentialRenewalError.Inc()
128 log.Fatalf("Error renewing credential: %s", err)
129 }
130 case renewal := <-w.RenewCh():
131 credentialRenewalSuccess.Inc()
132 log.Printf("Successfully renewed: %#v", renewal)
133 }
134 }
135 }()
136
137 return nil
138}
139
140func getVaultSecret(path string) (credentials.Value, error) {
141 var r credentials.Value
142
143 vc, err := api.NewClient(api.DefaultConfig())
144 if err != nil {
145 return r, err
146 }
147
148 // Use AppRole if it's in the environment, otherwise assume VAULT_TOKEN
149 // was provided in the environment.
150 if roleID := os.Getenv("VAULT_APPROLE_ROLE_ID"); roleID != "" {
151 appRoleAuth, err := approle.NewAppRoleAuth(roleID, &approle.SecretID{
152 FromEnv: "VAULT_APPROLE_SECRET_ID",
153 })
154 if err != nil {
155 return r, fmt.Errorf("unable to initialize AppRole auth method: %w", err)
156 }
157 if loginSecret, err := vc.Auth().Login(context.Background(), appRoleAuth); err != nil {
158 return r, fmt.Errorf("unable to login to AppRole auth method: %w", err)
159 } else {
160 if err := renewSecret(vc, loginSecret); err != nil {
161 return r, err
162 }
163 }
164 }
165
166 secret, err := vc.Logical().Read(path)
167 if err != nil {
168 return r, err
169 }
170 if secret == nil {
171 return r, fmt.Errorf("Vault returned no AWS secret")
172 }
173
174 keyId, ok := secret.Data["access_key"]
175 if !ok {
176 return r, fmt.Errorf("Vault secret had no access_key")
177 }
178
179 secretKey, ok := secret.Data["secret_key"]
180 if !ok {
181 return r, fmt.Errorf("Vault secret had no secret_key")
182 }
183
184 r.AccessKeyID = keyId.(string)
185 r.SecretAccessKey = secretKey.(string)
186
187 return r, renewSecret(vc, secret)
188}
189
190func makeSesClient(enableVault bool, vaultPath string) (*ses.SES, error) {
191 var err error
192 var s *session.Session
193
194 if enableVault {
195 cred, err := getVaultSecret(vaultPath)
196 if err != nil {
197 return nil, err
198 }
199
200 s, err = session.NewSession(&aws.Config{
201 Credentials: credentials.NewStaticCredentialsFromCreds(cred),
202 })
203 } else {
204 s, err = session.NewSession()
205 }
206 if err != nil {
207 return nil, err
208 }
209
210 return ses.New(s), nil
211}
212
72func main() { 213func main() {
73 sesClient = ses.New(session.Must(session.NewSession())) 214 var err error
74 addr := DefaultAddr 215
216 disablePrometheus := flag.Bool("disable-prometheus", false, "Disables prometheus metrics server")
217 prometheusBind := flag.String("prometheus-bind", ":2501", "Address/port on which to bind Prometheus server")
218 enableVault := flag.Bool("enable-vault", false, "Enable fetching AWS IAM credentials from a Vault server")
219 vaultPath := flag.String("vault-path", "", "Full path to Vault credential (ex: \"aws/creds/my-mail-user\")")
220
221 flag.Parse()
222
223 sesClient, err = makeSesClient(*enableVault, *vaultPath)
224 if err != nil {
225 log.Fatalf("Error creating AWS session: %s", err)
226 }
75 227
76 if len(os.Args) == 2 { 228 addr := DefaultAddr
77 addr = os.Args[1] 229 if flag.Arg(0) != "" {
78 } else if len(os.Args) > 2 { 230 addr = flag.Arg(0)
231 } else if flag.NArg() > 1 {
79 log.Fatalf("usage: %s [listen_host:port]", os.Args[0]) 232 log.Fatalf("usage: %s [listen_host:port]", os.Args[0])
80 } 233 }
81 234
235 if !*disablePrometheus {
236 sm := http.NewServeMux()
237 ps := &http.Server{Addr: *prometheusBind, Handler: sm}
238 sm.Handle("/metrics", promhttp.Handler())
239 go ps.ListenAndServe()
240 }
241
82 s := &smtpd.Server{ 242 s := &smtpd.Server{
83 Addr: addr, 243 Addr: addr,
84 OnNewMail: func(c smtpd.Connection, from smtpd.MailAddress) (smtpd.Envelope, error) { 244 OnNewMail: func(c smtpd.Connection, from smtpd.MailAddress) (smtpd.Envelope, error) {
85 return &Envelope{from: from.Email()}, nil 245 return &Envelope{from: from.Email()}, nil
86 }, 246 },
87 } 247 }
248
88 log.Printf("ListenAndServe on %s", addr) 249 log.Printf("ListenAndServe on %s", addr)
89 err := s.ListenAndServe() 250 if err := s.ListenAndServe(); err != nil {
90 if err != nil {
91 log.Fatalf("ListenAndServe: %v", err) 251 log.Fatalf("ListenAndServe: %v", err)
92 } 252 }
93} 253}