diff --git a/upper/services/tunnel/tunnel.go b/upper/services/tunnel/tunnel.go index 65f10d7..a67ccb3 100644 --- a/upper/services/tunnel/tunnel.go +++ b/upper/services/tunnel/tunnel.go @@ -1,6 +1,7 @@ package tunnel import ( + "encoding/binary" "encoding/hex" "io" "net" @@ -14,7 +15,7 @@ import ( type Tunnel struct { l *link.Link in chan []byte - out chan []byte + out chan *head.Packet outcache []byte peerip net.IP src uint16 @@ -26,7 +27,7 @@ func Create(me *link.Me, peer string) (s Tunnel, err error) { s.l, err = me.Connect(peer) if err == nil { s.in = make(chan []byte, 4) - s.out = make(chan []byte, 4) + s.out = make(chan *head.Packet, 4) s.peerip = net.ParseIP(peer) } else { logrus.Errorln("[tunnel] create err:", err) @@ -62,7 +63,16 @@ func (s *Tunnel) Read(p []byte) (int, error) { if s.outcache != nil { d = s.outcache } else { - d = <-s.out + pkt := <-s.out + if pkt == nil { + return 0, io.EOF + } + defer pkt.Put() + if len(pkt.Data) < 4 { + logrus.Warnln("[tunnel] unexpected packet data len", len(pkt.Data), "content", pkt.Data) + return 0, io.EOF + } + d = pkt.Data[4:] } if d != nil { if len(p) >= len(d) { @@ -79,9 +89,12 @@ func (s *Tunnel) Read(p []byte) (int, error) { func (s *Tunnel) Stop() { s.l.Close() close(s.in) + close(s.out) } func (s *Tunnel) handleWrite() { + seq := uint32(0) + buf := make([]byte, s.mtu) for b := range s.in { end := 64 endl := "..." @@ -95,27 +108,45 @@ func (s *Tunnel) handleWrite() { break } logrus.Debugln("[tunnel] writing", len(b), "bytes...") - for len(b) > int(s.mtu) { - logrus.Infoln("[tunnel] split buffer") - _, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b[:s.mtu]), false) + for len(b) > int(s.mtu)-4 { + logrus.Infoln("[tunnel] seq", seq, "split buffer") + binary.LittleEndian.PutUint32(buf[:4], seq) + seq++ + copy(buf[4:], b[:s.mtu-4]) + _, err := s.l.WriteAndPut( + head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf), false, + ) if err != nil { - logrus.Errorln("[tunnel] write err:", err) + logrus.Errorln("[tunnel] seq", seq-1, "write err:", err) return } - logrus.Debugln("[tunnel] write succeeded") - b = b[s.mtu:] + logrus.Debugln("[tunnel] seq", seq-1, "write succeeded") + b = b[s.mtu-4:] } - _, err := s.l.WriteAndPut(head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, b), false) + binary.LittleEndian.PutUint32(buf[:4], seq) + seq++ + copy(buf[4:], b) + _, err := s.l.WriteAndPut( + head.NewPacket(head.ProtoData, s.src, s.peerip, s.dest, buf[:len(b)+4]), false, + ) if err != nil { - logrus.Errorln("[tunnel] write err:", err) + logrus.Errorln("[tunnel] seq", seq-1, "write err:", err) break } - logrus.Debugln("[tunnel] write succeeded") + logrus.Debugln("[tunnel] seq", seq-1, "write succeeded") } } func (s *Tunnel) handleRead() { + seq := uint32(0) + seqmap := make(map[uint32]*head.Packet) for { + if p, ok := seqmap[seq]; ok { + logrus.Debugln("[tunnel] dispatch cached seq", seq) + delete(seqmap, seq) + seq++ + s.out <- p + } p := s.l.Read() if p == nil { logrus.Errorln("[tunnel] read recv nil") @@ -128,7 +159,14 @@ func (s *Tunnel) handleRead() { endl = "." } logrus.Debugln("[tunnel] read recv", hex.EncodeToString(p.Data[:end]), endl) - s.out <- p.Data - p.Put() + recvseq := binary.LittleEndian.Uint32(p.Data[:4]) + if recvseq == seq { + logrus.Debugln("[tunnel] dispatch seq", seq) + seq++ + s.out <- p + continue + } + seqmap[recvseq] = p + logrus.Debugln("[tunnel] cache seq", recvseq) } } diff --git a/upper/services/tunnel/tunnel_test.go b/upper/services/tunnel/tunnel_test.go index 24a3c83..486f7f1 100644 --- a/upper/services/tunnel/tunnel_test.go +++ b/upper/services/tunnel/tunnel_test.go @@ -1,6 +1,7 @@ package tunnel import ( + "bytes" "crypto/rand" "encoding/hex" "io" @@ -92,26 +93,47 @@ func TestTunnel(t *testing.T) { rand.Read(sendb) tunnme.Write(sendb) buf = make([]byte, 4096) - tunnpeer.Read(buf) + _, err = io.ReadFull(&tunnpeer, buf) + if err != nil { + t.Fatal(err) + } if string(sendb) != string(buf) { t.Fatal("error: recv 4096 bytes data") } sendb = make([]byte, 65535) - rand.Read(sendb) - n, _ := tunnme.Write(sendb) - t.Log("write", n, "bytes") buf = make([]byte, 65535) - n, _ = io.ReadFull(&tunnpeer, buf) - t.Log("read", n, "bytes") - if string(sendb) != string(buf) { - t.Fatal("error: recv 65535 bytes data") - t.Log("expect", hex.EncodeToString(sendb)) - t.Log("got", hex.EncodeToString(buf)) + for i := 0; i < 32; i++ { + rand.Read(sendb) + n, _ := tunnme.Write(sendb) + t.Log("loop", i, "write", n, "bytes") + n, err = io.ReadFull(&tunnpeer, buf) + if err != nil { + t.Fatal(err) + } + t.Log("loop", i, "read", n, "bytes") + if string(sendb) != string(buf) { + t.Fatal("loop", i, "error: recv 65535 bytes data") + } } - tunnme.Stop() - tunnpeer.Stop() + rand.Read(sendb) + tunnme.Write(sendb) + rd := bytes.NewBuffer(nil) + + tm := time.AfterFunc(time.Second*5, func() { + tunnme.Stop() + tunnpeer.Stop() + }) + defer tm.Stop() + + _, err = io.CopyBuffer(rd, &tunnpeer, make([]byte, 200)) + if err != nil { + t.Fatal(err) + } + if string(sendb) != rd.String() { + t.Fatal("error: recv fragmented 4096 bytes data") + } } // logFormat specialize for go-cqhttp