diff options
Diffstat (limited to 'tls/ocsp.go')
-rw-r--r-- | tls/ocsp.go | 219 |
1 files changed, 219 insertions, 0 deletions
diff --git a/tls/ocsp.go b/tls/ocsp.go new file mode 100644 index 0000000..94f7f68 --- /dev/null +++ b/tls/ocsp.go | |||
@@ -0,0 +1,219 @@ | |||
1 | package tls | ||
2 | |||
3 | import ( | ||
4 | "bytes" | ||
5 | "context" | ||
6 | "crypto/tls" | ||
7 | "crypto/x509" | ||
8 | "errors" | ||
9 | "fmt" | ||
10 | "io" | ||
11 | "net/http" | ||
12 | "sync" | ||
13 | "time" | ||
14 | |||
15 | "golang.org/x/crypto/ocsp" | ||
16 | ) | ||
17 | |||
18 | type OcspError struct { | ||
19 | Err error | ||
20 | AtBoot bool | ||
21 | } | ||
22 | |||
23 | func (e OcspError) Error() string { | ||
24 | return e.Err.Error() | ||
25 | } | ||
26 | |||
27 | func (e OcspError) Unwrap() error { | ||
28 | return e.Err | ||
29 | } | ||
30 | |||
31 | type OcspLogger interface { | ||
32 | Info(...interface{}) | ||
33 | Errorf(string, ...interface{}) | ||
34 | } | ||
35 | |||
36 | func OcspErrorLogger(l OcspLogger, c <-chan OcspError) func(context.Context, *sync.WaitGroup) error { | ||
37 | return func(ctx context.Context, wg *sync.WaitGroup) error { | ||
38 | wg.Add(1) | ||
39 | defer wg.Done() | ||
40 | |||
41 | for { | ||
42 | select { | ||
43 | case err := <-c: | ||
44 | l.Errorf("Error in OCSP stapling: %w", errors.Unwrap(err)) | ||
45 | case <-ctx.Done(): | ||
46 | l.Info("Shutting down OCSP logger") | ||
47 | return nil | ||
48 | } | ||
49 | } | ||
50 | } | ||
51 | } | ||
52 | |||
53 | type OcspManager struct { | ||
54 | CertPath, KeyPath string | ||
55 | Errors chan<- OcspError | ||
56 | cert *tls.Certificate | ||
57 | ocspRes *ocsp.Response | ||
58 | sync.RWMutex | ||
59 | } | ||
60 | |||
61 | func (m *OcspManager) loadCert() error { | ||
62 | cert, err := tls.LoadX509KeyPair(m.CertPath, m.KeyPath) | ||
63 | if err != nil { | ||
64 | return err | ||
65 | } | ||
66 | m.Lock() | ||
67 | m.cert = &cert | ||
68 | m.Unlock() | ||
69 | |||
70 | return nil | ||
71 | } | ||
72 | |||
73 | func (m *OcspManager) stapleCert() error { | ||
74 | // This makes a network request to an unknown server so don't hold the full | ||
75 | // lock while this is happening | ||
76 | m.RLock() | ||
77 | raw, ocspRes, err := GetOcspResponse(m.cert) | ||
78 | if err != nil { | ||
79 | return err | ||
80 | } | ||
81 | m.RUnlock() | ||
82 | |||
83 | m.Lock() | ||
84 | m.cert.OCSPStaple = raw | ||
85 | m.ocspRes = ocspRes | ||
86 | m.Unlock() | ||
87 | |||
88 | return nil | ||
89 | } | ||
90 | |||
91 | func (m *OcspManager) Init() error { | ||
92 | // All functions called here will handle locking themselves | ||
93 | if err := m.loadCert(); err != nil { | ||
94 | return err | ||
95 | } | ||
96 | |||
97 | if err := m.stapleCert(); err != nil { | ||
98 | return err | ||
99 | } | ||
100 | |||
101 | return nil | ||
102 | } | ||
103 | |||
104 | func (m *OcspManager) Run(ctx context.Context, wg *sync.WaitGroup) error { | ||
105 | wg.Add(1) | ||
106 | defer wg.Done() | ||
107 | |||
108 | t := time.NewTimer(m.ocspRes.NextUpdate.Sub(time.Now()) - time.Hour) | ||
109 | |||
110 | for { | ||
111 | select { | ||
112 | case <-t.C: | ||
113 | if err := m.stapleCert(); err != nil { | ||
114 | if m.Errors != nil { | ||
115 | m.Errors <- OcspError{err, false} | ||
116 | } | ||
117 | t.Reset(time.Hour) | ||
118 | continue | ||
119 | } | ||
120 | // We own this object and only we write it, no need to lock | ||
121 | t.Reset(m.ocspRes.NextUpdate.Sub(time.Now()) - time.Hour) | ||
122 | case <-ctx.Done(): | ||
123 | return nil | ||
124 | } | ||
125 | } | ||
126 | } | ||
127 | |||
128 | func (m *OcspManager) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) { | ||
129 | m.RLock() | ||
130 | defer m.RUnlock() | ||
131 | |||
132 | if m.cert != nil { | ||
133 | return m.cert, nil | ||
134 | } | ||
135 | |||
136 | return nil, fmt.Errorf("OCSP manager has no certificate stored") | ||
137 | } | ||
138 | |||
139 | func GetOcspResponse(chain *tls.Certificate) ([]byte, *ocsp.Response, error) { | ||
140 | var certs []*x509.Certificate | ||
141 | for _, c := range chain.Certificate { | ||
142 | cert, err := x509.ParseCertificate(c) | ||
143 | if err != nil { | ||
144 | return nil, nil, err | ||
145 | } | ||
146 | certs = append(certs, cert) | ||
147 | } | ||
148 | if len(certs) == 0 { | ||
149 | return nil, nil, fmt.Errorf("no certificates found in bundle") | ||
150 | } | ||
151 | |||
152 | // We expect the certificate slice to be ordered downwards the chain. | ||
153 | // SRV CRT -> CA. We need to pull the leaf and issuer certs out of it, | ||
154 | // which should always be the first two certificates. If there's no | ||
155 | // OCSP server listed in the leaf cert, there's nothing to do. And if | ||
156 | // we have only one certificate so far, we need to get the issuer cert. | ||
157 | leaf := certs[0] | ||
158 | if len(leaf.OCSPServer) == 0 { | ||
159 | return nil, nil, fmt.Errorf("no OCSP server specified in certificate") | ||
160 | } | ||
161 | |||
162 | if len(certs) == 1 { | ||
163 | if len(leaf.IssuingCertificateURL) == 0 { | ||
164 | return nil, nil, fmt.Errorf("no URL to issuing certificate") | ||
165 | } | ||
166 | |||
167 | resp, err := http.Get(leaf.IssuingCertificateURL[0]) | ||
168 | if err != nil { | ||
169 | return nil, nil, fmt.Errorf("getting issuer certificate: %w", err) | ||
170 | } | ||
171 | defer resp.Body.Close() | ||
172 | |||
173 | issuerBytes, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024)) | ||
174 | if err != nil { | ||
175 | return nil, nil, fmt.Errorf("reading issuer certificate: %w", err) | ||
176 | } | ||
177 | |||
178 | issuer, err := x509.ParseCertificate(issuerBytes) | ||
179 | if err != nil { | ||
180 | return nil, nil, fmt.Errorf("parsing issuer certificate: %w", err) | ||
181 | } | ||
182 | |||
183 | certs = append(certs, issuer) | ||
184 | } | ||
185 | |||
186 | issuer := certs[1] | ||
187 | |||
188 | req, err := ocsp.CreateRequest(leaf, issuer, nil) | ||
189 | if err != nil { | ||
190 | return nil, nil, fmt.Errorf("creating OCSP request: %w", err) | ||
191 | } | ||
192 | |||
193 | httpRes, err := http.Post(leaf.OCSPServer[0], "application/ocsp-request", bytes.NewReader(req)) | ||
194 | if err != nil { | ||
195 | return nil, nil, fmt.Errorf("making OCSP request: %w", err) | ||
196 | } | ||
197 | defer httpRes.Body.Close() | ||
198 | |||
199 | rawRes, err := io.ReadAll(io.LimitReader(httpRes.Body, 1024*1024)) | ||
200 | if err != nil { | ||
201 | return nil, nil, fmt.Errorf("reading OCSP response: %w", err) | ||
202 | } | ||
203 | |||
204 | res, err := ocsp.ParseResponse(rawRes, issuer) | ||
205 | if err != nil { | ||
206 | return nil, nil, fmt.Errorf("parsing OCSP response: %w", err) | ||
207 | } | ||
208 | |||
209 | if res.Status != ocsp.Good { | ||
210 | return nil, nil, fmt.Errorf("invalid: OCSP response was not of Good status") | ||
211 | } | ||
212 | |||
213 | // This is invalid, the response expires after the certificate | ||
214 | if res.NextUpdate.After(leaf.NotAfter) { | ||
215 | return nil, nil, fmt.Errorf("invalid: OCSP response valid after certificate expiration") | ||
216 | } | ||
217 | |||
218 | return rawRes, res, nil | ||
219 | } | ||