aboutsummaryrefslogtreecommitdiff
path: root/httputil/transport.go
blob: fdad3b47809911024dd52510843e1d2b353145e9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
// Copyright 2015 The Go Authors. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd.

// This file implements a http.RoundTripper that authenticates
// requests issued against api.github.com endpoint.

package httputil

import (
	"net/http"
	"net/url"
)

// AuthTransport is an implementation of http.RoundTripper that authenticates
// with the GitHub API.
//
// When both a token and client credentials are set, the latter is preferred.
type AuthTransport struct {
	UserAgent          string
	GithubToken        string
	GithubClientID     string
	GithubClientSecret string
	Base               http.RoundTripper
}

// RoundTrip implements the http.RoundTripper interface.
func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) {
	var reqCopy *http.Request
	if t.UserAgent != "" {
		reqCopy = copyRequest(req)
		reqCopy.Header.Set("User-Agent", t.UserAgent)
	}
	if req.URL.Host == "api.github.com" && req.URL.Scheme == "https" {
		switch {
		case t.GithubClientID != "" && t.GithubClientSecret != "":
			if reqCopy == nil {
				reqCopy = copyRequest(req)
			}
			if reqCopy.URL.RawQuery == "" {
				reqCopy.URL.RawQuery = "client_id=" + t.GithubClientID + "&client_secret=" + t.GithubClientSecret
			} else {
				reqCopy.URL.RawQuery += "&client_id=" + t.GithubClientID + "&client_secret=" + t.GithubClientSecret
			}
		case t.GithubToken != "":
			if reqCopy == nil {
				reqCopy = copyRequest(req)
			}
			reqCopy.Header.Set("Authorization", "token "+t.GithubToken)
		}
	}
	if reqCopy != nil {
		return t.base().RoundTrip(reqCopy)
	}
	return t.base().RoundTrip(req)
}

// CancelRequest cancels an in-flight request by closing its connection.
func (t *AuthTransport) CancelRequest(req *http.Request) {
	type canceler interface {
		CancelRequest(req *http.Request)
	}
	if cr, ok := t.base().(canceler); ok {
		cr.CancelRequest(req)
	}
}

func (t *AuthTransport) base() http.RoundTripper {
	if t.Base != nil {
		return t.Base
	}
	return http.DefaultTransport
}

func copyRequest(req *http.Request) *http.Request {
	req2 := new(http.Request)
	*req2 = *req
	req2.URL = new(url.URL)
	*req2.URL = *req.URL
	req2.Header = make(http.Header, len(req.Header))
	for k, s := range req.Header {
		req2.Header[k] = append([]string(nil), s...)
	}
	return req2
}