diff --git a/fastcache.go b/fastcache.go index 8a3268e..460824a 100644 --- a/fastcache.go +++ b/fastcache.go @@ -68,7 +68,13 @@ type Options struct { // Cache based on uri+querystring. IncludeQueryString bool + // Compression options. Compression CompressionsOptions + + // QueryArgsTransformerHook is a hook that can be used to transform the query string + // before it is used to generate the cache key. This can be used to selectively + // include / exclude query parameters from the cache key. + QueryArgsTransformerHook func(*fasthttp.Args) } // Item represents the cache entry for a single endpoint with the actual cache @@ -126,14 +132,7 @@ func (f *FastCache) Cached(h fastglue.FastRequestHandler, o *Options, group stri o.Compression.MinLength = 500 } - var hash [16]byte - // If IncludeQueryString option is set then cache based on uri + md5(query_string) - if o.IncludeQueryString { - hash = md5.Sum(r.RequestCtx.URI().FullURI()) - } else { - hash = md5.Sum(r.RequestCtx.URI().Path()) - } - uri := hex.EncodeToString(hash[:]) + uri := f.makeURI(r, o) // Fetch etag + cached bytes from the store. blob, err := f.s.Get(namespace, group, uri) @@ -245,6 +244,44 @@ func (f *FastCache) DelGroup(namespace string, group ...string) error { return f.s.DelGroup(namespace, group...) } +func (f *FastCache) makeURI(r *fastglue.Request, o *Options) string { + var hash [16]byte + + // lexicographically sort the query string. + r.RequestCtx.QueryArgs().Sort(func(x, y []byte) int { + return bytes.Compare(x, y) + }) + + // If IncludeQueryString option is set then cache based on uri + md5(query_string) + if o.IncludeQueryString { + id := r.RequestCtx.URI().FullURI() + + // Check if we need to include only specific query params. + if o.QueryArgsTransformerHook != nil { + // Acquire a copy so as to not modify the request. + uriRaw := fasthttp.AcquireURI() + r.RequestCtx.URI().CopyTo(uriRaw) + + q := uriRaw.QueryArgs() + + // Call the hook to transform the query string. + o.QueryArgsTransformerHook(q) + + // Get the new URI. + id = uriRaw.FullURI() + + // Release the borrowed URI. + fasthttp.ReleaseURI(uriRaw) + } + + hash = md5.Sum(id) + } else { + hash = md5.Sum(r.RequestCtx.URI().Path()) + } + + return hex.EncodeToString(hash[:]) +} + // cache caches a response body. func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Options) error { // ETag?. @@ -258,14 +295,7 @@ func (f *FastCache) cache(r *fastglue.Request, namespace, group string, o *Optio } // Write cache to the store (etag, content type, response body). - var hash [16]byte - // If IncludeQueryString option is set then cache based on uri + md5(query_string) - if o.IncludeQueryString { - hash = md5.Sum(r.RequestCtx.URI().FullURI()) - } else { - hash = md5.Sum(r.RequestCtx.URI().Path()) - } - uri := hex.EncodeToString(hash[:]) + uri := f.makeURI(r, o) var blob []byte if !o.NoBlob { diff --git a/fastcache_test.go b/fastcache_test.go index d2d2248..53b489c 100644 --- a/fastcache_test.go +++ b/fastcache_test.go @@ -3,6 +3,7 @@ package fastcache_test import ( "bytes" "compress/gzip" + "fmt" "io" "log" "net/http" @@ -70,6 +71,49 @@ func init() { Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), } + includeQS = &fastcache.Options{ + NamespaceKey: namespaceKey, + ETag: true, + TTL: time.Second * 5, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), + IncludeQueryString: true, + } + + includeQSNoEtag = &fastcache.Options{ + NamespaceKey: namespaceKey, + ETag: false, + TTL: time.Second * 5, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), + IncludeQueryString: true, + } + + includeQSSpecific = &fastcache.Options{ + NamespaceKey: namespaceKey, + ETag: true, + TTL: time.Second * 5, + Logger: log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile), + IncludeQueryString: true, + QueryArgsTransformerHook: func(args *fasthttp.Args) { + // Copy the keys to delete, and delete them later. This is to + // avoid borking the VisitAll() iterator. + mp := map[string]struct{}{ + "foo": {}, + } + + delKeys := [][]byte{} + args.VisitAll(func(k, v []byte) { + if _, ok := mp[string(k)]; !ok { + delKeys = append(delKeys, k) + } + }) + + // Delete the keys. + for _, k := range delKeys { + args.DelBytes(k) + } + }, + } + fc = fastcache.New(cachestore.New("CACHE:", redis.NewClient(&redis.Options{ Addr: rd.Addr(), }))) @@ -102,6 +146,19 @@ func init() { return r.SendBytes(200, "text/plain", content) }, cfgDefault, group)) + srv.GET("/include-qs", fc.Cached(func(r *fastglue.Request) error { + return r.SendBytes(200, "text/plain", content) + }, includeQS, group)) + + srv.GET("/include-qs-no-etag", fc.Cached(func(r *fastglue.Request) error { + out := time.Now() + return r.SendBytes(200, "text/plain", []byte(fmt.Sprintf("%v", out))) + }, includeQSNoEtag, group)) + + srv.GET("/include-qs-specific", fc.Cached(func(r *fastglue.Request) error { + return r.SendBytes(200, "text/plain", content) + }, includeQSSpecific, group)) + // Start the server go func() { s := &fasthttp.Server{ @@ -177,6 +234,7 @@ func TestCache(t *testing.T) { if r.StatusCode != 200 { t.Fatalf("expected 200 but got %v", r.StatusCode) } + r, b = getReq(srvRoot+"/cached", r.Header.Get("Etag"), false, t) if r.StatusCode != 200 { t.Fatalf("expected 200 but got '%v'", r.StatusCode) @@ -213,6 +271,103 @@ func TestCache(t *testing.T) { } } +func TestQueryString(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/include-qs?foo=bar", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, _ = getReq(srvRoot+"/include-qs?foo=bar", r.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } +} + +func TestQueryStringLexicographical(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/include-qs?foo=bar&baz=qux", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, _ = getReq(srvRoot+"/include-qs?baz=qux&foo=bar", r.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } +} + +func TestQueryStringWithoutEtag(t *testing.T) { + // First request should be 200. + r, b := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r.StatusCode) + } + + // Second should be 200 but with same response. + r2, b2 := getReq(srvRoot+"/include-qs-no-etag?foo=bar", "", false, t) + if r2.StatusCode != 200 { + t.Fatalf("expected 200 but got '%v'", r2.StatusCode) + } + + if !bytes.Equal(b, b2) { + t.Fatalf("expected '%v' in body but got %v", b, b2) + } + + // Third should be 200 but with different response. + r3, b3 := getReq(srvRoot+"/include-qs-no-etag?foo=baz", "", false, t) + if r3.StatusCode != 200 { + t.Fatalf("expected 200 but got '%v'", r3.StatusCode) + } + + // time should be different + if bytes.Equal(b, b3) { + t.Fatalf("expected both to be different (should not be %v), but got %v", b, b3) + } +} + +func TestQueryStringSpecific(t *testing.T) { + // First request should be 200. + r1, b := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", "", false, t) + if r1.StatusCode != 200 { + t.Fatalf("expected 200 but got %v", r1.StatusCode) + } + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } + + // Second should be 304. + r, _ := getReq(srvRoot+"/include-qs-specific?foo=bar&baz=qux", r1.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } + + // Third should be 304 as foo=bar + r, _ = getReq(srvRoot+"/include-qs-specific?loo=mar&foo=bar&baz=qux&quux=quuz", r1.Header.Get("Etag"), false, t) + if r.StatusCode != 304 { + t.Fatalf("expected 304 but got '%v'", r.StatusCode) + } + + // Fourth should be 200 as foo=rab + r, b = getReq(srvRoot+"/include-qs-specific?foo=rab&baz=qux&quux=quuz", r1.Header.Get("Etag"), false, t) + if r.StatusCode != 200 { + t.Fatalf("expected 200 but got '%v'", r.StatusCode) + } + if !bytes.Equal(b, content) { + t.Fatalf("expected 'ok' in body but got %v", b) + } +} + func TestNoCache(t *testing.T) { // All requests should return 200. for n := 0; n < 3; n++ { diff --git a/stores/goredis/redis.go b/stores/goredis/redis.go index 5da0da2..baff0cc 100644 --- a/stores/goredis/redis.go +++ b/stores/goredis/redis.go @@ -57,6 +57,7 @@ func (s *Store) Get(namespace, group, uri string) (fastcache.Item, error) { var ( out fastcache.Item ) + // Get content_type, etag, blob in that order. cmd := s.cn.HMGet(s.ctx, s.key(namespace, group), s.field(keyCtype, uri), s.field(keyEtag, uri), s.field(keyCompression, uri), s.field(keyBlob, uri)) if err := cmd.Err(); err != nil {