diff --git a/pkg/driver/node.go b/pkg/driver/node.go index 3af4a3657a..e64c8ef59f 100644 --- a/pkg/driver/node.go +++ b/pkg/driver/node.go @@ -103,7 +103,7 @@ func newNodeService(driverOptions *DriverOptions) nodeService { // Remove taint from node to indicate driver startup success // This is done at the last possible moment to prevent race conditions or false positive removals - go removeTaintInBackground(cloud.DefaultKubernetesAPIClient) + go removeTaintInBackground(cloud.DefaultKubernetesAPIClient, removeNotReadyTaint) return nodeService{ metadata: metadata, @@ -854,12 +854,12 @@ type JSONPatch struct { } // removeTaintInBackground is a goroutine that retries removeNotReadyTaint with exponential backoff -func removeTaintInBackground(k8sClient cloud.KubernetesAPIClient) { +func removeTaintInBackground(k8sClient cloud.KubernetesAPIClient, removalFunc func(cloud.KubernetesAPIClient) error) { backoffErr := wait.ExponentialBackoff(taintRemovalBackoff, func() (bool, error) { - err := removeNotReadyTaint(k8sClient) + err := removalFunc(k8sClient) if err != nil { klog.ErrorS(err, "Unexpected failure when attempting to remove node taint(s)") - return false, err + return false, nil } return true, nil }) diff --git a/pkg/driver/node_test.go b/pkg/driver/node_test.go index 1767fd965f..fe707a54d2 100644 --- a/pkg/driver/node_test.go +++ b/pkg/driver/node_test.go @@ -34,6 +34,7 @@ import ( "github.com/golang/mock/gomock" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/cloud" "github.com/kubernetes-sigs/aws-ebs-csi-driver/pkg/driver/internal" + "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" corev1 "k8s.io/api/core/v1" @@ -2405,6 +2406,21 @@ func TestRemoveNotReadyTaint(t *testing.T) { } } +func TestRemoveTaintInBackground(t *testing.T) { + mockRemovalCount := 0 + mockRemovalFunc := func(_ cloud.KubernetesAPIClient) error { + mockRemovalCount += 1 + if mockRemovalCount == 3 { + return nil + } else { + return fmt.Errorf("Taint removal failed!") + } + } + + removeTaintInBackground(nil, mockRemovalFunc) + assert.Equal(t, mockRemovalCount, 3) +} + func getNodeMock(mockCtl *gomock.Controller, nodeName string, returnNode *corev1.Node, returnError error) (kubernetes.Interface, *MockNodeInterface) { mockClient := NewMockKubernetesClient(mockCtl) mockCoreV1 := NewMockCoreV1Interface(mockCtl)