Skip to content

Commit

Permalink
Adds an API endpoint to create a ModelArtifact for a given ModelVersion
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Sep 17, 2024
1 parent 7536055 commit 1df0ad5
Show file tree
Hide file tree
Showing 9 changed files with 176 additions and 2 deletions.
19 changes: 19 additions & 0 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ make docker-build
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID |
| POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | CreateModelVersionForRegisteredModelHandler | Create a ModelVersion entity for a specific RegisteredModel |
| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | GetAllModelArtifactsByModelVersionHandler | Get all ModelArtifact entities by ModelVersion ID |
| POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | CreateModelArtifactByModelVersion | Create a ModelArtifact entity for a specific ModelVersion |

### Sample local calls
```
Expand Down Expand Up @@ -189,4 +190,22 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts
curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts
```
```
# POST /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts
curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts" \
-H "Content-Type: application/json" \
-d '{ "data": {
"customProperties": {
"my-label9": {
"metadataType": "MetadataStringValue",
"string_value": "val"
}
},
"description": "New description",
"externalId": "9927",
"name": "ModelArtifact One",
"state": "LIVE",
"artifactType": "TYPE_ONE"
}}'
```
4 changes: 4 additions & 0 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ const (
ModelRegistryId = "model_registry_id"
RegisteredModelId = "registered_model_id"
ModelVersionId = "model_version_id"
ModelArtifactId = "model_artifact_id"
HealthCheckPath = PathPrefix + "/healthcheck"
ModelRegistryListPath = PathPrefix + "/model_registry"
ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId
Expand All @@ -27,6 +28,8 @@ const (
ModelVersionListPath = ModelRegistryPath + "/model_versions"
ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId
ModelVersionArtifactListPath = ModelVersionPath + "/artifacts"
ModelArtifactListPath = ModelRegistryPath + "/model_artifacts"
ModelArtifactPath = ModelArtifactListPath + "/:" + ModelArtifactId
)

type App struct {
Expand Down Expand Up @@ -91,6 +94,7 @@ func (app *App) Routes() http.Handler {
router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler))
router.PATCH(ModelVersionPath, app.AttachRESTClient(app.UpdateModelVersionHandler))
router.GET(ModelVersionArtifactListPath, app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler))
router.POST(ModelVersionArtifactListPath, app.AttachRESTClient(app.CreateModelArtifactByModelVersionHandler))

// Kubernetes client routes
router.GET(ModelRegistryListPath, app.ModelRegistryHandler)
Expand Down
58 changes: 58 additions & 0 deletions clients/ui/bff/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
type ModelVersionEnvelope Envelope[*openapi.ModelVersion, None]
type ModelVersionListEnvelope Envelope[*openapi.ModelVersionList, None]
type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None]
type ModelArtifactEnvelope Envelope[*openapi.ModelArtifact, None]

func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
Expand Down Expand Up @@ -172,3 +173,60 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter,
app.serverErrorResponse(w, r, err)
}
}

func (app *App) CreateModelArtifactByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

var envelope ModelArtifactEnvelope
if err := json.NewDecoder(r.Body).Decode(&envelope); err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error decoding JSON:: %v", err.Error()))
return
}

data := *envelope.Data

if err := validation.ValidateModelArtifact(data); err != nil {
app.badRequestResponse(w, r, fmt.Errorf("validation error:: %v", err.Error()))
return
}

jsonData, err := json.Marshal(data)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error marshaling ModelVersion to JSON: %w", err))
return
}

createdArtifact, err := app.modelRegistryClient.CreateModelArtifactByModelVersion(client, ps.ByName(ModelVersionId), jsonData)
if err != nil {
var httpErr *integrations.HTTPError
if errors.As(err, &httpErr) {
app.errorResponse(w, r, httpErr)
} else {
app.serverErrorResponse(w, r, err)
}
return
}

if createdArtifact == nil {
app.serverErrorResponse(w, r, fmt.Errorf("created ModelArtifact is nil"))
return
}

response := ModelArtifactEnvelope{
Data: createdArtifact,
}

w.Header().Set("Location", ParseURLTemplate(ModelArtifactPath, map[string]string{
ModelRegistryId: ps.ByName(ModelRegistryId),
ModelArtifactId: createdArtifact.GetId(),
}))
err = app.WriteJSON(w, http.StatusCreated, response, nil)
if err != nil {
app.serverErrorResponse(w, r, fmt.Errorf("error writing JSON"))
return
}
}
18 changes: 18 additions & 0 deletions clients/ui/bff/api/model_versions_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,21 @@ func TestGetAllModelArtifactsByModelVersionHandler(t *testing.T) {
assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken)
assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items))
}

func TestCreateModelArtifactByModelVersionHandler(t *testing.T) {
data := mocks.GetModelArtifactMocks()[0]
expected := ModelArtifactEnvelope{Data: &data}

artifact := openapi.ModelArtifact{
Name: openapi.PtrString("Artifact One"),
ArtifactType: "ARTIFACT_TYPE_ONE",
}
body := ModelArtifactEnvelope{Data: &artifact}

actual, rs, err := setupApiTest[ModelArtifactEnvelope](http.MethodPost, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", body)
assert.NoError(t, err)

assert.Equal(t, http.StatusCreated, rs.StatusCode)
assert.Equal(t, expected.Data.GetArtifactType(), actual.Data.GetArtifactType())
assert.Equal(t, rs.Header.Get("Location"), "/api/v1/model_registry/model-registry/model_artifacts/1")
}
24 changes: 22 additions & 2 deletions clients/ui/bff/data/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type ModelVersionInterface interface {
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error)
CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error)
}

type ModelVersion struct {
Expand Down Expand Up @@ -58,7 +59,7 @@ func (v ModelVersion) CreateModelVersion(client integrations.HTTPClientInterface
return &model, nil
}

func (m ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) {
func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) {
path, err := url.JoinPath(modelVersionPath, id)

if err != nil {
Expand All @@ -78,7 +79,7 @@ func (m ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface
return &model, nil
}

func (m ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath)

if err != nil {
Expand All @@ -97,3 +98,22 @@ func (m ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPCl

return &model, nil
}

func (v ModelVersion) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) {
path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath)
if err != nil {
return nil, err
}

responseData, err := client.POST(path, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("error posting model artifact: %w", err)
}

var model openapi.ModelArtifact
if err := json.Unmarshal(responseData, &model); err != nil {
return nil, fmt.Errorf("error decoding response data: %w", err)
}

return &model, nil
}
26 changes: 26 additions & 0 deletions clients/ui/bff/data/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,29 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
assert.Equal(t, expected.PageSize, actual.PageSize)
assert.Equal(t, len(expected.Items), len(actual.Items))
}

func TestCreateModelArtifactByModelVersion(t *testing.T) {
gofakeit.Seed(0)

expected := mocks.GenerateMockModelArtifact()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath)
assert.NoError(t, err)

mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodPost, path, mock.Anything).Return(mockData, nil)

jsonInnput, err := json.Marshal(expected)
assert.NoError(t, err)

actual, err := modelVersion.CreateModelArtifactByModelVersion(mockClient, "1", jsonInnput)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.ArtifactType, actual.ArtifactType)
}
4 changes: 4 additions & 0 deletions clients/ui/bff/internals/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integra
mockData := GetModelArtifactListMock()
return &mockData, nil
}
func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) {
mockData := GetModelArtifactMocks()[0]
return &mockData, nil
}
8 changes: 8 additions & 0 deletions clients/ui/bff/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ func ValidateModelVersion(input openapi.ModelVersion) error {
// Add more field validations as required
return nil
}

func ValidateModelArtifact(input openapi.ModelArtifact) error {
if input.GetName() == "" {
return errors.New("name cannot be empty")
}
// Add more field validations as required
return nil
}
17 changes: 17 additions & 0 deletions clients/ui/bff/validation/validation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,20 @@ func TestValidateModelVersion(t *testing.T) {

validateTestSpecs(t, specs, ValidateModelVersion)
}

func TestValidateModel(t *testing.T) {
specs := []testSpec[openapi.ModelArtifact]{
{
name: "Empty name",
input: openapi.ModelArtifact{Name: openapi.PtrString("")},
wantErr: true,
},
{
name: "Valid name",
input: openapi.ModelArtifact{Name: openapi.PtrString("ValidName")},
wantErr: false,
},
}

validateTestSpecs(t, specs, ValidateModelArtifact)
}

0 comments on commit 1df0ad5

Please sign in to comment.