diff options
Diffstat (limited to 'cmd')
-rw-r--r-- | cmd/client/oauth2.go | 51 | ||||
-rw-r--r-- | cmd/web/server.go | 15 |
2 files changed, 59 insertions, 7 deletions
diff --git a/cmd/client/oauth2.go b/cmd/client/oauth2.go index 6667c5a..1ccdaaa 100644 --- a/cmd/client/oauth2.go +++ b/cmd/client/oauth2.go | |||
@@ -5,6 +5,7 @@ import ( | |||
5 | "encoding/json" | 5 | "encoding/json" |
6 | "fmt" | 6 | "fmt" |
7 | "net/http" | 7 | "net/http" |
8 | "net/url" | ||
8 | "strings" | 9 | "strings" |
9 | "time" | 10 | "time" |
10 | 11 | ||
@@ -23,6 +24,36 @@ type Oauth2PKCEDeviceClient struct { | |||
23 | interval time.Duration | 24 | interval time.Duration |
24 | } | 25 | } |
25 | 26 | ||
27 | func (c *Oauth2PKCEDeviceClient) discoverHost(ctx context.Context) (*models.OauthDiscoveryMetadata, error) { | ||
28 | u := &url.URL{ | ||
29 | Scheme: "https", | ||
30 | Host: c.Host, | ||
31 | Path: models.Oauth2MetadataPath, | ||
32 | } | ||
33 | |||
34 | req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) | ||
35 | if err != nil { | ||
36 | return nil, err | ||
37 | } | ||
38 | |||
39 | res, err := http.DefaultClient.Do(req) | ||
40 | if err != nil { | ||
41 | return nil, err | ||
42 | } | ||
43 | defer res.Body.Close() | ||
44 | |||
45 | if res.StatusCode != http.StatusOK { | ||
46 | return nil, fmt.Errorf("Oauth2 discovery request failed with code %d", res.StatusCode) | ||
47 | } | ||
48 | |||
49 | var resp models.OauthDiscoveryMetadata | ||
50 | if err := json.NewDecoder(res.Body).Decode(&resp); err != nil { | ||
51 | return nil, err | ||
52 | } | ||
53 | |||
54 | return &resp, nil | ||
55 | } | ||
56 | |||
26 | func (c *Oauth2PKCEDeviceClient) Authorize(ctx context.Context) (*models.DeviceAuthorizationResponse, error) { | 57 | func (c *Oauth2PKCEDeviceClient) Authorize(ctx context.Context) (*models.DeviceAuthorizationResponse, error) { |
27 | challenge, err := models.NewPKCEChallenge() | 58 | challenge, err := models.NewPKCEChallenge() |
28 | if err != nil { | 59 | if err != nil { |
@@ -40,8 +71,12 @@ func (c *Oauth2PKCEDeviceClient) Authorize(ctx context.Context) (*models.DeviceA | |||
40 | return nil, err | 71 | return nil, err |
41 | } | 72 | } |
42 | 73 | ||
43 | url := fmt.Sprintf("https://%s/auth/device", c.Host) | 74 | md, err := c.discoverHost(ctx) |
44 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(values.Encode())) | 75 | if err != nil { |
76 | return nil, err | ||
77 | } | ||
78 | |||
79 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, md.DeviceAuthorizationEndpoint, strings.NewReader(values.Encode())) | ||
45 | if err != nil { | 80 | if err != nil { |
46 | return nil, err | 81 | return nil, err |
47 | } | 82 | } |
@@ -53,7 +88,7 @@ func (c *Oauth2PKCEDeviceClient) Authorize(ctx context.Context) (*models.DeviceA | |||
53 | } | 88 | } |
54 | defer res.Body.Close() | 89 | defer res.Body.Close() |
55 | 90 | ||
56 | if res.StatusCode != 200 { | 91 | if res.StatusCode != http.StatusOK { |
57 | var resError models.Oauth2Error | 92 | var resError models.Oauth2Error |
58 | if err := json.NewDecoder(res.Body).Decode(&resError); err != nil { | 93 | if err := json.NewDecoder(res.Body).Decode(&resError); err != nil { |
59 | return nil, err | 94 | return nil, err |
@@ -85,8 +120,12 @@ func (c *Oauth2PKCEDeviceClient) fetchToken(ctx context.Context, deviceCode stri | |||
85 | return nil, err | 120 | return nil, err |
86 | } | 121 | } |
87 | 122 | ||
88 | url := fmt.Sprintf("https://%s/auth/token", c.Host) | 123 | md, err := c.discoverHost(ctx) |
89 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, strings.NewReader(values.Encode())) | 124 | if err != nil { |
125 | return nil, err | ||
126 | } | ||
127 | |||
128 | req, err := http.NewRequestWithContext(ctx, http.MethodPost, md.TokenEndpoint, strings.NewReader(values.Encode())) | ||
90 | if err != nil { | 129 | if err != nil { |
91 | return nil, err | 130 | return nil, err |
92 | } | 131 | } |
@@ -98,7 +137,7 @@ func (c *Oauth2PKCEDeviceClient) fetchToken(ctx context.Context, deviceCode stri | |||
98 | } | 137 | } |
99 | defer res.Body.Close() | 138 | defer res.Body.Close() |
100 | 139 | ||
101 | if res.StatusCode != 200 { | 140 | if res.StatusCode != http.StatusOK { |
102 | var resError models.Oauth2Error | 141 | var resError models.Oauth2Error |
103 | if err := json.NewDecoder(res.Body).Decode(&resError); err != nil { | 142 | if err := json.NewDecoder(res.Body).Decode(&resError); err != nil { |
104 | return nil, err | 143 | return nil, err |
diff --git a/cmd/web/server.go b/cmd/web/server.go index e930257..4867970 100644 --- a/cmd/web/server.go +++ b/cmd/web/server.go | |||
@@ -171,15 +171,25 @@ func webMain(cfg app.Config, embeddedTemplates, embeddedClients fs.FS, appVersio | |||
171 | Webauthn: wauthn, | 171 | Webauthn: wauthn, |
172 | } | 172 | } |
173 | 173 | ||
174 | // TODO: Clean up this hack and expose these to echo | ||
175 | hostname := fmt.Sprintf("https://%s", cfg.Hostnames[0]) | ||
176 | if strings.HasPrefix(cfg.Hostnames[0], "dev.") { | ||
177 | hostname += ":8070" | ||
178 | } | ||
179 | |||
174 | o2dc := &controllers.OAuth2DeviceController[*app.Session]{ | 180 | o2dc := &controllers.OAuth2DeviceController[*app.Session]{ |
175 | Logger: s.Logger, | 181 | Logger: s.Logger, |
176 | AuthSessions: authSessionStore, | 182 | AuthSessions: authSessionStore, |
177 | OauthClients: oauthClientStore, | 183 | OauthClients: oauthClientStore, |
178 | Hostname: fmt.Sprintf("https://%s", cfg.Hostnames[0]), // TODO | 184 | Hostname: hostname, |
179 | PollSeconds: cfg.OauthDevicePollSecs, | 185 | PollSeconds: cfg.OauthDevicePollSecs, |
180 | SessionExpiration: cfg.OauthSessionTimeout, | 186 | SessionExpiration: cfg.OauthSessionTimeout, |
181 | } | 187 | } |
182 | 188 | ||
189 | od := controllers.Oauth2DiscoveryController{ | ||
190 | Hostname: hostname, | ||
191 | } | ||
192 | |||
183 | ph := &controllers.ProxyHandler{ | 193 | ph := &controllers.ProxyHandler{ |
184 | Logger: s.Logger, | 194 | Logger: s.Logger, |
185 | Users: userStore, | 195 | Users: userStore, |
@@ -260,5 +270,8 @@ func webMain(cfg app.Config, embeddedTemplates, embeddedClients fs.FS, appVersio | |||
260 | pg.GET("/:host/:port", ph.Handle) | 270 | pg.GET("/:host/:port", ph.Handle) |
261 | } | 271 | } |
262 | 272 | ||
273 | s.GET(models.Oauth2MetadataPath, od.Handle) | ||
274 | s.GET(models.Oauth2MetadataCompatPath, od.Handle) | ||
275 | |||
263 | s.RunForever(!cfg.DisableBackgroundJobs) | 276 | s.RunForever(!cfg.DisableBackgroundJobs) |
264 | } | 277 | } |