package httputils

import (
	"context"
	"errors"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"strconv"
	"testing"
	"time"

	assert "github.com/stretchr/testify/require"
	"go.skia.org/infra/go/testutils/unittest"
)

func TestResponse2xxOnly(t *testing.T) {
	unittest.SmallTest(t)

	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		code, err := strconv.Atoi(r.URL.Query().Get("code"))
		assert.NoError(t, err)
		w.WriteHeader(code)
	}))
	defer s.Close()
	test := func(c *http.Client, code int, expectError bool) {
		resp, err := c.Get(s.URL + "/get?code=" + strconv.Itoa(code))
		if expectError {
			assert.Error(t, err)
		} else {
			assert.NoError(t, err)
			assert.Equal(t, code, resp.StatusCode)
			ReadAndClose(resp.Body)
		}
	}
	c := s.Client()
	test(c, http.StatusSwitchingProtocols, false)
	test(c, http.StatusOK, false)
	test(c, http.StatusNotModified, false)
	test(c, http.StatusNotFound, false)
	test(c, http.StatusServiceUnavailable, false)
	c = Response2xxOnly(c)
	test(c, http.StatusSwitchingProtocols, true)
	test(c, http.StatusOK, false)
	test(c, http.StatusNotModified, true)
	test(c, http.StatusNotFound, true)
	test(c, http.StatusServiceUnavailable, true)
}

var (
	mockRoundTripErr = errors.New("Can not round trip on a one-way street.")
)

type MockRoundTripper struct {
	// responseCodes gives the expected response for subsequent requests. The last response code is
	// repeated for subsequent requests. 0 means return mockRoundTripErr. You must set this field to a
	// non-empty slice before RoundTrip is called.
	responseCodes []int
}

func (t *MockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
	code := t.responseCodes[0]
	if len(t.responseCodes) > 1 {
		t.responseCodes = t.responseCodes[1:]
	}
	if code == 0 {
		return nil, mockRoundTripErr
	}
	w := httptest.NewRecorder()
	w.WriteHeader(code)
	return w.Result(), nil
}

func TestBackoffTransport(t *testing.T) {
	unittest.LargeTest(t) // BackoffTransport sleeps between requests.
	// Use a fail-faster config so the test doesn't take so long.
	maxInterval := 600 * time.Millisecond
	config := &BackOffConfig{
		initialInterval: INITIAL_INTERVAL,
		maxInterval:     maxInterval,
		// Tests below expect at least three retries.
		maxElapsedTime:      3 * maxInterval,
		randomizationFactor: RANDOMIZATION_FACTOR,
		backOffMultiplier:   BACKOFF_MULTIPLIER,
	}
	wrapped := &MockRoundTripper{}
	bt := NewConfiguredBackOffTransport(config, wrapped)

	// test takes a slice of response codes for the server to respond with (the last being repeated)
	// and verifies that the response code from BackoffTransport is equal to the final value in codes.
	// A 0 code means the RoundTripper returns an error.
	test := func(codes []int) {
		wrapped.responseCodes = codes
		r, err := http.NewRequest("GET", "http://example.com/foo", nil)
		assert.NoError(t, err)
		now := time.Now()
		resp, err := bt.RoundTrip(r)
		dur := time.Now().Sub(now)
		expected := codes[len(codes)-1]
		if expected == 0 {
			assert.Equal(t, mockRoundTripErr, err)
		} else {
			assert.NoError(t, err)
			assert.Equal(t, codes[len(codes)-1], resp.StatusCode)
			ReadAndClose(resp.Body)
		}
		if len(codes) > 1 {
			// There's not much we can assert other than there's a delay of at least
			// (INITIAL_INTERVAL * (1 - RANDOMIZATION_FACTOR)) after the first attempt.
			minDur := time.Duration(float64(INITIAL_INTERVAL) * (1 - RANDOMIZATION_FACTOR))
			assert.Truef(t, dur >= minDur, "For codes %v, expected duration to be at least %d, but was %d", codes, minDur, dur)
		}
	}
	// No retries.
	test([]int{http.StatusOK})
	test([]int{http.StatusSwitchingProtocols})
	test([]int{http.StatusNotModified})
	test([]int{http.StatusNotFound})
	// Some retries before non-retriable status code.
	test([]int{http.StatusServiceUnavailable, http.StatusOK})
	test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError, http.StatusNotFound})
	test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError, http.StatusBadGateway, http.StatusNotModified})
	// Retries exhausted for server error.
	test([]int{http.StatusInternalServerError})
	// Retry transport error.
	test([]int{0, http.StatusOK})
	test([]int{0, 0, http.StatusOK})
	// Retries exhausted for transport error.
	test([]int{http.StatusInternalServerError, 0})
}

// RoundTripperFunc transforms a function into a RoundTripper
type RoundTripperFunc func(req *http.Request) (*http.Response, error)

func (f RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
	return f(req)
}

func TestBackoffTransportWithContext(t *testing.T) {
	unittest.LargeTest(t) // BackoffTransport sleeps between requests.
	// Use a fail-faster config so the test doesn't take so long.
	maxInterval := 600 * time.Millisecond
	config := &BackOffConfig{
		initialInterval: INITIAL_INTERVAL,
		maxInterval:     maxInterval,
		// We should never reach this deadline.
		maxElapsedTime:      10 * maxInterval,
		randomizationFactor: RANDOMIZATION_FACTOR,
		backOffMultiplier:   BACKOFF_MULTIPLIER,
	}

	// Test canceling the context after the nth request. See MockRoundTripper docs for codes;
	// len(codes) > cancelAfter. Request context will be canceled during the request with index
	// cancelAfter. Asserts that the number of retries agrees with cancelAfter.
	test := func(codes []int, cancelAfter int) {
		mock := MockRoundTripper{
			responseCodes: codes,
		}
		ctx, cancel := context.WithCancel(context.Background())
		defer cancel()
		callCount := 0
		wrapped := func(req *http.Request) (*http.Response, error) {
			if cancelAfter == callCount {
				cancel()
			}
			callCount++
			return mock.RoundTrip(req)
		}
		bt := NewConfiguredBackOffTransport(config, RoundTripperFunc(wrapped))
		req, err := http.NewRequest("GET", "http://example.com/foo", nil)
		assert.NoError(t, err)
		req = req.WithContext(ctx)
		resp, err := bt.RoundTrip(req)
		// We expect no calls after the context is canceled.
		assert.Equal(t, cancelAfter, callCount-1)
		// We expect the result to be the result of the call when the context is canceled.
		expected := codes[cancelAfter]
		if expected == 0 {
			assert.Equal(t, mockRoundTripErr, err)
		} else {
			assert.NoError(t, err)
			assert.Equal(t, expected, resp.StatusCode)
			ReadAndClose(resp.Body)
		}
	}
	// No retries needed.
	test([]int{http.StatusOK}, 0)
	// Context is canceled, so no retry.
	test([]int{http.StatusServiceUnavailable}, 0)
	// Second request should never happen.
	test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError}, 0)
	// Some retries before context canceled.
	test([]int{http.StatusServiceUnavailable, http.StatusOK}, 1)
	test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError}, 1)
	test([]int{http.StatusServiceUnavailable, http.StatusInternalServerError, http.StatusBadGateway}, 2)

	// Transport error; context is canceled, so no retry.
	test([]int{0}, 0)
	// Transport error; some retries before context is canceled.
	test([]int{0, 0}, 1)
	test([]int{0, http.StatusOK}, 1)
	test([]int{0, http.StatusInternalServerError}, 1)
	test([]int{http.StatusInternalServerError, 0}, 1)
}

func TestForceHTTPS(t *testing.T) {
	unittest.SmallTest(t)
	var h http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		_, err := io.WriteString(w, "Hello World!")
		assert.NoError(t, err)
	})
	// Test w/o ForceHTTPS in place.
	r := httptest.NewRequest("GET", "http://example.com/foo", nil)
	r.Header.Set(SCHEME_AT_LOAD_BALANCER_HEADER, "http")
	w := httptest.NewRecorder()
	h.ServeHTTP(w, r)
	assert.Equal(t, 200, w.Result().StatusCode)
	assert.Equal(t, "", w.Result().Header.Get("Location"))
	b, err := ioutil.ReadAll(w.Result().Body)
	assert.NoError(t, err)
	assert.Len(t, b, 12)

	// Add in ForceHTTPS behavior.
	h = HealthzAndHTTPS(h)
	w = httptest.NewRecorder()
	h.ServeHTTP(w, r)
	assert.Equal(t, 301, w.Result().StatusCode)
	assert.Equal(t, "https://example.com/foo", w.Result().Header.Get("Location"))

	// Test the healthcheck handling.
	r = httptest.NewRequest("GET", "http://example.com/", nil)
	r.Header.Set("User-Agent", "GoogleHC/1.0")
	w = httptest.NewRecorder()
	h.ServeHTTP(w, r)
	assert.Equal(t, 200, w.Result().StatusCode)
	assert.Equal(t, "", w.Result().Header.Get("Location"))
	b, err = ioutil.ReadAll(w.Result().Body)
	assert.NoError(t, err)
	assert.Len(t, b, 0)
}
