Skip to content

Commit

Permalink
Merge pull request #285 from christopher-dG/cdg/awsv2
Browse files Browse the repository at this point in the history
Migrate to AWS SDK v2
  • Loading branch information
kzys authored Oct 27, 2021
2 parents b94b92f + a97ae4b commit d0e3d70
Show file tree
Hide file tree
Showing 833 changed files with 88,367 additions and 55,348 deletions.
29 changes: 15 additions & 14 deletions ecr-login/api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,17 @@
package api

import (
"context"
"encoding/base64"
"fmt"
"net/url"
"regexp"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ecr"
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"

Expand Down Expand Up @@ -103,11 +104,11 @@ type defaultClient struct {
}

type ECRAPI interface {
GetAuthorizationToken(*ecr.GetAuthorizationTokenInput) (*ecr.GetAuthorizationTokenOutput, error)
GetAuthorizationToken(context.Context, *ecr.GetAuthorizationTokenInput, ...func(*ecr.Options)) (*ecr.GetAuthorizationTokenOutput, error)
}

type ECRPublicAPI interface {
GetAuthorizationToken(*ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error)
GetAuthorizationToken(context.Context, *ecrpublic.GetAuthorizationTokenInput, ...func(*ecrpublic.Options)) (*ecrpublic.GetAuthorizationTokenOutput, error)
}

// GetCredentials returns username, password, and proxyEndpoint
Expand Down Expand Up @@ -213,11 +214,11 @@ func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error)
} else {
logrus.WithField("registry", registryID).Debug("Calling ECR.GetAuthorizationToken")
input = &ecr.GetAuthorizationTokenInput{
RegistryIds: []*string{aws.String(registryID)},
RegistryIds: []string{registryID},
}
}

output, err := c.ecrClient.GetAuthorizationToken(input)
output, err := c.ecrClient.GetAuthorizationToken(context.TODO(), input)
if err != nil || output == nil {
if err == nil {
if registryID == "" {
Expand All @@ -232,10 +233,10 @@ func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error)
for _, authData := range output.AuthorizationData {
if authData.ProxyEndpoint != nil && authData.AuthorizationToken != nil {
authEntry := cache.AuthEntry{
AuthorizationToken: aws.StringValue(authData.AuthorizationToken),
AuthorizationToken: aws.ToString(authData.AuthorizationToken),
RequestedAt: time.Now(),
ExpiresAt: aws.TimeValue(authData.ExpiresAt),
ProxyEndpoint: aws.StringValue(authData.ProxyEndpoint),
ExpiresAt: aws.ToTime(authData.ExpiresAt),
ProxyEndpoint: aws.ToString(authData.ProxyEndpoint),
Service: cache.ServiceECR,
}
registry, err := ExtractRegistry(authEntry.ProxyEndpoint)
Expand All @@ -259,22 +260,22 @@ func (c *defaultClient) getAuthorizationToken(registryID string) (*Auth, error)
func (c *defaultClient) getPublicAuthorizationToken() (*Auth, error) {
var input *ecrpublic.GetAuthorizationTokenInput

output, err := c.ecrPublicClient.GetAuthorizationToken(input)
output, err := c.ecrPublicClient.GetAuthorizationToken(context.TODO(), input)
if err != nil {
return nil, errors.Wrap(err, "ecr: failed to get authorization token")
}
if output == nil || output.AuthorizationData == nil {
return nil, fmt.Errorf("ecr: missing AuthorizationData in ECR Public response")
}
authData := output.AuthorizationData
token, err := extractToken(aws.StringValue(authData.AuthorizationToken), ecrPublicEndpoint)
token, err := extractToken(aws.ToString(authData.AuthorizationToken), ecrPublicEndpoint)
if err != nil {
return nil, err
}
authEntry := cache.AuthEntry{
AuthorizationToken: aws.StringValue(authData.AuthorizationToken),
AuthorizationToken: aws.ToString(authData.AuthorizationToken),
RequestedAt: time.Now(),
ExpiresAt: aws.TimeValue(authData.ExpiresAt),
ExpiresAt: aws.ToTime(authData.ExpiresAt),
ProxyEndpoint: ecrPublicEndpoint,
Service: cache.ServiceECRPublic,
}
Expand Down
30 changes: 16 additions & 14 deletions ecr-login/api/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ import (
"testing"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/api/mocks"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/ecr"
ecrtypes "github.com/aws/aws-sdk-go-v2/service/ecr/types"
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
ecrpublictypes "github.com/aws/aws-sdk-go-v2/service/ecrpublic/types"
mock_api "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/api/mocks"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache/mocks"
mock_cache "github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache/mocks"
"github.com/stretchr/testify/assert"
)

Expand Down Expand Up @@ -146,7 +148,7 @@ func TestGetAuthConfigSuccess(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 1, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String(authorizationToken),
Expand Down Expand Up @@ -186,7 +188,7 @@ func TestGetAuthConfigNoMatchAuthorizationToken(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 1, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(proxyEndpointScheme + "notproxy"),
AuthorizationToken: aws.String(base64.StdEncoding.EncodeToString([]byte(expectedUsername + ":" + expectedPassword))),
}},
Expand Down Expand Up @@ -248,7 +250,7 @@ func TestGetAuthConfigSuccessInvalidCacheHit(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 1, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String(authorizationToken),
Expand Down Expand Up @@ -298,7 +300,7 @@ func TestGetAuthConfigBadBase64(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 1, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(proxyEndpointScheme + proxyEndpoint),
AuthorizationToken: aws.String(expectedUsername + ":" + expectedPassword),
}},
Expand Down Expand Up @@ -519,16 +521,16 @@ func TestListCredentialsEmpty(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 0, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String(authorizationToken),
}},
}, nil
}
ecrPublicClient.GetAuthorizationTokenFn = func(_ *ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error) {
ecrPublicClient.GetAuthorizationTokenFn = func(*ecrpublic.GetAuthorizationTokenInput) (*ecrpublic.GetAuthorizationTokenOutput, error) {
return &ecrpublic.GetAuthorizationTokenOutput{
AuthorizationData: &ecrpublic.AuthorizationData{
AuthorizationData: &ecrpublictypes.AuthorizationData{
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String(authorizationToken),
},
Expand Down Expand Up @@ -579,7 +581,7 @@ func TestListCredentialsBadBase64AuthToken(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 0, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String("invalid:token"),
Expand Down Expand Up @@ -619,7 +621,7 @@ func TestListCredentialsInvalidAuthToken(t *testing.T) {
assert.NotNil(t, input, "GetAuthorizationToken input")
assert.Len(t, input.RegistryIds, 0, "GetAuthorizationToken registry IDs len")
return &ecr.GetAuthorizationTokenOutput{
AuthorizationData: []*ecr.AuthorizationData{{
AuthorizationData: []ecrtypes.AuthorizationData{{
ProxyEndpoint: aws.String(testProxyEndpoint),
ExpiresAt: aws.Time(expiresAt),
AuthorizationToken: aws.String("invalidtoken"),
Expand Down
114 changes: 42 additions & 72 deletions ecr-login/api/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,27 @@
package api

import (
"os"
"strconv"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ecr"
"github.com/aws/aws-sdk-go/service/ecrpublic"
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/ecr"
"github.com/aws/aws-sdk-go-v2/service/ecrpublic"
"github.com/aws/smithy-go/middleware"
"github.com/aws/smithy-go/transport/http"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/cache"
"github.com/awslabs/amazon-ecr-credential-helper/ecr-login/version"
)

// Options makes the constructors more configurable
type Options struct {
Session *session.Session
Config *aws.Config
Config aws.Config
CacheDir string
}

// ClientFactory is a factory for creating clients to interact with ECR
type ClientFactory interface {
NewClient(awsSession *session.Session, awsConfig *aws.Config) Client
NewClient(awsConfig aws.Config) Client
NewClientWithOptions(opts Options) Client
NewClientFromRegion(region string) Client
NewClientWithFipsEndpoint(region string) (Client, error)
Expand All @@ -46,92 +44,64 @@ type ClientFactory interface {
// DefaultClientFactory is a default implementation of the ClientFactory
type DefaultClientFactory struct{}

var userAgentHandler = request.NamedHandler{
Name: "ecr-login.UserAgentHandler",
Fn: request.MakeAddToUserAgentHandler("amazon-ecr-credential-helper", version.Version),
}
var userAgentLoadOption = config.WithAPIOptions([]func(*middleware.Stack) error{
http.AddHeaderValue("User-Agent", "amazon-ecr-credential-helper/"+version.Version),
})

// NewClientWithDefaults creates the client and defaults region
func (defaultClientFactory DefaultClientFactory) NewClientWithDefaults() Client {
awsSession := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: loadSharedConfigState(),
}))
awsConfig, err := config.LoadDefaultConfig(context.TODO(), userAgentLoadOption)
if err != nil {
panic(err)
}

awsSession.Handlers.Build.PushBackNamed(userAgentHandler)
awsConfig := awsSession.Config
return defaultClientFactory.NewClientWithOptions(Options{
Session: awsSession,
Config: awsConfig,
})
return defaultClientFactory.NewClientWithOptions(Options{Config: awsConfig})
}

// NewClientWithFipsEndpoint overrides the default ECR service endpoint in a given region to use the FIPS endpoint
func (defaultClientFactory DefaultClientFactory) NewClientWithFipsEndpoint(region string) (Client, error) {
awsSession := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: loadSharedConfigState(),
}))

awsSession.Handlers.Build.PushBackNamed(userAgentHandler)

endpoint, err := getServiceEndpoint("ecr-fips", region)
awsConfig, err := config.LoadDefaultConfig(
context.TODO(),
userAgentLoadOption,
config.WithRegion(region),
config.WithEndpointDiscovery(aws.EndpointDiscoveryEnabled),
)
if err != nil {
return nil, err
}

awsConfig := awsSession.Config.WithEndpoint(endpoint).WithRegion(region)
return defaultClientFactory.NewClientWithOptions(Options{
Session: awsSession,
Config: awsConfig,
}), nil
return defaultClientFactory.NewClientWithOptions(Options{Config: awsConfig}), nil
}

// NewClientFromRegion uses the region to create the client
func (defaultClientFactory DefaultClientFactory) NewClientFromRegion(region string) Client {
awsSession := session.Must(session.NewSessionWithOptions(session.Options{
SharedConfigState: loadSharedConfigState(),
}))
awsSession.Handlers.Build.PushBackNamed(userAgentHandler)
awsConfig := &aws.Config{Region: aws.String(region)}
awsConfig, err := config.LoadDefaultConfig(
context.TODO(),
userAgentLoadOption,
config.WithRegion(region),
)
if err != nil {
panic(err)
}

return defaultClientFactory.NewClientWithOptions(Options{
Session: awsSession,
Config: awsConfig,
Config: awsConfig,
})
}

// NewClient Create new client with AWS Config
func (defaultClientFactory DefaultClientFactory) NewClient(awsSession *session.Session, awsConfig *aws.Config) Client {
return defaultClientFactory.NewClientWithOptions(Options{
Session: awsSession,
Config: awsConfig,
})
func (defaultClientFactory DefaultClientFactory) NewClient(awsConfig aws.Config) Client {
return defaultClientFactory.NewClientWithOptions(Options{Config: awsConfig})
}

// NewClientWithOptions Create new client with Options
func (defaultClientFactory DefaultClientFactory) NewClientWithOptions(opts Options) Client {
// The ECR Public API is only available in us-east-1 today
publicConfig := opts.Config.Copy().WithRegion("us-east-1")
publicConfig := opts.Config.Copy()
publicConfig.Region = "us-east-1"
return &defaultClient{
ecrClient: ecr.New(opts.Session, opts.Config),
ecrPublicClient: ecrpublic.New(opts.Session, publicConfig),
credentialCache: cache.BuildCredentialsCache(opts.Session, aws.StringValue(opts.Config.Region), opts.CacheDir),
}
}

func getServiceEndpoint(service, region string) (string, error) {
resolver := endpoints.DefaultResolver()
endpoint, err := resolver.EndpointFor(service, region, func(opts *endpoints.Options) {
opts.ResolveUnknownService = true
})
return endpoint.URL, err
}

func loadSharedConfigState() session.SharedConfigState {
loadConfig := os.Getenv("AWS_SDK_LOAD_CONFIG")
if loadConfig == "" {
return session.SharedConfigEnable
}
if enable, err := strconv.ParseBool(loadConfig); err == nil && !enable {
return session.SharedConfigDisable
ecrClient: ecr.NewFromConfig(opts.Config),
ecrPublicClient: ecrpublic.NewFromConfig(publicConfig),
credentialCache: cache.BuildCredentialsCache(opts.Config, opts.CacheDir),
}
return session.SharedConfigEnable
}
48 changes: 0 additions & 48 deletions ecr-login/api/factory_test.go

This file was deleted.

Loading

0 comments on commit d0e3d70

Please sign in to comment.