diff --git a/hashmail_server.go b/hashmail_server.go index 5f14ff0..f7e7b74 100644 --- a/hashmail_server.go +++ b/hashmail_server.go @@ -8,6 +8,7 @@ import ( "sync" "time" + "github.com/btcsuite/btclog/v2" "github.com/lightninglabs/lightning-node-connect/hashmailrpc" "github.com/lightningnetwork/lnd/tlv" "github.com/prometheus/client_golang/prometheus" @@ -104,8 +105,8 @@ func (r *readStream) ReadNextMsg(ctx context.Context) ([]byte, error) { // ReturnStream gives up the read stream by passing it back up through the // payment stream. -func (r *readStream) ReturnStream() { - log.Debugf("Returning read stream %x", r.parentStream.id[:]) +func (r *readStream) ReturnStream(ctx context.Context) { + log.DebugS(ctx, "Returning read stream") r.parentStream.ReturnReadStream(r) } @@ -193,7 +194,7 @@ type stream struct { } // newStream creates a new stream independent of any given stream ID. -func newStream(id streamID, limiter *rate.Limiter, +func newStream(ctx context.Context, id streamID, limiter *rate.Limiter, equivAuth func(auth *hashmailrpc.CipherBoxAuth) error, onStale func() error, staleTimeout time.Duration) *stream { @@ -210,7 +211,7 @@ func newStream(id streamID, limiter *rate.Limiter, id: id, equivAuth: equivAuth, limiter: limiter, - status: newStreamStatus(onStale, staleTimeout), + status: newStreamStatus(ctx, onStale, staleTimeout), readBytesChan: make(chan []byte), readErrChan: make(chan error, 1), quit: make(chan struct{}), @@ -305,8 +306,8 @@ func (s *stream) ReturnWriteStream(w *writeStream) { // RequestReadStream attempts to request the read stream from the main backing // stream. If we're unable to obtain it before the timeout, then an error is // returned. -func (s *stream) RequestReadStream() (*readStream, error) { - log.Tracef("HashMailStream(%x): requesting read stream", s.id[:]) +func (s *stream) RequestReadStream(ctx context.Context) (*readStream, error) { + log.TraceS(ctx, "Requested read stream") select { case r := <-s.readStreamChan: @@ -320,8 +321,8 @@ func (s *stream) RequestReadStream() (*readStream, error) { // RequestWriteStream attempts to request the read stream from the main backing // stream. If we're unable to obtain it before the timeout, then an error is // returned. -func (s *stream) RequestWriteStream() (*writeStream, error) { - log.Tracef("HashMailStream(%x): requesting write stream", s.id[:]) +func (s *stream) RequestWriteStream(ctx context.Context) (*writeStream, error) { + log.TraceS(ctx, "Requesting write stream") select { case w := <-s.writeStreamChan: @@ -389,8 +390,10 @@ func (h *hashMailServer) Stop() { } // tearDownStaleStream can be used to tear down a stale mailbox stream. -func (h *hashMailServer) tearDownStaleStream(id streamID) error { - log.Debugf("Tearing down stale HashMail stream: id=%x", id) +func (h *hashMailServer) tearDownStaleStream(ctx context.Context, + id streamID) error { + + log.DebugS(ctx, "Tearing down stale HashMail stream") h.Lock() defer h.Unlock() @@ -428,7 +431,7 @@ func (h *hashMailServer) ValidateStreamAuth(ctx context.Context, } // InitStream attempts to initialize a new stream given a valid descriptor. -func (h *hashMailServer) InitStream( +func (h *hashMailServer) InitStream(ctx context.Context, init *hashmailrpc.CipherBoxAuth) (*hashmailrpc.CipherInitResp, error) { h.Lock() @@ -436,7 +439,7 @@ func (h *hashMailServer) InitStream( streamID := newStreamID(init.Desc.StreamId) - log.Debugf("Creating new HashMail Stream: %x", streamID) + log.DebugS(ctx, "Creating new HashMail Stream") // The stream is already active, and we only allow a single session for // a given stream to exist. @@ -452,10 +455,11 @@ func (h *hashMailServer) InitStream( rate.Every(h.cfg.msgRate), h.cfg.msgBurstAllowance, ) freshStream := newStream( - streamID, limiter, func(auth *hashmailrpc.CipherBoxAuth) error { + ctx, streamID, limiter, + func(auth *hashmailrpc.CipherBoxAuth) error { return nil }, func() error { - return h.tearDownStaleStream(streamID) + return h.tearDownStaleStream(ctx, streamID) }, h.cfg.staleTimeout, ) @@ -470,7 +474,9 @@ func (h *hashMailServer) InitStream( // LookUpReadStream attempts to loop up a new stream. If the stream is found, then // the stream is marked as being active. Otherwise, an error is returned. -func (h *hashMailServer) LookUpReadStream(streamID []byte) (*readStream, error) { +func (h *hashMailServer) LookUpReadStream(ctx context.Context, + streamID []byte) (*readStream, error) { + h.RLock() defer h.RUnlock() @@ -479,12 +485,13 @@ func (h *hashMailServer) LookUpReadStream(streamID []byte) (*readStream, error) return nil, fmt.Errorf("stream not found") } - return stream.RequestReadStream() + return stream.RequestReadStream(ctx) } // LookUpWriteStream attempts to loop up a new stream. If the stream is found, // then the stream is marked as being active. Otherwise, an error is returned. -func (h *hashMailServer) LookUpWriteStream(streamID []byte) (*writeStream, error) { +func (h *hashMailServer) LookUpWriteStream(ctx context.Context, + streamID []byte) (*writeStream, error) { h.RLock() defer h.RUnlock() @@ -494,7 +501,7 @@ func (h *hashMailServer) LookUpWriteStream(streamID []byte) (*writeStream, error return nil, fmt.Errorf("stream not found") } - return stream.RequestWriteStream() + return stream.RequestWriteStream(ctx) } // TearDownStream attempts to tear down a stream which renders both sides of @@ -523,8 +530,7 @@ func (h *hashMailServer) TearDownStream(ctx context.Context, streamID []byte, return err } - log.Debugf("Tearing down HashMail stream: id=%x, auth=%v", - auth.Desc.StreamId, auth.Auth) + log.DebugS(ctx, "Tearing down HashMail stream", "auth", auth.Auth) // At this point we know the auth was valid, so we'll tear down the // stream. @@ -568,16 +574,17 @@ func (h *hashMailServer) NewCipherBox(ctx context.Context, return nil, err } - log.Debugf("New HashMail stream init: id=%x, auth=%v", - init.Desc.StreamId, init.Auth) + ctxl := btclog.WithCtx(ctx, btclog.Hex("stream_id", init.Desc.StreamId)) - if err := h.ValidateStreamAuth(ctx, init); err != nil { - log.Debugf("Stream creation validation failed (id=%x): %v", - init.Desc.StreamId, err) + log.DebugS(ctxl, "New HashMail stream init", "auth", init.Auth) + + if err := h.ValidateStreamAuth(ctxl, init); err != nil { + log.DebugS(ctxl, "Stream creation validation failed", + "err", err) return nil, err } - resp, err := h.InitStream(init) + resp, err := h.InitStream(ctxl, init) if err != nil { return nil, err } @@ -597,8 +604,9 @@ func (h *hashMailServer) DelCipherBox(ctx context.Context, return nil, err } - log.Debugf("New HashMail stream deletion: id=%x, auth=%v", - auth.Desc.StreamId, auth.Auth) + ctxl := btclog.WithCtx(ctx, btclog.Hex("stream_id", auth.Desc.StreamId)) + + log.DebugS(ctxl, "New HashMail stream deletion", "auth", auth.Auth) if err := h.TearDownStream(ctx, auth.Desc.StreamId, auth); err != nil { return nil, err @@ -610,7 +618,7 @@ func (h *hashMailServer) DelCipherBox(ctx context.Context, // SendStream implements the client streaming call to utilize the write end of // a stream to send a message to the read end. func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamServer) error { - log.Debugf("New HashMail write stream pending...") + log.Debug("New HashMail write stream pending...") // We'll need to receive the first message in order to determine if // this stream exists or not @@ -621,6 +629,9 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe return err } + ctx := btclog.WithCtx(readStream.Context(), + btclog.Hex("stream_id", cipherBox.Desc.StreamId)) + switch { case cipherBox.Desc == nil: return fmt.Errorf("cipher box descriptor required") @@ -629,12 +640,11 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe return fmt.Errorf("stream_id required") } - log.Debugf("New HashMail write stream: id=%x", - cipherBox.Desc.StreamId) + log.DebugS(ctx, "New HashMail write stream") // Now that we have the first message, we can attempt to look up the // given stream. - writeStream, err := h.LookUpWriteStream(cipherBox.Desc.StreamId) + writeStream, err := h.LookUpWriteStream(ctx, cipherBox.Desc.StreamId) if err != nil { return err } @@ -643,13 +653,12 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe // write inactive if the client hangs up on their end. defer writeStream.ReturnStream() - log.Tracef("Sending msg_len=%v to stream_id=%x", len(cipherBox.Msg), - cipherBox.Desc.StreamId) + log.TraceS(ctx, "Sending message to stream", + "msg_len", len(cipherBox.Msg)) // We'll send the first message into the stream, then enter our loop // below to continue to read from the stream and send it to the read // end. - ctx := readStream.Context() if err := writeStream.WriteMsg(ctx, cipherBox.Msg); err != nil { return err } @@ -659,7 +668,7 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe // exit before shutting down. select { case <-ctx.Done(): - log.Debugf("SendStream: Context done, exiting") + log.DebugS(ctx, "SendStream: Context done, exiting") return nil case <-h.quit: return fmt.Errorf("server shutting down") @@ -669,13 +678,13 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe cipherBox, err := readStream.Recv() if err != nil { - log.Debugf("SendStream: Exiting write stream RPC "+ - "stream read: %v", err) + log.DebugS(ctx, "SendStream: Exiting write stream RPC "+ + "stream read", err) return err } - log.Tracef("Sending msg_len=%v to stream_id=%x", - len(cipherBox.Msg), cipherBox.Desc.StreamId) + log.TraceS(ctx, "Sending message to stream", + "msg_len", len(cipherBox.Msg)) if err := writeStream.WriteMsg(ctx, cipherBox.Msg); err != nil { return err @@ -689,25 +698,28 @@ func (h *hashMailServer) SendStream(readStream hashmailrpc.HashMail_SendStreamSe func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc, reader hashmailrpc.HashMail_RecvStreamServer) error { + ctx := btclog.WithCtx(reader.Context(), + btclog.Hex("stream_id", desc.StreamId)) + // First, we'll attempt to locate the stream. We allow any single // entity that knows of the full stream ID to access the read end. - readStream, err := h.LookUpReadStream(desc.StreamId) + readStream, err := h.LookUpReadStream(ctx, desc.StreamId) if err != nil { return err } - log.Debugf("New HashMail read stream: id=%x", desc.StreamId) + log.DebugS(ctx, "New HashMail read stream") // If the reader hangs up, then we'll mark the stream as inactive so // another can take its place. - defer readStream.ReturnStream() + defer readStream.ReturnStream(ctx) for { // Check to see if the stream has been closed or if we need to - // exit before shutting down. + // exit before shutting d[own. select { case <-reader.Context().Done(): - log.Debugf("Read stream context done.") + log.DebugS(ctx, "Read stream context done.") return nil case <-h.quit: return fmt.Errorf("server shutting down") @@ -717,12 +729,11 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc, nextMsg, err := readStream.ReadNextMsg(reader.Context()) if err != nil { - log.Debugf("Got error an read stream read: %v", err) + log.ErrorS(ctx, "Got error on read stream read", err) return err } - log.Tracef("Read %v bytes for HashMail stream_id=%x", - len(nextMsg), desc.StreamId) + log.TraceS(ctx, "Read bytes", "msg_len", len(nextMsg)) // In order not to duplicate metric data, we only record this // read if its streamID is odd. We use the base stream ID as the @@ -742,8 +753,8 @@ func (h *hashMailServer) RecvStream(desc *hashmailrpc.CipherBoxDesc, Msg: nextMsg, }) if err != nil { - log.Debugf("Got error when sending on read stream: %v", - err) + log.DebugS(ctx, "Got error when sending on read stream", + "err", err) return err } } @@ -767,7 +778,7 @@ type streamStatus struct { } // newStreamStatus constructs a new streamStatus instance. -func newStreamStatus(onStale func() error, +func newStreamStatus(ctx context.Context, onStale func() error, staleTimeout time.Duration) *streamStatus { if staleTimeout < 0 { @@ -778,7 +789,7 @@ func newStreamStatus(onStale func() error, staleTimer := time.AfterFunc(staleTimeout, func() { if err := onStale(); err != nil { - log.Errorf("error in onStale callback: %v", err) + log.ErrorS(ctx, "Error from onStale callback", err) } })