blob: 1a27f96be28a10a29e3c90fd9550e7c2eb8232b4 [file] [log] [blame]
package atomic_miss_cache
import (
"context"
"errors"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
type backingCache struct {
cache map[string]*string
}
func (c *backingCache) Get(ctx context.Context, key string) (Value, error) {
val, ok := c.cache[key]
if !ok {
return nil, ErrNoSuchEntry
}
return val, nil
}
func (c *backingCache) Set(ctx context.Context, key string, val Value) error {
c.cache[key] = val.(*string)
return nil
}
func (c *backingCache) Delete(ctx context.Context, key string) error {
delete(c.cache, key)
return nil
}
func TestAtomicMissCache(t *testing.T) {
ctx := context.Background()
// Basic tests for the cache.
test := func(c *AtomicMissCache) {
k := "key"
got, err := c.Get(ctx, k)
require.Equal(t, ErrNoSuchEntry, err)
require.Nil(t, got)
val1Str := "hello"
val1 := &val1Str
require.NoError(t, c.Set(ctx, k, val1))
got, err = c.Get(ctx, k)
require.NoError(t, err)
require.True(t, val1 == got)
// We shouldn't set the new value in SetIfUnset.
val2Str := "world"
val2 := &val2Str
got, err = c.SetIfUnset(ctx, k, func(ctx context.Context) (Value, error) {
return val2, nil
})
require.NoError(t, err)
require.True(t, val1 == got)
got, err = c.Get(ctx, k)
require.NoError(t, err)
require.True(t, val1 == got)
require.False(t, val2 == got)
// Ensure that we do actually SetIfUnset.
got, err = c.SetIfUnset(ctx, "key2", func(ctx context.Context) (Value, error) {
return val2, nil
})
require.NoError(t, err)
require.True(t, val2 == got)
got, err = c.Get(ctx, "key2")
require.NoError(t, err)
require.True(t, val2 == got)
// Ensure that the error is passed back.
got, err = c.SetIfUnset(ctx, "key3", func(ctx context.Context) (Value, error) {
return val2, errors.New("fail")
})
require.EqualError(t, err, "fail")
require.Nil(t, got)
got, err = c.Get(ctx, "key3")
require.Equal(t, ErrNoSuchEntry, err)
require.Nil(t, got)
}
// The cache should work with or without a backing cache.
c1 := New(nil)
test(c1)
bc := &backingCache{
cache: map[string]*string{},
}
c2 := New(bc)
test(c2)
// Test that we read and write through to the backing cache.
val3Str := "hi"
val3 := &val3Str
key4 := "key4"
require.NoError(t, bc.Set(ctx, key4, val3))
got, err := c2.Get(ctx, key4)
require.NoError(t, err)
require.True(t, val3 == got)
val4Str := "blah"
val4 := &val4Str
require.NoError(t, c2.Set(ctx, key4, val4))
got, err = c2.Get(ctx, key4)
require.NoError(t, err)
require.True(t, val4 == got)
got, err = bc.Get(ctx, key4)
require.NoError(t, err)
require.True(t, val4 == got)
// The backing cache has a value but the cache itself doesn't. Ensure
// that we pull it from the backing cache and don't overwrite it.
key5 := "key5"
require.NoError(t, bc.Set(ctx, key5, val3))
got, err = c2.SetIfUnset(ctx, key5, func(ctx context.Context) (Value, error) {
return val4, nil
})
require.NoError(t, err)
require.True(t, val3 == got)
got, err = c2.Get(ctx, key5)
require.NoError(t, err)
require.True(t, val3 == got)
// Delete an entry.
require.NoError(t, c2.Delete(ctx, key5))
got, err = c2.Get(ctx, key5)
require.Equal(t, ErrNoSuchEntry, err)
require.Nil(t, got)
got, err = bc.Get(ctx, key5)
require.Equal(t, ErrNoSuchEntry, err)
require.Nil(t, got)
}
func TestAtomicMissCacheLocking(t *testing.T) {
ctx := context.Background()
wait := make(chan struct{})
c := New(nil)
k1 := "k1"
v1 := &struct{}{}
go func() {
// Wait for the below func to start, indicating that the cache
// entry should be locked.
<-wait
got, err := c.Get(ctx, k1)
require.NoError(t, err)
require.True(t, v1 == got)
}()
got, err := c.SetIfUnset(ctx, k1, func(ctx context.Context) (Value, error) {
// Signal the other goroutine to start.
wait <- struct{}{}
return v1, nil
})
require.NoError(t, err)
require.True(t, v1 == got)
}
type miniEntry struct {
val int
}
func cacheLen(c *AtomicMissCache) int {
length := 0
c.ForEach(context.Background(), func(_ context.Context, _ string, _ Value) {
length++
})
return length
}
func TestAtomicMissCacheForEach(t *testing.T) {
ctx := context.Background()
c := New(nil)
for i := 0; i < 10; i++ {
require.NoError(t, c.Set(ctx, strconv.Itoa(i), &miniEntry{val: i}))
}
got := map[string]bool{}
c.ForEach(ctx, func(ctx context.Context, key string, value Value) {
got[key] = true
})
require.Equal(t, 10, len(got))
require.NoError(t, c.Cleanup(ctx, func(ctx context.Context, key string, value Value) bool {
i, err := strconv.Atoi(key)
require.NoError(t, err)
return i >= 5
}))
require.Equal(t, 5, cacheLen(c))
got = map[string]bool{}
c.ForEach(ctx, func(ctx context.Context, key string, value Value) {
got[key] = true
i, err := strconv.Atoi(key)
require.NoError(t, err)
require.False(t, i >= 5)
})
require.Equal(t, 5, len(got))
}