From c753875ac15361949bc006f084d567b29d9c431f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 23 May 2020 19:38:15 +0200 Subject: [PATCH] handle AWS SigV4 correctly with RawPath fixes #3 --- handler.go | 3 + handler_test.go | 148 +++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 137 insertions(+), 14 deletions(-) diff --git a/handler.go b/handler.go index 83addb9..245cfec 100644 --- a/handler.go +++ b/handler.go @@ -8,6 +8,7 @@ import ( "net" "net/http" "net/http/httputil" + "net/url" "regexp" "strings" "time" @@ -139,6 +140,7 @@ func (h *Handler) generateFakeIncomingRequest(signer *v4.Signer, req *http.Reque if err != nil { return nil, err } + fakeReq.URL.RawPath = req.URL.Path // We already validated there there is exactly one Authorization header authorizationHeader := req.Header.Get("authorization") @@ -177,6 +179,7 @@ func (h *Handler) assembleUpstreamReq(signer *v4.Signer, req *http.Request, regi proxyURL := *req.URL proxyURL.Scheme = h.UpstreamScheme proxyURL.Host = upstreamEndpoint + proxyURL.RawPath = req.URL.Path proxyReq, err := http.NewRequest(req.Method, proxyURL.String(), req.Body) if err != nil { return nil, err diff --git a/handler_test.go b/handler_test.go index ce18a47..8bfc3af 100644 --- a/handler_test.go +++ b/handler_test.go @@ -1,13 +1,18 @@ package main import ( + "bytes" "fmt" "net" "net/http" "net/http/httptest" "net/url" + "strings" "testing" + "time" + "github.com/aws/aws-sdk-go/aws/credentials" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/stretchr/testify/assert" ) @@ -15,11 +20,15 @@ import ( // log.SetOutput(ioutil.Discard) // } -func newTestHandler(t *testing.T) *Handler { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +func newTestProxy(t *testing.T) *Handler { + thf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "Hello, client") - })) - // defer ts.Close() + }) + return newTestProxyWithHandler(t, &thf) +} + +func newTestProxyWithHandler(t *testing.T, thf *http.HandlerFunc) *Handler { + ts := httptest.NewServer(thf) tsURL, _ := url.Parse(ts.URL) h, err := NewAwsS3ReverseProxy(Options{ @@ -35,8 +44,52 @@ func newTestHandler(t *testing.T) *Handler { return h } +func signRequest(r *http.Request) { + // delete headers to get clean signature + r.Header.Del("accept-encoding") + r.Header.Del("authorization") + r.Header.Set("X-Amz-Date", "20060102T150405Z") + r.URL.RawPath = r.URL.Path + + // compute the expected signature with valid credentials + body := bytes.NewReader([]byte{}) + signTime, _ := time.Parse("20060102T150405Z", r.Header["X-Amz-Date"][0]) + signer := v4.NewSigner(credentials.NewStaticCredentialsFromCreds(credentials.Value{ + AccessKeyID: "fooooooooooooooo", + SecretAccessKey: "bar", + })) + signer.Sign(r, body, "s3", "eu-test-1", signTime) +} + +func verifySignature(w http.ResponseWriter, r *http.Request) { + // save copy of the received signature + receivedAuthorization := r.Header["Authorization"][0] + + // delete headers to get clean signature + r.Header.Del("accept-encoding") + r.Header.Del("authorization") + + // compute the expected signature with valid credentials + body := bytes.NewReader([]byte{}) + signTime, _ := time.Parse("20060102T150405Z", r.Header["X-Amz-Date"][0]) + signer := v4.NewSigner(credentials.NewStaticCredentialsFromCreds(credentials.Value{ + AccessKeyID: "fooooooooooooooo", + SecretAccessKey: "bar", + })) + signer.Sign(r, body, "s3", "eu-test-1", signTime) + expectedAuthorization := r.Header["Authorization"][0] + + // verify signature + fmt.Fprintln(w, receivedAuthorization, expectedAuthorization) + if receivedAuthorization == expectedAuthorization { + fmt.Fprintln(w, "ok") + } else { + fmt.Fprintln(w, "failed signature check") + } +} + func TestHandlerMissingAmzDate(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) resp := httptest.NewRecorder() @@ -46,7 +99,7 @@ func TestHandlerMissingAmzDate(t *testing.T) { } func TestHandlerMissingAuthorization(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) req.Header.Set("X-Amz-Date", "20060102T150405Z") @@ -57,7 +110,7 @@ func TestHandlerMissingAuthorization(t *testing.T) { } func TestHandlerMissingCredential(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) req.Header.Set("X-Amz-Date", "20060102T150405Z") @@ -69,7 +122,7 @@ func TestHandlerMissingCredential(t *testing.T) { } func TestHandlerInvalidSignature(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) req.Header.Set("X-Amz-Date", "20060102T150405Z") @@ -81,11 +134,10 @@ func TestHandlerInvalidSignature(t *testing.T) { } func TestHandlerValidSignature(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) - req.Header.Set("X-Amz-Date", "20060102T150405Z") - req.Header.Set("Authorization", "AWS4-HMAC-SHA256 Credential=fooooooooooooooo/20060102/eu-test-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a0d5e0c0924c1f9298c5f2a3925e202657bf1e239a1d6856235cbe0702855334") // signature computed manually for this test case + signRequest(req) resp := httptest.NewRecorder() h.ServeHTTP(resp, req) assert.Equal(t, 200, resp.Code) @@ -93,7 +145,7 @@ func TestHandlerValidSignature(t *testing.T) { } func TestHandlerInvalidCredential(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) req.Header.Set("X-Amz-Date", "20060102T150405Z") @@ -105,7 +157,7 @@ func TestHandlerInvalidCredential(t *testing.T) { } func TestHandlerInvalidSourceSubnet(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) _, newNet, _ := net.ParseCIDR("172.27.42.0/24") h.AllowedSourceSubnet = []*net.IPNet{newNet} @@ -119,7 +171,7 @@ func TestHandlerInvalidSourceSubnet(t *testing.T) { } func TestHandlerInvalidAmzDate(t *testing.T) { - h := newTestHandler(t) + h := newTestProxy(t) req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) req.Header.Set("X-Amz-Date", "foobar") @@ -129,3 +181,71 @@ func TestHandlerInvalidAmzDate(t *testing.T) { assert.Equal(t, 400, resp.Code) assert.Contains(t, resp.Body.String(), "error parsing X-Amz-Date foobar") } + +func TestHandlerRawPathEncodingMatchingSignature(t *testing.T) { + thf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + verifySignature(w, r) + }) + h := newTestProxyWithHandler(t, &thf) + + urls := []string{ + "http://foobar.example.com/foo%3Dbar/test.txt", + "http://foobar.example.com/foo=bar/test.txt", + "http://foobar.example.com/foo%3Dbar/test.txt?marker=1000", + "http://foobar.example.com/foo=bar/test.txt?marker=1000", + } + + for _, url := range urls { + req := httptest.NewRequest(http.MethodGet, url, nil) + signRequest(req) + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + assert.Equal(t, 200, resp.Code) + assert.Contains(t, strings.TrimSpace(resp.Body.String()), "ok") + } +} + +func TestHandlerWithQueryArgs(t *testing.T) { + thf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + verifySignature(w, r) + if r.URL.Query().Get("marker") == "1000" { + fmt.Fprintln(w, "marker-ok") + } else { + fmt.Fprintln(w, "marker missing") + } + }) + h := newTestProxyWithHandler(t, &thf) + + urls := []string{ + "http://foobar.example.com/foo%3Dbar/test.txt?marker=1000", + "http://foobar.example.com/foo=bar/test.txt?marker=1000", + } + + for _, url := range urls { + req := httptest.NewRequest(http.MethodGet, url, nil) + signRequest(req) + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + assert.Equal(t, 200, resp.Code) + assert.Contains(t, strings.TrimSpace(resp.Body.String()), "marker-ok") + } +} + +func TestHandlerPassCustomHeaders(t *testing.T) { + thf := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("x-aws-s3-reverse-proxy") == "testing" { + fmt.Fprintln(w, "ok") + } else { + fmt.Fprintln(w, "header missing") + } + }) + h := newTestProxyWithHandler(t, &thf) + + req := httptest.NewRequest(http.MethodGet, "http://foobar.example.com", nil) + signRequest(req) + req.Header.Set("x-aws-s3-reverse-proxy", "testing") + resp := httptest.NewRecorder() + h.ServeHTTP(resp, req) + assert.Equal(t, 200, resp.Code) + assert.Contains(t, strings.TrimSpace(resp.Body.String()), "ok") +}