aboutsummaryrefslogtreecommitdiff
path: root/net
diff options
context:
space:
mode:
authorMike Crute <mike@crute.us>2023-01-07 20:00:42 -0800
committerMike Crute <mike@crute.us>2023-01-07 20:23:30 -0800
commit6cb3c4271d3126697c8917bce9312603fa6607ca (patch)
treed3492fe7a59eb0ac3f7083061008ff2d60d18349 /net
parent85c67dd133605d2b1125a1b0e8c28f991223aa75 (diff)
downloadgolib-6cb3c4271d3126697c8917bce9312603fa6607ca.tar.bz2
golib-6cb3c4271d3126697c8917bce9312603fa6607ca.tar.xz
golib-6cb3c4271d3126697c8917bce9312603fa6607ca.zip
net/http: add accept parserv0.5.2
Diffstat (limited to 'net')
-rw-r--r--net/http/accept.go214
-rw-r--r--net/http/accept_test.go147
2 files changed, 361 insertions, 0 deletions
diff --git a/net/http/accept.go b/net/http/accept.go
new file mode 100644
index 0000000..cabfc48
--- /dev/null
+++ b/net/http/accept.go
@@ -0,0 +1,214 @@
1package http
2
3import (
4 "fmt"
5 "mime"
6 "reflect"
7 "regexp"
8 "sort"
9 "strconv"
10 "strings"
11)
12
13type MediaType struct {
14 Type string
15 Subtype string
16 Parameters map[string]string
17 Weight float64
18 originalQ string
19}
20
21func ParseMediaType(v string) (*MediaType, error) {
22 mt, params, err := mime.ParseMediaType(v)
23 if err != nil {
24 return nil, err
25 }
26
27 majorMinor := strings.Split(mt, "/")
28 if len(majorMinor) != 2 {
29 return nil, fmt.Errorf("Invalid major/minor media type: %s", mt)
30 }
31
32 // No q should be weight 1.0, per spec
33 q := float64(1)
34 sq, ok := params["q"]
35 if ok {
36 delete(params, "q")
37 q, err = parseQ(sq)
38 if err != nil {
39 return nil, err
40 }
41 }
42
43 return &MediaType{
44 Type: majorMinor[0],
45 Subtype: majorMinor[1],
46 Parameters: params,
47 Weight: q,
48 originalQ: sq,
49 }, nil
50}
51
52func (m MediaType) String() string {
53 b := strings.Builder{}
54 b.WriteString(m.Type + "/" + m.Subtype)
55
56 params := []string{}
57 for k, v := range m.Parameters {
58 params = append(params, fmt.Sprintf("%s=%s", k, v))
59 }
60
61 // Keep them in the same order so they're comparable
62 sort.Strings(params)
63
64 // Q should always be last per RFC9110
65 if m.originalQ != "" {
66 params = append(params, fmt.Sprintf("q=%s", m.originalQ))
67 }
68
69 if len(params) > 0 {
70 b.WriteString(";")
71 b.WriteString(strings.Join(params, ";"))
72 }
73
74 return b.String()
75}
76
77func (m MediaType) Specificity() int {
78 s := 0
79
80 if m.Type != "*" {
81 s += 1
82 }
83 if m.Subtype != "*" {
84 s += 1
85 }
86 if m.Parameters != nil {
87 s += len(m.Parameters)
88 }
89
90 return s
91}
92
93func (m MediaType) Satisfies(v MediaType) bool {
94 if m.Equal(v) {
95 return true
96 }
97
98 if m.Type != v.Type && m.Type != "*" && v.Type != "*" {
99 return false
100 }
101
102 if m.Subtype != v.Subtype && m.Subtype != "*" && v.Subtype != "*" {
103 return false
104 }
105
106 return reflect.DeepEqual(m.Parameters, v.Parameters)
107}
108
109func (m MediaType) Equal(v MediaType) bool {
110 return m.Type == v.Type &&
111 m.Subtype == v.Subtype &&
112 reflect.DeepEqual(m.Parameters, v.Parameters)
113}
114
115type AcceptableTypes []*MediaType
116
117func (a AcceptableTypes) Sorted() AcceptableTypes { sort.Stable(sort.Reverse(a)); return a }
118func (a AcceptableTypes) Len() int { return len(a) }
119func (a AcceptableTypes) Less(i, j int) bool { return a[i].Weight < a[j].Weight }
120func (a AcceptableTypes) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
121
122// FindMatch returns the MediaType in the set that has the best match
123// for the passed media types as well as the media type that determined
124// the match. According to the rules of RFC7231, the most specific match
125// wins in order of weight. This function assumes that both values are
126// sorted.
127func (a AcceptableTypes) FindMatch(values AcceptableTypes) (match *MediaType, matcher *MediaType) {
128 if len(a) == 0 || len(values) == 0 {
129 return nil, nil
130 }
131
132 candidates := AcceptableTypes{}
133
134 // Return the highest precedence match for the type. If there is no
135 // exact match then the most specific match with the highest precedence
136 // should win.
137 for _, matcher = range values {
138 for _, match = range a {
139 if match.Equal(*matcher) {
140 return match, matcher
141 }
142
143 if match.Satisfies(*matcher) {
144 candidates = append(candidates, match)
145 }
146 }
147
148 if len(candidates) != 0 {
149 break
150 }
151 }
152
153 if len(candidates) == 0 {
154 return nil, nil
155 }
156
157 // Sort ascending by specificity
158 sort.SliceStable(candidates, func(i, j int) bool {
159 return candidates[i].Specificity() < candidates[j].Specificity()
160 })
161
162 return candidates[len(candidates)-1], matcher
163}
164
165// ParseAccept parses a set of Accept headers and scores them per the
166// rules in RFC7231, returning a slice of headers in descending order of
167// priority. If the Accept header occurs multiple times in the request
168// the result of all headers will be combined and scored together as if
169// they were all in one header line.
170//
171// See: https://tools.ietf.org/html/rfc7231#section-5.3.2
172func ParseAccept(values []string) (AcceptableTypes, error) {
173 all := AcceptableTypes{}
174
175 for _, l := range values {
176 t, err := parseAcceptLine(l)
177 if err != nil {
178 return nil, err
179 }
180 all = append(all, t...)
181 }
182
183 return all.Sorted(), nil
184}
185
186func parseAcceptLine(l string) (AcceptableTypes, error) {
187 out := AcceptableTypes{}
188
189 for _, t := range strings.Split(l, ",") {
190 mt, err := ParseMediaType(t)
191 if err != nil {
192 return nil, err
193 }
194 out = append(out, mt)
195 }
196
197 return out, nil
198}
199
200// https://tools.ietf.org/html/rfc7231#section-5.3.1
201var validateQ = regexp.MustCompile(`(0\.[0-9]{1,3}|1\.0{1,3})$`)
202
203func parseQ(v string) (float64, error) {
204 if !validateQ.Match([]byte(v)) {
205 return 0.0, fmt.Errorf("Invalid format for Q")
206 }
207
208 f, err := strconv.ParseFloat(v, 64)
209 if err != nil {
210 return 0.0, err
211 }
212
213 return f, nil
214}
diff --git a/net/http/accept_test.go b/net/http/accept_test.go
new file mode 100644
index 0000000..4e209c2
--- /dev/null
+++ b/net/http/accept_test.go
@@ -0,0 +1,147 @@
1package http
2
3import (
4 "testing"
5
6 "github.com/stretchr/testify/assert"
7)
8
9var rfcTestLine = "text/*;q=0.3, text/html;q=0.7, text/html;level=1, text/html;level=2;q=0.4, */*;q=0.5"
10
11func TestParseAccept(t *testing.T) {
12 values, err := ParseAccept([]string{
13 "text/*;q=0.3, text/html;q=0.7",
14 "text/html;level=1, text/html;level=2;q=0.4",
15 "*/*;q=0.5",
16 })
17
18 assert.NoError(t, err)
19 assert.Len(t, values, 5)
20
21 expected := []MediaType{
22 MediaType{Type: "text", Subtype: "html", Weight: 1, Parameters: map[string]string{"level": "1"}},
23 MediaType{Type: "text", Subtype: "html", Weight: 0.7, Parameters: map[string]string{}, originalQ: "0.7"},
24 MediaType{Type: "*", Subtype: "*", Weight: 0.5, Parameters: map[string]string{}, originalQ: "0.5"},
25 MediaType{Type: "text", Subtype: "html", Weight: 0.4, Parameters: map[string]string{"level": "2"}, originalQ: "0.4"},
26 MediaType{Type: "text", Subtype: "*", Weight: 0.3, Parameters: map[string]string{}, originalQ: "0.3"},
27 }
28
29 for i, e := range values {
30 assert.Equal(t, *e, expected[i])
31 }
32}
33
34func TestParseQ(t *testing.T) {
35 var v float64
36 var err error
37
38 _, err = parseQ("1.12")
39 assert.ErrorContains(t, err, "Invalid format for Q")
40
41 _, err = parseQ("0.1234")
42 assert.ErrorContains(t, err, "Invalid format for Q")
43
44 v, err = parseQ("1.0")
45 assert.NoError(t, err)
46 assert.Equal(t, 1.0, v)
47
48 v, err = parseQ("0.003")
49 assert.NoError(t, err)
50 assert.Equal(t, 0.003, v)
51}
52
53func TestParseMediaType(t *testing.T) {
54 mt, err := ParseMediaType("text/plain;foo=bar;q=0.3")
55 assert.NoError(t, err)
56 assert.Equal(t, "text", mt.Type)
57 assert.Equal(t, "plain", mt.Subtype)
58 assert.Equal(t, map[string]string{"foo": "bar"}, mt.Parameters)
59 assert.Equal(t, 0.3, mt.Weight)
60
61 mt, err = ParseMediaType("text/plain")
62 assert.NoError(t, err)
63 assert.Equal(t, 1.0, mt.Weight)
64
65 mt, err = ParseMediaType("foo")
66 assert.ErrorContains(t, err, "Invalid major/minor")
67
68 mt, err = ParseMediaType("foo/bar;q=11")
69 assert.ErrorContains(t, err, "Invalid format for Q")
70}
71
72func TestMediaTypeString(t *testing.T) {
73 mt, err := ParseMediaType("text/plain;foo=bar;biz=baz;q=0.3")
74 assert.NoError(t, err)
75 assert.Equal(t, "text/plain;biz=baz;foo=bar;q=0.3", mt.String())
76
77 // Default q of 1 should not leak to String()
78 mt, err = ParseMediaType("text/plain;foo=bar;biz=baz")
79 assert.NoError(t, err)
80 assert.Equal(t, "text/plain;biz=baz;foo=bar", mt.String())
81}
82
83func TestMediaTypeSpecificity(t *testing.T) {
84 mt, err := ParseMediaType("text/plain;foo=bar;biz=baz")
85 assert.NoError(t, err)
86 assert.Equal(t, 4, mt.Specificity())
87
88 mt, err = ParseMediaType("text/*;foo=bar;biz=baz")
89 assert.NoError(t, err)
90 assert.Equal(t, 3, mt.Specificity())
91
92 mt, err = ParseMediaType("text/*")
93 assert.NoError(t, err)
94 assert.Equal(t, 1, mt.Specificity())
95
96 mt, err = ParseMediaType("*/*")
97 assert.NoError(t, err)
98 assert.Equal(t, 0, mt.Specificity())
99}
100
101func TestParseAcceptLine(t *testing.T) {
102 types, err := parseAcceptLine(rfcTestLine)
103
104 assert.NoError(t, err)
105 assert.Len(t, types, 5)
106
107 for i, e := range []MediaType{
108 MediaType{Type: "text", Subtype: "*", Weight: 0.3, Parameters: map[string]string{}, originalQ: "0.3"},
109 MediaType{Type: "text", Subtype: "html", Weight: 0.7, Parameters: map[string]string{}, originalQ: "0.7"},
110 MediaType{Type: "text", Subtype: "html", Weight: 1, Parameters: map[string]string{"level": "1"}},
111 MediaType{Type: "text", Subtype: "html", Weight: 0.4, Parameters: map[string]string{"level": "2"}, originalQ: "0.4"},
112 MediaType{Type: "*", Subtype: "*", Weight: 0.5, Parameters: map[string]string{}, originalQ: "0.5"},
113 } {
114 assert.Equal(t, e, *types[i])
115 }
116}
117
118func TestFindMatch(t *testing.T) {
119 ours, err := ParseAccept([]string{
120 "text/html;level=1, text/html, text/plain, image/jpeg",
121 "text/html;level=2, text/html;level=3",
122 })
123 assert.NoError(t, err)
124
125 theirs, err := ParseAccept([]string{rfcTestLine})
126 assert.NoError(t, err)
127
128 match, matcher := ours.FindMatch(theirs)
129 assert.Equal(t, "text/html;level=1", match.String())
130 assert.Equal(t, "text/html;level=1", matcher.String())
131
132 match, matcher = ours.FindMatch(theirs[1:])
133 assert.Equal(t, "text/html", match.String())
134 assert.Equal(t, "text/html;q=0.7", matcher.String())
135
136 match, matcher = ours.FindMatch(theirs[2:])
137 assert.Equal(t, "image/jpeg", match.String())
138 assert.Equal(t, "*/*;q=0.5", matcher.String())
139
140 match, matcher = ours.FindMatch(theirs[3:])
141 assert.Equal(t, "text/html;level=2", match.String())
142 assert.Equal(t, "text/html;level=2;q=0.4", matcher.String())
143
144 match, matcher = ours.FindMatch(theirs[4:])
145 assert.Equal(t, "text/plain", match.String())
146 assert.Equal(t, "text/*;q=0.3", matcher.String())
147}