Skip to content

Commit

Permalink
Add configurable Now time for signature generation
Browse files Browse the repository at this point in the history
This behavior and test is prep for an AWS SDK update, to ensure that the
generated token (signature) matches when we update the SDK.

Signed-off-by: Micah Hausler <[email protected]>
  • Loading branch information
micahhausler committed Aug 20, 2024
1 parent efea3dd commit 8a57c73
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 1 deletion.
16 changes: 15 additions & 1 deletion pkg/token/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
v4 "github.com/aws/aws-sdk-go/aws/signer/v4"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/prometheus/client_golang/prometheus"
Expand Down Expand Up @@ -198,13 +199,15 @@ type Generator interface {
type generator struct {
forwardSessionName bool
cache bool
nowFunc func() time.Time
}

// NewGenerator creates a Generator and returns it.
func NewGenerator(forwardSessionName bool, cache bool) (Generator, error) {
return generator{
forwardSessionName: forwardSessionName,
cache: cache,
nowFunc: time.Now,
}, nil
}

Expand Down Expand Up @@ -332,12 +335,23 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) {
return g.GetWithSTS(options.ClusterID, stsAPI)
}

func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler {
return request.NamedHandler{
Name: "v4.SignRequestHandler", Fn: func(req *request.Request) {
v4.SignSDKRequestWithCurrentTime(req, nowFunc)
},
}
}

// GetWithSTS returns a token valid for clusterID using the given STS client.
func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) {
// generate an sts:GetCallerIdentity request and add our custom cluster ID header
request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{})
request.HTTPRequest.Header.Add(clusterIDHeader, clusterID)

// override the Sign handler so we can control the now time for testing.
request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc))

// Sign the request. The expires parameter (sets the x-amz-expires header) is
// currently ignored by STS, and the token expires 15 minutes after the x-amz-date
// timestamp regardless. We set it to 60 seconds for backwards compatibility (the
Expand All @@ -350,7 +364,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token,
}

// Set token expiration to 1 minute before the presigned URL expires for some cushion
tokenExpiration := time.Now().Local().Add(presignedURLExpiration - 1*time.Minute)
tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute)
// TODO: this may need to be a constant-time base64 encoding
return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil
}
Expand Down
62 changes: 62 additions & 0 deletions pkg/token/token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/google/go-cmp/cmp"
"github.com/prometheus/client_golang/prometheus"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
Expand Down Expand Up @@ -582,3 +586,61 @@ func Test_getDefaultHostNameForRegion(t *testing.T) {
})
}
}

func TestGetWithSTS(t *testing.T) {
clusterID := "test-cluster"

cases := []struct {
name string
creds *credentials.Credentials
nowTime time.Time
want Token
wantErr error
}{
{
"Non-zero time",
// Example non-real credentials
func() *credentials.Credentials {
decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=")
decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==")
return credentials.NewStaticCredentials(
string(decodedAkid),
string(decodedSk),
"",
)
}(),
time.Unix(1682640000, 0),
Token{
Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JTNCeC1rOHMtYXdzLWlkJlgtQW16LVNpZ25hdHVyZT00ZDdhYmZkZTk2NzI1ZWI4YTc3MzgyNDg0MTZlNGI1ZDA4ZDlkYmQ3MThiNGY2ZGQ2OTBmOGZiNzUwMTMyOWQ1",
Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14),
},
nil,
},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
svc := sts.New(session.Must(session.NewSession(
&aws.Config{
Credentials: tc.creds,
Region: aws.String("us-west-2"),
STSRegionalEndpoint: endpoints.RegionalSTSEndpoint,
},
)))

gen := &generator{
forwardSessionName: false,
cache: false,
nowFunc: func() time.Time { return tc.nowTime },
}

got, err := gen.GetWithSTS(clusterID, svc)
if diff := cmp.Diff(err, tc.wantErr); diff != "" {
t.Errorf("Unexpected error: %s", diff)
}
if diff := cmp.Diff(tc.want, got); diff != "" {
t.Errorf("Got unexpected token: %s", diff)
}
})
}
}

0 comments on commit 8a57c73

Please sign in to comment.