aboutsummaryrefslogtreecommitdiff
path: root/tls/ocsp.go
blob: b80764f8f1f991ecac8a7dd662094a44a81c0b26 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
package tls

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"fmt"
	"io"
	"net/http"
	"sync"
	"time"

	"golang.org/x/crypto/ocsp"
)

type OcspError struct {
	Err    error
	AtBoot bool
}

func (e OcspError) Error() string {
	return e.Err.Error()
}

func (e OcspError) Unwrap() error {
	return e.Err
}

type OcspLogger interface {
	Info(...interface{})
	Errorf(string, ...interface{})
}

func OcspErrorLogger(l OcspLogger, c <-chan OcspError) func(context.Context, *sync.WaitGroup) error {
	return func(ctx context.Context, wg *sync.WaitGroup) error {
		wg.Add(1)
		defer wg.Done()

		for {
			select {
			case err := <-c:
				l.Errorf("Error in OCSP stapling: %w", errors.Unwrap(err))
			case <-ctx.Done():
				l.Info("Shutting down OCSP logger")
				return nil
			}
		}
	}
}

type OcspManager struct {
	CertPath, KeyPath string
	Errors            chan<- OcspError
	cert              *tls.Certificate
	ocspRes           *ocsp.Response
	sync.RWMutex
}

func (m *OcspManager) loadCert() error {
	cert, err := tls.LoadX509KeyPair(m.CertPath, m.KeyPath)
	if err != nil {
		return err
	}
	m.Lock()
	m.cert = &cert
	m.Unlock()

	return nil
}

func (m *OcspManager) stapleCert() error {
	// This makes a network request to an unknown server so don't hold the full
	// lock while this is happening
	m.RLock()
	raw, ocspRes, err := GetOcspResponse(m.cert)
	if err != nil {
		return err
	}
	m.RUnlock()

	m.Lock()
	m.cert.OCSPStaple = raw
	m.ocspRes = ocspRes
	m.Unlock()

	return nil
}

func (m *OcspManager) Init() error {
	// All functions called here will handle locking themselves
	if err := m.loadCert(); err != nil {
		return err
	}

	if err := m.stapleCert(); err != nil {
		return err
	}

	return nil
}

func (m *OcspManager) Run(ctx context.Context, wg *sync.WaitGroup) error {
	wg.Add(1)
	defer wg.Done()

	t := time.NewTimer(m.ocspRes.NextUpdate.Sub(time.Now()) - time.Hour)

	for {
		select {
		case <-t.C:
			if err := m.stapleCert(); err != nil {
				if m.Errors != nil {
					m.Errors <- OcspError{err, false}
				}
				t.Reset(time.Hour)
				continue
			}
			// We own this object and only we write it, no need to lock
			t.Reset(m.ocspRes.NextUpdate.Sub(time.Now()) - time.Hour)
		case <-ctx.Done():
			return nil
		}
	}
}

// TODO: TLS.GetCertificate for dyanmic certs for LE (cache these)
func (m *OcspManager) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
	m.RLock()
	defer m.RUnlock()

	if m.cert != nil {
		return m.cert, nil
	}

	return nil, fmt.Errorf("OCSP manager has no certificate stored")
}

func GetOcspResponse(chain *tls.Certificate) ([]byte, *ocsp.Response, error) {
	var certs []*x509.Certificate
	for _, c := range chain.Certificate {
		cert, err := x509.ParseCertificate(c)
		if err != nil {
			return nil, nil, err
		}
		certs = append(certs, cert)
	}
	if len(certs) == 0 {
		return nil, nil, fmt.Errorf("no certificates found in bundle")
	}

	// We expect the certificate slice to be ordered downwards the chain.
	// SRV CRT -> CA. We need to pull the leaf and issuer certs out of it,
	// which should always be the first two certificates. If there's no
	// OCSP server listed in the leaf cert, there's nothing to do. And if
	// we have only one certificate so far, we need to get the issuer cert.
	leaf := certs[0]
	if len(leaf.OCSPServer) == 0 {
		return nil, nil, fmt.Errorf("no OCSP server specified in certificate")
	}

	if len(certs) == 1 {
		if len(leaf.IssuingCertificateURL) == 0 {
			return nil, nil, fmt.Errorf("no URL to issuing certificate")
		}

		resp, err := http.Get(leaf.IssuingCertificateURL[0])
		if err != nil {
			return nil, nil, fmt.Errorf("getting issuer certificate: %w", err)
		}
		defer resp.Body.Close()

		issuerBytes, err := io.ReadAll(io.LimitReader(resp.Body, 1024*1024))
		if err != nil {
			return nil, nil, fmt.Errorf("reading issuer certificate: %w", err)
		}

		issuer, err := x509.ParseCertificate(issuerBytes)
		if err != nil {
			return nil, nil, fmt.Errorf("parsing issuer certificate: %w", err)
		}

		certs = append(certs, issuer)
	}

	issuer := certs[1]

	req, err := ocsp.CreateRequest(leaf, issuer, nil)
	if err != nil {
		return nil, nil, fmt.Errorf("creating OCSP request: %w", err)
	}

	httpRes, err := http.Post(leaf.OCSPServer[0], "application/ocsp-request", bytes.NewReader(req))
	if err != nil {
		return nil, nil, fmt.Errorf("making OCSP request: %w", err)
	}
	defer httpRes.Body.Close()

	rawRes, err := io.ReadAll(io.LimitReader(httpRes.Body, 1024*1024))
	if err != nil {
		return nil, nil, fmt.Errorf("reading OCSP response: %w", err)
	}

	res, err := ocsp.ParseResponse(rawRes, issuer)
	if err != nil {
		return nil, nil, fmt.Errorf("parsing OCSP response: %w", err)
	}

	if res.Status != ocsp.Good {
		return nil, nil, fmt.Errorf("invalid: OCSP response was not of Good status")
	}

	// This is invalid, the response expires after the certificate
	if res.NextUpdate.After(leaf.NotAfter) {
		return nil, nil, fmt.Errorf("invalid: OCSP response valid after certificate expiration")
	}

	return rawRes, res, nil
}