diff --git a/ezcache.go b/ezcache.go index 24a9ea9..51b52ed 100644 --- a/ezcache.go +++ b/ezcache.go @@ -1,15 +1,17 @@ package ezcache import ( - "errors" + "sync" "time" ) type ezc[K comparable, V any] struct { - fetch Fetcher[K, V] - exp time.Duration - cache map[K]V - setTime map[K]time.Time + fetch Fetcher[K, V] + exp time.Duration + cache map[K]V + setTime map[K]time.Time + lock sync.RWMutex + timeLock sync.RWMutex } func New[K comparable, V any](fetcher Fetcher[K, V], exp time.Duration) (Cache[K, V], error) { @@ -28,11 +30,33 @@ func New[K comparable, V any](fetcher Fetcher[K, V], exp time.Duration) (Cache[K return c, nil } -var errUnimpl = errors.New("unimplemented") - func (c *ezc[K, V]) Get(key K) (V, error) { - var val V - return val, errUnimpl + c.timeLock.RLock() + setTime, ok := c.setTime[key] + c.timeLock.RUnlock() + if ok && time.Since(setTime) <= c.exp { + c.lock.RLock() + val, ok := c.cache[key] + c.lock.RUnlock() + if ok { + return val, nil + } + } + + val, err := c.fetch(key) + if err != nil { + return val, err + } + + c.lock.Lock() + defer c.lock.Unlock() + c.cache[key] = val + + c.timeLock.Lock() + defer c.timeLock.Unlock() + c.setTime[key] = time.Now() + + return val, nil } func (c *ezc[K, V]) SetFetcher(f Fetcher[K, V]) error { diff --git a/ezcache_test.go b/ezcache_test.go index 127bc65..b713a75 100644 --- a/ezcache_test.go +++ b/ezcache_test.go @@ -42,3 +42,43 @@ func TestNewBadExpiry(tt *testing.T) { }) } } + +func TestGetHappy(t *testing.T) { + var hit bool + cache, _ := ezcache.New(func(key uint8) (string, error) { hit = true; return fetcher(key) }, 5*time.Second) + + val, err := cache.Get(4) + assert.NoError(t, err) + assert.Equal(t, "4", val) + assert.True(t, hit) + + hit = false + val, err = cache.Get(4) + assert.NoError(t, err) + assert.Equal(t, "4", val) + assert.False(t, hit) +} + +func TestGetExpire(t *testing.T) { + var hit bool + cache, _ := ezcache.New(func(key uint8) (string, error) { hit = true; return fetcher(key) }, 1) + + val, err := cache.Get(4) + assert.NoError(t, err) + assert.Equal(t, "4", val) + assert.True(t, hit) + + hit = false + time.Sleep(2) + val, err = cache.Get(4) + assert.NoError(t, err) + assert.Equal(t, "4", val) + assert.True(t, hit) +} + +func TestGetError(t *testing.T) { + cache, _ := ezcache.New(func(k uint8) (byte, error) { return 0, fmt.Errorf("Nope for %d", k) }, 1) + + _, err := cache.Get(4) + assert.ErrorContains(t, err, "Nope for 4") +}