From d10e2375e82d73ac8dce1297e2f4c5bd30050b09 Mon Sep 17 00:00:00 2001 From: nightfury1204 Date: Tue, 29 Aug 2023 23:01:20 +0100 Subject: [PATCH] Add 0.0.0.0/0 edge case fix --- .../lambda/formation/handler/sg_ingress.go | 10 +++---- .../formation/handler/sg_ingress_test.go | 26 ++++++++++++++++--- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/provider/aws/lambda/formation/handler/sg_ingress.go b/provider/aws/lambda/formation/handler/sg_ingress.go index d4662b899..ed596c487 100644 --- a/provider/aws/lambda/formation/handler/sg_ingress.go +++ b/provider/aws/lambda/formation/handler/sg_ingress.go @@ -58,7 +58,6 @@ func sgIngressApply(req Request) error { GroupIds: []*string{sgId}, }) if err != nil { - fmt.Println("1") return err } @@ -139,7 +138,6 @@ func sgIngressApply(req Request) error { }, }) if err != nil { - fmt.Println("3") return err } } @@ -147,11 +145,15 @@ func sgIngressApply(req Request) error { var deleteIps []*ec2.IpRange for i := range prevIps { if prevIps[i].CidrIp != nil { - if val, _ := prevIpMap[*prevIps[i].CidrIp]; val == 1 { + if val := prevIpMap[*prevIps[i].CidrIp]; val == 1 { + if *prevIps[i].CidrIp == "0.0.0.0/0" && len(req.ResourceProperties["Ips"].(string)) == 0 { + continue + } deleteIps = append(deleteIps, prevIps[i]) } } } + if len(deleteIps) > 0 { _, err = EC2(req).RevokeSecurityGroupIngress(&ec2.RevokeSecurityGroupIngressInput{ GroupId: sgId, @@ -165,7 +167,6 @@ func sgIngressApply(req Request) error { }, }) if err != nil { - fmt.Println("2") return err } } @@ -207,7 +208,6 @@ func sgIngressApply(req Request) error { }, }) if err != nil { - fmt.Println("4") return err } } diff --git a/provider/aws/lambda/formation/handler/sg_ingress_test.go b/provider/aws/lambda/formation/handler/sg_ingress_test.go index b83df083a..70f8b72d6 100644 --- a/provider/aws/lambda/formation/handler/sg_ingress_test.go +++ b/provider/aws/lambda/formation/handler/sg_ingress_test.go @@ -5,6 +5,8 @@ import ( "testing" ) +const sgID = "sg-062c937b3e674adc3" + func TestSGIngress(t *testing.T) { if os.Getenv("MANUAL_TEST") != "true" { t.Skip() @@ -12,7 +14,7 @@ func TestSGIngress(t *testing.T) { err := sgIngressApply(Request{ ResourceProperties: map[string]interface{}{ - "SecurityGroupID": "sg-019eb79e77bba7daa", + "SecurityGroupID": sgID, "Ips": "", "SgIDs": "", }, @@ -29,7 +31,7 @@ func TestSGIngress2(t *testing.T) { err := sgIngressApply(Request{ ResourceProperties: map[string]interface{}{ - "SecurityGroupID": "sg-019eb79e77bba7daa", + "SecurityGroupID": sgID, "Ips": "10.0.0.0/16", "SgIDs": "", }, @@ -46,7 +48,7 @@ func TestSGIngress4(t *testing.T) { err := sgIngressApply(Request{ ResourceProperties: map[string]interface{}{ - "SecurityGroupID": "sg-019eb79e77bba7daa", + "SecurityGroupID": sgID, "Ips": "10.0.0.0/16,173.0.0.0/8", "SgIDs": "", }, @@ -62,7 +64,7 @@ func TestSGIngress3(t *testing.T) { } err := sgIngressApply(Request{ ResourceProperties: map[string]interface{}{ - "SecurityGroupID": "sg-019eb79e77bba7daa", + "SecurityGroupID": sgID, "Ips": "10.0.0.0/16", "SgIDs": "sg-0cb8ec0e8c9505ffa", //sg-0e1a179a4c9307a55", }, @@ -71,3 +73,19 @@ func TestSGIngress3(t *testing.T) { t.Fatal(err) } } + +func TestSGIngress5(t *testing.T) { + if os.Getenv("MANUAL_TEST") != "true" { + t.Skip() + } + err := sgIngressApply(Request{ + ResourceProperties: map[string]interface{}{ + "SecurityGroupID": sgID, + "Ips": "0.0.0.0/0,10.0.0.0/16", + "SgIDs": "", //sg-0e1a179a4c9307a55", + }, + }) + if err != nil { + t.Fatal(err) + } +}