aboutsummaryrefslogtreecommitdiff
path: root/crypto/tls/ocsp_manager.go
blob: dac4c5eb1ef51daebf4555d6142a6750f1452eb2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package tls

import (
	"context"
	"crypto/tls"
	"errors"
	"fmt"
	"sync"
	"time"

	"golang.org/x/crypto/ocsp"
)

type OcspError struct {
	Err    error
	AtBoot bool
}

func (e OcspError) Error() string {
	return e.Err.Error()
}

func (e OcspError) Unwrap() error {
	return e.Err
}

type OcspLogger interface {
	Info(...interface{})
	Errorf(string, ...interface{})
}

func OcspErrorLogger(l OcspLogger, c <-chan OcspError) func(context.Context, *sync.WaitGroup) error {
	return func(ctx context.Context, wg *sync.WaitGroup) error {
		wg.Add(1)
		defer wg.Done()

		for {
			select {
			case err := <-c:
				l.Errorf("Error in OCSP stapling: %w", errors.Unwrap(err))
			case <-ctx.Done():
				l.Info("Shutting down OCSP logger")
				return nil
			}
		}
	}
}

type OcspManager struct {
	CertPath, KeyPath string
	Errors            chan<- OcspError
	cert              *tls.Certificate
	ocspRes           *ocsp.Response
	sync.RWMutex
}

func (m *OcspManager) loadCert() error {
	cert, err := tls.LoadX509KeyPair(m.CertPath, m.KeyPath)
	if err != nil {
		return err
	}
	m.Lock()
	m.cert = &cert
	m.Unlock()

	return nil
}

func (m *OcspManager) stapleCert() error {
	// This makes a network request to an unknown server so don't hold the full
	// lock while this is happening
	m.RLock()
	raw, ocspRes, err := GetOcspResponse(m.cert)
	if err != nil {
		return err
	}
	m.RUnlock()

	m.Lock()
	m.cert.OCSPStaple = raw
	m.ocspRes = ocspRes
	m.Unlock()

	return nil
}

func (m *OcspManager) Init() error {
	// All functions called here will handle locking themselves
	if err := m.loadCert(); err != nil {
		return err
	}

	if err := m.stapleCert(); err != nil {
		return err
	}

	return nil
}

func (m *OcspManager) Run(ctx context.Context, wg *sync.WaitGroup) error {
	wg.Add(1)
	defer wg.Done()

	t := time.NewTimer(m.ocspRes.NextUpdate.Sub(time.Now()) - time.Hour)

	for {
		select {
		case <-t.C:
			if err := m.stapleCert(); err != nil {
				if m.Errors != nil {
					m.Errors <- OcspError{err, false}
				}
				t.Reset(time.Hour)
				continue
			}
			// We own this object and only we write it, no need to lock
			t.Reset(m.ocspRes.NextUpdate.Sub(time.Now()) - time.Hour)
		case <-ctx.Done():
			return nil
		}
	}
}

// TODO: TLS.GetCertificate for dyanmic certs for LE (cache these)
func (m *OcspManager) GetCertificate(h *tls.ClientHelloInfo) (*tls.Certificate, error) {
	m.RLock()
	defer m.RUnlock()

	if m.cert != nil {
		return m.cert, nil
	}

	return nil, fmt.Errorf("OCSP manager has no certificate stored")
}