diff --git a/pkg/controller.v1beta1/experiment/manifest/generator.go b/pkg/controller.v1beta1/experiment/manifest/generator.go index a41cc2f0c89..c3d0a2a14bc 100644 --- a/pkg/controller.v1beta1/experiment/manifest/generator.go +++ b/pkg/controller.v1beta1/experiment/manifest/generator.go @@ -17,6 +17,7 @@ limitations under the License. package manifest import ( + "errors" "fmt" "regexp" "strings" @@ -33,6 +34,15 @@ import ( "github.com/kubeflow/katib/pkg/util/v1beta1/katibconfig" ) +var ( + errConfigMapNotFound = errors.New("configMap not found") + errConvertStringToUnstructuredFailed = errors.New("failed to convert string to unstructured") + errConvertUnstructuredToStringFailed = errors.New("failed to convert unstructured to string") + errParamNotFoundInParameterAssignment = errors.New("unable to find non-meta parameter from TrialParameters in ParameterAssignment") + errParamNotFoundInTrialParameters = errors.New("unable to find parameter from ParameterAssignment in TrialParameters") + errTrialTemplateNotFound = errors.New("unable to find trial template in ConfigMap") +) + // Generator is the type for manifests Generator. type Generator interface { InjectClient(c client.Client) @@ -86,7 +96,7 @@ func (g *DefaultGenerator) GetRunSpecWithHyperParameters(experiment *experiments // Convert Trial template to unstructured runSpec, err := util.ConvertStringToUnstructured(replacedTemplate) if err != nil { - return nil, fmt.Errorf("ConvertStringToUnstructured failed: %v", err) + return nil, fmt.Errorf("%w: %w", errConvertStringToUnstructuredFailed, err) } // Set name and namespace for Run Spec @@ -108,7 +118,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi if trialSpec == nil { trialSpec, err = util.ConvertStringToUnstructured(trialTemplate) if err != nil { - return "", fmt.Errorf("ConvertStringToUnstructured failed: %v", err) + return "", fmt.Errorf("%w: %w", errConvertStringToUnstructuredFailed, err) } } @@ -131,7 +141,7 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi nonMetaParamCount += 1 continue } else { - return "", fmt.Errorf("Unable to find parameter: %v in parameter assignment %v", param.Reference, assignmentsMap) + return "", fmt.Errorf("%w: parameter: %v, parameter assignment: %v", errParamNotFoundInParameterAssignment, param.Reference, assignmentsMap) } } metaRefKey = sub[1] @@ -172,9 +182,10 @@ func (g *DefaultGenerator) applyParameters(experiment *experimentsv1beta1.Experi } } - // Number of parameters must be equal + // Number of assignment parameters must be equal to the number of non-meta trial parameters + // i.e. all parameters in ParameterAssignment must be in TrialParameters if len(assignments) != nonMetaParamCount { - return "", fmt.Errorf("Number of TrialAssignment: %v != number of nonMetaTrialParameters in TrialSpec: %v", len(assignments), nonMetaParamCount) + return "", fmt.Errorf("%w: parameter assignments: %v, non-meta trial parameter count: %v", errParamNotFoundInTrialParameters, assignments, nonMetaParamCount) } // Replacing placeholders with parameter values @@ -194,7 +205,7 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim if trialSource.TrialSpec != nil { trialTemplateString, err = util.ConvertUnstructuredToString(trialSource.TrialSpec) if err != nil { - return "", fmt.Errorf("ConvertUnstructuredToString failed: %v", err) + return "", fmt.Errorf("%w: %w", errConvertUnstructuredToStringFailed, err) } } else { configMapNS := trialSource.ConfigMap.ConfigMapNamespace @@ -202,12 +213,12 @@ func (g *DefaultGenerator) GetTrialTemplate(instance *experimentsv1beta1.Experim templatePath := trialSource.ConfigMap.TemplatePath configMap, err := g.client.GetConfigMap(configMapName, configMapNS) if err != nil { - return "", fmt.Errorf("GetConfigMap failed: %v", err) + return "", fmt.Errorf("%w: %w", errConfigMapNotFound, err) } var ok bool trialTemplateString, ok = configMap[templatePath] if !ok { - return "", fmt.Errorf("TemplatePath: %v not found in configMap: %v", templatePath, configMap) + return "", fmt.Errorf("%w: TemplatePath: %v, ConfigMap: %v", errTrialTemplateNotFound, templatePath, configMap) } } diff --git a/pkg/controller.v1beta1/experiment/manifest/generator_test.go b/pkg/controller.v1beta1/experiment/manifest/generator_test.go index dabd2631063..57e2789a2ac 100644 --- a/pkg/controller.v1beta1/experiment/manifest/generator_test.go +++ b/pkg/controller.v1beta1/experiment/manifest/generator_test.go @@ -17,11 +17,11 @@ limitations under the License. package manifest import ( - "errors" "math" - "reflect" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "go.uber.org/mock/gomock" batchv1 "k8s.io/api/batch/v1" v1 "k8s.io/api/core/v1" @@ -88,23 +88,18 @@ func TestGetRunSpecWithHP(t *testing.T) { t.Errorf("ConvertObjectToUnstructured failed: %v", err) } - tcs := []struct { - instance *experimentsv1beta1.Experiment - parameterAssignments []commonapiv1beta1.ParameterAssignment - expectedRunSpec *unstructured.Unstructured - err bool - testDescription string + cases := map[string]struct { + instance *experimentsv1beta1.Experiment + parameterAssignments []commonapiv1beta1.ParameterAssignment + wantRunSpecWithHyperParameters *unstructured.Unstructured + wantError error }{ - // Valid run - { - instance: newFakeInstance(), - parameterAssignments: newFakeParameterAssignment(), - expectedRunSpec: expectedRunSpec, - err: false, - testDescription: "Run with valid parameters", + "Run with valid parameters": { + instance: newFakeInstance(), + parameterAssignments: newFakeParameterAssignment(), + wantRunSpecWithHyperParameters: expectedRunSpec, }, - // Invalid JSON in unstructured - { + "Invalid JSON in Unstructured Trial template": { instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() trialSpec := i.Spec.TrialTemplate.TrialSource.TrialSpec @@ -114,48 +109,45 @@ func TestGetRunSpecWithHP(t *testing.T) { return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid JSON in Trial template", + wantError: errConvertUnstructuredToStringFailed, }, - // len(parameterAssignment) != len(trialParameters) - { + "Non-meta parameter from TrialParameters not found in ParameterAssignment": { instance: newFakeInstance(), parameterAssignments: func() []commonapiv1beta1.ParameterAssignment { pa := newFakeParameterAssignment() - pa = pa[1:] + pa[0] = commonapiv1beta1.ParameterAssignment{ + Name: "invalid-name", + Value: "invalid-value", + } return pa }(), - err: true, - testDescription: "Number of parameter assignments is not equal to number of Trial parameters", + wantError: errParamNotFoundInParameterAssignment, }, - // Parameter from assignments not found in Trial parameters - { + // case in which the lengths of trial parameters and parameter assignments are different + "Parameter from ParameterAssignment not found in TrialParameters": { instance: newFakeInstance(), parameterAssignments: func() []commonapiv1beta1.ParameterAssignment { pa := newFakeParameterAssignment() - pa[0] = commonapiv1beta1.ParameterAssignment{ - Name: "invalid-name", - Value: "invalid-value", - } + pa = append(pa, commonapiv1beta1.ParameterAssignment{ + Name: "extra-name", + Value: "extra-value", + }) return pa }(), - err: true, - testDescription: "Trial parameters don't have parameter from assignments", + wantError: errParamNotFoundInTrialParameters, }, } - for _, tc := range tcs { - actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) - - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(tc.expectedRunSpec, actualRunSpec) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedRunSpec.Object, actualRunSpec.Object) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) } - } + if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 { + t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) + } + }) } } @@ -204,25 +196,6 @@ spec: - --momentum=${trialParameters.momentum} - --invalidParameter={'num_layers': 2, 'input_sizes': [32, 32, 3]}` - validGetConfigMap1 := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( - map[string]string{templatePath: trialSpec}, nil) - - invalidConfigMapName := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( - nil, errors.New("Unable to get ConfigMap")) - - validGetConfigMap3 := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( - map[string]string{templatePath: trialSpec}, nil) - - invalidTemplate := c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( - map[string]string{templatePath: invalidTrialSpec}, nil) - - gomock.InOrder( - validGetConfigMap1, - invalidConfigMapName, - validGetConfigMap3, - invalidTemplate, - ) - // We can't compare structures, because in ConfigMap trialSpec is a string and creationTimestamp was not added expectedStr := `apiVersion: batch/v1 kind: Job @@ -244,19 +217,23 @@ spec: - "--momentum=0.9"` expectedRunSpec, err := util.ConvertStringToUnstructured(expectedStr) - if err != nil { - t.Errorf("ConvertStringToUnstructured failed: %v", err) + if diff := cmp.Diff(nil, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("ConvertStringToUnstructured failed (-want,+got):\n%s", diff) } - tcs := []struct { - instance *experimentsv1beta1.Experiment - parameterAssignments []commonapiv1beta1.ParameterAssignment - err bool - testDescription string + cases := map[string]struct { + mockConfigMapGetter func() *gomock.Call + instance *experimentsv1beta1.Experiment + parameterAssignments []commonapiv1beta1.ParameterAssignment + wantRunSpecWithHyperParameters *unstructured.Unstructured + wantError error }{ - // Valid run - // validGetConfigMap1 case - { + "Run with valid parameters": { + mockConfigMapGetter: func() *gomock.Call { + return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( + map[string]string{templatePath: trialSpec}, nil, + ) + }, instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -268,13 +245,15 @@ spec: } return i }(), - parameterAssignments: newFakeParameterAssignment(), - err: false, - testDescription: "Run with valid parameters", + parameterAssignments: newFakeParameterAssignment(), + wantRunSpecWithHyperParameters: expectedRunSpec, }, - // Invalid ConfigMap name - // invalidConfigMapName case - { + "Invalid ConfigMap name": { + mockConfigMapGetter: func() *gomock.Call { + return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( + nil, errConfigMapNotFound, + ) + }, instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -285,12 +264,14 @@ spec: return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid ConfigMap name", + wantError: errConfigMapNotFound, }, - // Invalid template path in ConfigMap name - // validGetConfigMap3 case - { + "Invalid template path in ConfigMap name": { + mockConfigMapGetter: func() *gomock.Call { + return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( + map[string]string{templatePath: trialSpec}, nil, + ) + }, instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -303,14 +284,16 @@ spec: return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid template path in ConfigMap", + wantError: errTrialTemplateNotFound, }, - // Invalid Trial template spec in ConfigMap // Trial template is a string in ConfigMap // Because of that, user can specify not valid unstructured template - // invalidTemplate case - { + "Invalid trial spec in ConfigMap": { + mockConfigMapGetter: func() *gomock.Call { + return c.EXPECT().GetConfigMap(gomock.Any(), gomock.Any()).Return( + map[string]string{templatePath: invalidTrialSpec}, nil, + ) + }, instance: func() *experimentsv1beta1.Experiment { i := newFakeInstance() i.Spec.TrialTemplate.TrialSource = experimentsv1beta1.TrialSource{ @@ -323,22 +306,21 @@ spec: return i }(), parameterAssignments: newFakeParameterAssignment(), - err: true, - testDescription: "Invalid Trial spec in ConfigMap", + wantError: errConvertStringToUnstructuredFailed, }, } - for _, tc := range tcs { - actualRunSpec, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(expectedRunSpec, actualRunSpec) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, expectedRunSpec.Object, actualRunSpec.Object) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + tc.mockConfigMapGetter() + got, err := p.GetRunSpecWithHyperParameters(tc.instance, "trial-name", "trial-namespace", tc.parameterAssignments) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantRunSpecWithHyperParameters, got); len(diff) != 0 { + t.Errorf("Unexpected run spec from GetRunSpecWithHyperParameters (-want,+got):\n%s", diff) } - } + }) } } diff --git a/pkg/controller.v1beta1/trial/util/job_util_test.go b/pkg/controller.v1beta1/trial/util/job_util_test.go index d2a018967c3..c1908144df9 100644 --- a/pkg/controller.v1beta1/trial/util/job_util_test.go +++ b/pkg/controller.v1beta1/trial/util/job_util_test.go @@ -17,7 +17,6 @@ limitations under the License. package util import ( - "reflect" "testing" batchv1 "k8s.io/api/batch/v1" @@ -25,6 +24,8 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" trialsv1beta1 "github.com/kubeflow/katib/pkg/apis/controller/trials/v1beta1" "github.com/kubeflow/katib/pkg/controller.v1beta1/util" ) @@ -39,14 +40,13 @@ func TestGetDeployedJobStatus(t *testing.T) { successCondition := "status.conditions.#(type==\"Complete\")#|#(status==\"True\")#" failureCondition := "status.conditions.#(type==\"Failed\")#|#(status==\"True\")#" - tcs := []struct { - trial *trialsv1beta1.Trial - deployedJob *unstructured.Unstructured - expectedTrialJobStatus *TrialJobStatus - err bool - testDescription string + cases := map[string]struct { + trial *trialsv1beta1.Trial + deployedJob *unstructured.Unstructured + wantTrialJobStatus *TrialJobStatus + wantError error }{ - { + "Job status is running": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: func() *unstructured.Unstructured { job := newFakeJob() @@ -54,28 +54,24 @@ func TestGetDeployedJobStatus(t *testing.T) { job.Status.Conditions[1].Status = corev1.ConditionFalse return newFakeDeployedJob(job) }(), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobRunning, } }(), - err: false, - testDescription: "Job status is running", }, - { + "Job status is succeeded, reason and message must be returned": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: newFakeDeployedJob(newFakeJob()), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobSucceeded, Message: testMessage, Reason: testReason, } }(), - err: false, - testDescription: "Job status is succeeded, reason and message must be returned", }, - { + "Job status is failed, reason and message must be returned": { trial: newFakeTrial(successCondition, failureCondition), deployedJob: func() *unstructured.Unstructured { job := newFakeJob() @@ -83,41 +79,35 @@ func TestGetDeployedJobStatus(t *testing.T) { job.Status.Conditions[1].Status = corev1.ConditionFalse return newFakeDeployedJob(job) }(), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobFailed, Message: testMessage, Reason: testReason, } }(), - err: false, - testDescription: "Job status is failed, reason and message must be returned", }, - { + "Job status is succeeded because status.succeeded = 1": { trial: newFakeTrial("status.[@this].#(succeeded==1)", failureCondition), deployedJob: newFakeDeployedJob(newFakeJob()), - expectedTrialJobStatus: func() *TrialJobStatus { + wantTrialJobStatus: func() *TrialJobStatus { return &TrialJobStatus{ Condition: JobSucceeded, } }(), - err: false, - testDescription: "Job status is succeeded because status.succeeded = 1", }, } - for _, tc := range tcs { - actualTrialJobStatus, err := GetDeployedJobStatus(tc.trial, tc.deployedJob) - - if tc.err && err == nil { - t.Errorf("Case: %v failed. Expected err, got nil", tc.testDescription) - } else if !tc.err { - if err != nil { - t.Errorf("Case: %v failed. Expected nil, got %v", tc.testDescription, err) - } else if !reflect.DeepEqual(tc.expectedTrialJobStatus, actualTrialJobStatus) { - t.Errorf("Case: %v failed. Expected %v\n got %v", tc.testDescription, tc.expectedTrialJobStatus, actualTrialJobStatus) + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := GetDeployedJobStatus(tc.trial, tc.deployedJob) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from GetDeployedJobStatus() (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantTrialJobStatus, got); len(diff) != 0 { + t.Errorf("Unexpected trial job status from GetDeployedJobStatus() (-want,+got):\n%s", diff) } - } + }) } } @@ -154,6 +144,7 @@ func newFakeJob() *batchv1.Job { }, } } + func newFakeDeployedJob(job interface{}) *unstructured.Unstructured { jobUnstructured, _ := util.ConvertObjectToUnstructured(job) diff --git a/pkg/suggestion/v1beta1/goptuna/converter_test.go b/pkg/suggestion/v1beta1/goptuna/converter_test.go index 1e3189a773d..d92b0f391ca 100644 --- a/pkg/suggestion/v1beta1/goptuna/converter_test.go +++ b/pkg/suggestion/v1beta1/goptuna/converter_test.go @@ -17,48 +17,44 @@ limitations under the License. package suggestion_goptuna_v1beta1 import ( - "reflect" "testing" "github.com/c-bata/goptuna" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" api_v1_beta1 "github.com/kubeflow/katib/pkg/apis/manager/v1beta1" ) func Test_toGoptunaDirection(t *testing.T) { - for _, tt := range []struct { - name string + for name, tc := range map[string]struct { objectiveType api_v1_beta1.ObjectiveType - expected goptuna.StudyDirection + wantDirection goptuna.StudyDirection }{ - { - name: "minimize", + "minimize": { objectiveType: api_v1_beta1.ObjectiveType_MINIMIZE, - expected: goptuna.StudyDirectionMinimize, + wantDirection: goptuna.StudyDirectionMinimize, }, - { - name: "maximize", + "maximize": { objectiveType: api_v1_beta1.ObjectiveType_MAXIMIZE, - expected: goptuna.StudyDirectionMaximize, + wantDirection: goptuna.StudyDirectionMaximize, }, } { - t.Run(tt.name, func(t *testing.T) { - got := toGoptunaDirection(tt.objectiveType) - if got != tt.expected { - t.Errorf("toGoptunaDirection() got = %v, want %v", got, tt.expected) + t.Run(name, func(t *testing.T) { + got := toGoptunaDirection(tc.objectiveType) + if diff := cmp.Diff(tc.wantDirection, got); len(diff) != 0 { + t.Errorf("Unexpected direction from toGoptunaDirection (-want,+got):\n%s", diff) } }) } } func Test_toGoptunaSearchSpace(t *testing.T) { - tests := []struct { - name string - parameters []*api_v1_beta1.ParameterSpec - want map[string]interface{} - wantErr bool + cases := map[string]struct { + parameters []*api_v1_beta1.ParameterSpec + wantSearchSpace map[string]interface{} + wantError error }{ - { - name: "Double parameter type", + "Double parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-double", @@ -69,16 +65,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-double": goptuna.UniformDistribution{ High: 5.5, Low: 1.5, }, }, - wantErr: false, }, - { - name: "Double parameter type with step", + "Double parameter type with step": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-double", @@ -90,17 +84,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-double": goptuna.DiscreteUniformDistribution{ High: 5.5, Low: 1.5, Q: 0.5, }, }, - wantErr: false, }, - { - name: "Int parameter type", + "Int parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-int", @@ -111,16 +103,14 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-int": goptuna.IntUniformDistribution{ High: 5, Low: 1, }, }, - wantErr: false, }, - { - name: "Int parameter type with step", + "Int parameter type with step": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-int", @@ -132,17 +122,15 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-int": goptuna.StepIntUniformDistribution{ High: 5, Low: 1, Step: 2, }, }, - wantErr: false, }, - { - name: "Discrete parameter type", + "Discrete parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-discrete", @@ -152,15 +140,13 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-discrete": goptuna.CategoricalDistribution{ Choices: []string{"3", "2", "6"}, }, }, - wantErr: false, }, - { - name: "Categorical parameter type", + "Categorical parameter type": { parameters: []*api_v1_beta1.ParameterSpec{ { Name: "param-categorical", @@ -170,23 +156,21 @@ func Test_toGoptunaSearchSpace(t *testing.T) { }, }, }, - want: map[string]interface{}{ + wantSearchSpace: map[string]interface{}{ "param-categorical": goptuna.CategoricalDistribution{ Choices: []string{"cat1", "cat2", "cat3"}, }, }, - wantErr: false, }, } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := toGoptunaSearchSpace(tt.parameters) - if (err != nil) != tt.wantErr { - t.Errorf("toGoptunaSearchSpace() error = %v, wantErr %v", err, tt.wantErr) - return + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + got, err := toGoptunaSearchSpace(tc.parameters) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from toGoptunaSearchSpace (-want,+got):\n%s", diff) } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("toGoptunaSearchSpace() got = %v, want %v", got, tt.want) + if diff := cmp.Diff(tc.wantSearchSpace, got); len(diff) != 0 { + t.Errorf("Unexpected search space from toGoptunaSearchSpace (-want,+got):\n%s", diff) } }) } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook.go b/pkg/webhook/v1beta1/pod/inject_webhook.go index 3932a6bbfd3..96deaf23c1d 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook.go @@ -20,6 +20,7 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" "path/filepath" "strconv" @@ -47,6 +48,13 @@ import ( var log = logf.Log.WithName("injector-webhook") +var ( + errInvalidOwnerAPIVersion = errors.New("invalid owner API version") + errInvalidSuggestionName = errors.New("invalid suggestion name") + errPodNotBelongToKatibJob = errors.New("pod does not belong to Katib Job") + errFailedToGetTrialTemplateJob = errors.New("unable to get Job in the trialTemplate") +) + // SidecarInjector injects metrics collect sidecar to the primary pod. type SidecarInjector struct { client client.Client @@ -266,7 +274,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // Get group and version from owner API version gv, err := schema.ParseGroupVersion(owners[i].APIVersion) if err != nil { - return "", "", err + return "", "", fmt.Errorf("%w: %w", errInvalidOwnerAPIVersion, err) } gvk := schema.GroupVersionKind{ Group: gv.Group, @@ -279,7 +287,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // Nested object namespace must be equal to object namespace err = s.client.Get(context.TODO(), apitypes.NamespacedName{Name: owners[i].Name, Namespace: namespace}, nestedJob) if err != nil { - return "", "", err + return "", "", fmt.Errorf("%w: %w", errFailedToGetTrialTemplateJob, err) } // Recursively search for Trial ownership in nested object jobKind, jobName, err = s.getKatibJob(nestedJob, namespace) @@ -292,7 +300,7 @@ func (s *SidecarInjector) getKatibJob(object *unstructured.Unstructured, namespa // If jobKind is empty after the loop, Trial doesn't own the object if jobKind == "" { - return "", "", errors.New("The Pod doesn't belong to Katib Job") + return "", "", errPodNotBelongToKatibJob } return jobKind, jobName, nil @@ -329,7 +337,7 @@ func (s *SidecarInjector) getMetricsCollectorArgs(trial *trialsv1beta1.Trial, me suggestion := &suggestionsv1beta1.Suggestion{} err := s.client.Get(context.TODO(), apitypes.NamespacedName{Name: suggestionName, Namespace: trial.Namespace}, suggestion) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", errInvalidSuggestionName, err) } args = append(args, "-s-earlystop", util.GetEarlyStoppingEndpoint(suggestion)) } diff --git a/pkg/webhook/v1beta1/pod/inject_webhook_test.go b/pkg/webhook/v1beta1/pod/inject_webhook_test.go index d8ac88254e6..8350264cfaa 100644 --- a/pkg/webhook/v1beta1/pod/inject_webhook_test.go +++ b/pkg/webhook/v1beta1/pod/inject_webhook_test.go @@ -76,13 +76,13 @@ func TestWrapWorkerContainer(t *testing.T) { metricsFile := "metric.log" - testCases := map[string]struct { + cases := map[string]struct { trial *trialsv1beta1.Trial pod *v1.Pod metricsFile string pathKind common.FileSystemKind - expectedPod *v1.Pod - err bool + wantPod *v1.Pod + wantError error }{ "Tensorflow container without sh -c": { trial: trial, @@ -100,7 +100,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, metricsFile: metricsFile, pathKind: common.FileKind, - expectedPod: &v1.Pod{ + wantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -115,7 +115,6 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - err: false, }, "Tensorflow container with sh -c": { trial: trial, @@ -134,7 +133,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, metricsFile: metricsFile, pathKind: common.FileKind, - expectedPod: &v1.Pod{ + wantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -149,7 +148,6 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - err: false, }, "Training pod doesn't have primary container": { trial: trial, @@ -163,7 +161,16 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, pathKind: common.FileKind, - err: true, + wantPod: &v1.Pod{ + Spec: v1.PodSpec{ + Containers: []v1.Container{ + { + Name: "not-primary-container", + }, + }, + }, + }, + wantError: errPrimaryContainerNotFound, }, "Container with early stopping command": { trial: func() *trialsv1beta1.Trial { @@ -191,7 +198,7 @@ func TestWrapWorkerContainer(t *testing.T) { }, metricsFile: metricsFile, pathKind: common.FileKind, - expectedPod: &v1.Pod{ + wantPod: &v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -210,21 +217,17 @@ func TestWrapWorkerContainer(t *testing.T) { }, }, }, - err: false, }, } - for name, tc := range testCases { + for name, tc := range cases { t.Run(name, func(t *testing.T) { err := wrapWorkerContainer(tc.trial, tc.pod, tc.trial.Namespace, tc.metricsFile, tc.pathKind) - if tc.err && err == nil { - t.Errorf("Expected error, got nil") - } else if !tc.err { - if err != nil { - t.Errorf("Expected nil, got error: %v", err) - } else if diff := cmp.Diff(tc.expectedPod.Spec.Containers, tc.pod.Spec.Containers); len(diff) != 0 { - t.Errorf("Unexpected pod (-want,got):\n%s", diff) - } + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from wrapWorkerContainer (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantPod.Spec.Containers, tc.pod.Spec.Containers); len(diff) != 0 { + t.Errorf("Unexpected pod from wrapWorkerContainer (-want,+got):\n%s", diff) } }) } @@ -314,14 +317,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, } - testCases := map[string]struct { + cases := map[string]struct { trial *trialsv1beta1.Trial metricNames string mCSpec common.MetricsCollectorSpec earlyStoppingRules []string katibConfig configv1beta1.MetricsCollectorConfig - expectedArgs []string - err bool + wantArgs []string + wantError error }{ "StdOut MC": { trial: testTrial, @@ -334,7 +337,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { katibConfig: configv1beta1.MetricsCollectorConfig{ WaitAllProcesses: &waitAllProcessesValue, }, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -365,7 +368,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -390,7 +393,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -413,7 +416,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -430,7 +433,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -451,7 +454,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -468,7 +471,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, }, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -485,7 +488,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, earlyStoppingRules: earlyStoppingRules, katibConfig: configv1beta1.MetricsCollectorConfig{}, - expectedArgs: []string{ + wantArgs: []string{ "-t", testTrialName, "-m", testMetricName, "-o-type", string(testObjective), @@ -510,7 +513,7 @@ func TestGetMetricsCollectorArgs(t *testing.T) { }, earlyStoppingRules: earlyStoppingRules, katibConfig: configv1beta1.MetricsCollectorConfig{}, - err: true, + wantError: errInvalidSuggestionName, }, } @@ -521,17 +524,14 @@ func TestGetMetricsCollectorArgs(t *testing.T) { return c.Get(context.TODO(), types.NamespacedName{Namespace: testNamespace, Name: testSuggestionName}, testSuggestion) }, timeout).ShouldNot(gomega.HaveOccurred()) - for name, tc := range testCases { + for name, tc := range cases { t.Run(name, func(t *testing.T) { - args, err := si.getMetricsCollectorArgs(tc.trial, tc.metricNames, tc.mCSpec, tc.katibConfig, tc.earlyStoppingRules) - if !tc.err && err != nil { - t.Errorf("Expected nil, got %v", err) - } else if tc.err && err == nil { - t.Error("Expected err, got nil") - } else if !tc.err { - if diff := cmp.Diff(tc.expectedArgs, args); len(diff) != 0 { - t.Errorf("Unexpected Args (-want,got):\n%s", diff) - } + got, err := si.getMetricsCollectorArgs(tc.trial, tc.metricNames, tc.mCSpec, tc.katibConfig, tc.earlyStoppingRules) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from getMetricsCollectorArgs (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantArgs, got); len(diff) != 0 { + t.Errorf("Unexpected args from getMetricsCollectorArgs (-want,+got):\n%s", diff) } }) } @@ -573,13 +573,13 @@ func TestNeedWrapWorkerContainer(t *testing.T) { func TestMutateMetricsCollectorVolume(t *testing.T) { testCases := map[string]struct { pod v1.Pod - expectedPod v1.Pod - JobKind string - MountPath string - SidecarContainerName string - PrimaryContainerName string + wantPod v1.Pod + jobKind string + mountPath string + sidecarContainerName string + primaryContainerName string pathKind common.FileSystemKind - err bool + wantError error }{ "Valid case": { pod: v1.Pod{ @@ -597,7 +597,7 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { }, }, }, - expectedPod: v1.Pod{ + wantPod: v1.Pod{ Spec: v1.PodSpec{ Containers: []v1.Container{ { @@ -632,9 +632,9 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { }, }, }, - MountPath: common.DefaultFilePath, - SidecarContainerName: "metrics-collector", - PrimaryContainerName: "train-job", + mountPath: common.DefaultFilePath, + sidecarContainerName: "metrics-collector", + primaryContainerName: "train-job", pathKind: common.FileKind, }, } @@ -643,14 +643,15 @@ func TestMutateMetricsCollectorVolume(t *testing.T) { t.Run(name, func(t *testing.T) { err := mutateMetricsCollectorVolume( &tc.pod, - tc.MountPath, - tc.SidecarContainerName, - tc.PrimaryContainerName, + tc.mountPath, + tc.sidecarContainerName, + tc.primaryContainerName, tc.pathKind) - if err != nil { - t.Errorf("mutateMetricsCollectorVolume failed: %v", err) - } else if diff := cmp.Diff(tc.expectedPod, tc.pod); len(diff) != 0 { - t.Errorf("Unexpected mutated pod result (-want,got):\n%s", diff) + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from mutateMetricsCollectorVolume (-want,+got):\n%s", diff) + } + if diff := cmp.Diff(tc.wantPod, tc.pod); len(diff) != 0 { + t.Errorf("Unexpected pod from mutateMetricsCollectorVolume (-want,+got):\n%s", diff) } }) } @@ -719,13 +720,13 @@ func TestGetKatibJob(t *testing.T) { deployName := "deploy-name" jobName := "job-name" - testCases := map[string]struct { - pod *v1.Pod - job *batchv1.Job - deployment *appsv1.Deployment - expectedJobKind string - expectedJobName string - err bool + cases := map[string]struct { + pod *v1.Pod + job *batchv1.Job + deployment *appsv1.Deployment + wantJobKind string + wantJobName string + wantError error }{ "Valid run with ownership sequence: Trial -> Job -> Pod": { pod: &v1.Pod{ @@ -768,9 +769,8 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - expectedJobKind: "Job", - expectedJobName: jobName + "-1", - err: false, + wantJobKind: "Job", + wantJobName: jobName + "-1", }, "Valid run with ownership sequence: Trial -> Deployment -> Pod, Job -> Pod": { pod: &v1.Pod{ @@ -846,9 +846,8 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - expectedJobKind: "Deployment", - expectedJobName: deployName + "-2", - err: false, + wantJobKind: "Deployment", + wantJobName: deployName + "-2", }, "Run for not Trial's pod with ownership sequence: Job -> Pod": { pod: &v1.Pod{ @@ -883,7 +882,7 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - err: true, + wantError: errPodNotBelongToKatibJob, }, "Run when Pod owns Job that doesn't exists": { pod: &v1.Pod{ @@ -899,7 +898,7 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - err: true, + wantError: errFailedToGetTrialTemplateJob, }, "Run when Pod owns Job with invalid API version": { pod: &v1.Pod{ @@ -915,13 +914,13 @@ func TestGetKatibJob(t *testing.T) { }, }, }, - err: true, + wantError: errInvalidOwnerAPIVersion, }, } - for name, tc := range testCases { - // Create Job if it is needed + for name, tc := range cases { t.Run(name, func(t *testing.T) { + // Create Job if it is needed if tc.job != nil { jobUnstr, err := util.ConvertObjectToUnstructured(tc.job) gvk := schema.GroupVersionKind{ @@ -954,13 +953,12 @@ func TestGetKatibJob(t *testing.T) { object, _ := util.ConvertObjectToUnstructured(tc.pod) jobKind, jobName, err := si.getKatibJob(object, namespace) - if !tc.err && err != nil { - t.Errorf("Error %v", err) - } else if !tc.err && (tc.expectedJobKind != jobKind || tc.expectedJobName != jobName) { - t.Errorf("Expected jobKind %v, got %v, Expected jobName %v, got %v", - tc.expectedJobKind, jobKind, tc.expectedJobName, jobName) - } else if tc.err && err == nil { - t.Errorf("Expected error got nil") + if diff := cmp.Diff(tc.wantError, err, cmpopts.EquateErrors()); len(diff) != 0 { + t.Errorf("Unexpected error from getKatibJob (-want,+got):\n%s", diff) + } + if tc.wantError == nil && (tc.wantJobKind != jobKind || tc.wantJobName != jobName) { + t.Errorf("Unexpected error from getKatibJob, expected jobKind %v, got %v, expected jobName %v, got %v", + tc.wantJobKind, jobKind, tc.wantJobName, jobName) } }) } diff --git a/pkg/webhook/v1beta1/pod/utils.go b/pkg/webhook/v1beta1/pod/utils.go index 7dad82553cf..a3dc66e1cc8 100644 --- a/pkg/webhook/v1beta1/pod/utils.go +++ b/pkg/webhook/v1beta1/pod/utils.go @@ -18,6 +18,7 @@ package pod import ( "context" + "errors" "fmt" "path/filepath" "strings" @@ -35,6 +36,8 @@ import ( mccommon "github.com/kubeflow/katib/pkg/metricscollector/v1beta1/common" ) +var errPrimaryContainerNotFound = errors.New("unable to find primary container in mutated pod containers") + func isPrimaryPod(podLabels, primaryLabels map[string]string) bool { for primaryKey, primaryValue := range primaryLabels { @@ -190,8 +193,7 @@ func wrapWorkerContainer(trial *trialsv1beta1.Trial, pod *v1.Pod, namespace, c.Command = command c.Args = []string{argsStr} } else { - return fmt.Errorf("Unable to find primary container %v in mutated pod containers %v", - trial.Spec.PrimaryContainerName, pod.Spec.Containers) + return fmt.Errorf("%w: primary container: %v, mutated pod containers: %v", errPrimaryContainerNotFound, trial.Spec.PrimaryContainerName, pod.Spec.Containers) } return nil }