[attest] Add rate limiting and caching of verification results

The server is hitting quota limits. Requests block if the rate limit
is exceeded.

Bug: b/432211966
Change-Id: I8c390bdcd6449dbcb2eaaaef6ce0a291b27740c8
Reviewed-on: https://skia-review.googlesource.com/c/buildbot/+/1029496
Auto-Submit: Eric Boren <borenet@google.com>
Reviewed-by: Kaylee Lubick <kjlubick@google.com>
diff --git a/attest/go/attest/BUILD.bazel b/attest/go/attest/BUILD.bazel
index 24c85f1..868fcac 100644
--- a/attest/go/attest/BUILD.bazel
+++ b/attest/go/attest/BUILD.bazel
@@ -8,9 +8,11 @@
     deps = [
         "//attest/go/attestation",
         "//attest/go/types",
+        "//go/cache/local",
         "//go/common",
         "//go/httputils",
         "//go/sklog",
+        "@org_golang_x_time//rate",
     ],
 )
 
diff --git a/attest/go/attest/main.go b/attest/go/attest/main.go
index e1dcfa9..bc2de72 100644
--- a/attest/go/attest/main.go
+++ b/attest/go/attest/main.go
@@ -4,21 +4,26 @@
 	"context"
 	"flag"
 	"net/http"
+	"time"
 
 	"go.skia.org/infra/attest/go/attestation"
 	"go.skia.org/infra/attest/go/types"
+	local_cache "go.skia.org/infra/go/cache/local"
 	"go.skia.org/infra/go/common"
 	"go.skia.org/infra/go/httputils"
 	"go.skia.org/infra/go/sklog"
+	"golang.org/x/time/rate"
 )
 
 var (
 	// Flags.
-	attestor = flag.String("attestor", "", "Fully-qualified resource name of the attestor (e.g., 'projects/my-project/attestors/my-attestor')")
-	host     = flag.String("host", "localhost", "HTTP service host")
-	port     = flag.String("port", ":8000", "HTTP service port (e.g., ':8000')")
-	promPort = flag.String("prom_port", ":20000", "Metrics service address (e.g., ':10110')")
-	local    = flag.Bool("local", false, "Running locally if true. As opposed to in production.")
+	attestor             = flag.String("attestor", "", "Fully-qualified resource name of the attestor (e.g., 'projects/my-project/attestors/my-attestor')")
+	cacheSize            = flag.Int("cache_size", 10000, "Maximum number of verification results to store in the in-memory cache.")
+	maxRequestsPerMinute = flag.Int("max_requests_per_minute", 50, "Per-minute rate limit on calls to attestation APIs.")
+	host                 = flag.String("host", "localhost", "HTTP service host")
+	port                 = flag.String("port", ":8000", "HTTP service port (e.g., ':8000')")
+	promPort             = flag.String("prom_port", ":20000", "Metrics service address (e.g., ':10110')")
+	local                = flag.Bool("local", false, "Running locally if true. As opposed to in production.")
 )
 
 func main() {
@@ -34,10 +39,32 @@
 	}
 
 	ctx := context.Background()
-	client, err := attestation.NewClient(ctx, *attestor)
+
+	var client types.Client
+	var err error
+	client, err = attestation.NewClient(ctx, *attestor)
 	if err != nil {
 		sklog.Fatal(err)
 	}
+
+	if *maxRequestsPerMinute > 0 {
+		// Our quota is based on requests per minute, but rate.Limiter uses a
+		// per-second limit. Set the maximum burst to be our per-minute limit
+		// (so that we can use our entire per-minute quota immediately if
+		// necessary) and compute the per-second limit.
+		perSecondLimit := (float64(*maxRequestsPerMinute) / float64(time.Minute)) * float64(time.Second)
+		rl := rate.NewLimiter(rate.Limit(perSecondLimit), *maxRequestsPerMinute)
+		client = types.WithRateLimiter(client, rl)
+	}
+
+	if *cacheSize > 0 {
+		cache, err := local_cache.New(*cacheSize)
+		if err != nil {
+			sklog.Fatal(err)
+		}
+		client = types.WithCache(client, cache)
+	}
+
 	server := types.NewServer(client)
 
 	h := httputils.LoggingRequestResponse(server)
diff --git a/attest/go/types/BUILD.bazel b/attest/go/types/BUILD.bazel
index 2032f48..dc72b56 100644
--- a/attest/go/types/BUILD.bazel
+++ b/attest/go/types/BUILD.bazel
@@ -7,9 +7,11 @@
     importpath = "go.skia.org/infra/attest/go/types",
     visibility = ["//visibility:public"],
     deps = [
+        "//go/cache",
         "//go/skerr",
         "//go/sklog",
         "//go/util",
+        "@org_golang_x_time//rate",
     ],
 )
 
@@ -19,6 +21,7 @@
     embed = [":types"],
     deps = [
         "//attest/go/types/mocks",
+        "//go/cache/mock",
         "//go/testutils",
         "@com_github_stretchr_testify//require",
     ],
diff --git a/attest/go/types/types.go b/attest/go/types/types.go
index 339cd03..f71c2b7 100644
--- a/attest/go/types/types.go
+++ b/attest/go/types/types.go
@@ -8,9 +8,11 @@
 	"net/http"
 	"regexp"
 
+	"go.skia.org/infra/go/cache"
 	"go.skia.org/infra/go/skerr"
 	"go.skia.org/infra/go/sklog"
 	"go.skia.org/infra/go/util"
+	"golang.org/x/time/rate"
 )
 
 const (
@@ -55,6 +57,22 @@
 	Verify(ctx context.Context, imageID string) (bool, error)
 }
 
+// VerifyFunc is a function which finds and validates the attestation for the
+// given Docker image ID. It returns true if any attestation exists with a valid
+// signature and false if no such attestation exists, or an error if any of the
+// required API calls failed.
+//
+// VerifyFunc is an adapter which allows the use of ordinary functions as
+// Client implementations.
+type VerifyFunc func(ctx context.Context, imageID string) (bool, error)
+
+// Verify implements Client.
+func (f VerifyFunc) Verify(ctx context.Context, imageID string) (bool, error) {
+	return f(ctx, imageID)
+}
+
+var _ Client = VerifyFunc(nil)
+
 // HttpClient implements Client by communicating with the attest service.
 type HttpClient struct {
 	host string
@@ -120,6 +138,51 @@
 
 var _ Client = &HttpClient{}
 
+// cache.Cache uses strings for keys and values.
+const (
+	cachedValueTrue  = "true"
+	cachedValueFalse = "false"
+)
+
+// WithCache returns a Client which uses the given cache.
+func WithCache(wrapped Client, cache cache.Cache) VerifyFunc {
+	return func(ctx context.Context, imageID string) (bool, error) {
+		cachedValue, err := cache.GetValue(ctx, imageID)
+		if err != nil {
+			return false, skerr.Wrapf(err, "failed to retrieve cached value for %s", imageID)
+		}
+		switch cachedValue {
+		case cachedValueTrue:
+			return true, nil
+		case cachedValueFalse:
+			return false, nil
+		default:
+			verified, err := wrapped.Verify(ctx, imageID)
+			if err != nil {
+				return false, skerr.Wrap(err)
+			}
+			cachedValue = cachedValueFalse
+			if verified {
+				cachedValue = cachedValueTrue
+			}
+			if err := cache.SetValue(ctx, imageID, cachedValue); err != nil {
+				return false, skerr.Wrapf(err, "failed to set cached value for %s", imageID)
+			}
+			return verified, nil
+		}
+	}
+}
+
+// WithRateLimiter returns a Client which uses the given rate.Limiter.
+func WithRateLimiter(wrapped Client, lim *rate.Limiter) VerifyFunc {
+	return func(ctx context.Context, imageID string) (bool, error) {
+		if err := lim.Wait(ctx); err != nil {
+			return false, skerr.Wrap(err)
+		}
+		return wrapped.Verify(ctx, imageID)
+	}
+}
+
 // Server wraps a Client and serves HTTP requests.
 type Server struct {
 	wrappedClient Client
diff --git a/attest/go/types/types_test.go b/attest/go/types/types_test.go
index c431510..1f483e6 100644
--- a/attest/go/types/types_test.go
+++ b/attest/go/types/types_test.go
@@ -1,11 +1,13 @@
 package types
 
 import (
+	"context"
 	"net/http/httptest"
 	"testing"
 
 	"github.com/stretchr/testify/require"
 	"go.skia.org/infra/attest/go/types/mocks"
+	cache_mocks "go.skia.org/infra/go/cache/mock"
 	"go.skia.org/infra/go/testutils"
 )
 
@@ -85,3 +87,67 @@
 	require.True(t, IsErrBadImageFormat(err))
 	require.False(t, verified)
 }
+
+func TestClientWithCache_NotCached_Verified(t *testing.T) {
+	ctx := context.Background()
+	mockClient := &mocks.Client{}
+	mockCache := &cache_mocks.Cache{}
+	client := WithCache(mockClient, mockCache)
+
+	const imageID = "fake-image-id"
+	mockCache.On("GetValue", testutils.AnyContext, imageID).Return("", nil).Once()
+	mockClient.On("Verify", testutils.AnyContext, imageID).Return(true, nil).Once()
+	mockCache.On("SetValue", testutils.AnyContext, imageID, cachedValueTrue).Return(nil).Once()
+	verified, err := client.Verify(ctx, imageID)
+	require.NoError(t, err)
+	require.True(t, verified)
+	mockCache.AssertExpectations(t)
+	mockClient.AssertExpectations(t)
+}
+
+func TestClientWithCache_Cached_Verified(t *testing.T) {
+	ctx := context.Background()
+	mockClient := &mocks.Client{}
+	mockCache := &cache_mocks.Cache{}
+	client := WithCache(mockClient, mockCache)
+
+	const imageID = "fake-image-id"
+	mockCache.On("GetValue", testutils.AnyContext, imageID).Return(cachedValueTrue, nil).Once()
+	verified, err := client.Verify(ctx, imageID)
+	require.NoError(t, err)
+	require.True(t, verified)
+	mockCache.AssertExpectations(t)
+	mockClient.AssertExpectations(t)
+}
+
+func TestClientWithCache_NotCached_NotVerified(t *testing.T) {
+	ctx := context.Background()
+	mockClient := &mocks.Client{}
+	mockCache := &cache_mocks.Cache{}
+	client := WithCache(mockClient, mockCache)
+
+	const imageID = "fake-image-id"
+	mockCache.On("GetValue", testutils.AnyContext, imageID).Return("", nil).Once()
+	mockClient.On("Verify", testutils.AnyContext, imageID).Return(false, nil).Once()
+	mockCache.On("SetValue", testutils.AnyContext, imageID, cachedValueFalse).Return(nil).Once()
+	verified, err := client.Verify(ctx, imageID)
+	require.NoError(t, err)
+	require.False(t, verified)
+	mockCache.AssertExpectations(t)
+	mockClient.AssertExpectations(t)
+}
+
+func TestClientWithCache_Cached_NotVerified(t *testing.T) {
+	ctx := context.Background()
+	mockClient := &mocks.Client{}
+	mockCache := &cache_mocks.Cache{}
+	client := WithCache(mockClient, mockCache)
+
+	const imageID = "fake-image-id"
+	mockCache.On("GetValue", testutils.AnyContext, imageID).Return(cachedValueFalse, nil).Once()
+	verified, err := client.Verify(ctx, imageID)
+	require.NoError(t, err)
+	require.False(t, verified)
+	mockCache.AssertExpectations(t)
+	mockClient.AssertExpectations(t)
+}
diff --git a/kube/cmd/k8s-config-presubmit/main.go b/kube/cmd/k8s-config-presubmit/main.go
index a008d9d..209b60d 100644
--- a/kube/cmd/k8s-config-presubmit/main.go
+++ b/kube/cmd/k8s-config-presubmit/main.go
@@ -360,10 +360,10 @@
 	for image := range images {
 		verified, err := client.Verify(ctx, image)
 		if err != nil {
-			logf(ctx, "Failed to verify image %s: %s", image, err)
+			logf(ctx, "Failed to verify image %s: %s\n", image, err)
 			return false
 		} else if !verified {
-			logf(ctx, "Image %s does not have a valid attestation", image)
+			logf(ctx, "Image %s does not have a valid attestation\n", image)
 			ok = false
 		}
 	}