blob: d117b6f72ca19012528a2da34af1e13a19d61876 [file] [log] [blame]
package remote_db
import (
"fmt"
"net/http"
"net/http/httptest"
"os"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/gorilla/mux"
assert "github.com/stretchr/testify/require"
"go.skia.org/infra/go/deepequal"
"go.skia.org/infra/go/testutils"
"go.skia.org/infra/task_scheduler/go/db"
memory "go.skia.org/infra/task_scheduler/go/db/memory"
"go.skia.org/infra/task_scheduler/go/db/pubsub"
"go.skia.org/infra/task_scheduler/go/types"
)
func TestMain(m *testing.M) {
db.AssertDeepEqual = deepequal.AssertDeepEqual
os.Exit(m.Run())
}
// clientWithBackdoor allows us to test the client/server pair as a db.DB, using
// the generic DB test utils. All method calls supported by RemoteDB use the
// client/server implementation; other methods have "backdoor" access to the
// underlying DB to allow the tests to modify the DB.
type clientWithBackdoor struct {
// *client; implements the methods being tested.
db.RemoteDB
// The DB passed to NewServer.
backdoor db.DB
// The test HTTP server listening on the loopback address.
httpserver *httptest.Server
}
func (b *clientWithBackdoor) Close() error {
b.httpserver.Close()
return nil
}
func (b *clientWithBackdoor) AssignId(task *types.Task) error {
return b.backdoor.AssignId(task)
}
func (b *clientWithBackdoor) PutTask(task *types.Task) error {
return b.backdoor.PutTask(task)
}
func (b *clientWithBackdoor) PutTasks(tasks []*types.Task) error {
return b.backdoor.PutTasks(tasks)
}
func (b *clientWithBackdoor) PutTasksInChunks(tasks []*types.Task) error {
return b.PutTasks(tasks)
}
func (b *clientWithBackdoor) PutJob(job *types.Job) error {
return b.backdoor.PutJob(job)
}
func (b *clientWithBackdoor) PutJobs(jobs []*types.Job) error {
return b.backdoor.PutJobs(jobs)
}
func (b *clientWithBackdoor) PutJobsInChunks(jobs []*types.Job) error {
return b.PutJobs(jobs)
}
type reqCountingTransport struct {
count int
countMtx sync.RWMutex
rt http.RoundTripper
}
func (t *reqCountingTransport) Inc() {
t.countMtx.Lock()
defer t.countMtx.Unlock()
t.count++
}
func (t *reqCountingTransport) Get() int {
t.countMtx.RLock()
defer t.countMtx.RUnlock()
return t.count
}
func (t *reqCountingTransport) Reset() {
t.countMtx.Lock()
defer t.countMtx.Unlock()
t.count = 0
}
func (t *reqCountingTransport) RoundTrip(req *http.Request) (*http.Response, error) {
t.Inc()
return t.rt.RoundTrip(req)
}
func newReqCountingTransport(rt http.RoundTripper) http.RoundTripper {
return &reqCountingTransport{
rt: rt,
}
}
// makeDB sets up a client/server pair wrapped in a clientWithBackdoor.
func makeDB(t *testing.T) db.DBCloser {
serverLabel := fmt.Sprintf("remote-db-test-%s", uuid.New())
mod, err := pubsub.NewModifiedData(pubsub.TOPIC_SET_PRODUCTION, serverLabel, nil)
assert.NoError(t, err)
baseDB := memory.NewInMemoryDB(mod)
r := mux.NewRouter()
err = RegisterServer(baseDB, r.PathPrefix("/db").Subrouter())
assert.NoError(t, err)
ts := httptest.NewServer(r)
clientLabel := fmt.Sprintf("remote-db-test-%s", uuid.New())
dbclient, err := NewClient(ts.URL+"/db/", pubsub.TOPIC_SET_PRODUCTION, clientLabel, nil)
assert.NoError(t, err)
dbclient.(*client).client.Transport = newReqCountingTransport(dbclient.(*client).client.Transport)
return &clientWithBackdoor{
RemoteDB: dbclient,
backdoor: baseDB,
httpserver: ts,
}
}
func TestRemoteDBTaskDB(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestTaskDB(t, d)
}
func TestRemoteDBTaskDBConcurrentUpdate(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestTaskDBConcurrentUpdate(t, d)
}
func TestRemoteDBTaskDBUpdateTasksWithRetries(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestUpdateTasksWithRetries(t, d)
}
func TestRemoteDBTaskDBGetTasksFromDateRangeByRepo(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestTaskDBGetTasksFromDateRangeByRepo(t, d)
}
func TestRemoteDBTaskDBGetTasksFromWindow(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestTaskDBGetTasksFromWindow(t, d)
}
func TestRemoteDBJobDB(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestJobDB(t, d)
}
func TestRemoteDBJobDBConcurrentUpdate(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestJobDBConcurrentUpdate(t, d)
}
func TestRemoteDBCommentDB(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
db.TestCommentDB(t, d)
}
func TestRemoteDBGetTasksFromDateRange(t *testing.T) {
testutils.LargeTest(t)
d := makeDB(t)
defer testutils.AssertCloses(t, d)
tp := d.(*clientWithBackdoor).RemoteDB.(*client).client.Transport.(*reqCountingTransport)
timeStart := time.Now().Add(-3 * MAX_TASK_TIME_RANGE)
t1 := types.MakeTestTask(timeStart.Add(time.Nanosecond), []string{"a", "b"})
assert.NoError(t, d.PutTask(t1))
t2 := types.MakeTestTask(t1.Created.Add(MAX_TASK_TIME_RANGE), []string{"c"})
assert.NoError(t, d.PutTask(t2))
t3 := types.MakeTestTask(t2.Created.Add(MAX_TASK_TIME_RANGE), []string{"d"})
assert.NoError(t, d.PutTask(t3))
// Request time ranges, and ensure that we get back the correct number
// of tasks and made the correct number of HTTP requests.
test := func(start, end time.Time, expectTasks, expectReqs int) {
tp.Reset()
tasks, err := d.GetTasksFromDateRange(start, end, "")
assert.NoError(t, err)
assert.Equal(t, expectTasks, len(tasks))
assert.Equal(t, expectReqs, tp.Get())
}
test(timeStart, t1.Created.Add(time.Nanosecond), 1, 1)
test(timeStart, t2.Created.Add(time.Nanosecond), 2, 2)
test(timeStart, t3.Created.Add(time.Nanosecond), 3, 3)
test(timeStart, timeStart.Add(MAX_TASK_TIME_RANGE), 1, 1)
test(timeStart, timeStart.Add(MAX_TASK_TIME_RANGE).Add(time.Nanosecond), 1, 2)
}