diff --git a/pkg/aws/aws.go b/pkg/aws/aws.go index 68ed99d0..8bc5b2cb 100644 --- a/pkg/aws/aws.go +++ b/pkg/aws/aws.go @@ -99,10 +99,14 @@ func (ac *awsCloud) setSuffixes(vpcID string) error { } publicSubnets, err := ac.findPublicSubnets(vpcID, ac.filterByName("{infraID}*-public-{region}*")) - if err != nil || len(publicSubnets) == 0 { + if err != nil { return errors.Wrapf(err, "unable to find the public subnet") } + if len(publicSubnets) == 0 { + return errors.New("no public subnet found") + } + pattern := fmt.Sprintf(`%s.*-subnet-public-%s.*`, regexp.QuoteMeta(ac.infraID), regexp.QuoteMeta(ac.region)) re := regexp.MustCompile(pattern) diff --git a/pkg/aws/aws_cloud_test.go b/pkg/aws/aws_cloud_test.go index b05369ff..e5c82a97 100644 --- a/pkg/aws/aws_cloud_test.go +++ b/pkg/aws/aws_cloud_test.go @@ -40,6 +40,8 @@ func testOpenPorts() { JustBeforeEach(func() { t.expectDescribeVpcs(t.vpcID) + t.expectDescribeVpcsSigs(t.vpcID) + t.expectDescribePublicSubnets(t.subnets...) retError = t.cloud.OpenPorts([]api.PortSpec{ { @@ -87,6 +89,7 @@ func testOpenPorts() { When("authorize security group ingress validation fails", func() { BeforeEach(func() { t.expectDescribeVpcs(vpcID) + t.expectDescribePublicSubnets(t.subnets...) t.expectValidateAuthorizeSecurityGroupIngress(errors.New("mock error")) }) @@ -114,6 +117,9 @@ func testClosePorts() { JustBeforeEach(func() { t.expectDescribeVpcs(t.vpcID) + t.expectDescribePublicSubnets(t.subnets...) + t.expectDescribeVpcsSigs(t.vpcID) + t.expectDescribePublicSubnetsSigs(t.subnets...) retError = t.cloud.ClosePorts(reporter.Stdout()) }) diff --git a/pkg/aws/aws_suite_test.go b/pkg/aws/aws_suite_test.go index 5f973cbf..d8bbf08f 100644 --- a/pkg/aws/aws_suite_test.go +++ b/pkg/aws/aws_suite_test.go @@ -34,22 +34,23 @@ import ( ) const ( - infraID = "test-infra" - region = "test-region" - vpcID = "test-vpc" - workerGroupID = "worker-group" - masterGroupID = "master-group" - gatewayGroupID = "gateway-group" - internalTraffic = "Internal Submariner traffic" - availabilityZone1 = "availability-zone-1" - availabilityZone2 = "availability-zone-2" - subnetID1 = "subnet-1" - subnetID2 = "subnet-2" - instanceImageID = "test-image" - masterSGName = infraID + "-master-sg" - workerSGName = infraID + "-worker-sg" - gatewaySGName = infraID + "-submariner-gw-sg" - clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID + infraID = "test-infra" + region = "test-region" + vpcID = "test-vpc" + workerGroupID = "worker-group" + masterGroupID = "master-group" + gatewayGroupID = "gateway-group" + internalTraffic = "Internal Submariner traffic" + availabilityZone1 = "availability-zone-1" + availabilityZone2 = "availability-zone-2" + subnetID1 = "subnet-1" + subnetID2 = "subnet-2" + instanceImageID = "test-image" + masterSGName = infraID + "-master-sg" + workerSGName = infraID + "-worker-sg" + gatewaySGName = infraID + "-submariner-gw-sg" + clusterFilterTagName = "tag:kubernetes.io/cluster/" + infraID + clusterFilterTagNameSigs = "tag:sigs.k8s.io/cluster-api-provider-aws/cluster/" + infraID ) var internalTrafficDesc = fmt.Sprintf("Should contain %q", internalTraffic) @@ -62,6 +63,7 @@ func TestAWS(t *testing.T) { type fakeAWSClientBase struct { awsClient *fake.MockInterface vpcID string + subnets []types.Subnet describeSubnetsErr error authorizeSecurityGroupIngressErr error createTagsErr error @@ -71,6 +73,7 @@ type fakeAWSClientBase struct { func (f *fakeAWSClientBase) beforeEach() { f.awsClient = fake.NewMockInterface(GinkgoT()) f.vpcID = vpcID + f.subnets = []types.Subnet{newSubnet(availabilityZone1, subnetID1), newSubnet(availabilityZone2, subnetID2)} f.describeSubnetsErr = nil f.authorizeSecurityGroupIngressErr = nil f.createTagsErr = nil @@ -110,6 +113,25 @@ func (f *fakeAWSClientBase) expectDescribeVpcs(vpcID string) { }}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe() } +func (f *fakeAWSClientBase) expectDescribeVpcsSigs(vpcID string) { + var vpcs []types.Vpc + if vpcID != "" { + vpcs = []types.Vpc{ + { + VpcId: ptr.To(vpcID), + }, + } + } + + f.awsClient.EXPECT().DescribeVpcs(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{ + Name: ptr.To("tag:Name"), + Values: []string{infraID + "-vpc"}, + }, { + Name: ptr.To(clusterFilterTagNameSigs), + Values: []string{"owned"}, + }}}).Matches))).Return(&ec2.DescribeVpcsOutput{Vpcs: vpcs}, nil).Maybe() +} + func (f *fakeAWSClientBase) expectValidateAuthorizeSecurityGroupIngress(authErr error) *mock.Call { return f.awsClient.EXPECT().AuthorizeSecurityGroupIngress(mock.Anything, mock.MatchedBy((&authorizeSecurityGroupIngressInputMatcher{ec2.AuthorizeSecurityGroupIngressInput{ @@ -144,7 +166,7 @@ func (f *fakeAWSClientBase) expectValidateRevokeSecurityGroupIngress(retErr erro func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subnet) { f.awsClient.EXPECT().DescribeSubnets(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{ Name: ptr.To("tag:Name"), - Values: []string{infraID + "-public-" + region + "*"}, + Values: []string{infraID + "*-public-" + region + "*"}, }, { Name: ptr.To("vpc-id"), Values: []string{f.vpcID}, @@ -154,6 +176,19 @@ func (f *fakeAWSClientBase) expectDescribePublicSubnets(retSubnets ...types.Subn }}}).Matches))).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).Maybe() } +func (f *fakeAWSClientBase) expectDescribePublicSubnetsSigs(retSubnets ...types.Subnet) { + f.awsClient.EXPECT().DescribeSubnets(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{ + Name: ptr.To("tag:Name"), + Values: []string{infraID + "*-public-" + region + "*"}, + }, { + Name: ptr.To("vpc-id"), + Values: []string{f.vpcID}, + }, { + Name: ptr.To(clusterFilterTagNameSigs), + Values: []string{"owned"}, + }}}).Matches))).Return(&ec2.DescribeSubnetsOutput{Subnets: retSubnets}, f.describeSubnetsErr).Maybe() +} + func (f *fakeAWSClientBase) expectDescribeGatewaySubnets(retSubnets ...types.Subnet) { f.awsClient.EXPECT().DescribeSubnets(mock.Anything, mock.MatchedBy(((&filtersMatcher{expectedFilters: []types.Filter{{ Name: ptr.To("tag:submariner.io/gateway"), diff --git a/pkg/aws/ocpgwdeployer_test.go b/pkg/aws/ocpgwdeployer_test.go index eb51989a..9c5c6925 100644 --- a/pkg/aws/ocpgwdeployer_test.go +++ b/pkg/aws/ocpgwdeployer_test.go @@ -283,6 +283,9 @@ func newGatewayDeployerTestDriver() *gatewayDeployerTestDriver { t.expectDescribeSecurityGroups(gatewaySGName, t.gatewayGroupID) t.expectDescribeInstances(instanceImageID) t.expectDescribeSecurityGroups(workerSGName, workerGroupID) + t.expectDescribePublicSubnets(t.subnets...) + t.expectDescribeVpcsSigs(t.vpcID) + t.expectDescribePublicSubnetsSigs(t.subnets...) var err error