From 8a57c73ca45652f82c35664c88c098fa37b0465e Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 15 Aug 2024 11:52:20 -0500 Subject: [PATCH] Add configurable Now time for signature generation This behavior and test is prep for an AWS SDK update, to ensure that the generated token (signature) matches when we update the SDK. Signed-off-by: Micah Hausler --- pkg/token/token.go | 16 ++++++++++- pkg/token/token_test.go | 62 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/pkg/token/token.go b/pkg/token/token.go index 88ab70299..7936f4c88 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -34,6 +34,7 @@ import ( "github.com/aws/aws-sdk-go/aws/endpoints" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/aws/session" + v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/sts" "github.com/aws/aws-sdk-go/service/sts/stsiface" "github.com/prometheus/client_golang/prometheus" @@ -198,6 +199,7 @@ type Generator interface { type generator struct { forwardSessionName bool cache bool + nowFunc func() time.Time } // NewGenerator creates a Generator and returns it. @@ -205,6 +207,7 @@ func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) { return generator{ forwardSessionName: forwardSessionName, cache: cache, + nowFunc: time.Now, }, nil } @@ -332,12 +335,23 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { return g.GetWithSTS(options.ClusterID, stsAPI) } +func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler { + return request.NamedHandler{ + Name: "v4.SignRequestHandler", Fn: func(req *request.Request) { + v4.SignSDKRequestWithCurrentTime(req, nowFunc) + }, + } +} + // GetWithSTS returns a token valid for clusterID using the given STS client. func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) { // generate an sts:GetCallerIdentity request and add our custom cluster ID header request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) + // override the Sign handler so we can control the now time for testing. + request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) + // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date // timestamp regardless. We set it to 60 seconds for backwards compatibility (the @@ -350,7 +364,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, } // Set token expiration to 1 minute before the presigned URL expires for some cushion - tokenExpiration := time.Now().Local().Add(presignedURLExpiration - 1*time.Minute) + tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) // TODO: this may need to be a constant-time base64 encoding return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index 021718815..6cc87eb80 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -14,7 +14,11 @@ import ( "testing" "time" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/endpoints" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -582,3 +586,61 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { }) } } + +func TestGetWithSTS(t *testing.T) { + clusterID := "test-cluster" + + cases := []struct { + name string + creds *credentials.Credentials + nowTime time.Time + want Token + wantErr error + }{ + { + "Non-zero time", + // Example non-real credentials + func() *credentials.Credentials { + decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") + decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") + return credentials.NewStaticCredentials( + string(decodedAkid), + string(decodedSk), + "", + ) + }(), + time.Unix(1682640000, 0), + Token{ + Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JTNCeC1rOHMtYXdzLWlkJlgtQW16LVNpZ25hdHVyZT00ZDdhYmZkZTk2NzI1ZWI4YTc3MzgyNDg0MTZlNGI1ZDA4ZDlkYmQ3MThiNGY2ZGQ2OTBmOGZiNzUwMTMyOWQ1", + Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14), + }, + nil, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + svc := sts.New(session.Must(session.NewSession( + &aws.Config{ + Credentials: tc.creds, + Region: aws.String("us-west-2"), + STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, + }, + ))) + + gen := &generator{ + forwardSessionName: false, + cache: false, + nowFunc: func() time.Time { return tc.nowTime }, + } + + got, err := gen.GetWithSTS(clusterID, svc) + if diff := cmp.Diff(err, tc.wantErr); diff != "" { + t.Errorf("Unexpected error: %s", diff) + } + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("Got unexpected token: %s", diff) + } + }) + } +}