Skip to content

Commit

Permalink
Adds an API endpoint to get all ModelArtifacts for a given ModelVersion
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Sep 17, 2024
1 parent 0e5f6e2 commit 7536055
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 17 deletions.
5 changes: 5 additions & 0 deletions clients/ui/bff/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ make docker-build
| PATCH /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id} | UpdateModelVersionHandler | Update a ModelVersion entity by ID |
| GET /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | GetAllModelVersionsForRegisteredModelHandler | Get all ModelVersion entities by RegisteredModel ID |
| POST /v1/model_registry/{model_registry_id}/registered_models/{registered_model_id}/versions | CreateModelVersionForRegisteredModelHandler | Create a ModelVersion entity for a specific RegisteredModel |
| GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts | GetAllModelArtifactsByModelVersionHandler | Get all ModelArtifact entities by ModelVersion ID |

### Sample local calls
```
Expand Down Expand Up @@ -184,4 +185,8 @@ curl -i -X POST "http://localhost:4000/api/v1/model_registry/model-registry/regi
"state": "LIVE",
"author": "alex"
}}'
```
```
# GET /api/v1/model_registry/{model_registry_id}/model_versions/{model_version_id}/artifacts
curl -i http://localhost:4000/api/v1/model_registry/model-registry/model_versions/1/artifacts
```
28 changes: 15 additions & 13 deletions clients/ui/bff/api/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ import (
)

const (
Version = "1.0.0"
PathPrefix = "/api/v1"
ModelRegistryId = "model_registry_id"
RegisteredModelId = "registered_model_id"
ModelVersionId = "model_version_id"
HealthCheckPath = PathPrefix + "/healthcheck"
ModelRegistryListPath = PathPrefix + "/model_registry"
ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId
RegisteredModelListPath = ModelRegistryPath + "/registered_models"
RegisteredModelPath = RegisteredModelListPath + "/:" + RegisteredModelId
RegisteredModelVersionsPath = RegisteredModelPath + "/versions"
ModelVersionListPath = ModelRegistryPath + "/model_versions"
ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId
Version = "1.0.0"
PathPrefix = "/api/v1"
ModelRegistryId = "model_registry_id"
RegisteredModelId = "registered_model_id"
ModelVersionId = "model_version_id"
HealthCheckPath = PathPrefix + "/healthcheck"
ModelRegistryListPath = PathPrefix + "/model_registry"
ModelRegistryPath = ModelRegistryListPath + "/:" + ModelRegistryId
RegisteredModelListPath = ModelRegistryPath + "/registered_models"
RegisteredModelPath = RegisteredModelListPath + "/:" + RegisteredModelId
RegisteredModelVersionsPath = RegisteredModelPath + "/versions"
ModelVersionListPath = ModelRegistryPath + "/model_versions"
ModelVersionPath = ModelVersionListPath + "/:" + ModelVersionId
ModelVersionArtifactListPath = ModelVersionPath + "/artifacts"
)

type App struct {
Expand Down Expand Up @@ -89,6 +90,7 @@ func (app *App) Routes() http.Handler {
router.GET(ModelVersionPath, app.AttachRESTClient(app.GetModelVersionHandler))
router.POST(ModelVersionListPath, app.AttachRESTClient(app.CreateModelVersionHandler))
router.PATCH(ModelVersionPath, app.AttachRESTClient(app.UpdateModelVersionHandler))
router.GET(ModelVersionArtifactListPath, app.AttachRESTClient(app.GetAllModelArtifactsByModelVersionHandler))

// Kubernetes client routes
router.GET(ModelRegistryListPath, app.ModelRegistryHandler)
Expand Down
24 changes: 24 additions & 0 deletions clients/ui/bff/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

type ModelVersionEnvelope Envelope[*openapi.ModelVersion, None]
type ModelVersionListEnvelope Envelope[*openapi.ModelVersionList, None]
type ModelArtifactListEnvelope Envelope[*openapi.ModelArtifactList, None]

func (app *App) GetModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
Expand Down Expand Up @@ -148,3 +149,26 @@ func (app *App) UpdateModelVersionHandler(w http.ResponseWriter, r *http.Request
return
}
}

func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId))
if err != nil {
app.serverErrorResponse(w, r, err)
return
}

result := ModelArtifactListEnvelope{
Data: data,
}

err = app.WriteJSON(w, http.StatusOK, result, nil)
if err != nil {
app.serverErrorResponse(w, r, err)
}
}
14 changes: 14 additions & 0 deletions clients/ui/bff/api/model_versions_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,17 @@ func TestUpdateModelVersionHandler(t *testing.T) {
assert.Equal(t, http.StatusOK, rs.StatusCode)
assert.Equal(t, expected.Data.Name, actual.Data.Name)
}

func TestGetAllModelArtifactsByModelVersionHandler(t *testing.T) {
data := mocks.GetModelArtifactListMock()
expected := ModelArtifactListEnvelope{Data: &data}

actual, rs, err := setupApiTest[ModelArtifactListEnvelope](http.MethodGet, "/api/v1/model_registry/model-registry/model_versions/1/artifacts", nil)
assert.NoError(t, err)

assert.Equal(t, http.StatusOK, rs.StatusCode)
assert.Equal(t, expected.Data.Size, actual.Data.Size)
assert.Equal(t, expected.Data.PageSize, actual.Data.PageSize)
assert.Equal(t, expected.Data.NextPageToken, actual.Data.NextPageToken)
assert.Equal(t, len(expected.Data.Items), len(actual.Data.Items))
}
22 changes: 22 additions & 0 deletions clients/ui/bff/data/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@ import (
)

const modelVersionPath = "/model_versions"
const artifactsByModelVersionPath = "/artifacts"

type ModelVersionInterface interface {
GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error)
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error)
}

type ModelVersion struct {
Expand Down Expand Up @@ -75,3 +77,23 @@ func (m ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface

return &model, nil
}

func (m ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
if err != nil {
return nil, fmt.Errorf("error fetching model version artifacts: %w", err)
}

var model openapi.ModelArtifactList
if err := json.Unmarshal(responseData, &model); err != nil {
return nil, fmt.Errorf("error decoding response data: %w", err)
}

return &model, nil
}
26 changes: 26 additions & 0 deletions clients/ui/bff/data/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,29 @@ func TestUpdateModelVersion(t *testing.T) {

mockClient.AssertExpectations(t)
}

func TestGetModelArtifactsByModelVersion(t *testing.T) {
gofakeit.Seed(0)

expected := mocks.GenerateMockModelArtifactList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath)
assert.NoError(t, err)

mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1")
assert.NoError(t, err)

assert.NotNil(t, actual)
assert.Equal(t, expected.Size, actual.Size)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
assert.Equal(t, expected.PageSize, actual.PageSize)
assert.Equal(t, len(expected.Items), len(actual.Items))
}
5 changes: 5 additions & 0 deletions clients/ui/bff/internals/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,8 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in
mockData := GetModelVersionMocks()[0]
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
mockData := GetModelArtifactListMock()
return &mockData, nil
}
62 changes: 62 additions & 0 deletions clients/ui/bff/internals/mocks/static_data_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,65 @@ func GetModelVersionListMock() openapi.ModelVersionList {
Size: 2,
}
}

func GetModelArtifactMocks() []openapi.ModelArtifact {
artifact1 := openapi.ModelArtifact{
ArtifactType: "TYPE_ONE",
CustomProperties: newCustomProperties(),
Description: stringToPointer("This artifact can do more than you would expect"),
ExternalId: stringToPointer("1000001"),
Uri: stringToPointer("http://localhost/artifacts/1"),
State: stateToPointer(openapi.ARTIFACTSTATE_LIVE),
Name: stringToPointer("Artifact One"),
Id: stringToPointer("1"),
CreateTimeSinceEpoch: stringToPointer("1725282249921"),
LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"),
ModelFormatName: stringToPointer("ONNX"),
StorageKey: stringToPointer("key1"),
StoragePath: stringToPointer("/artifacts/1"),
ModelFormatVersion: stringToPointer("1.0.0"),
ServiceAccountName: stringToPointer("service-1"),
}

artifact2 := openapi.ModelArtifact{
ArtifactType: "TYPE_TWO",
CustomProperties: newCustomProperties(),
Description: stringToPointer("This artifact can do more than you would expect, but less than you would hope"),
ExternalId: stringToPointer("1000002"),
Uri: stringToPointer("http://localhost/artifacts/2"),
State: stateToPointer(openapi.ARTIFACTSTATE_PENDING),
Name: stringToPointer("Artifact Two"),
Id: stringToPointer("2"),
CreateTimeSinceEpoch: stringToPointer("1725282249921"),
LastUpdateTimeSinceEpoch: stringToPointer("1725282249921"),
ModelFormatName: stringToPointer("TensorFlow"),
StorageKey: stringToPointer("key2"),
StoragePath: stringToPointer("/artifacts/2"),
ModelFormatVersion: stringToPointer("1.0.0"),
ServiceAccountName: stringToPointer("service-2"),
}

return []openapi.ModelArtifact{artifact1, artifact2}
}

func GetModelArtifactListMock() openapi.ModelArtifactList {
return openapi.ModelArtifactList{
NextPageToken: "abcdefgh",
PageSize: 2,
Items: GetModelArtifactMocks(),
Size: 2,
}
}

func newCustomProperties() *map[string]openapi.MetadataValue {
result := map[string]openapi.MetadataValue{
"my-label9": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: "property9",
MetadataType: "string",
},
},
}

return &result
}
68 changes: 64 additions & 4 deletions clients/ui/bff/internals/mocks/types_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ func GenerateMockRegisteredModel() openapi.RegisteredModel {
ExternalId: stringToPointer(gofakeit.UUID()),
Name: gofakeit.Name(),
Id: stringToPointer(gofakeit.UUID()),
CreateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())),
LastUpdateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())),
CreateTimeSinceEpoch: randomEpochTime(),
LastUpdateTimeSinceEpoch: randomEpochTime(),
Owner: stringToPointer(gofakeit.Name()),
State: stateToPointer(openapi.RegisteredModelState(gofakeit.RandomString([]string{string(openapi.REGISTEREDMODELSTATE_LIVE), string(openapi.REGISTEREDMODELSTATE_ARCHIVED)}))),
}
Expand All @@ -57,8 +57,8 @@ func GenerateMockModelVersion() openapi.ModelVersion {
ExternalId: stringToPointer(gofakeit.UUID()),
Name: gofakeit.Name(),
Id: stringToPointer(gofakeit.UUID()),
CreateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())),
LastUpdateTimeSinceEpoch: stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())),
CreateTimeSinceEpoch: randomEpochTime(),
LastUpdateTimeSinceEpoch: randomEpochTime(),
Author: stringToPointer(gofakeit.Name()),
State: stateToPointer(openapi.ModelVersionState(gofakeit.RandomString([]string{string(openapi.MODELVERSIONSTATE_LIVE), string(openapi.MODELVERSIONSTATE_ARCHIVED)}))),
}
Expand All @@ -81,6 +81,66 @@ func GenerateMockModelVersionList() openapi.ModelVersionList {
}
}

func GenerateMockModelArtifact() openapi.ModelArtifact {
artifact := openapi.ModelArtifact{
ArtifactType: gofakeit.Word(),
CustomProperties: &map[string]openapi.MetadataValue{
"example_key": {
MetadataStringValue: &openapi.MetadataStringValue{
StringValue: gofakeit.Sentence(3),
MetadataType: "string",
},
},
},
Description: stringToPointer(gofakeit.Sentence(5)),
ExternalId: stringToPointer(gofakeit.UUID()),
Uri: stringToPointer(gofakeit.URL()),
State: randomArtifactState(),
Name: stringToPointer(gofakeit.Name()),
Id: stringToPointer(gofakeit.UUID()),
CreateTimeSinceEpoch: randomEpochTime(),
LastUpdateTimeSinceEpoch: randomEpochTime(),
ModelFormatName: stringToPointer(gofakeit.Name()),
StorageKey: stringToPointer(gofakeit.Word()),
StoragePath: stringToPointer("/" + gofakeit.Word() + "/" + gofakeit.Word()),
ModelFormatVersion: stringToPointer(gofakeit.AppVersion()),
ServiceAccountName: stringToPointer(gofakeit.Username()),
}
return artifact
}

func GenerateMockModelArtifactList() openapi.ModelArtifactList {
var artifacts []openapi.ModelArtifact

for i := 0; i < 2; i++ {
artifact := GenerateMockModelArtifact()
artifacts = append(artifacts, artifact)
}

return openapi.ModelArtifactList{
NextPageToken: gofakeit.UUID(),
PageSize: int32(gofakeit.Number(1, 20)),
Size: int32(len(artifacts)),
Items: artifacts,
}
}

func randomEpochTime() *string {
return stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli()))
}

func randomArtifactState() *openapi.ArtifactState {
return stateToPointer(openapi.ArtifactState(gofakeit.RandomString([]string{
string(openapi.ARTIFACTSTATE_LIVE),
string(openapi.ARTIFACTSTATE_DELETED),
string(openapi.ARTIFACTSTATE_ABANDONED),
string(openapi.ARTIFACTSTATE_MARKED_FOR_DELETION),
string(openapi.ARTIFACTSTATE_PENDING),
string(openapi.ARTIFACTSTATE_REFERENCE),
string(openapi.ARTIFACTSTATE_UNKNOWN),
})))
}

func stateToPointer[T any](s T) *T {
return &s
}
Expand Down

0 comments on commit 7536055

Please sign in to comment.