Skip to content

Commit

Permalink
Fix Unit tests in AWS cloud-prepare
Browse files Browse the repository at this point in the history
Signed-off-by: Aswin Suryanarayanan <[email protected]>
  • Loading branch information
aswinsuryan committed Jun 21, 2024
1 parent 924b671 commit 4071086
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 18 deletions.
6 changes: 5 additions & 1 deletion pkg/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,14 @@ func (ac *awsCloud) setSuffixes(vpcID string) error {
}

publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*"))
if err != nil || len(publicSubnets) == 0 {
if err != nil {
return errors.Wrapf(err, "unable to find the public subnet")
}

if len(publicSubnets) == 0 {
return errors.New("no public subnet found")
}

pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region))
re := regexp.MustCompile(pattern)

Expand Down
6 changes: 6 additions & 0 deletions pkg/aws/aws_cloud_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ func testOpenPorts() {

JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)

retError = t.cloud.OpenPorts([]api.PortSpec{
{
Expand Down Expand Up @@ -87,6 +89,7 @@ func testOpenPorts() {
When("authorize security group ingress validation fails", func() {
BeforeEach(func() {
t.expectDescribeVpcs(vpcID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectValidateAuthorizeSecurityGroupIngress(errors.New("mock error"))
})

Expand Down Expand Up @@ -114,6 +117,9 @@ func testClosePorts() {

JustBeforeEach(func() {
t.expectDescribeVpcs(t.vpcID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

retError = t.cloud.ClosePorts(reporter.Stdout())
})
Expand Down
69 changes: 52 additions & 17 deletions pkg/aws/aws_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,23 @@ import (
)

const (
infraID = "test-infra"
region = "test-region"
vpcID = "test-vpc"
workerGroupID = "worker-group"
masterGroupID = "master-group"
gatewayGroupID = "gateway-group"
internalTraffic = "Internal Submariner traffic"
availabilityZone1 = "availability-zone-1"
availabilityZone2 = "availability-zone-2"
subnetID1 = "subnet-1"
subnetID2 = "subnet-2"
instanceImageID = "test-image"
masterSGName = infraID + "-master-sg"
workerSGName = infraID + "-worker-sg"
gatewaySGName = infraID + "-submariner-gw-sg"
clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID
infraID = "test-infra"
region = "test-region"
vpcID = "test-vpc"
workerGroupID = "worker-group"
masterGroupID = "master-group"
gatewayGroupID = "gateway-group"
internalTraffic = "Internal Submariner traffic"
availabilityZone1 = "availability-zone-1"
availabilityZone2 = "availability-zone-2"
subnetID1 = "subnet-1"
subnetID2 = "subnet-2"
instanceImageID = "test-image"
masterSGName = infraID + "-master-sg"
workerSGName = infraID + "-worker-sg"
gatewaySGName = infraID + "-submariner-gw-sg"
clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID
clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID
)

var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic)
Expand All @@ -62,6 +63,7 @@ func TestAWS(t *testing.T) {
type fakeAWSClientBase struct {
awsClient *fake.MockInterface
vpcID string
subnets []types.Subnet
describeSubnetsErr error
authorizeSecurityGroupIngressErr error
createTagsErr error
Expand All @@ -71,6 +73,7 @@ type fakeAWSClientBase struct {
func (f *fakeAWSClientBase) beforeEach() {
f.awsClient = fake.NewMockInterface(GinkgoT())
f.vpcID = vpcID
f.subnets = []types.Subnet{newSubnet(availabilityZone1, subnetID1), newSubnet(availabilityZone2, subnetID2)}
f.describeSubnetsErr = nil
f.authorizeSecurityGroupIngressErr = nil
f.createTagsErr = nil
Expand Down Expand Up @@ -110,6 +113,25 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) {
}}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe()
}

func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) {
var vpcs []types.Vpc
if vpcID != "" {
vpcs = []types.Vpc{
{
VpcId: ptr.To(vpcID),
},
}
}

f.awsClient.EXPECT().DescribeVpcs(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{
Name: ptr.To("tag:Name"),
Values: []string{infraID + "-vpc"},
}, {
Name: ptr.To(clusterFilterTagNameSigs),
Values: []string{"owned"},
}}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe()
}

func (f *fakeAWSClientBase) expectValidateAuthorizeSecurityGroupIngress(authErr error) *mock.Call {
return f.awsClient.EXPECT().AuthorizeSecurityGroupIngress(mock.Anything,
mock.MatchedBy((&authorizeSecurityGroupIngressInputMatcher{ec2.AuthorizeSecurityGroupIngressInput{
Expand Down Expand Up @@ -144,7 +166,7 @@ func (f *fakeAWSClientBase) expectValidateRevokeSecurityGroupIngress(retErr erro
func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subnet) {
f.awsClient.EXPECT().DescribeSubnets(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{
Name: ptr.To("tag:Name"),
Values: []string{infraID + "-public-" + region + "*"},
Values: []string{infraID + "*-public-" + region + "*"},
}, {
Name: ptr.To("vpc-id"),
Values: []string{f.vpcID},
Expand All @@ -154,6 +176,19 @@ func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subn
}}}).Matches))).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).Maybe()
}

func (f *fakeAWSClientBase) expectDescribePublicSubnetsSigs(retSubnets ...types.Subnet) {
f.awsClient.EXPECT().DescribeSubnets(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{
Name: ptr.To("tag:Name"),
Values: []string{infraID + "*-public-" + region + "*"},
}, {
Name: ptr.To("vpc-id"),
Values: []string{f.vpcID},
}, {
Name: ptr.To(clusterFilterTagNameSigs),
Values: []string{"owned"},
}}}).Matches))).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).Maybe()
}

func (f *fakeAWSClientBase) expectDescribeGatewaySubnets(retSubnets ...types.Subnet) {
f.awsClient.EXPECT().DescribeSubnets(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{
Name: ptr.To("tag:submariner.io/gateway"),
Expand Down
3 changes: 3 additions & 0 deletions pkg/aws/ocpgwdeployer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver {
t.expectDescribeSecurityGroups(gatewaySGName, t.gatewayGroupID)
t.expectDescribeInstances(instanceImageID)
t.expectDescribeSecurityGroups(workerSGName, workerGroupID)
t.expectDescribePublicSubnets(t.subnets...)
t.expectDescribeVpcsSigs(t.vpcID)
t.expectDescribePublicSubnetsSigs(t.subnets...)

var err error

Expand Down

0 comments on commit 4071086

Please sign in to comment.