aboutsummaryrefslogtreecommitdiff
path: root/echo/static_file.go
blob: 9723db781631d317b8888990ffa515cf43dae0b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
package echo

import (
	"io"
	"io/fs"
	"net/http"
	"net/url"
	"path/filepath"

	"github.com/labstack/echo/v4"
)

type routeFunc func(string, echo.HandlerFunc, ...echo.MiddlewareFunc) *echo.Route

func StaticFS(get routeFunc, f fs.FS, prefix, root string, m ...echo.MiddlewareFunc) *echo.Route {
	if root == "" {
		root = "." // For security we want to restrict to CWD.
	}

	h := func(c echo.Context) error {
		p, err := url.PathUnescape(c.Param("*"))
		if err != nil {
			return err
		}

		name := filepath.Join(root, filepath.Clean("/"+p)) // "/"+ for security
		fp, err := f.Open(name)
		if err != nil {
			// The access path does not exist
			return echo.NotFoundHandler(c)
		}
		defer fp.Close()

		fi, err := fp.Stat()
		if err != nil {
			// The access path does not exist
			return echo.NotFoundHandler(c)
		}

		// If the request is for a directory and does not end with "/"
		p = c.Request().URL.Path // path must not be empty.
		if fi.IsDir() && p[len(p)-1] != '/' {
			// Redirect to ends with "/"
			// return c.Redirect(http.StatusMovedPermanently, p+"/")
			// TODO: Serve an index.html if there is one for this dir
			return echo.NotFoundHandler(c)
		}

		fs, ok := fp.(io.ReadSeeker)
		if !ok {
			c.Logger().Errorf("File %s is not a io.ReadSeeker", p)
			return echo.ErrInternalServerError
		}

		http.ServeContent(c.Response(), c.Request(), fi.Name(), fi.ModTime(), fs)
		return nil
	}

	// Handle added routes based on trailing slash:
	// 	/prefix  => exact route "/prefix" + any route "/prefix/*"
	// 	/prefix/ => only any route "/prefix/*"
	if prefix != "" {
		if prefix[len(prefix)-1] == '/' {
			// Only add any route for intentional trailing slash
			return get(prefix+"*", h, m...)
		}
		get(prefix, h, m...)
	}

	return get(prefix+"/*", h, m...)
}