Skip to content

Commit

Permalink
Decouple MLMD types setup from Go core library (kubeflow#267)
Browse files Browse the repository at this point in the history
* implement Go core changes

* add test coverage

* refactoring

* minor message change

* nit: renamed function
  • Loading branch information
tarilabs authored Jan 26, 2024
1 parent 46d1f1c commit 655e4da
Show file tree
Hide file tree
Showing 5 changed files with 243 additions and 144 deletions.
5 changes: 5 additions & 0 deletions cmd/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"time"

"github.com/golang/glog"
"github.com/opendatahub-io/model-registry/internal/mlmdtypes"
"github.com/opendatahub-io/model-registry/internal/server/openapi"
"github.com/opendatahub-io/model-registry/pkg/core"
"github.com/spf13/cobra"
Expand Down Expand Up @@ -48,6 +49,10 @@ func runProxyServer(cmd *cobra.Command, args []string) error {
defer conn.Close()
glog.Infof("connected to MLMD server")

_, err = mlmdtypes.CreateMLMDTypes(conn)
if err != nil {
return fmt.Errorf("error creating MLMD types: %v", err)
}
service, err := core.NewModelRegistryService(conn)
if err != nil {
return fmt.Errorf("error creating core service: %v", err)
Expand Down
5 changes: 5 additions & 0 deletions internal/apiutils/api_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ func ZeroIfNil[T any](input *T) T {
return *new(T)
}

// of returns a pointer to the provided literal/const input
func Of[E any](e E) *E {
return &e
}

func BuildListOption(pageSize string, orderBy model.OrderByField, sortOrder model.SortOrder, nextPageToken string) (api.ListOptions, error) {
var pageSizeInt32 *int32
if pageSize != "" {
Expand Down
144 changes: 144 additions & 0 deletions internal/mlmdtypes/mlmdtypes.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
package mlmdtypes

import (
"context"
"fmt"

"github.com/opendatahub-io/model-registry/internal/apiutils"
"github.com/opendatahub-io/model-registry/internal/constants"
"github.com/opendatahub-io/model-registry/internal/ml_metadata/proto"
"google.golang.org/grpc"
)

var (
registeredModelTypeName = apiutils.Of(constants.RegisteredModelTypeName)
modelVersionTypeName = apiutils.Of(constants.ModelVersionTypeName)
modelArtifactTypeName = apiutils.Of(constants.ModelArtifactTypeName)
servingEnvironmentTypeName = apiutils.Of(constants.ServingEnvironmentTypeName)
inferenceServiceTypeName = apiutils.Of(constants.InferenceServiceTypeName)
serveModelTypeName = apiutils.Of(constants.ServeModelTypeName)
canAddFields = apiutils.Of(true)
)

// Utility method that created the necessary Model Registry's logical-model types
// as the necessary MLMD's Context, Artifact, Execution types etc. in the underlying MLMD service
func CreateMLMDTypes(cc grpc.ClientConnInterface) (map[string]int64, error) {
client := proto.NewMetadataStoreServiceClient(cc)

registeredModelReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: registeredModelTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"state": proto.PropertyType_STRING,
},
},
}

modelVersionReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: modelVersionTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_name": proto.PropertyType_STRING,
"version": proto.PropertyType_STRING,
"author": proto.PropertyType_STRING,
"state": proto.PropertyType_STRING,
},
},
}

modelArtifactReq := proto.PutArtifactTypeRequest{
CanAddFields: canAddFields,
ArtifactType: &proto.ArtifactType{
Name: modelArtifactTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_format_name": proto.PropertyType_STRING,
"model_format_version": proto.PropertyType_STRING,
"storage_key": proto.PropertyType_STRING,
"storage_path": proto.PropertyType_STRING,
"service_account_name": proto.PropertyType_STRING,
},
},
}

servingEnvironmentReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: servingEnvironmentTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
},
},
}

inferenceServiceReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: inferenceServiceTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_version_id": proto.PropertyType_INT,
"registered_model_id": proto.PropertyType_INT,
// same information tracked using ParentContext association
"serving_environment_id": proto.PropertyType_INT,
"runtime": proto.PropertyType_STRING,
"desired_state": proto.PropertyType_STRING,
},
},
}

serveModelReq := proto.PutExecutionTypeRequest{
CanAddFields: canAddFields,
ExecutionType: &proto.ExecutionType{
Name: serveModelTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_version_id": proto.PropertyType_INT,
},
},
}

registeredModelResp, err := client.PutContextType(context.Background(), &registeredModelReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *registeredModelTypeName, err)
}

modelVersionResp, err := client.PutContextType(context.Background(), &modelVersionReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *modelVersionTypeName, err)
}

modelArtifactResp, err := client.PutArtifactType(context.Background(), &modelArtifactReq)
if err != nil {
return nil, fmt.Errorf("error setting up artifact type %s: %v", *modelArtifactTypeName, err)
}

servingEnvironmentResp, err := client.PutContextType(context.Background(), &servingEnvironmentReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *servingEnvironmentTypeName, err)
}

inferenceServiceResp, err := client.PutContextType(context.Background(), &inferenceServiceReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *inferenceServiceTypeName, err)
}

serveModelResp, err := client.PutExecutionType(context.Background(), &serveModelReq)
if err != nil {
return nil, fmt.Errorf("error setting up execution type %s: %v", *serveModelTypeName, err)
}

typesMap := map[string]int64{
constants.RegisteredModelTypeName: registeredModelResp.GetTypeId(),
constants.ModelVersionTypeName: modelVersionResp.GetTypeId(),
constants.ModelArtifactTypeName: modelArtifactResp.GetTypeId(),
constants.ServingEnvironmentTypeName: servingEnvironmentResp.GetTypeId(),
constants.InferenceServiceTypeName: inferenceServiceResp.GetTypeId(),
constants.ServeModelTypeName: serveModelResp.GetTypeId(),
}
return typesMap, nil
}
172 changes: 55 additions & 117 deletions pkg/core/core.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,12 @@ import (
)

var (
registeredModelTypeName = of(constants.RegisteredModelTypeName)
modelVersionTypeName = of(constants.ModelVersionTypeName)
modelArtifactTypeName = of(constants.ModelArtifactTypeName)
servingEnvironmentTypeName = of(constants.ServingEnvironmentTypeName)
inferenceServiceTypeName = of(constants.InferenceServiceTypeName)
serveModelTypeName = of(constants.ServeModelTypeName)
canAddFields = of(true)
registeredModelTypeName = apiutils.Of(constants.RegisteredModelTypeName)
modelVersionTypeName = apiutils.Of(constants.ModelVersionTypeName)
modelArtifactTypeName = apiutils.Of(constants.ModelArtifactTypeName)
servingEnvironmentTypeName = apiutils.Of(constants.ServingEnvironmentTypeName)
inferenceServiceTypeName = apiutils.Of(constants.InferenceServiceTypeName)
serveModelTypeName = apiutils.Of(constants.ServeModelTypeName)
)

// ModelRegistryService is the core library of the model registry
Expand All @@ -36,137 +35,81 @@ type ModelRegistryService struct {
}

// NewModelRegistryService creates a new instance of the ModelRegistryService, initializing it with the provided gRPC client connection.
// It sets up the necessary context and artifact types in the MLMD service.
// It _assumes_ the necessary MLMD's Context, Artifact, Execution types etc. are already setup in the underlying MLMD service.
//
// Parameters:
// - cc: A gRPC client connection to the underlying MLMD service
func NewModelRegistryService(cc grpc.ClientConnInterface) (api.ModelRegistryApi, error) {
typesMap, err := BuildTypesMap(cc)
if err != nil { // early return in case type Ids cannot be retrieved
return nil, err
}

client := proto.NewMetadataStoreServiceClient(cc)

// Setup the needed Type instances if not existing already
return &ModelRegistryService{
mlmdClient: client,
typesMap: typesMap,
openapiConv: &generated.OpenAPIConverterImpl{},
mapper: mapper.NewMapper(typesMap),
}, nil
}

registeredModelReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: registeredModelTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"state": proto.PropertyType_STRING,
},
},
}
func BuildTypesMap(cc grpc.ClientConnInterface) (map[string]int64, error) {
client := proto.NewMetadataStoreServiceClient(cc)

modelVersionReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: modelVersionTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_name": proto.PropertyType_STRING,
"version": proto.PropertyType_STRING,
"author": proto.PropertyType_STRING,
"state": proto.PropertyType_STRING,
},
},
registeredModelContextTypeReq := proto.GetContextTypeRequest{
TypeName: registeredModelTypeName,
}

modelArtifactReq := proto.PutArtifactTypeRequest{
CanAddFields: canAddFields,
ArtifactType: &proto.ArtifactType{
Name: modelArtifactTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_format_name": proto.PropertyType_STRING,
"model_format_version": proto.PropertyType_STRING,
"storage_key": proto.PropertyType_STRING,
"storage_path": proto.PropertyType_STRING,
"service_account_name": proto.PropertyType_STRING,
},
},
registeredModelResp, err := client.GetContextType(context.Background(), &registeredModelContextTypeReq)
if err != nil {
return nil, fmt.Errorf("error getting context type %s: %v", *registeredModelTypeName, err)
}

servingEnvironmentReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: servingEnvironmentTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
},
},
modelVersionContextTypeReq := proto.GetContextTypeRequest{
TypeName: modelVersionTypeName,
}

inferenceServiceReq := proto.PutContextTypeRequest{
CanAddFields: canAddFields,
ContextType: &proto.ContextType{
Name: inferenceServiceTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_version_id": proto.PropertyType_INT,
"registered_model_id": proto.PropertyType_INT,
// same information tracked using ParentContext association
"serving_environment_id": proto.PropertyType_INT,
"runtime": proto.PropertyType_STRING,
"desired_state": proto.PropertyType_STRING,
},
},
modelVersionResp, err := client.GetContextType(context.Background(), &modelVersionContextTypeReq)
if err != nil {
return nil, fmt.Errorf("error getting context type %s: %v", *modelVersionTypeName, err)
}

serveModelReq := proto.PutExecutionTypeRequest{
CanAddFields: canAddFields,
ExecutionType: &proto.ExecutionType{
Name: serveModelTypeName,
Properties: map[string]proto.PropertyType{
"description": proto.PropertyType_STRING,
"model_version_id": proto.PropertyType_INT,
},
},
modelArtifactArtifactTypeReq := proto.GetArtifactTypeRequest{
TypeName: modelArtifactTypeName,
}

registeredModelResp, err := client.PutContextType(context.Background(), &registeredModelReq)
modelArtifactResp, err := client.GetArtifactType(context.Background(), &modelArtifactArtifactTypeReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *registeredModelTypeName, err)
return nil, fmt.Errorf("error getting artifact type %s: %v", *modelArtifactTypeName, err)
}

modelVersionResp, err := client.PutContextType(context.Background(), &modelVersionReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *modelVersionTypeName, err)
servingEnvironmentContextTypeReq := proto.GetContextTypeRequest{
TypeName: servingEnvironmentTypeName,
}

modelArtifactResp, err := client.PutArtifactType(context.Background(), &modelArtifactReq)
servingEnvironmentResp, err := client.GetContextType(context.Background(), &servingEnvironmentContextTypeReq)
if err != nil {
return nil, fmt.Errorf("error setting up artifact type %s: %v", *modelArtifactTypeName, err)
return nil, fmt.Errorf("error getting context type %s: %v", *servingEnvironmentTypeName, err)
}

servingEnvironmentResp, err := client.PutContextType(context.Background(), &servingEnvironmentReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *servingEnvironmentTypeName, err)
inferenceServiceContextTypeReq := proto.GetContextTypeRequest{
TypeName: inferenceServiceTypeName,
}

inferenceServiceResp, err := client.PutContextType(context.Background(), &inferenceServiceReq)
inferenceServiceResp, err := client.GetContextType(context.Background(), &inferenceServiceContextTypeReq)
if err != nil {
return nil, fmt.Errorf("error setting up context type %s: %v", *inferenceServiceTypeName, err)
return nil, fmt.Errorf("error getting context type %s: %v", *inferenceServiceTypeName, err)
}

serveModelResp, err := client.PutExecutionType(context.Background(), &serveModelReq)
serveModelExecutionReq := proto.GetExecutionTypeRequest{
TypeName: serveModelTypeName,
}
serveModelResp, err := client.GetExecutionType(context.Background(), &serveModelExecutionReq)
if err != nil {
return nil, fmt.Errorf("error setting up execution type %s: %v", *serveModelTypeName, err)
return nil, fmt.Errorf("error getting execution type %s: %v", *serveModelTypeName, err)
}

typesMap := map[string]int64{
constants.RegisteredModelTypeName: registeredModelResp.GetTypeId(),
constants.ModelVersionTypeName: modelVersionResp.GetTypeId(),
constants.ModelArtifactTypeName: modelArtifactResp.GetTypeId(),
constants.ServingEnvironmentTypeName: servingEnvironmentResp.GetTypeId(),
constants.InferenceServiceTypeName: inferenceServiceResp.GetTypeId(),
constants.ServeModelTypeName: serveModelResp.GetTypeId(),
}
return &ModelRegistryService{
mlmdClient: client,
typesMap: typesMap,
openapiConv: &generated.OpenAPIConverterImpl{},
mapper: mapper.NewMapper(typesMap),
}, nil
constants.RegisteredModelTypeName: registeredModelResp.ContextType.GetId(),
constants.ModelVersionTypeName: modelVersionResp.ContextType.GetId(),
constants.ModelArtifactTypeName: modelArtifactResp.ArtifactType.GetId(),
constants.ServingEnvironmentTypeName: servingEnvironmentResp.ContextType.GetId(),
constants.InferenceServiceTypeName: inferenceServiceResp.ContextType.GetId(),
constants.ServeModelTypeName: serveModelResp.ExecutionType.GetId(),
}
return typesMap, nil
}

// REGISTERED MODELS
Expand Down Expand Up @@ -1420,8 +1363,3 @@ func (serv *ModelRegistryService) GetServeModels(listOptions api.ListOptions, in
}
return &toReturn, nil
}

// of returns a pointer to the provided literal/const input
func of[E any](e E) *E {
return &e
}
Loading

0 comments on commit 655e4da

Please sign in to comment.