blob: 2a59345aa23d284c64c4a005c0667b589a864951 [file] [log] [blame]
package mockhttpclient
import (
"bytes"
"errors"
"fmt"
"net/http"
"net/url"
"runtime"
"strings"
"github.com/go-chi/chi/v5"
expect "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.skia.org/infra/go/util"
)
// muxClient implements http.RoundTripper and sends requests to a mux.Router.
type muxClient struct {
router chi.Router
}
// muxClientNotFoundHandler provides a useful error message for client requests that don't match any
// mux.Route.
func muxClientNotFoundHandler(w http.ResponseWriter, r *http.Request) {
http.Error(w, fmt.Sprintf("No matching handler for %s", r.URL.String()), TEST_FAILED_STATUS_CODE)
}
// SchemeMatcher is a middleware that returns 404 if the request does not match the given scheme.
func SchemeMatcher(scheme string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Scheme == scheme {
next.ServeHTTP(w, r)
} else {
http.Error(w, http.StatusText(404), 404)
}
})
}
}
// HostMatcher is a middleware that returns 404 if the request does not match the given host.
func HostMatcher(host string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Host == host {
next.ServeHTTP(w, r)
} else {
http.Error(w, http.StatusText(404), 404)
}
})
}
}
// QueryMatcher is a middleware that returns 404 if the request does not have the given key/value
// pairs in the query string. For example, QueryMatcher("name", "foo", "size", "42") would match
// "example.com/hello?name=foo&size=42" but it wouldn't match
// "example.com/hello?name=bar&length=123".
func QueryMatcher(pairs ...string) func(http.Handler) http.Handler {
if len(pairs)%2 != 0 {
panic("the number of arguments must be even")
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
values, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
http.Error(w, fmt.Sprintf("error parsing query: %s", err), 500)
return
}
ok := true
for i := 0; i < len(pairs); i += 2 {
if !values.Has(pairs[i]) || values.Get(pairs[i]) != pairs[i+1] {
ok = false
break
}
}
if ok {
next.ServeHTTP(w, r)
} else {
http.Error(w, http.StatusText(404), 404)
}
})
}
}
// NewMuxClient returns an http.Client instance which sends requests to the given mux.Router.
//
// NewMuxClient is more flexible than using httptest.NewServer because the returned client can
// accept requests for any scheme or host. It is more flexible than URLMock because it allows
// handling arbitrary URLs with arbitrary handlers. However, it is more difficult to use than
// URLMock when the same request URL should be handled differently on subsequent requests.
//
// TODO(benjaminwagner): NewMuxClient does not currently support streaming responses, but does
// support streaming requests.
//
// Examples:
//
// // Mock out a URL to always respond with the same body.
// r := chi.NewRouter()
// r.With(SchemeMatcher("https"), HostMatcher("www.google.com")).
// Get("/", MockGetDialogue([]byte("Here's a response.")).ServeHTTP)
// client := NewMuxClient(r)
// res, _ := client.Get("https://www.google.com")
// respBody, _ := io.ReadAll(res.Body) // respBody == []byte("Here's a response.")
//
// // Check that the client uses the correct ID in the request.
// r.With(HostMatcher("example.com"), QueryMatcher("name", "foo", "size", "42")).
// Post("/add/{id:[a-zA-Z0-9]+}", func(w http.ResponseWriter, r *http.Request) {
// t := MuxSafeT(t)
// assert.Equal(t, chi.URLParam(r, "id"), values.Get("name"))
// })
func NewMuxClient(r chi.Router) *http.Client {
r.NotFound(http.HandlerFunc(muxClientNotFoundHandler))
m := &muxClient{
router: r,
}
return &http.Client{
Transport: m,
}
}
// responseWriter implements http.ResponseWriter for handlers and provides an http.Response for
// clients.
type responseWriter struct {
resp http.Response
body bytes.Buffer
}
func newResponseWriter() *responseWriter {
w := &responseWriter{}
w.resp.Body = &respBodyCloser{&w.body}
w.resp.Header = http.Header{}
w.resp.StatusCode = http.StatusOK
w.resp.ContentLength = -1
return w
}
func (w *responseWriter) Header() http.Header {
return w.resp.Header
}
func (w *responseWriter) Write(data []byte) (int, error) {
return w.body.Write(data)
}
func (w *responseWriter) WriteHeader(code int) {
w.resp.StatusCode = code
}
// RoundTrip is an implementation of http.RoundTripper.RoundTrip. It sends requests to the
// mux.Router.
func (m *muxClient) RoundTrip(req *http.Request) (resp *http.Response, err error) {
defer func() {
if req.Body != nil {
util.Close(req.Body)
}
}()
if req.URL == nil || req.URL.Scheme == "" || req.URL.Host == "" || req.Header == nil {
return nil, fmt.Errorf("invalid request; URL: %#v Header: %#v", req.URL, req.Header)
}
if req.Method == "" {
req.Method = http.MethodGet
}
if req.URL.Path == "" {
req.URL.Path = "/"
}
w := newResponseWriter()
// Check for muxClientFailNowValue and set err if found.
defer func() {
r := recover()
if r != nil {
v, ok := r.(muxClientFailNowValue)
if ok {
loc := ""
if v.file != "" {
loc = fmt.Sprintf("at %s:%d ", v.file, v.line)
}
err = fmt.Errorf("Test failed %swhile handling HTTP request for %s", loc, req.URL.String())
} else {
panic(r)
}
}
}()
m.router.ServeHTTP(w, req)
if w.resp.StatusCode == TEST_FAILED_STATUS_CODE {
return nil, errors.New(w.body.String())
}
w.resp.Request = req
return &w.resp, nil
}
// muxSafeT implements assert.TestingT (aka require.TestingT) but allows muxClient to translate
// FailNow into a regular error. This is necessary because some users of http.Client behave badly
// when runtime.Goexit() (called by testing.T.FailNow) is called within muxClient.RoundTrip.
type muxSafeT struct {
expect.TestingT
}
// MuxSafeT wraps *testing.T to allow using the assert package (aka require package) within handler
// functions of the mux.Router passed to MuxClient. This is necessary because some users of
// http.Client behave badly when runtime.Goexit() (called by testing.T.FailNow) occurs during a
// request.
//
// The documentation for testing.T.FailNow states "FailNow must be called from the goroutine running
// the test or benchmark function, not from other goroutines created during the test," so if the
// http.Client returned from MuxClient is used by a different goroutine, you should use MuxSafeT to
// ensure the test doesn't hang.
func MuxSafeT(orig expect.TestingT) require.TestingT {
return muxSafeT{orig}
}
// muxClientFailNowValue indicates to muxClient.RoundTrip that the test has failed and records the
// file and line where the failure occurred.
type muxClientFailNowValue struct {
file string
line int
}
// Implements assert.TestingT.FailNow().
func (muxSafeT) FailNow() {
// 3 frames up seems to give the correct spot.
_, file, line, ok := runtime.Caller(3)
if ok {
slash := strings.LastIndex(file, "/")
if slash >= 0 {
file = file[slash+1:]
}
panic(muxClientFailNowValue{
file: file,
line: line,
})
}
panic(muxClientFailNowValue{})
}