diff options
Diffstat (limited to 'https/tls_config.go')
-rw-r--r-- | https/tls_config.go | 110 |
1 files changed, 77 insertions, 33 deletions
diff --git a/https/tls_config.go b/https/tls_config.go index 4b29862..44e57e9 100644 --- a/https/tls_config.go +++ b/https/tls_config.go | |||
@@ -20,12 +20,20 @@ import ( | |||
20 | "io/ioutil" | 20 | "io/ioutil" |
21 | "net/http" | 21 | "net/http" |
22 | 22 | ||
23 | "github.com/go-kit/kit/log" | ||
24 | "github.com/go-kit/kit/log/level" | ||
23 | "github.com/pkg/errors" | 25 | "github.com/pkg/errors" |
26 | config_util "github.com/prometheus/common/config" | ||
24 | "gopkg.in/yaml.v2" | 27 | "gopkg.in/yaml.v2" |
25 | ) | 28 | ) |
26 | 29 | ||
30 | var ( | ||
31 | errNoTLSConfig = errors.New("TLS config is not present") | ||
32 | ) | ||
33 | |||
27 | type Config struct { | 34 | type Config struct { |
28 | TLSConfig TLSStruct `yaml:"tls_config"` | 35 | TLSConfig TLSStruct `yaml:"tls_config"` |
36 | Users map[string]config_util.Secret `yaml:"basic_auth_users"` | ||
29 | } | 37 | } |
30 | 38 | ||
31 | type TLSStruct struct { | 39 | type TLSStruct struct { |
@@ -35,13 +43,18 @@ type TLSStruct struct { | |||
35 | ClientCAs string `yaml:"client_ca_file"` | 43 | ClientCAs string `yaml:"client_ca_file"` |
36 | } | 44 | } |
37 | 45 | ||
38 | func getTLSConfig(configPath string) (*tls.Config, error) { | 46 | func getConfig(configPath string) (*Config, error) { |
39 | content, err := ioutil.ReadFile(configPath) | 47 | content, err := ioutil.ReadFile(configPath) |
40 | if err != nil { | 48 | if err != nil { |
41 | return nil, err | 49 | return nil, err |
42 | } | 50 | } |
43 | c := &Config{} | 51 | c := &Config{} |
44 | err = yaml.Unmarshal(content, c) | 52 | err = yaml.UnmarshalStrict(content, c) |
53 | return c, err | ||
54 | } | ||
55 | |||
56 | func getTLSConfig(configPath string) (*tls.Config, error) { | ||
57 | c, err := getConfig(configPath) | ||
45 | if err != nil { | 58 | if err != nil { |
46 | return nil, err | 59 | return nil, err |
47 | } | 60 | } |
@@ -50,14 +63,18 @@ func getTLSConfig(configPath string) (*tls.Config, error) { | |||
50 | 63 | ||
51 | // ConfigToTLSConfig generates the golang tls.Config from the TLSStruct config. | 64 | // ConfigToTLSConfig generates the golang tls.Config from the TLSStruct config. |
52 | func ConfigToTLSConfig(c *TLSStruct) (*tls.Config, error) { | 65 | func ConfigToTLSConfig(c *TLSStruct) (*tls.Config, error) { |
53 | cfg := &tls.Config{ | 66 | if c.TLSCertPath == "" && c.TLSKeyPath == "" && c.ClientAuth == "" && c.ClientCAs == "" { |
54 | MinVersion: tls.VersionTLS12, | 67 | return nil, errNoTLSConfig |
55 | } | 68 | } |
56 | if len(c.TLSCertPath) == 0 { | 69 | |
57 | return nil, errors.New("missing TLSCertPath") | 70 | if c.TLSCertPath == "" { |
71 | return nil, errors.New("missing cert_file") | ||
58 | } | 72 | } |
59 | if len(c.TLSKeyPath) == 0 { | 73 | if c.TLSKeyPath == "" { |
60 | return nil, errors.New("missing TLSKeyPath") | 74 | return nil, errors.New("missing key_file") |
75 | } | ||
76 | cfg := &tls.Config{ | ||
77 | MinVersion: tls.VersionTLS12, | ||
61 | } | 78 | } |
62 | loadCert := func() (*tls.Certificate, error) { | 79 | loadCert := func() (*tls.Certificate, error) { |
63 | cert, err := tls.LoadX509KeyPair(c.TLSCertPath, c.TLSKeyPath) | 80 | cert, err := tls.LoadX509KeyPair(c.TLSCertPath, c.TLSKeyPath) |
@@ -74,7 +91,7 @@ func ConfigToTLSConfig(c *TLSStruct) (*tls.Config, error) { | |||
74 | return loadCert() | 91 | return loadCert() |
75 | } | 92 | } |
76 | 93 | ||
77 | if len(c.ClientCAs) > 0 { | 94 | if c.ClientCAs != "" { |
78 | clientCAPool := x509.NewCertPool() | 95 | clientCAPool := x509.NewCertPool() |
79 | clientCAFile, err := ioutil.ReadFile(c.ClientCAs) | 96 | clientCAFile, err := ioutil.ReadFile(c.ClientCAs) |
80 | if err != nil { | 97 | if err != nil { |
@@ -83,40 +100,67 @@ func ConfigToTLSConfig(c *TLSStruct) (*tls.Config, error) { | |||
83 | clientCAPool.AppendCertsFromPEM(clientCAFile) | 100 | clientCAPool.AppendCertsFromPEM(clientCAFile) |
84 | cfg.ClientCAs = clientCAPool | 101 | cfg.ClientCAs = clientCAPool |
85 | } | 102 | } |
86 | if len(c.ClientAuth) > 0 { | 103 | |
87 | switch s := (c.ClientAuth); s { | 104 | switch c.ClientAuth { |
88 | case "NoClientCert": | 105 | case "RequestClientCert": |
89 | cfg.ClientAuth = tls.NoClientCert | 106 | cfg.ClientAuth = tls.RequestClientCert |
90 | case "RequestClientCert": | 107 | case "RequireClientCert": |
91 | cfg.ClientAuth = tls.RequestClientCert | 108 | cfg.ClientAuth = tls.RequireAnyClientCert |
92 | case "RequireClientCert": | 109 | case "VerifyClientCertIfGiven": |
93 | cfg.ClientAuth = tls.RequireAnyClientCert | 110 | cfg.ClientAuth = tls.VerifyClientCertIfGiven |
94 | case "VerifyClientCertIfGiven": | 111 | case "RequireAndVerifyClientCert": |
95 | cfg.ClientAuth = tls.VerifyClientCertIfGiven | 112 | cfg.ClientAuth = tls.RequireAndVerifyClientCert |
96 | case "RequireAndVerifyClientCert": | 113 | case "", "NoClientCert": |
97 | cfg.ClientAuth = tls.RequireAndVerifyClientCert | 114 | cfg.ClientAuth = tls.NoClientCert |
98 | case "": | 115 | default: |
99 | cfg.ClientAuth = tls.NoClientCert | 116 | return nil, errors.New("Invalid ClientAuth: " + c.ClientAuth) |
100 | default: | ||
101 | return nil, errors.New("Invalid ClientAuth: " + s) | ||
102 | } | ||
103 | } | 117 | } |
104 | if len(c.ClientCAs) > 0 && cfg.ClientAuth == tls.NoClientCert { | 118 | |
119 | if c.ClientCAs != "" && cfg.ClientAuth == tls.NoClientCert { | ||
105 | return nil, errors.New("Client CA's have been configured without a Client Auth Policy") | 120 | return nil, errors.New("Client CA's have been configured without a Client Auth Policy") |
106 | } | 121 | } |
122 | |||
107 | return cfg, nil | 123 | return cfg, nil |
108 | } | 124 | } |
109 | 125 | ||
110 | // Listen starts the server on the given address. If tlsConfigPath isn't empty the server connection will be started using TLS. | 126 | // Listen starts the server on the given address. If tlsConfigPath isn't empty the server connection will be started using TLS. |
111 | func Listen(server *http.Server, tlsConfigPath string) error { | 127 | func Listen(server *http.Server, tlsConfigPath string, logger log.Logger) error { |
112 | if (tlsConfigPath) == "" { | 128 | if tlsConfigPath == "" { |
129 | level.Info(logger).Log("msg", "TLS is disabled and it cannot be enabled on the fly.") | ||
113 | return server.ListenAndServe() | 130 | return server.ListenAndServe() |
114 | } | 131 | } |
115 | var err error | 132 | |
116 | server.TLSConfig, err = getTLSConfig(tlsConfigPath) | 133 | if err := validateUsers(tlsConfigPath); err != nil { |
117 | if err != nil { | ||
118 | return err | 134 | return err |
119 | } | 135 | } |
136 | |||
137 | // Setup basic authentication. | ||
138 | var handler http.Handler = http.DefaultServeMux | ||
139 | if server.Handler != nil { | ||
140 | handler = server.Handler | ||
141 | } | ||
142 | server.Handler = &userAuthRoundtrip{ | ||
143 | tlsConfigPath: tlsConfigPath, | ||
144 | logger: logger, | ||
145 | handler: handler, | ||
146 | } | ||
147 | |||
148 | config, err := getTLSConfig(tlsConfigPath) | ||
149 | switch err { | ||
150 | case nil: | ||
151 | // Valid TLS config. | ||
152 | level.Info(logger).Log("msg", "TLS is enabled and it cannot be disabled on the fly.") | ||
153 | case errNoTLSConfig: | ||
154 | // No TLS config, back to plain HTTP. | ||
155 | level.Info(logger).Log("msg", "TLS is disabled and it cannot be enabled on the fly.") | ||
156 | return server.ListenAndServe() | ||
157 | default: | ||
158 | // Invalid TLS config. | ||
159 | return err | ||
160 | } | ||
161 | |||
162 | server.TLSConfig = config | ||
163 | |||
120 | // Set the GetConfigForClient method of the HTTPS server so that the config | 164 | // Set the GetConfigForClient method of the HTTPS server so that the config |
121 | // and certs are reloaded on new connections. | 165 | // and certs are reloaded on new connections. |
122 | server.TLSConfig.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { | 166 | server.TLSConfig.GetConfigForClient = func(*tls.ClientHelloInfo) (*tls.Config, error) { |