From 885868b77349c3c1a0a790e9e444cf8ae8c5453a Mon Sep 17 00:00:00 2001 From: Terry Cain Date: Tue, 26 Sep 2023 23:32:37 +0100 Subject: [PATCH] Add best case role ARN cache to lookup role ARNs with paths Signed-off-by: Terry Cain --- pkg/token/rolecache.go | 103 ++++++++++++++++++++++++++++ pkg/token/rolecache_test.go | 133 ++++++++++++++++++++++++++++++++++++ pkg/token/token.go | 9 +++ 3 files changed, 245 insertions(+) create mode 100644 pkg/token/rolecache.go create mode 100644 pkg/token/rolecache_test.go diff --git a/pkg/token/rolecache.go b/pkg/token/rolecache.go new file mode 100644 index 000000000..52761a1e4 --- /dev/null +++ b/pkg/token/rolecache.go @@ -0,0 +1,103 @@ +package token + +import ( + "errors" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/iam" + "github.com/sirupsen/logrus" + "sync" + "time" +) + +type IAMClient interface { + ListRolesPages(input *iam.ListRolesInput, fn func(*iam.ListRolesOutput, bool) bool) error +} + +// RoleCache Cache on first use IAM role cache which stores a map of Role ID -> ARN for lookup during verify +// operations. This solves the issue of assume-role ARNs not including the assumed role's path. +type RoleCache struct { + awsClient IAMClient + searchRoles bool + lastUpdate time.Time + idToFullARN map[string]string + mutex sync.RWMutex +} + +// NewRoleCache Creates a RoleCache and returns it. +func NewRoleCache() *RoleCache { + sess, err := session.NewSessionWithOptions(session.Options{SharedConfigState: session.SharedConfigEnable}) + var iamClient *iam.IAM + if err != nil { + logrus.WithError(err).Warn("failed to instantiate AWS session, disabling full role ARN lookup") + } else { + iamClient = iam.New(sess) + } + + return &RoleCache{ + awsClient: iamClient, + searchRoles: iamClient != nil, + lastUpdate: time.Now().Add(-10 * time.Minute), + idToFullARN: make(map[string]string), + mutex: sync.RWMutex{}, + } +} + +// updateRoles Calls IAM ListRoles and updates the idToFullARN map. +func (r *RoleCache) updateRoles() { + r.mutex.Lock() + defer r.mutex.Unlock() + + r.lastUpdate = time.Now() + + newARNMap := make(map[string]string) + + err := r.awsClient.ListRolesPages(&iam.ListRolesInput{}, func(output *iam.ListRolesOutput, b bool) bool { + for _, role := range output.Roles { + newARNMap[*role.RoleId] = *role.Arn + } + return true + }) + if err != nil { + var aerr awserr.Error + if errors.As(err, &aerr) { + // If we don't have credentials or have access to ListRole then cancel searching roles in future + switch aerr.Code() { + case "NoCredentialProviders": + logrus.WithError(aerr).Error("no credentials found to list IAM roles with, disabling role cache") + r.searchRoles = false + case "AccessDenied": + logrus.WithError(aerr).Error("no access to IAM list role, disabling role cache") + r.searchRoles = false + default: + // Treat as transient error + logrus.WithError(aerr).Error("transient IAM role list failure") + } + } else { + // Non aws error + logrus.WithError(err).Error("failed to list IAM roles") + r.searchRoles = false + } + + return + } + + r.idToFullARN = newARNMap +} + +// CheckRoleID Takes a unique role ID and returns an ARN if found. +func (r *RoleCache) CheckRoleID(roleID string) (string, bool) { + if !r.searchRoles { + return "", false + } + + if time.Now().Sub(r.lastUpdate) > 5*time.Minute { + r.updateRoles() + } + + r.mutex.RLock() + defer r.mutex.RUnlock() + + roleARN, exists := r.idToFullARN[roleID] + return roleARN, exists +} diff --git a/pkg/token/rolecache_test.go b/pkg/token/rolecache_test.go new file mode 100644 index 000000000..c733c31b4 --- /dev/null +++ b/pkg/token/rolecache_test.go @@ -0,0 +1,133 @@ +package token + +import ( + "errors" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/iam" + "testing" +) + +type mockIAM struct { + Called int + Pages []*iam.ListRolesOutput + Err error +} + +func (m *mockIAM) ListRolesPages(input *iam.ListRolesInput, fn func(*iam.ListRolesOutput, bool) bool) error { + m.Called++ + + if m.Err != nil { + return m.Err + } + + lastItem := len(m.Pages) - 1 + for index, page := range m.Pages { + fn(page, index == lastItem) + } + + return nil +} + +func getMockedRoleCache() (*RoleCache, *mockIAM) { + r := NewRoleCache() + m := &mockIAM{ + Pages: make([]*iam.ListRolesOutput, 0), + Err: nil, + } + r.awsClient = m + r.searchRoles = true + + return r, m +} + +func TestRoleCache_SuccessfulLookup(t *testing.T) { + r, mock := getMockedRoleCache() + + mock.Pages = append(mock.Pages, &iam.ListRolesOutput{Roles: []*iam.Role{ + {RoleId: aws.String("someid1"), Arn: aws.String("somearn1")}, + }}) + + lookupARN1, exists1 := r.CheckRoleID("someid1") + if !exists1 { + t.Fatal("ARN lookup 1 should have found an ARN") + } + if lookupARN1 != "somearn1" { + t.Fatalf("ARN lookup 1 expected %s, got %s", "somearn1", lookupARN1) + } + if mock.Called != 1 { + t.Fatal("IAM Mock called counter incorrect") + } + + _, exists2 := r.CheckRoleID("someid2") + if exists2 { + t.Fatal("ARN lookup 2 should have not found an ARN") + } + if mock.Called != 1 { + t.Fatalf("IAM Mock called erronously, counter should be %d but got %d", 1, mock.Called) + } +} + +func TestRoleCache_AccessDenied(t *testing.T) { + r, mock := getMockedRoleCache() + + mock.Err = awserr.New("AccessDenied", "access denied", errors.New("some access denied error")) + _, exists := r.CheckRoleID("someid1") + if exists { + t.Fatal("ARN lookup should have not found an ARN") + } + if mock.Called != 1 { + t.Fatal("Mock lookup was not called") + } + if r.searchRoles { + t.Fatal("Role searching should have been permanently disabled") + } +} + +func TestRoleCache_NoCredentialProviders(t *testing.T) { + r, mock := getMockedRoleCache() + + mock.Err = awserr.New("NoCredentialProviders", "no creds", errors.New("no creds")) + _, exists := r.CheckRoleID("someid1") + if exists { + t.Fatal("ARN lookup should have not found an ARN") + } + if mock.Called != 1 { + t.Fatal("Mock lookup was not called") + } + if r.searchRoles { + t.Fatal("Role searching should have been permanently disabled") + } +} + +func TestRoleCache_TransientError(t *testing.T) { + r, mock := getMockedRoleCache() + + mock.Err = awserr.New("TransientError", "random error", errors.New("random error")) + _, exists := r.CheckRoleID("someid1") + if exists { + t.Fatal("ARN lookup should have not found an ARN") + } + if mock.Called != 1 { + t.Fatal("Mock lookup was not called") + } + if !r.searchRoles { + t.Fatal("Role searching should not have been permanently disabled") + } +} + +func TestRoleCache_NonAWSError(t *testing.T) { + r, mock := getMockedRoleCache() + + mock.Err = errors.New("non aws error") + _, exists := r.CheckRoleID("someid1") + if exists { + t.Fatal("ARN lookup should have not found an ARN") + } + if mock.Called != 1 { + t.Fatal("Mock lookup was not called") + } + if r.searchRoles { + t.Fatal("Role searching should have been permanently disabled") + } +} diff --git a/pkg/token/token.go b/pkg/token/token.go index 99b286074..da04d40c1 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -390,6 +390,7 @@ type tokenVerifier struct { client *http.Client clusterID string validSTShostnames map[string]bool + roleCache *RoleCache } func stsHostsForPartition(partitionID, region string) map[string]bool { @@ -462,6 +463,7 @@ func NewVerifier(clusterID, partitionID, region string) Verifier { }, clusterID: clusterID, validSTShostnames: stsHostsForPartition(partitionID, region), + roleCache: NewRoleCache(), } } @@ -625,6 +627,13 @@ func (v tokenVerifier) Verify(token string) (*Identity, error) { callerIdentity.GetCallerIdentityResponse.GetCallerIdentityResult.UserID)} } + // STS get-caller-identity can return an assume-role ARN which does not include the IAM role path + if strings.HasPrefix(id.UserID, "AROA") { + if roleARN, exists := v.roleCache.CheckRoleID(id.UserID); exists { + id.CanonicalARN = roleARN + } + } + return id, nil }