From 846431661493b82d4ffb3b28ab84d93dc21a9df2 Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 29 Aug 2024 11:05:38 -0500 Subject: [PATCH 1/3] Refactored token filecache The token filecache used to use a private global function for creating a filelock, and overrode it in tests with a hand-crafted mocks for filesystem and environment variable operations. This change adds adds injectability to the filecache's filesystem and file lock using afero. This change also will simplify future changes when updating the AWS SDK with new credential interfaces. Signed-off-by: Micah Hausler --- go.mod | 2 +- pkg/{token => filecache}/filecache.go | 173 ++++----- pkg/{token => filecache}/filecache_test.go | 411 ++++++++++++--------- pkg/token/token.go | 5 +- tests/integration/go.mod | 1 + tests/integration/go.sum | 2 + 6 files changed, 333 insertions(+), 261 deletions(-) rename pkg/{token => filecache}/filecache.go (71%) rename pkg/{token => filecache}/filecache_test.go (55%) diff --git a/go.mod b/go.mod index eb704fd2c..3d866c974 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/manifoldco/promptui v0.9.0 github.com/prometheus/client_golang v1.19.1 github.com/sirupsen/logrus v1.9.3 + github.com/spf13/afero v1.11.0 github.com/spf13/cobra v1.8.1 github.com/spf13/viper v1.18.2 golang.org/x/time v0.5.0 @@ -61,7 +62,6 @@ require ( github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect diff --git a/pkg/token/filecache.go b/pkg/filecache/filecache.go similarity index 71% rename from pkg/token/filecache.go rename to pkg/filecache/filecache.go index e1a0c2a84..41597edaa 100644 --- a/pkg/token/filecache.go +++ b/pkg/filecache/filecache.go @@ -1,4 +1,4 @@ -package token +package filecache import ( "context" @@ -12,68 +12,22 @@ import ( "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" + "github.com/spf13/afero" "gopkg.in/yaml.v2" ) // env variable name for custom credential cache file location const cacheFileNameEnv = "AWS_IAM_AUTHENTICATOR_CACHE_FILE" -// A mockable filesystem interface -var f filesystem = osFS{} - -type filesystem interface { - Stat(filename string) (os.FileInfo, error) - ReadFile(filename string) ([]byte, error) - WriteFile(filename string, data []byte, perm os.FileMode) error - MkdirAll(path string, perm os.FileMode) error -} - -// default os based implementation -type osFS struct{} - -func (osFS) Stat(filename string) (os.FileInfo, error) { - return os.Stat(filename) -} - -func (osFS) ReadFile(filename string) ([]byte, error) { - return os.ReadFile(filename) -} - -func (osFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - return os.WriteFile(filename, data, perm) -} - -func (osFS) MkdirAll(path string, perm os.FileMode) error { - return os.MkdirAll(path, perm) -} - -// A mockable environment interface -var e environment = osEnv{} - -type environment interface { - Getenv(key string) string - LookupEnv(key string) (string, bool) -} - -// default os based implementation -type osEnv struct{} - -func (osEnv) Getenv(key string) string { - return os.Getenv(key) -} - -func (osEnv) LookupEnv(key string) (string, bool) { - return os.LookupEnv(key) -} - -// A mockable flock interface -type filelock interface { +// FileLocker is a subset of the methods exposed by *flock.Flock +type FileLocker interface { Unlock() error TryLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) TryRLockContext(ctx context.Context, retryDelay time.Duration) (bool, error) } -var newFlock = func(filename string) filelock { +// NewFileLocker returns a *flock.Flock that satisfies FileLocker +func NewFileLocker(filename string) FileLocker { return flock.New(filename) } @@ -135,11 +89,11 @@ func (c *cachedCredential) IsExpired() bool { // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. -func readCacheWhileLocked(filename string) (cache cacheFile, err error) { +func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ map[string]map[string]map[string]cachedCredential{}, } - data, err := f.ReadFile(filename) + data, err := afero.ReadFile(fs, filename) if err != nil { err = fmt.Errorf("unable to open file %s: %v", filename, err) return @@ -155,45 +109,86 @@ func readCacheWhileLocked(filename string) (cache cacheFile, err error) { // writeCacheWhileLocked writes the contents of the credential cache using the // yaml marshaled form of the passed cacheFile object. This method must be // called while an exclusive lock is held on the filename. -func writeCacheWhileLocked(filename string, cache cacheFile) error { +func writeCacheWhileLocked(fs afero.Fs, filename string, cache cacheFile) error { data, err := yaml.Marshal(cache) if err == nil { // write privately owned by the user - err = f.WriteFile(filename, data, 0600) + err = afero.WriteFile(fs, filename, data, 0600) } return err } -// FileCacheProvider is a Provider implementation that wraps an underlying Provider +type FileCacheOpt func(*FileCacheProvider) + +// WithFs returns a FileCacheOpt that sets the cache's filesystem +func WithFs(fs afero.Fs) FileCacheOpt { + return func(p *FileCacheProvider) { + p.fs = fs + } +} + +// WithFilename returns a FileCacheOpt that sets the cache's file +func WithFilename(filename string) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filename = filename + } +} + +// WithFileLockCreator returns a FileCacheOpt that sets the cache's FileLocker +// creation function +func WithFileLockerCreator(f func(string) FileLocker) FileCacheOpt { + return func(p *FileCacheProvider) { + p.filelockCreator = f + } +} + +// FileCacheProvider is a credentials.Provider implementation that wraps an underlying Provider // (contained in Credentials) and provides caching support for credentials for the // specified clusterID, profile, and roleARN (contained in cacheKey) type FileCacheProvider struct { + fs afero.Fs + filelockCreator func(string) FileLocker + filename string credentials *credentials.Credentials // the underlying implementation that has the *real* Provider cacheKey cacheKey // cache key parameters used to create Provider cachedCredential cachedCredential // the cached credential, if it exists } +var _ credentials.Provider = &FileCacheProvider{} + // NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials) (FileCacheProvider, error) { +func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { if creds == nil { - return FileCacheProvider{}, errors.New("no underlying Credentials object provided") + return nil, errors.New("no underlying Credentials object provided") + } + + resp := &FileCacheProvider{ + fs: afero.NewOsFs(), + filelockCreator: NewFileLocker, + filename: defaultCacheFilename(), + credentials: creds, + cacheKey: cacheKey{clusterID, profile, roleARN}, + cachedCredential: cachedCredential{}, } - filename := CacheFilename() - cacheKey := cacheKey{clusterID, profile, roleARN} - cachedCredential := cachedCredential{} + + // override defaults + for _, opt := range opts { + opt(resp) + } + // ensure path to cache file exists - _ = f.MkdirAll(filepath.Dir(filename), 0700) - if info, err := f.Stat(filename); err == nil { + _ = resp.fs.MkdirAll(filepath.Dir(resp.filename), 0700) + if info, err := resp.fs.Stat(resp.filename); err == nil { if info.Mode()&0077 != 0 { // cache file has secret credentials and should only be accessible to the user, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("cache file %s is not private", filename) + return nil, fmt.Errorf("cache file %s is not private", resp.filename) } // do file locking on cache to prevent inconsistent reads - lock := newFlock(filename) + lock := resp.filelockCreator(resp.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -201,30 +196,26 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials ok, err := lock.TryRLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // unable to lock the cache, something is wrong, refuse to use it. - return FileCacheProvider{}, fmt.Errorf("unable to read lock file %s: %v", filename, err) + return nil, fmt.Errorf("unable to read lock file %s: %v", resp.filename, err) } - cache, err := readCacheWhileLocked(filename) + cache, err := readCacheWhileLocked(resp.fs, resp.filename) if err != nil { // can't read or parse cache, refuse to use it. - return FileCacheProvider{}, err + return nil, err } - cachedCredential = cache.Get(cacheKey) + resp.cachedCredential = cache.Get(resp.cacheKey) } else { if errors.Is(err, fs.ErrNotExist) { // cache file is missing. maybe this is the very first run? continue to use cache. - _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", filename) + _, _ = fmt.Fprintf(os.Stderr, "Cache file %s does not exist.\n", resp.filename) } else { - return FileCacheProvider{}, fmt.Errorf("couldn't stat cache file: %w", err) + return nil, fmt.Errorf("couldn't stat cache file: %w", err) } } - return FileCacheProvider{ - creds, - cacheKey, - cachedCredential, - }, nil + return resp, nil } // Retrieve() implements the Provider interface, returning the cached credential if is not expired, @@ -243,9 +234,9 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { } if expiration, err := f.credentials.ExpiresAt(); err == nil { // underlying provider supports Expirer interface, so we can cache - filename := CacheFilename() + // do file locking on cache to prevent inconsistent writes - lock := newFlock(filename) + lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock ctx, cancel := context.WithTimeout(context.TODO(), time.Second) @@ -253,7 +244,7 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) return credential, nil } f.cachedCredential = cachedCredential{ @@ -262,12 +253,12 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { nil, } // don't really care about read error. Either read the cache, or we create a new cache. - cache, _ := readCacheWhileLocked(filename) + cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) - err = writeCacheWhileLocked(filename, cache) + err = writeCacheWhileLocked(f.fs, f.filename, cache) if err != nil { // can't write cache, but still return the credential - _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", filename, err) + _, _ = fmt.Fprintf(os.Stderr, "Unable to update credential cache %s: %v\n", f.filename, err) err = nil } else { _, _ = fmt.Fprintf(os.Stderr, "Updated cached credential\n") @@ -292,23 +283,23 @@ func (f *FileCacheProvider) ExpiresAt() time.Time { return f.cachedCredential.Expiration } -// CacheFilename returns the name of the credential cache file, which can either be +// defaultCacheFilename returns the name of the credential cache file, which can either be // set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml -func CacheFilename() string { - if filename, ok := e.LookupEnv(cacheFileNameEnv); ok { +func defaultCacheFilename() string { + if filename := os.Getenv(cacheFileNameEnv); filename != "" { return filename } else { - return filepath.Join(UserHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") + return filepath.Join(userHomeDir(), ".kube", "cache", "aws-iam-authenticator", "credentials.yaml") } } -// UserHomeDir returns the home directory for the user the process is +// userHomeDir returns the home directory for the user the process is // running under. -func UserHomeDir() string { +func userHomeDir() string { if runtime.GOOS == "windows" { // Windows - return e.Getenv("USERPROFILE") + return os.Getenv("USERPROFILE") } // *nix - return e.Getenv("HOME") + return os.Getenv("HOME") } diff --git a/pkg/token/filecache_test.go b/pkg/filecache/filecache_test.go similarity index 55% rename from pkg/token/filecache_test.go rename to pkg/filecache/filecache_test.go index d69c75937..60b4a8771 100644 --- a/pkg/token/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,21 +1,32 @@ -package token +package filecache import ( "bytes" "context" "errors" - "github.com/aws/aws-sdk-go/aws/credentials" + "fmt" + "io/fs" "os" "testing" "time" + + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/spf13/afero" +) + +const ( + testFilename = "/test.yaml" ) +// stubProvider implements credentials.Provider with configurable response values type stubProvider struct { creds credentials.Value expired bool err error } +var _ credentials.Provider = &stubProvider{} + func (s *stubProvider) Retrieve() (credentials.Value, error) { s.expired = false s.creds.ProviderName = "stubProvider" @@ -26,89 +37,54 @@ func (s *stubProvider) IsExpired() bool { return s.expired } +// stubProviderExpirer implements credentials.Expirer with configurable expiration type stubProviderExpirer struct { stubProvider expiration time.Time } +var _ credentials.Expirer = &stubProviderExpirer{} + func (s *stubProviderExpirer) ExpiresAt() time.Time { return s.expiration } +// testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string size int64 - mode os.FileMode + mode fs.FileMode modTime time.Time } +var _ fs.FileInfo = &testFileInfo{} + func (fs *testFileInfo) Name() string { return fs.name } func (fs *testFileInfo) Size() int64 { return fs.size } -func (fs *testFileInfo) Mode() os.FileMode { return fs.mode } +func (fs *testFileInfo) Mode() fs.FileMode { return fs.mode } func (fs *testFileInfo) ModTime() time.Time { return fs.modTime } func (fs *testFileInfo) IsDir() bool { return fs.Mode().IsDir() } func (fs *testFileInfo) Sys() interface{} { return nil } +// testFs wraps afero.Fs with an overridable Stat() method type testFS struct { - filename string - fileinfo testFileInfo - data []byte + afero.Fs + + fileinfo fs.FileInfo err error - perm os.FileMode } -func (t *testFS) Stat(filename string) (os.FileInfo, error) { - t.filename = filename - if t.err == nil { - return &t.fileinfo, nil - } else { +func (t *testFS) Stat(filename string) (fs.FileInfo, error) { + if t.err != nil { return nil, t.err } + if t.fileinfo != nil { + return t.fileinfo, nil + } + return t.Fs.Stat(filename) } -func (t *testFS) ReadFile(filename string) ([]byte, error) { - t.filename = filename - return t.data, t.err -} - -func (t *testFS) WriteFile(filename string, data []byte, perm os.FileMode) error { - t.filename = filename - t.data = data - t.perm = perm - return t.err -} - -func (t *testFS) MkdirAll(path string, perm os.FileMode) error { - t.filename = path - t.perm = perm - return t.err -} - -func (t *testFS) reset() { - t.filename = "" - t.fileinfo = testFileInfo{} - t.data = []byte{} - t.err = nil - t.perm = 0600 -} - -type testEnv struct { - values map[string]string -} - -func (e *testEnv) Getenv(key string) string { - return e.values[key] -} - -func (e *testEnv) LookupEnv(key string) (string, bool) { - value, ok := e.values[key] - return value, ok -} - -func (e *testEnv) reset() { - e.values = map[string]string{} -} - +// testFileLock implements FileLocker with configurable response options type testFilelock struct { ctx context.Context retryDelay time.Duration @@ -116,6 +92,8 @@ type testFilelock struct { err error } +var _ FileLocker = &testFilelock{} + func (l *testFilelock) Unlock() error { return nil } @@ -132,28 +110,12 @@ func (l *testFilelock) TryRLockContext(ctx context.Context, retryDelay time.Dura return l.success, l.err } -func (l *testFilelock) reset() { - l.ctx = context.TODO() - l.retryDelay = 0 - l.success = true - l.err = nil -} - -func getMocks() (tf *testFS, te *testEnv, testFlock *testFilelock) { - tf = &testFS{} - tf.reset() - f = tf - te = &testEnv{} - te.reset() - e = te - testFlock = &testFilelock{} - testFlock.reset() - newFlock = func(filename string) filelock { - return testFlock - } - return +// getMocks returns a mocked filesystem and FileLocker +func getMocks() (*testFS, *testFilelock) { + return &testFS{Fs: afero.NewMemMapFs()}, &testFilelock{context.TODO(), 0, true, nil} } +// makeCredential returns a dummy AWS crdential func makeCredential() credentials.Value { return credentials.Value{ AccessKeyID: "AKID", @@ -163,7 +125,9 @@ func makeCredential() credentials.Value { } } -func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c *credentials.Credentials) { +// validateFileCacheProvider ensures that the cache provider is properly initialized +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { + t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -181,21 +145,37 @@ func validateFileCacheProvider(t *testing.T, p FileCacheProvider, err error, c * } } +// testSetEnv sets an env var, and returns a cleanup func +func testSetEnv(t *testing.T, key, value string) func() { + t.Helper() + old := os.Getenv(key) + os.Setenv(key, value) + return func() { + if old == "" { + os.Unsetenv(key) + } else { + os.Setenv(key, old) + } + } +} + func TestCacheFilename(t *testing.T) { - _, te, _ := getMocks() - te.values["HOME"] = "homedir" // unix - te.values["USERPROFILE"] = "homedir" // windows + c1 := testSetEnv(t, "HOME", "homedir") + defer c1() + c2 := testSetEnv(t, "USERPROFILE", "homedir") + defer c2() - filename := CacheFilename() + filename := defaultCacheFilename() expected := "homedir/.kube/cache/aws-iam-authenticator/credentials.yaml" if filename != expected { t.Errorf("Incorrect default cacheFilename, expected %s, got %s", expected, filename) } - te.values["AWS_IAM_AUTHENTICATOR_CACHE_FILE"] = "special.yaml" - filename = CacheFilename() + c3 := testSetEnv(t, "AWS_IAM_AUTHENTICATOR_CACHE_FILE", "special.yaml") + defer c3() + filename = defaultCacheFilename() expected = "special.yaml" if filename != expected { t.Errorf("Incorrect custom cacheFilename, expected %s, got %s", @@ -206,85 +186,133 @@ func TestCacheFilename(t *testing.T) { func TestNewFileCacheProvider_Missing(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, tfl := getMocks() - // missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + })) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing cache file should result in expired cached credential") } - tf.err = nil } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() + tfs, _ := getMocks() + // afero.MemMapFs always returns tempfile FileInfo, + // so we manually set the response to the Stat() call + tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - tf.fileinfo.mode = 0777 - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + ) if err == nil { t.Errorf("Expected error due to public permissions") } - if tf.filename != CacheFilename() { - t.Errorf("unexpected file checked, expected %s, got %s", - CacheFilename(), tf.filename) + wantMsg := fmt.Sprintf("cache file %s is not private", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } } func TestNewFileCacheProvider_Unlockable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, testFlock := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // unable to lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to lock failure") } - testFlock.success = true - testFlock.err = nil } func TestNewFileCacheProvider_Unreadable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // unable to read existing cache - tf.err = errors.New("read failure") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + tfl.err = fmt.Errorf("open %s: permission denied", testFilename) + tfl.success = false + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to read failure") + return + } + wantMsg := fmt.Sprintf("unable to read lock file %s: open %s: permission denied", testFilename, testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) } - tf.err = nil } func TestNewFileCacheProvider_Unparseable(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // unable to parse yaml - tf.data = []byte("invalid: yaml: file") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + afero.WriteFile( + tfs, + testFilename, + []byte("invalid: yaml: file"), + 0700) + return tfl + }), + ) if err == nil { t.Errorf("Expected error due to bad yaml") } + wantMsg := fmt.Sprintf("unable to parse file %s: yaml: mapping values are not allowed in this context", testFilename) + if err.Error() != wantMsg { + t.Errorf("Incorrect error, wanted '%s', got '%s'", wantMsg, err.Error()) + } } func TestNewFileCacheProvider_Empty(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - _, _, _ = getMocks() + tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("empty cache file should result in expired cached credential") @@ -294,13 +322,24 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse existing cluster without matching arn - tf.data = []byte(`clusters: + tfs, tfl := getMocks() + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: CLUSTER: -`) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + ARN2: {} +`), + 0700) + // successfully parse existing cluster without matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) if !p.cachedCredential.IsExpired() { t.Errorf("missing arn in cache file should result in expired cached credential") @@ -310,10 +349,7 @@ func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { func TestNewFileCacheProvider_ExistingARN(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) - tf, _, _ := getMocks() - - // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -324,11 +360,27 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { providername: JKL expiration: 2018-01-02T03:04:56.789Z `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + tfs.Create(testFilename) + + // successfully parse cluster with matching arn + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + }), + ) + if err != nil { + t.Errorf("Unexpected error: %v", err) + return + } validateFileCacheProvider(t, p, err, c) if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { - t.Errorf("cached credential not extracted correctly") + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } // fiddle with clock p.cachedCredential.currentTime = func() time.Time { @@ -353,11 +405,17 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { creds: providerCredential, }) - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) credential, err := p.Retrieve() @@ -370,6 +428,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { } } +// makeExpirerCredentials returns an expiring credential func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { providerCredential = makeCredential() expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) @@ -385,17 +444,23 @@ func makeExpirerCredentials() (providerCredential credentials.Value, expiration func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, testFlock := getMocks() + tfs, tfl := getMocks() + // don't create the empty cache file, create it in the filelock creator - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + })) validateFileCacheProvider(t, p, err, c) // retrieve credential, which will fetch from underlying Provider // fail to get write lock - testFlock.success = false - testFlock.err = errors.New("lock stuck, needs wd-40") + tfl.success = false + tfl.err = errors.New("lock stuck, needs wd-40") + credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -409,16 +474,19 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providerCredential, expiration, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - // retrieve credential, which will fetch from underlying Provider - // fail to write cache - tf.err = errors.New("can't write cache") credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) @@ -427,14 +495,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { t.Errorf("Cache did not return provider credential, got %v, expected %v", credential, providerCredential) } - if tf.filename != CacheFilename() { - t.Errorf("Wrote to wrong file, expected %v, got %v", - CacheFilename(), tf.filename) - } - if tf.perm != 0600 { - t.Errorf("Wrote with wrong permissions, expected %o, got %o", - 0600, tf.perm) - } + expectedData := []byte(`clusters: CLUSTER: PROFILE: @@ -446,22 +507,31 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { providername: stubProvider expiration: ` + expiration.Format(time.RFC3339Nano) + ` `) - if bytes.Compare(tf.data, expectedData) != 0 { + got, err := afero.ReadFile(tfs, testFilename) + if err != nil { + t.Errorf("unexpected error reading generated file: %v", err) + } + if !bytes.Equal(got, expectedData) { t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, tf.data) + expectedData, got) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { providerCredential, _, c := makeExpirerCredentials() - tf, _, _ := getMocks() - - // initialize from missing cache file - tf.err = os.ErrNotExist - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + tfs, tfl := getMocks() + // don't create the file, let the FileLocker create it + + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + return tfl + }), + ) validateFileCacheProvider(t, p, err, c) - tf.err = nil // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -478,11 +548,13 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { c := credentials.NewCredentials(&stubProvider{}) + currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - tf, _, _ := getMocks() + tfs, tfl := getMocks() + tfs.Create(testFilename) // successfully parse cluster with matching arn - tf.data = []byte(`clusters: + content := []byte(`clusters: CLUSTER: PROFILE: ARN: @@ -491,15 +563,20 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { secretaccesskey: DEF sessiontoken: GHI providername: JKL - expiration: 2018-01-02T03:04:56.789Z + expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c) + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + WithFilename(testFilename), + WithFs(tfs), + WithFileLockerCreator(func(string) FileLocker { + tfs.Create(testFilename) + afero.WriteFile(tfs, testFilename, content, 0700) + return tfl + })) validateFileCacheProvider(t, p, err, c) // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } + p.cachedCredential.currentTime = func() time.Time { return currentTime } credential, err := p.Retrieve() if err != nil { diff --git a/pkg/token/token.go b/pkg/token/token.go index 16ab8d92b..d9d7fd2e8 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -44,6 +44,7 @@ import ( clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -247,8 +248,8 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { - sess.Config.Credentials = credentials.NewCredentials(&cacheProvider) + if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) } diff --git a/tests/integration/go.mod b/tests/integration/go.mod index 451b89c6d..ee7a84140 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -72,6 +72,7 @@ require ( github.com/prometheus/common v0.55.0 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect + github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cobra v1.8.1 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/stoewer/go-strcase v1.2.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index f4a756a0e..c85dc3777 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -172,6 +172,8 @@ github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/soheilhy/cmux v0.1.5 h1:jjzc5WVemNEDTLwv9tlmemhC73tI08BNOIGwBOo10Js= github.com/soheilhy/cmux v0.1.5/go.mod h1:T7TcVDs9LWfQgPlPsdngu6I6QIoyIFZDDC6sNE1GqG0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= From f4ca66151e93c8993c4dc30a091a0fa8ac28219e Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 29 Aug 2024 14:36:13 -0500 Subject: [PATCH 2/3] Update filecache to use AWS SDK Go V2 with wrappers This changes updates filecache's internal types to use the AWS SDK Go v2's types, while preserving the external interface used by /pkg/token. This will simplify the future project-wide change for AWS SDK Go v2. Signed-off-by: Micah Hausler --- go.mod | 2 + go.sum | 4 + pkg/filecache/converter.go | 55 +++++++ pkg/filecache/filecache.go | 82 ++++----- pkg/filecache/filecache_test.go | 284 ++++++++++++++++---------------- pkg/token/token.go | 6 +- tests/integration/go.mod | 2 + tests/integration/go.sum | 4 + 8 files changed, 246 insertions(+), 193 deletions(-) create mode 100644 pkg/filecache/converter.go diff --git a/go.mod b/go.mod index 3d866c974..58736482f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go-v2 v1.30.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 @@ -25,6 +26,7 @@ require ( ) require ( + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 2a3bcb3a0..9d1fc5538 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= diff --git a/pkg/filecache/converter.go b/pkg/filecache/converter.go new file mode 100644 index 000000000..ec2f16bde --- /dev/null +++ b/pkg/filecache/converter.go @@ -0,0 +1,55 @@ +package filecache + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +type v2 struct { + creds *credentials.Credentials +} + +var _ aws.CredentialsProvider = &v2{} + +func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) { + val, err := p.creds.GetWithContext(ctx) + if err != nil { + return aws.Credentials{}, err + } + resp := aws.Credentials{ + AccessKeyID: val.AccessKeyID, + SecretAccessKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + Source: val.ProviderName, + CanExpire: false, + // Don't have account ID + } + + if expiration, err := p.creds.ExpiresAt(); err != nil { + resp.CanExpire = true + resp.Expires = expiration + } + return resp, nil +} + +// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider +func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider { + return V1CredentialToV2Provider(credentials.NewCredentials(p)) +} + +// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider +func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider { + return &v2{creds: c} +} + +// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value +func V2CredentialToV1Value(cred aws.Credentials) credentials.Value { + return credentials.Value{ + AccessKeyID: cred.AccessKeyID, + SecretAccessKey: cred.SecretAccessKey, + SessionToken: cred.SessionToken, + ProviderName: cred.Source, + } +} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go index 41597edaa..64092b9f4 100644 --- a/pkg/filecache/filecache.go +++ b/pkg/filecache/filecache.go @@ -10,6 +10,7 @@ import ( "runtime" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" "github.com/spf13/afero" @@ -34,7 +35,7 @@ func NewFileLocker(filename string) FileLocker { // cacheFile is a map of clusterID/roleARNs to cached credentials type cacheFile struct { // a map of clusterIDs/profiles/roleARNs to cachedCredentials - ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"` + ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"` } // a utility type for dealing with compound cache keys @@ -44,19 +45,19 @@ type cacheKey struct { roleARN string } -func (c *cacheFile) Put(key cacheKey, credential cachedCredential) { +func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; !ok { // first use of this cluster id - c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{} + c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{} } if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok { // first use of this profile - c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{} + c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{} } c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential } -func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { +func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; ok { if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok { // we at least have this cluster and profile combo in the map, if no matching roleARN, map will @@ -67,31 +68,12 @@ func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { return } -// cachedCredential is a single cached credential entry, along with expiration time -type cachedCredential struct { - Credential credentials.Value - Expiration time.Time - // If set will be used by IsExpired to determine the current time. - // Defaults to time.Now if CurrentTime is not set. Available for testing - // to be able to mock out the current time. - currentTime func() time.Time -} - -// IsExpired determines if the cached credential has expired -func (c *cachedCredential) IsExpired() bool { - curTime := c.currentTime - if curTime == nil { - curTime = time.Now - } - return c.Expiration.Before(curTime()) -} - // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ - map[string]map[string]map[string]cachedCredential{}, + map[string]map[string]map[string]aws.Credentials{}, } data, err := afero.ReadFile(fs, filename) if err != nil { @@ -149,9 +131,9 @@ type FileCacheProvider struct { fs afero.Fs filelockCreator func(string) FileLocker filename string - credentials *credentials.Credentials // the underlying implementation that has the *real* Provider - cacheKey cacheKey // cache key parameters used to create Provider - cachedCredential cachedCredential // the cached credential, if it exists + provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider + cacheKey cacheKey // cache key parameters used to create Provider + cachedCredential aws.Credentials // the cached credential, if it exists } var _ credentials.Provider = &FileCacheProvider{} @@ -160,8 +142,8 @@ var _ credentials.Provider = &FileCacheProvider{} // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { - if creds == nil { +func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) { + if provider == nil { return nil, errors.New("no underlying Credentials object provided") } @@ -169,9 +151,9 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials fs: afero.NewOsFs(), filelockCreator: NewFileLocker, filename: defaultCacheFilename(), - credentials: creds, + provider: provider, cacheKey: cacheKey{clusterID, profile, roleARN}, - cachedCredential: cachedCredential{}, + cachedCredential: aws.Credentials{}, } // override defaults @@ -222,36 +204,40 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials // otherwise fetching the credential from the underlying Provider and caching the results on disk // with an expiration time. func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - if !f.cachedCredential.IsExpired() { + return f.RetrieveWithContext(context.Background()) +} + +// Retrieve() implements the Provider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying Provider and caching the results on disk +// with an expiration time. +func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { // use the cached credential - return f.cachedCredential.Credential, nil + return V2CredentialToV1Value(f.cachedCredential), nil } else { _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") // fetch the credentials from the underlying Provider - credential, err := f.credentials.Get() + credential, err := f.provider.Retrieve(ctx) if err != nil { - return credential, err + return V2CredentialToV1Value(credential), err } - if expiration, err := f.credentials.ExpiresAt(); err == nil { - // underlying provider supports Expirer interface, so we can cache + + if credential.CanExpire { + // Credential supports expiration, so we can cache // do file locking on cache to prevent inconsistent writes lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) - return credential, nil - } - f.cachedCredential = cachedCredential{ - credential, - expiration, - nil, + return V2CredentialToV1Value(credential), nil } + f.cachedCredential = credential // don't really care about read error. Either read the cache, or we create a new cache. cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) @@ -268,19 +254,19 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) err = nil } - return credential, err + return V2CredentialToV1Value(credential), err } } // IsExpired() implements the Provider interface, deferring to the cached credential first, // but fall back to the underlying Provider if it is expired. func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.IsExpired() && f.credentials.IsExpired() + return f.cachedCredential.CanExpire && f.cachedCredential.Expired() } // ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expiration + return f.cachedCredential.Expires } // defaultCacheFilename returns the name of the credential cache file, which can either be diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go index 60b4a8771..f2db98556 100644 --- a/pkg/filecache/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,7 +1,6 @@ package filecache import ( - "bytes" "context" "errors" "fmt" @@ -10,7 +9,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/google/go-cmp/cmp" "github.com/spf13/afero" ) @@ -20,35 +20,17 @@ const ( // stubProvider implements credentials.Provider with configurable response values type stubProvider struct { - creds credentials.Value - expired bool - err error + creds aws.Credentials + err error } -var _ credentials.Provider = &stubProvider{} +var _ aws.CredentialsProvider = &stubProvider{} -func (s *stubProvider) Retrieve() (credentials.Value, error) { - s.expired = false - s.creds.ProviderName = "stubProvider" +func (s *stubProvider) Retrieve(_ context.Context) (aws.Credentials, error) { + s.creds.Source = "stubProvider" return s.creds, s.err } -func (s *stubProvider) IsExpired() bool { - return s.expired -} - -// stubProviderExpirer implements credentials.Expirer with configurable expiration -type stubProviderExpirer struct { - stubProvider - expiration time.Time -} - -var _ credentials.Expirer = &stubProviderExpirer{} - -func (s *stubProviderExpirer) ExpiresAt() time.Time { - return s.expiration -} - // testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string @@ -116,22 +98,34 @@ func getMocks() (*testFS, *testFilelock) { } // makeCredential returns a dummy AWS crdential -func makeCredential() credentials.Value { - return credentials.Value{ +func makeCredential() aws.Credentials { + return aws.Credentials{ AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "TOKEN", - ProviderName: "stubProvider", + Source: "stubProvider", + CanExpire: false, + } +} + +func makeExpiringCredential(e time.Time) aws.Credentials { + return aws.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "TOKEN", + Source: "stubProvider", + CanExpire: true, + Expires: e, } } // validateFileCacheProvider ensures that the cache provider is properly initialized -func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c aws.CredentialsProvider) { t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } - if p.credentials != c { + if p.provider != c { t.Errorf("Credentials not copied") } if p.cacheKey.clusterID != "CLUSTER" { @@ -184,24 +178,24 @@ func TestCacheFilename(t *testing.T) { } func TestNewFileCacheProvider_Missing(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { return tfl })) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing cache file should result in empty cached credential") } } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, _ := getMocks() // afero.MemMapFs always returns tempfile FileInfo, @@ -209,7 +203,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), ) @@ -223,7 +217,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { } func TestNewFileCacheProvider_Unlockable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) @@ -232,7 +226,7 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { tfl.success = false tfl.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -245,14 +239,14 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { } func TestNewFileCacheProvider_Unreadable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) tfl.err = fmt.Errorf("open %s: permission denied", testFilename) tfl.success = false - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -270,12 +264,12 @@ func TestNewFileCacheProvider_Unreadable(t *testing.T) { } func TestNewFileCacheProvider_Unparseable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -297,12 +291,12 @@ func TestNewFileCacheProvider_Unparseable(t *testing.T) { } func TestNewFileCacheProvider_Empty(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -313,58 +307,60 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("empty cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("empty cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - afero.WriteFile( - tfs, - testFilename, - []byte(`clusters: - CLUSTER: - ARN2: {} -`), - 0700) + tfs.Create(testFilename) + // successfully parse existing cluster without matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { - tfs.Create(testFilename) + + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: + CLUSTER: + PROFILE2: {} +`), + 0700) return tfl }), ) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing arn in cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing profile in cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingARN(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} + expiry := time.Now().Add(time.Hour * 6) content := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: 2018-01-02T03:04:56.789Z + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + expires: ` + expiry.Format(time.RFC3339Nano) + ` `) tfs, tfl := getMocks() tfs.Create(testFilename) // successfully parse cluster with matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -377,38 +373,31 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || - p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.AccessKeyID != "ABC" || p.cachedCredential.SecretAccessKey != "DEF" || + p.cachedCredential.SessionToken != "GHI" || p.cachedCredential.Source != "JKL" { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } - if p.cachedCredential.IsExpired() { + + if p.cachedCredential.Expired() { t.Errorf("Cached credential should not be expired") } - if p.IsExpired() { - t.Errorf("Cache credential should not be expired") - } - expectedExpiration := time.Date(2018, 01, 02, 03, 04, 56, 789000000, time.UTC) - if p.ExpiresAt() != expectedExpiration { + + if p.ExpiresAt() != p.cachedCredential.Expires { t.Errorf("Credential expiration time is not correct, expected %v, got %v", - expectedExpiration, p.ExpiresAt()) + p.cachedCredential.Expires, p.ExpiresAt()) } } func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { - providerCredential := makeCredential() - c := credentials.NewCredentials(&stubProvider{ - creds: providerCredential, - }) + provider := &stubProvider{ + creds: makeCredential(), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -416,45 +405,37 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken { t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + credential, provider.creds) } } -// makeExpirerCredentials returns an expiring credential -func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { - providerCredential = makeCredential() - expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) - c = credentials.NewCredentials(&stubProviderExpirer{ - stubProvider{ - creds: providerCredential, - }, - expiration, - }) - return -} - func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { tfs.Create(testFilename) return tfl })) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // fail to get write lock @@ -465,19 +446,22 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || + credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { - providerCredential, expiration, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -485,45 +469,50 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } expectedData := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: AKID - secretaccesskey: SECRET - sessiontoken: TOKEN - providername: stubProvider - expiration: ` + expiration.Format(time.RFC3339Nano) + ` + accesskeyid: AKID + secretaccesskey: SECRET + sessiontoken: TOKEN + source: stubProvider + canexpire: true + expires: ` + expires.Format(time.RFC3339Nano) + ` + accountid: "" `) got, err := afero.ReadFile(tfs, testFilename) if err != nil { t.Errorf("unexpected error reading generated file: %v", err) } - if !bytes.Equal(got, expectedData) { - t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, got) + if diff := cmp.Diff(got, expectedData); diff != "" { + t.Errorf("Wrong data written to cache, %s", diff) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -531,7 +520,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -540,15 +529,17 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) + provider := &stubProvider{} + currentTime := time.Now() tfs, tfl := getMocks() tfs.Create(testFilename) @@ -559,13 +550,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { PROFILE: ARN: credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + canexpire: true + expires: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -573,10 +565,7 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { afero.WriteFile(tfs, testFilename, content, 0700) return tfl })) - validateFileCacheProvider(t, p, err, c) - - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { return currentTime } + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { @@ -586,4 +575,11 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { t.Errorf("cached credential not returned") } + + if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) { + t.Errorf("unexpected expiration time: got %s, wanted %s", + p.ExpiresAt().Format(time.RFC3339Nano), + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano), + ) + } } diff --git a/pkg/token/token.go b/pkg/token/token.go index d9d7fd2e8..716a8cb12 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -248,7 +248,11 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + if cacheProvider, err := filecache.NewFileCacheProvider( + options.ClusterID, + profile, + options.AssumeRoleARN, + filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil { sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) diff --git a/tests/integration/go.mod b/tests/integration/go.mod index ee7a84140..6781a0caf 100644 --- a/tests/integration/go.mod +++ b/tests/integration/go.mod @@ -18,6 +18,8 @@ require ( github.com/NYTimes/gziphandler v1.1.1 // indirect github.com/antlr4-go/antlr/v4 v4.13.0 // indirect github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a // indirect + github.com/aws/aws-sdk-go-v2 v1.30.4 // indirect + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect diff --git a/tests/integration/go.sum b/tests/integration/go.sum index c85dc3777..4794685e6 100644 --- a/tests/integration/go.sum +++ b/tests/integration/go.sum @@ -12,6 +12,10 @@ github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a h1:idn718Q4 github.com/asaskevich/govalidator v0.0.0-20190424111038-f61b66f89f4a/go.mod h1:lB+ZfQJz7igIIfQNfa7Ml4HSf2uFQQRzpGGRXenZAgY= github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= From 51bf911b467700e2910625aeb9a59836f85faa9a Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Mon, 26 Aug 2024 12:53:24 -0500 Subject: [PATCH 3/3] WIP: Update to aws-sdk-go-v2 --- cmd/aws-iam-authenticator/verify.go | 16 ++- go.mod | 15 ++- go.sum | 28 ++++- pkg/arn/arn.go | 2 +- pkg/ec2provider/ec2provider.go | 143 +++++++++--------------- pkg/ec2provider/ec2provider_test.go | 79 ++++--------- pkg/ec2provider/source_headers.go | 66 +++++++++++ pkg/filecache/converter.go | 55 --------- pkg/filecache/filecache.go | 35 ++---- pkg/filecache/filecache_test.go | 29 ++--- pkg/server/server.go | 44 +++++--- pkg/token/cluster_id_header.go | 45 ++++++++ pkg/token/token.go | 167 ++++++++++++++-------------- pkg/token/token_test.go | 42 ++++--- 14 files changed, 390 insertions(+), 376 deletions(-) create mode 100644 pkg/ec2provider/source_headers.go delete mode 100644 pkg/filecache/converter.go create mode 100644 pkg/token/cluster_id_header.go diff --git a/cmd/aws-iam-authenticator/verify.go b/cmd/aws-iam-authenticator/verify.go index 9bb37f4f4..839b19ef8 100644 --- a/cmd/aws-iam-authenticator/verify.go +++ b/cmd/aws-iam-authenticator/verify.go @@ -25,9 +25,9 @@ import ( "sigs.k8s.io/aws-iam-authenticator/pkg/token" - "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -54,14 +54,18 @@ var verifyCmd = &cobra.Command{ os.Exit(1) } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + cfg, err := config.LoadDefaultConfig(cmd.Context()) + if err != nil { + fmt.Printf("Error constructing aws config: %v", err) + os.Exit(1) + } + client := imds.NewFromConfig(cfg) + resp, err := client.GetRegion(cmd.Context(), nil) if err != nil { fmt.Printf("[Warn] Region not found in instance metadata, err: %v", err) } - id, err := token.NewVerifier(clusterID, partition, instanceRegion).Verify(tok) + id, err := token.NewVerifier(clusterID, partition, resp.Region).Verify(tok) if err != nil { fmt.Fprintf(os.Stderr, "could not verify token: %v\n", err) os.Exit(1) diff --git a/go.mod b/go.mod index 58736482f..8fbd7f9a3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,12 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 github.com/aws/aws-sdk-go-v2 v1.30.4 + github.com/aws/aws-sdk-go-v2/config v1.27.30 + github.com/aws/aws-sdk-go-v2/credentials v1.17.29 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0 + github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 + github.com/aws/smithy-go v1.20.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 @@ -26,13 +32,20 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/emicklei/go-restful/v3 v3.11.1 // indirect + github.com/emicklei/go-restful/v3 v3.11.3 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.20.2 // indirect diff --git a/go.sum b/go.sum index 9d1fc5538..05290a6e2 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,30 @@ github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/aws-sdk-go-v2/config v1.27.30 h1:AQF3/+rOgeJBQP3iI4vojlPib5X6eeOYoa/af7OxAYg= +github.com/aws/aws-sdk-go-v2/config v1.27.30/go.mod h1:yxqvuubha9Vw8stEgNiStO+yZpP68Wm9hLmcm+R/Qk4= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29 h1:CwGsupsXIlAFYuDVHv1nnK0wnxO0wZ/g1L8DSK/xiIw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29/go.mod h1:BPJ/yXV92ZVq6G8uYvbU0gSl8q94UB63nMT5ctNO38g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 h1:yjwoSyDZF8Jth+mUk5lSPJCkMC0lMy6FaCD51jm6ayE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12/go.mod h1:fuR57fAgMk7ot3WcNQfb6rSEn+SUffl7ri+aa8uKysI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0 h1:fWhkSvaQqa5eWiRwBw10FUnk1YatAQ9We4GdGxKiCtg= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0/go.mod h1:ISODge3zgdwOEa4Ou6WM9PKbxJWJ15DYKnr2bfmCAIA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5/go.mod h1:ZeDX1SnKsVlejeuz41GiajjZpRSWR7/42q/EyA/QEiM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 h1:SKvPgvdvmiTWoi0GAJ7AsJfOz3ngVkD/ERbs5pUnHNI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5/go.mod h1:20sz31hv/WsPa3HhU3hfrIet2kxM4Pe0r20eBZ20Tac= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 h1:OMsEmCyz2i89XwRwPouAJvhj81wINh+4UK+k/0Yo/q8= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5/go.mod h1:vmSqFK+BVIwVpDAGZB3CoCXHzurt4qBE8lf+I/kRTh0= github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -21,8 +45,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.1 h1:S+9bSbua1z3FgCnV0KKOSSZ3mDthb5NyEPL5gEpCvyk= -github.com/emicklei/go-restful/v3 v3.11.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.11.3 h1:yagOQz/38xJmcNeZJtrUcKjkHRltIaIFXKWeG1SkWGE= +github.com/emicklei/go-restful/v3 v3.11.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= diff --git a/pkg/arn/arn.go b/pkg/arn/arn.go index e9b73b587..22900c96d 100644 --- a/pkg/arn/arn.go +++ b/pkg/arn/arn.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - awsarn "github.com/aws/aws-sdk-go/aws/arn" + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go/aws/endpoints" ) diff --git a/pkg/ec2provider/ec2provider.go b/pkg/ec2provider/ec2provider.go index d760f0bda..5d18ec720 100644 --- a/pkg/ec2provider/ec2provider.go +++ b/pkg/ec2provider/ec2provider.go @@ -1,25 +1,24 @@ package ec2provider import ( + "context" "errors" "fmt" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/sirupsen/logrus" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/httputil" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" + "github.com/sirupsen/logrus" ) const ( @@ -36,11 +35,6 @@ const ( // Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has // already become 100 then it will not respect this limit maxWaitIntervalForBatch = 200 - - // Headers for STS request for source ARN - headerSourceArn = "x-amz-source-arn" - // Headers for STS request for source account - headerSourceAccount = "x-amz-source-account" ) // Get a node name from instance ID @@ -60,13 +54,13 @@ type ec2Requests struct { } type ec2ProviderImpl struct { - ec2 ec2iface.EC2API + ec2 ec2.DescribeInstancesAPIClient privateDNSCache ec2PrivateDNSCache ec2Requests ec2Requests instanceIdsChannel chan string } -func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { +func New(roleARN, sourceARN, region string, qps int, burst int) (EC2Provider, error) { dnsCache := ec2PrivateDNSCache{ cache: make(map[string]string), lock: sync.RWMutex{}, @@ -75,50 +69,56 @@ func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { set: make(map[string]bool), lock: sync.RWMutex{}, } + cfg, err := newConfig(roleARN, sourceARN, region, qps, burst) + if err != nil { + return nil, err + } + return &ec2ProviderImpl{ - ec2: ec2.New(newSession(roleARN, sourceARN, region, qps, burst)), + ec2: ec2.NewFromConfig(cfg), privateDNSCache: dnsCache, ec2Requests: ec2Requests, instanceIdsChannel: make(chan string, maxChannelSize), - } + }, nil } -// Initial credentials loaded from SDK's default credential chain, such as -// the environment, shared credentials (~/.aws/credentials), or EC2 Instance -// Role. - -func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.Session { - sess := session.Must(session.NewSession()) - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) - if aws.StringValue(sess.Config.Region) == "" { - sess.Config.Region = aws.String(region) +func newConfig(roleARN, sourceArn, region string, qps, burst int) (aws.Config, error) { + rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst) + if err != nil { + logrus.Errorf("error creating rate limited client %s", err) + return aws.Config{}, err + } + loadOpts := []func(*config.LoadOptions) error{ + config.WithRegion(region), + config.WithAPIOptions( + []func(*smithymiddleware.Stack) error{ + middleware.AddUserAgentKeyValue("aws-iam-authenticator", pkg.Version), + }), + config.WithHTTPClient(rateLimitedClient), } - if roleARN != "" { logrus.WithFields(logrus.Fields{ "roleARN": roleARN, }).Infof("Using assumed role for EC2 API") - rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst) - + cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...) if err != nil { - logrus.Errorf("Getting error = %s while creating rate limited client ", err) + logrus.Errorf("error loading AWS config %s", err) + return aws.Config{}, err } - - stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), sourceARN) - ap := &stscreds.AssumeRoleProvider{ - Client: stsClient, - RoleARN: roleARN, - Duration: time.Duration(60) * time.Minute, + stsOpts := []func(*sts.Options){} + if sourceArn != "" { + stsOpts = append(stsOpts, WithSourceHeaders(sourceArn)) } - sess.Config.Credentials = credentials.NewCredentials(ap) + stsCli := sts.NewFromConfig(cfg, stsOpts...) + creds := stscreds.NewAssumeRoleProvider(stsCli, roleARN, + func(o *stscreds.AssumeRoleOptions) { + o.Duration = time.Duration(60) * time.Minute + }) + loadOpts = append(loadOpts, config.WithCredentialsProvider(creds)) } - return sess + return config.LoadDefaultConfig(context.Background(), loadOpts...) } func (p *ec2ProviderImpl) setPrivateDNSNameCache(id string, privateDNSName string) { @@ -197,8 +197,8 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) { logrus.Infof("Calling ec2:DescribeInstances for the InstanceId = %s ", id) metrics.Get().EC2DescribeInstanceCallCount.Inc() // Look up instance from EC2 API - output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice([]string{id}), + output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: []string{id}, }) if err != nil { p.unsetRequestInFlightForInstanceId(id) @@ -206,8 +206,8 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) { } for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { - if aws.StringValue(instance.InstanceId) == id { - privateDNSName = aws.StringValue(instance.PrivateDnsName) + if aws.ToString(instance.InstanceId) == id { + privateDNSName = aws.ToString(instance.PrivateDnsName) p.setPrivateDNSNameCache(id, privateDNSName) p.unsetRequestInFlightForInstanceId(id) } @@ -258,8 +258,8 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string // Look up instance from EC2 API logrus.Infof("Making Batch Query to DescribeInstances for %v instances ", len(instanceIdList)) metrics.Get().EC2DescribeInstanceCallCount.Inc() - output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice(instanceIdList), + output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: instanceIdList, }) if err != nil { logrus.Errorf("Batch call failed querying private DNS from EC2 API for nodes [%s] : with error = []%s ", instanceIdList, err.Error()) @@ -272,8 +272,8 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string // Adding the result to privateDNSChache as well as removing from the requestQueueMap. for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { - id := aws.StringValue(instance.InstanceId) - privateDNSName := aws.StringValue(instance.PrivateDnsName) + id := aws.ToString(instance.InstanceId) + privateDNSName := aws.ToString(instance.PrivateDnsName) p.setPrivateDNSNameCache(id, privateDNSName) } } @@ -284,40 +284,3 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string p.unsetRequestInFlightForInstanceId(id) } } - -func applySTSRequestHeaders(stsClient *sts.STS, sourceARN string) *sts.STS { - // parse both source account and source arn from the sourceARN, and add them as headers to the STS client - if sourceARN != "" { - sourceAcct, err := getSourceAccount(sourceARN) - if err != nil { - panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceARN, err)) - } - reqHeaders := map[string]string{ - headerSourceAccount: sourceAcct, - headerSourceArn: sourceARN, - } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) - }) - logrus.Infof("configuring STS client with extra headers, %v", reqHeaders) - } - return stsClient -} - -// getSourceAccount constructs source acct and return them for use -func getSourceAccount(roleARN string) (string, error) { - // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) - // arn:partition:service:region:account-id:resource-type/resource-id - // IAM format, region is always blank - // arn:aws:iam::account:role/role-name-with-path - if !arn.IsARN(roleARN) { - return "", fmt.Errorf("incorrect ARN format for role %s", roleARN) - } - - parsedArn, err := arn.Parse(roleARN) - if err != nil { - return "", err - } - - return parsedArn.AccountID, nil -} diff --git a/pkg/ec2provider/ec2provider_test.go b/pkg/ec2provider/ec2provider_test.go index 912d73c8c..21ebb94aa 100644 --- a/pkg/ec2provider/ec2provider_test.go +++ b/pkg/ec2provider/ec2provider_test.go @@ -1,14 +1,15 @@ package ec2provider import ( + "context" "strconv" "sync" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/prometheus/client_golang/prometheus" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -18,25 +19,24 @@ const ( ) type mockEc2Client struct { - ec2iface.EC2API - Reservations []*ec2.Reservation + Reservations []ec2types.Reservation } -func (c *mockEc2Client) DescribeInstances(in *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { +func (c *mockEc2Client) DescribeInstances(ctx context.Context, in *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { // simulate the time it takes for aws to return time.Sleep(DescribeDelay * time.Millisecond) - var reservations []*ec2.Reservation + var reservations []ec2types.Reservation for _, res := range c.Reservations { - var reservation ec2.Reservation + var reservation ec2types.Reservation for _, inst := range res.Instances { for _, id := range in.InstanceIds { - if aws.StringValue(id) == aws.StringValue(inst.InstanceId) { + if id == aws.ToString(inst.InstanceId) { reservation.Instances = append(reservation.Instances, inst) } } } if len(reservation.Instances) > 0 { - reservations = append(reservations, &reservation) + reservations = append(reservations, reservation) } } return &ec2.DescribeInstancesOutput{ @@ -76,12 +76,12 @@ func TestGetPrivateDNSName(t *testing.T) { } } -func prepareSingleInstanceOutput() []*ec2.Reservation { - reservations := []*ec2.Reservation{ +func prepareSingleInstanceOutput() []ec2types.Reservation { + reservations := []ec2types.Reservation{ { Groups: nil, - Instances: []*ec2.Instance{ - &ec2.Instance{ + Instances: []ec2types.Instance{ + ec2types.Instance{ InstanceId: aws.String("ec2-1"), PrivateDnsName: aws.String("ec2-dns-1"), }, @@ -125,20 +125,20 @@ func getPrivateDNSName(ec2provider *ec2ProviderImpl, instanceString string, dnsS } } -func prepare100InstanceOutput() []*ec2.Reservation { +func prepare100InstanceOutput() []ec2types.Reservation { - var reservations []*ec2.Reservation + var reservations []ec2types.Reservation for i := 1; i < 101; i++ { instanceString := "ec2-" + strconv.Itoa(i) dnsString := "ec2-dns-" + strconv.Itoa(i) - instance := &ec2.Instance{ + instance := ec2types.Instance{ InstanceId: aws.String(instanceString), PrivateDnsName: aws.String(dnsString), } - var instances []*ec2.Instance + var instances []ec2types.Instance instances = append(instances, instance) - res1 := &ec2.Reservation{ + res1 := ec2types.Reservation{ Groups: nil, Instances: instances, OwnerId: nil, @@ -150,44 +150,3 @@ func prepare100InstanceOutput() []*ec2.Reservation { return reservations } - -func TestGetSourceAcctAndArn(t *testing.T) { - type args struct { - roleARN string - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "corect role arn", - args: args{ - roleARN: "arn:aws:iam::123456789876:role/test-cluster", - }, - want: "123456789876", - wantErr: false, - }, - { - name: "incorect role arn", - args: args{ - roleARN: "arn:aws:iam::123456789876", - }, - want: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getSourceAccount(tt.args.roleARN) - if (err != nil) != tt.wantErr { - t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/ec2provider/source_headers.go b/pkg/ec2provider/source_headers.go new file mode 100644 index 000000000..f2473a689 --- /dev/null +++ b/pkg/ec2provider/source_headers.go @@ -0,0 +1,66 @@ +package ec2provider + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws/arn" + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +const ( + // Headers for STS request for source ARN + headerSourceArn = "x-amz-source-arn" + // Headers for STS request for source account + headerSourceAccount = "x-amz-source-account" +) + +type withSourceHeaders struct { + sourceARN string +} + +// implements middleware.BuildMiddleware, which runs AFTER a request has been +// serialized and can operate on the transport request +var _ smithymiddleware.BuildMiddleware = (*withSourceHeaders)(nil) + +func (*withSourceHeaders) ID() string { + return "withSourceHeaders" +} + +func (m *withSourceHeaders) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) ( + out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + + if arn.IsARN(m.sourceARN) { + req.Header.Set(headerSourceArn, m.sourceARN) + } + + if parsedArn, err := arn.Parse(m.sourceARN); err == nil && parsedArn.AccountID != "" { + req.Header.Set(headerSourceAccount, parsedArn.AccountID) + } + + return next.HandleBuild(ctx, in) +} + +// WithSourceHeaders adds the x-amz-source-arn and x-amz-source-account headers to the request. +// These can be referenced in an IAM role trust policy document with the condition keys +// aws:SourceArn and aws:SourceAccount for sts:AssumeRole calls +// +// If the sourceARN is invalid, the source arn header is skipped. If the ARN is valid but doesn't +// contain an account ID, the source account header is skipped +func WithSourceHeaders(sourceARN string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *smithymiddleware.Stack) error { + return s.Build.Add(&withSourceHeaders{ + sourceARN: sourceARN, + }, smithymiddleware.After) + }) + } +} diff --git a/pkg/filecache/converter.go b/pkg/filecache/converter.go deleted file mode 100644 index ec2f16bde..000000000 --- a/pkg/filecache/converter.go +++ /dev/null @@ -1,55 +0,0 @@ -package filecache - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws/credentials" -) - -type v2 struct { - creds *credentials.Credentials -} - -var _ aws.CredentialsProvider = &v2{} - -func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) { - val, err := p.creds.GetWithContext(ctx) - if err != nil { - return aws.Credentials{}, err - } - resp := aws.Credentials{ - AccessKeyID: val.AccessKeyID, - SecretAccessKey: val.SecretAccessKey, - SessionToken: val.SessionToken, - Source: val.ProviderName, - CanExpire: false, - // Don't have account ID - } - - if expiration, err := p.creds.ExpiresAt(); err != nil { - resp.CanExpire = true - resp.Expires = expiration - } - return resp, nil -} - -// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider -func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider { - return V1CredentialToV2Provider(credentials.NewCredentials(p)) -} - -// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider -func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider { - return &v2{creds: c} -} - -// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value -func V2CredentialToV1Value(cred aws.Credentials) credentials.Value { - return credentials.Value{ - AccessKeyID: cred.AccessKeyID, - SecretAccessKey: cred.SecretAccessKey, - SessionToken: cred.SessionToken, - ProviderName: cred.Source, - } -} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go index 64092b9f4..3d8624a2c 100644 --- a/pkg/filecache/filecache.go +++ b/pkg/filecache/filecache.go @@ -11,7 +11,6 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" "github.com/spf13/afero" "gopkg.in/yaml.v2" @@ -136,7 +135,7 @@ type FileCacheProvider struct { cachedCredential aws.Credentials // the cached credential, if it exists } -var _ credentials.Provider = &FileCacheProvider{} +var _ aws.CredentialsProvider = &FileCacheProvider{} // NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, // and works with an on disk cache to speed up credential usage when the cached copy is not expired. @@ -200,26 +199,19 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.Crede return resp, nil } -// Retrieve() implements the Provider interface, returning the cached credential if is not expired, -// otherwise fetching the credential from the underlying Provider and caching the results on disk +// Retrieve() implements the aws.CredentialsProvider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying CredentialProvider and caching the results on disk // with an expiration time. -func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - return f.RetrieveWithContext(context.Background()) -} - -// Retrieve() implements the Provider interface, returning the cached credential if is not expired, -// otherwise fetching the credential from the underlying Provider and caching the results on disk -// with an expiration time. -func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { +func (f *FileCacheProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { // use the cached credential - return V2CredentialToV1Value(f.cachedCredential), nil + return f.cachedCredential, nil } else { _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") // fetch the credentials from the underlying Provider credential, err := f.provider.Retrieve(ctx) if err != nil { - return V2CredentialToV1Value(credential), err + return credential, err } if credential.CanExpire { @@ -235,7 +227,7 @@ func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credential if !ok { // can't get write lock to create/update cache, but still return the credential _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) - return V2CredentialToV1Value(credential), nil + return credential, nil } f.cachedCredential = credential // don't really care about read error. Either read the cache, or we create a new cache. @@ -254,21 +246,10 @@ func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credential _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) err = nil } - return V2CredentialToV1Value(credential), err + return credential, err } } -// IsExpired() implements the Provider interface, deferring to the cached credential first, -// but fall back to the underlying Provider if it is expired. -func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.CanExpire && f.cachedCredential.Expired() -} - -// ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential -func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expires -} - // defaultCacheFilename returns the name of the credential cache file, which can either be // set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml func defaultCacheFilename() string { diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go index f2db98556..7c8fbaafd 100644 --- a/pkg/filecache/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -383,10 +383,6 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { t.Errorf("Cached credential should not be expired") } - if p.ExpiresAt() != p.cachedCredential.Expires { - t.Errorf("Credential expiration time is not correct, expected %v, got %v", - p.cachedCredential.Expires, p.ExpiresAt()) - } } func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { @@ -407,7 +403,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { ) validateFileCacheProvider(t, p, err, provider) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -442,12 +438,12 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { tfl.success = false tfl.err = errors.New("lock stuck, needs wd-40") - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || - credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" { + credential.SessionToken != "TOKEN" || credential.Source != "stubProvider" { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } @@ -471,14 +467,14 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { ) validateFileCacheProvider(t, p, err, provider) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != provider.creds.AccessKeyID || credential.SecretAccessKey != provider.creds.SecretAccessKey || credential.SessionToken != provider.creds.SessionToken || - credential.ProviderName != provider.creds.Source { + credential.Source != provider.creds.Source { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } @@ -525,14 +521,14 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, // but write to disk (code coverage) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != provider.creds.AccessKeyID || credential.SecretAccessKey != provider.creds.SecretAccessKey || credential.SessionToken != provider.creds.SessionToken || - credential.ProviderName != provider.creds.Source { + credential.Source != provider.creds.Source { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } @@ -567,19 +563,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { })) validateFileCacheProvider(t, p, err, provider) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != "ABC" || credential.SecretAccessKey != "DEF" || - credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { + credential.SessionToken != "GHI" || credential.Source != "JKL" || + !credential.Expires.Equal(currentTime.Add(time.Hour*6)) { t.Errorf("cached credential not returned") } - if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) { - t.Errorf("unexpected expiration time: got %s, wanted %s", - p.ExpiresAt().Format(time.RFC3339Nano), - currentTime.Add(time.Hour*6).Format(time.RFC3339Nano), - ) - } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 045f948c2..ce47f0f7d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -28,8 +28,6 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/session" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider" "sigs.k8s.io/aws-iam-authenticator/pkg/errutil" @@ -42,7 +40,9 @@ import ( "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" "sigs.k8s.io/aws-iam-authenticator/pkg/token" - awsarn "github.com/aws/aws-sdk-go/aws/arn" + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" authenticationv1beta1 "k8s.io/api/authentication/v1beta1" @@ -88,7 +88,7 @@ func New(cfg config.Config, stopCh <-chan struct{}) *Server { backendMapper, err := BuildMapperChain(cfg, cfg.BackendMode) if err != nil { - logrus.Fatalf("failed to build mapper chain: %v", err) + logrus.WithError(err).Fatal("failed to build mapper chain") } for _, mapping := range c.RoleMappings { @@ -144,7 +144,11 @@ func New(cfg config.Config, stopCh <-chan struct{}) *Server { logrus.Infof("listening on %s", listener.Addr()) logrus.Infof("reconfigure your apiserver with `--authentication-token-webhook-config-file=%s` to enable (assuming default hostPath mounts)", c.GenerateKubeconfigPath) - internalHandler := c.getHandler(backendMapper, c.EC2DescribeInstancesQps, c.EC2DescribeInstancesBurst, stopCh) + internalHandler, err := c.getHandler(backendMapper, c.EC2DescribeInstancesQps, c.EC2DescribeInstancesBurst, stopCh) + if err != nil { + logrus.WithError(err).Fatal("Failed to create handlers") + } + c.httpServer = http.Server{ ErrorLog: log.New(errLog, "", 0), Handler: internalHandler, @@ -191,23 +195,35 @@ type healthzHandler struct{} func (m *healthzHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "ok") } -func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2DescribeBurst int, stopCh <-chan struct{}) *handler { + +func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2DescribeBurst int, stopCh <-chan struct{}) (*handler, error) { if c.ServerEC2DescribeInstancesRoleARN != "" { _, err := awsarn.Parse(c.ServerEC2DescribeInstancesRoleARN) if err != nil { - panic(fmt.Sprintf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN)) + logrus.WithError(err).Errorf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN) + return nil, err } } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + ctx := context.Background() + cfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + logrus.WithError(err).Error("EC2 instance metadata not configured") + } + cli := imds.NewFromConfig(cfg) + resp, err := cli.GetRegion(ctx, nil) + if err != nil { + logrus.WithError(err).Error("region not found in instance metadata.") + } + + ec2Prov, err := ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, resp.Region, ec2DescribeQps, ec2DescribeBurst) if err != nil { - logrus.WithError(err).Errorln("Region not found in instance metadata.") + logrus.WithError(err).Errorln("error initializing EC2 provider") + return nil, err } h := &handler{ - verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion), - ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst), + verifier: token.NewVerifier(c.ClusterID, c.PartitionID, resp.Region), + ec2Provider: ec2Prov, clusterID: c.ClusterID, backendMapper: backendMapper, scrubbedAccounts: c.Config.ScrubbedAWSAccounts, @@ -226,7 +242,7 @@ func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2 fileutil.StartLoadDynamicFile(c.DynamicBackendModePath, h, stopCh) } - return h + return h, nil } func BuildMapperChain(cfg config.Config, modes []string) (BackendMapper, error) { diff --git a/pkg/token/cluster_id_header.go b/pkg/token/cluster_id_header.go new file mode 100644 index 000000000..1f6a43d32 --- /dev/null +++ b/pkg/token/cluster_id_header.go @@ -0,0 +1,45 @@ +package token + +import ( + "context" + "fmt" + + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +type withClusterIDHeader struct { + clusterID string +} + +// implements middleware.BuildMiddleware, which runs AFTER a request has been +// serialized and can operate on the transport request +var _ smithymiddleware.BuildMiddleware = (*withClusterIDHeader)(nil) + +func (*withClusterIDHeader) ID() string { + return "withClusterIDHeader" +} + +func (m *withClusterIDHeader) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) ( + out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + req.Header.Set(clusterIDHeader, m.clusterID) + return next.HandleBuild(ctx, in) +} + +// WithClusterIDHeader adds the clusterID header to the request befor signing +func WithClusterIDHeader(clusterID string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *smithymiddleware.Stack) error { + return s.Build.Add(&withClusterIDHeader{ + clusterID: clusterID, + }, smithymiddleware.After) + }) + } +} diff --git a/pkg/token/token.go b/pkg/token/token.go index 716a8cb12..9ad001f1c 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -17,6 +17,7 @@ limitations under the License. package token import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -28,24 +29,24 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "sigs.k8s.io/aws-iam-authenticator/pkg" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" + "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + smithymiddleware "github.com/aws/smithy-go/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/pkg/apis/clientauthentication" clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" - "sigs.k8s.io/aws-iam-authenticator/pkg" - "sigs.k8s.io/aws-iam-authenticator/pkg/arn" - "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" - "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) // Identity is returned on successful Verify() results. It contains a parsed @@ -180,12 +181,16 @@ type getCallerIdentityWrapper struct { } `json:"GetCallerIdentityResponse"` } +type GCIPresigner interface { + PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) +} + // Generator provides new tokens for the AWS IAM Authenticator. type Generator interface { // Get a token using the provided options - GetWithOptions(options *GetTokenOptions) (Token, error) - // GetWithSTS returns a token valid for clusterID using the given STS client. - GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) + GetWithOptions(*GetTokenOptions) (Token, error) + // Presign returns a Token using the given STS client + Presign(GCIPresigner) (Token, error) // FormatJSON returns the client auth formatted json for the ExecCredential auth FormatJSON(Token) string } @@ -220,23 +225,14 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { if options.ClusterID == "" { return Token{}, fmt.Errorf("ClusterID is required") } - - // create a session with the "base" credentials available - // (from environment variable, profile files, EC2 metadata, etc) - sess, err := session.NewSessionWithOptions(session.Options{ - AssumeRoleTokenProvider: StdinStderrTokenProvider, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return Token{}, fmt.Errorf("could not create session: %v", err) + loadOpts := []func(*config.LoadOptions) error{ + config.WithAPIOptions( + []func(*smithymiddleware.Stack) error{ + middleware.AddUserAgentKeyValue("aws-iam-authenticator", pkg.Version), + }), } - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) if options.Region != "" { - sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)) + loadOpts = append(loadOpts, config.WithRegion(options.Region)) } if g.cache { @@ -244,90 +240,91 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { var profile string if v := os.Getenv("AWS_PROFILE"); len(v) > 0 { profile = v - } else { - profile = session.DefaultSharedConfigProfile } - // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := filecache.NewFileCacheProvider( + // Create a new config to get the default cred chain + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + return Token{}, fmt.Errorf("could not create config: %v", err) + } + // create a caching Provider wrapper around the Credentials + cacheProvider, err := filecache.NewFileCacheProvider( options.ClusterID, profile, options.AssumeRoleARN, - filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil { - sess.Config.Credentials = credentials.NewCredentials(cacheProvider) + cfg.Credentials, + ) + if err == nil { + loadOpts = append(loadOpts, config.WithCredentialsProvider(cacheProvider)) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) } } - // use an STS client based on the direct credentials - stsAPI := sts.New(sess) - - // if a roleARN was specified, replace the STS client with one that uses - // temporary credentials from that role. - if options.AssumeRoleARN != "" { - var sessionSetters []func(*stscreds.AssumeRoleProvider) + cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...) + if err != nil { + return Token{}, fmt.Errorf("could not create config: %v", err) - if options.AssumeRoleExternalID != "" { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.ExternalID = &options.AssumeRoleExternalID - }) - } + } + if options.AssumeRoleARN != "" { + var sessionName = options.SessionName if g.forwardSessionName { - // If the current session is already a federated identity, carry through - // this session name onto the new session to provide better debugging - // capabilities - resp, err := stsAPI.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + stsSvc := sts.NewFromConfig(cfg) + gciResp, err := stsSvc.GetCallerIdentity(context.Background(), nil) if err != nil { return Token{}, err } - - userIDParts := strings.Split(*resp.UserId, ":") + userIDParts := strings.Split(aws.ToString(gciResp.UserId), ":") if len(userIDParts) == 2 { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.RoleSessionName = userIDParts[1] - }) + sessionName = userIDParts[1] } - } else if options.SessionName != "" { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.RoleSessionName = options.SessionName - }) } - // create STS-based credentials that will assume the given role - creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...) + creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), options.AssumeRoleARN, func(o *stscreds.AssumeRoleOptions) { + o.RoleSessionName = sessionName + o.ExternalID = &options.AssumeRoleExternalID + // TODO: Can we get the serial number from the client? + // o.SerialNumber = aws.String("myTokenSerialNumber") + o.TokenProvider = stscreds.StdinTokenProvider + }) - // create an STS API interface that uses the assumed role's temporary credentials - stsAPI = sts.New(sess, &aws.Config{Credentials: creds}) + cfg.Credentials = aws.NewCredentialsCache(creds) } - return g.GetWithSTS(options.ClusterID, stsAPI) + stsSvc := sts.NewFromConfig(cfg, WithClusterIDHeader(options.ClusterID)) + presigner := timedPresigner{v4.NewSigner(), g.nowFunc} + presignClient := sts.NewPresignClient(stsSvc, func(o *sts.PresignOptions) { + o.Presigner = &presigner + }) + return g.Presign(presignClient) } -func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler { - return request.NamedHandler{ - Name: "v4.SignRequestHandler", Fn: func(req *request.Request) { - v4.SignSDKRequestWithCurrentTime(req, nowFunc) - }, - } +// timedPresigner exists to wrap PresignHTTP() with a specific time +// and set the x-amz-expires header in the query url +type timedPresigner struct { + signer *v4.Signer + timeFunc func() time.Time } -// GetWithSTS returns a token valid for clusterID using the given STS client. -func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) { - // generate an sts:GetCallerIdentity request and add our custom cluster ID header - request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) - - // override the Sign handler so we can control the now time for testing. - request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) +func (p *timedPresigner) PresignHTTP( + ctx context.Context, credentials aws.Credentials, r *http.Request, + payloadHash string, service string, region string, _ time.Time, + optFns ...func(*v4.SignerOptions), +) (url string, signedHeader http.Header, err error) { + query := r.URL.Query() + query.Set("X-Amz-Expires", strconv.Itoa(requestPresignParam)) + r.URL.RawQuery = query.Encode() + return p.signer.PresignHTTP(ctx, credentials, r, payloadHash, service, region, p.timeFunc(), optFns...) +} +// Presign returns a token valid for clusterID using the given STS client. +func (g generator) Presign(presigner GCIPresigner) (Token, error) { // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date - // timestamp regardless. We set it to 60 seconds for backwards compatibility (the - // parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between - // 0 and 60 on the server side). + // timestamp regardless. // https://github.com/aws/aws-sdk-go/issues/2167 - presignedURLString, err := request.Presign(requestPresignParam * time.Second) + + req, err := presigner.PresignGetCallerIdentity(context.Background(), nil) if err != nil { return Token{}, err } @@ -335,7 +332,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, // Set token expiration to 1 minute before the presigned URL expires for some cushion tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) // TODO: this may need to be a constant-time base64 encoding - return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil + return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(req.URL)), tokenExpiration}, nil } // FormatJSON formats the json to support ExecCredential authentication diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index a8e997c86..ce43cca45 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -2,6 +2,7 @@ package token import ( "bytes" + "context" "encoding/base64" "encoding/json" "errors" @@ -14,11 +15,12 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -587,12 +589,12 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { } } -func TestGetWithSTS(t *testing.T) { +func TestPresign(t *testing.T) { clusterID := "test-cluster" cases := []struct { name string - creds *credentials.Credentials + creds aws.CredentialsProvider nowTime time.Time want Token wantErr error @@ -600,10 +602,10 @@ func TestGetWithSTS(t *testing.T) { { "Non-zero time", // Example non-real credentials - func() *credentials.Credentials { + func() credentials.StaticCredentialsProvider { decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") - return credentials.NewStaticCredentials( + return credentials.NewStaticCredentialsProvider( string(decodedAkid), string(decodedSk), "", @@ -620,13 +622,15 @@ func TestGetWithSTS(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - svc := sts.New(session.Must(session.NewSession( - &aws.Config{ - Credentials: tc.creds, - Region: aws.String("us-west-2"), - STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, - }, - ))) + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion("us-west-2"), + config.WithCredentialsProvider(tc.creds), + ) + if err != nil { + t.Errorf("unexpected error initialzing config: %v", err) + return + } gen := &generator{ forwardSessionName: false, @@ -634,7 +638,13 @@ func TestGetWithSTS(t *testing.T) { nowFunc: func() time.Time { return tc.nowTime }, } - got, err := gen.GetWithSTS(clusterID, svc) + stsSvc := sts.NewFromConfig(cfg, WithClusterIDHeader(clusterID)) + presigner := timedPresigner{v4.NewSigner(), gen.nowFunc} + presignClient := sts.NewPresignClient(stsSvc, func(o *sts.PresignOptions) { + o.Presigner = &presigner + }) + + got, err := gen.Presign(presignClient) if diff := cmp.Diff(err, tc.wantErr); diff != "" { t.Errorf("Unexpected error: %s", diff) }