From 25bea408d66f83b41fae3b58a6a1bb6157fb594b Mon Sep 17 00:00:00 2001 From: tariq-hasan Date: Sun, 17 Mar 2024 19:01:09 -0400 Subject: [PATCH] Introduced error constants and replaced reflect with cmp --- .../experiment/manifest/generator.go | 27 +- .../experiment/manifest/generator_test.go | 139 ++++---- .../suggestion/composer/composer_test.go | 6 +- .../suggestionclient/suggestionclient_test.go | 6 +- .../trial/util/job_util_test.go | 59 ++-- pkg/suggestion/v1beta1/goptuna/converter.go | 4 +- .../v1beta1/goptuna/converter_test.go | 86 ++--- pkg/suggestion/v1beta1/goptuna/sample.go | 6 +- pkg/webhook/v1beta1/pod/inject_webhook.go | 16 +- .../v1beta1/pod/inject_webhook_test.go | 331 ++++++++---------- pkg/webhook/v1beta1/pod/utils.go | 6 +- 11 files changed, 317 insertions(+), 369 deletions(-) diff --git a/pkg/controller.v1beta1/experiment/manifest/generator.go b/pkg/controller.v1beta1/experiment/manifest/generator.go index a41cc2f0c89..5507f4388fd 100644 --- a/pkg/controller.v1beta1/experiment/manifest/generator.go +++ b/pkg/controller.v1beta1/experiment/manifest/generator.go @@ -17,6 +17,7 @@ limitations under the License. package manifest import ( + "errors" "fmt" "regexp" "strings" @@ -33,6 +34,15 @@ import ( "github.com/kubeflow/katib/pkg/util/v1beta1/katibconfig" ) +var ( + ErrConfigMapNotFound = errors.New("ConfigMap not found") + ErrConvertStringToUnstructuredFailed = errors.New("ConvertStringToUnstructured failed") + ErrConvertUnstructuredToStringFailed = errors.New("ConvertUnstructuredToString failed") + ErrParamNotFoundInParameterAssignment = errors.New("unable to find non-meta paramater from TrialParameters in ParameterAssignment") + ErrParamNotFoundInTrialParameters = errors.New("unable to find parameter from ParameterAssignment in TrialParameters") + ErrTemplatePathNotFound = errors.New("TemplatePath not found in ConfigMap") +) + // Generator is the type for manifests Generator. type Generator interface { InjectClient(c client.Client) @@ -86,7 +96,7 @@ func (g *DefaultGenerator) GetRunSpecWithHyperParameters(experiment *experiments // Convert Trial template to unstructured runSpec, err := util.ConvertStringToUnstructured(replacedTemplate) if err != nil { - return nil, fmt.Errorf("ConvertStringToUnstructured failed: %v", err) + return nil, fmt.Errorf("%w: %s", ErrConvertStringToUnstructuredFailed, err.Error()) } // Set name and namespace for Run Spec @@ -108,7 +118,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi if trialSpec == nil { trialSpec, err = util.ConvertStringToUnstructured(trialTemplate) if err != nil { - return "", fmt.Errorf("ConvertStringToUnstructured failed: %v", err) + return "", fmt.Errorf("%w: %s", ErrConvertStringToUnstructuredFailed, err.Error()) } } @@ -131,7 +141,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi nonMetaParamCount += 1 continue } else { - return "", fmt.Errorf("Unable to find parameter: %v in parameter assignment %v", param.Reference, assignmentsMap) + return "", fmt.Errorf("%w: parameter: %v, parameter assignment: %v", ErrParamNotFoundInParameterAssignment, param.Reference, assignmentsMap) } } metaRefKey = sub[1] @@ -172,9 +182,10 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi } } - // Number of parameters must be equal + // Number of assignment parameters must be equal to the number of non-meta trial parameters + // i.e. all parameters in ParameterAssignment must be in TrialParameters if len(assignments) != nonMetaParamCount { - return "", fmt.Errorf("Number of TrialAssignment: %v != number of nonMetaTrialParameters in TrialSpec: %v", len(assignments), nonMetaParamCount) + return "", fmt.Errorf("%w: parameter assignments: %v, non-meta trial parameter count: %v", ErrParamNotFoundInTrialParameters, assignments, nonMetaParamCount) } // Replacing placeholders with parameter values @@ -194,7 +205,7 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim if trialSource.TrialSpec != nil { trialTemplateString, err = util.ConvertUnstructuredToString(trialSource.TrialSpec) if err != nil { - return "", fmt.Errorf("ConvertUnstructuredToString failed: %v", err) + return "", fmt.Errorf("%w: %s", ErrConvertUnstructuredToStringFailed, err.Error()) } } else { configMapNS := trialSource.ConfigMap.ConfigMapNamespace @@ -202,12 +213,12 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim templatePath := trialSource.ConfigMap.TemplatePath configMap, err := g.client.GetConfigMap(configMapName, configMapNS) if err != nil { - return "", fmt.Errorf("GetConfigMap failed: %v", err) + return "", fmt.Errorf("%w: %v", ErrConfigMapNotFound, err) } var ok bool trialTemplateString, ok = configMap[templatePath] if !ok { - return "", fmt.Errorf("TemplatePath: %v not found in configMap: %v", templatePath, configMap) + return "", fmt.Errorf("%w: TemplatePath: %v, ConfigMap: %v", ErrTemplatePathNotFound, templatePath, configMap) } } diff --git a/pkg/controller.v1beta1/experiment/manifest/generator_test.go b/pkg/controller.v1beta1/experiment/manifest/generator_test.go index 3adeb017f74..25f374e6f86 100644 --- a/pkg/controller.v1beta1/experiment/manifest/generator_test.go +++ b/pkg/controller.v1beta1/experiment/manifest/generator_test.go @@ -17,12 +17,12 @@ limitations under the License. package manifest import ( - "errors" "math" - "reflect" "testing" "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" batchv1 "k8s.io/api/batch/v1" v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -88,23 +88,18 @@ func TestGetRunSpecWithHP(t *testing.T) { t.Errorf("ConvertObjectToUnstructured failed: %v", err) } - tcs := []struct { - Instance *experimentsv1beta1.Experiment - ParameterAssignments []commonapiv1beta1.ParameterAssignment - expectedRunSpec *unstructured.Unstructured - Err bool - testDescription string + cases := map[string]struct { + Instance *experimentsv1beta1.Experiment + ParameterAssignments []commonapiv1beta1.ParameterAssignment + wantRunSpecWithHyperParameters *unstructured.Unstructured + wantError error }{ - // Valid run - { - Instance: newFakeInstance(), - ParameterAssignments: newFakeParameterAssignment(), - expectedRunSpec: expectedRunSpec, - Err: false, - testDescription: "Run with valid parameters", + "Run with valid parameters": { + Instance: newFakeInstance(), + ParameterAssignments: newFakeParameterAssignment(), + wantRunSpecWithHyperParameters: expectedRunSpec, }, - // Invalid JSON in unstructured - { + "Invalid JSON in Unstructured Trial template": { Instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() trialSpec := i.Spec.TrialTemplate.TrialSource.TrialSpec @@ -114,48 +109,45 @@ func TestGetRunSpecWithHP(t *testing.T) { return i }(), ParameterAssignments: newFakeParameterAssignment(), - Err: true, - testDescription: "Invalid JSON in Trial template", + wantError: ErrConvertUnstructuredToStringFailed, }, - // len(parameterAssignment) != len(trialParameters) - { + "Non-meta parameter from TrialParameters not found in ParameterAssignment": { Instance: newFakeInstance(), ParameterAssignments: func() []commonapiv1beta1.ParameterAssignment { pa := newFakeParameterAssignment() - pa = pa[1:] + pa[0] = commonapiv1beta1.ParameterAssignment{ + Name: "invalid-name", + Value: "invalid-value", + } return pa }(), - Err: true, - testDescription: "Number of parameter assignments is not equal to number of Trial parameters", + wantError: ErrParamNotFoundInParameterAssignment, }, - // Parameter from assignments not found in Trial parameters - { + // case in which the lengths of trial parameters and parameter assignments are different + "Parameter from ParameterAssignment not found in TrialParameters": { Instance: newFakeInstance(), ParameterAssignments: func() []commonapiv1beta1.ParameterAssignment { pa := newFakeParameterAssignment() - pa[0] = commonapiv1beta1.ParameterAssignment{ - Name: "invalid-name", - Value: "invalid-value", - } + pa = append(pa, commonapiv1beta1.ParameterAssignment{ + Name: "extra-name", + Value: "extra-value", + }) return pa }(), - Err: true, - testDescription: "Trial parameters don't have parameter from assignments", + wantError: ErrParamNotFoundInTrialParameters, }, } - for _, tc := range tcs { - actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.Instance, "trial-name", "trial-namespace", tc.ParameterAssignments) - - if tc.Err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.Err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(tc.expectedRunSpec, actualRunSpec) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedRunSpec.Object, actualRunSpec.Object) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := p.GetRunSpecWithHyperParameters(tc.Instance, "trial-name", "trial-namespace", tc.ParameterAssignments) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 { + t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) } - } + }) } } @@ -208,7 +200,7 @@ spec: map[string]string{templatePath: trialSpec}, nil) invalidConfigMapName := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( - nil, errors.New("Unable to get ConfigMap")) + nil, ErrConfigMapNotFound) validGetConfigMap3 := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( map[string]string{templatePath: trialSpec}, nil) @@ -245,18 +237,17 @@ spec: expectedRunSpec, err := util.ConvertStringToUnstructured(expectedStr) if err != nil { - t.Errorf("ConvertStringToUnstructured failed: %v", err) + t.Errorf("%v: %v", ErrConvertStringToUnstructuredFailed, err) } - tcs := []struct { - Instance *experimentsv1beta1.Experiment - ParameterAssignments []commonapiv1beta1.ParameterAssignment - Err bool - testDescription string + cases := map[string]struct { + Instance *experimentsv1beta1.Experiment + ParameterAssignments []commonapiv1beta1.ParameterAssignment + wantRunSpecWithHyperParameters *unstructured.Unstructured + wantError error }{ - // Valid run // validGetConfigMap1 case - { + "Run with valid parameters": { Instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -268,13 +259,11 @@ spec: } return i }(), - ParameterAssignments: newFakeParameterAssignment(), - Err: false, - testDescription: "Run with valid parameters", + ParameterAssignments: newFakeParameterAssignment(), + wantRunSpecWithHyperParameters: expectedRunSpec, }, - // Invalid ConfigMap name // invalidConfigMapName case - { + "Invalid ConfigMap name": { Instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -285,12 +274,10 @@ spec: return i }(), ParameterAssignments: newFakeParameterAssignment(), - Err: true, - testDescription: "Invalid ConfigMap name", + wantError: ErrConfigMapNotFound, }, - // Invalid template path in ConfigMap name // validGetConfigMap3 case - { + "Invalid template path in ConfigMap name": { Instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -303,14 +290,12 @@ spec: return i }(), ParameterAssignments: newFakeParameterAssignment(), - Err: true, - testDescription: "Invalid template path in ConfigMap", + wantError: ErrTemplatePathNotFound, }, - // Invalid Trial template spec in ConfigMap + // invalidTemplate case // Trial template is a string in ConfigMap // Because of that, user can specify not valid unstructured template - // invalidTemplate case - { + "Invalid trial spec in ConfigMap": { Instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -323,22 +308,20 @@ spec: return i }(), ParameterAssignments: newFakeParameterAssignment(), - Err: true, - testDescription: "Invalid Trial spec in ConfigMap", + wantError: ErrConvertStringToUnstructuredFailed, }, } - for _, tc := range tcs { - actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.Instance, "trial-name", "trial-namespace", tc.ParameterAssignments) - if tc.Err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.Err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(expectedRunSpec, actualRunSpec) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, expectedRunSpec.Object, actualRunSpec.Object) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := p.GetRunSpecWithHyperParameters(tc.Instance, "trial-name", "trial-namespace", tc.ParameterAssignments) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 { + t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) } - } + }) } } diff --git a/pkg/controller.v1beta1/suggestion/composer/composer_test.go b/pkg/controller.v1beta1/suggestion/composer/composer_test.go index 9de963030fb..41b92b6b60f 100644 --- a/pkg/controller.v1beta1/suggestion/composer/composer_test.go +++ b/pkg/controller.v1beta1/suggestion/composer/composer_test.go @@ -22,12 +22,12 @@ import ( stdlog "log" "os" "path/filepath" - "reflect" "strings" "sync" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/onsi/gomega" "github.com/spf13/viper" appsv1 "k8s.io/api/apps/v1" @@ -631,8 +631,8 @@ func TestDesiredRBAC(t *testing.T) { func metaEqual(expected, actual metav1.ObjectMeta) bool { return expected.Name == actual.Name && expected.Namespace == actual.Namespace && - reflect.DeepEqual(expected.Labels, actual.Labels) && - reflect.DeepEqual(expected.Annotations, actual.Annotations) && + cmp.Equal(expected.Labels, actual.Labels) && + cmp.Equal(expected.Annotations, actual.Annotations) && (len(actual.OwnerReferences) > 0 && expected.OwnerReferences[0].APIVersion == actual.OwnerReferences[0].APIVersion && expected.OwnerReferences[0].Kind == actual.OwnerReferences[0].Kind && diff --git a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go index 1015e5ed157..d11a204bf3b 100644 --- a/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go +++ b/pkg/controller.v1beta1/suggestion/suggestionclient/suggestionclient_test.go @@ -19,11 +19,11 @@ package suggestionclient import ( "errors" "fmt" - "reflect" "testing" "time" "github.com/golang/mock/gomock" + "github.com/google/go-cmp/cmp" "github.com/onsi/gomega" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -578,8 +578,8 @@ func TestConvertTrialObservation(t *testing.T) { } for _, tc := range tcs { actualObservation := convertTrialObservation(tc.strategies, tc.inObservation) - if !reflect.DeepEqual(actualObservation, tc.expectedObservation) { - t.Errorf("Case: %v failed.\nExpected observation: %v \ngot: %v", tc.testDescription, tc.expectedObservation, actualObservation) + if diff := cmp.Diff(actualObservation, tc.expectedObservation); diff != "" { + t.Errorf("Case: %v failed. Diff (-want,+got):\n%s", tc.testDescription, diff) } } } diff --git a/pkg/controller.v1beta1/trial/util/job_util_test.go b/pkg/controller.v1beta1/trial/util/job_util_test.go index d2a018967c3..c1908144df9 100644 --- a/pkg/controller.v1beta1/trial/util/job_util_test.go +++ b/pkg/controller.v1beta1/trial/util/job_util_test.go @@ -17,7 +17,6 @@ limitations under the License. package util import ( - "reflect" "testing" batchv1 "k8s.io/api/batch/v1" @@ -25,6 +24,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" trialsv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/trials/v1beta1" "github.com/kubeflow/katib/pkg/controller.v1beta1/util" ) @@ -39,14 +40,13 @@ func TestGetDeployedJobStatus(t *testing.T) { successCondition := "status.conditions.#(type==\"Complete\")#|#(status==\"True\")#" failureCondition := "status.conditions.#(type==\"Failed\")#|#(status==\"True\")#" - tcs := []struct { - trial *trialsv1beta1.Trial - deployedJob *unstructured.Unstructured - expectedTrialJobStatus *TrialJobStatus - err bool - testDescription string + cases := map[string]struct { + trial *trialsv1beta1.Trial + deployedJob *unstructured.Unstructured + wantTrialJobStatus *TrialJobStatus + wantError error }{ - { + "Job status is running": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: func() *unstructured.Unstructured { job := newFakeJob() @@ -54,28 +54,24 @@ func TestGetDeployedJobStatus(t *testing.T) { job.Status.Conditions[1].Status = corev1.ConditionFalse return newFakeDeployedJob(job) }(), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobRunning, } }(), - err: false, - testDescription: "Job status is running", }, - { + "Job status is succeeded, reason and message must be returned": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: newFakeDeployedJob(newFakeJob()), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobSucceeded, Message: testMessage, Reason: testReason, } }(), - err: false, - testDescription: "Job status is succeeded, reason and message must be returned", }, - { + "Job status is failed, reason and message must be returned": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: func() *unstructured.Unstructured { job := newFakeJob() @@ -83,41 +79,35 @@ func TestGetDeployedJobStatus(t *testing.T) { job.Status.Conditions[1].Status = corev1.ConditionFalse return newFakeDeployedJob(job) }(), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobFailed, Message: testMessage, Reason: testReason, } }(), - err: false, - testDescription: "Job status is failed, reason and message must be returned", }, - { + "Job status is succeeded because status.succeeded = 1": { trial: newFakeTrial("status.[@this].#(succeeded==1)", failureCondition), deployedJob: newFakeDeployedJob(newFakeJob()), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobSucceeded, } }(), - err: false, - testDescription: "Job status is succeeded because status.succeeded = 1", }, } - for _, tc := range tcs { - actualTrialJobStatus, err := GetDeployedJobStatus(tc.trial, tc.deployedJob) - - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(tc.expectedTrialJobStatus, actualTrialJobStatus) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedTrialJobStatus, actualTrialJobStatus) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := GetDeployedJobStatus(tc.trial, tc.deployedJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetDeployedJobStatus() (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantTrialJobStatus, got); len(diff) != 0 { + t.Errorf("Unexpected trial job status from GetDeployedJobStatus() (-want,+got):\n%s", diff) } - } + }) } } @@ -154,6 +144,7 @@ func newFakeJob() *batchv1.Job { }, } } + func newFakeDeployedJob(job interface{}) *unstructured.Unstructured { jobUnstructured, _ := util.ConvertObjectToUnstructured(job) diff --git a/pkg/suggestion/v1beta1/goptuna/converter.go b/pkg/suggestion/v1beta1/goptuna/converter.go index b7865b5b307..6c338efe0a3 100644 --- a/pkg/suggestion/v1beta1/goptuna/converter.go +++ b/pkg/suggestion/v1beta1/goptuna/converter.go @@ -311,7 +311,7 @@ func toGoptunaParams( } ir := float64(p) internalParams[name] = ir - // externalParams[name] = p is prohibited because of reflect.DeepEqual() will be false + // externalParams[name] = p is prohibited because of cmp.Diff() will not be empty // at findGoptunaTrialIDByParam() function. externalParams[name] = d.ToExternalRepr(ir) case goptuna.StepIntUniformDistribution: @@ -321,7 +321,7 @@ func toGoptunaParams( } ir := float64(p) internalParams[name] = ir - // externalParams[name] = p is prohibited because of reflect.DeepEqual() will be false + // externalParams[name] = p is prohibited because of cmp.Diff() will not be empty // at findGoptunaTrialIDByParam() function. externalParams[name] = d.ToExternalRepr(ir) case goptuna.CategoricalDistribution: diff --git a/pkg/suggestion/v1beta1/goptuna/converter_test.go b/pkg/suggestion/v1beta1/goptuna/converter_test.go index 1e3189a773d..d92b0f391ca 100644 --- a/pkg/suggestion/v1beta1/goptuna/converter_test.go +++ b/pkg/suggestion/v1beta1/goptuna/converter_test.go @@ -17,48 +17,44 @@ limitations under the License. package suggestion_goptuna_v1beta1 import ( - "reflect" "testing" "github.com/c-bata/goptuna" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" api_v1_beta1 "github.com/kubeflow/katib/pkg/apis/manager/v1beta1" ) func Test_toGoptunaDirection(t *testing.T) { - for _, tt := range []struct { - name string + for name, tc := range map[string]struct { objectiveType api_v1_beta1.ObjectiveType - expected goptuna.StudyDirection + wantDirection goptuna.StudyDirection }{ - { - name: "minimize", + "minimize": { objectiveType: api_v1_beta1.ObjectiveType_MINIMIZE, - expected: goptuna.StudyDirectionMinimize, + wantDirection: goptuna.StudyDirectionMinimize, }, - { - name: "maximize", + "maximize": { objectiveType: api_v1_beta1.ObjectiveType_MAXIMIZE, - expected: goptuna.StudyDirectionMaximize, + wantDirection: goptuna.StudyDirectionMaximize, }, } { - t.Run(tt.name, func(t *testing.T) { - got := toGoptunaDirection(tt.objectiveType) - if got != tt.expected { - t.Errorf("toGoptunaDirection() got = %v, want %v", got, tt.expected) + t.Run(name, func(t *testing.T) { + got := toGoptunaDirection(tc.objectiveType) + if diff := cmp.Diff(tc.wantDirection, got); len(diff) != 0 { + t.Errorf("Unexpected direction from toGoptunaDirection (-want,+got):\n%s", diff) } }) } } func Test_toGoptunaSearchSpace(t *testing.T) { - tests := []struct { - name string - parameters []*api_v1_beta1.ParameterSpec - want map[string]interface{} - wantErr bool + cases := map[string]struct { + parameters []*api_v1_beta1.ParameterSpec + wantSearchSpace map[string]interface{} + wantError error }{ - { - name: "Double parameter type", + "Double parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-double", @@ -69,16 +65,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-double": goptuna.UniformDistribution{ High: 5.5, Low: 1.5, }, }, - wantErr: false, }, - { - name: "Double parameter type with step", + "Double parameter type with step": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-double", @@ -90,17 +84,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-double": goptuna.DiscreteUniformDistribution{ High: 5.5, Low: 1.5, Q: 0.5, }, }, - wantErr: false, }, - { - name: "Int parameter type", + "Int parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-int", @@ -111,16 +103,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-int": goptuna.IntUniformDistribution{ High: 5, Low: 1, }, }, - wantErr: false, }, - { - name: "Int parameter type with step", + "Int parameter type with step": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-int", @@ -132,17 +122,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-int": goptuna.StepIntUniformDistribution{ High: 5, Low: 1, Step: 2, }, }, - wantErr: false, }, - { - name: "Discrete parameter type", + "Discrete parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-discrete", @@ -152,15 +140,13 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-discrete": goptuna.CategoricalDistribution{ Choices: []string{"3", "2", "6"}, }, }, - wantErr: false, }, - { - name: "Categorical parameter type", + "Categorical parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-categorical", @@ -170,23 +156,21 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-categorical": goptuna.CategoricalDistribution{ Choices: []string{"cat1", "cat2", "cat3"}, }, }, - wantErr: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := toGoptunaSearchSpace(tt.parameters) - if (err != nil) != tt.wantErr { - t.Errorf("toGoptunaSearchSpace() error = %v, wantErr %v", err, tt.wantErr) - return + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := toGoptunaSearchSpace(tc.parameters) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from toGoptunaSearchSpace (-want,+got):\n%s", diff) } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("toGoptunaSearchSpace() got = %v, want %v", got, tt.want) + if diff := cmp.Diff(tc.wantSearchSpace, got); len(diff) != 0 { + t.Errorf("Unexpected search space from toGoptunaSearchSpace (-want,+got):\n%s", diff) } }) } diff --git a/pkg/suggestion/v1beta1/goptuna/sample.go b/pkg/suggestion/v1beta1/goptuna/sample.go index 8e484eb8847..08f26deab9e 100644 --- a/pkg/suggestion/v1beta1/goptuna/sample.go +++ b/pkg/suggestion/v1beta1/goptuna/sample.go @@ -18,10 +18,10 @@ package suggestion_goptuna_v1beta1 import ( "fmt" - "reflect" "strconv" "github.com/c-bata/goptuna" + "github.com/google/go-cmp/cmp" api_v1_beta1 "github.com/kubeflow/katib/pkg/apis/manager/v1beta1" ) @@ -119,9 +119,9 @@ func findGoptunaTrialIDByParam(study *goptuna.Study, trialMapping map[string]int continue } - if reflect.DeepEqual(ktrial.Params, trials[i].Params) { + if diff := cmp.Diff(ktrial.Params, trials[i].Params); diff == "" { return trials[i].ID, nil } } - return -1, fmt.Errorf("Same parameter is not found for Trial: %v", ktrial) + return -1, fmt.Errorf("same parameter is not found for Trial: %v", ktrial) } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook.go b/pkg/webhook/v1beta1/pod/inject_webhook.go index f48b4b96a07..2f393dc6763 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "path/filepath" "strconv" @@ -47,6 +48,13 @@ import ( var log = logf.Log.WithName("injector-webhook") +var ( + ErrInvalidOwnerAPIVersion = errors.New("invalid owner API version") + ErrInvaidSuggestionName = errors.New("invalid suggestion name") + ErrPodNotBelongToKatibJob = errors.New("pod does not belong to Katib Job") + ErrNestedObjectNotFound = errors.New("nested object not found in the cluster") +) + // SidecarInjector injects metrics collect sidecar to the primary pod. type SidecarInjector struct { client client.Client @@ -259,7 +267,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // Get group and version from owner API version gv, err := schema.ParseGroupVersion(owners[i].APIVersion) if err != nil { - return "", "", err + return "", "", fmt.Errorf("%w: %s", ErrInvalidOwnerAPIVersion, err.Error()) } gvk := schema.GroupVersionKind{ Group: gv.Group, @@ -272,7 +280,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // Nested object namespace must be equal to object namespace err = s.client.Get(context.TODO(), apitypes.NamespacedName{Name: owners[i].Name, Namespace: namespace}, nestedJob) if err != nil { - return "", "", err + return "", "", fmt.Errorf("%w: %s", ErrNestedObjectNotFound, err.Error()) } // Recursively search for Trial ownership in nested object jobKind, jobName, err = s.getKatibJob(nestedJob, namespace) @@ -285,7 +293,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // If jobKind is empty after the loop, Trial doesn't own the object if jobKind == "" { - return "", "", errors.New("The Pod doesn't belong to Katib Job") + return "", "", ErrPodNotBelongToKatibJob } return jobKind, jobName, nil @@ -322,7 +330,7 @@ func (s *SidecarInjector) getMetricsCollectorArgs(trial *trialsv1beta1.Trial, me suggestion := &suggestionsv1beta1.Suggestion{} err := s.client.Get(context.TODO(), apitypes.NamespacedName{Name: suggestionName, Namespace: trial.Namespace}, suggestion) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %s", ErrInvaidSuggestionName, err.Error()) } args = append(args, "-s-earlystop", util.GetEarlyStoppingEndpoint(suggestion)) } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook_test.go b/pkg/webhook/v1beta1/pod/inject_webhook_test.go index cf46d331768..426b401a793 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook_test.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook_test.go @@ -20,11 +20,12 @@ import ( "context" "fmt" "path/filepath" - "reflect" "sync" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "github.com/onsi/gomega" appsv1 "k8s.io/api/apps/v1" batchv1 "k8s.io/api/batch/v1" @@ -75,16 +76,15 @@ func TestWrapWorkerContainer(t *testing.T) { metricsFile := "metric.log" - testCases := []struct { - Trial *trialsv1beta1.Trial - Pod *v1.Pod - MetricsFile string - PathKind common.FileSystemKind - ExpectedPod *v1.Pod - Err bool - TestDescription string + cases := map[string]struct { + Trial *trialsv1beta1.Trial + Pod *v1.Pod + MetricsFile string + PathKind common.FileSystemKind + WantPod *v1.Pod + WantError error }{ - { + "Tensorflow container without sh -c": { Trial: trial, Pod: &v1.Pod{ Spec: v1.PodSpec{ @@ -100,7 +100,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, MetricsFile: metricsFile, PathKind: common.FileKind, - ExpectedPod: &v1.Pod{ + WantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -115,10 +115,8 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - Err: false, - TestDescription: "Tensorflow container without sh -c", }, - { + "Tensorflow container with sh -c": { Trial: trial, Pod: &v1.Pod{ Spec: v1.PodSpec{ @@ -135,7 +133,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, MetricsFile: metricsFile, PathKind: common.FileKind, - ExpectedPod: &v1.Pod{ + WantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -150,10 +148,8 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - Err: false, - TestDescription: "Tensorflow container with sh -c", }, - { + "Training pod doesn't have primary container": { Trial: trial, Pod: &v1.Pod{ Spec: v1.PodSpec{ @@ -164,11 +160,10 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - PathKind: common.FileKind, - Err: true, - TestDescription: "Training pod doesn't have primary container", + PathKind: common.FileKind, + WantError: ErrPrimaryContainerNotFound, }, - { + "Container with early stopping command": { Trial: func() *trialsv1beta1.Trial { t := trial.DeepCopy() t.Spec.EarlyStoppingRules = []common.EarlyStoppingRule{ @@ -194,7 +189,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, MetricsFile: metricsFile, PathKind: common.FileKind, - ExpectedPod: &v1.Pod{ + WantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -213,23 +208,19 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - Err: false, - TestDescription: "Container with early stopping command", }, } - for _, c := range testCases { - err := wrapWorkerContainer(c.Trial, c.Pod, c.Trial.Namespace, c.MetricsFile, c.PathKind) - if c.Err && err == nil { - t.Errorf("Case %s failed. Expected error, got nil", c.TestDescription) - } else if !c.Err { - if err != nil { - t.Errorf("Case %s failed. Expected nil, got error: %v", c.TestDescription, err) - } else if !equality.Semantic.DeepEqual(c.Pod.Spec.Containers, c.ExpectedPod.Spec.Containers) { - t.Errorf("Case %s failed. Expected pod: %v, got: %v", - c.TestDescription, c.ExpectedPod.Spec.Containers, c.Pod.Spec.Containers) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + err := wrapWorkerContainer(tc.Trial, tc.Pod, tc.Trial.Namespace, tc.MetricsFile, tc.PathKind) + if diff := cmp.Diff(tc.WantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from wrapWorkerContainer (-want,+got):\n%s", diff) } - } + if err == nil && !equality.Semantic.DeepEqual(tc.Pod.Spec.Containers, tc.WantPod.Spec.Containers) { + t.Errorf("Unexpected error from wrapWorkerContainer, expected pod: %v, got: %v", tc.WantPod.Spec.Containers, tc.Pod.Spec.Containers) + } + }) } } @@ -317,17 +308,16 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, } - testCases := []struct { + cases := map[string]struct { Trial *trialsv1beta1.Trial MetricNames string MCSpec common.MetricsCollectorSpec EarlyStoppingRules []string KatibConfig configv1beta1.MetricsCollectorConfig - ExpectedArgs []string - Name string - Err bool + WantArgs []string + WantError error }{ - { + "StdOut MC": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -338,7 +328,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { KatibConfig: configv1beta1.MetricsCollectorConfig{ WaitAllProcesses: &waitAllProcessesValue, }, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -347,9 +337,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-format", string(common.TextFormat), "-w", "false", }, - Name: "StdOut MC", }, - { + "File MC with Filter": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -370,7 +359,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -379,9 +368,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-f", "{mn1: ([a-b]), mv1: [0-9]};{mn2: ([a-b]), mv2: ([0-9])}", "-format", string(common.TextFormat), }, - Name: "File MC with Filter", }, - { + "File MC with Json Format": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -396,7 +384,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -404,9 +392,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-path", testPath, "-format", string(common.JsonFormat), }, - Name: "File MC with Json Format", }, - { + "Tf Event MC": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -420,16 +407,15 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, "-path", testPath, }, - Name: "Tf Event MC", }, - { + "Custom MC without Path": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -438,15 +424,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, }, - Name: "Custom MC without Path", }, - { + "Custom MC with Path": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -460,16 +445,15 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, "-path", testPath, }, - Name: "Custom MC with Path", }, - { + "Prometheus MC without Path": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -478,15 +462,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), "-s-db", katibDBAddress, }, - Name: "Prometheus MC without Path", }, - { + "Trial with EarlyStopping rules": { Trial: testTrial, MetricNames: testMetricName, MCSpec: common.MetricsCollectorSpec{ @@ -496,7 +479,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, EarlyStoppingRules: earlyStoppingRules, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - ExpectedArgs: []string{ + WantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -507,9 +490,8 @@ func TestGetMetricsCollectorArgs(t *testing.T) { "-stop-rule", earlyStoppingRules[1], "-s-earlystop", katibEarlyStopAddress, }, - Name: "Trial with EarlyStopping rules", }, - { + "Trial with invalid Experiment label name. Suggestion is not created": { Trial: func() *trialsv1beta1.Trial { trial := testTrial.DeepCopy() trial.ObjectMeta.Labels[consts.LabelExperimentName] = "invalid-name" @@ -522,8 +504,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, EarlyStoppingRules: earlyStoppingRules, KatibConfig: configv1beta1.MetricsCollectorConfig{}, - Name: "Trial with invalid Experiment label name. Suggestion is not created", - Err: true, + WantError: ErrInvaidSuggestionName, }, } @@ -534,16 +515,16 @@ func TestGetMetricsCollectorArgs(t *testing.T) { return c.Get(context.TODO(), types.NamespacedName{Namespace: testNamespace, Name: testSuggestionName}, testSuggestion) }, timeout).ShouldNot(gomega.HaveOccurred()) - for _, tc := range testCases { - args, err := si.getMetricsCollectorArgs(tc.Trial, tc.MetricNames, tc.MCSpec, tc.KatibConfig, tc.EarlyStoppingRules) - - if !tc.Err && err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.Name, err) - } else if tc.Err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.Name) - } else if !tc.Err && !reflect.DeepEqual(tc.ExpectedArgs, args) { - t.Errorf("Case %v failed. ExpectedArgs: %v, got %v", tc.Name, tc.ExpectedArgs, args) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := si.getMetricsCollectorArgs(tc.Trial, tc.MetricNames, tc.MCSpec, tc.KatibConfig, tc.EarlyStoppingRules) + if diff := cmp.Diff(tc.WantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from getMetricsCollectorArgs (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.WantArgs, got); len(diff) != 0 { + t.Errorf("Unexpected args from getMetricsCollectorArgs (-want,+got):\n%s", diff) + } + }) } } @@ -581,13 +562,13 @@ func TestNeedWrapWorkerContainer(t *testing.T) { func TestMutateMetricsCollectorVolume(t *testing.T) { tc := struct { Pod v1.Pod - ExpectedPod v1.Pod + WantPod v1.Pod JobKind string MountPath string SidecarContainerName string PrimaryContainerName string PathKind common.FileSystemKind - Err bool + WantError error }{ Pod: v1.Pod{ Spec: v1.PodSpec{ @@ -604,7 +585,7 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { }, }, }, - ExpectedPod: v1.Pod{ + WantPod: v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -651,32 +632,33 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { tc.SidecarContainerName, tc.PrimaryContainerName, tc.PathKind) - if err != nil { - t.Errorf("mutateMetricsCollectorVolume failed: %v", err) - } else if !equality.Semantic.DeepEqual(tc.Pod, tc.ExpectedPod) { - t.Errorf("Expected pod %v, got %v", tc.ExpectedPod, tc.Pod) + if diff := cmp.Diff(tc.WantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from mutateMetricsCollectorVolume (-want,+got):\n%s", diff) + } + if err == nil && !equality.Semantic.DeepEqual(tc.Pod, tc.WantPod) { + t.Errorf("Unexpected error from mutateMetricsCollectorVolume, expected pod: %v, got: %v", tc.WantPod, tc.Pod) } } func TestGetSidecarContainerName(t *testing.T) { - testCases := []struct { - CollectorKind common.CollectorKind - ExpectedCollectorKind string + cases := []struct { + CollectorKind common.CollectorKind + WantCollectorKind string }{ { - CollectorKind: common.StdOutCollector, - ExpectedCollectorKind: mccommon.MetricLoggerCollectorContainerName, + CollectorKind: common.StdOutCollector, + WantCollectorKind: mccommon.MetricLoggerCollectorContainerName, }, { - CollectorKind: common.TfEventCollector, - ExpectedCollectorKind: mccommon.MetricCollectorContainerName, + CollectorKind: common.TfEventCollector, + WantCollectorKind: mccommon.MetricCollectorContainerName, }, } - for _, tc := range testCases { + for _, tc := range cases { collectorKind := getSidecarContainerName(tc.CollectorKind) - if collectorKind != tc.ExpectedCollectorKind { - t.Errorf("Expected Collector Kind: %v, got %v", tc.ExpectedCollectorKind, collectorKind) + if collectorKind != tc.WantCollectorKind { + t.Errorf("Expected Collector Kind: %v, got %v", tc.WantCollectorKind, collectorKind) } } } @@ -719,16 +701,15 @@ func TestGetKatibJob(t *testing.T) { deployName := "deploy-name" jobName := "job-name" - testCases := []struct { - Pod *v1.Pod - Job *batchv1.Job - Deployment *appsv1.Deployment - ExpectedJobKind string - ExpectedJobName string - Err bool - TestDescription string + cases := map[string]struct { + Pod *v1.Pod + Job *batchv1.Job + Deployment *appsv1.Deployment + WantJobKind string + WantJobName string + WantError error }{ - { + "Valid run with ownership sequence: Trial -> Job -> Pod": { Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -769,12 +750,10 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - ExpectedJobKind: "Job", - ExpectedJobName: jobName + "-1", - Err: false, - TestDescription: "Valid run with ownership sequence: Trial -> Job -> Pod", + WantJobKind: "Job", + WantJobName: jobName + "-1", }, - { + "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod": { Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -848,12 +827,10 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - ExpectedJobKind: "Deployment", - ExpectedJobName: deployName + "-2", - Err: false, - TestDescription: "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod", + WantJobKind: "Deployment", + WantJobName: deployName + "-2", }, - { + "Run for not Trial's pod with ownership sequence: Job -> Pod": { Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -886,10 +863,9 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - Err: true, - TestDescription: "Run for not Trial's pod with ownership sequence: Job -> Pod", + WantError: ErrPodNotBelongToKatibJob, }, - { + "Run when Pod owns Job that doesn't exists": { Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -903,10 +879,9 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - Err: true, - TestDescription: "Run when Pod owns Job that doesn't exists", + WantError: ErrNestedObjectNotFound, }, - { + "Run when Pod owns Job with invalid API version": { Pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Name: podName, @@ -920,64 +895,63 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - Err: true, - TestDescription: "Run when Pod owns Job with invalid API version", + WantError: ErrInvalidOwnerAPIVersion, }, } - for _, tc := range testCases { - // Create Job if it is needed - if tc.Job != nil { - jobUnstr, err := util.ConvertObjectToUnstructured(tc.Job) - gvk := schema.GroupVersionKind{ - Group: "batch", - Version: "v1", - Kind: "Job", - } - jobUnstr.SetGroupVersionKind(gvk) - if err != nil { - t.Errorf("ConvertObjectToUnstructured error %v", err) - } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + // Create Job if it is needed + if tc.Job != nil { + jobUnstr, err := util.ConvertObjectToUnstructured(tc.Job) + gvk := schema.GroupVersionKind{ + Group: "batch", + Version: "v1", + Kind: "Job", + } + jobUnstr.SetGroupVersionKind(gvk) + if err != nil { + t.Errorf("ConvertObjectToUnstructured error %v", err) + } - g.Expect(c.Create(context.TODO(), jobUnstr)).NotTo(gomega.HaveOccurred()) + g.Expect(c.Create(context.TODO(), jobUnstr)).NotTo(gomega.HaveOccurred()) - // Wait that Job is created - g.Eventually(func() error { - return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.Job.Name}, jobUnstr) - }, timeout).ShouldNot(gomega.HaveOccurred()) - } + // Wait that Job is created + g.Eventually(func() error { + return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.Job.Name}, jobUnstr) + }, timeout).ShouldNot(gomega.HaveOccurred()) + } - // Create Deployment if it is needed - if tc.Deployment != nil { - g.Expect(c.Create(context.TODO(), tc.Deployment)).NotTo(gomega.HaveOccurred()) + // Create Deployment if it is needed + if tc.Deployment != nil { + g.Expect(c.Create(context.TODO(), tc.Deployment)).NotTo(gomega.HaveOccurred()) - // Wait that Deployment is created - g.Eventually(func() error { - return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.Deployment.Name}, tc.Deployment) - }, timeout).ShouldNot(gomega.HaveOccurred()) - } + // Wait that Deployment is created + g.Eventually(func() error { + return c.Get(context.TODO(), types.NamespacedName{Namespace: namespace, Name: tc.Deployment.Name}, tc.Deployment) + }, timeout).ShouldNot(gomega.HaveOccurred()) + } - object, _ := util.ConvertObjectToUnstructured(tc.Pod) - jobKind, jobName, err := si.getKatibJob(object, namespace) - if !tc.Err && err != nil { - t.Errorf("Case %v failed. Error %v", tc.TestDescription, err) - } else if !tc.Err && (tc.ExpectedJobKind != jobKind || tc.ExpectedJobName != jobName) { - t.Errorf("Case %v failed. Expected jobKind %v, got %v, Expected jobName %v, got %v", - tc.TestDescription, tc.ExpectedJobKind, jobKind, tc.ExpectedJobName, jobName) - } else if tc.Err && err == nil { - t.Errorf("Expected error got nil") - } + object, _ := util.ConvertObjectToUnstructured(tc.Pod) + jobKind, jobName, err := si.getKatibJob(object, namespace) + if diff := cmp.Diff(tc.WantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from getKatibJob (-want,+got):\n%s", diff) + } + if tc.WantError == nil && (tc.WantJobKind != jobKind || tc.WantJobName != jobName) { + t.Errorf("Unexpected error from getKatibJob, expected jobKind %v, got %v, expected jobName %v, got %v", + tc.WantJobKind, jobKind, tc.WantJobName, jobName) + } + }) } } func TestIsPrimaryPod(t *testing.T) { - testCases := []struct { + cases := map[string]struct { podLabels map[string]string primaryPodLabels map[string]string isPrimary bool - testDescription string }{ - { + "Pod contains all labels from primary pod labels": { podLabels: map[string]string{ "test-key-1": "test-value-1", "test-key-2": "test-value-2", @@ -987,10 +961,9 @@ func TestIsPrimaryPod(t *testing.T) { "test-key-1": "test-value-1", "test-key-2": "test-value-2", }, - isPrimary: true, - testDescription: "Pod contains all labels from primary pod labels", + isPrimary: true, }, - { + "Pod doesn't contain primary label": { podLabels: map[string]string{ "test-key-1": "test-value-1", }, @@ -998,25 +971,23 @@ func TestIsPrimaryPod(t *testing.T) { "test-key-1": "test-value-1", "test-key-2": "test-value-2", }, - isPrimary: false, - testDescription: "Pod doesn't contain primary label", + isPrimary: false, }, - { + "Pod contains label with incorrect value": { podLabels: map[string]string{ "test-key-1": "invalid", }, primaryPodLabels: map[string]string{ "test-key-1": "test-value-1", }, - isPrimary: false, - testDescription: "Pod contains label with incorrect value", + isPrimary: false, }, } - for _, tc := range testCases { + for name, tc := range cases { isPrimary := isPrimaryPod(tc.podLabels, tc.primaryPodLabels) if isPrimary != tc.isPrimary { - t.Errorf("Case %v. Expected isPrimary %v, got %v", tc.testDescription, tc.isPrimary, isPrimary) + t.Errorf("Case %v. Expected isPrimary %v, got %v", name, tc.isPrimary, isPrimary) } } } @@ -1028,13 +999,12 @@ func TestMutatePodMetadata(t *testing.T) { consts.LabelTrialName: "test-trial", } - testCases := []struct { - pod *v1.Pod - trial *trialsv1beta1.Trial - mutatedPod *v1.Pod - testDescription string + cases := map[string]struct { + pod *v1.Pod + trial *trialsv1beta1.Trial + mutatedPod *v1.Pod }{ - { + "Mutated Pod should contain label from the origin Pod and Trial": { pod: &v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Labels: map[string]string{ @@ -1055,14 +1025,13 @@ func TestMutatePodMetadata(t *testing.T) { Labels: mutatedPodLabels, }, }, - testDescription: "Mutated Pod should contain label from the origin Pod and Trial", }, } - for _, tc := range testCases { + for _, tc := range cases { mutatePodMetadata(tc.pod, tc.trial) - if !reflect.DeepEqual(tc.mutatedPod, tc.pod) { - t.Errorf("Case %v. Expected Pod %v, got %v", tc.testDescription, tc.mutatedPod, tc.pod) + if diff := cmp.Diff(tc.mutatedPod, tc.pod); len(diff) != 0 { + t.Errorf("Unexpected pod from mutatePodMetadata (-want,+got):\n%s", diff) } } } diff --git a/pkg/webhook/v1beta1/pod/utils.go b/pkg/webhook/v1beta1/pod/utils.go index 6381bf6d895..9349dea8d4c 100644 --- a/pkg/webhook/v1beta1/pod/utils.go +++ b/pkg/webhook/v1beta1/pod/utils.go @@ -18,6 +18,7 @@ package pod import ( "context" + "errors" "fmt" "path/filepath" "strings" @@ -35,6 +36,8 @@ import ( mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" ) +var ErrPrimaryContainerNotFound = errors.New("unable to find primary container in mutated pod containers") + func isPrimaryPod(podLabels, primaryLabels map[string]string) bool { for primaryKey, primaryValue := range primaryLabels { @@ -190,8 +193,7 @@ func wrapWorkerContainer(trial *trialsv1beta1.Trial, pod *v1.Pod, namespace, c.Command = command c.Args = []string{argsStr} } else { - return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v", - trial.Spec.PrimaryContainerName, pod.Spec.Containers) + return fmt.Errorf("%w: primary container: %v, mutated pod containers: %v", ErrPrimaryContainerNotFound, trial.Spec.PrimaryContainerName, pod.Spec.Containers) } return nil }