[attest] Add rate limiting and caching of verification results
The server is hitting quota limits. Requests block if the rate limit
is exceeded.
Bug: b/432211966
Change-Id: I8c390bdcd6449dbcb2eaaaef6ce0a291b27740c8
Reviewed-on: https://skia-review.googlesource.com/c/buildbot/+/1029496
Auto-Submit: Eric Boren <borenet@google.com>
Reviewed-by: Kaylee Lubick <kjlubick@google.com>
diff --git a/attest/go/attest/BUILD.bazel b/attest/go/attest/BUILD.bazel
index 24c85f1..868fcac 100644
--- a/attest/go/attest/BUILD.bazel
+++ b/attest/go/attest/BUILD.bazel
@@ -8,9 +8,11 @@
deps = [
"//attest/go/attestation",
"//attest/go/types",
+ "//go/cache/local",
"//go/common",
"//go/httputils",
"//go/sklog",
+ "@org_golang_x_time//rate",
],
)
diff --git a/attest/go/attest/main.go b/attest/go/attest/main.go
index e1dcfa9..bc2de72 100644
--- a/attest/go/attest/main.go
+++ b/attest/go/attest/main.go
@@ -4,21 +4,26 @@
"context"
"flag"
"net/http"
+ "time"
"go.skia.org/infra/attest/go/attestation"
"go.skia.org/infra/attest/go/types"
+ local_cache "go.skia.org/infra/go/cache/local"
"go.skia.org/infra/go/common"
"go.skia.org/infra/go/httputils"
"go.skia.org/infra/go/sklog"
+ "golang.org/x/time/rate"
)
var (
// Flags.
- attestor = flag.String("attestor", "", "Fully-qualified resource name of the attestor (e.g., 'projects/my-project/attestors/my-attestor')")
- host = flag.String("host", "localhost", "HTTP service host")
- port = flag.String("port", ":8000", "HTTP service port (e.g., ':8000')")
- promPort = flag.String("prom_port", ":20000", "Metrics service address (e.g., ':10110')")
- local = flag.Bool("local", false, "Running locally if true. As opposed to in production.")
+ attestor = flag.String("attestor", "", "Fully-qualified resource name of the attestor (e.g., 'projects/my-project/attestors/my-attestor')")
+ cacheSize = flag.Int("cache_size", 10000, "Maximum number of verification results to store in the in-memory cache.")
+ maxRequestsPerMinute = flag.Int("max_requests_per_minute", 50, "Per-minute rate limit on calls to attestation APIs.")
+ host = flag.String("host", "localhost", "HTTP service host")
+ port = flag.String("port", ":8000", "HTTP service port (e.g., ':8000')")
+ promPort = flag.String("prom_port", ":20000", "Metrics service address (e.g., ':10110')")
+ local = flag.Bool("local", false, "Running locally if true. As opposed to in production.")
)
func main() {
@@ -34,10 +39,32 @@
}
ctx := context.Background()
- client, err := attestation.NewClient(ctx, *attestor)
+
+ var client types.Client
+ var err error
+ client, err = attestation.NewClient(ctx, *attestor)
if err != nil {
sklog.Fatal(err)
}
+
+ if *maxRequestsPerMinute > 0 {
+ // Our quota is based on requests per minute, but rate.Limiter uses a
+ // per-second limit. Set the maximum burst to be our per-minute limit
+ // (so that we can use our entire per-minute quota immediately if
+ // necessary) and compute the per-second limit.
+ perSecondLimit := (float64(*maxRequestsPerMinute) / float64(time.Minute)) * float64(time.Second)
+ rl := rate.NewLimiter(rate.Limit(perSecondLimit), *maxRequestsPerMinute)
+ client = types.WithRateLimiter(client, rl)
+ }
+
+ if *cacheSize > 0 {
+ cache, err := local_cache.New(*cacheSize)
+ if err != nil {
+ sklog.Fatal(err)
+ }
+ client = types.WithCache(client, cache)
+ }
+
server := types.NewServer(client)
h := httputils.LoggingRequestResponse(server)
diff --git a/attest/go/types/BUILD.bazel b/attest/go/types/BUILD.bazel
index 2032f48..dc72b56 100644
--- a/attest/go/types/BUILD.bazel
+++ b/attest/go/types/BUILD.bazel
@@ -7,9 +7,11 @@
importpath = "go.skia.org/infra/attest/go/types",
visibility = ["//visibility:public"],
deps = [
+ "//go/cache",
"//go/skerr",
"//go/sklog",
"//go/util",
+ "@org_golang_x_time//rate",
],
)
@@ -19,6 +21,7 @@
embed = [":types"],
deps = [
"//attest/go/types/mocks",
+ "//go/cache/mock",
"//go/testutils",
"@com_github_stretchr_testify//require",
],
diff --git a/attest/go/types/types.go b/attest/go/types/types.go
index 339cd03..f71c2b7 100644
--- a/attest/go/types/types.go
+++ b/attest/go/types/types.go
@@ -8,9 +8,11 @@
"net/http"
"regexp"
+ "go.skia.org/infra/go/cache"
"go.skia.org/infra/go/skerr"
"go.skia.org/infra/go/sklog"
"go.skia.org/infra/go/util"
+ "golang.org/x/time/rate"
)
const (
@@ -55,6 +57,22 @@
Verify(ctx context.Context, imageID string) (bool, error)
}
+// VerifyFunc is a function which finds and validates the attestation for the
+// given Docker image ID. It returns true if any attestation exists with a valid
+// signature and false if no such attestation exists, or an error if any of the
+// required API calls failed.
+//
+// VerifyFunc is an adapter which allows the use of ordinary functions as
+// Client implementations.
+type VerifyFunc func(ctx context.Context, imageID string) (bool, error)
+
+// Verify implements Client.
+func (f VerifyFunc) Verify(ctx context.Context, imageID string) (bool, error) {
+ return f(ctx, imageID)
+}
+
+var _ Client = VerifyFunc(nil)
+
// HttpClient implements Client by communicating with the attest service.
type HttpClient struct {
host string
@@ -120,6 +138,51 @@
var _ Client = &HttpClient{}
+// cache.Cache uses strings for keys and values.
+const (
+ cachedValueTrue = "true"
+ cachedValueFalse = "false"
+)
+
+// WithCache returns a Client which uses the given cache.
+func WithCache(wrapped Client, cache cache.Cache) VerifyFunc {
+ return func(ctx context.Context, imageID string) (bool, error) {
+ cachedValue, err := cache.GetValue(ctx, imageID)
+ if err != nil {
+ return false, skerr.Wrapf(err, "failed to retrieve cached value for %s", imageID)
+ }
+ switch cachedValue {
+ case cachedValueTrue:
+ return true, nil
+ case cachedValueFalse:
+ return false, nil
+ default:
+ verified, err := wrapped.Verify(ctx, imageID)
+ if err != nil {
+ return false, skerr.Wrap(err)
+ }
+ cachedValue = cachedValueFalse
+ if verified {
+ cachedValue = cachedValueTrue
+ }
+ if err := cache.SetValue(ctx, imageID, cachedValue); err != nil {
+ return false, skerr.Wrapf(err, "failed to set cached value for %s", imageID)
+ }
+ return verified, nil
+ }
+ }
+}
+
+// WithRateLimiter returns a Client which uses the given rate.Limiter.
+func WithRateLimiter(wrapped Client, lim *rate.Limiter) VerifyFunc {
+ return func(ctx context.Context, imageID string) (bool, error) {
+ if err := lim.Wait(ctx); err != nil {
+ return false, skerr.Wrap(err)
+ }
+ return wrapped.Verify(ctx, imageID)
+ }
+}
+
// Server wraps a Client and serves HTTP requests.
type Server struct {
wrappedClient Client
diff --git a/attest/go/types/types_test.go b/attest/go/types/types_test.go
index c431510..1f483e6 100644
--- a/attest/go/types/types_test.go
+++ b/attest/go/types/types_test.go
@@ -1,11 +1,13 @@
package types
import (
+ "context"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
"go.skia.org/infra/attest/go/types/mocks"
+ cache_mocks "go.skia.org/infra/go/cache/mock"
"go.skia.org/infra/go/testutils"
)
@@ -85,3 +87,67 @@
require.True(t, IsErrBadImageFormat(err))
require.False(t, verified)
}
+
+func TestClientWithCache_NotCached_Verified(t *testing.T) {
+ ctx := context.Background()
+ mockClient := &mocks.Client{}
+ mockCache := &cache_mocks.Cache{}
+ client := WithCache(mockClient, mockCache)
+
+ const imageID = "fake-image-id"
+ mockCache.On("GetValue", testutils.AnyContext, imageID).Return("", nil).Once()
+ mockClient.On("Verify", testutils.AnyContext, imageID).Return(true, nil).Once()
+ mockCache.On("SetValue", testutils.AnyContext, imageID, cachedValueTrue).Return(nil).Once()
+ verified, err := client.Verify(ctx, imageID)
+ require.NoError(t, err)
+ require.True(t, verified)
+ mockCache.AssertExpectations(t)
+ mockClient.AssertExpectations(t)
+}
+
+func TestClientWithCache_Cached_Verified(t *testing.T) {
+ ctx := context.Background()
+ mockClient := &mocks.Client{}
+ mockCache := &cache_mocks.Cache{}
+ client := WithCache(mockClient, mockCache)
+
+ const imageID = "fake-image-id"
+ mockCache.On("GetValue", testutils.AnyContext, imageID).Return(cachedValueTrue, nil).Once()
+ verified, err := client.Verify(ctx, imageID)
+ require.NoError(t, err)
+ require.True(t, verified)
+ mockCache.AssertExpectations(t)
+ mockClient.AssertExpectations(t)
+}
+
+func TestClientWithCache_NotCached_NotVerified(t *testing.T) {
+ ctx := context.Background()
+ mockClient := &mocks.Client{}
+ mockCache := &cache_mocks.Cache{}
+ client := WithCache(mockClient, mockCache)
+
+ const imageID = "fake-image-id"
+ mockCache.On("GetValue", testutils.AnyContext, imageID).Return("", nil).Once()
+ mockClient.On("Verify", testutils.AnyContext, imageID).Return(false, nil).Once()
+ mockCache.On("SetValue", testutils.AnyContext, imageID, cachedValueFalse).Return(nil).Once()
+ verified, err := client.Verify(ctx, imageID)
+ require.NoError(t, err)
+ require.False(t, verified)
+ mockCache.AssertExpectations(t)
+ mockClient.AssertExpectations(t)
+}
+
+func TestClientWithCache_Cached_NotVerified(t *testing.T) {
+ ctx := context.Background()
+ mockClient := &mocks.Client{}
+ mockCache := &cache_mocks.Cache{}
+ client := WithCache(mockClient, mockCache)
+
+ const imageID = "fake-image-id"
+ mockCache.On("GetValue", testutils.AnyContext, imageID).Return(cachedValueFalse, nil).Once()
+ verified, err := client.Verify(ctx, imageID)
+ require.NoError(t, err)
+ require.False(t, verified)
+ mockCache.AssertExpectations(t)
+ mockClient.AssertExpectations(t)
+}
diff --git a/kube/cmd/k8s-config-presubmit/main.go b/kube/cmd/k8s-config-presubmit/main.go
index a008d9d..209b60d 100644
--- a/kube/cmd/k8s-config-presubmit/main.go
+++ b/kube/cmd/k8s-config-presubmit/main.go
@@ -360,10 +360,10 @@
for image := range images {
verified, err := client.Verify(ctx, image)
if err != nil {
- logf(ctx, "Failed to verify image %s: %s", image, err)
+ logf(ctx, "Failed to verify image %s: %s\n", image, err)
return false
} else if !verified {
- logf(ctx, "Image %s does not have a valid attestation", image)
+ logf(ctx, "Image %s does not have a valid attestation\n", image)
ok = false
}
}