Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pull] main from kubeflow:main #140

Closed
wants to merge 8 commits into from
244 changes: 125 additions & 119 deletions clients/python/poetry.lock

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions clients/python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ nest-asyncio = "^1.6.0"
# necessary for modern type annotations using pydantic on 3.9
eval-type-backport = "^0.2.0"

huggingface-hub = { version = ">=0.20.1,<0.26.0", optional = true }
huggingface-hub = { version = ">=0.20.1,<0.27.0", optional = true }

[tool.poetry.extras]
hf = ["huggingface-hub"]
Expand All @@ -44,7 +44,7 @@ sphinx-autobuild = ">=2021.3.14,<2025.0.0"
pytest = ">=7.4.2,<9.0.0"
coverage = { extras = ["toml"], version = "^7.3.2" }
pytest-cov = ">=4.1,<6.0"
ruff = ">=0.5.2,<0.7.0"
ruff = ">=0.5.2,<0.8.0"
mypy = "^1.7.0"
pytest-asyncio = ">=0.23.7,<0.25.0"
requests = "^2.32.2"
Expand Down
4 changes: 2 additions & 2 deletions csi/GET_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ We assume all [prerequisites](#prerequisites) are satisfied at this point.

3. Setup local deployment of *Kserve* using the provided *Kserve quick installation* script
```bash
curl -s "https://raw.githubusercontent.com/kserve/kserve/release-0.12/hack/quick_install.sh" | bash
curl -s "https://raw.githubusercontent.com/kserve/kserve/release-0.14/hack/quick_install.sh" | bash
```

4. Install *model registry* in the local cluster
Expand Down Expand Up @@ -257,5 +257,5 @@ EOF
If you do not have DNS, you can still curl with the ingress gateway external IP using the HOST Header.
```bash
SERVICE_HOSTNAME=$(kubectl get inferenceservice iris-model -n kserve-test -o jsonpath='{.status.url}' | cut -d "/" -f 3)
curl -v -H "Host: ${SERVICE_HOSTNAME}" -H "Content-Type: application/json" "http://${INGRESS_HOST}:${INGRESS_PORT}/v1/models/iris-v1:predict" -d @/tmp/iris-input.json
curl -v -H "Host: ${SERVICE_HOSTNAME}" -H "Content-Type: application/json" "http://${INGRESS_HOST}:${INGRESS_PORT}/v1/models/iris-model:predict" -d @/tmp/iris-input.json
```
9 changes: 6 additions & 3 deletions csi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ sequenceDiagram
U->>+MR: Register ML Model
MR-->>-U: Indexed Model
U->>U: Create InferenceService CR
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model>/<version>
Note right of U: The InferenceService should<br/>point to the model registry<br/>indexed model, e.g.,:<br/> model-registry://<model-registry-url>/<model>/<version>
KC->>KC: React to InferenceService creation
KC->>+MD: Create Model Deployment
MD->>+MRSI: Initialization (Download Model)
Expand Down Expand Up @@ -66,14 +66,17 @@ Which wil create the executable under `bin/mr-storage-initializer`.

You can run `main.go` (without building the executable) by running:
```bash
./bin/mr-storage-initializer "model-registry://model/version" "./"
./bin/mr-storage-initializer "model-registry://model-registry-url/model/version" "./"
```

or directly running the `main.go` skipping the previous step:
```bash
make SOURCE_URI=model-registry://model/version DEST_PATH=./ run
make SOURCE_URI=model-registry://model-registry-url/model/version DEST_PATH=./ run
```

> [!NOTE]
> `model-registry-url` is optional, if not provided the value of `MODEL_REGISTRY_BASE_URL` env variable will be used.

> [!NOTE]
> A Model Registry service should be up and running at `localhost:8080`.

Expand Down
2 changes: 1 addition & 1 deletion csi/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.21

require (
github.com/kserve/kserve v0.13.1
github.com/kubeflow/model-registry v0.2.8-alpha
github.com/kubeflow/model-registry v0.2.9
)

require (
Expand Down
4 changes: 2 additions & 2 deletions csi/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kserve/kserve v0.13.1 h1:MRszrN5pf1nNzBBoyTeBsoIYcbWvuve5G1pBwdKj9dI=
github.com/kserve/kserve v0.13.1/go.mod h1:l6fHejIVM3RYO9cD9Q0gQ4eriCz3lQaFIdcT05rMUbs=
github.com/kubeflow/model-registry v0.2.8-alpha h1:M1rdHHTlQ/QKZv3Wi9EZU+2heAe3YbaOLjb8uq9FggI=
github.com/kubeflow/model-registry v0.2.8-alpha/go.mod h1:luHw6hNjX7oim7PFiYxOEkSf3EGqZ/qZ5cUji5fKlA4=
github.com/kubeflow/model-registry v0.2.9 h1:+uOUeJwo9yzUObfqdgmQmGhrWWVi0mKGK8hK5UFBnFM=
github.com/kubeflow/model-registry v0.2.9/go.mod h1:QXBKUJhldJj6mB81+OBUIUYtajUzQw72zIoPbCSDriM=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
Expand Down
6 changes: 5 additions & 1 deletion csi/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"log"
"os"

"github.com/kubeflow/model-registry/csi/pkg/modelregistry"
"github.com/kubeflow/model-registry/csi/pkg/storage"
"github.com/kubeflow/model-registry/pkg/openapi"
)
Expand Down Expand Up @@ -38,7 +39,10 @@ func main() {
cfg := openapi.NewConfiguration()
cfg.Host = baseUrl
cfg.Scheme = scheme
provider, err := storage.NewModelRegistryProvider(cfg)

apiClient := modelregistry.NewAPIClient(cfg, sourceUri)

provider, err := storage.NewModelRegistryProvider(apiClient)
if err != nil {
log.Fatalf("Error initiliazing model registry provider: %v", err)
}
Expand Down
5 changes: 5 additions & 0 deletions csi/pkg/constants/constants.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package constants

import kserve "github.com/kserve/kserve/pkg/agent/storage"

const MR kserve.Protocol = "model-registry://"
41 changes: 41 additions & 0 deletions csi/pkg/modelregistry/api_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package modelregistry

import (
"context"
"log"
"strings"

"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

func NewAPIClient(cfg *openapi.Configuration, storageUri string) *openapi.APIClient {
client := openapi.NewAPIClient(cfg)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) < 2 {
return client
}

newCfg := openapi.NewConfiguration()
newCfg.Host = tokens[0]
newCfg.Scheme = cfg.Scheme

newClient := openapi.NewAPIClient(newCfg)

if len(tokens) == 2 {
// Check if the model registry service is available
_, _, err := newClient.ModelRegistryServiceAPI.GetRegisteredModels(context.Background()).Execute()
if err != nil {
log.Printf("Falling back to base url %s for model registry service", cfg.Host)

return client
}
}

return newClient
}
161 changes: 114 additions & 47 deletions csi/pkg/storage/modelregistry_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,79 +2,81 @@ package storage

import (
"context"
"errors"
"fmt"
"log"
"regexp"
"strings"

kserve "github.com/kserve/kserve/pkg/agent/storage"
"github.com/kubeflow/model-registry/csi/pkg/constants"
"github.com/kubeflow/model-registry/pkg/openapi"
)

const MR kserve.Protocol = "model-registry://"
var (
_ kserve.Provider = (*ModelRegistryProvider)(nil)
ErrInvalidMRURI = errors.New("invalid model registry URI, use like model-registry://{dnsName}/{registeredModelName}/{versionName}")
ErrNoVersionAssociated = errors.New("no versions associated to registered model")
ErrNoArtifactAssociated = errors.New("no artifacts associated to model version")
ErrNoModelArtifact = errors.New("no model artifact found for model version")
ErrModelArtifactEmptyURI = errors.New("model artifact has empty URI")
ErrNoStorageURI = errors.New("there is no storageUri supplied")
ErrNoProtocolInSTorageURI = errors.New("there is no protocol specified for the storageUri")
ErrProtocolNotSupported = errors.New("protocol not supported for storageUri")
ErrFetchingModelVersion = errors.New("error fetching model version")
ErrFetchingModelVersions = errors.New("error fetching model versions")
)

type ModelRegistryProvider struct {
Client *openapi.APIClient
Providers map[kserve.Protocol]kserve.Provider
}

func NewModelRegistryProvider(cfg *openapi.Configuration) (*ModelRegistryProvider, error) {
client := openapi.NewAPIClient(cfg)

func NewModelRegistryProvider(client *openapi.APIClient) (*ModelRegistryProvider, error) {
return &ModelRegistryProvider{
Client: client,
Providers: map[kserve.Protocol]kserve.Provider{},
}, nil
}

var _ kserve.Provider = (*ModelRegistryProvider)(nil)

// storageUri formatted like model-registry://{registeredModelName}/{versionName}
// storageUri formatted like model-registry://{modelRegistryUrl}/{registeredModelName}/{versionName}
func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s", modelName, storageUri, modelDir)

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(MR))
tokens := strings.SplitN(mrUri, "/", 2)
log.Printf("Download model indexed in model registry: modelName=%s, storageUri=%s, modelDir=%s",
modelName,
storageUri,
modelDir,
)

if len(tokens) == 0 || len(tokens) > 2 {
return fmt.Errorf("invalid model registry URI, use like model-registry://{registeredModelName}/{versionName}")
registeredModelName, versionName, err := p.parseModelVersion(storageUri)
if err != nil {
return err
}

registeredModelName := tokens[0]
var versionName *string
if len(tokens) == 2 {
versionName = &tokens[1]
}
log.Printf("Parsed storageUri=%s as: modelRegistryUrl=%s, registeredModelName=%s, versionName=%v",
storageUri,
p.Client.GetConfig().Host,
registeredModelName,
versionName,
)

log.Printf("Fetching model: registeredModelName=%s, versionName=%v", registeredModelName, versionName)

// Fetch the registered model
model, _, err := p.Client.ModelRegistryServiceAPI.FindRegisteredModel(context.Background()).Name(registeredModelName).Execute()
if err != nil {
return err
}

// Fetch model version by name or latest if not specified
var version *openapi.ModelVersion
if versionName != nil {
version, _, err = p.Client.ModelRegistryServiceAPI.FindModelVersion(context.Background()).Name(*versionName).ParentResourceId(*model.Id).Execute()
if err != nil {
return err
}
} else {
versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return err
}
log.Printf("Fetching model version: model=%v", model)

if versions.Size == 0 {
return fmt.Errorf("no versions associated to registered model %s", registeredModelName)
}
version = &versions.Items[0]
// Fetch model version by name or latest if not specified
version, err := p.fetchModelVersion(versionName, registeredModelName, model)
if err != nil {
return err
}

log.Printf("Fetching model artifacts: version=%v", version)

artifacts, _, err := p.Client.ModelRegistryServiceAPI.GetModelVersionArtifacts(context.Background(), *version.Id).
OrderBy(openapi.ORDERBYFIELD_CREATE_TIME).
SortOrder(openapi.SORTORDER_DESC).
Expand All @@ -84,20 +86,20 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
}

if artifacts.Size == 0 {
return fmt.Errorf("no artifacts associated to model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoArtifactAssociated, *version.Id)
}

modelArtifact := artifacts.Items[0].ModelArtifact
if modelArtifact == nil {
return fmt.Errorf("no model artifact found for model version %s", *version.Id)
return fmt.Errorf("%w %s", ErrNoModelArtifact, *version.Id)
}

// Call appropriate kserve provider based on the indexed model artifact URI
if modelArtifact.Uri == nil {
return fmt.Errorf("model artifact %s has empty URI", *modelArtifact.Id)
return fmt.Errorf("%w %s", ErrModelArtifactEmptyURI, *modelArtifact.Id)
}

protocol, err := extractProtocol(*modelArtifact.Uri)
protocol, err := p.extractProtocol(*modelArtifact.Uri)
if err != nil {
return err
}
Expand All @@ -110,19 +112,84 @@ func (p *ModelRegistryProvider) DownloadModel(modelDir string, modelName string,
return provider.DownloadModel(modelDir, "", *modelArtifact.Uri)
}

func extractProtocol(storageURI string) (kserve.Protocol, error) {
// Possible URIs:
// (1) model-registry://{modelName}
// (2) model-registry://{modelName}/{modelVersion}
// (3) model-registry://{modelRegistryUrl}/{modelName}
// (4) model-registry://{modelRegistryUrl}/{modelName}/{modelVersion}
func (p *ModelRegistryProvider) parseModelVersion(storageUri string) (string, *string, error) {
var versionName *string

// Parse the URI to retrieve the needed information to query model registry (modelArtifact)
mrUri := strings.TrimPrefix(storageUri, string(constants.MR))

tokens := strings.SplitN(mrUri, "/", 3)

if len(tokens) == 0 || len(tokens) > 3 {
return "", nil, ErrInvalidMRURI
}

// Check if the first token is the host and remove it so that we reduce cases (3) and (4) to (1) and (2)
if len(tokens) >= 2 && p.Client.GetConfig().Host == tokens[0] {
tokens = tokens[1:]
}

registeredModelName := tokens[0]

if len(tokens) == 2 {
versionName = &tokens[1]
}

return registeredModelName, versionName, nil
}

func (p *ModelRegistryProvider) fetchModelVersion(
versionName *string,
registeredModelName string,
model *openapi.RegisteredModel,
) (*openapi.ModelVersion, error) {
if versionName != nil {
version, _, err := p.Client.ModelRegistryServiceAPI.
FindModelVersion(context.Background()).
Name(*versionName).
ParentResourceId(*model.Id).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersion, err)
}

return version, nil
}

versions, _, err := p.Client.ModelRegistryServiceAPI.GetRegisteredModelVersions(context.Background(), *model.Id).
// OrderBy(openapi.ORDERBYFIELD_CREATE_TIME). not supported
SortOrder(openapi.SORTORDER_DESC).
Execute()
if err != nil {
return nil, fmt.Errorf("%w: %w", ErrFetchingModelVersions, err)
}

if versions.Size == 0 {
return nil, fmt.Errorf("%w %s", ErrNoVersionAssociated, registeredModelName)
}

return &versions.Items[0], nil
}

func (*ModelRegistryProvider) extractProtocol(storageURI string) (kserve.Protocol, error) {
if storageURI == "" {
return "", fmt.Errorf("there is no storageUri supplied")
return "", ErrNoStorageURI
}

if !regexp.MustCompile("\\w+?://").MatchString(storageURI) {
return "", fmt.Errorf("there is no protocol specified for the storageUri")
if !regexp.MustCompile(`\w+?://`).MatchString(storageURI) {
return "", ErrNoProtocolInSTorageURI
}

for _, prefix := range kserve.SupportedProtocols {
if strings.HasPrefix(storageURI, string(prefix)) {
return prefix, nil
}
}
return "", fmt.Errorf("protocol not supported for storageUri")

return "", ErrProtocolNotSupported
}
Loading