From 3d7f26e200d1edefe68eac3b761acde57e244e42 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Mon, 15 Nov 2021 23:08:29 -0800 Subject: echo: add many new things --- echo/controller/content_type_negotiator.go | 59 ++++++++++++++ echo/cookie.go | 15 ++++ echo/echo_default.go | 2 + echo/error_handler.go | 120 +++++++++++++++++++++------- echo/go.mod | 1 + echo/go.sum | 2 + echo/middleware/cache_headers_middleware.go | 20 ++++- echo/middleware/vary.go | 24 ++++++ echo/url_builder.go | 49 ++++++++++++ 9 files changed, 259 insertions(+), 33 deletions(-) create mode 100644 echo/controller/content_type_negotiator.go create mode 100644 echo/cookie.go create mode 100644 echo/middleware/vary.go create mode 100644 echo/url_builder.go diff --git a/echo/controller/content_type_negotiator.go b/echo/controller/content_type_negotiator.go new file mode 100644 index 0000000..273a118 --- /dev/null +++ b/echo/controller/content_type_negotiator.go @@ -0,0 +1,59 @@ +package controller + +import ( + "errors" + "net/http" + "sync" + + "github.com/elnormous/contenttype" + "github.com/labstack/echo/v4" +) + +type ContentTypeNegotiatingHandler struct { + Handlers map[string]echo.HandlerFunc + DefaultHandler echo.HandlerFunc + mediaTypes []contenttype.MediaType + once sync.Once +} + +func errorIsNotAcceptable(err error) bool { + return errors.Is(err, contenttype.ErrNoAcceptableTypeFound) || + errors.Is(err, contenttype.ErrNoAvailableTypeGiven) +} + +func errorIsBadRequest(err error) bool { + return errors.Is(err, contenttype.ErrInvalidMediaType) || + errors.Is(err, contenttype.ErrInvalidMediaRange) || + errors.Is(err, contenttype.ErrInvalidParameter) || + errors.Is(err, contenttype.ErrInvalidExtensionParameter) || + errors.Is(err, contenttype.ErrInvalidWeight) +} + +func (h *ContentTypeNegotiatingHandler) Handle(c echo.Context) error { + h.once.Do(func() { + h.mediaTypes = []contenttype.MediaType{} + for k, _ := range h.Handlers { + h.mediaTypes = append(h.mediaTypes, contenttype.NewMediaType(k)) + } + }) + + handler := h.DefaultHandler + ct, _, err := contenttype.GetAcceptableMediaType(c.Request(), h.mediaTypes) + if err == nil { + handler = h.Handlers[ct.String()] + } else if errorIsNotAcceptable(err) { + return echo.NewHTTPError(http.StatusNotAcceptable) + } else if errorIsBadRequest(err) { + return echo.NewHTTPError(http.StatusBadRequest) + } + + // If negotiation failed but it wasn't an error and there is no default + // handler then the request is still not acceptable + if handler == nil { + return echo.NewHTTPError(http.StatusNotAcceptable) + } + + // Don't force each handler to do this itself to eliminate redundant code + c.Response().Header().Set("Content-Type", ct.String()) + return handler(c) +} diff --git a/echo/cookie.go b/echo/cookie.go new file mode 100644 index 0000000..9f4f26a --- /dev/null +++ b/echo/cookie.go @@ -0,0 +1,15 @@ +package echo + +import ( + "time" + + "github.com/labstack/echo/v4" +) + +func DeleteAllCookies(c echo.Context) { + for _, k := range c.Request().Cookies() { + k.Expires = time.Unix(0, 0) + k.MaxAge = -1 + c.SetCookie(k) + } +} diff --git a/echo/echo_default.go b/echo/echo_default.go index 92a5cfd..569d686 100644 --- a/echo/echo_default.go +++ b/echo/echo_default.go @@ -173,6 +173,8 @@ func NewDefaultEchoWithConfig(c EchoConfig) (*EchoWrapper, error) { // Only install template handlers if the path and glob are set if templates != nil && c.TemplateGlob != nil { + // TODO: Should assert the presence of required templates: 404.tpl + // 40x.tpl 50x.tpl header.tpl footer.tpl e.HTTPErrorHandler = ErrorHandler(templates, c.TemplateFunctions) tr, err := NewTemplateRenderer(templates, *c.TemplateGlob, c.TemplateFunctions) diff --git a/echo/error_handler.go b/echo/error_handler.go index bfbcfb6..bb4102b 100644 --- a/echo/error_handler.go +++ b/echo/error_handler.go @@ -7,19 +7,94 @@ import ( "io/fs" "net/http" + "github.com/elnormous/contenttype" "github.com/labstack/echo/v4" ) -// Copied from echo and tweaked to make our errors nicer +// TODO: This should allow plugging in other content types +// TODO: This should also be refactored into something prettier func ErrorHandler(templates fs.FS, funcs template.FuncMap) func(error, echo.Context) { + handleHtml := func(c echo.Context, he *echo.HTTPError) error { + t, err := template.New("").Funcs(funcs).ParseFS( + templates, + "404.tpl", + "40x.tpl", + "50x.tpl", + "header.tpl", + "footer.tpl", + ) + if err != nil { + return err + } + + path := "50x.tpl" + if he.Code == 404 { + path = "404.tpl" + } else if he.Code >= 400 && he.Code <= 499 { + path = "40x.tpl" + } + + buf := bytes.Buffer{} + if err = t.ExecuteTemplate(&buf, path, nil); err != nil { + err = c.String(he.Code, fmt.Sprintf("%s", he.Message)) + } + + return c.HTMLBlob(he.Code, buf.Bytes()) + } + + handlePlain := func(c echo.Context, he *echo.HTTPError) error { + return c.String(he.Code, fmt.Sprintf("%s", he.Message)) + } + + handleJson := func(c echo.Context, he *echo.HTTPError) error { + code := he.Code + message := he.Message + if m, ok := he.Message.(string); ok { + if c.Echo().Debug { + message = echo.Map{"message": m, "error": he.Error()} + } else { + message = echo.Map{"message": m} + } + } + + return c.JSON(code, message) + } + + errorWhileErroring := func(c echo.Context, err interface{}) { + c.Echo().Logger.Error(err) + if c.Echo().Debug { + c.JSON(http.StatusInternalServerError, &echo.HTTPError{ + Code: http.StatusInternalServerError, + Message: fmt.Sprintf("Error while processing error page. %w", err), + }) + } else { + c.JSON(http.StatusInternalServerError, &echo.HTTPError{ + Code: http.StatusInternalServerError, + Message: http.StatusText(http.StatusInternalServerError), + }) + } + } + + handlers := map[string]func(echo.Context, *echo.HTTPError) error{ + "text/plain": handlePlain, + "text/html": handleHtml, + "text/json": handleJson, + "application/json": handleJson, + } + + // This is hand maintained here because order is important for negotiation, + // especially in the case of */* + hIndex := []contenttype.MediaType{ + contenttype.NewMediaType("text/json"), + contenttype.NewMediaType("application/json"), + contenttype.NewMediaType("text/plain"), + contenttype.NewMediaType("text/html"), + } + return func(err error, c echo.Context) { defer func() { if r := recover(); r != nil { - if c.Echo().Debug { - c.String(http.StatusInternalServerError, fmt.Sprintf("Error while processing error page. %s", r)) - } else { - c.String(http.StatusInternalServerError, "Error while processing error page.") - } + errorWhileErroring(c, r) } }() @@ -37,41 +112,26 @@ func ErrorHandler(templates fs.FS, funcs template.FuncMap) func(error, echo.Cont } } - t, err := template.New("").Funcs(funcs).ParseFS( - templates, - "404.tpl", - "40x.tpl", - "50x.tpl", - "header.tpl", - "footer.tpl", - ) + ct, _, err := contenttype.GetAcceptableMediaType(c.Request(), hIndex) if err != nil { - he = &echo.HTTPError{ - Code: http.StatusInternalServerError, - Message: http.StatusText(http.StatusInternalServerError), - } + c.Echo().Logger.Error("Error negotiating content type in error handler, using json") + ct = contenttype.NewMediaType("text/json") } - path := "50x.tpl" - if he.Code == 404 { - path = "404.tpl" - } else if he.Code >= 400 && he.Code <= 499 { - path = "40x.tpl" + handle, ok := handlers[ct.String()] + if !ok { + c.Echo().Logger.Errorf("Error handler content type %s is unknown", ct.String()) + handle = handleJson } - // Send response if !c.Response().Committed { if c.Request().Method == http.MethodHead { // Issue #608 err = c.NoContent(he.Code) } else { - buf := bytes.Buffer{} - if err = t.ExecuteTemplate(&buf, path, nil); err != nil { - err = c.String(he.Code, fmt.Sprintf("%s", he.Message)) - } - c.HTMLBlob(he.Code, buf.Bytes()) + err = handle(c, he) } if err != nil { - c.Echo().Logger.Error(err) + errorWhileErroring(c, err) } } } diff --git a/echo/go.mod b/echo/go.mod index 942b4e5..44b7e1d 100644 --- a/echo/go.mod +++ b/echo/go.mod @@ -6,6 +6,7 @@ replace code.crute.us/mcrute/golib => ../ require ( code.crute.us/mcrute/golib v0.1.1 + github.com/elnormous/contenttype v1.0.0 github.com/labstack/echo/v4 v4.6.1 github.com/labstack/gommon v0.3.1 github.com/prometheus/client_golang v1.11.0 diff --git a/echo/go.sum b/echo/go.sum index 4f3a541..7bcd71c 100644 --- a/echo/go.sum +++ b/echo/go.sum @@ -53,6 +53,8 @@ github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGX github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/elnormous/contenttype v1.0.0 h1:cTLou7K7uQMsPEmRiTJosAznsPcYuoBmXMrFAf86t2A= +github.com/elnormous/contenttype v1.0.0/go.mod h1:ngVcyGGU8pnn4QJ5sL4StrNgc/wmXZXy5IQSBuHOFPg= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= diff --git a/echo/middleware/cache_headers_middleware.go b/echo/middleware/cache_headers_middleware.go index f168dfd..73df9b2 100644 --- a/echo/middleware/cache_headers_middleware.go +++ b/echo/middleware/cache_headers_middleware.go @@ -10,19 +10,33 @@ import ( ) var ( + CacheNeverMiddleware = CacheHeadersMiddleware(0) CacheOneHourMiddleware = CacheHeadersMiddleware(1 * time.Hour) CacheOneDayMiddleware = CacheHeadersMiddleware(1 * gltime.Day) CacheOneMonthMiddleware = CacheHeadersMiddleware(30 * gltime.Day) ) +func setHeaderMissing(c echo.Context, name string, value string) { + h := c.Response().Header() + if v := h.Get(name); v == "" { + h.Set(name, value) + } +} + func CacheHeadersMiddleware(d time.Duration) echo.MiddlewareFunc { ds := int(d.Seconds()) return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - c.Response().Header().Set("Vary", "Accept-Encoding") - c.Response().Header().Set("Cache-Control", fmt.Sprintf("public, max-age=%d", ds)) - c.Response().Header().Set("Expires", time.Now().Add(d).Format(time.RFC1123)) + c.Response().Header().Add(echo.HeaderVary, "Accept-Encoding") + + if ds == 0 { + setHeaderMissing(c, "Cache-Control", "private, max-age=0, no-cache, no-store") + setHeaderMissing(c, "Expires", time.Now().Add(-time.Hour).Format(time.RFC1123)) + } else { + setHeaderMissing(c, "Cache-Control", fmt.Sprintf("public, max-age=%d", ds)) + setHeaderMissing(c, "Expires", time.Now().Add(d).Format(time.RFC1123)) + } return next(c) } } diff --git a/echo/middleware/vary.go b/echo/middleware/vary.go new file mode 100644 index 0000000..8f87d29 --- /dev/null +++ b/echo/middleware/vary.go @@ -0,0 +1,24 @@ +package middleware + +import ( + "github.com/labstack/echo/v4" +) + +type VaryConfig struct { + Vary []string +} + +func VaryCookie() echo.MiddlewareFunc { + return VaryWithConfig(VaryConfig{Vary: []string{"Cookie"}}) +} + +func VaryWithConfig(cfg VaryConfig) echo.MiddlewareFunc { + return func(next echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + for _, v := range cfg.Vary { + c.Response().Header().Add(echo.HeaderVary, v) + } + return next(c) + } + } +} diff --git a/echo/url_builder.go b/echo/url_builder.go new file mode 100644 index 0000000..955ecb5 --- /dev/null +++ b/echo/url_builder.go @@ -0,0 +1,49 @@ +package echo + +import ( + "net/url" + "path" + + "github.com/labstack/echo/v4" +) + +// URLBuilder is used to build URLs with optional querystring arguments. This +// is used to build URLs to REST resources within handlers. +// +// This exists because the default Echo reversing logic requires handlers to +// hold references to other handlers to be able to build reverse URLs. This is +// a bad solution to an ugly problem but as the router currently stands there's +// not much that can be done about it. In the future this should go away and be +// replaced by something like named routes in echo. +type URLBuilder struct { + c echo.Context + u *url.URL + q url.Values +} + +func URLFor(c echo.Context, parts ...string) *URLBuilder { + u := &url.URL{ + Scheme: "http", + Host: c.Request().Host, + Path: path.Join(parts...), + } + if c.Request().TLS != nil { + u.Scheme = "https" + } + return &URLBuilder{c, u, nil} +} + +func (b *URLBuilder) Query(k, v string) *URLBuilder { + if b.q == nil { + b.q = url.Values{} + } + b.q.Add(k, v) + return b +} + +func (b *URLBuilder) String() string { + if b.q != nil { + b.u.RawQuery = b.q.Encode() + } + return b.u.String() +} -- cgit v1.2.3