[pinpoint] Refactor swarming logic to backends/
run_benchmark and pinpoint.go are difficult to work with because
of the request object and swarming usage. this change takes care
of the exisiting TODO and moves the swarming calls out of
run_benchmark into its own file under benchmarks, and decouples
away from run_benchmark's request object.
Change-Id: Ie3dcacc8b14f009d7cee5a30cd7e3e2a48a8842a
Reviewed-on: https://skia-review.googlesource.com/c/buildbot/+/817456
Reviewed-by: Hao Wu <haowoo@google.com>
Commit-Queue: Jeff Yoon <jeffyoon@google.com>
diff --git a/.mockery.yaml b/.mockery.yaml
index c4cfc5c..8bfa6e3 100644
--- a/.mockery.yaml
+++ b/.mockery.yaml
@@ -92,6 +92,7 @@
go.skia.org/infra/pinpoint/go/backends:
interfaces:
BuildbucketClient:
+ SwarmingClient:
go.skia.org/infra/pinpoint/go/build_chrome:
interfaces:
BuildChromeClient:
diff --git a/pinpoint/go/backends/BUILD.bazel b/pinpoint/go/backends/BUILD.bazel
index 05bfecd..81dff25 100644
--- a/pinpoint/go/backends/BUILD.bazel
+++ b/pinpoint/go/backends/BUILD.bazel
@@ -7,6 +7,7 @@
"buildbucket.go",
"doc.go",
"gitiles.go",
+ "swarming.go",
"waterfall_map.go",
],
importpath = "go.skia.org/infra/pinpoint/go/backends",
@@ -18,6 +19,7 @@
"//go/gitiles",
"//go/httputils",
"//go/skerr",
+ "//go/swarming",
"@com_github_google_uuid//:uuid",
"@org_chromium_go_luci//buildbucket/proto",
"@org_chromium_go_luci//common/api/swarming/swarming/v1:swarming",
@@ -31,13 +33,20 @@
go_test(
name = "backends_test",
- srcs = ["buildbucket_test.go"],
+ srcs = [
+ "buildbucket_test.go",
+ "swarming_test.go",
+ ],
embed = [":backends"],
deps = [
"//go/buildbucket",
+ "//go/swarming/mocks",
"@com_github_golang_mock//gomock",
"@com_github_smartystreets_goconvey//convey",
+ "@com_github_stretchr_testify//assert",
+ "@com_github_stretchr_testify//mock",
"@org_chromium_go_luci//buildbucket/proto",
+ "@org_chromium_go_luci//common/api/swarming/swarming/v1:swarming",
"@org_chromium_go_luci//common/testing/assertions",
"@org_chromium_go_luci//grpc/appstatus",
"@org_golang_google_protobuf//types/known/structpb",
diff --git a/pinpoint/go/backends/mocks/BUILD.bazel b/pinpoint/go/backends/mocks/BUILD.bazel
index 57a3c10..43df272 100644
--- a/pinpoint/go/backends/mocks/BUILD.bazel
+++ b/pinpoint/go/backends/mocks/BUILD.bazel
@@ -2,7 +2,10 @@
go_library(
name = "mocks",
- srcs = ["BuildbucketClient.go"],
+ srcs = [
+ "BuildbucketClient.go",
+ "SwarmingClient.go",
+ ],
importpath = "go.skia.org/infra/pinpoint/go/backends/mocks",
visibility = ["//visibility:public"],
deps = [
diff --git a/pinpoint/go/backends/mocks/SwarmingClient.go b/pinpoint/go/backends/mocks/SwarmingClient.go
new file mode 100644
index 0000000..414995f
--- /dev/null
+++ b/pinpoint/go/backends/mocks/SwarmingClient.go
@@ -0,0 +1,195 @@
+// Code generated by mockery v0.0.0-dev. DO NOT EDIT.
+
+package mocks
+
+import (
+ context "context"
+
+ mock "github.com/stretchr/testify/mock"
+ swarming "go.chromium.org/luci/common/api/swarming/swarming/v1"
+)
+
+// SwarmingClient is an autogenerated mock type for the SwarmingClient type
+type SwarmingClient struct {
+ mock.Mock
+}
+
+// CancelTasks provides a mock function with given fields: ctx, taskIDs
+func (_m *SwarmingClient) CancelTasks(ctx context.Context, taskIDs []string) error {
+ ret := _m.Called(ctx, taskIDs)
+
+ if len(ret) == 0 {
+ panic("no return value specified for CancelTasks")
+ }
+
+ var r0 error
+ if rf, ok := ret.Get(0).(func(context.Context, []string) error); ok {
+ r0 = rf(ctx, taskIDs)
+ } else {
+ r0 = ret.Error(0)
+ }
+
+ return r0
+}
+
+// GetCASOutput provides a mock function with given fields: ctx, taskID
+func (_m *SwarmingClient) GetCASOutput(ctx context.Context, taskID string) (*swarming.SwarmingRpcsCASReference, error) {
+ ret := _m.Called(ctx, taskID)
+
+ if len(ret) == 0 {
+ panic("no return value specified for GetCASOutput")
+ }
+
+ var r0 *swarming.SwarmingRpcsCASReference
+ var r1 error
+ if rf, ok := ret.Get(0).(func(context.Context, string) (*swarming.SwarmingRpcsCASReference, error)); ok {
+ return rf(ctx, taskID)
+ }
+ if rf, ok := ret.Get(0).(func(context.Context, string) *swarming.SwarmingRpcsCASReference); ok {
+ r0 = rf(ctx, taskID)
+ } else {
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).(*swarming.SwarmingRpcsCASReference)
+ }
+ }
+
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, taskID)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
+}
+
+// GetStates provides a mock function with given fields: ctx, taskIDs
+func (_m *SwarmingClient) GetStates(ctx context.Context, taskIDs []string) ([]string, error) {
+ ret := _m.Called(ctx, taskIDs)
+
+ if len(ret) == 0 {
+ panic("no return value specified for GetStates")
+ }
+
+ var r0 []string
+ var r1 error
+ if rf, ok := ret.Get(0).(func(context.Context, []string) ([]string, error)); ok {
+ return rf(ctx, taskIDs)
+ }
+ if rf, ok := ret.Get(0).(func(context.Context, []string) []string); ok {
+ r0 = rf(ctx, taskIDs)
+ } else {
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).([]string)
+ }
+ }
+
+ if rf, ok := ret.Get(1).(func(context.Context, []string) error); ok {
+ r1 = rf(ctx, taskIDs)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
+}
+
+// GetStatus provides a mock function with given fields: ctx, taskID
+func (_m *SwarmingClient) GetStatus(ctx context.Context, taskID string) (string, error) {
+ ret := _m.Called(ctx, taskID)
+
+ if len(ret) == 0 {
+ panic("no return value specified for GetStatus")
+ }
+
+ var r0 string
+ var r1 error
+ if rf, ok := ret.Get(0).(func(context.Context, string) (string, error)); ok {
+ return rf(ctx, taskID)
+ }
+ if rf, ok := ret.Get(0).(func(context.Context, string) string); ok {
+ r0 = rf(ctx, taskID)
+ } else {
+ r0 = ret.Get(0).(string)
+ }
+
+ if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
+ r1 = rf(ctx, taskID)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
+}
+
+// ListPinpointTasks provides a mock function with given fields: ctx, jobID, buildArtifact
+func (_m *SwarmingClient) ListPinpointTasks(ctx context.Context, jobID string, buildArtifact *swarming.SwarmingRpcsCASReference) ([]string, error) {
+ ret := _m.Called(ctx, jobID, buildArtifact)
+
+ if len(ret) == 0 {
+ panic("no return value specified for ListPinpointTasks")
+ }
+
+ var r0 []string
+ var r1 error
+ if rf, ok := ret.Get(0).(func(context.Context, string, *swarming.SwarmingRpcsCASReference) ([]string, error)); ok {
+ return rf(ctx, jobID, buildArtifact)
+ }
+ if rf, ok := ret.Get(0).(func(context.Context, string, *swarming.SwarmingRpcsCASReference) []string); ok {
+ r0 = rf(ctx, jobID, buildArtifact)
+ } else {
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).([]string)
+ }
+ }
+
+ if rf, ok := ret.Get(1).(func(context.Context, string, *swarming.SwarmingRpcsCASReference) error); ok {
+ r1 = rf(ctx, jobID, buildArtifact)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
+}
+
+// TriggerTask provides a mock function with given fields: ctx, req
+func (_m *SwarmingClient) TriggerTask(ctx context.Context, req *swarming.SwarmingRpcsNewTaskRequest) (*swarming.SwarmingRpcsTaskRequestMetadata, error) {
+ ret := _m.Called(ctx, req)
+
+ if len(ret) == 0 {
+ panic("no return value specified for TriggerTask")
+ }
+
+ var r0 *swarming.SwarmingRpcsTaskRequestMetadata
+ var r1 error
+ if rf, ok := ret.Get(0).(func(context.Context, *swarming.SwarmingRpcsNewTaskRequest) (*swarming.SwarmingRpcsTaskRequestMetadata, error)); ok {
+ return rf(ctx, req)
+ }
+ if rf, ok := ret.Get(0).(func(context.Context, *swarming.SwarmingRpcsNewTaskRequest) *swarming.SwarmingRpcsTaskRequestMetadata); ok {
+ r0 = rf(ctx, req)
+ } else {
+ if ret.Get(0) != nil {
+ r0 = ret.Get(0).(*swarming.SwarmingRpcsTaskRequestMetadata)
+ }
+ }
+
+ if rf, ok := ret.Get(1).(func(context.Context, *swarming.SwarmingRpcsNewTaskRequest) error); ok {
+ r1 = rf(ctx, req)
+ } else {
+ r1 = ret.Error(1)
+ }
+
+ return r0, r1
+}
+
+// NewSwarmingClient creates a new instance of SwarmingClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations.
+// The first argument is typically a *testing.T value.
+func NewSwarmingClient(t interface {
+ mock.TestingT
+ Cleanup(func())
+}) *SwarmingClient {
+ mock := &SwarmingClient{}
+ mock.Mock.Test(t)
+
+ t.Cleanup(func() { mock.AssertExpectations(t) })
+
+ return mock
+}
diff --git a/pinpoint/go/backends/swarming.go b/pinpoint/go/backends/swarming.go
new file mode 100644
index 0000000..f4eaffb
--- /dev/null
+++ b/pinpoint/go/backends/swarming.go
@@ -0,0 +1,138 @@
+package backends
+
+import (
+ "context"
+ "fmt"
+ "time"
+
+ "go.skia.org/infra/go/auth"
+ "go.skia.org/infra/go/httputils"
+ "go.skia.org/infra/go/skerr"
+ "go.skia.org/infra/go/swarming"
+
+ "golang.org/x/oauth2/google"
+
+ swarmingV1 "go.chromium.org/luci/common/api/swarming/swarming/v1"
+)
+
+const (
+ DefaultSwarmingServiceAddress = "chrome-swarming.appspot.com:443"
+)
+
+// SwarmingClient
+type SwarmingClient interface {
+ // CancelTasks tells Swarming to cancel the given tasks.
+ CancelTasks(ctx context.Context, taskIDs []string) error
+
+ // GetCASOutput returns the CAS output of a swarming task.
+ GetCASOutput(ctx context.Context, taskID string) (*swarmingV1.SwarmingRpcsCASReference, error)
+
+ // GetStates returns the state of each task in a list of tasks.
+ GetStates(ctx context.Context, taskIDs []string) ([]string, error)
+
+ // GetStatus gets the current status of a swarming task.
+ GetStatus(ctx context.Context, taskID string) (string, error)
+
+ // ListPinpointTasks lists the Pinpoint swarming tasks.
+ ListPinpointTasks(ctx context.Context, jobID string, buildArtifact *swarmingV1.SwarmingRpcsCASReference) ([]string, error)
+
+ // TriggerTask is a literal wrapper around swarming.ApiClient TriggerTask
+ // TODO(jeffyoon@) remove once run_benchmark is refactored if no longer needed.
+ TriggerTask(ctx context.Context, req *swarmingV1.SwarmingRpcsNewTaskRequest) (*swarmingV1.SwarmingRpcsTaskRequestMetadata, error)
+}
+
+// SwarmingClientImpl
+// TODO(jeffyoon@) make this private once run_benchmark doesn't rely on this in testing.
+type SwarmingClientImpl struct {
+ swarming.ApiClient
+}
+
+func NewSwarmingClient(ctx context.Context, server string) (*SwarmingClientImpl, error) {
+ httpClientTokenSource, err := google.DefaultTokenSource(ctx, auth.ScopeReadOnly)
+ if err != nil {
+ return nil, skerr.Wrapf(err, "Problem setting up default token source")
+ }
+ c := httputils.DefaultClientConfig().WithTokenSource(httpClientTokenSource).With2xxOnly().Client()
+
+ sc, err := swarming.NewApiClient(c, server)
+ if err != nil {
+ return nil, err
+ }
+
+ return &SwarmingClientImpl{
+ ApiClient: sc,
+ }, nil
+}
+
+// CancelTasks tells Swarming to cancel the given tasks.
+func (s *SwarmingClientImpl) CancelTasks(ctx context.Context, taskIDs []string) error {
+ for _, id := range taskIDs {
+ err := s.CancelTask(ctx, id, true)
+ if err != nil {
+ return skerr.Fmt("Could not cancel task %s due to %s", id, err)
+ }
+ }
+ return nil
+}
+
+// GetCASOutput returns the CAS output of a swarming task in the form of a RBE CAS hash.
+// This function assumes the task is finished, or it throws an error.
+func (s *SwarmingClientImpl) GetCASOutput(ctx context.Context, taskID string) (*swarmingV1.SwarmingRpcsCASReference, error) {
+ task, err := s.GetTask(ctx, taskID, false)
+ if err != nil {
+ return nil, fmt.Errorf("error retrieving result of task %s: %s", taskID, err)
+ }
+ if task.State != "COMPLETED" {
+ return nil, fmt.Errorf("cannot get result of task %s because it is %s and not COMPLETED", taskID, task.State)
+ }
+ rbe := &swarmingV1.SwarmingRpcsCASReference{
+ CasInstance: task.CasOutputRoot.CasInstance,
+ Digest: &swarmingV1.SwarmingRpcsDigest{
+ Hash: task.CasOutputRoot.Digest.Hash,
+ SizeBytes: task.CasOutputRoot.Digest.SizeBytes,
+ },
+ }
+
+ return rbe, nil
+}
+
+// func (s *SwarmingClientImpl) GetStates(ctx context.Context, taskIDs []string) ([]string, error) {
+// return s.GetStates(ctx, taskIDs)
+// }
+
+// GetStatus gets the current status of a swarming task.
+func (s *SwarmingClientImpl) GetStatus(ctx context.Context, taskID string) (string, error) {
+ res, err := s.GetTask(ctx, taskID, false)
+ if err != nil {
+ return "", skerr.Fmt("failed to get swarming task ID %s due to err: %v", taskID, err)
+ }
+ return res.State, nil
+}
+
+// ListPinpointTasks lists the Pinpoint swarming tasks of a given job and build identified by Swarming tags.
+func (s *SwarmingClientImpl) ListPinpointTasks(ctx context.Context, jobID string, buildArtifact *swarmingV1.SwarmingRpcsCASReference) ([]string, error) {
+ if jobID == "" {
+ return nil, skerr.Fmt("Cannot list tasks because request is missing JobID")
+ }
+ if buildArtifact == nil || buildArtifact.Digest == nil {
+ return nil, skerr.Fmt("Cannot list tasks because request is missing cas isolate")
+ }
+ start := time.Now().Add(-24 * time.Hour)
+ tags := []string{
+ fmt.Sprintf("pinpoint_job_id:%s", jobID),
+ fmt.Sprintf("build_cas:%s/%d", buildArtifact.Digest.Hash, buildArtifact.Digest.SizeBytes),
+ }
+ tasks, err := s.ListTasks(ctx, start, time.Now(), tags, "")
+ if err != nil {
+ return nil, fmt.Errorf("error retrieving tasks %s", err)
+ }
+ taskIDs := make([]string, len(tasks))
+ for i, t := range tasks {
+ taskIDs[i] = t.TaskId
+ }
+ return taskIDs, nil
+}
+
+// func (s *SwarmingClientImpl) TriggerTask(ctx context.Context, req *swarmingV1.SwarmingRpcsNewTaskRequest) (*swarmingV1.SwarmingRpcsTaskRequestMetadata, error) {
+// return s.TriggerTask(ctx, req)
+// }
diff --git a/pinpoint/go/backends/swarming_test.go b/pinpoint/go/backends/swarming_test.go
new file mode 100644
index 0000000..c0c5a39
--- /dev/null
+++ b/pinpoint/go/backends/swarming_test.go
@@ -0,0 +1,133 @@
+package backends
+
+import (
+ "context"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+ "github.com/stretchr/testify/mock"
+
+ "go.skia.org/infra/go/swarming/mocks"
+
+ swarmingV1 "go.chromium.org/luci/common/api/swarming/swarming/v1"
+)
+
+func TestNewSwarmingClient_Default_SwarmingClient(t *testing.T) {
+ ctx := context.Background()
+ sc, err := NewSwarmingClient(ctx, DefaultSwarmingServiceAddress)
+ assert.NoError(t, err)
+ assert.NotNil(t, sc)
+}
+
+func TestListPinpointTasks_ValidInput_TasksFound(t *testing.T) {
+ ctx := context.Background()
+ mockClient := mocks.NewApiClient(t)
+
+ bA := &swarmingV1.SwarmingRpcsCASReference{
+ CasInstance: "instance",
+ Digest: &swarmingV1.SwarmingRpcsDigest{
+ Hash: "hash",
+ SizeBytes: 0,
+ },
+ }
+
+ mockClient.On("ListTasks", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
+ Return([]*swarmingV1.SwarmingRpcsTaskRequestMetadata{
+ {
+ TaskId: "123",
+ },
+ {
+ TaskId: "456",
+ },
+ }, nil).Once()
+
+ sc := &SwarmingClientImpl{
+ ApiClient: mockClient,
+ }
+ taskIds, err := sc.ListPinpointTasks(ctx, "id", bA)
+ assert.NoError(t, err)
+ assert.Equal(t, []string{"123", "456"}, taskIds)
+}
+
+func TestListPinpointTasks_ValidInput_NoTasksFound(t *testing.T) {
+ ctx := context.Background()
+ mockClient := mocks.NewApiClient(t)
+
+ bA := &swarmingV1.SwarmingRpcsCASReference{
+ CasInstance: "instance",
+ Digest: &swarmingV1.SwarmingRpcsDigest{
+ Hash: "hash",
+ SizeBytes: 0,
+ },
+ }
+
+ mockClient.On("ListTasks", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).
+ Return([]*swarmingV1.SwarmingRpcsTaskRequestMetadata{}, nil).Once()
+
+ sc := &SwarmingClientImpl{
+ ApiClient: mockClient,
+ }
+ taskIds, err := sc.ListPinpointTasks(ctx, "id", bA)
+
+ assert.NoError(t, err)
+ assert.Empty(t, taskIds)
+}
+
+func TestListPinpointTasks_InvalidInputs_Error(t *testing.T) {
+ ctx := context.Background()
+ mockClient := mocks.NewApiClient(t)
+ sc := &SwarmingClientImpl{
+ ApiClient: mockClient,
+ }
+
+ taskIds, err := sc.ListPinpointTasks(ctx, "", &swarmingV1.SwarmingRpcsCASReference{})
+ assert.Nil(t, taskIds)
+ assert.ErrorContains(t, err, "Cannot list tasks because request is missing JobID")
+
+ taskIds, err = sc.ListPinpointTasks(ctx, "id", nil)
+ assert.Nil(t, taskIds)
+ assert.ErrorContains(t, err, "Cannot list tasks because request is missing cas isolate")
+}
+
+func TestGetCASOutput_ValidInput_SwarmingRBECasRef(t *testing.T) {
+ ctx := context.Background()
+ mockClient := mocks.NewApiClient(t)
+ sc := &SwarmingClientImpl{
+ ApiClient: mockClient,
+ }
+
+ mockClient.On("GetTask", ctx, mock.Anything, mock.Anything).
+ Return(&swarmingV1.SwarmingRpcsTaskResult{
+ State: "COMPLETED",
+ CasOutputRoot: &swarmingV1.SwarmingRpcsCASReference{
+ CasInstance: "instance",
+ Digest: &swarmingV1.SwarmingRpcsDigest{
+ Hash: "hash",
+ SizeBytes: 0,
+ },
+ },
+ }, nil).Once()
+
+ rbe, err := sc.GetCASOutput(ctx, "taskId")
+ assert.NoError(t, err)
+ assert.Equal(t, "instance", rbe.CasInstance)
+ assert.Equal(t, "hash", rbe.Digest.Hash)
+ assert.Equal(t, int64(0), rbe.Digest.SizeBytes)
+}
+
+func TestGasCASOutput_IncompleteTask_Error(t *testing.T) {
+ ctx := context.Background()
+ mockClient := mocks.NewApiClient(t)
+ sc := &SwarmingClientImpl{
+ ApiClient: mockClient,
+ }
+
+ mockClient.On("GetTask", ctx, mock.Anything, mock.Anything).
+ Return(&swarmingV1.SwarmingRpcsTaskResult{
+ State: "Not_Completed",
+ }, nil).Once()
+
+ rbe, err := sc.GetCASOutput(ctx, "taskId")
+ assert.Nil(t, rbe)
+ assert.ErrorContains(t, err, "cannot get result of task")
+}
diff --git a/pinpoint/go/pinpoint/BUILD.bazel b/pinpoint/go/pinpoint/BUILD.bazel
index 537c222..6793e22 100644
--- a/pinpoint/go/pinpoint/BUILD.bazel
+++ b/pinpoint/go/pinpoint/BUILD.bazel
@@ -12,6 +12,7 @@
"//go/skerr",
"//go/sklog",
"//go/swarming",
+ "//pinpoint/go/backends",
"//pinpoint/go/bot_configs",
"//pinpoint/go/build_chrome",
"//pinpoint/go/compare",
@@ -36,6 +37,7 @@
"//go/mockhttpclient",
"//go/skerr",
"//go/swarming/mocks",
+ "//pinpoint/go/backends",
"//pinpoint/go/bot_configs",
"//pinpoint/go/build_chrome/mocks",
"//pinpoint/go/compare",
diff --git a/pinpoint/go/pinpoint/pinpoint.go b/pinpoint/go/pinpoint/pinpoint.go
index 91c4ff6..6a6319d 100644
--- a/pinpoint/go/pinpoint/pinpoint.go
+++ b/pinpoint/go/pinpoint/pinpoint.go
@@ -16,6 +16,7 @@
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/go/swarming"
+ "go.skia.org/infra/pinpoint/go/backends"
"go.skia.org/infra/pinpoint/go/bot_configs"
"go.skia.org/infra/pinpoint/go/build_chrome"
"go.skia.org/infra/pinpoint/go/compare"
@@ -63,7 +64,7 @@
// pinpointJobImpl implements the PinpointJob interface.
type pinpointHandlerImpl struct {
- sc swarming.ApiClient
+ sc backends.SwarmingClient
bc build_chrome.BuildChromeClient
mc midpoint.MidpointHandler
}
@@ -109,7 +110,7 @@
}
c := httputils.DefaultClientConfig().WithTokenSource(httpClientTokenSource).With2xxOnly().Client()
- sc, err := swarming.NewApiClient(c, swarmingServiceAddress)
+ sc, err := backends.NewSwarmingClient(ctx, swarmingServiceAddress)
if err != nil {
return nil, skerr.Wrapf(err, "Could not create swarming client")
}
@@ -369,12 +370,12 @@
}
// scheduleRunBenchmark schedules run benchmark tests to swarming and returns the task IDs
-func (c *commitData) scheduleRunBenchmark(ctx context.Context, sc swarming.ApiClient) ([]string, error) {
+func (c *commitData) scheduleRunBenchmark(ctx context.Context, sc backends.SwarmingClient) ([]string, error) {
if c.tests == nil || c.tests.req == nil {
return nil, skerr.Fmt("Cannot schedule benchmark runs without request")
}
// Fetching Pinpoint tasks here can skip scheduling new tasks for faster testing
- tasks, err := run_benchmark.ListPinpointTasks(ctx, sc, *c.tests.req)
+ tasks, err := sc.ListPinpointTasks(ctx, c.tests.req.JobID, c.tests.req.Build)
if err != nil {
return nil, skerr.Wrapf(err, "Could not list tasks prior to run benchmark for request %v", *c.tests.req)
}
@@ -393,14 +394,14 @@
// pollTests checks the test status of every commit in the commitQ
// returns upon finding the first commit with running tasks that all finished
// returns the index of the commit so it is easier to compare left and right neighbors
-func (cdl commitDataList) pollTests(ctx context.Context, sc swarming.ApiClient) (int, *commitData, error) {
+func (cdl commitDataList) pollTests(ctx context.Context, sc backends.SwarmingClient) (int, *commitData, error) {
for i, c := range cdl.commits {
if c.tests == nil {
continue
}
if c.tests.isRunning {
c.tests.isRunning = false
- states, err := run_benchmark.GetStates(ctx, sc, c.tests.tasks)
+ states, err := sc.GetStates(ctx, c.tests.tasks)
if err != nil {
return -1, nil, skerr.Wrapf(err, "failed to retrieve swarming tasks %v", c.tests.tasks)
}
@@ -420,7 +421,7 @@
}
// getTestCAS returns the CAS output addresses from a set of swarming tests
-func (c *commitData) getTestCAS(ctx context.Context, sc swarming.ApiClient) (
+func (c *commitData) getTestCAS(ctx context.Context, sc backends.SwarmingClient) (
[]*swarmingV1.SwarmingRpcsCASReference, error) {
casOutputs := []*swarmingV1.SwarmingRpcsCASReference{}
if c.tests == nil {
@@ -432,7 +433,7 @@
}
for i, s := range c.tests.states {
if s == "COMPLETED" {
- cas, err := run_benchmark.GetCASOutput(ctx, sc, c.tests.tasks[i])
+ cas, err := sc.GetCASOutput(ctx, c.tests.tasks[i])
if err != nil {
return nil, skerr.Wrapf(err, "error retrieving cas outputs")
}
@@ -512,7 +513,7 @@
// updateCommitsByResult takes the compare results and determines the next
// steps in the workflow. Changes are made to CommitDataList depending
// on what the compare verdict is.
-func (cdl *commitDataList) updateCommitsByResult(ctx context.Context, sc swarming.ApiClient, mh midpoint.MidpointHandler, res *compare.CompareResults,
+func (cdl *commitDataList) updateCommitsByResult(ctx context.Context, sc backends.SwarmingClient, mh midpoint.MidpointHandler, res *compare.CompareResults,
left, right int) (*midpoint.Commit, error) {
if left < 0 || right >= len(cdl.commits) {
return nil, skerr.Fmt("cannot update commitDataList with left %d and right %d index out of bounds", left, right)
@@ -554,7 +555,7 @@
}
// runMoreTestsIfNeeded adds more run_benchmark tasks to the left and right commit
-func (cdl *commitDataList) runMoreTestsIfNeeded(ctx context.Context, sc swarming.ApiClient, left, right int) error {
+func (cdl *commitDataList) runMoreTestsIfNeeded(ctx context.Context, sc backends.SwarmingClient, left, right int) error {
c := cdl.commits[left]
tasks, err := c.scheduleRunBenchmark(ctx, sc)
if err != nil {
diff --git a/pinpoint/go/pinpoint/pinpoint_test.go b/pinpoint/go/pinpoint/pinpoint_test.go
index e4a0ba5..624b4f1 100644
--- a/pinpoint/go/pinpoint/pinpoint_test.go
+++ b/pinpoint/go/pinpoint/pinpoint_test.go
@@ -12,6 +12,7 @@
"go.skia.org/infra/go/mockhttpclient"
"go.skia.org/infra/go/skerr"
swarmingMocks "go.skia.org/infra/go/swarming/mocks"
+ "go.skia.org/infra/pinpoint/go/backends"
"go.skia.org/infra/pinpoint/go/bot_configs"
"go.skia.org/infra/pinpoint/go/build_chrome/mocks"
"go.skia.org/infra/pinpoint/go/compare"
@@ -348,7 +349,10 @@
TaskId: "new_task",
}, nil).Times(interval)
- tasks, err := c.scheduleRunBenchmark(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ tasks, err := c.scheduleRunBenchmark(ctx, sc)
So(err, ShouldBeNil)
So(tasks[0], ShouldEqual, "new_task")
So(len(tasks), ShouldEqual, interval)
@@ -357,7 +361,10 @@
Convey(`Error`, t, func() {
Convey(`When no tests started`, func() {
c := &commitData{}
- tasks, err := c.scheduleRunBenchmark(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ tasks, err := c.scheduleRunBenchmark(ctx, sc)
So(err, ShouldErrLike, "Cannot schedule benchmark runs without request")
So(tasks, ShouldBeNil)
})
@@ -365,7 +372,10 @@
c := &commitData{
tests: &testMetadata{},
}
- tasks, err := c.scheduleRunBenchmark(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ tasks, err := c.scheduleRunBenchmark(ctx, sc)
So(err, ShouldErrLike, "Cannot schedule benchmark runs without request")
So(tasks, ShouldBeNil)
})
@@ -393,7 +403,10 @@
msc.On("ListTasks", ctx, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil, nil).Once()
msc.On("TriggerTask", ctx, mock.Anything).Return(nil, skerr.Fmt(errMsg)).Once()
- tasks, err := c.scheduleRunBenchmark(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ tasks, err := c.scheduleRunBenchmark(ctx, sc)
So(err, ShouldErrLike, errMsg)
So(tasks, ShouldBeNil)
})
@@ -403,12 +416,16 @@
func TestPollTests(t *testing.T) {
ctx := context.Background()
msc := swarmingMocks.NewApiClient(t)
+
Convey(`OK`, t, func() {
Convey(`When no tasks to poll`, func() {
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
cdl := commitDataList{
commits: []*commitData{},
}
- idx, c, err := cdl.pollTests(ctx, msc)
+ idx, c, err := cdl.pollTests(ctx, sc)
So(err, ShouldBeNil)
So(idx, ShouldEqual, -1)
So(c, ShouldBeNil)
@@ -420,7 +437,7 @@
},
},
}
- idx, c, err = cdl.pollTests(ctx, msc)
+ idx, c, err = cdl.pollTests(ctx, sc)
So(err, ShouldBeNil)
So(idx, ShouldEqual, -1)
So(c, ShouldBeNil)
@@ -436,7 +453,10 @@
},
},
}
- idx, c, err := cdl.pollTests(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ idx, c, err := cdl.pollTests(ctx, sc)
So(err, ShouldBeNil)
So(idx, ShouldEqual, -1)
So(c, ShouldBeNil)
@@ -462,7 +482,10 @@
"COMPLETED",
}, nil).Once()
- idx, c, err := cdl.pollTests(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ idx, c, err := cdl.pollTests(ctx, sc)
So(err, ShouldBeNil)
So(idx, ShouldEqual, -1)
So(c, ShouldBeNil)
@@ -488,7 +511,10 @@
"COMPLETED",
}, nil).Once()
- idx, c, err := cdl.pollTests(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ idx, c, err := cdl.pollTests(ctx, sc)
So(err, ShouldBeNil)
So(idx, ShouldEqual, 0)
So(c, ShouldNotBeNil)
@@ -511,7 +537,10 @@
}
msc.On("GetStates", ctx, mock.Anything).Return(nil, skerr.Fmt("some error")).Once()
- idx, c, err := cdl.pollTests(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ idx, c, err := cdl.pollTests(ctx, sc)
So(err, ShouldErrLike, "failed to retrieve swarming tasks")
So(idx, ShouldEqual, -1)
So(c, ShouldBeNil)
@@ -521,6 +550,7 @@
func TestGetTestCAS(t *testing.T) {
ctx := context.Background()
msc := swarmingMocks.NewApiClient(t)
+
Convey(`OK`, t, func() {
c := &commitData{
tests: &testMetadata{
@@ -541,7 +571,10 @@
}, nil,
).Times(len(c.tests.tasks))
- cas, err := c.getTestCAS(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ cas, err := c.getTestCAS(ctx, sc)
So(err, ShouldBeNil)
So(cas, ShouldNotBeNil)
So(len(cas), ShouldEqual, len(c.tests.tasks))
@@ -553,7 +586,10 @@
Convey(`Error`, t, func() {
Convey(`When there are no tests`, func() {
c := &commitData{}
- cas, err := c.getTestCAS(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ cas, err := c.getTestCAS(ctx, sc)
So(err, ShouldErrLike, "cannot get cas output of non-existent swarming tasks")
So(cas, ShouldBeNil)
So(c.tests, ShouldBeNil)
@@ -565,7 +601,10 @@
states: []string{"COMPLETED"},
},
}
- cas, err := c.getTestCAS(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ cas, err := c.getTestCAS(ctx, sc)
So(err, ShouldErrLike, "mismatching number of swarming states")
So(cas, ShouldBeNil)
So(c.tests.casOutputs, ShouldBeNil)
@@ -581,7 +620,10 @@
skerr.Fmt("some error"),
).Once()
- cas, err := c.getTestCAS(ctx, msc)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ cas, err := c.getTestCAS(ctx, sc)
So(err, ShouldErrLike, "error retrieving cas outputs")
So(cas, ShouldBeNil)
})
@@ -745,14 +787,20 @@
Convey(`When index out of bounds`, func() {
left, right := -1, 0
res := &compare.CompareResults{}
- mid, err := cdl.updateCommitsByResult(ctx, msc, mmh, res, left, right)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ mid, err := cdl.updateCommitsByResult(ctx, sc, mmh, res, left, right)
So(err, ShouldErrLike, "index out of bounds")
So(mid, ShouldBeNil)
})
Convey(`When left >= right`, func() {
left, right := 1, 0
res := &compare.CompareResults{}
- mid, err := cdl.updateCommitsByResult(ctx, msc, mmh, res, left, right)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ mid, err := cdl.updateCommitsByResult(ctx, sc, mmh, res, left, right)
So(err, ShouldErrLike, fmt.Sprintf("left %d index >= right %d", left, right))
So(mid, ShouldBeNil)
})
@@ -834,7 +882,10 @@
TaskId: "new_right_task",
}, nil).Times(interval)
- err = cdl.runMoreTestsIfNeeded(ctx, msc, left, right)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: msc,
+ }
+ err = cdl.runMoreTestsIfNeeded(ctx, sc, left, right)
So(err, ShouldBeNil)
So(len(lcommit.tests.tasks), ShouldEqual, interval+2)
So(lcommit.tests.tasks[0], ShouldEqual, "old_left_task_1")
diff --git a/pinpoint/go/run_benchmark/BUILD.bazel b/pinpoint/go/run_benchmark/BUILD.bazel
index 2968037..dd58ff9 100644
--- a/pinpoint/go/run_benchmark/BUILD.bazel
+++ b/pinpoint/go/run_benchmark/BUILD.bazel
@@ -10,10 +10,10 @@
importpath = "go.skia.org/infra/pinpoint/go/run_benchmark",
visibility = ["//visibility:public"],
deps = [
- "//cabe/go/backends",
"//go/skerr",
"//go/swarming",
"//go/util",
+ "//pinpoint/go/backends",
"//pinpoint/go/bot_configs",
"@org_chromium_go_luci//common/api/swarming/swarming/v1:swarming",
],
@@ -30,6 +30,7 @@
"//go/skerr",
"//go/swarming",
"//go/swarming/mocks",
+ "//pinpoint/go/backends",
"//pinpoint/go/bot_configs",
"@com_github_smartystreets_goconvey//convey",
"@com_github_stretchr_testify//assert",
diff --git a/pinpoint/go/run_benchmark/run_benchmark.go b/pinpoint/go/run_benchmark/run_benchmark.go
index 812c167..294bdb7 100644
--- a/pinpoint/go/run_benchmark/run_benchmark.go
+++ b/pinpoint/go/run_benchmark/run_benchmark.go
@@ -10,20 +10,17 @@
"context"
"fmt"
"slices"
- "time"
swarmingV1 "go.chromium.org/luci/common/api/swarming/swarming/v1"
- "go.skia.org/infra/cabe/go/backends"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/swarming"
+ "go.skia.org/infra/pinpoint/go/backends"
"go.skia.org/infra/pinpoint/go/bot_configs"
)
// A RunBenchmarkRequest defines the request arguments of the performance test
// to swarming.
type RunBenchmarkRequest struct {
- // the Swarming client
- Client swarming.ApiClient
// the Pinpoint job id
JobID string
// the swarming instance and cas digest hash and bytes location for the build
@@ -64,103 +61,6 @@
// NullFields: omitted
}
-var runningStates = []string{
- swarming.TASK_STATE_PENDING,
- swarming.TASK_STATE_RUNNING,
-}
-
-// DialSwarming dials a swarming API client.
-// TODO(sunxiaodi@) migrate swarming components to backends/ folder
-func DialSwarming(ctx context.Context) (swarming.ApiClient, error) {
- return backends.DialSwarming(ctx)
-}
-
-// ListPinpointTasks lists the Pinpoint swarming tasks of a given
-// job and build.
-func ListPinpointTasks(ctx context.Context, client swarming.ApiClient, req RunBenchmarkRequest) ([]string, error) {
- if req.JobID == "" {
- return nil, skerr.Fmt("Cannot list tasks because request is missing JobID")
- }
- if req.Build == nil || req.Build.Digest == nil {
- return nil, skerr.Fmt("Cannot list tasks because request is missing cas isolate")
- }
- start := time.Now().Add(-24 * time.Hour)
- tags := []string{
- fmt.Sprintf("pinpoint_job_id:%s", req.JobID),
- fmt.Sprintf("build_cas:%s/%d", req.Build.Digest.Hash, req.Build.Digest.SizeBytes),
- }
- tasks, err := client.ListTasks(ctx, start, time.Now(), tags, "")
- if err != nil {
- return nil, fmt.Errorf("error retrieving tasks %s", err)
- }
- taskIDs := make([]string, len(tasks))
- for i, t := range tasks {
- taskIDs[i] = t.TaskId
- }
- return taskIDs, nil
-}
-
-// GetStatus gets the current status of a swarming task.
-func GetStatus(ctx context.Context, client swarming.ApiClient, taskID string) (string, error) {
- res, err := client.GetTask(ctx, taskID, false)
- if err != nil {
- return "", skerr.Fmt("failed to get swarming task ID %s due to err: %v", taskID, err)
- }
- return res.State, nil
-}
-
-// GetStates returns the state of each task in a list of tasks.
-func GetStates(ctx context.Context, client swarming.ApiClient, taskIDs []string) ([]string, error) {
- return client.GetStates(ctx, taskIDs)
-}
-
-// IsTaskStateFinished checks if a swarming task state is finished
-func IsTaskStateFinished(state string) (bool, error) {
- if !slices.Contains(swarming.TASK_STATES, state) {
- return false, skerr.Fmt("Not a valid swarming task state %s", state)
- }
- return !slices.Contains(runningStates, state), nil
-}
-
-// IsTaskStateSuccess checks if a swarming task is successful or not. Makes no assumptions
-// about whether it is still running
-func IsTaskStateSuccess(state string) bool {
- return state == swarming.TASK_STATE_COMPLETED
-}
-
-func CancelTasks(ctx context.Context, client swarming.ApiClient, taskIDs []string) error {
- for _, id := range taskIDs {
- err := client.CancelTask(ctx, id, true)
- if err != nil {
- return skerr.Fmt("Could not cancel task %s due to %s", id, err)
- }
- }
- return nil
-}
-
-// GetCASOutput returns the CAS output of a swarming task in the
-// form of a RBE CAS hash.
-// GetCASOutput assumes the task is finished, or it throws an error.
-func GetCASOutput(ctx context.Context, client swarming.ApiClient, taskID string) (
- *swarmingV1.SwarmingRpcsCASReference, error) {
- task, err := client.GetTask(ctx, taskID, false)
- if err != nil {
- return nil, fmt.Errorf("error retrieving result of task %s: %s", taskID, err)
- }
- if task.State != "COMPLETED" {
- return nil, fmt.Errorf("cannot get result of task %s because it is %s and not COMPLETED", taskID, task.State)
- }
- rbe := &swarmingV1.SwarmingRpcsCASReference{
- CasInstance: task.CasOutputRoot.CasInstance,
- Digest: &swarmingV1.SwarmingRpcsDigest{
- Hash: task.CasOutputRoot.Digest.Hash,
- SizeBytes: task.CasOutputRoot.Digest.SizeBytes,
- },
- }
-
- return rbe, nil
-}
-
func createSwarmingReq(req RunBenchmarkRequest) (
*swarmingV1.SwarmingRpcsNewTaskRequest, error) {
// TODO(b/318863812): add mapping from device + benchmark to the specific run test
@@ -211,14 +111,32 @@
return &swarmingReq, nil
}
+var runningStates = []string{
+ swarming.TASK_STATE_PENDING,
+ swarming.TASK_STATE_RUNNING,
+}
+
+// IsTaskStateFinished checks if a swarming task state is finished
+func IsTaskStateFinished(state string) (bool, error) {
+ if !slices.Contains(swarming.TASK_STATES, state) {
+ return false, skerr.Fmt("Not a valid swarming task state %s", state)
+ }
+ return !slices.Contains(runningStates, state), nil
+}
+
+// IsTaskStateSuccess checks if a swarming task state is finished
+func IsTaskStateSuccess(state string) bool {
+ return state == swarming.TASK_STATE_COMPLETED
+}
+
// Run schedules a swarming task to run the RunBenchmarkRequest.
-func Run(ctx context.Context, client swarming.ApiClient, req RunBenchmarkRequest) (string, error) {
+func Run(ctx context.Context, sc backends.SwarmingClient, req RunBenchmarkRequest) (string, error) {
swarmingReq, err := createSwarmingReq(req)
if err != nil {
return "", skerr.Wrapf(err, "Could not create run test request")
}
- metadataResp, err := client.TriggerTask(ctx, swarmingReq)
+ metadataResp, err := sc.TriggerTask(ctx, swarmingReq)
if err != nil {
return "", skerr.Fmt("trigger task %v\ncaused error: %s", req, err)
}
diff --git a/pinpoint/go/run_benchmark/run_benchmark_test.go b/pinpoint/go/run_benchmark/run_benchmark_test.go
index 2de0c25..f825764 100644
--- a/pinpoint/go/run_benchmark/run_benchmark_test.go
+++ b/pinpoint/go/run_benchmark/run_benchmark_test.go
@@ -12,6 +12,7 @@
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/swarming"
"go.skia.org/infra/go/swarming/mocks"
+ "go.skia.org/infra/pinpoint/go/backends"
"go.skia.org/infra/pinpoint/go/bot_configs"
)
@@ -30,104 +31,12 @@
}
var expectedErr = skerr.Fmt("some error")
-func TestListPinpointTasks(t *testing.T) {
- ctx := context.Background()
- mockClient := mocks.NewApiClient(t)
-
- Convey(`OK`, t, func() {
- Convey(`Tasks found`, func() {
- mockClient.On("ListTasks", ctx, mock.Anything, mock.Anything,
- mock.Anything, mock.Anything).
- Return([]*swarmingV1.SwarmingRpcsTaskRequestMetadata{
- {
- TaskId: "123",
- },
- {
- TaskId: "456",
- },
- }, nil).Once()
- taskIds, err := ListPinpointTasks(ctx, mockClient, req)
- So(err, ShouldBeNil)
- So(taskIds, ShouldEqual, []string{"123", "456"})
- })
- Convey(`No tasks found`, func() {
- mockClient.On("ListTasks", ctx, mock.Anything, mock.Anything,
- mock.Anything, mock.Anything).
- Return([]*swarmingV1.SwarmingRpcsTaskRequestMetadata{}, nil).Once()
- taskIds, err := ListPinpointTasks(ctx, mockClient, req)
- So(err, ShouldBeNil)
- So(taskIds, ShouldBeEmpty)
- })
- })
- Convey(`Return error`, t, func() {
- Convey(`Missing inputs`, func() {
- req := RunBenchmarkRequest{}
- taskIds, err := ListPinpointTasks(ctx, mockClient, req)
- So(taskIds, ShouldBeNil)
- So(err, ShouldErrLike, "Cannot list tasks because request is missing JobID")
- req.JobID = "1"
- taskIds, err = ListPinpointTasks(ctx, mockClient, req)
- So(taskIds, ShouldBeNil)
- So(err, ShouldErrLike, "Cannot list tasks because request is missing cas isolate")
- })
- Convey(`Client failure`, func() {
- mockClient.On("ListTasks", ctx, mock.Anything, mock.Anything,
- mock.Anything, mock.Anything).
- Return([]*swarmingV1.SwarmingRpcsTaskRequestMetadata{}, expectedErr).Once()
- taskIds, err := ListPinpointTasks(ctx, mockClient, req)
- So(taskIds, ShouldBeNil)
- So(err, ShouldErrLike, expectedErr)
- })
- })
-}
-
-func TestGetCasOutput(t *testing.T) {
- ctx := context.Background()
- mockClient := mocks.NewApiClient(t)
-
- Convey(`OK`, t, func() {
- Convey(`CAS found`, func() {
- mockClient.On("GetTask", ctx, mock.Anything, mock.Anything).
- Return(&swarmingV1.SwarmingRpcsTaskResult{
- State: "COMPLETED",
- CasOutputRoot: &swarmingV1.SwarmingRpcsCASReference{
- CasInstance: "instance",
- Digest: &swarmingV1.SwarmingRpcsDigest{
- Hash: "hash",
- SizeBytes: 0,
- },
- },
- }, nil).Once()
- rbe, err := GetCASOutput(ctx, mockClient, "taskId")
- So(err, ShouldBeNil)
- So(rbe.CasInstance, ShouldEqual, "instance")
- So(rbe.Digest.Hash, ShouldEqual, "hash")
- So(rbe.Digest.SizeBytes, ShouldEqual, 0)
- })
- })
- Convey(`Return error`, t, func() {
- Convey(`Task not completed`, func() {
- mockClient.On("GetTask", ctx, mock.Anything, mock.Anything).
- Return(&swarmingV1.SwarmingRpcsTaskResult{
- State: "Not_Completed",
- }, nil).Once()
- rbe, err := GetCASOutput(ctx, mockClient, "taskId")
- So(err, ShouldErrLike, "cannot get result of task")
- So(rbe, ShouldBeNil)
- })
- Convey(`Client failure`, func() {
- mockClient.On("GetTask", ctx, mock.Anything, mock.Anything).
- Return(nil, expectedErr).Once()
- taskIds, err := GetCASOutput(ctx, mockClient, "taskId")
- So(taskIds, ShouldBeNil)
- So(err, ShouldErrLike, expectedErr)
- })
- })
-}
-
func TestRun(t *testing.T) {
ctx := context.Background()
mockClient := mocks.NewApiClient(t)
+ sc := &backends.SwarmingClientImpl{
+ ApiClient: mockClient,
+ }
Convey(`OK`, t, func() {
cfg, err := bot_configs.GetBotConfig("linux-perf", true)
@@ -137,7 +46,7 @@
Return(&swarmingV1.SwarmingRpcsTaskRequestMetadata{
TaskId: "123",
}, nil).Once()
- taskId, err := Run(ctx, mockClient, req)
+ taskId, err := Run(ctx, sc, req)
So(err, ShouldBeNil)
So(taskId, ShouldEqual, "123")
})
@@ -149,7 +58,7 @@
Return(&swarmingV1.SwarmingRpcsTaskRequestMetadata{
TaskId: "123",
}, expectedErr).Once()
- taskId, err := Run(ctx, mockClient, req)
+ taskId, err := Run(ctx, sc, req)
So(taskId, ShouldBeEmpty)
So(err, ShouldErrLike, expectedErr)
})
diff --git a/pinpoint/go/workflows/internal/BUILD.bazel b/pinpoint/go/workflows/internal/BUILD.bazel
index ca8b3ed..1fd88ec 100644
--- a/pinpoint/go/workflows/internal/BUILD.bazel
+++ b/pinpoint/go/workflows/internal/BUILD.bazel
@@ -11,6 +11,7 @@
visibility = ["//pinpoint/go/workflows:__subpackages__"],
deps = [
"//go/skerr",
+ "//pinpoint/go/backends",
"//pinpoint/go/build_chrome",
"//pinpoint/go/run_benchmark",
"//pinpoint/go/workflows",
diff --git a/pinpoint/go/workflows/internal/run_benchmark.go b/pinpoint/go/workflows/internal/run_benchmark.go
index ed4c603..aa12df2 100644
--- a/pinpoint/go/workflows/internal/run_benchmark.go
+++ b/pinpoint/go/workflows/internal/run_benchmark.go
@@ -7,6 +7,7 @@
swarmingV1 "go.chromium.org/luci/common/api/swarming/swarming/v1"
"go.skia.org/infra/go/skerr"
+ "go.skia.org/infra/pinpoint/go/backends"
"go.skia.org/infra/pinpoint/go/run_benchmark"
"go.skia.org/infra/pinpoint/go/workflows"
"go.temporal.io/sdk/activity"
@@ -75,7 +76,7 @@
func (rba *RunBenchmarkActivity) ScheduleTaskActivity(ctx context.Context, params workflows.RunBenchmarkParams) (string, error) {
logger := activity.GetLogger(ctx)
- sc, err := run_benchmark.DialSwarming(ctx)
+ sc, err := backends.NewSwarmingClient(ctx, backends.DefaultSwarmingServiceAddress)
if err != nil {
logger.Error("Failed to connect to swarming client:", err)
return "", skerr.Wrap(err)
@@ -89,7 +90,7 @@
func (rba *RunBenchmarkActivity) WaitTaskFinishedActivity(ctx context.Context, taskID string) (string, error) {
logger := activity.GetLogger(ctx)
- sc, err := run_benchmark.DialSwarming(ctx)
+ sc, err := backends.NewSwarmingClient(ctx, backends.DefaultSwarmingServiceAddress)
if err != nil {
logger.Error("Failed to connect to swarming client:", err)
return "", skerr.Wrap(err)
@@ -102,7 +103,7 @@
case <-ctx.Done():
return "", ctx.Err()
default:
- state, err := run_benchmark.GetStatus(ctx, sc, taskID)
+ state, err := sc.GetStatus(ctx, taskID)
if err != nil {
logger.Error("Failed to get task status:", err, "remaining retries:", failureRetries)
failureRetries -= 1
@@ -127,13 +128,13 @@
func (rba *RunBenchmarkActivity) RetrieveCASActivity(ctx context.Context, taskID string) (*swarmingV1.SwarmingRpcsCASReference, error) {
logger := activity.GetLogger(ctx)
- sc, err := run_benchmark.DialSwarming(ctx)
+ sc, err := backends.NewSwarmingClient(ctx, backends.DefaultSwarmingServiceAddress)
if err != nil {
logger.Error("Failed to connect to swarming client:", err)
return nil, skerr.Wrap(err)
}
- cas, err := run_benchmark.GetCASOutput(ctx, sc, taskID)
+ cas, err := sc.GetCASOutput(ctx, taskID)
if err != nil {
logger.Error("Failed to retrieve CAS:", err)
return nil, err