blob: 95413b4f6cb340c95bf2b4b02db004bb426776aa [file] [log] [blame]
package httputils
import (
func TestResponse2xxOnly(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
code, err := strconv.Atoi(r.URL.Query().Get("code"))
require.NoError(t, err)
defer s.Close()
test := func(c *http.Client, code int, expectError bool) {
resp, err := c.Get(s.URL + "/get?code=" + strconv.Itoa(code))
if expectError {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, code, resp.StatusCode)
c := s.Client()
test(c, http.StatusSwitchingProtocols, false)
test(c, http.StatusOK, false)
test(c, http.StatusNotModified, false)
test(c, http.StatusNotFound, false)
test(c, http.StatusServiceUnavailable, false)
c = Response2xxOnly(c)
test(c, http.StatusSwitchingProtocols, true)
test(c, http.StatusOK, false)
test(c, http.StatusNotModified, true)
test(c, http.StatusNotFound, true)
test(c, http.StatusServiceUnavailable, true)
var (
mockRoundTripErr = errors.New("Can not round trip on a one-way street.")
type MockRoundTripper struct {
// responseCodes gives the expected response for subsequent requests. The last response code is
// repeated for subsequent requests. 0 means return mockRoundTripErr. You must set this field to a
// non-empty slice before RoundTrip is called.
responseCodes []int
func (t *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
code := t.responseCodes[0]
if len(t.responseCodes) > 1 {
t.responseCodes = t.responseCodes[1:]
if code == 0 {
return nil, mockRoundTripErr
w := httptest.NewRecorder()
return w.Result(), nil
func TestBackoffTransport(t *testing.T) {
// Use a fail-faster config so the test doesn't take so long.
maxInterval := 600 * time.Millisecond
config := &BackOffConfig{
initialInterval: INITIAL_INTERVAL,
maxInterval: maxInterval,
// Tests below expect at least three retries.
maxElapsedTime: 3 * maxInterval,
randomizationFactor: RANDOMIZATION_FACTOR,
backOffMultiplier: BACKOFF_MULTIPLIER,
wrapped := &MockRoundTripper{}
bt := NewConfiguredBackOffTransport(config, wrapped)
// test takes a slice of response codes for the server to respond with (the last being repeated)
// and verifies that the response code from BackoffTransport is equal to the final value in codes.
// A 0 code means the RoundTripper returns an error.
test := func(codes []int) {
wrapped.responseCodes = codes
r, err := http.NewRequest("GET", "", nil)
require.NoError(t, err)
now := time.Now()
resp, err := bt.RoundTrip(r)
dur := time.Since(now)
expected := codes[len(codes)-1]
if expected == 0 {
require.Equal(t, mockRoundTripErr, err)
} else {
require.NoError(t, err)
require.Equal(t, codes[len(codes)-1], resp.StatusCode)
if len(codes) > 1 {
// There's not much we can assert other than there's a delay of at least
// (INITIAL_INTERVAL * (1 - RANDOMIZATION_FACTOR)) after the first attempt.
minDur := time.Duration(float64(INITIAL_INTERVAL) * (1 - RANDOMIZATION_FACTOR))
require.Truef(t, dur >= minDur, "For codes %v, expected duration to be at least %d, but was %d", codes, minDur, dur)
// No retries.
// Some retries before non-retriable status code.
test([]int{http.StatusServiceUnavailable, http.StatusOK})
test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError, http.StatusNotFound})
test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError, http.StatusBadGateway, http.StatusNotModified})
// Retries exhausted for server error.
// Retry transport error.
test([]int{0, http.StatusOK})
test([]int{0, 0, http.StatusOK})
// Retries exhausted for transport error.
test([]int{http.StatusInternalServerError, 0})
// RoundTripperFunc transforms a function into a RoundTripper
type RoundTripperFunc func(req *http.Request) (*http.Response, error)
func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
func TestBackoffTransportWithContext(t *testing.T) {
// Use a fail-faster config so the test doesn't take so long.
maxInterval := 600 * time.Millisecond
config := &BackOffConfig{
initialInterval: INITIAL_INTERVAL,
maxInterval: maxInterval,
// We should never reach this deadline.
maxElapsedTime: 10 * maxInterval,
randomizationFactor: RANDOMIZATION_FACTOR,
backOffMultiplier: BACKOFF_MULTIPLIER,
// Test canceling the context after the nth request. See MockRoundTripper docs for codes;
// len(codes) > cancelAfter. Request context will be canceled during the request with index
// cancelAfter. Asserts that the number of retries agrees with cancelAfter.
test := func(codes []int, cancelAfter int) {
mock := MockRoundTripper{
responseCodes: codes,
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
callCount := 0
wrapped := func(req *http.Request) (*http.Response, error) {
if cancelAfter == callCount {
return mock.RoundTrip(req)
bt := NewConfiguredBackOffTransport(config, RoundTripperFunc(wrapped))
req, err := http.NewRequestWithContext(ctx, "GET", "", nil)
require.NoError(t, err)
resp, err := bt.RoundTrip(req)
// We expect no calls after the context is canceled.
require.Equal(t, cancelAfter, callCount-1)
// We expect the result to be the result of the call when the context is canceled.
expected := codes[cancelAfter]
if expected == 0 {
require.Equal(t, mockRoundTripErr, err)
} else {
require.NoError(t, err)
require.Equal(t, expected, resp.StatusCode)
// No retries needed.
test([]int{http.StatusOK}, 0)
// Context is canceled, so no retry.
test([]int{http.StatusServiceUnavailable}, 0)
// Second request should never happen.
test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError}, 0)
// Some retries before context canceled.
test([]int{http.StatusServiceUnavailable, http.StatusOK}, 1)
test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError}, 1)
test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError, http.StatusBadGateway}, 2)
// Transport error; context is canceled, so no retry.
test([]int{0}, 0)
// Transport error; some retries before context is canceled.
test([]int{0, 0}, 1)
test([]int{0, http.StatusOK}, 1)
test([]int{0, http.StatusInternalServerError}, 1)
test([]int{http.StatusInternalServerError, 0}, 1)
func TestForceHTTPS(t *testing.T) {
var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, err := io.WriteString(w, "Hello World!")
require.NoError(t, err)
// Test w/o ForceHTTPS in place.
r := httptest.NewRequest("GET", "", nil)
w := httptest.NewRecorder()
h.ServeHTTP(w, r)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, "", w.Result().Header.Get("Location"))
b, err := io.ReadAll(w.Result().Body)
require.NoError(t, err)
require.Len(t, b, 12)
// Add in ForceHTTPS behavior.
h = HealthzAndHTTPS(h)
w = httptest.NewRecorder()
h.ServeHTTP(w, r)
require.Equal(t, 301, w.Result().StatusCode)
require.Equal(t, "", w.Result().Header.Get("Location"))
// Test the healthcheck handling.
r = httptest.NewRequest("GET", "", nil)
r.Header.Set("User-Agent", "GoogleHC/1.0")
w = httptest.NewRecorder()
h.ServeHTTP(w, r)
require.Equal(t, 200, w.Result().StatusCode)
require.Equal(t, "", w.Result().Header.Get("Location"))
b, err = io.ReadAll(w.Result().Body)
require.NoError(t, err)
require.Len(t, b, 0)
func TestGetWithContextSunnyDay(t *testing.T) {
content := []byte("something")
m := mockhttpclient.NewURLMock()
resp := mockhttpclient.MockGetDialogue(content)
m.Mock("", resp)
r, err := GetWithContext(context.Background(), m.Client(), "")
require.NoError(t, err)
msg, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, content, msg)
require.NoError(t, r.Body.Close())
func TestGetWithContextCanceled(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
_, err := GetWithContext(ctx, http.DefaultClient, "")
require.Error(t, err)
assert.Contains(t, err.Error(), "canceled")
func TestPostWithContextSunnyDay(t *testing.T) {
const mimeType = "text/plain"
const input = "something"
output := []byte("different")
m := mockhttpclient.NewURLMock()
resp := mockhttpclient.MockPostDialogue(mimeType, []byte(input), output)
m.Mock("", resp)
r, err := PostWithContext(context.Background(), m.Client(), "", mimeType, strings.NewReader(input))
require.NoError(t, err)
msg, err := io.ReadAll(r.Body)
require.NoError(t, err)
assert.Equal(t, output, msg)
require.NoError(t, r.Body.Close())
func TestPostWithContextCancelled(t *testing.T) {
const mimeType = "text/plain"
const input = "something"
ctx, cancel := context.WithCancel(context.Background())
_, err := PostWithContext(ctx, http.DefaultClient, "", mimeType, strings.NewReader(input))
require.Error(t, err)
assert.Contains(t, err.Error(), "canceled")
func TestCrossOriginResourcePolicy_Success(t *testing.T) {
w := httptest.NewRecorder()
var h http.Handler
h = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
h = CrossOriginResourcePolicy(h)
r := httptest.NewRequest("GET", "/", nil)
h.ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "cross-origin", w.Header().Get("Cross-Origin-Resource-Policy"))