Skip to content

Commit

Permalink
Adds an API endpoint to create a ModelVersion for a given RegisteredM…
Browse files Browse the repository at this point in the history
…odel

Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Sep 16, 2024
1 parent 26dbecf commit 63ff3e2
Show file tree
Hide file tree
Showing 9 changed files with 227 additions and 12 deletions.
43 changes: 31 additions & 12 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,19 @@ make docker-build

### Endpoints

| URL Pattern | Handler | Action |
|---------------------------------------------------------------------------------------------|----------------------------------------------|-----------------------------------------------------|
| GET /v1/healthcheck | HealthcheckHandler | Show application information. |
| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, |
| GET /v1/model_registry/{model_registry_id}/registered_models | GetAllRegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model_registry/{model_registry_id}/registered_models | CreateRegisteredModelHandler | Create a RegisteredModel entity. |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | GetRegisteredModelHandler | Get a RegisteredModel entity by ID |
| PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | UpdateRegisteredModelHandler | Update a RegisteredModel entity by ID |
| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | GetModelVersionHandler | Get a ModelVersion by ID |
| POST /api/v1/model_registry/{model_registry_id}/model_versions | CreateModelVersionHandler | Create a ModelVersion entity |
| PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | UpdateModelVersionHandler | Update a ModelVersion entity by ID |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID |
| URL Pattern | Handler | Action |
|----------------------------------------------------------------------------------------------|----------------------------------------------|-------------------------------------------------------------|
| GET /v1/healthcheck | HealthcheckHandler | Show application information. |
| GET /v1/model_registry | ModelRegistryHandler | Get all model registries, |
| GET /v1/model_registry/{model_registry_id}/registered_models | GetAllRegisteredModelsHandler | Gets a list of all RegisteredModel entities. |
| POST /v1/model_registry/{model_registry_id}/registered_models | CreateRegisteredModelHandler | Create a RegisteredModel entity. |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | GetRegisteredModelHandler | Get a RegisteredModel entity by ID |
| PATCH /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id} | UpdateRegisteredModelHandler | Update a RegisteredModel entity by ID |
| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | GetModelVersionHandler | Get a ModelVersion by ID |
| POST /api/v1/model_registry/{model_registry_id}/model_versions | CreateModelVersionHandler | Create a ModelVersion entity |
| PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | UpdateModelVersionHandler | Update a ModelVersion entity by ID |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID |
| POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | CreateModelVersionForRegisteredModelHandler | Create a ModelVersion entity for a specific RegisteredModel |

### Sample local calls
```
Expand Down Expand Up @@ -165,4 +166,22 @@ curl -i -X PATCH "http://localhost:4000/api/v1/model_registry/model-registry/mod
```
# GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions
curl -i localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions
```
```
# POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/registered_models/1/versions" \
-H "Content-Type: application/json" \
-d '{ "data": {
"customProperties": {
"my-label9": {
"metadataType": "MetadataStringValue",
"string_value": "val"
}
},
"description": "New description",
"externalId": "9927",
"name": "ModelVersion One",
"state": "LIVE",
"author": "alex"
}}'
```
1 change: 1 addition & 0 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ func (app *App) Routes() http.Handler {
router.POST(RegisteredModelListPath, app.AttachRESTClient(app.CreateRegisteredModelHandler))
router.PATCH(RegisteredModelPath, app.AttachRESTClient(app.UpdateRegisteredModelHandler))
router.GET(RegisteredModelVersionsPath, app.AttachRESTClient(app.GetAllModelVersionsForRegisteredModelHandler))
router.POST(RegisteredModelVersionsPath, app.AttachRESTClient(app.CreateModelVersionForRegisteredModelHandler))

router.GET(ModelVersionPath, app.AttachRESTClient(app.GetModelVersionHandler))
router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler))
Expand Down
12 changes: 12 additions & 0 deletions clients/ui/bff/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,15 @@ func (app *App) ReadJSON(w http.ResponseWriter, r *http.Request, dst any) error

return nil
}

func ParseURLTemplate(tmpl string, params map[string]string) string {
args := make([]string, len(params)*2)

for k, v := range params {
args = append(args, ":"+k, v)
}

r := strings.NewReplacer(args...)

return r.Replace(tmpl)
}
21 changes: 21 additions & 0 deletions clients/ui/bff/api/helpers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package api

import (
"github.com/stretchr/testify/assert"
"testing"
)

func TestParseURLTemplate(t *testing.T) {
expected := "/v1/model_registry/demo-registry/registered_models/111-222-333/versions"
tmpl := "/v1/model_registry/:model_registry_id/registered_models/:registered_model_id/versions"
params := map[string]string{"model_registry_id": "demo-registry", "registered_model_id": "111-222-333"}

actual := ParseURLTemplate(tmpl, params)

assert.Equal(t, expected, actual)
}

func TestParseURLTemplateWhenEmpty(t *testing.T) {
actual := ParseURLTemplate("", nil)
assert.Empty(t, actual)
}
51 changes: 51 additions & 0 deletions clients/ui/bff/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,54 @@ func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWrit
app.serverErrorResponse(w, r, err)
}
}

func (app *App) CreateModelVersionForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

var envelope ModelVersionEnvelope
if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil {
app.badRequestResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error()))
return
}

data := *envelope.Data

if err := validation.ValidateModelVersion(data); err != nil {
app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error()))
}

jsonData, err := json.Marshal(data)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error marshaling model to JSON: %w", err))
}

createdVersion, err := app.modelRegistryClient.CreateModelVersionForRegisteredModel(client, ps.ByName(RegisteredModelId), jsonData)
if err != nil {
var httpErr *integrations.HTTPError
if errors.As(err, &httpErr) {
app.errorResponse(w, r, httpErr)
} else {
app.serverErrorResponse(w, r, err)
}
return
}

if createdVersion == nil {
app.serverErrorResponse(w, r, fmt.Errorf("created model version is nil"))
return
}

responseBody := ModelVersionEnvelope{
Data: createdVersion,
}

w.Header().Set("Location", ParseURLTemplate(ModelVersionPath, map[string]string{ModelRegistryId: ps.ByName(ModelRegistryId), ModelVersionId: createdVersion.GetId()}))
err = app.WriteJSON(w, http.StatusCreated, responseBody, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
57 changes: 57 additions & 0 deletions clients/ui/bff/api/registered_models_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"encoding/json"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/internals/mocks"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -224,3 +225,59 @@ func TestGetAllModelVersionsForRegisteredModelHandler(t *testing.T) {
assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken)
assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items))
}

func TestCreateModelVersionForRegisteredModelHandler(t *testing.T) {
mockMRClient, _ := mocks.NewModelRegistryClient(nil)
mockClient := new(mocks.MockHTTPClient)

testApp := App{
modelRegistryClient: mockMRClient,
}

newVersion := openapi.NewModelVersion("Model One", "1")
reqEnvelope := ModelVersionEnvelope{Data: newVersion}

reqJSON, err := json.Marshal(reqEnvelope)
assert.NoError(t, err)

reqBody := bytes.NewReader(reqJSON)

req, err := http.NewRequest(http.MethodPost, "/api/v1/model_registry/model-registry/registered_models/1/versions", reqBody)
assert.NoError(t, err)

ctx := context.WithValue(req.Context(), httpClientKey, mockClient)
req = req.WithContext(ctx)

rr := httptest.NewRecorder()

ps := httprouter.Params{
httprouter.Param{
Key: ModelRegistryId,
Value: "model-registry",
},
httprouter.Param{
Key: RegisteredModelId,
Value: "1",
},
}

testApp.CreateModelVersionForRegisteredModelHandler(rr, req, ps)
rs := rr.Result()

defer rs.Body.Close()

body, err := io.ReadAll(rs.Body)
assert.NoError(t, err)
var actual ModelVersionEnvelope
err = json.Unmarshal(body, &actual)
assert.NoError(t, err)

assert.Equal(t, http.StatusCreated, rr.Code)

expectedVersion := mocks.GetModelVersionMocks()[0]

expected := ModelVersionEnvelope{Data: &expectedVersion}

assert.Equal(t, expected.Data.Name, actual.Data.Name)
assert.Equal(t, rs.Header.Get("Location"), "/api/v1/model_registry/model-registry/model_versions/1")
}
21 changes: 21 additions & 0 deletions clients/ui/bff/data/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type RegisteredModelInterface interface {
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error)
CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
}

type RegisteredModel struct {
Expand Down Expand Up @@ -113,3 +114,23 @@ func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInter

return &model, nil
}

func (m RegisteredModel) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) {
path, err := url.JoinPath(registeredModelPath, id, versionsPath)

if err != nil {
return nil, err
}

responseData, err := client.POST(path, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error posting model version: %w", err)
}

var model openapi.ModelVersion
if err := json.Unmarshal(responseData, &model); err != nil {
return nil, fmt.Errorf("error decoding response data: %w", err)
}

return &model, nil
}
28 changes: 28 additions & 0 deletions clients/ui/bff/data/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,31 @@ func TestGetAllModelVersions(t *testing.T) {

mockClient.AssertExpectations(t)
}

func TestCreateModelVersionForRegisteredModel(t *testing.T) {
gofakeit.Seed(0)

expected := mocks.GenerateMockModelVersion()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

registeredModel := RegisteredModel{}

path, err := url.JoinPath(registeredModelPath, "1", versionsPath)
assert.NoError(t, err)
mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodPost, path, mock.Anything).Return(mockData, nil)

jsonInput, err := json.Marshal(expected)
assert.NoError(t, err)

actual, err := registeredModel.CreateModelVersionForRegisteredModel(mockClient, "1", jsonInput)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, *expected.Author, *actual.Author)
assert.Equal(t, expected.RegisteredModelId, actual.RegisteredModelId)

mockClient.AssertExpectations(t)
}
5 changes: 5 additions & 0 deletions clients/ui/bff/internals/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,8 @@ func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPCl
mockData := GetModelVersionListMock()
return &mockData, nil
}

func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) {
mockData := GetModelVersionMocks()[0]
return &mockData, nil
}

0 comments on commit 63ff3e2

Please sign in to comment.