Skip to content

Commit

Permalink
fix: create the workspace service for custom models (#745)
Browse files Browse the repository at this point in the history
**Reason for Change**:
<!-- What does this PR improve or fix in Kaito? Why is it needed? -->

This change updates the `ensureService` function in the workspace
controller to create the workspace service for custom models instead of
only doing it for preset models.

The logic today only goes into the service creation code if there is a
preset, but it is only using the preset information to see if the model
supports distributed inference. To add support for custom models, we can
just default the distributed inference support to false and always
create the service preset or not.

It seems fair to me to only support distributed inference on the preset
models as that requires more setup. With this change custom models can
at least have basic support without any manual Kubernetes resource
creation.

Fixes #744.

**Requirements**

- [X] added unit tests and e2e tests (if applicable).

**Issue Fixed**:
<!-- If this PR fixes GitHub issue 4321, add "Fixes #4321" to the next
line. -->

Fixes #744 

**Notes for Reviewers**:
  • Loading branch information
zawachte authored Dec 2, 2024
1 parent 21056a1 commit 1b440d5
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
27 changes: 27 additions & 0 deletions pkg/utils/test/testUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,33 @@ var (
},
},
}
MockWorkspaceCustomModel = &v1alpha1.Workspace{
ObjectMeta: metav1.ObjectMeta{
Name: "testCustomWorkspace",
Namespace: "kaito",
},
Resource: v1alpha1.ResourceSpec{
Count: &gpuNodeCount,
InstanceType: "Standard_NC12s_v3",
LabelSelector: &metav1.LabelSelector{
MatchLabels: map[string]string{
"apps": "test",
},
},
},
Inference: &v1alpha1.InferenceSpec{
Template: &corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
Containers: []corev1.Container{
{
Name: "test-container",
Image: "fake.kaito.com/kaito-image:0.0.1",
},
},
},
},
},
}
)

var (
Expand Down
21 changes: 13 additions & 8 deletions pkg/workspace/controllers/workspace_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -632,19 +632,24 @@ func (c *WorkspaceReconciler) ensureService(ctx context.Context, wObj *kaitov1al
return nil
}

supportsDistributedInference := false
if presetName := getPresetName(wObj); presetName != "" {
model := plugin.KaitoModelRegister.MustGet(presetName)
serviceObj := manifests.GenerateServiceManifest(ctx, wObj, serviceType, model.SupportDistributedInference())
if err := resources.CreateResource(ctx, serviceObj, c.Client); err != nil {
supportsDistributedInference = model.SupportDistributedInference()
}

serviceObj := manifests.GenerateServiceManifest(ctx, wObj, serviceType, supportsDistributedInference)
if err := resources.CreateResource(ctx, serviceObj, c.Client); err != nil {
return err
}

if supportsDistributedInference {
headlessService := manifests.GenerateHeadlessServiceManifest(ctx, wObj)
if err := resources.CreateResource(ctx, headlessService, c.Client); err != nil {
return err
}
if model.SupportDistributedInference() {
headlessService := manifests.GenerateHeadlessServiceManifest(ctx, wObj)
if err := resources.CreateResource(ctx, headlessService, c.Client); err != nil {
return err
}
}
}

return nil
}

Expand Down
14 changes: 13 additions & 1 deletion pkg/workspace/controllers/workspace_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -548,26 +548,38 @@ func TestEnsureService(t *testing.T) {
testcases := map[string]struct {
callMocks func(c *test.MockClient)
expectedError error
workspace *kaitov1alpha1.Workspace
}{
"Existing service is found for workspace": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
workspace: test.MockWorkspaceDistributedModel,
},
"Service creation fails": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(errors.New("cannot create service"))
},
expectedError: errors.New("cannot create service"),
workspace: test.MockWorkspaceDistributedModel,
},
"Successfully creates a new service": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
workspace: test.MockWorkspaceDistributedModel,
},
"Successfully creates a new service for a custom model": {
callMocks: func(c *test.MockClient) {
c.On("Get", mock.IsType(context.Background()), mock.Anything, mock.IsType(&corev1.Service{}), mock.Anything).Return(test.NotFoundError())
c.On("Create", mock.IsType(context.Background()), mock.IsType(&corev1.Service{}), mock.Anything).Return(nil)
},
expectedError: nil,
workspace: test.MockWorkspaceCustomModel,
},
}

Expand All @@ -582,7 +594,7 @@ func TestEnsureService(t *testing.T) {
}
ctx := context.Background()

err := reconciler.ensureService(ctx, test.MockWorkspaceDistributedModel)
err := reconciler.ensureService(ctx, tc.workspace)
if tc.expectedError == nil {
assert.Check(t, err == nil, "Not expected to return error")
} else {
Expand Down

0 comments on commit 1b440d5

Please sign in to comment.