Skip to content

Commit

Permalink
fix(tunnel): add seq to prevent order mismatch
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 11, 2024
1 parent b1ef267 commit c0bd86d
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 26 deletions.
66 changes: 52 additions & 14 deletions upper/services/tunnel/tunnel.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tunnel

import (
"encoding/binary"
"encoding/hex"
"io"
"net"
Expand All @@ -14,7 +15,7 @@ import (
type Tunnel struct {
l *link.Link
in chan []byte
out chan []byte
out chan *head.Packet
outcache []byte
peerip net.IP
src uint16
Expand All @@ -26,7 +27,7 @@ func Create(me *link.Me, peer string) (s Tunnel, err error) {
s.l, err = me.Connect(peer)
if err == nil {
s.in = make(chan []byte, 4)
s.out = make(chan []byte, 4)
s.out = make(chan *head.Packet, 4)
s.peerip = net.ParseIP(peer)
} else {
logrus.Errorln("[tunnel] create err:", err)
Expand Down Expand Up @@ -62,7 +63,16 @@ func (s *Tunnel) Read(p []byte) (int, error) {
if s.outcache != nil {
d = s.outcache
} else {
d = <-s.out
pkt := <-s.out
if pkt == nil {
return 0, io.EOF
}
defer pkt.Put()
if len(pkt.Data) < 4 {
logrus.Warnln("[tunnel] unexpected packet data len", len(pkt.Data), "content", pkt.Data)
return 0, io.EOF
}
d = pkt.Data[4:]
}
if d != nil {
if len(p) >= len(d) {
Expand All @@ -79,9 +89,12 @@ func (s *Tunnel) Read(p []byte) (int, error) {
func (s *Tunnel) Stop() {
s.l.Close()
close(s.in)
close(s.out)
}

func (s *Tunnel) handleWrite() {
seq := uint32(0)
buf := make([]byte, s.mtu)
for b := range s.in {
end := 64
endl := "..."
Expand All @@ -95,27 +108,45 @@ func (s *Tunnel) handleWrite() {
break
}
logrus.Debugln("[tunnel] writing", len(b), "bytes...")
for len(b) > int(s.mtu) {
logrus.Infoln("[tunnel] split buffer")
_, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false)
for len(b) > int(s.mtu)-4 {
logrus.Infoln("[tunnel] seq", seq, "split buffer")
binary.LittleEndian.PutUint32(buf[:4], seq)
seq++
copy(buf[4:], b[:s.mtu-4])
_, err := s.l.WriteAndPut(
head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf), false,
)
if err != nil {
logrus.Errorln("[tunnel] write err:", err)
logrus.Errorln("[tunnel] seq", seq-1, "write err:", err)
return
}
logrus.Debugln("[tunnel] write succeeded")
b = b[s.mtu:]
logrus.Debugln("[tunnel] seq", seq-1, "write succeeded")
b = b[s.mtu-4:]
}
_, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false)
binary.LittleEndian.PutUint32(buf[:4], seq)
seq++
copy(buf[4:], b)
_, err := s.l.WriteAndPut(
head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf[:len(b)+4]), false,
)
if err != nil {
logrus.Errorln("[tunnel] write err:", err)
logrus.Errorln("[tunnel] seq", seq-1, "write err:", err)
break
}
logrus.Debugln("[tunnel] write succeeded")
logrus.Debugln("[tunnel] seq", seq-1, "write succeeded")
}
}

func (s *Tunnel) handleRead() {
seq := uint32(0)
seqmap := make(map[uint32]*head.Packet)
for {
if p, ok := seqmap[seq]; ok {
logrus.Debugln("[tunnel] dispatch cached seq", seq)
delete(seqmap, seq)
seq++
s.out <- p
}
p := s.l.Read()
if p == nil {
logrus.Errorln("[tunnel] read recv nil")
Expand All @@ -128,7 +159,14 @@ func (s *Tunnel) handleRead() {
endl = "."
}
logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Data[:end]), endl)
s.out <- p.Data
p.Put()
recvseq := binary.LittleEndian.Uint32(p.Data[:4])
if recvseq == seq {
logrus.Debugln("[tunnel] dispatch seq", seq)
seq++
s.out <- p
continue
}
seqmap[recvseq] = p
logrus.Debugln("[tunnel] cache seq", recvseq)
}
}
46 changes: 34 additions & 12 deletions upper/services/tunnel/tunnel_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tunnel

import (
"bytes"
"crypto/rand"
"encoding/hex"
"io"
Expand Down Expand Up @@ -92,26 +93,47 @@ func TestTunnel(t *testing.T) {
rand.Read(sendb)
tunnme.Write(sendb)
buf = make([]byte, 4096)
tunnpeer.Read(buf)
_, err = io.ReadFull(&tunnpeer, buf)
if err != nil {
t.Fatal(err)
}
if string(sendb) != string(buf) {
t.Fatal("error: recv 4096 bytes data")
}

sendb = make([]byte, 65535)
rand.Read(sendb)
n, _ := tunnme.Write(sendb)
t.Log("write", n, "bytes")
buf = make([]byte, 65535)
n, _ = io.ReadFull(&tunnpeer, buf)
t.Log("read", n, "bytes")
if string(sendb) != string(buf) {
t.Fatal("error: recv 65535 bytes data")
t.Log("expect", hex.EncodeToString(sendb))
t.Log("got", hex.EncodeToString(buf))
for i := 0; i < 32; i++ {
rand.Read(sendb)
n, _ := tunnme.Write(sendb)
t.Log("loop", i, "write", n, "bytes")
n, err = io.ReadFull(&tunnpeer, buf)
if err != nil {
t.Fatal(err)
}
t.Log("loop", i, "read", n, "bytes")
if string(sendb) != string(buf) {
t.Fatal("loop", i, "error: recv 65535 bytes data")
}
}

tunnme.Stop()
tunnpeer.Stop()
rand.Read(sendb)
tunnme.Write(sendb)
rd := bytes.NewBuffer(nil)

tm := time.AfterFunc(time.Second*5, func() {
tunnme.Stop()
tunnpeer.Stop()
})
defer tm.Stop()

_, err = io.CopyBuffer(rd, &tunnpeer, make([]byte, 200))
if err != nil {
t.Fatal(err)
}
if string(sendb) != rd.String() {
t.Fatal("error: recv fragmented 4096 bytes data")
}
}

// logFormat specialize for go-cqhttp
Expand Down

0 comments on commit c0bd86d

Please sign in to comment.