From ba81beb7f1ac828a95ecac70d42b471dc68f80b8 Mon Sep 17 00:00:00 2001 From: Yuki Iwai Date: Thu, 17 Oct 2024 04:35:20 +0900 Subject: [PATCH] Implement TrainingRuntime ReplicatedJob validation Signed-off-by: Yuki Iwai --- .../core/clustertrainingruntime_test.go | 16 ++--- pkg/runtime.v2/core/trainingruntime.go | 3 - pkg/runtime.v2/core/trainingruntime_test.go | 16 ++--- pkg/runtime.v2/framework/core/framework.go | 5 +- .../framework/core/framework_test.go | 23 ++++--- pkg/runtime.v2/framework/interface.go | 2 +- pkg/runtime.v2/framework/plugins/mpi/mpi.go | 3 +- .../framework/plugins/torch/torch.go | 3 +- pkg/runtime.v2/runtime_test.go | 6 +- pkg/util.v2/testing/wrapper.go | 49 ++++++++------- .../clustertrainingruntime_webhook.go | 8 ++- pkg/webhook.v2/trainingruntime_webhook.go | 24 +++++++- .../trainingruntime_webhook_test.go | 60 +++++++++++++++++++ .../webhook.v2/clustertrainingruntime_test.go | 25 ++++++++ .../webhook.v2/trainingruntime_test.go | 26 ++++++++ 15 files changed, 203 insertions(+), 66 deletions(-) create mode 100644 pkg/webhook.v2/trainingruntime_webhook_test.go diff --git a/pkg/runtime.v2/core/clustertrainingruntime_test.go b/pkg/runtime.v2/core/clustertrainingruntime_test.go index 1831d2ec62..5665c10fe5 100644 --- a/pkg/runtime.v2/core/clustertrainingruntime_test.go +++ b/pkg/runtime.v2/core/clustertrainingruntime_test.go @@ -35,7 +35,7 @@ import ( ) func TestClusterTrainingRuntimeNewObjects(t *testing.T) { - baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper(t, "test-runtime"). + baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper("test-runtime"). Clone() cases := map[string]struct { @@ -45,17 +45,17 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { wantError error }{ "succeeded to build JobSet and PodGroup": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). Obj(), clusterTrainingRuntime: baseRuntime.RuntimeSpec( - testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). ContainerImage("test:runtime"). PodGroupPolicySchedulingTimeout(120). MLPolicyNumNodes(20). @@ -68,7 +68,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { Obj(), ).Obj(), wantObjs: []client.Object{ - testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). ContainerImage(ptr.To("test:trainjob")). JobCompletionMode(batchv1.IndexedCompletion). @@ -80,7 +80,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { }). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). Obj(), - testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). MinMember(40). SchedulingTimeout(120). @@ -91,11 +91,11 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) { }, }, "missing trainingRuntime resource": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). diff --git a/pkg/runtime.v2/core/trainingruntime.go b/pkg/runtime.v2/core/trainingruntime.go index 41dea87105..1597bbf0a8 100644 --- a/pkg/runtime.v2/core/trainingruntime.go +++ b/pkg/runtime.v2/core/trainingruntime.go @@ -107,9 +107,6 @@ func (r *TrainingRuntime) buildObjects( runtime.WithPodGroupPolicy(podGroupPolicy), } for idx, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs { - if rJob.Replicas == 0 { - jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas = 1 - } replicas := jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas * ptr.Deref(rJob.Template.Spec.Completions, 1) opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, replicas, rJob.Template.Spec.Template.Spec)) } diff --git a/pkg/runtime.v2/core/trainingruntime_test.go b/pkg/runtime.v2/core/trainingruntime_test.go index ca31d8f133..a3bd63efa6 100644 --- a/pkg/runtime.v2/core/trainingruntime_test.go +++ b/pkg/runtime.v2/core/trainingruntime_test.go @@ -35,7 +35,7 @@ import ( ) func TestTrainingRuntimeNewObjects(t *testing.T) { - baseRuntime := testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test-runtime"). + baseRuntime := testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime"). Clone() cases := map[string]struct { @@ -45,13 +45,13 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { wantError error }{ "succeeded to build JobSet and PodGroup": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). SpecLabel("conflictLabel", "override"). SpecAnnotation("conflictAnnotation", "override"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). @@ -60,7 +60,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Label("conflictLabel", "overridden"). Annotation("conflictAnnotation", "overridden"). RuntimeSpec( - testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec). + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). ContainerImage("test:runtime"). PodGroupPolicySchedulingTimeout(120). MLPolicyNumNodes(20). @@ -73,7 +73,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { Obj(), ).Obj(), wantObjs: []client.Object{ - testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). Label("conflictLabel", "override"). Annotation("conflictAnnotation", "override"). PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job"). @@ -87,7 +87,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). Obj(), - testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid"). MinMember(40). SchedulingTimeout(120). @@ -98,11 +98,11 @@ func TestTrainingRuntimeNewObjects(t *testing.T) { }, }, "missing trainingRuntime resource": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("test:trainjob"). Obj(), ). diff --git a/pkg/runtime.v2/framework/core/framework.go b/pkg/runtime.v2/framework/core/framework.go index ae9f007493..8997afe467 100644 --- a/pkg/runtime.v2/framework/core/framework.go +++ b/pkg/runtime.v2/framework/core/framework.go @@ -89,7 +89,7 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo return nil } -func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { +func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { var aggregatedWarnings admission.Warnings var aggregatedErrors field.ErrorList for _, plugin := range f.customValidationPlugins { @@ -101,9 +101,6 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (ad aggregatedErrors = append(aggregatedErrors, errs...) } } - if len(aggregatedErrors) == 0 { - return aggregatedWarnings, nil - } return aggregatedWarnings, aggregatedErrors } diff --git a/pkg/runtime.v2/framework/core/framework_test.go b/pkg/runtime.v2/framework/core/framework_test.go index 4f5085781f..c3b630d923 100644 --- a/pkg/runtime.v2/framework/core/framework_test.go +++ b/pkg/runtime.v2/framework/core/framework_test.go @@ -287,24 +287,21 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) { func TestRunCustomValidationPlugins(t *testing.T) { cases := map[string]struct { - trainJob *kubeflowv2.TrainJob registry fwkplugins.Registry - oldObj client.Object - newObj client.Object + oldObj *kubeflowv2.TrainJob + newObj *kubeflowv2.TrainJob wantWarnings admission.Warnings wantError field.ErrorList }{ // Need to implement more detail testing after we implement custom validator in any plugins. "there are not any custom validations": { - trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, registry: fwkplugins.NewRegistry(), - oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), - newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + oldObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), + newObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), }, "an empty registry": { - trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}}, - oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), - newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(), + oldObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), + newObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(), }, } for name, tc := range cases { @@ -329,7 +326,7 @@ func TestRunCustomValidationPlugins(t *testing.T) { } func TestRunComponentBuilderPlugins(t *testing.T) { - jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). ResourceRequests(0, corev1.ResourceList{ corev1.ResourceCPU: resource.MustParse("2"), corev1.ResourceMemory: resource.MustParse("4Gi"), @@ -354,10 +351,10 @@ func TestRunComponentBuilderPlugins(t *testing.T) { wantObjs []client.Object }{ "coscheduling and jobset are performed": { - trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). UID("uid"). Trainer( - testingutil.MakeTrainJobTrainerWrapper(t). + testingutil.MakeTrainJobTrainerWrapper(). ContainerImage("foo:bar"). Obj(), ). @@ -396,7 +393,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) { }, registry: fwkplugins.NewRegistry(), wantObjs: []client.Object{ - testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job"). + testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job"). SchedulingTimeout(300). MinMember(20). MinResources(corev1.ResourceList{ diff --git a/pkg/runtime.v2/framework/interface.go b/pkg/runtime.v2/framework/interface.go index 886c1ab39c..00d613ec0a 100644 --- a/pkg/runtime.v2/framework/interface.go +++ b/pkg/runtime.v2/framework/interface.go @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface { type CustomValidationPlugin interface { Plugin - Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) + Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) } type ComponentBuilderPlugin interface { diff --git a/pkg/runtime.v2/framework/plugins/mpi/mpi.go b/pkg/runtime.v2/framework/plugins/mpi/mpi.go index b85a265195..9e79f6c5a1 100644 --- a/pkg/runtime.v2/framework/plugins/mpi/mpi.go +++ b/pkg/runtime.v2/framework/plugins/mpi/mpi.go @@ -23,6 +23,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" ) @@ -55,6 +56,6 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info) error { } // TODO: Need to implement validations for MPIJob. -func (m *MPI) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { +func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { return nil, nil } diff --git a/pkg/runtime.v2/framework/plugins/torch/torch.go b/pkg/runtime.v2/framework/plugins/torch/torch.go index 1a0306be98..b9b7f10cb9 100644 --- a/pkg/runtime.v2/framework/plugins/torch/torch.go +++ b/pkg/runtime.v2/framework/plugins/torch/torch.go @@ -23,6 +23,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" "github.com/kubeflow/training-operator/pkg/runtime.v2/framework" ) @@ -51,6 +52,6 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info) error { } // TODO: Need to implement validateions for TorchJob. -func (t *Torch) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) { +func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) { return nil, nil } diff --git a/pkg/runtime.v2/runtime_test.go b/pkg/runtime.v2/runtime_test.go index bcb6b5efd9..95e366ac22 100644 --- a/pkg/runtime.v2/runtime_test.go +++ b/pkg/runtime.v2/runtime_test.go @@ -32,7 +32,7 @@ import ( ) func TestNewInfo(t *testing.T) { - jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). Clone() cases := map[string]struct { @@ -159,7 +159,7 @@ func TestNewInfo(t *testing.T) { } func TestUpdate(t *testing.T) { - jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job"). + jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job"). Clone() cases := map[string]struct { @@ -172,7 +172,7 @@ func TestUpdate(t *testing.T) { info: &Info{ Obj: jobSetBase.Obj(), }, - obj: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job"). + obj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job"). Obj(), wantInfo: &Info{ Obj: jobSetBase.Obj(), diff --git a/pkg/util.v2/testing/wrapper.go b/pkg/util.v2/testing/wrapper.go index 10441c040b..df3f8cbfce 100644 --- a/pkg/util.v2/testing/wrapper.go +++ b/pkg/util.v2/testing/wrapper.go @@ -17,8 +17,6 @@ limitations under the License. package testing import ( - "testing" - batchv1 "k8s.io/api/batch/v1" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -35,8 +33,7 @@ type JobSetWrapper struct { jobsetv1alpha2.JobSet } -func MakeJobSetWrapper(t *testing.T, namespace, name string) *JobSetWrapper { - t.Helper() +func MakeJobSetWrapper(namespace, name string) *JobSetWrapper { return &JobSetWrapper{ JobSet: jobsetv1alpha2.JobSet{ TypeMeta: metav1.TypeMeta{ @@ -170,6 +167,13 @@ func (j *JobSetWrapper) Annotation(key, value string) *JobSetWrapper { return j } +func (j *JobSetWrapper) Replicas(replicas int32) *JobSetWrapper { + for idx := range j.Spec.ReplicatedJobs { + j.Spec.ReplicatedJobs[idx].Replicas = replicas + } + return j +} + func (j *JobSetWrapper) Clone() *JobSetWrapper { return &JobSetWrapper{ JobSet: *j.JobSet.DeepCopy(), @@ -184,8 +188,7 @@ type TrainJobWrapper struct { kubeflowv2.TrainJob } -func MakeTrainJobWrapper(t *testing.T, namespace, name string) *TrainJobWrapper { - t.Helper() +func MakeTrainJobWrapper(namespace, name string) *TrainJobWrapper { return &TrainJobWrapper{ TrainJob: kubeflowv2.TrainJob{ TypeMeta: metav1.TypeMeta{ @@ -226,8 +229,7 @@ type TrainJobTrainerWrapper struct { kubeflowv2.Trainer } -func MakeTrainJobTrainerWrapper(t *testing.T) *TrainJobTrainerWrapper { - t.Helper() +func MakeTrainJobTrainerWrapper() *TrainJobTrainerWrapper { return &TrainJobTrainerWrapper{ Trainer: kubeflowv2.Trainer{}, } @@ -264,8 +266,7 @@ type TrainingRuntimeWrapper struct { kubeflowv2.TrainingRuntime } -func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingRuntimeWrapper { - t.Helper() +func MakeTrainingRuntimeWrapper(namespace, name string) *TrainingRuntimeWrapper { return &TrainingRuntimeWrapper{ TrainingRuntime: kubeflowv2.TrainingRuntime{ TypeMeta: metav1.TypeMeta{ @@ -281,7 +282,8 @@ func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingR Spec: jobsetv1alpha2.JobSetSpec{ ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ { - Name: "Coordinator", + Name: "Coordinator", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -295,7 +297,8 @@ func MakeTrainingRuntimeWrapper(t *testing.T, namespace, name string) *TrainingR }, }, { - Name: "Worker", + Name: "Worker", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -351,8 +354,7 @@ type ClusterTrainingRuntimeWrapper struct { kubeflowv2.ClusterTrainingRuntime } -func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTrainingRuntimeWrapper { - t.Helper() +func MakeClusterTrainingRuntimeWrapper(name string) *ClusterTrainingRuntimeWrapper { return &ClusterTrainingRuntimeWrapper{ ClusterTrainingRuntime: kubeflowv2.ClusterTrainingRuntime{ TypeMeta: metav1.TypeMeta{ @@ -367,7 +369,8 @@ func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTraini Spec: jobsetv1alpha2.JobSetSpec{ ReplicatedJobs: []jobsetv1alpha2.ReplicatedJob{ { - Name: "Coordinator", + Name: "Coordinator", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -381,7 +384,8 @@ func MakeClusterTrainingRuntimeWrapper(t *testing.T, name string) *ClusterTraini }, }, { - Name: "Worker", + Name: "Worker", + Replicas: 1, Template: batchv1.JobTemplateSpec{ Spec: batchv1.JobSpec{ Template: corev1.PodTemplateSpec{ @@ -421,13 +425,19 @@ type TrainingRuntimeSpecWrapper struct { kubeflowv2.TrainingRuntimeSpec } -func MakeTrainingRuntimeSpecWrapper(t *testing.T, spec kubeflowv2.TrainingRuntimeSpec) *TrainingRuntimeSpecWrapper { - t.Helper() +func MakeTrainingRuntimeSpecWrapper(spec kubeflowv2.TrainingRuntimeSpec) *TrainingRuntimeSpecWrapper { return &TrainingRuntimeSpecWrapper{ TrainingRuntimeSpec: spec, } } +func (s *TrainingRuntimeSpecWrapper) Replicas(replicas int32) *TrainingRuntimeSpecWrapper { + for idx := range s.Template.Spec.ReplicatedJobs { + s.Template.Spec.ReplicatedJobs[idx].Replicas = replicas + } + return s +} + func (s *TrainingRuntimeSpecWrapper) ContainerImage(image string) *TrainingRuntimeSpecWrapper { for i, rJob := range s.Template.Spec.ReplicatedJobs { for j := range rJob.Template.Spec.Template.Spec.Containers { @@ -475,8 +485,7 @@ type SchedulerPluginsPodGroupWrapper struct { schedulerpluginsv1alpha1.PodGroup } -func MakeSchedulerPluginsPodGroup(t *testing.T, namespace, name string) *SchedulerPluginsPodGroupWrapper { - t.Helper() +func MakeSchedulerPluginsPodGroup(namespace, name string) *SchedulerPluginsPodGroupWrapper { return &SchedulerPluginsPodGroupWrapper{ PodGroup: schedulerpluginsv1alpha1.PodGroup{ TypeMeta: metav1.TypeMeta{ diff --git a/pkg/webhook.v2/clustertrainingruntime_webhook.go b/pkg/webhook.v2/clustertrainingruntime_webhook.go index a679fe66dd..c98d71a15b 100644 --- a/pkg/webhook.v2/clustertrainingruntime_webhook.go +++ b/pkg/webhook.v2/clustertrainingruntime_webhook.go @@ -20,6 +20,7 @@ import ( "context" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" @@ -43,8 +44,11 @@ func setupWebhookForClusterTrainingRuntime(mgr ctrl.Manager, run map[string]runt var _ webhook.CustomValidator = (*ClusterTrainingRuntimeWebhook)(nil) -func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *ClusterTrainingRuntimeWebhook) ValidateCreate(ctx context.Context, obj apiruntime.Object) (admission.Warnings, error) { + clTrainingRuntime := obj.(*kubeflowv2.ClusterTrainingRuntime) + log := ctrl.LoggerFrom(ctx).WithName("clustertrainingruntime-webhook") + log.V(5).Info("Validating create", "clusterTrainingRuntime", klog.KObj(clTrainingRuntime)) + return nil, validateReplicatedJobs(clTrainingRuntime.Spec.Template.Spec.ReplicatedJobs).ToAggregate() } func (w *ClusterTrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { diff --git a/pkg/webhook.v2/trainingruntime_webhook.go b/pkg/webhook.v2/trainingruntime_webhook.go index 5597e5238f..fa6a7186db 100644 --- a/pkg/webhook.v2/trainingruntime_webhook.go +++ b/pkg/webhook.v2/trainingruntime_webhook.go @@ -20,9 +20,12 @@ import ( "context" apiruntime "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/util/validation/field" + "k8s.io/klog/v2" ctrl "sigs.k8s.io/controller-runtime" "sigs.k8s.io/controller-runtime/pkg/webhook" "sigs.k8s.io/controller-runtime/pkg/webhook/admission" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" runtime "github.com/kubeflow/training-operator/pkg/runtime.v2" @@ -43,8 +46,25 @@ func setupWebhookForTrainingRuntime(mgr ctrl.Manager, run map[string]runtime.Run var _ webhook.CustomValidator = (*TrainingRuntimeWebhook)(nil) -func (w *TrainingRuntimeWebhook) ValidateCreate(context.Context, apiruntime.Object) (admission.Warnings, error) { - return nil, nil +func (w *TrainingRuntimeWebhook) ValidateCreate(ctx context.Context, obj apiruntime.Object) (admission.Warnings, error) { + trainingRuntime := obj.(*kubeflowv2.TrainingRuntime) + log := ctrl.LoggerFrom(ctx).WithName("trainingruntime-webhook") + log.V(5).Info("Validating create", "trainingRuntime", klog.KObj(trainingRuntime)) + return nil, validateReplicatedJobs(trainingRuntime.Spec.Template.Spec.ReplicatedJobs).ToAggregate() +} + +func validateReplicatedJobs(rJobs []jobsetv1alpha2.ReplicatedJob) field.ErrorList { + rJobsPath := field.NewPath("spec"). + Child("template"). + Child("spec"). + Child("replicatedJobs") + var allErrs field.ErrorList + for idx, rJob := range rJobs { + if rJob.Replicas != 1 { + allErrs = append(allErrs, field.Invalid(rJobsPath.Index(idx).Child("replicas"), rJob.Replicas, "always must be 1")) + } + } + return allErrs } func (w *TrainingRuntimeWebhook) ValidateUpdate(context.Context, apiruntime.Object, apiruntime.Object) (admission.Warnings, error) { diff --git a/pkg/webhook.v2/trainingruntime_webhook_test.go b/pkg/webhook.v2/trainingruntime_webhook_test.go new file mode 100644 index 0000000000..fbbf9a6a35 --- /dev/null +++ b/pkg/webhook.v2/trainingruntime_webhook_test.go @@ -0,0 +1,60 @@ +/* +Copyright 2024 The Kubeflow Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package webhookv2 + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "k8s.io/apimachinery/pkg/util/validation/field" + jobsetv1alpha2 "sigs.k8s.io/jobset/api/jobset/v1alpha2" + + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" +) + +func TestValidateReplicatedJobs(t *testing.T) { + cases := map[string]struct { + rJobs []jobsetv1alpha2.ReplicatedJob + wantError field.ErrorList + }{ + "valid replicatedJobs": { + rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). + Replicas(1). + Obj().Spec.ReplicatedJobs, + }, + "invalid replicas": { + rJobs: testingutil.MakeJobSetWrapper("ns", "valid"). + Replicas(2). + Obj().Spec.ReplicatedJobs, + wantError: field.ErrorList{ + field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(0).Child("replicas"), + "2", ""), + field.Invalid(field.NewPath("spec").Child("template").Child("spec").Child("replicatedJobs").Index(1).Child("replicas"), + "2", ""), + }, + }, + } + for name, tc := range cases { + t.Run(name, func(t *testing.T) { + gotErr := validateReplicatedJobs(tc.rJobs) + if diff := cmp.Diff(tc.wantError, gotErr, cmpopts.IgnoreFields(field.Error{}, "Detail", "BadValue")); len(diff) != 0 { + t.Errorf("validateReplicateJobs() mismatch (-want,+got):\n%s", diff) + } + }) + } +} diff --git a/test/integration/webhook.v2/clustertrainingruntime_test.go b/test/integration/webhook.v2/clustertrainingruntime_test.go index 9ba9be69af..a2519c8ff8 100644 --- a/test/integration/webhook.v2/clustertrainingruntime_test.go +++ b/test/integration/webhook.v2/clustertrainingruntime_test.go @@ -22,9 +22,13 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" "github.com/kubeflow/training-operator/test/integration/framework" ) +const clTrainingRuntimeName = "test-clustertrainingruntime" + var _ = ginkgo.Describe("ClusterTrainingRuntime Webhook", ginkgo.Ordered, func() { var ns *corev1.Namespace @@ -49,4 +53,25 @@ var _ = ginkgo.Describe("ClusterTrainingRuntime Webhook", ginkgo.Ordered, func() } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) }) + + ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.ClusterTrainingRuntime{})).To(gomega.Succeed()) + }) + + ginkgo.When("Creating ClusterTrainingRuntime", func() { + ginkgo.DescribeTable("", func(runtime func() *kubeflowv2.ClusterTrainingRuntime) { + gomega.Expect(k8sClient.Create(ctx, runtime())).Should(gomega.Succeed()) + }, + ginkgo.Entry("Should succeed to create ClusterTrainingRuntime", + func() *kubeflowv2.ClusterTrainingRuntime { + baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper(clTrainingRuntimeName) + return baseRuntime. + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). + Replicas(1). + Obj()). + Obj() + }), + ) + }) }) diff --git a/test/integration/webhook.v2/trainingruntime_test.go b/test/integration/webhook.v2/trainingruntime_test.go index dc2add14ce..7599e04759 100644 --- a/test/integration/webhook.v2/trainingruntime_test.go +++ b/test/integration/webhook.v2/trainingruntime_test.go @@ -21,10 +21,15 @@ import ( "github.com/onsi/gomega" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "sigs.k8s.io/controller-runtime/pkg/client" + kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1" + testingutil "github.com/kubeflow/training-operator/pkg/util.v2/testing" "github.com/kubeflow/training-operator/test/integration/framework" ) +const trainingRuntimeName = "test-trainingruntime" + var _ = ginkgo.Describe("TrainingRuntime Webhook", ginkgo.Ordered, func() { var ns *corev1.Namespace @@ -49,4 +54,25 @@ var _ = ginkgo.Describe("TrainingRuntime Webhook", ginkgo.Ordered, func() { } gomega.Expect(k8sClient.Create(ctx, ns)).To(gomega.Succeed()) }) + + ginkgo.AfterEach(func() { + gomega.Expect(k8sClient.DeleteAllOf(ctx, &kubeflowv2.TrainingRuntime{}, client.InNamespace(ns.Name))).To(gomega.Succeed()) + }) + + ginkgo.When("Creating TrainingRuntime", func() { + ginkgo.DescribeTable("Validate TrainingRuntime on creation", func(runtime func() *kubeflowv2.TrainingRuntime) { + gomega.Expect(k8sClient.Create(ctx, runtime())).Should(gomega.Succeed()) + }, + ginkgo.Entry("Should succeed to create TrainingRuntime", + func() *kubeflowv2.TrainingRuntime { + baseRuntime := testingutil.MakeTrainingRuntimeWrapper(ns.Name, trainingRuntimeName).Clone() + return baseRuntime. + RuntimeSpec( + testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec). + Replicas(1). + Obj()). + Obj() + }), + ) + }) })