diff --git a/go.mod b/go.mod index a6bceb4..c603c5b 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/caarlos0/env/v11 v11.2.2 github.com/gleich/lumber/v3 v3.0.2 github.com/go-chi/chi/v5 v5.1.0 + github.com/gorilla/websocket v1.5.3 github.com/joho/godotenv v1.5.1 github.com/minio/minio-go/v7 v7.0.81 github.com/prometheus/client_golang v1.20.5 diff --git a/go.sum b/go.sum index 81b3346..f306e47 100644 --- a/go.sum +++ b/go.sum @@ -28,6 +28,8 @@ github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg= +github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= diff --git a/internal/cache/cache.go b/internal/cache/cache.go index fd96669..dfdd9cf 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -14,18 +14,22 @@ import ( "github.com/gleich/lcp-v2/internal/metrics" "github.com/gleich/lcp-v2/internal/secrets" "github.com/gleich/lumber/v3" + "github.com/gorilla/websocket" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) type Cache[T any] struct { - name string - mutex sync.RWMutex - data T - updated time.Time - updateCounter prometheus.Counter - requestCounter prometheus.Counter - filePath string + name string + dataMutex sync.RWMutex + data T + updated time.Time + updateCounter prometheus.Counter + requestCounter prometheus.Counter + filePath string + wsConnPool map[*websocket.Conn]bool + wsConnPoolMutex sync.Mutex + wsUpgrader websocket.Upgrader } func NewCache[T any](name string, data T) *Cache[T] { @@ -39,7 +43,9 @@ func NewCache[T any](name string, data T) *Cache[T] { Name: fmt.Sprintf("cache_%s_requests", name), Help: fmt.Sprintf(`The total number of times the cache "%s" has been requested`, name), }), - filePath: filepath.Join(secrets.SECRETS.CacheFolder, fmt.Sprintf("%s.json", name)), + filePath: filepath.Join(secrets.SECRETS.CacheFolder, fmt.Sprintf("%s.json", name)), + wsConnPool: make(map[*websocket.Conn]bool), + wsUpgrader: websocket.Upgrader{}, } cache.loadFromFile() cache.Update(data) @@ -58,10 +64,10 @@ func (c *Cache[T]) ServeHTTP() http.HandlerFunc { w.WriteHeader(http.StatusUnauthorized) return } - c.mutex.RLock() + c.dataMutex.RLock() w.Header().Set("Content-Type", "application/json") err := json.NewEncoder(w).Encode(cacheData[T]{Data: c.data, Updated: c.updated}) - c.mutex.RUnlock() + c.dataMutex.RUnlock() c.requestCounter.Inc() if err != nil { lumber.Error(err, "failed to write data") @@ -74,7 +80,7 @@ func (c *Cache[T]) ServeHTTP() http.HandlerFunc { // Update the given cache func (c *Cache[T]) Update(data T) { var updated bool - c.mutex.Lock() + c.dataMutex.Lock() old, err := json.Marshal(c.data) if err != nil { lumber.Error(err, "failed to json marshal old data") @@ -90,12 +96,18 @@ func (c *Cache[T]) Update(data T) { c.updated = time.Now() updated = true } - c.mutex.Unlock() + c.dataMutex.Unlock() + if updated { c.updateCounter.Inc() metrics.CacheUpdates.Inc() c.persistToFile() - lumber.Done(strings.ToUpper(c.name), "cache updated") + connectionsUpdated := c.broadcastUpdate() + lumber.Done( + strings.ToUpper(c.name), + "cache updated;", + "broadcasted to", connectionsUpdated, "websocket connections", + ) } } diff --git a/internal/cache/websockets.go b/internal/cache/websockets.go new file mode 100644 index 0000000..4714baa --- /dev/null +++ b/internal/cache/websockets.go @@ -0,0 +1,62 @@ +package cache + +import ( + "net/http" + + "github.com/gleich/lumber/v3" + "github.com/gorilla/websocket" +) + +// Handle websocket connections +func (c *Cache[T]) ServeWS() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + conn, err := c.wsUpgrader.Upgrade(w, r, nil) + if err != nil { + lumber.Error(err, "failed to upgrade connection to websocket") + return + } + c.wsConnPoolMutex.Lock() + c.wsConnPool[conn] = true + c.wsConnPoolMutex.Unlock() + + // sending initial data + c.dataMutex.RLock() + err = conn.WriteJSON(c.data) + c.dataMutex.RUnlock() + if err != nil { + lumber.Error(err, "failed to write initial cache data for", c.name) + c.removeConnection(conn) + return + } + + // spawning goroutine to handle connection + go func() { + defer c.removeConnection(conn) + }() + } +} + +func (c *Cache[T]) broadcastUpdate() int { + c.dataMutex.RLock() + d := c.data + c.dataMutex.RUnlock() + + updatedConnections := 0 + for conn := range c.wsConnPool { + err := conn.WriteJSON(d) + if err != nil { + lumber.Error(err, "failed to broadcast update to client") + c.removeConnection(conn) + } else { + updatedConnections++ + } + } + return updatedConnections +} + +func (c *Cache[T]) removeConnection(conn *websocket.Conn) { + c.wsConnPoolMutex.Lock() + delete(c.wsConnPool, conn) + c.wsConnPoolMutex.Unlock() + conn.Close() +}