Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unix socket proxy support added #77

Merged
merged 7 commits into from Feb 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 223 additions & 23 deletions pkg/services/forwarder/ports.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,21 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
"sort"
"strconv"
"strings"
"sync"
"time"

"github.com/containers/gvisor-tap-vsock/pkg/types"
"github.com/google/tcpproxy"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
Expand Down Expand Up @@ -50,31 +55,174 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote
return errors.New("proxy already running")
}

split := strings.Split(remote, ":")
if len(split) != 2 {
return errors.New("invalid remote addr")
}
port, err := strconv.Atoi(split[1])
if err != nil {
return err
}
address := tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(net.ParseIP(split[0]).To4()),
Port: uint16(port),
}

switch protocol {
case types.UNIX:
// parse URI for remote
remoteURI, err := url.Parse(remote)
if err != nil {
return fmt.Errorf("failed to parse remote uri :%s : %w", remote, err)
}

// build the address from remoteURI
remoteAddr := fmt.Sprintf("%s:%s", remoteURI.Hostname(), remoteURI.Port())

// dialFn opens remote connection for the proxy
var dialFn func(ctx context.Context, network, addr string) (conn net.Conn, e error)

// dialFn is set based on the protocol provided by remoteURI.Scheme
switch remoteURI.Scheme {
case "ssh-tunnel": // unix-to-unix proxy (over SSH)
// query string to map for the remoteURI contains ssh config info
remoteQuery := remoteURI.Query()

// username
sshuser := firstValueOrEmpty(remoteQuery["user"])
if sshuser == "" {
return fmt.Errorf("user not provided for unix-ssh connection")
}

// key
sshkeypath := firstValueOrEmpty(remoteQuery["key"])
if sshkeypath == "" {
return fmt.Errorf("key not provided for unix-ssh connection")
}

sshkeyBytes, err := ioutil.ReadFile(sshkeypath)
if err != nil {
return fmt.Errorf("failed to read ssh key: %s: %w", sshkeypath, err)
}

// passphrase
passphrase := firstValueOrEmpty(remoteQuery["passphrase"])

var sshsigner ssh.Signer

if passphrase == "" {
sshsigner, err = ssh.ParsePrivateKey(sshkeyBytes)
} else {
sshsigner, err = ssh.ParsePrivateKeyWithPassphrase(sshkeyBytes, []byte(passphrase))
}

// parse private key error?
if err != nil {
return fmt.Errorf("failed to parse ssh key: %s: %w", sshkeypath, err)
}

// default ssh port if not set
if remoteURI.Port() == "" {
remoteAddr = fmt.Sprintf("%s:%s", remoteURI.Hostname(), "22")
}

// build address
address, err := tcpipAddress(1, remoteAddr)
if err != nil {
return err
}

// check the remoteURI path provided for nonsense
if remoteURI.Path == "" || remoteURI.Path == "/" {
return fmt.Errorf("remote uri must contain a path to a socket file")
}

// captured and used by dialFn
var tcpConn *gonet.TCPConn
var sshClient *ssh.Client
This conversation was marked as resolved.
Show resolved Hide resolved
var connLock sync.Mutex

// handles getting underlying ssh connection, having this outside of
// dialFn limits connLock to only the parts it's needed for in a way
// that doesn't get racy.
sshConnFn := func(ctx context.Context, network, addr string) (client *ssh.Client, err error) {
connLock.Lock()
defer connLock.Unlock()

// check underlying tcpConn to see if it's closed
if tcpConn != nil {
if _, err := tcpConn.Read(make([]byte, 0)); err == io.EOF {
tcpConn = nil // set back to nil to force reconnect
}
}

// connect or reconnect to ssh
if tcpConn == nil || sshClient == nil {
// underlying connection to endpoint for the ssh client
tcpConn, err := gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber)
if err != nil {
return sshClient, err
}

// ssh client config that uses key authentication
config := &ssh.ClientConfig{
User: sshuser,
Auth: []ssh.AuthMethod{
ssh.PublicKeys(sshsigner),
},
// #nosec G106
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
HostKeyAlgorithms: []string{
ssh.KeyAlgoRSA,
ssh.KeyAlgoDSA,
ssh.KeyAlgoECDSA256,
ssh.KeyAlgoECDSA384,
ssh.KeyAlgoECDSA521,
ssh.KeyAlgoED25519,
},
Timeout: 5 * time.Second,
}

// get an sshConn using the underlying gonet.TCPConn
sshConn, chans, reqs, err := ssh.NewClientConn(tcpConn, addr, config)
if err != nil {
return sshClient, err
}

// build an ssh client using sshConn
sshClient = ssh.NewClient(sshConn, chans, reqs)
}

return sshClient, err
}

// the dialFn for unix-to-unix over SSH
dialFn = func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
// check or create new ssh connection
sshClient, err = sshConnFn(ctx, network, addr)
if err != nil {
return nil, err
}

// connection using sshclient's dialer
return sshClient.Dial("unix", remoteURI.Path)
}

case "tcp": // unix-to-tcp proxy
// build address
address, err := tcpipAddress(1, remoteAddr)
if err != nil {
return err
}

dialFn = func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
return gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber)
}

default:
return fmt.Errorf("remote protocol for unix forwarder is not implemented: %s", remoteURI.Scheme)
}

// build the tcp proxy
var p tcpproxy.Proxy
p.ListenFunc = func(_, socketPath string) (net.Listener, error) {
// remove existing socket file
if err := os.Remove(socketPath); err != nil && !os.IsNotExist(err) {
return nil, err
}

return net.Listen("unix", socketPath) // override tcp to use unix socket
}
p.AddRoute(local, &tcpproxy.DialProxy{
Addr: remote,
DialContext: func(ctx context.Context, network, addr string) (conn net.Conn, e error) {
return gonet.DialContextTCP(ctx, f.stack, address, ipv4.ProtocolNumber)
},
Addr: remoteAddr,
DialContext: dialFn,
})
if err := p.Start(); err != nil {
return err
Expand All @@ -91,7 +239,13 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote
Remote: remote,
underlying: &p,
}

case types.UDP:
address, err := tcpipAddress(1, remote)
if err != nil {
return err
}

addr, err := net.ResolveUDPAddr("udp", local)
if err != nil {
return err
Expand All @@ -114,6 +268,11 @@ func (f *PortsForwarder) Expose(protocol types.TransportProtocol, local, remote
underlying: p,
}
case types.TCP:
address, err := tcpipAddress(1, remote)
if err != nil {
return err
}

var p tcpproxy.Proxy
p.AddRoute(local, &tcpproxy.DialProxy{
Addr: remote,
Expand Down Expand Up @@ -186,12 +345,21 @@ func (f *PortsForwarder) Mux() http.Handler {
if req.Protocol == "" {
req.Protocol = types.TCP
}
remote, err := remote(req, r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return

// contains unparsed remote field
remoteAddr := req.Remote

// TCP and UDP rely on remote() to preparse the remote field
if req.Protocol != types.UNIX {
var err error
remoteAddr, err = remote(req, r.RemoteAddr)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
}
if err := f.Expose(req.Protocol, req.Local, remote); err != nil {

if err := f.Expose(req.Protocol, req.Local, remoteAddr); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
Expand Down Expand Up @@ -234,3 +402,35 @@ func remote(req types.ExposeRequest, ip string) (string, error) {
}
return req.Remote, nil
}

// helper function for parsed URL query strings
func firstValueOrEmpty(x []string) string {
if len(x) > 0 {
return x[0]
}
return ""
}

// helper function to build tcpip address
func tcpipAddress(nicID tcpip.NICID, remote string) (address tcpip.FullAddress, err error) {

// build the address manual way
split := strings.Split(remote, ":")
if len(split) != 2 {
return address, errors.New("invalid remote addr")
}

port, err := strconv.Atoi(split[1])
if err != nil {
return address, err

}

address = tcpip.FullAddress{
NIC: nicID,
Addr: tcpip.Address(net.ParseIP(split[0]).To4()),
Port: uint16(port),
}

return address, err
}
42 changes: 41 additions & 1 deletion test/port_forwarding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package e2e

import (
"context"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -191,7 +192,7 @@ var _ = Describe("port forwarding", func() {

unix2tcpfwdsock, _ := filepath.Abs(filepath.Join(tmpDir, "podman-unix-to-unix-forwarding.sock"))

out, err := sshExec(`curl http://gateway.containers.internal/services/forwarder/expose -X POST -d'{"protocol":"unix","local":"` + unix2tcpfwdsock + `","remote":"192.168.127.2:8080"}'`)
out, err := sshExec(`curl http://gateway.containers.internal/services/forwarder/expose -X POST -d'{"protocol":"unix","local":"` + unix2tcpfwdsock + `","remote":"tcp://192.168.127.2:8080"}'`)
Expect(string(out)).Should(Equal(""))
Expect(err).ShouldNot(HaveOccurred())

Expand All @@ -215,4 +216,43 @@ var _ = Describe("port forwarding", func() {
g.Expect(resp.StatusCode).To(Equal(http.StatusOK))
}).Should(Succeed())
})

It("should expose and reach rootless podman API using unix to unix forwarding over ssh", func() {
if runtime.GOOS == "windows" {
Skip("AF_UNIX not supported on Windows")
}

unix2unixfwdsock, _ := filepath.Abs(filepath.Join(tmpDir, "podman-unix-to-unix-forwarding.sock"))

remoteuri := fmt.Sprintf(`ssh-tunnel://%s:%d%s?user=root&key=%s`, "192.168.127.2", 22, podmanSock, privateKeyFile)
_, err := sshExec(`curl http://192.168.127.1/services/forwarder/expose -X POST -d'{"protocol":"unix","local":"` + unix2unixfwdsock + `","remote":"` + remoteuri + `"}'`)
Expect(err).ShouldNot(HaveOccurred())

Eventually(func(g Gomega) {
sockfile, err := os.Stat(unix2unixfwdsock)
g.Expect(err).ShouldNot(HaveOccurred())
g.Expect(sockfile.Mode().Type().String()).To(Equal(os.ModeSocket.String()))
}).Should(Succeed())

httpClient := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return net.Dial("unix", unix2unixfwdsock)
},
},
}

Eventually(func(g Gomega) {
resp, err := httpClient.Get("http://host/_ping")
g.Expect(err).ShouldNot(HaveOccurred())
g.Expect(resp.StatusCode).To(Equal(http.StatusOK))
g.Expect(resp.ContentLength).To(Equal(int64(2)))

reply := make([]byte, resp.ContentLength)
_, err = io.ReadAtLeast(resp.Body, reply, len(reply))

g.Expect(err).ShouldNot(HaveOccurred())
g.Expect(string(reply)).To(Equal("OK"))
}).Should(Succeed())
})
})