diff options
author | Mike Crute <mike@crute.us> | 2021-11-16 14:46:24 -0800 |
---|---|---|
committer | Mike Crute <mike@crute.us> | 2021-11-17 07:56:10 -0800 |
commit | cc58a3da7d647de8520e33dc4356672d2ed1a366 (patch) | |
tree | 1b232a0d51446eb6370cfb13932190d31ce053df /app | |
parent | a42d794a286154a3106551e6e483861af2a9ef16 (diff) | |
download | cloud-identity-broker-cc58a3da7d647de8520e33dc4356672d2ed1a366.tar.bz2 cloud-identity-broker-cc58a3da7d647de8520e33dc4356672d2ed1a366.tar.xz cloud-identity-broker-cc58a3da7d647de8520e33dc4356672d2ed1a366.zip |
Import of source code
Diffstat (limited to 'app')
-rw-r--r-- | app/config.go | 70 | ||||
-rw-r--r-- | app/controllers/api.go | 6 | ||||
-rw-r--r-- | app/controllers/api_account_list.go | 109 | ||||
-rw-r--r-- | app/controllers/api_console_redirect.go | 63 | ||||
-rw-r--r-- | app/controllers/api_credentials.go | 76 | ||||
-rw-r--r-- | app/controllers/api_region_list.go | 61 | ||||
-rw-r--r-- | app/controllers/aws.go | 52 | ||||
-rw-r--r-- | app/controllers/basic.go | 17 | ||||
-rw-r--r-- | app/middleware/auth.go | 212 | ||||
-rw-r--r-- | app/models/account.go | 115 | ||||
-rw-r--r-- | app/models/session_key.go | 202 | ||||
-rw-r--r-- | app/models/user.go | 99 |
12 files changed, 1082 insertions, 0 deletions
diff --git a/app/config.go b/app/config.go new file mode 100644 index 0000000..6565863 --- /dev/null +++ b/app/config.go | |||
@@ -0,0 +1,70 @@ | |||
1 | package app | ||
2 | |||
3 | import ( | ||
4 | "log" | ||
5 | "time" | ||
6 | |||
7 | "code.crute.us/mcrute/golib/cli" | ||
8 | "code.crute.us/mcrute/golib/vault" | ||
9 | "github.com/spf13/cobra" | ||
10 | ) | ||
11 | |||
12 | type GitHubOauthCreds struct { | ||
13 | ClientId string `mapstructure:"client-id"` | ||
14 | ClientSecret string `mapstructure:"client-secret"` | ||
15 | } | ||
16 | |||
17 | type Config struct { | ||
18 | Bind []string | ||
19 | BindTLS []string | ||
20 | Debug bool | ||
21 | TemplateGlob string | ||
22 | TemplatePath string | ||
23 | MongoDbUri string | ||
24 | MongodbVaultPath string | ||
25 | LogFile string | ||
26 | TLSCacheDir string | ||
27 | TrustedIPRanges []string | ||
28 | ManagementIPRanges []string | ||
29 | Hostnames []string | ||
30 | DisableBackgroundJobs bool | ||
31 | RateLimit time.Duration | ||
32 | RateLimitBurst int | ||
33 | IssuerEndpoint string | ||
34 | JWTAudience string | ||
35 | AuthCookieDuration time.Duration | ||
36 | GitHubOauthCreds *GitHubOauthCreds | ||
37 | } | ||
38 | |||
39 | func NewConfigFromCmd(cmd *cobra.Command) Config { | ||
40 | f := cli.TolerantPflagSet{cmd.Flags()} | ||
41 | |||
42 | var githubOauth GitHubOauthCreds | ||
43 | oauthPath := f.MayGetString("github-oauth-vault-path") | ||
44 | err := vault.GetVaultKeyStruct(oauthPath, &githubOauth) | ||
45 | if err != nil { | ||
46 | log.Fatalf("Error getting %s from vault: %w", oauthPath, err) | ||
47 | } | ||
48 | |||
49 | return Config{ | ||
50 | Bind: f.MayGetStringSlice("bind"), | ||
51 | BindTLS: f.MayGetStringSlice("bind-tls"), | ||
52 | Debug: f.MayGetBool("debug"), | ||
53 | TemplateGlob: f.MayGetString("template-glob"), | ||
54 | TemplatePath: f.MayGetString("template-path"), | ||
55 | MongoDbUri: f.MayGetString("mongodb-uri"), | ||
56 | MongodbVaultPath: f.MayGetString("mongodb-vault-path"), | ||
57 | DisableBackgroundJobs: f.MayGetBool("disable-bg-jobs"), | ||
58 | TrustedIPRanges: f.MayGetStringSlice("trusted-ip-ranges"), | ||
59 | ManagementIPRanges: f.MayGetStringSlice("management-ip-ranges"), | ||
60 | Hostnames: f.MayGetStringSlice("hostname"), | ||
61 | LogFile: f.MayGetString("log-file"), | ||
62 | TLSCacheDir: f.MayGetString("tls-cache-dir"), | ||
63 | RateLimit: f.MayGetDuration("rate-limit"), | ||
64 | RateLimitBurst: f.MayGetInt("rate-limit-burst"), | ||
65 | IssuerEndpoint: f.MayGetString("issuer-endpoint"), | ||
66 | JWTAudience: f.MayGetString("jwt-audience"), | ||
67 | AuthCookieDuration: f.MayGetDuration("auth-cookie-duration"), | ||
68 | GitHubOauthCreds: &githubOauth, | ||
69 | } | ||
70 | } | ||
diff --git a/app/controllers/api.go b/app/controllers/api.go new file mode 100644 index 0000000..7beaa4c --- /dev/null +++ b/app/controllers/api.go | |||
@@ -0,0 +1,6 @@ | |||
1 | package controllers | ||
2 | |||
3 | const ( | ||
4 | contentTypeV1 = "application/vnd.broker.v1+json" // Original type | ||
5 | contentTypeV2 = "application/vnd.broker.v2+json" // Start of migration to multi-cloud | ||
6 | ) | ||
diff --git a/app/controllers/api_account_list.go b/app/controllers/api_account_list.go new file mode 100644 index 0000000..f69db6a --- /dev/null +++ b/app/controllers/api_account_list.go | |||
@@ -0,0 +1,109 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "context" | ||
5 | "net/http" | ||
6 | |||
7 | "code.crute.us/mcrute/cloud-identity-broker/app/middleware" | ||
8 | "code.crute.us/mcrute/cloud-identity-broker/app/models" | ||
9 | |||
10 | glecho "code.crute.us/mcrute/golib/echo" | ||
11 | "code.crute.us/mcrute/golib/echo/controller" | ||
12 | "github.com/labstack/echo/v4" | ||
13 | ) | ||
14 | |||
15 | type jsonAccount struct { | ||
16 | Vendor string `json:"vendor,omitempty"` | ||
17 | AccountNumber int `json:"account_number"` | ||
18 | ShortName string `json:"short_name"` | ||
19 | Name string `json:"name"` | ||
20 | ConsoleUrl string `json:"get_console_url,omitempty"` | ||
21 | ConsoleRedirectUrl string `json:"console_redirect_url,omitempty"` | ||
22 | CredentialsUrl string `json:"credentials_url"` | ||
23 | GlobalCredentialsUrl string `json:"global_credential_url,omitempty"` | ||
24 | } | ||
25 | |||
26 | func jsonAccountFromAccount(c echo.Context, a *models.Account) *jsonAccount { | ||
27 | return &jsonAccount{ | ||
28 | AccountNumber: a.AccountNumber, | ||
29 | ShortName: a.ShortName, | ||
30 | Name: a.Name, | ||
31 | ConsoleUrl: glecho.URLFor(c, "/api/account", a.ShortName, "console").String(), | ||
32 | ConsoleRedirectUrl: glecho.URLFor(c, "/api/account", a.ShortName, "console").Query("redirect", "1").String(), | ||
33 | CredentialsUrl: glecho.URLFor(c, "/api/account", a.ShortName, "credentials").String(), | ||
34 | GlobalCredentialsUrl: glecho.URLFor(c, "/api/account", a.ShortName, "credentials/global").String(), | ||
35 | } | ||
36 | } | ||
37 | |||
38 | type APIAccountListHandler struct { | ||
39 | store models.AccountStore | ||
40 | } | ||
41 | |||
42 | func NewAPIAccountListHandler(s models.AccountStore) echo.HandlerFunc { | ||
43 | al := &APIAccountListHandler{store: s} | ||
44 | h := &controller.ContentTypeNegotiatingHandler{ | ||
45 | DefaultHandler: al.HandleV1, | ||
46 | Handlers: map[string]echo.HandlerFunc{ | ||
47 | contentTypeV1: al.HandleV1, | ||
48 | contentTypeV2: al.HandleV2, | ||
49 | }, | ||
50 | } | ||
51 | return h.Handle | ||
52 | } | ||
53 | |||
54 | // getAccountList returns the account list. This does the same work that | ||
55 | // GetContext would do for most AWSAPI handlers but is a little different | ||
56 | // because it deals with lists of accounts. | ||
57 | // | ||
58 | // Authorization of the account is handled within the store. The store will not | ||
59 | // return accounts for which the user does not have access. | ||
60 | func (h *APIAccountListHandler) getAccountList(c echo.Context) ([]*models.Account, error) { | ||
61 | principal, err := middleware.GetAuthorizedPrincipal(c) | ||
62 | if err != nil { | ||
63 | return nil, echo.ErrUnauthorized | ||
64 | } | ||
65 | |||
66 | accounts, err := h.store.ListForUser(context.Background(), principal) | ||
67 | if err != nil { | ||
68 | c.Logger().Errorf("Unable to load account list: %w", err) | ||
69 | return nil, echo.ErrInternalServerError | ||
70 | } | ||
71 | |||
72 | return accounts, nil | ||
73 | } | ||
74 | |||
75 | // HandleV1 returns a list of JSON account objects | ||
76 | func (h *APIAccountListHandler) HandleV1(c echo.Context) error { | ||
77 | accounts, err := h.getAccountList(c) | ||
78 | if err != nil { | ||
79 | return err | ||
80 | } | ||
81 | |||
82 | out := []*jsonAccount{} | ||
83 | for _, a := range accounts { | ||
84 | ja := jsonAccountFromAccount(c, a) | ||
85 | ja.Vendor = "aws" | ||
86 | out = append(out, ja) | ||
87 | } | ||
88 | |||
89 | return c.JSON(http.StatusOK, out) | ||
90 | } | ||
91 | |||
92 | // HandleV2 returns a map of lists of account objects. the key to the map is | ||
93 | // the short name of the cloud provider. | ||
94 | func (h *APIAccountListHandler) HandleV2(c echo.Context) error { | ||
95 | accounts, err := h.getAccountList(c) | ||
96 | if err != nil { | ||
97 | return err | ||
98 | } | ||
99 | |||
100 | out := map[string][]*jsonAccount{ | ||
101 | "aws": []*jsonAccount{}, | ||
102 | } | ||
103 | for _, a := range accounts { | ||
104 | ja := jsonAccountFromAccount(c, a) | ||
105 | out["aws"] = append(out["aws"], ja) | ||
106 | } | ||
107 | |||
108 | return c.JSON(http.StatusOK, out) | ||
109 | } | ||
diff --git a/app/controllers/api_console_redirect.go b/app/controllers/api_console_redirect.go new file mode 100644 index 0000000..701bbf3 --- /dev/null +++ b/app/controllers/api_console_redirect.go | |||
@@ -0,0 +1,63 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "net/http" | ||
5 | |||
6 | "code.crute.us/mcrute/golib/echo/controller" | ||
7 | "github.com/labstack/echo/v4" | ||
8 | "github.com/prometheus/client_golang/prometheus" | ||
9 | "github.com/prometheus/client_golang/prometheus/promauto" | ||
10 | ) | ||
11 | |||
12 | var consoleAllowed = promauto.NewCounterVec(prometheus.CounterOpts{ | ||
13 | Namespace: "aws_access", // Legacy Namespace | ||
14 | Name: "broker_console_access_total", | ||
15 | Help: "Total number of console logins allowed by broker", | ||
16 | }, []string{"account"}) | ||
17 | |||
18 | type jsonConsoleUrl struct { | ||
19 | ConsoleURL string `json:"console_url"` | ||
20 | } | ||
21 | |||
22 | type APIConsoleRedirectHandler struct { | ||
23 | FederationIssuerEndpoint string | ||
24 | *AWSAPI | ||
25 | } | ||
26 | |||
27 | func NewAPIConsoleRedirectHandler(a *AWSAPI, fe string) echo.HandlerFunc { | ||
28 | al := &APIConsoleRedirectHandler{fe, a} | ||
29 | h := &controller.ContentTypeNegotiatingHandler{ | ||
30 | DefaultHandler: al.Handle, | ||
31 | Handlers: map[string]echo.HandlerFunc{ | ||
32 | contentTypeV1: al.Handle, | ||
33 | }, | ||
34 | } | ||
35 | return h.Handle | ||
36 | } | ||
37 | |||
38 | func (h *APIConsoleRedirectHandler) Handle(c echo.Context) error { | ||
39 | rc, err := h.GetContext(c) // Does all authorization checks | ||
40 | if err != nil { | ||
41 | return err | ||
42 | } | ||
43 | |||
44 | u, err := rc.AWS.GetFederationURL(rc.Principal.Username, h.FederationIssuerEndpoint) | ||
45 | if err != nil { | ||
46 | c.Logger().Errorf("Error fetching console URL: %w", err) | ||
47 | return echo.ErrBadRequest | ||
48 | } | ||
49 | |||
50 | c.Logger().Infof( | ||
51 | "Allowing '%s' to access account console '%s'", | ||
52 | rc.Principal.Username, rc.Account.Name, | ||
53 | ) | ||
54 | consoleAllowed.With(prometheus.Labels{ | ||
55 | "account": rc.Account.ShortName, | ||
56 | }).Inc() | ||
57 | |||
58 | if c.QueryParam("redirect") == "1" { | ||
59 | return c.Redirect(http.StatusFound, u) | ||
60 | } else { | ||
61 | return c.JSON(http.StatusOK, &jsonConsoleUrl{u}) | ||
62 | } | ||
63 | } | ||
diff --git a/app/controllers/api_credentials.go b/app/controllers/api_credentials.go new file mode 100644 index 0000000..1cefc07 --- /dev/null +++ b/app/controllers/api_credentials.go | |||
@@ -0,0 +1,76 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "net/http" | ||
5 | "time" | ||
6 | |||
7 | "code.crute.us/mcrute/cloud-identity-broker/cloud/aws" | ||
8 | |||
9 | "code.crute.us/mcrute/golib/echo/controller" | ||
10 | "github.com/labstack/echo/v4" | ||
11 | "github.com/prometheus/client_golang/prometheus" | ||
12 | "github.com/prometheus/client_golang/prometheus/promauto" | ||
13 | ) | ||
14 | |||
15 | var credsAllowed = promauto.NewCounterVec(prometheus.CounterOpts{ | ||
16 | Namespace: "aws_access", // Legacy namespace | ||
17 | Name: "broker_cred_access_total", | ||
18 | Help: "Total number of credential accesses allowed by broker", | ||
19 | }, []string{"account", "region"}) | ||
20 | |||
21 | type jsonCredential struct { | ||
22 | AccessKeyId *string `json:"access_key"` | ||
23 | SecretAccessKey *string `json:"secret_key"` | ||
24 | SessionToken *string `json:"session_token"` | ||
25 | Expiration *time.Time `json:"expiration"` | ||
26 | } | ||
27 | |||
28 | type APICredentialsHandler struct { | ||
29 | *AWSAPI | ||
30 | } | ||
31 | |||
32 | func NewAPICredentialsHandler(a *AWSAPI) echo.HandlerFunc { | ||
33 | al := &APICredentialsHandler{a} | ||
34 | h := &controller.ContentTypeNegotiatingHandler{ | ||
35 | DefaultHandler: al.Handle, | ||
36 | Handlers: map[string]echo.HandlerFunc{ | ||
37 | contentTypeV1: al.Handle, | ||
38 | }, | ||
39 | } | ||
40 | return h.Handle | ||
41 | } | ||
42 | |||
43 | func (h *APICredentialsHandler) Handle(c echo.Context) error { | ||
44 | rc, err := h.GetContext(c) // Does authorization checks | ||
45 | if err != nil { | ||
46 | return err | ||
47 | } | ||
48 | |||
49 | region := c.Param("region") | ||
50 | creds, err := rc.AWS.AssumeRole(rc.Principal.Username, ®ion) | ||
51 | if err != nil { | ||
52 | if aws.IsRegionNotExist(err) { | ||
53 | return echo.NotFoundHandler(c) | ||
54 | } | ||
55 | c.Logger().Errorf("Error retrieving credentials: %w", err) | ||
56 | return echo.ErrInternalServerError | ||
57 | } | ||
58 | |||
59 | c.Logger().Infof( | ||
60 | "Allowing '%s' to access account credential '%s'", | ||
61 | rc.Principal.Username, rc.Account.Name, | ||
62 | ) | ||
63 | credsAllowed.With(prometheus.Labels{ | ||
64 | "account": rc.Account.ShortName, | ||
65 | "region": region, | ||
66 | }).Inc() | ||
67 | |||
68 | c.Response().Header().Set("Expires", creds.Expiration.Add(-5*time.Minute).Format(time.RFC1123)) | ||
69 | |||
70 | return c.JSON(http.StatusOK, &jsonCredential{ | ||
71 | AccessKeyId: creds.AccessKeyId, | ||
72 | SecretAccessKey: creds.SecretAccessKey, | ||
73 | SessionToken: creds.SessionToken, | ||
74 | Expiration: creds.Expiration, | ||
75 | }) | ||
76 | } | ||
diff --git a/app/controllers/api_region_list.go b/app/controllers/api_region_list.go new file mode 100644 index 0000000..5bd1f7e --- /dev/null +++ b/app/controllers/api_region_list.go | |||
@@ -0,0 +1,61 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "net/http" | ||
5 | |||
6 | glecho "code.crute.us/mcrute/golib/echo" | ||
7 | "code.crute.us/mcrute/golib/echo/controller" | ||
8 | "github.com/labstack/echo/v4" | ||
9 | ) | ||
10 | |||
11 | type jsonRegion struct { | ||
12 | Name string `json:"name"` | ||
13 | Enabled bool `json:"enabled"` | ||
14 | Default bool `json:"default"` | ||
15 | CredentialsURL string `json:"credentials_url,omitempty"` | ||
16 | } | ||
17 | |||
18 | type APIRegionListHandler struct { | ||
19 | *AWSAPI | ||
20 | } | ||
21 | |||
22 | func NewAPIRegionListHandler(a *AWSAPI) echo.HandlerFunc { | ||
23 | al := &APIRegionListHandler{a} | ||
24 | h := &controller.ContentTypeNegotiatingHandler{ | ||
25 | DefaultHandler: al.Handle, | ||
26 | Handlers: map[string]echo.HandlerFunc{ | ||
27 | contentTypeV1: al.Handle, | ||
28 | }, | ||
29 | } | ||
30 | return h.Handle | ||
31 | } | ||
32 | |||
33 | func (h *APIRegionListHandler) Handle(c echo.Context) error { | ||
34 | rc, err := h.GetContext(c) // Does authorization checks | ||
35 | if err != nil { | ||
36 | return err | ||
37 | } | ||
38 | |||
39 | regions, err := rc.AWS.GetRegionList() | ||
40 | if err != nil { | ||
41 | c.Logger().Errorf("Failed to load region list: %w", err) | ||
42 | return echo.ErrInternalServerError | ||
43 | } | ||
44 | |||
45 | out := make([]*jsonRegion, len(regions)) | ||
46 | |||
47 | for i, r := range regions { | ||
48 | out[i] = &jsonRegion{ | ||
49 | Name: r.Name, | ||
50 | Enabled: r.Enabled, | ||
51 | Default: rc.Account.DefaultRegion == r.Name, | ||
52 | } | ||
53 | if r.Enabled { | ||
54 | out[i].CredentialsURL = glecho.URLFor(c, | ||
55 | "/api/account", rc.Account.ShortName, "credentials", r.Name, | ||
56 | ).String() | ||
57 | } | ||
58 | } | ||
59 | |||
60 | return c.JSON(http.StatusOK, out) | ||
61 | } | ||
diff --git a/app/controllers/aws.go b/app/controllers/aws.go new file mode 100644 index 0000000..5b1765d --- /dev/null +++ b/app/controllers/aws.go | |||
@@ -0,0 +1,52 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "context" | ||
5 | |||
6 | "code.crute.us/mcrute/cloud-identity-broker/app/middleware" | ||
7 | "code.crute.us/mcrute/cloud-identity-broker/app/models" | ||
8 | "code.crute.us/mcrute/cloud-identity-broker/cloud/aws" | ||
9 | |||
10 | "github.com/labstack/echo/v4" | ||
11 | ) | ||
12 | |||
13 | type requestContext struct { | ||
14 | Account *models.Account | ||
15 | Principal *models.User | ||
16 | AWS aws.AWSClient | ||
17 | } | ||
18 | |||
19 | // AWSAPI is a capability that all handlers talking to the AWS APIs should use. | ||
20 | // This capability does common permission checks and populates a request | ||
21 | // context with user, account, and AWS API information. | ||
22 | type AWSAPI struct { | ||
23 | Store models.AccountStore | ||
24 | } | ||
25 | |||
26 | // GetContext checks that the user is authenticated and is authorized to access | ||
27 | // the requested AWS account. This should be the very first call in any handler | ||
28 | // that will eventually call the AWS APIs. Errors returned from this method are | ||
29 | // echo responses and can be returned directly to the client. | ||
30 | func (h *AWSAPI) GetContext(c echo.Context) (*requestContext, error) { | ||
31 | principal, err := middleware.GetAuthorizedPrincipal(c) | ||
32 | if err != nil { | ||
33 | return nil, echo.ErrUnauthorized | ||
34 | } | ||
35 | |||
36 | account, err := h.Store.GetForUser(context.Background(), c.Param("account"), principal) | ||
37 | if err != nil { | ||
38 | return nil, echo.NotFoundHandler(c) | ||
39 | } | ||
40 | |||
41 | ac, err := aws.NewAWSClientFromAccount(account) | ||
42 | if err != nil { | ||
43 | c.Logger().Errorf("Error building AWS client: %w", err) | ||
44 | return nil, echo.ErrInternalServerError | ||
45 | } | ||
46 | |||
47 | return &requestContext{ | ||
48 | Account: account, | ||
49 | Principal: principal, | ||
50 | AWS: ac, | ||
51 | }, nil | ||
52 | } | ||
diff --git a/app/controllers/basic.go b/app/controllers/basic.go new file mode 100644 index 0000000..eff97e1 --- /dev/null +++ b/app/controllers/basic.go | |||
@@ -0,0 +1,17 @@ | |||
1 | package controllers | ||
2 | |||
3 | import ( | ||
4 | "net/http" | ||
5 | |||
6 | glecho "code.crute.us/mcrute/golib/echo" | ||
7 | "github.com/labstack/echo/v4" | ||
8 | ) | ||
9 | |||
10 | func IndexHandler(c echo.Context) error { | ||
11 | return c.Render(http.StatusOK, "index.tpl", nil) | ||
12 | } | ||
13 | |||
14 | func LogoutHandler(c echo.Context) error { | ||
15 | glecho.DeleteAllCookies(c) | ||
16 | return c.Redirect(http.StatusFound, "/") | ||
17 | } | ||
diff --git a/app/middleware/auth.go b/app/middleware/auth.go new file mode 100644 index 0000000..167d261 --- /dev/null +++ b/app/middleware/auth.go | |||
@@ -0,0 +1,212 @@ | |||
1 | package middleware | ||
2 | |||
3 | import ( | ||
4 | "context" | ||
5 | "fmt" | ||
6 | "net/http" | ||
7 | "strings" | ||
8 | "time" | ||
9 | |||
10 | "code.crute.us/mcrute/cloud-identity-broker/app/models" | ||
11 | "code.crute.us/mcrute/cloud-identity-broker/auth" | ||
12 | "code.crute.us/mcrute/cloud-identity-broker/auth/github" | ||
13 | |||
14 | "github.com/labstack/echo/v4" | ||
15 | ) | ||
16 | |||
17 | // canRegisterUrls is an interface that identifies what about an HTTP router is | ||
18 | // needed by this middleware. This mainly exists to work around the fact that | ||
19 | // our server is actually a golib.EchoWrapper and not an echo.Echo. | ||
20 | type canRegisterUrls interface { | ||
21 | GET(string, echo.HandlerFunc, ...echo.MiddlewareFunc) *echo.Route | ||
22 | } | ||
23 | |||
24 | const ( | ||
25 | authPrincipalContextKey = "broker.AuthorizedPrincipal" | ||
26 | gitHubTokenCookie = "github-token" | ||
27 | gitHubStateCookie = "github-state" | ||
28 | oauthReturnUrl = "/github-auth" | ||
29 | ) | ||
30 | |||
31 | // GetAuthorizedPrincipal returns the user principal object from the request | ||
32 | // context and casts it correctly. Will return error if there is no principal | ||
33 | // or if the principal is of the incorrect type. | ||
34 | // | ||
35 | // Note that use of this function implies that AuthenticationMiddleware is used | ||
36 | // somewhere in the stack before the handler calling this function is | ||
37 | // dispatched. | ||
38 | func GetAuthorizedPrincipal(c echo.Context) (*models.User, error) { | ||
39 | rp := c.Get(authPrincipalContextKey) | ||
40 | if rp == nil { | ||
41 | return nil, fmt.Errorf("No principal set in request") | ||
42 | } | ||
43 | principal, ok := rp.(*models.User) | ||
44 | if !ok { | ||
45 | return nil, fmt.Errorf("Principal in request is not of User type") | ||
46 | } | ||
47 | return principal, nil | ||
48 | } | ||
49 | |||
50 | type AuthenticationMiddleware struct { | ||
51 | Store models.UserStore | ||
52 | JWTManager *auth.JWTManager | ||
53 | GitHub *github.GitHubAuthenticator | ||
54 | CookieDuration time.Duration | ||
55 | } | ||
56 | |||
57 | func (m *AuthenticationMiddleware) redirectToGitHubAuth(c echo.Context) error { | ||
58 | redir, state := m.GitHub.GetAuthRedirect() | ||
59 | |||
60 | c.SetCookie(&http.Cookie{ | ||
61 | Name: gitHubStateCookie, | ||
62 | Value: state, | ||
63 | Path: "/", | ||
64 | Secure: true, | ||
65 | HttpOnly: true, | ||
66 | SameSite: http.SameSiteStrictMode, | ||
67 | }) | ||
68 | |||
69 | return c.Redirect(http.StatusFound, redir) | ||
70 | } | ||
71 | |||
72 | // RegisterUrls registers the URLs required by this middleware and handler with an echo instance. | ||
73 | // | ||
74 | // This is here instead of in the web main because these paths are encoded in | ||
75 | // the configuration for the GitHub application so changing them requires | ||
76 | // addition changes to that configuration. | ||
77 | func (m *AuthenticationMiddleware) RegisterUrls(e canRegisterUrls) { | ||
78 | e.GET(oauthReturnUrl, m.HandleCompleteLogin) | ||
79 | } | ||
80 | |||
81 | // Middleware does user authentication based on either an X-API-Key header, | ||
82 | // Authorization header, or GitHub cookie depending on how the request is | ||
83 | // phrased. | ||
84 | // | ||
85 | // If the request has either an X-API-Key or an Authorization Bearer header | ||
86 | // then that must pass validation with the downstream validation logic. | ||
87 | // Failures through this path are hard failures and the only way to re-try them | ||
88 | // is to authenticate with a new token. The underlying assumption is that only | ||
89 | // programmatic access goes through this path so redirecting to interactive | ||
90 | // authentication is pointless. | ||
91 | // | ||
92 | // In the absence of those headers it's assumed that the user is interactive | ||
93 | // and their auth cookie will be read and validated (by the exact same logic | ||
94 | // that an API key is validated, they're the same format) but the failure case | ||
95 | // here will redirect the user to GitHub for interactive auth. | ||
96 | // | ||
97 | // X-API-Key should be considered deprecated and the Authorization header with | ||
98 | // a type of Bearer should be used instead. This is more in-line with Oauth 2 | ||
99 | // style authentication. However, for now this middleware continues to support | ||
100 | // X-API-Key for to not break legacy API clients. | ||
101 | func (m *AuthenticationMiddleware) Middleware(next echo.HandlerFunc) echo.HandlerFunc { | ||
102 | return func(c echo.Context) error { | ||
103 | token := c.Request().Header.Get("X-API-Key") | ||
104 | if token == "" { | ||
105 | tp := strings.Split(c.Request().Header.Get(echo.HeaderAuthorization), " ") | ||
106 | if len(tp) == 2 && tp[0] == "Bearer" { | ||
107 | token = tp[1] | ||
108 | } | ||
109 | } | ||
110 | |||
111 | // If an API key is specified this is a success or failure path. There | ||
112 | // is no option to authenticate to GitHub as would an interactive user. | ||
113 | if token != "" { | ||
114 | u, err := m.JWTManager.Validate(token) | ||
115 | if err != nil { | ||
116 | c.Logger().Debugf("Error validating JWT: %w", err) | ||
117 | return echo.ErrUnauthorized | ||
118 | } | ||
119 | c.Set(authPrincipalContextKey, u) | ||
120 | return next(c) | ||
121 | } | ||
122 | |||
123 | // Kick them to GitHub auth if they have no auth cookie | ||
124 | authCookie, err := c.Cookie(gitHubTokenCookie) | ||
125 | if err != nil { | ||
126 | return m.redirectToGitHubAuth(c) | ||
127 | } | ||
128 | |||
129 | // If they fail the check them bounce them through logout to remove | ||
130 | // their existing cookies which should then bounce them back through | ||
131 | // GitHub auth, which will eventually land them back here. | ||
132 | u, err := m.JWTManager.Validate(authCookie.Value) | ||
133 | if err != nil { | ||
134 | c.Logger().Debugf("Error validating JWT: %w", err) | ||
135 | return c.Redirect(http.StatusFound, "/logout") | ||
136 | } | ||
137 | |||
138 | c.Set(authPrincipalContextKey, u) | ||
139 | return next(c) | ||
140 | } | ||
141 | } | ||
142 | |||
143 | // HandleCompleteLogin handles the Oauth 2 code flow. It receives the auth code | ||
144 | // and uses that to retrieve the auth token. This sets the user's auth cookie | ||
145 | // to a authenticated JWT. | ||
146 | // | ||
147 | // This is redirected-to by the Oauth authorization server and should never be | ||
148 | // hit directly by a user or script. | ||
149 | func (m *AuthenticationMiddleware) HandleCompleteLogin(c echo.Context) error { | ||
150 | ctx := context.Background() | ||
151 | |||
152 | code, state := c.QueryParam("code"), c.QueryParam("state") | ||
153 | if code == "" || state == "" { | ||
154 | return echo.ErrBadRequest | ||
155 | } | ||
156 | |||
157 | ghState, err := c.Cookie(gitHubStateCookie) | ||
158 | if err != nil || ghState.Value == "" { | ||
159 | return echo.ErrBadRequest | ||
160 | } | ||
161 | |||
162 | if ghState.Value != state { | ||
163 | return echo.ErrBadRequest | ||
164 | } | ||
165 | |||
166 | token, err := m.GitHub.GetTokens(code) | ||
167 | if err != nil { | ||
168 | return echo.ErrBadRequest | ||
169 | } | ||
170 | |||
171 | user, err := m.GitHub.GetUsernameWithToken(token.AccessToken) | ||
172 | if err != nil { | ||
173 | c.Logger().Debugf("Error getting GitHub username with token: %w", err) | ||
174 | return echo.ErrUnauthorized | ||
175 | } | ||
176 | |||
177 | dbUser, err := m.Store.Get(ctx, user) | ||
178 | if err != nil { | ||
179 | c.Logger().Errorf("GitHub user %s does not have access to app", user) | ||
180 | return echo.ErrUnauthorized | ||
181 | } | ||
182 | |||
183 | jwt, sk, err := m.JWTManager.CreateForUser(dbUser) | ||
184 | if err != nil { | ||
185 | return echo.ErrInternalServerError | ||
186 | } | ||
187 | |||
188 | dbUser.AddKey(sk) | ||
189 | dbUser.GCKeys() // This is a convenient place to do it | ||
190 | |||
191 | dbUser.AddToken(&models.AuthToken{ | ||
192 | Kind: "github", | ||
193 | Token: token.AccessToken, | ||
194 | RefreshToken: token.RefreshToken, | ||
195 | }) | ||
196 | |||
197 | if err := m.Store.Put(ctx, dbUser); err != nil { | ||
198 | return echo.ErrInternalServerError | ||
199 | } | ||
200 | |||
201 | c.SetCookie(&http.Cookie{ | ||
202 | Name: gitHubTokenCookie, | ||
203 | Value: jwt, | ||
204 | Path: "/", | ||
205 | MaxAge: int(m.CookieDuration.Seconds()), | ||
206 | Secure: true, | ||
207 | HttpOnly: true, | ||
208 | SameSite: http.SameSiteStrictMode, | ||
209 | }) | ||
210 | |||
211 | return c.Redirect(http.StatusFound, "/") | ||
212 | } | ||
diff --git a/app/models/account.go b/app/models/account.go new file mode 100644 index 0000000..0ae1821 --- /dev/null +++ b/app/models/account.go | |||
@@ -0,0 +1,115 @@ | |||
1 | package models | ||
2 | |||
3 | import ( | ||
4 | "context" | ||
5 | "fmt" | ||
6 | "time" | ||
7 | |||
8 | "code.crute.us/mcrute/golib/db/mongodb" | ||
9 | ) | ||
10 | |||
11 | const accountCol = "accounts" | ||
12 | |||
13 | type AccountStore interface { | ||
14 | List(context.Context) ([]*Account, error) | ||
15 | ListForUser(context.Context, *User) ([]*Account, error) | ||
16 | Get(context.Context, string) (*Account, error) // Error on not found | ||
17 | GetForUser(context.Context, string, *User) (*Account, error) // Error on not found | ||
18 | Put(context.Context, *Account) error | ||
19 | Delete(context.Context, *Account) error | ||
20 | } | ||
21 | |||
22 | type Account struct { | ||
23 | ShortName string `bson:"_id"` | ||
24 | AccountType string | ||
25 | AccountNumber int | ||
26 | Name string | ||
27 | ConsoleSessionDuration time.Duration | ||
28 | VaultMaterial string | ||
29 | DefaultRegion string | ||
30 | Users []string | ||
31 | } | ||
32 | |||
33 | func (a *Account) ConsoleSessionDurationSecs() int64 { | ||
34 | return int64(a.ConsoleSessionDuration.Seconds()) | ||
35 | } | ||
36 | |||
37 | func (a *Account) CanAccess(u *User) bool { | ||
38 | if u.IsAdmin { | ||
39 | return true | ||
40 | } | ||
41 | // Linear search should be fine for now, these lists are pretty small | ||
42 | for _, n := range a.Users { | ||
43 | if n == u.Username { | ||
44 | return true | ||
45 | } | ||
46 | } | ||
47 | return false | ||
48 | } | ||
49 | |||
50 | type MongoDbAccountStore struct { | ||
51 | Db *mongodb.Mongo | ||
52 | } | ||
53 | |||
54 | // List returns all accounts in the system. | ||
55 | func (s *MongoDbAccountStore) List(ctx context.Context) ([]*Account, error) { | ||
56 | var out []*Account | ||
57 | if err := s.Db.FindAll(ctx, accountCol, &out); err != nil { | ||
58 | return nil, err | ||
59 | } | ||
60 | return out, nil | ||
61 | } | ||
62 | |||
63 | // ListForUser returns all accounts for which the user has access. This is the | ||
64 | // authorized version of List. | ||
65 | // | ||
66 | // Note this does not handle the case where a user is an admin but not | ||
67 | // explicitly listed in the allowed users list for an account. For that case | ||
68 | // just use List directly. | ||
69 | func (s *MongoDbAccountStore) ListForUser(ctx context.Context, u *User) ([]*Account, error) { | ||
70 | var out []*Account | ||
71 | filter := mongodb.AnyInTopLevelArray("Users", u.Username) | ||
72 | if err := s.Db.FindAllByFilter(ctx, accountCol, filter, &out); err != nil { | ||
73 | return nil, err | ||
74 | } | ||
75 | return out, nil | ||
76 | } | ||
77 | |||
78 | func (s *MongoDbAccountStore) Get(ctx context.Context, id string) (*Account, error) { | ||
79 | var a Account | ||
80 | if err := s.Db.FindOneById(ctx, accountCol, id, &a); err != nil { | ||
81 | return nil, err | ||
82 | } | ||
83 | return &a, nil | ||
84 | } | ||
85 | |||
86 | // GetForUser returns an account if the user has access to this account, | ||
87 | // otherwise it returns an error. This is the authorized version of Get. | ||
88 | func (s *MongoDbAccountStore) GetForUser(ctx context.Context, id string, u *User) (*Account, error) { | ||
89 | a, err := s.Get(ctx, id) | ||
90 | if err != nil { | ||
91 | return nil, err | ||
92 | } | ||
93 | |||
94 | if !a.CanAccess(u) { | ||
95 | return nil, fmt.Errorf("User does not have access to account") | ||
96 | } | ||
97 | |||
98 | return a, nil | ||
99 | } | ||
100 | |||
101 | func (s *MongoDbAccountStore) Put(ctx context.Context, a *Account) error { | ||
102 | if err := s.Db.ReplaceOneById(ctx, accountCol, a.ShortName, a); err != nil { | ||
103 | return err | ||
104 | } | ||
105 | return nil | ||
106 | } | ||
107 | |||
108 | func (s *MongoDbAccountStore) Delete(ctx context.Context, a *Account) error { | ||
109 | if err := s.Db.DeleteOneById(ctx, accountCol, a.ShortName); err != nil { | ||
110 | return err | ||
111 | } | ||
112 | return nil | ||
113 | } | ||
114 | |||
115 | var _ AccountStore = (*MongoDbAccountStore)(nil) | ||
diff --git a/app/models/session_key.go b/app/models/session_key.go new file mode 100644 index 0000000..64ac7e0 --- /dev/null +++ b/app/models/session_key.go | |||
@@ -0,0 +1,202 @@ | |||
1 | package models | ||
2 | |||
3 | import ( | ||
4 | "crypto" | ||
5 | "crypto/ecdsa" | ||
6 | "crypto/elliptic" | ||
7 | "crypto/rand" | ||
8 | "crypto/x509" | ||
9 | "encoding/base64" | ||
10 | "encoding/hex" | ||
11 | "fmt" | ||
12 | "time" | ||
13 | |||
14 | "go.mongodb.org/mongo-driver/bson" | ||
15 | ) | ||
16 | |||
17 | // SessionKey represents a public and sometimes private key-pair for a user | ||
18 | // that will be stored on the user's record in the user store. These keys are | ||
19 | // used for signing authentication JWTs. | ||
20 | // | ||
21 | // This object is designed to be serialized to BSON. Other serializations can | ||
22 | // be added in the future as needed. | ||
23 | // | ||
24 | // There are two flavors of this record. A record with a private key (which | ||
25 | // implies a public key) is a key that the service generated and is used by the | ||
26 | // service to sign JWTs for the user. The private key is never given to the | ||
27 | // user. The private key is only used in the CreateToken flow, never the Verify | ||
28 | // flow. Currently (as of Nov 2021) the application sets a near-future NotAfter | ||
29 | // date and these get garbage collected. It might be nice to re-use them in the | ||
30 | // future for a while but it's not all that important. | ||
31 | // | ||
32 | // The other flavor of this key will have a public key but no private key. | ||
33 | // These are service keys. Service keys are given to programmatic actors that | ||
34 | // need to be able to mint their own JWTs for authentication to the service. | ||
35 | // For these keys the client will construct their own JWT and sign it with the | ||
36 | // private key and the service will validate the signature with the public key. | ||
37 | // These keys (as of Nov 2021) do not expire, though they can be revoked. | ||
38 | type SessionKey struct { | ||
39 | KeyId string | ||
40 | Description string | ||
41 | Revoked *time.Time | ||
42 | NotAfter *time.Time | ||
43 | NotBefore *time.Time | ||
44 | PublicKey crypto.PublicKey | ||
45 | PrivateKey *ecdsa.PrivateKey | ||
46 | } | ||
47 | |||
48 | func GenerateSessionKey(ttl time.Duration) (*SessionKey, error) { | ||
49 | pk, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) | ||
50 | if err != nil { | ||
51 | return nil, err | ||
52 | } | ||
53 | |||
54 | key := make([]byte, 8) | ||
55 | if _, err := rand.Read(key); err != nil { | ||
56 | return nil, err | ||
57 | } | ||
58 | |||
59 | now := time.Now() | ||
60 | notAfter := now.Add(ttl) | ||
61 | |||
62 | return &SessionKey{ | ||
63 | KeyId: hex.EncodeToString(key), | ||
64 | Revoked: nil, | ||
65 | NotAfter: ¬After, | ||
66 | NotBefore: &now, | ||
67 | PublicKey: pk.Public(), | ||
68 | PrivateKey: pk, | ||
69 | }, nil | ||
70 | } | ||
71 | |||
72 | // IsGarbage checks to determine if a key is garbage that should be collected. | ||
73 | // The definition of garbage is similar to the inversion of the definition of | ||
74 | // vaild but revoked keys are not considered to be garbage since they may be | ||
75 | // useful for auditing later. Also keys that are not yet valid are not garbage. | ||
76 | func (s *SessionKey) IsGarbage() bool { | ||
77 | if s.Revoked != nil { | ||
78 | return false | ||
79 | } | ||
80 | |||
81 | if s.NotBefore != nil && s.NotBefore.Before(time.Now()) { | ||
82 | return false | ||
83 | } | ||
84 | |||
85 | if s.NotAfter != nil && s.NotAfter.After(time.Now()) { | ||
86 | return true | ||
87 | } | ||
88 | |||
89 | return false | ||
90 | } | ||
91 | |||
92 | // IsValid checks the various dates in the SessionKey to verify that they are | ||
93 | // valid and in-range for use. This should be called before trusting this key | ||
94 | // for any use. | ||
95 | func (s *SessionKey) IsValid() bool { | ||
96 | if s.Revoked != nil { | ||
97 | return false | ||
98 | } | ||
99 | |||
100 | if s.NotBefore != nil && s.NotBefore.Before(time.Now()) { | ||
101 | return false | ||
102 | } | ||
103 | |||
104 | if s.NotAfter != nil && s.NotAfter.After(time.Now()) { | ||
105 | return false | ||
106 | } | ||
107 | |||
108 | return true | ||
109 | } | ||
110 | |||
111 | func (s *SessionKey) MarshalBSON() ([]byte, error) { | ||
112 | var err error | ||
113 | var pub, priv []byte | ||
114 | |||
115 | if s.PrivateKey != nil { | ||
116 | priv, err = x509.MarshalECPrivateKey(s.PrivateKey) | ||
117 | if err != nil { | ||
118 | return nil, err | ||
119 | } | ||
120 | } | ||
121 | |||
122 | // If there's a private key and a public key set then just save the private | ||
123 | // key. The private key already contains a copy of the public key. | ||
124 | if s.PublicKey != nil && s.PrivateKey == nil { | ||
125 | pub, err = x509.MarshalPKIXPublicKey(s.PublicKey) | ||
126 | if err != nil { | ||
127 | return nil, err | ||
128 | } | ||
129 | } | ||
130 | |||
131 | return bson.Marshal(struct { | ||
132 | KeyId string | ||
133 | Revoked *time.Time | ||
134 | NotAfter *time.Time | ||
135 | NotBefore *time.Time | ||
136 | PublicKey string | ||
137 | PrivateKey string | ||
138 | }{ | ||
139 | s.KeyId, | ||
140 | s.Revoked, s.NotAfter, s.NotBefore, | ||
141 | base64.StdEncoding.EncodeToString(pub), | ||
142 | base64.StdEncoding.EncodeToString(priv), | ||
143 | }) | ||
144 | } | ||
145 | |||
146 | func (s *SessionKey) UnmarshalBSON(d []byte) error { | ||
147 | v := struct { | ||
148 | KeyId string | ||
149 | Revoked *time.Time | ||
150 | NotAfter *time.Time | ||
151 | NotBefore *time.Time | ||
152 | PublicKey string | ||
153 | PrivateKey string | ||
154 | }{} | ||
155 | if err := bson.Unmarshal(d, &v); err != nil { | ||
156 | return err | ||
157 | } | ||
158 | |||
159 | s.KeyId = v.KeyId | ||
160 | s.Revoked = v.Revoked | ||
161 | s.NotAfter = v.NotAfter | ||
162 | s.NotBefore = v.NotBefore | ||
163 | |||
164 | if v.PrivateKey != "" { | ||
165 | privb, err := base64.StdEncoding.DecodeString(v.PrivateKey) | ||
166 | if err != nil { | ||
167 | return err | ||
168 | } | ||
169 | |||
170 | priv, err := x509.ParseECPrivateKey(privb) | ||
171 | if err != nil { | ||
172 | return err | ||
173 | } | ||
174 | |||
175 | s.PrivateKey = priv | ||
176 | s.PublicKey = priv.Public() | ||
177 | } | ||
178 | |||
179 | // If there was a private key then the public key was already set by | ||
180 | // decoding that private key. No need to do this a second time (also it's | ||
181 | // rather unlikely that both would be set). | ||
182 | if v.PublicKey != "" && s.PublicKey == nil { | ||
183 | pubb, err := base64.StdEncoding.DecodeString(v.PublicKey) | ||
184 | if err != nil { | ||
185 | return err | ||
186 | } | ||
187 | |||
188 | pubp, err := x509.ParsePKIXPublicKey(pubb) | ||
189 | if err != nil { | ||
190 | return err | ||
191 | } | ||
192 | |||
193 | pub, ok := pubp.(*ecdsa.PublicKey) | ||
194 | if !ok { | ||
195 | return fmt.Errorf("Failed to convert public key to *ecdsa.PublicKey") | ||
196 | } | ||
197 | |||
198 | s.PublicKey = pub | ||
199 | } | ||
200 | |||
201 | return nil | ||
202 | } | ||
diff --git a/app/models/user.go b/app/models/user.go new file mode 100644 index 0000000..0cbd92d --- /dev/null +++ b/app/models/user.go | |||
@@ -0,0 +1,99 @@ | |||
1 | package models | ||
2 | |||
3 | import ( | ||
4 | "context" | ||
5 | |||
6 | "code.crute.us/mcrute/golib/db/mongodb" | ||
7 | ) | ||
8 | |||
9 | const userCol = "users" | ||
10 | |||
11 | type UserStore interface { | ||
12 | List(context.Context) ([]*User, error) | ||
13 | Get(context.Context, string) (*User, error) // Error on not found | ||
14 | Put(context.Context, *User) error | ||
15 | Delete(context.Context, *User) error | ||
16 | } | ||
17 | |||
18 | type AuthToken struct { | ||
19 | Kind string | ||
20 | Token string | ||
21 | RefreshToken string | ||
22 | } | ||
23 | |||
24 | type User struct { | ||
25 | Username string `bson:"_id"` | ||
26 | IsAdmin bool | ||
27 | IsService bool | ||
28 | Keys map[string]*SessionKey // kid -> key | ||
29 | AuthTokens map[string]*AuthToken // kind -> token | ||
30 | } | ||
31 | |||
32 | // GCKeys garbage collects keys that are no longer valid | ||
33 | func (u *User) GCKeys() { | ||
34 | for k, v := range u.Keys { | ||
35 | if v.IsGarbage() { | ||
36 | delete(u.Keys, k) | ||
37 | } | ||
38 | } | ||
39 | } | ||
40 | |||
41 | // GetKey returns a key for a key ID. It will only return valid keys. | ||
42 | func (u *User) GetKey(kid string) *SessionKey { | ||
43 | if u.Keys != nil { | ||
44 | if k := u.Keys[kid]; k != nil && k.IsValid() { | ||
45 | return k | ||
46 | } | ||
47 | } | ||
48 | return nil | ||
49 | } | ||
50 | |||
51 | func (u *User) AddKey(k *SessionKey) { | ||
52 | if u.Keys == nil { | ||
53 | u.Keys = map[string]*SessionKey{} | ||
54 | } | ||
55 | u.Keys[k.KeyId] = k | ||
56 | } | ||
57 | |||
58 | func (u *User) AddToken(t *AuthToken) { | ||
59 | if u.AuthTokens == nil { | ||
60 | u.AuthTokens = map[string]*AuthToken{} | ||
61 | } | ||
62 | u.AuthTokens[t.Kind] = t | ||
63 | } | ||
64 | |||
65 | type MongoDbUserStore struct { | ||
66 | Db *mongodb.Mongo | ||
67 | } | ||
68 | |||
69 | func (s *MongoDbUserStore) List(ctx context.Context) ([]*User, error) { | ||
70 | var out []*User | ||
71 | if err := s.Db.FindAll(ctx, userCol, &out); err != nil { | ||
72 | return nil, err | ||
73 | } | ||
74 | return out, nil | ||
75 | } | ||
76 | |||
77 | func (s *MongoDbUserStore) Get(ctx context.Context, username string) (*User, error) { | ||
78 | var u User | ||
79 | if err := s.Db.FindOneById(ctx, userCol, username, &u); err != nil { | ||
80 | return nil, err | ||
81 | } | ||
82 | return &u, nil | ||
83 | } | ||
84 | |||
85 | func (s *MongoDbUserStore) Put(ctx context.Context, u *User) error { | ||
86 | if err := s.Db.ReplaceOneById(ctx, userCol, u.Username, u); err != nil { | ||
87 | return err | ||
88 | } | ||
89 | return nil | ||
90 | } | ||
91 | |||
92 | func (s *MongoDbUserStore) Delete(ctx context.Context, u *User) error { | ||
93 | if err := s.Db.DeleteOneById(ctx, userCol, u.Username); err != nil { | ||
94 | return err | ||
95 | } | ||
96 | return nil | ||
97 | } | ||
98 | |||
99 | var _ UserStore = (*MongoDbUserStore)(nil) | ||