Skip to content

Commit

Permalink
reactor: modify storage inteaface and add contxt.Context
Browse files Browse the repository at this point in the history
  • Loading branch information
li-jin-gou committed Jan 16, 2023
1 parent 5406560 commit 7488c13
Show file tree
Hide file tree
Showing 15 changed files with 137 additions and 127 deletions.
10 changes: 5 additions & 5 deletions app.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,23 @@ type Map map[string]interface{}
type Storage interface {
// Get gets the value for the given key.
// `nil, nil` is returned when the key does not exist
Get(key string) ([]byte, error)
Get(ctx context.Context, key string) ([]byte, error)

// Set stores the given value for the given key along
// with an expiration value, 0 means no expiration.
// Empty key or value will be ignored without an error.
Set(key string, val []byte, exp time.Duration) error
Set(ctx context.Context, key string, val []byte, exp time.Duration) error

// Delete deletes the value for the given key.
// It returns no error if the storage does not contain the key,
Delete(key string) error
Delete(ctx context.Context, key string) error

// Reset resets the storage and delete all keys.
Reset() error
Reset(ctx context.Context) error

// Close closes the storage and will stop any running garbage
// collectors and open connections.
Close() error
Close(ctx context.Context) error
}

// ErrorHandler defines a function that will process all errors
Expand Down
9 changes: 5 additions & 4 deletions internal/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package memory

import (
"context"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -31,7 +32,7 @@ func New() *Storage {
}

// Get value by key
func (s *Storage) Get(key string) interface{} {
func (s *Storage) Get(_ context.Context, key string) interface{} {
s.RLock()
v, ok := s.data[key]
s.RUnlock()
Expand All @@ -42,7 +43,7 @@ func (s *Storage) Get(key string) interface{} {
}

// Set key with value
func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
func (s *Storage) Set(_ context.Context, key string, val interface{}, ttl time.Duration) {
var exp uint32
if ttl > 0 {
exp = uint32(ttl.Seconds()) + atomic.LoadUint32(&utils.Timestamp)
Expand All @@ -54,14 +55,14 @@ func (s *Storage) Set(key string, val interface{}, ttl time.Duration) {
}

// Delete key by key
func (s *Storage) Delete(key string) {
func (s *Storage) Delete(_ context.Context, key string) {
s.Lock()
delete(s.data, key)
s.Unlock()
}

// Reset all keys
func (s *Storage) Reset() {
func (s *Storage) Reset(_ context.Context) {
nd := make(map[string]item)
s.Lock()
s.data = nd
Expand Down
13 changes: 7 additions & 6 deletions internal/storage/memory/memory.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package memory

import (
"context"
"sync"
"sync/atomic"
"time"
Expand Down Expand Up @@ -44,7 +45,7 @@ func New(config ...Config) *Storage {
}

// Get value by key
func (s *Storage) Get(key string) ([]byte, error) {
func (s *Storage) Get(_ context.Context, key string) ([]byte, error) {
if len(key) <= 0 {
return nil, nil
}
Expand All @@ -59,7 +60,7 @@ func (s *Storage) Get(key string) ([]byte, error) {
}

// Set key with value
func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
func (s *Storage) Set(_ context.Context, key string, val []byte, exp time.Duration) error {
// Ain't Nobody Got Time For That
if len(key) <= 0 || len(val) <= 0 {
return nil
Expand All @@ -78,7 +79,7 @@ func (s *Storage) Set(key string, val []byte, exp time.Duration) error {
}

// Delete key by key
func (s *Storage) Delete(key string) error {
func (s *Storage) Delete(_ context.Context, key string) error {
// Ain't Nobody Got Time For That
if len(key) <= 0 {
return nil
Expand All @@ -90,7 +91,7 @@ func (s *Storage) Delete(key string) error {
}

// Reset all keys
func (s *Storage) Reset() error {
func (s *Storage) Reset(_ context.Context) error {
ndb := make(map[string]entry)
s.mux.Lock()
s.db = ndb
Expand All @@ -99,7 +100,7 @@ func (s *Storage) Reset() error {
}

// Close the memory storage
func (s *Storage) Close() error {
func (s *Storage) Close(_ context.Context) error {
s.done <- struct{}{}
return nil
}
Expand Down Expand Up @@ -137,7 +138,7 @@ func (s *Storage) gc() {
}
}

// Return database client
// Conn Return database client
func (s *Storage) Conn() map[string]entry {
return s.db
}
41 changes: 21 additions & 20 deletions internal/storage/memory/memory_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package memory

import (
"context"
"testing"
"time"

Expand All @@ -16,7 +17,7 @@ func Test_Storage_Memory_Set(t *testing.T) {
val = []byte("doe")
)

err := testStore.Set(key, val, 0)
err := testStore.Set(context.TODO(), key, val, 0)
utils.AssertEqual(t, nil, err)
}

Expand All @@ -27,10 +28,10 @@ func Test_Storage_Memory_Set_Override(t *testing.T) {
val = []byte("doe")
)

err := testStore.Set(key, val, 0)
err := testStore.Set(context.TODO(), key, val, 0)
utils.AssertEqual(t, nil, err)

err = testStore.Set(key, val, 0)
err = testStore.Set(context.TODO(), key, val, 0)
utils.AssertEqual(t, nil, err)
}

Expand All @@ -41,10 +42,10 @@ func Test_Storage_Memory_Get(t *testing.T) {
val = []byte("doe")
)

err := testStore.Set(key, val, 0)
err := testStore.Set(context.TODO(), key, val, 0)
utils.AssertEqual(t, nil, err)

result, err := testStore.Get(key)
result, err := testStore.Get(context.TODO(), key)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, val, result)
}
Expand All @@ -57,7 +58,7 @@ func Test_Storage_Memory_Set_Expiration(t *testing.T) {
exp = 1 * time.Second
)

err := testStore.Set(key, val, exp)
err := testStore.Set(context.TODO(), key, val, exp)
utils.AssertEqual(t, nil, err)

time.Sleep(1100 * time.Millisecond)
Expand All @@ -68,15 +69,15 @@ func Test_Storage_Memory_Get_Expired(t *testing.T) {
key = "john"
)

result, err := testStore.Get(key)
result, err := testStore.Get(context.TODO(), key)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(result) == 0)
}

func Test_Storage_Memory_Get_NotExist(t *testing.T) {
t.Parallel()

result, err := testStore.Get("notexist")
result, err := testStore.Get(context.TODO(), "notexist")
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(result) == 0)
}
Expand All @@ -88,13 +89,13 @@ func Test_Storage_Memory_Delete(t *testing.T) {
val = []byte("doe")
)

err := testStore.Set(key, val, 0)
err := testStore.Set(context.TODO(), key, val, 0)
utils.AssertEqual(t, nil, err)

err = testStore.Delete(key)
err = testStore.Delete(context.TODO(), key)
utils.AssertEqual(t, nil, err)

result, err := testStore.Get(key)
result, err := testStore.Get(context.TODO(), key)
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(result) == 0)
}
Expand All @@ -105,27 +106,27 @@ func Test_Storage_Memory_Reset(t *testing.T) {
val = []byte("doe")
)

err := testStore.Set("john1", val, 0)
err := testStore.Set(context.TODO(), "john1", val, 0)
utils.AssertEqual(t, nil, err)

err = testStore.Set("john2", val, 0)
err = testStore.Set(context.TODO(), "john2", val, 0)
utils.AssertEqual(t, nil, err)

err = testStore.Reset()
err = testStore.Reset(context.TODO())
utils.AssertEqual(t, nil, err)

result, err := testStore.Get("john1")
result, err := testStore.Get(context.TODO(), "john1")
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(result) == 0)

result, err = testStore.Get("john2")
result, err = testStore.Get(context.TODO(), "john2")
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, true, len(result) == 0)
}

func Test_Storage_Memory_Close(t *testing.T) {
t.Parallel()
utils.AssertEqual(t, nil, testStore.Close())
utils.AssertEqual(t, nil, testStore.Close(context.TODO()))
}

func Test_Storage_Memory_Conn(t *testing.T) {
Expand All @@ -149,13 +150,13 @@ func Benchmark_Storage_Memory(b *testing.B) {
b.ResetTimer()
for n := 0; n < b.N; n++ {
for _, key := range keys {
d.Set(key, value, ttl)
d.Set(context.TODO(), key, value, ttl)
}
for _, key := range keys {
_, _ = d.Get(key)
_, _ = d.Get(context.TODO(), key)
}
for _, key := range keys {
d.Delete(key)
d.Delete(context.TODO(), key)
}
}
})
Expand Down
21 changes: 11 additions & 10 deletions middleware/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package cache

import (
"context"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -80,11 +81,11 @@ func New(config ...Config) fiber.Handler {
}()

// Delete key from both manager and storage
deleteKey := func(dkey string) {
manager.delete(dkey)
deleteKey := func(ctx context.Context, dkey string) {
manager.delete(ctx, dkey)
// External storage saves body data with different key
if cfg.Storage != nil {
manager.delete(dkey + "_body")
manager.delete(ctx, dkey+"_body")
}
}

Expand Down Expand Up @@ -113,7 +114,7 @@ func New(config ...Config) fiber.Handler {
key := cfg.KeyGenerator(c) + "_" + c.Method()

// Get entry from pool
e := manager.get(key)
e := manager.get(c.Context(), key)

// Lock entry
mux.Lock()
Expand All @@ -123,7 +124,7 @@ func New(config ...Config) fiber.Handler {

// Check if entry is expired
if e.exp != 0 && ts >= e.exp {
deleteKey(key)
deleteKey(c.Context(), key)
if cfg.MaxBytes > 0 {
_, size := heap.remove(e.heapidx)
storedBytes -= size
Expand All @@ -132,7 +133,7 @@ func New(config ...Config) fiber.Handler {
// Separate body value to avoid msgp serialization
// We can store raw bytes with Storage 👍
if cfg.Storage != nil {
e.body = manager.getRaw(key + "_body")
e.body = manager.getRaw(c.Context(), key+"_body")
}
// Set response headers from cache
c.Response().SetBodyRaw(e.body)
Expand Down Expand Up @@ -189,7 +190,7 @@ func New(config ...Config) fiber.Handler {
if cfg.MaxBytes > 0 {
for storedBytes+bodySize > cfg.MaxBytes {
key, size := heap.removeFirst()
deleteKey(key)
deleteKey(c.Context(), key)
storedBytes -= size
}
}
Expand Down Expand Up @@ -231,14 +232,14 @@ func New(config ...Config) fiber.Handler {

// For external Storage we store raw body separated
if cfg.Storage != nil {
manager.setRaw(key+"_body", e.body, expiration)
manager.setRaw(c.Context(), key+"_body", e.body, expiration)
// avoid body msgp encoding
e.body = nil
manager.set(key, e, expiration)
manager.set(c.Context(), key, e, expiration)
manager.release(e)
} else {
// Store entry in memory
manager.set(key, e, expiration)
manager.set(c.Context(), key, e, expiration)
}

c.Set(cfg.CacheHeader, cacheMiss)
Expand Down
Loading

0 comments on commit 7488c13

Please sign in to comment.