From 925de3239c9eab61a3a1275a554508f46a172709 Mon Sep 17 00:00:00 2001 From: Mike Crute Date: Wed, 18 Oct 2023 12:45:07 -0700 Subject: echo: allow extending CSP at request-time --- echo/middleware/csp.go | 109 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 94 insertions(+), 15 deletions(-) diff --git a/echo/middleware/csp.go b/echo/middleware/csp.go index 14047eb..644dfe1 100644 --- a/echo/middleware/csp.go +++ b/echo/middleware/csp.go @@ -2,6 +2,8 @@ package middleware import ( "bytes" + "crypto/sha256" + "encoding/base64" "encoding/json" "fmt" "net/url" @@ -13,7 +15,19 @@ import ( "github.com/labstack/echo/v4/middleware" ) -const HeaderReportTo = "ReportTo" +const ( + HeaderReportTo = "ReportTo" + cspExtendContextKey = "__echomw_csp__csp_extend" + cspReplaceContextKey = "__echomw_csp__csp_replace" +) + +func ReplaceCSP(c echo.Context, csp *ContentSecurityPolicyConfig) { + c.Set(cspReplaceContextKey, csp) +} + +func ExtendCSP(c echo.Context, csp *ContentSecurityPolicyConfig) { + c.Set(cspExtendContextKey, csp) +} type ContentSecurityPolicyConfig struct { Skipper middleware.Skipper @@ -46,6 +60,53 @@ type ContentSecurityPolicyConfig struct { RequireTrustedTypesFor []CSPDirective `csp:"require-trusted-types-for"` // experimental } +func mergeField[T any](a []T, b []T) []T { + if b != nil { + v := make([]T, len(a)+len(b)) + copy(v, a) + copy(v[len(a):], b) + return v + } else { + return a + } +} + +// ExtendSimple returns a copy of the current policy extended with some +// other policy. Boolean fields are not merged and will retain the value +// of the base configuration. +func (c *ContentSecurityPolicyConfig) ExtendSimple(o *ContentSecurityPolicyConfig) *ContentSecurityPolicyConfig { + return &ContentSecurityPolicyConfig{ + Skipper: c.Skipper, + ReportOnly: c.ReportOnly, + UpgradeInsecureRequests: c.UpgradeInsecureRequests, + BlockAllMixedContent: c.BlockAllMixedContent, + DefaultSrc: mergeField(c.DefaultSrc, o.DefaultSrc), + ChildSrc: mergeField(c.ChildSrc, o.ChildSrc), + ConnectSrc: mergeField(c.ConnectSrc, o.ConnectSrc), + FontSrc: mergeField(c.FontSrc, o.FontSrc), + FrameSrc: mergeField(c.FrameSrc, o.FrameSrc), + ImageSrc: mergeField(c.ImageSrc, o.ImageSrc), + ManifestSrc: mergeField(c.ManifestSrc, o.ManifestSrc), + MediaSrc: mergeField(c.MediaSrc, o.MediaSrc), + ObjectSrc: mergeField(c.ObjectSrc, o.ObjectSrc), + ScriptSrc: mergeField(c.ScriptSrc, o.ScriptSrc), + StyleSrc: mergeField(c.StyleSrc, o.StyleSrc), + BaseUri: mergeField(c.BaseUri, o.BaseUri), + Sandbox: mergeField(c.Sandbox, o.Sandbox), + FormAction: mergeField(c.FormAction, o.FormAction), + FrameAncestors: mergeField(c.FrameAncestors, o.FrameAncestors), + ReportUri: mergeField(c.ReportUri, o.ReportUri), + ReportTo: mergeField(c.ReportTo, o.ReportTo), + PrefetchSrc: mergeField(c.PrefetchSrc, o.PrefetchSrc), + ScriptSrcElem: mergeField(c.ScriptSrcElem, o.ScriptSrcElem), + StyleSrcElem: mergeField(c.StyleSrcElem, o.StyleSrcElem), + StyleSrcAttr: mergeField(c.StyleSrcAttr, o.StyleSrcAttr), + WorkerSrc: mergeField(c.WorkerSrc, o.WorkerSrc), + NavigateTo: mergeField(c.NavigateTo, o.NavigateTo), + RequireTrustedTypesFor: mergeField(c.RequireTrustedTypesFor, o.RequireTrustedTypesFor), + } +} + func (c *ContentSecurityPolicyConfig) String() string { st := reflect.TypeOf(*c) sv := reflect.ValueOf(*c) @@ -135,6 +196,11 @@ func CSPShaString(size int, h string) CSPDirective { return CSPDirective(fmt.Sprintf("'sha%d-%s'", size, h)) } +func CSPSha256FromBytes(d []byte) CSPDirective { + s := sha256.Sum256(d) + return CSPShaString(256, base64.StdEncoding.EncodeToString(s[:])) +} + type CSPSandbox string const ( @@ -174,22 +240,35 @@ func ContentSecurityPolicyWithConfig(config ContentSecurityPolicyConfig) echo.Mi return next(c) } - h := c.Response().Header() - if config.ReportOnly { - h.Set(echo.HeaderContentSecurityPolicyReportOnly, config.String()) - } else { - h.Set(echo.HeaderContentSecurityPolicy, config.String()) - } + // This has to hook after the template runs but before the headers + // are written because some template helper functions want to modify + // the CSP state and if it renders too early that won't work. + c.Response().Before(func() { + liveConfig := &config - if config.ReportTo != nil { - rt := bytes.Buffer{} - je := json.NewEncoder(&rt) - for _, r := range config.ReportTo { - _ = je.Encode(r) - rt.WriteString(", ") + if replace, ok := c.Get(cspReplaceContextKey).(*ContentSecurityPolicyConfig); ok { + liveConfig = replace + } else if extend, ok := c.Get(cspExtendContextKey).(*ContentSecurityPolicyConfig); ok { + liveConfig = config.ExtendSimple(extend) } - h.Set(HeaderReportTo, rt.String()) - } + + h := c.Response().Header() + if liveConfig.ReportOnly { + h.Set(echo.HeaderContentSecurityPolicyReportOnly, liveConfig.String()) + } else { + h.Set(echo.HeaderContentSecurityPolicy, liveConfig.String()) + } + + if liveConfig.ReportTo != nil { + rt := bytes.Buffer{} + je := json.NewEncoder(&rt) + for _, r := range liveConfig.ReportTo { + _ = je.Encode(r) + rt.WriteString(", ") + } + h.Set(HeaderReportTo, rt.String()) + } + }) return next(c) } -- cgit v1.2.3