Skip to content

Commit

Permalink
Implement TrainingRuntime ReplicatedJob validation
Browse files Browse the repository at this point in the history
Signed-off-by: Yuki Iwai <[email protected]>
  • Loading branch information
tenzen-y committed Oct 16, 2024
1 parent 498ff28 commit ba81beb
Show file tree
Hide file tree
Showing 15 changed files with 203 additions and 66 deletions.
16 changes: 8 additions & 8 deletions pkg/runtime.v2/core/clustertrainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
)

func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper(t, "test-runtime").
baseRuntime := testingutil.MakeClusterTrainingRuntimeWrapper("test-runtime").
Clone()

cases := map[string]struct {
Expand All @@ -45,17 +45,17 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
wantError error
}{
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job").
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper(t).
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Obj(),
).
Obj(),
clusterTrainingRuntime: baseRuntime.RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec).
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicySchedulingTimeout(120).
MLPolicyNumNodes(20).
Expand All @@ -68,7 +68,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job").
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
ContainerImage(ptr.To("test:trainjob")).
JobCompletionMode(batchv1.IndexedCompletion).
Expand All @@ -80,7 +80,7 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
}).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job").
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
MinMember(40).
SchedulingTimeout(120).
Expand All @@ -91,11 +91,11 @@ func TestClusterTrainingRuntimeNewObjects(t *testing.T) {
},
},
"missing trainingRuntime resource": {
trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job").
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.ClusterTrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper(t).
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Obj(),
).
Expand Down
3 changes: 0 additions & 3 deletions pkg/runtime.v2/core/trainingruntime.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,6 @@ func (r *TrainingRuntime) buildObjects(
runtime.WithPodGroupPolicy(podGroupPolicy),
}
for idx, rJob := range jobSetTemplateSpec.Spec.ReplicatedJobs {
if rJob.Replicas == 0 {
jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas = 1
}
replicas := jobSetTemplateSpec.Spec.ReplicatedJobs[idx].Replicas * ptr.Deref(rJob.Template.Spec.Completions, 1)
opts = append(opts, runtime.WithPodSpecReplicas(rJob.Name, replicas, rJob.Template.Spec.Template.Spec))
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/runtime.v2/core/trainingruntime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import (
)

func TestTrainingRuntimeNewObjects(t *testing.T) {
baseRuntime := testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test-runtime").
baseRuntime := testingutil.MakeTrainingRuntimeWrapper(metav1.NamespaceDefault, "test-runtime").
Clone()

cases := map[string]struct {
Expand All @@ -45,13 +45,13 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
wantError error
}{
"succeeded to build JobSet and PodGroup": {
trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job").
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
SpecLabel("conflictLabel", "override").
SpecAnnotation("conflictAnnotation", "override").
Trainer(
testingutil.MakeTrainJobTrainerWrapper(t).
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Obj(),
).
Expand All @@ -60,7 +60,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Label("conflictLabel", "overridden").
Annotation("conflictAnnotation", "overridden").
RuntimeSpec(
testingutil.MakeTrainingRuntimeSpecWrapper(t, baseRuntime.Spec).
testingutil.MakeTrainingRuntimeSpecWrapper(baseRuntime.Spec).
ContainerImage("test:runtime").
PodGroupPolicySchedulingTimeout(120).
MLPolicyNumNodes(20).
Expand All @@ -73,7 +73,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
Obj(),
).Obj(),
wantObjs: []client.Object{
testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job").
testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Label("conflictLabel", "override").
Annotation("conflictAnnotation", "override").
PodLabel(schedulerpluginsv1alpha1.PodGroupLabel, "test-job").
Expand All @@ -87,7 +87,7 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
}).
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
Obj(),
testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job").
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
ControllerReference(kubeflowv2.SchemeGroupVersion.WithKind("TrainJob"), "test-job", "uid").
MinMember(40).
SchedulingTimeout(120).
Expand All @@ -98,11 +98,11 @@ func TestTrainingRuntimeNewObjects(t *testing.T) {
},
},
"missing trainingRuntime resource": {
trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job").
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
TrainingRuntimeRef(kubeflowv2.SchemeGroupVersion.WithKind(kubeflowv2.TrainingRuntimeKind), "test-runtime").
Trainer(
testingutil.MakeTrainJobTrainerWrapper(t).
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("test:trainjob").
Obj(),
).
Expand Down
5 changes: 1 addition & 4 deletions pkg/runtime.v2/framework/core/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (f *Framework) RunEnforcePodGroupPolicyPlugins(trainJob *kubeflowv2.TrainJo
return nil
}

func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) {
func (f *Framework) RunCustomValidationPlugins(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
var aggregatedWarnings admission.Warnings
var aggregatedErrors field.ErrorList
for _, plugin := range f.customValidationPlugins {
Expand All @@ -101,9 +101,6 @@ func (f *Framework) RunCustomValidationPlugins(oldObj, newObj client.Object) (ad
aggregatedErrors = append(aggregatedErrors, errs...)
}
}
if len(aggregatedErrors) == 0 {
return aggregatedWarnings, nil
}
return aggregatedWarnings, aggregatedErrors
}

Expand Down
23 changes: 10 additions & 13 deletions pkg/runtime.v2/framework/core/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,24 +287,21 @@ func TestRunEnforcePodGroupPolicyPlugins(t *testing.T) {

func TestRunCustomValidationPlugins(t *testing.T) {
cases := map[string]struct {
trainJob *kubeflowv2.TrainJob
registry fwkplugins.Registry
oldObj client.Object
newObj client.Object
oldObj *kubeflowv2.TrainJob
newObj *kubeflowv2.TrainJob
wantWarnings admission.Warnings
wantError field.ErrorList
}{
// Need to implement more detail testing after we implement custom validator in any plugins.
"there are not any custom validations": {
trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}},
registry: fwkplugins.NewRegistry(),
oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
oldObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(),
newObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(),
},
"an empty registry": {
trainJob: &kubeflowv2.TrainJob{ObjectMeta: metav1.ObjectMeta{Name: "test-job", Namespace: metav1.NamespaceDefault}},
oldObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
newObj: testingutil.MakeTrainingRuntimeWrapper(t, metav1.NamespaceDefault, "test").Obj(),
oldObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(),
newObj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test").Obj(),
},
}
for name, tc := range cases {
Expand All @@ -329,7 +326,7 @@ func TestRunCustomValidationPlugins(t *testing.T) {
}

func TestRunComponentBuilderPlugins(t *testing.T) {
jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job").
jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
ResourceRequests(0, corev1.ResourceList{
corev1.ResourceCPU: resource.MustParse("2"),
corev1.ResourceMemory: resource.MustParse("4Gi"),
Expand All @@ -354,10 +351,10 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
wantObjs []client.Object
}{
"coscheduling and jobset are performed": {
trainJob: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job").
trainJob: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
UID("uid").
Trainer(
testingutil.MakeTrainJobTrainerWrapper(t).
testingutil.MakeTrainJobTrainerWrapper().
ContainerImage("foo:bar").
Obj(),
).
Expand Down Expand Up @@ -396,7 +393,7 @@ func TestRunComponentBuilderPlugins(t *testing.T) {
},
registry: fwkplugins.NewRegistry(),
wantObjs: []client.Object{
testingutil.MakeSchedulerPluginsPodGroup(t, metav1.NamespaceDefault, "test-job").
testingutil.MakeSchedulerPluginsPodGroup(metav1.NamespaceDefault, "test-job").
SchedulingTimeout(300).
MinMember(20).
MinResources(corev1.ResourceList{
Expand Down
2 changes: 1 addition & 1 deletion pkg/runtime.v2/framework/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ type EnforceMLPolicyPlugin interface {

type CustomValidationPlugin interface {
Plugin
Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList)
Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList)
}

type ComponentBuilderPlugin interface {
Expand Down
3 changes: 2 additions & 1 deletion pkg/runtime.v2/framework/plugins/mpi/mpi.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
"github.com/kubeflow/training-operator/pkg/runtime.v2/framework"
)
Expand Down Expand Up @@ -55,6 +56,6 @@ func (m *MPI) EnforceMLPolicy(info *runtime.Info) error {
}

// TODO: Need to implement validations for MPIJob.
func (m *MPI) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) {
func (m *MPI) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
}
3 changes: 2 additions & 1 deletion pkg/runtime.v2/framework/plugins/torch/torch.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"sigs.k8s.io/controller-runtime/pkg/client"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"

kubeflowv2 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v2alpha1"
runtime "github.com/kubeflow/training-operator/pkg/runtime.v2"
"github.com/kubeflow/training-operator/pkg/runtime.v2/framework"
)
Expand Down Expand Up @@ -51,6 +52,6 @@ func (t *Torch) EnforceMLPolicy(info *runtime.Info) error {
}

// TODO: Need to implement validateions for TorchJob.
func (t *Torch) Validate(oldObj, newObj client.Object) (admission.Warnings, field.ErrorList) {
func (t *Torch) Validate(oldObj, newObj *kubeflowv2.TrainJob) (admission.Warnings, field.ErrorList) {
return nil, nil
}
6 changes: 3 additions & 3 deletions pkg/runtime.v2/runtime_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import (
)

func TestNewInfo(t *testing.T) {
jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job").
jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Clone()

cases := map[string]struct {
Expand Down Expand Up @@ -159,7 +159,7 @@ func TestNewInfo(t *testing.T) {
}

func TestUpdate(t *testing.T) {
jobSetBase := testingutil.MakeJobSetWrapper(t, metav1.NamespaceDefault, "test-job").
jobSetBase := testingutil.MakeJobSetWrapper(metav1.NamespaceDefault, "test-job").
Clone()

cases := map[string]struct {
Expand All @@ -172,7 +172,7 @@ func TestUpdate(t *testing.T) {
info: &Info{
Obj: jobSetBase.Obj(),
},
obj: testingutil.MakeTrainJobWrapper(t, metav1.NamespaceDefault, "test-job").
obj: testingutil.MakeTrainJobWrapper(metav1.NamespaceDefault, "test-job").
Obj(),
wantInfo: &Info{
Obj: jobSetBase.Obj(),
Expand Down
Loading

0 comments on commit ba81beb

Please sign in to comment.