// Copyright 2015 The Go Authors. All rights reserved. // // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file or at // https://developers.google.com/open-source/licenses/bsd. // This file implements a http.RoundTripper that authenticates // requests issued against api.github.com endpoint. package httputil import ( "net/http" "net/url" ) // AuthTransport is an implementation of http.RoundTripper that authenticates // with the GitHub API. // // When both a token and client credentials are set, the latter is preferred. type AuthTransport struct { UserAgent string GithubToken string GithubClientID string GithubClientSecret string Base http.RoundTripper } // RoundTrip implements the http.RoundTripper interface. func (t *AuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { var reqCopy *http.Request if t.UserAgent != "" { reqCopy = copyRequest(req) reqCopy.Header.Set("User-Agent", t.UserAgent) } if req.URL.Host == "api.github.com" && req.URL.Scheme == "https" { switch { case t.GithubClientID != "" && t.GithubClientSecret != "": if reqCopy == nil { reqCopy = copyRequest(req) } if reqCopy.URL.RawQuery == "" { reqCopy.URL.RawQuery = "client_id=" + t.GithubClientID + "&client_secret=" + t.GithubClientSecret } else { reqCopy.URL.RawQuery += "&client_id=" + t.GithubClientID + "&client_secret=" + t.GithubClientSecret } case t.GithubToken != "": if reqCopy == nil { reqCopy = copyRequest(req) } reqCopy.Header.Set("Authorization", "token "+t.GithubToken) } } if reqCopy != nil { return t.base().RoundTrip(reqCopy) } return t.base().RoundTrip(req) } // CancelRequest cancels an in-flight request by closing its connection. func (t *AuthTransport) CancelRequest(req *http.Request) { type canceler interface { CancelRequest(req *http.Request) } if cr, ok := t.base().(canceler); ok { cr.CancelRequest(req) } } func (t *AuthTransport) base() http.RoundTripper { if t.Base != nil { return t.Base } return http.DefaultTransport } func copyRequest(req *http.Request) *http.Request { req2 := new(http.Request) *req2 = *req req2.URL = new(url.URL) *req2.URL = *req.URL req2.Header = make(http.Header, len(req.Header)) for k, s := range req.Header { req2.Header[k] = append([]string(nil), s...) } return req2 }