diff options
Diffstat (limited to 'main.go')
-rw-r--r-- | main.go | 180 |
1 files changed, 170 insertions, 10 deletions
@@ -2,13 +2,23 @@ package main | |||
2 | 2 | ||
3 | import ( | 3 | import ( |
4 | "bytes" | 4 | "bytes" |
5 | "context" | ||
6 | "flag" | ||
5 | "fmt" | 7 | "fmt" |
6 | "log" | 8 | "log" |
9 | "net/http" | ||
7 | "os" | 10 | "os" |
8 | 11 | ||
9 | "code.crute.us/mcrute/ses-smtpd-proxy/smtpd" | 12 | "code.crute.us/mcrute/ses-smtpd-proxy/smtpd" |
13 | "github.com/aws/aws-sdk-go/aws" | ||
14 | "github.com/aws/aws-sdk-go/aws/credentials" | ||
10 | "github.com/aws/aws-sdk-go/aws/session" | 15 | "github.com/aws/aws-sdk-go/aws/session" |
11 | "github.com/aws/aws-sdk-go/service/ses" | 16 | "github.com/aws/aws-sdk-go/service/ses" |
17 | "github.com/hashicorp/vault/api" | ||
18 | "github.com/hashicorp/vault/api/auth/approle" | ||
19 | "github.com/prometheus/client_golang/prometheus" | ||
20 | "github.com/prometheus/client_golang/prometheus/promauto" | ||
21 | "github.com/prometheus/client_golang/prometheus/promhttp" | ||
12 | ) | 22 | ) |
13 | 23 | ||
14 | const ( | 24 | const ( |
@@ -18,6 +28,34 @@ const ( | |||
18 | 28 | ||
19 | var sesClient *ses.SES | 29 | var sesClient *ses.SES |
20 | 30 | ||
31 | var ( | ||
32 | emailSent = promauto.NewCounter(prometheus.CounterOpts{ | ||
33 | Namespace: "smtpd", | ||
34 | Name: "email_send_success_total", | ||
35 | Help: "Total number of successfuly sent emails", | ||
36 | }) | ||
37 | emailError = promauto.NewCounterVec(prometheus.CounterOpts{ | ||
38 | Namespace: "smtpd", | ||
39 | Name: "email_send_fail_total", | ||
40 | Help: "Total number emails that failed to send", | ||
41 | }, []string{"type"}) | ||
42 | sesError = promauto.NewCounter(prometheus.CounterOpts{ | ||
43 | Namespace: "smtpd", | ||
44 | Name: "ses_error_total", | ||
45 | Help: "Total number errors with SES", | ||
46 | }) | ||
47 | credentialRenewalSuccess = promauto.NewCounter(prometheus.CounterOpts{ | ||
48 | Namespace: "smtpd", | ||
49 | Name: "credential_renewal_success_total", | ||
50 | Help: "Total number successful credential renewals", | ||
51 | }) | ||
52 | credentialRenewalError = promauto.NewCounter(prometheus.CounterOpts{ | ||
53 | Namespace: "smtpd", | ||
54 | Name: "credential_renewal_error_total", | ||
55 | Help: "Total number errors during credential renewal", | ||
56 | }) | ||
57 | ) | ||
58 | |||
21 | type Envelope struct { | 59 | type Envelope struct { |
22 | from string | 60 | from string |
23 | rcpts []*string | 61 | rcpts []*string |
@@ -32,6 +70,7 @@ func (e *Envelope) AddRecipient(rcpt smtpd.MailAddress) error { | |||
32 | 70 | ||
33 | func (e *Envelope) BeginData() error { | 71 | func (e *Envelope) BeginData() error { |
34 | if len(e.rcpts) == 0 { | 72 | if len(e.rcpts) == 0 { |
73 | emailError.With(prometheus.Labels{"type": "no valid recipients"}).Inc() | ||
35 | return smtpd.SMTPError("554 5.5.1 Error: no valid recipients") | 74 | return smtpd.SMTPError("554 5.5.1 Error: no valid recipients") |
36 | } | 75 | } |
37 | return nil | 76 | return nil |
@@ -40,7 +79,8 @@ func (e *Envelope) BeginData() error { | |||
40 | func (e *Envelope) Write(line []byte) error { | 79 | func (e *Envelope) Write(line []byte) error { |
41 | e.b.Write(line) | 80 | e.b.Write(line) |
42 | if e.b.Len() > SesSizeLimit { // SES limitation | 81 | if e.b.Len() > SesSizeLimit { // SES limitation |
43 | log.Printf("message size %d exceeds SES limit of %d\n", e.b.Len(), SesSizeLimit) | 82 | emailError.With(prometheus.Labels{"type": "minimum message size exceed"}).Inc() |
83 | log.Printf("message size %d exceeds SES limit of %d", e.b.Len(), SesSizeLimit) | ||
44 | return smtpd.SMTPError("554 5.5.1 Error: maximum message size exceeded") | 84 | return smtpd.SMTPError("554 5.5.1 Error: maximum message size exceeded") |
45 | } | 85 | } |
46 | return nil | 86 | return nil |
@@ -52,10 +92,10 @@ func (e *Envelope) logMessageSend() { | |||
52 | dr[i] = *e.rcpts[i] | 92 | dr[i] = *e.rcpts[i] |
53 | } | 93 | } |
54 | log.Printf("sending message from %+v to %+v", e.from, dr) | 94 | log.Printf("sending message from %+v to %+v", e.from, dr) |
95 | emailSent.Inc() | ||
55 | } | 96 | } |
56 | 97 | ||
57 | func (e *Envelope) Close() error { | 98 | func (e *Envelope) Close() error { |
58 | e.logMessageSend() | ||
59 | r := &ses.SendRawEmailInput{ | 99 | r := &ses.SendRawEmailInput{ |
60 | Source: &e.from, | 100 | Source: &e.from, |
61 | Destinations: e.rcpts, | 101 | Destinations: e.rcpts, |
@@ -64,30 +104,150 @@ func (e *Envelope) Close() error { | |||
64 | _, err := sesClient.SendRawEmail(r) | 104 | _, err := sesClient.SendRawEmail(r) |
65 | if err != nil { | 105 | if err != nil { |
66 | log.Printf("ERROR: ses: %v", err) | 106 | log.Printf("ERROR: ses: %v", err) |
67 | return smtpd.SMTPError(fmt.Sprintf("451 4.5.1 Temporary server error. Please try again later: %v", err)) | 107 | emailError.With(prometheus.Labels{"type": "ses error"}).Inc() |
108 | sesError.Inc() | ||
109 | return smtpd.SMTPError("451 4.5.1 Temporary server error. Please try again later") | ||
68 | } | 110 | } |
111 | e.logMessageSend() | ||
69 | return err | 112 | return err |
70 | } | 113 | } |
71 | 114 | ||
115 | func renewSecret(vc *api.Client, s *api.Secret) error { | ||
116 | w, err := vc.NewLifetimeWatcher(&api.LifetimeWatcherInput{Secret: s}) | ||
117 | if err != nil { | ||
118 | return err | ||
119 | } | ||
120 | go w.Start() | ||
121 | |||
122 | go func() { | ||
123 | for { | ||
124 | select { | ||
125 | case err := <-w.DoneCh(): | ||
126 | if err != nil { | ||
127 | credentialRenewalError.Inc() | ||
128 | log.Fatalf("Error renewing credential: %s", err) | ||
129 | } | ||
130 | case renewal := <-w.RenewCh(): | ||
131 | credentialRenewalSuccess.Inc() | ||
132 | log.Printf("Successfully renewed: %#v", renewal) | ||
133 | } | ||
134 | } | ||
135 | }() | ||
136 | |||
137 | return nil | ||
138 | } | ||
139 | |||
140 | func getVaultSecret(path string) (credentials.Value, error) { | ||
141 | var r credentials.Value | ||
142 | |||
143 | vc, err := api.NewClient(api.DefaultConfig()) | ||
144 | if err != nil { | ||
145 | return r, err | ||
146 | } | ||
147 | |||
148 | // Use AppRole if it's in the environment, otherwise assume VAULT_TOKEN | ||
149 | // was provided in the environment. | ||
150 | if roleID := os.Getenv("VAULT_APPROLE_ROLE_ID"); roleID != "" { | ||
151 | appRoleAuth, err := approle.NewAppRoleAuth(roleID, &approle.SecretID{ | ||
152 | FromEnv: "VAULT_APPROLE_SECRET_ID", | ||
153 | }) | ||
154 | if err != nil { | ||
155 | return r, fmt.Errorf("unable to initialize AppRole auth method: %w", err) | ||
156 | } | ||
157 | if loginSecret, err := vc.Auth().Login(context.Background(), appRoleAuth); err != nil { | ||
158 | return r, fmt.Errorf("unable to login to AppRole auth method: %w", err) | ||
159 | } else { | ||
160 | if err := renewSecret(vc, loginSecret); err != nil { | ||
161 | return r, err | ||
162 | } | ||
163 | } | ||
164 | } | ||
165 | |||
166 | secret, err := vc.Logical().Read(path) | ||
167 | if err != nil { | ||
168 | return r, err | ||
169 | } | ||
170 | if secret == nil { | ||
171 | return r, fmt.Errorf("Vault returned no AWS secret") | ||
172 | } | ||
173 | |||
174 | keyId, ok := secret.Data["access_key"] | ||
175 | if !ok { | ||
176 | return r, fmt.Errorf("Vault secret had no access_key") | ||
177 | } | ||
178 | |||
179 | secretKey, ok := secret.Data["secret_key"] | ||
180 | if !ok { | ||
181 | return r, fmt.Errorf("Vault secret had no secret_key") | ||
182 | } | ||
183 | |||
184 | r.AccessKeyID = keyId.(string) | ||
185 | r.SecretAccessKey = secretKey.(string) | ||
186 | |||
187 | return r, renewSecret(vc, secret) | ||
188 | } | ||
189 | |||
190 | func makeSesClient(enableVault bool, vaultPath string) (*ses.SES, error) { | ||
191 | var err error | ||
192 | var s *session.Session | ||
193 | |||
194 | if enableVault { | ||
195 | cred, err := getVaultSecret(vaultPath) | ||
196 | if err != nil { | ||
197 | return nil, err | ||
198 | } | ||
199 | |||
200 | s, err = session.NewSession(&aws.Config{ | ||
201 | Credentials: credentials.NewStaticCredentialsFromCreds(cred), | ||
202 | }) | ||
203 | } else { | ||
204 | s, err = session.NewSession() | ||
205 | } | ||
206 | if err != nil { | ||
207 | return nil, err | ||
208 | } | ||
209 | |||
210 | return ses.New(s), nil | ||
211 | } | ||
212 | |||
72 | func main() { | 213 | func main() { |
73 | sesClient = ses.New(session.Must(session.NewSession())) | 214 | var err error |
74 | addr := DefaultAddr | 215 | |
216 | disablePrometheus := flag.Bool("disable-prometheus", false, "Disables prometheus metrics server") | ||
217 | prometheusBind := flag.String("prometheus-bind", ":2501", "Address/port on which to bind Prometheus server") | ||
218 | enableVault := flag.Bool("enable-vault", false, "Enable fetching AWS IAM credentials from a Vault server") | ||
219 | vaultPath := flag.String("vault-path", "", "Full path to Vault credential (ex: \"aws/creds/my-mail-user\")") | ||
220 | |||
221 | flag.Parse() | ||
222 | |||
223 | sesClient, err = makeSesClient(*enableVault, *vaultPath) | ||
224 | if err != nil { | ||
225 | log.Fatalf("Error creating AWS session: %s", err) | ||
226 | } | ||
75 | 227 | ||
76 | if len(os.Args) == 2 { | 228 | addr := DefaultAddr |
77 | addr = os.Args[1] | 229 | if flag.Arg(0) != "" { |
78 | } else if len(os.Args) > 2 { | 230 | addr = flag.Arg(0) |
231 | } else if flag.NArg() > 1 { | ||
79 | log.Fatalf("usage: %s [listen_host:port]", os.Args[0]) | 232 | log.Fatalf("usage: %s [listen_host:port]", os.Args[0]) |
80 | } | 233 | } |
81 | 234 | ||
235 | if !*disablePrometheus { | ||
236 | sm := http.NewServeMux() | ||
237 | ps := &http.Server{Addr: *prometheusBind, Handler: sm} | ||
238 | sm.Handle("/metrics", promhttp.Handler()) | ||
239 | go ps.ListenAndServe() | ||
240 | } | ||
241 | |||
82 | s := &smtpd.Server{ | 242 | s := &smtpd.Server{ |
83 | Addr: addr, | 243 | Addr: addr, |
84 | OnNewMail: func(c smtpd.Connection, from smtpd.MailAddress) (smtpd.Envelope, error) { | 244 | OnNewMail: func(c smtpd.Connection, from smtpd.MailAddress) (smtpd.Envelope, error) { |
85 | return &Envelope{from: from.Email()}, nil | 245 | return &Envelope{from: from.Email()}, nil |
86 | }, | 246 | }, |
87 | } | 247 | } |
248 | |||
88 | log.Printf("ListenAndServe on %s", addr) | 249 | log.Printf("ListenAndServe on %s", addr) |
89 | err := s.ListenAndServe() | 250 | if err := s.ListenAndServe(); err != nil { |
90 | if err != nil { | ||
91 | log.Fatalf("ListenAndServe: %v", err) | 251 | log.Fatalf("ListenAndServe: %v", err) |
92 | } | 252 | } |
93 | } | 253 | } |