package middleware import ( "bytes" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "net/url" "reflect" "strings" "time" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) const ( HeaderReportTo = "ReportTo" cspExtendContextKey = "__echomw_csp__csp_extend" cspReplaceContextKey = "__echomw_csp__csp_replace" ) func ReplaceCSP(c echo.Context, csp *ContentSecurityPolicyConfig) { c.Set(cspReplaceContextKey, csp) } func ExtendCSP(c echo.Context, csp *ContentSecurityPolicyConfig) { c.Set(cspExtendContextKey, csp) } type ContentSecurityPolicyConfig struct { Skipper middleware.Skipper ReportOnly bool DefaultSrc []CSPDirective `csp:"default-src"` ChildSrc []CSPDirective `csp:"child-src"` ConnectSrc []CSPDirective `csp:"connect-src"` FontSrc []CSPDirective `csp:"font-src"` FrameSrc []CSPDirective `csp:"frame-src"` ImageSrc []CSPDirective `csp:"img-src"` ManifestSrc []CSPDirective `csp:"manifest-src"` MediaSrc []CSPDirective `csp:"media-src"` ObjectSrc []CSPDirective `csp:"object-src"` ScriptSrc []CSPDirective `csp:"script-src"` StyleSrc []CSPDirective `csp:"style-src"` BaseUri []CSPDirective `csp:"base-uri"` Sandbox []CSPSandbox `csp:"sandbox"` FormAction []CSPDirective `csp:"form-action"` UpgradeInsecureRequests bool `csp:"upgrade-insecure-requests"` BlockAllMixedContent bool `csp:"block-all-mixed-content"` FrameAncestors []CSPDirective `csp:"frame-ancestors"` ReportUri []*url.URL `csp:"report-uri"` // deprecated ReportTo []CSPReportTo `csp:"report-to"` // experimental PrefetchSrc []CSPDirective `csp:"prefetch-src"` // experimental ScriptSrcElem []CSPDirective `csp:"script-src-elem"` // experimental StyleSrcElem []CSPDirective `csp:"style-src-elem"` // experimental StyleSrcAttr []CSPDirective `csp:"script-src-attr"` // experimental WorkerSrc []CSPDirective `csp:"worker-src"` // experimental NavigateTo []CSPDirective `csp:"navigate-to"` // experimental RequireTrustedTypesFor []CSPDirective `csp:"require-trusted-types-for"` // experimental } func mergeField[T any](a []T, b []T) []T { if b != nil { v := make([]T, len(a)+len(b)) copy(v, a) copy(v[len(a):], b) return v } else { return a } } // ExtendSimple returns a copy of the current policy extended with some // other policy. Boolean fields are not merged and will retain the value // of the base configuration. func (c *ContentSecurityPolicyConfig) ExtendSimple(o *ContentSecurityPolicyConfig) *ContentSecurityPolicyConfig { return &ContentSecurityPolicyConfig{ Skipper: c.Skipper, ReportOnly: c.ReportOnly, UpgradeInsecureRequests: c.UpgradeInsecureRequests, BlockAllMixedContent: c.BlockAllMixedContent, DefaultSrc: mergeField(c.DefaultSrc, o.DefaultSrc), ChildSrc: mergeField(c.ChildSrc, o.ChildSrc), ConnectSrc: mergeField(c.ConnectSrc, o.ConnectSrc), FontSrc: mergeField(c.FontSrc, o.FontSrc), FrameSrc: mergeField(c.FrameSrc, o.FrameSrc), ImageSrc: mergeField(c.ImageSrc, o.ImageSrc), ManifestSrc: mergeField(c.ManifestSrc, o.ManifestSrc), MediaSrc: mergeField(c.MediaSrc, o.MediaSrc), ObjectSrc: mergeField(c.ObjectSrc, o.ObjectSrc), ScriptSrc: mergeField(c.ScriptSrc, o.ScriptSrc), StyleSrc: mergeField(c.StyleSrc, o.StyleSrc), BaseUri: mergeField(c.BaseUri, o.BaseUri), Sandbox: mergeField(c.Sandbox, o.Sandbox), FormAction: mergeField(c.FormAction, o.FormAction), FrameAncestors: mergeField(c.FrameAncestors, o.FrameAncestors), ReportUri: mergeField(c.ReportUri, o.ReportUri), ReportTo: mergeField(c.ReportTo, o.ReportTo), PrefetchSrc: mergeField(c.PrefetchSrc, o.PrefetchSrc), ScriptSrcElem: mergeField(c.ScriptSrcElem, o.ScriptSrcElem), StyleSrcElem: mergeField(c.StyleSrcElem, o.StyleSrcElem), StyleSrcAttr: mergeField(c.StyleSrcAttr, o.StyleSrcAttr), WorkerSrc: mergeField(c.WorkerSrc, o.WorkerSrc), NavigateTo: mergeField(c.NavigateTo, o.NavigateTo), RequireTrustedTypesFor: mergeField(c.RequireTrustedTypesFor, o.RequireTrustedTypesFor), } } func (c *ContentSecurityPolicyConfig) String() string { st := reflect.TypeOf(*c) sv := reflect.ValueOf(*c) lines := []string{} for i := 0; i < st.NumField(); i++ { cspTag := st.Field(i).Tag.Get("csp") if cspTag == "" { continue } v := sv.Field(i) switch v.Kind() { case reflect.Slice: if v.Cap() == 0 { continue } items := make([]string, v.Cap()) for j := 0; j < v.Cap(); j++ { // Call the String() method if there is one to handle things // like *net.URL instances. Otherwise just treat the value as a // string because it probably is (all CSPDirective types are // strings). if str := v.Index(j).MethodByName("String"); str.IsValid() { items[j] = str.Call(nil)[0].String() } else { items[j] = v.Index(j).String() } } lines = append(lines, fmt.Sprintf("%s %s", cspTag, strings.Join(items, " "))) case reflect.Bool: if v.Bool() { lines = append(lines, cspTag) } } } return strings.Join(lines, "; ") + ";" } type CSPReportTo struct { GroupName string MaxAge time.Duration Endpoints []*url.URL } func (r CSPReportTo) MarshalJSON() ([]byte, error) { ep := []map[string]string{} for _, u := range r.Endpoints { ep = append(ep, map[string]string{"url": u.String()}) } return json.Marshal(map[string]interface{}{ "group": r.GroupName, "max_age": r.MaxAge.Seconds(), "endpoints": ep, }) } type CSPDirective string const ( CSPNone CSPDirective = "'none'" CSPSelf = "'self'" CSPUnsafeInline = "'unsafe-inline'" CSPUnsafeEval = "'unsafe-eval'" CSPUnsafeHashes = "'unsafe-hashes'" CSPStrictDynamic = "'strict-dynamic'" CSPReportSample = "'report-sample'" CSPData = "data:" CSPBlob = "blob:" CSPMediastream = "mediastream:" CSPFilesystem = "filesystem:" CSPHttp = "http:" CSPHttps = "https:" ) func CSPHost(s string) CSPDirective { return CSPDirective(s) } func CSPNonce(n string) CSPDirective { return CSPDirective(fmt.Sprintf("'nonce-%s'", n)) } func CSPShaString(size int, h string) CSPDirective { return CSPDirective(fmt.Sprintf("'sha%d-%s'", size, h)) } func CSPSha256FromBytes(d []byte) CSPDirective { s := sha256.Sum256(d) return CSPShaString(256, base64.StdEncoding.EncodeToString(s[:])) } type CSPSandbox string const ( CSPAllowDownloads CSPSandbox = "allow-downloads" CSPAllowDownloadsNoUser = "allow-downloads-without-user-activation" CSPAllowForms = "allow-forms" CSPAllowModals = "allow-modals" CSPAllowOrientationLock = "allow-orientation-lock" CSPAllowPointerLock = "allow-pointer-lock" CSPAllowPopups = "allow-popups" CSPAllowPopupEscape = "allow-popups-to-escape-sandbox" CSPAllowPresentation = "allow-presentation" CSPAllowSameOrigin = "allow-same-origin" CSPAllowScripts = "allow-scripts" CSPAllowStorageAccessByUser = "allow-storage-access-by-user-activation" CSPAllowTopActivation = "allow-top-navigation" CSPAllowNavigationByUser = "allow-top-navigation-by-user-activation" ) var DefaultContentSecurityPolicyConfig = ContentSecurityPolicyConfig{ Skipper: middleware.DefaultSkipper, DefaultSrc: []CSPDirective{CSPSelf, CSPData}, } func ContentSecurityPolicy() echo.MiddlewareFunc { return ContentSecurityPolicyWithConfig(DefaultContentSecurityPolicyConfig) } func ContentSecurityPolicyWithConfig(config ContentSecurityPolicyConfig) echo.MiddlewareFunc { if config.Skipper == nil { config.Skipper = DefaultContentSecurityPolicyConfig.Skipper } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if config.Skipper(c) { return next(c) } // This has to hook after the template runs but before the headers // are written because some template helper functions want to modify // the CSP state and if it renders too early that won't work. c.Response().Before(func() { liveConfig := &config if replace, ok := c.Get(cspReplaceContextKey).(*ContentSecurityPolicyConfig); ok { liveConfig = replace } else if extend, ok := c.Get(cspExtendContextKey).(*ContentSecurityPolicyConfig); ok { liveConfig = config.ExtendSimple(extend) } h := c.Response().Header() if liveConfig.ReportOnly { h.Set(echo.HeaderContentSecurityPolicyReportOnly, liveConfig.String()) } else { h.Set(echo.HeaderContentSecurityPolicy, liveConfig.String()) } if liveConfig.ReportTo != nil { rt := bytes.Buffer{} je := json.NewEncoder(&rt) for _, r := range liveConfig.ReportTo { _ = je.Encode(r) rt.WriteString(", ") } h.Set(HeaderReportTo, rt.String()) } }) return next(c) } } }