aboutsummaryrefslogtreecommitdiff
path: root/echo/middleware/recover_middleware.go
blob: af4259af2598f5e634e8ccd1d07105160efcc7aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
package middleware

import (
	"bytes"
	"fmt"
	"net/http"
	"runtime"
	"text/template"

	"github.com/labstack/echo/v4"
	"github.com/labstack/echo/v4/middleware"
	"github.com/labstack/gommon/log"
)

// This is mostly copied from the upstream and modified to:
//
// * Report a stack up to 1MB
// * Print the stack trace to the output if in debug

type (
	// RecoverConfig defines the config for Recover middleware.
	RecoverConfig struct {
		// Skipper defines a function to skip middleware.
		Skipper middleware.Skipper

		// Size of the stack to be printed.
		// Optional. Default value 4KB.
		StackSize int `yaml:"stack_size"`

		// DisableStackAll disables formatting stack traces of all other goroutines
		// into buffer after the trace for the current goroutine.
		// Optional. Default value false.
		DisableStackAll bool `yaml:"disable_stack_all"`

		// DisablePrintStack disables printing stack trace.
		// Optional. Default value as false.
		DisablePrintStack bool `yaml:"disable_print_stack"`

		// LogLevel is log level to printing stack trace.
		// Optional. Default value 0 (Print).
		LogLevel log.Lvl
	}
)

var (
	// DefaultRecoverConfig is the default Recover middleware config.
	DefaultRecoverConfig = RecoverConfig{
		Skipper:           middleware.DefaultSkipper,
		StackSize:         1024 << 10, // 1 MB
		DisableStackAll:   false,
		DisablePrintStack: false,
		LogLevel:          0,
	}
	defaultErrorTemplate = `<html>
		<head>
			<title>Error</title>
			<style type="text/css">
				pre { border: 1px solid black; padding: 1em; background: #f0f0f0; font-size: 1.5em; }
			</style>
		</head>
		<body><h1>Error</h1><pre>{{ . }}</pre></body>
	</html>
	`
)

// Recover returns a middleware which recovers from panics anywhere in the chain
// and handles the control to the centralized HTTPErrorHandler.
func Recover() echo.MiddlewareFunc {
	return RecoverWithConfig(DefaultRecoverConfig)
}

// RecoverWithConfig returns a Recover middleware with config.
// See: `Recover()`.
func RecoverWithConfig(config RecoverConfig) echo.MiddlewareFunc {
	// Defaults
	if config.Skipper == nil {
		config.Skipper = DefaultRecoverConfig.Skipper
	}
	if config.StackSize == 0 {
		config.StackSize = DefaultRecoverConfig.StackSize
	}

	errorTemplate, err := template.New("").Parse(defaultErrorTemplate)
	if err != nil {
		panic("RecoverWithConfig: error parsing html template")
	}

	return func(next echo.HandlerFunc) echo.HandlerFunc {
		return func(c echo.Context) error {
			if config.Skipper(c) {
				return next(c)
			}

			defer func() {
				if r := recover(); r != nil {
					err, ok := r.(error)
					if !ok {
						err = fmt.Errorf("%v", r)
					}
					stack := make([]byte, config.StackSize)
					length := runtime.Stack(stack, !config.DisableStackAll)
					if !config.DisablePrintStack {
						msg := fmt.Sprintf("[PANIC RECOVER] %v %s\n", err, stack[:length])
						switch config.LogLevel {
						case log.DEBUG:
							c.Logger().Debug(msg)
						case log.INFO:
							c.Logger().Info(msg)
						case log.WARN:
							c.Logger().Warn(msg)
						case log.ERROR:
							c.Logger().Error(msg)
						case log.OFF:
							// None.
						default:
							c.Logger().Print(msg)
						}
					}
					if c.Echo().Debug {
						buf := bytes.Buffer{}
						msg := fmt.Sprintf("%v\n\n%s", err, stack[:length])
						if err := errorTemplate.Execute(&buf, msg); err != nil {
							c.Logger().Errorf("Error rendering HTML error page: %s", err)
							c.String(http.StatusInternalServerError, msg)
							return
						}
						c.HTMLBlob(http.StatusInternalServerError, buf.Bytes())
					} else {
						c.Error(err)
					}
				}
			}()
			return next(c)
		}
	}
}