Refactor reconnectingmemcached into its own package

Bug: skia:10582
Change-Id: Icb65097dee9ada7a851c002453fae8cdb0dce0fd
Reviewed-on: https://skia-review.googlesource.com/c/buildbot/+/371992
Reviewed-by: Joe Gregorio <jcgregorio@google.com>
Commit-Queue: Kevin Lubick <kjlubick@google.com>
diff --git a/go/reconnectingmemcached/client.go b/go/reconnectingmemcached/client.go
new file mode 100644
index 0000000..060c949
--- /dev/null
+++ b/go/reconnectingmemcached/client.go
@@ -0,0 +1,187 @@
+// Package reconnectingmemcached contains a wrapper around a general memcache client. It provides
+// the ability to automatically reconnect after a certain number of failures. While the connection
+// is down, its APIs quickly return, allowing clients to fallback to some other mechanism.
+// This design decision (instead of, for example, blocking until the connection is restored) is
+// because memcached is used where performance is critical, and it is probably faster for clients
+// to respond to a memcached outage like they would a cache miss.
+package reconnectingmemcached
+
+import (
+	"math/rand"
+	"sync"
+	"time"
+
+	"github.com/bradfitz/gomemcache/memcache"
+
+	"go.skia.org/infra/go/skerr"
+	"go.skia.org/infra/go/sklog"
+)
+
+// Client is a slightly modified version of the interface on *memcache.Client. Most of the methods
+// return a boolean instead of an error. That boolean indicates if the connection is up or down,
+// that is, if the return value is valid or if the calling client should use a fallback.
+type Client interface {
+	// ConnectionAvailable returns true if there is an established connection. If false is returned,
+	// it means the connection is being restored.
+	ConnectionAvailable() bool
+	// GetMulti returns a map filled with items that were in the cache. The boolean means "ok"
+	// and can be false if either there was an error or the connection is currently down.
+	GetMulti(keys []string) (map[string]*memcache.Item, bool)
+	// Ping returns an error if there is no connection or if any instance is down.
+	Ping() error
+	// Set unconditionally sets the item. It returns false if there was an error or the connection
+	// is currently down.
+	Set(i *memcache.Item) bool
+}
+
+// memcachedClient is the (partial) interface of memcache.Client, which is used for testing
+// purposes.
+type memcachedClient interface {
+	Ping() error
+	GetMulti(keys []string) (map[string]*memcache.Item, error)
+	Set(item *memcache.Item) error
+}
+
+type Options struct {
+	// Servers are the addresses of the servers that should be contacted with equal weight.
+	// See bradfitz/gomemcache/memcache.New() for more.
+	Servers []string
+	// Timeout is the socket read/write timeout. The default is 100 milliseconds.
+	Timeout time.Duration
+	// MaxIdleConnections is the maximum number of connections. It should be greater than or
+	// equal to the peek parallel requests. The default is 2.
+	MaxIdleConnections int
+
+	// AllowedFailuresBeforeHealing is the number of connection errors that will be tolerated
+	// before autohealing starts.
+	AllowedFailuresBeforeHealing int
+}
+
+type healingClientImpl struct {
+	opts   Options
+	client memcachedClient // if client is nil, that's a signal we are reconnecting.
+	// clientFactory is used to re-generate the client if it fails. This is due to the fact that
+	// once a *memcached.Client starts returning errors due to a bad connection, it doesn't
+	// heal itself and must be recreated.
+	clientFactory    func(Options) memcachedClient
+	clientMutex      sync.RWMutex
+	numFailures      int
+	recoveryDuration time.Duration
+}
+
+// NewClient returns a Client to talk to memcached instance(s) that will heal and re-generate
+// itself with the options provided.
+func NewClient(opts Options) *healingClientImpl {
+	if opts.AllowedFailuresBeforeHealing <= 0 {
+		opts.AllowedFailuresBeforeHealing = 10
+	}
+	c := memcachedFactory(opts)
+	return &healingClientImpl{
+		opts:             opts,
+		client:           c,
+		clientFactory:    memcachedFactory,
+		recoveryDuration: 10 * time.Second,
+	}
+}
+
+// memcachedFactor returns a "real" implementation of the memcached client.
+func memcachedFactory(opts Options) memcachedClient {
+	c := memcache.New(opts.Servers...)
+	c.Timeout = opts.Timeout                 // defaults handled from memcache client code.
+	c.MaxIdleConns = opts.MaxIdleConnections // defaults handled from memcache client code.
+	return c
+}
+
+// ConnectionAvailable returns true if the client is not nil. nil means it is being healed.
+func (h *healingClientImpl) ConnectionAvailable() bool {
+	h.clientMutex.RLock()
+	defer h.clientMutex.RUnlock()
+	return h.client != nil
+}
+
+// GetMulti passes a call through to the underlying client (if available). If the connection
+// is not available or there is an error, it returns false. Otherwise it returns the value and
+// true.
+func (h *healingClientImpl) GetMulti(keys []string) (map[string]*memcache.Item, bool) {
+	h.clientMutex.RLock()
+	if h.client == nil {
+		// currently reconnecting
+		h.clientMutex.RUnlock()
+		return nil, false
+	}
+	m, err := h.client.GetMulti(keys)
+	h.clientMutex.RUnlock() // need to free up the mutex before calling maybeReload
+	if err != nil {
+		sklog.Errorf("Could not get %d keys from memcached: %s", len(keys), err)
+		h.maybeReload()
+		return nil, false
+	}
+	return m, true
+}
+
+// Ping returns an error if the connection is being restored or any error from the
+// underlying client.
+func (h *healingClientImpl) Ping() error {
+	h.clientMutex.RLock()
+	defer h.clientMutex.RUnlock()
+	if h.client == nil {
+		return skerr.Fmt("Connection down. Reconnecting.")
+	}
+	return skerr.Wrap(h.client.Ping())
+}
+
+// Set passes through to the underlying client (if available). It returns true if the set succeeded
+// or the passed in item is nil. It returns false if there was an error or the connection is down.
+func (h *healingClientImpl) Set(i *memcache.Item) bool {
+	if i == nil {
+		return true // trivially true
+	}
+	h.clientMutex.RLock()
+	if h.client == nil {
+		// currently reconnecting
+		h.clientMutex.RUnlock()
+		return false
+	}
+	err := h.client.Set(i)
+	h.clientMutex.RUnlock() // need to free up the mutex before calling maybeReload
+	if err != nil {
+		sklog.Errorf("Could not set item with key %s to memcached: %s", i.Key, err)
+		h.maybeReload()
+		return false
+	}
+	return true
+}
+
+// maybeReload will add one to the failure count. If that brings the number of failures over the
+// limit, it will remove the connection and try to reconnect after 10-20 seconds.
+func (h *healingClientImpl) maybeReload() {
+	h.clientMutex.Lock()
+	defer h.clientMutex.Unlock()
+	h.numFailures++
+	// We add the h.client == nil check to make it so there's only one goroutine in charge of
+	// reconnecting
+	if h.numFailures < h.numFailures || h.client == nil {
+		return
+	}
+	sklog.Infof("Initiating memcached reconnection.")
+	h.client = nil
+	go func() { // spin up a background goroutine to heal the connection.
+		for {
+			// wait for a random time between recoveryDuration and 2*recoveryDuration
+			time.Sleep(h.recoveryDuration + time.Duration(float32(h.recoveryDuration)*rand.Float32()))
+			c := h.clientFactory(h.opts)
+			if err := c.Ping(); err != nil {
+				sklog.Warningf("Cannot reconnect to memcached: %s", err)
+				continue // go back to sleep, try again later
+			}
+			h.clientMutex.Lock()
+			h.client = c
+			h.numFailures = 0
+			sklog.Infof("Reconnected to memcached")
+			h.clientMutex.Unlock()
+			return
+		}
+	}()
+}
+
+var _ Client = (*healingClientImpl)(nil)
diff --git a/go/reconnectingmemcached/client_test.go b/go/reconnectingmemcached/client_test.go
new file mode 100644
index 0000000..8524fe4
--- /dev/null
+++ b/go/reconnectingmemcached/client_test.go
@@ -0,0 +1,173 @@
+package reconnectingmemcached
+
+import (
+	"errors"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/bradfitz/gomemcache/memcache"
+	"github.com/stretchr/testify/assert"
+
+	"go.skia.org/infra/go/testutils/unittest"
+)
+
+func TestGetMulti_BooleanRepresentsConnectionState(t *testing.T) {
+	unittest.SmallTest(t)
+
+	hc, fmc := makeClientWithFakeMemcache()
+	_, ok := hc.GetMulti([]string{"whatever"})
+	assert.True(t, ok)
+
+	fmc.isDown = true
+
+	_, ok = hc.GetMulti([]string{"whatever"})
+	assert.False(t, ok)
+}
+
+func TestSet_BooleanRepresentsConnectionState(t *testing.T) {
+	unittest.SmallTest(t)
+
+	hc, fmc := makeClientWithFakeMemcache()
+	assert.True(t, hc.Set(&memcache.Item{}))
+
+	fmc.isDown = true
+
+	assert.False(t, hc.Set(&memcache.Item{}))
+}
+
+func TestPing_ReturnsErrorOnBadConnection(t *testing.T) {
+	unittest.SmallTest(t)
+
+	hc, fmc := makeClientWithFakeMemcache()
+	assert.NoError(t, hc.Ping())
+
+	fmc.isDown = true
+
+	assert.Error(t, hc.Ping())
+}
+
+func TestRecovery_ConnectionReattemptedAfterAFewSeconds(t *testing.T) {
+	unittest.LargeTest(t)
+
+	hc, fmc := makeClientWithFakeMemcache()
+	hc.numFailures = 5
+	hc.recoveryDuration = time.Second
+	fmc.isDown = true
+	// Connection hasn't been detect as down yet
+	assert.True(t, hc.ConnectionAvailable())
+
+	// Inject a few more failures than required to make sure we don't block until healed.
+	const failuresToInject = 10
+	wc := sync.WaitGroup{}
+	wc.Add(failuresToInject)
+	for i := 0; i < failuresToInject; i++ {
+		go func(isSet bool) {
+			defer wc.Done()
+			if isSet {
+				assert.False(t, hc.Set(&memcache.Item{}))
+			} else {
+				_, ok := hc.GetMulti([]string{"whatever"})
+				assert.False(t, ok)
+			}
+		}(i%2 == 0)
+	}
+	wc.Wait()
+	// Connection should be down and healing
+	assert.False(t, hc.ConnectionAvailable())
+	// Things should be returning false
+	assert.False(t, hc.Set(&memcache.Item{}))
+	_, ok := hc.GetMulti([]string{"whatever"})
+	assert.False(t, ok)
+	assert.Error(t, hc.Ping())
+
+	// Wait until we are sure the connection has been restored.
+	time.Sleep(hc.recoveryDuration*2 + time.Second)
+
+	// Connection should be back up
+	assert.True(t, hc.ConnectionAvailable())
+	// Things should be returning true again
+	assert.True(t, hc.Set(&memcache.Item{}))
+	_, ok = hc.GetMulti([]string{"whatever"})
+	assert.True(t, ok)
+	assert.NoError(t, hc.Ping())
+}
+
+func TestRecovery_HealsAfterThirdTry(t *testing.T) {
+	unittest.MediumTest(t)
+
+	const requiredRecoveryAttempts = 3
+	recoveryAttempts := 0
+
+	hc, fmc := makeClientWithFakeMemcache()
+	hc.numFailures = 0
+	hc.clientFactory = func(_ Options) memcachedClient {
+		recoveryAttempts++
+		if recoveryAttempts >= requiredRecoveryAttempts {
+			fmc.recover()
+		}
+		return fmc
+	}
+	hc.recoveryDuration = time.Millisecond
+
+	fmc.isDown = true
+
+	_, ok := hc.GetMulti([]string{"whatever"})
+	assert.False(t, ok)
+
+	time.Sleep(time.Second)
+	assert.True(t, hc.ConnectionAvailable())
+	assert.Equal(t, 3, recoveryAttempts)
+
+}
+
+func makeClientWithFakeMemcache() (*healingClientImpl, *fakeMemcacheClient) {
+	fmc := &fakeMemcacheClient{}
+	return &healingClientImpl{
+		client: fmc,
+		clientFactory: func(_ Options) memcachedClient {
+			// Call recover to signal connection restored and then return the
+			// same client to make it easy to handle assertions.
+			fmc.recover()
+			return fmc
+		},
+	}, fmc
+}
+
+type fakeMemcacheClient struct {
+	isDown bool
+	mutex  sync.RWMutex
+}
+
+func (f *fakeMemcacheClient) Ping() error {
+	f.mutex.RLock()
+	defer f.mutex.RUnlock()
+	if f.isDown {
+		return errors.New("down")
+	}
+	return nil
+}
+
+func (f *fakeMemcacheClient) GetMulti(_ []string) (map[string]*memcache.Item, error) {
+	f.mutex.RLock()
+	defer f.mutex.RUnlock()
+	if f.isDown {
+		return nil, errors.New("down")
+	}
+	return map[string]*memcache.Item{}, nil
+}
+
+func (f *fakeMemcacheClient) Set(_ *memcache.Item) error {
+	f.mutex.RLock()
+	defer f.mutex.RUnlock()
+	if f.isDown {
+		return errors.New("down")
+	}
+	return nil
+}
+
+func (f *fakeMemcacheClient) recover() {
+	f.mutex.Lock()
+	defer f.mutex.Unlock()
+	f.isDown = false
+}
diff --git a/golden/cmd/diffcalculator/diffcalculator.go b/golden/cmd/diffcalculator/diffcalculator.go
index b20d20c..f809b64 100644
--- a/golden/cmd/diffcalculator/diffcalculator.go
+++ b/golden/cmd/diffcalculator/diffcalculator.go
@@ -14,16 +14,16 @@
 	"sync/atomic"
 	"time"
 
-	"github.com/bradfitz/gomemcache/memcache"
-
 	"cloud.google.com/go/pubsub"
 	gstorage "cloud.google.com/go/storage"
+	"github.com/bradfitz/gomemcache/memcache"
 	"github.com/jackc/pgx/v4/pgxpool"
 	"go.opencensus.io/trace"
 
 	"go.skia.org/infra/go/common"
 	"go.skia.org/infra/go/httputils"
 	"go.skia.org/infra/go/metrics2"
+	"go.skia.org/infra/go/reconnectingmemcached"
 	"go.skia.org/infra/go/skerr"
 	"go.skia.org/infra/go/sklog"
 	"go.skia.org/infra/go/util"
@@ -189,10 +189,7 @@
 const failureReconnectLimit = 100
 
 type memcachedDiffCache struct {
-	serverAddress string
-	client        *memcache.Client
-	clientMutex   sync.RWMutex
-	numFailures   int
+	client reconnectingmemcached.Client
 
 	// namespace is the string to add to each key to avoid conflicts with more than one
 	// gold instance.
@@ -200,11 +197,16 @@
 }
 
 func newMemcacheDiffCache(server, namespace string) (*memcachedDiffCache, error) {
-	m := &memcachedDiffCache{serverAddress: server, namespace: namespace}
-	c := memcache.New(server)
-	c.Timeout = time.Second
-	m.client = c
-	return m, c.Ping()
+	m := &memcachedDiffCache{
+		client: reconnectingmemcached.NewClient(reconnectingmemcached.Options{
+			Servers:                      []string{server},
+			Timeout:                      time.Second,
+			MaxIdleConnections:           4,
+			AllowedFailuresBeforeHealing: failureReconnectLimit,
+		}),
+		namespace: namespace,
+	}
+	return m, m.client.Ping()
 }
 
 func key(namespace string, left, right types.Digest) string {
@@ -222,16 +224,8 @@
 		}
 		keys = append(keys, key(m.namespace, left, right))
 	}
-	m.clientMutex.RLock()
-	if m.client == nil { // memcached client unavailable
-		m.clientMutex.RUnlock()
-		return rightDigests
-	}
-	alreadyCalculated, err := m.client.GetMulti(keys)
-	m.clientMutex.RUnlock()
-	if err != nil {
-		sklog.Warningf("Could not read from memcached: %s", err)
-		m.maybeReload()
+	alreadyCalculated, ok := m.client.GetMulti(keys)
+	if !ok {
 		return rightDigests // on an error, assume all need to be queried from DB.
 	}
 	if len(alreadyCalculated) == len(keys) {
@@ -253,51 +247,10 @@
 
 // StoreDiffComputed implements the DiffCache interface.
 func (m *memcachedDiffCache) StoreDiffComputed(_ context.Context, left, right types.Digest) {
-	m.clientMutex.RLock()
-	if m.client == nil { // memcached client unavailable
-		m.clientMutex.RUnlock()
-		return
-	}
-	err := m.client.Set(&memcache.Item{
+	m.client.Set(&memcache.Item{
 		Key:   key(m.namespace, left, right),
 		Value: []byte{0x01},
 	})
-	m.clientMutex.RUnlock()
-	if err != nil {
-		sklog.Warningf("Could not set in memcached: %s", err)
-		m.maybeReload()
-	}
-}
-
-func (m *memcachedDiffCache) maybeReload() {
-	m.clientMutex.Lock()
-	m.numFailures++
-	// We add the m.client == nil check to make it so there's only one goroutine in charge of
-	// reconnecting
-	if m.numFailures < failureReconnectLimit || m.client == nil {
-		m.clientMutex.Unlock()
-		return
-	}
-	m.client = nil
-	m.clientMutex.Unlock()
-	go func() { // spin up a background goroutine to heal the connection.
-		for {
-			// wait 10 seconds + some jitter to re-connect
-			time.Sleep(10*time.Second + time.Duration(float32(10*time.Second)*rand.Float32()))
-			c := memcache.New(m.serverAddress)
-			c.Timeout = time.Second
-			if err := c.Ping(); err != nil {
-				sklog.Warningf("Cannot reconnect to memcached: %s", err)
-				continue // go back to sleep, try again later
-			}
-			m.clientMutex.Lock()
-			m.client = c
-			m.numFailures = 0
-			sklog.Infof("Reconnected to memcached")
-			m.clientMutex.Unlock()
-			return
-		}
-	}()
 }
 
 // noopDiffCache pretends the memcached instance always does not have what we are looking up.