From c42f30206924e9ce33a7548cbed42491fa6ebc88 Mon Sep 17 00:00:00 2001 From: Alex Creasy Date: Tue, 24 Sep 2024 16:42:17 +0100 Subject: [PATCH] Implements pagination support for GetAll style endpoints Signed-off-by: Alex Creasy --- .../internal/api/model_versions_handler.go | 2 +- .../internal/api/registered_models_handler.go | 15 +++--- clients/ui/bff/internal/data/helpers.go | 38 ++++++++++++++ clients/ui/bff/internal/data/model_version.go | 6 +-- .../bff/internal/data/model_version_test.go | 28 +++++++++- .../ui/bff/internal/data/registered_model.go | 12 ++--- .../internal/data/registered_model_test.go | 52 ++++++++++++++++++- .../mocks/model_registry_client_mock.go | 10 ++-- clients/ui/bff/internal/mocks/types_mock.go | 13 +++++ 9 files changed, 150 insertions(+), 26 deletions(-) create mode 100644 clients/ui/bff/internal/data/helpers.go diff --git a/clients/ui/bff/internal/api/model_versions_handler.go b/clients/ui/bff/internal/api/model_versions_handler.go index 7b9792674..88b29f8ae 100644 --- a/clients/ui/bff/internal/api/model_versions_handler.go +++ b/clients/ui/bff/internal/api/model_versions_handler.go @@ -157,7 +157,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter, return } - data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId)) + data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId), r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) return diff --git a/clients/ui/bff/internal/api/registered_models_handler.go b/clients/ui/bff/internal/api/registered_models_handler.go index ffff7cdc4..e75d121f5 100644 --- a/clients/ui/bff/internal/api/registered_models_handler.go +++ b/clients/ui/bff/internal/api/registered_models_handler.go @@ -4,27 +4,25 @@ import ( "encoding/json" "errors" "fmt" + "github.com/julienschmidt/httprouter" + "github.com/kubeflow/model-registry/pkg/openapi" "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/kubeflow/model-registry/ui/bff/internal/validation" "net/http" - - "github.com/julienschmidt/httprouter" - "github.com/kubeflow/model-registry/pkg/openapi" ) type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None] type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None] type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None] -func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - //TODO (ederign) implement pagination +func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return } - modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client) + modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client, r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) return @@ -40,7 +38,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req } } -func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { +func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) @@ -173,14 +171,13 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ } func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { - //TODO (acreasy) implement pagination client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface) if !ok { app.serverErrorResponse(w, r, errors.New("REST client not found")) return } - versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId)) + versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query()) if err != nil { app.serverErrorResponse(w, r, err) diff --git a/clients/ui/bff/internal/data/helpers.go b/clients/ui/bff/internal/data/helpers.go new file mode 100644 index 000000000..8c19f2c26 --- /dev/null +++ b/clients/ui/bff/internal/data/helpers.go @@ -0,0 +1,38 @@ +package data + +import ( + "fmt" + "net/url" +) + +func FilterPageValues(values url.Values) url.Values { + result := url.Values{} + + if v := values.Get("pageSize"); v != "" { + result.Set("pageSize", v) + } + if v := values.Get("orderBy"); v != "" { + result.Set("orderBy", v) + } + if v := values.Get("sortOrder"); v != "" { + result.Set("sortOrder", v) + } + if v := values.Get("nextPageToken"); v != "" { + result.Set("nextPageToken", v) + } + + return result +} + +func UrlWithParams(url string, values url.Values) string { + queryString := values.Encode() + if queryString == "" { + return url + } + return fmt.Sprintf("%s?%s", url, queryString) +} + +func UrlWithPageParams(url string, values url.Values) string { + pageValues := FilterPageValues(values) + return UrlWithParams(url, pageValues) +} diff --git a/clients/ui/bff/internal/data/model_version.go b/clients/ui/bff/internal/data/model_version.go index c84c08efe..c149a177c 100644 --- a/clients/ui/bff/internal/data/model_version.go +++ b/clients/ui/bff/internal/data/model_version.go @@ -16,7 +16,7 @@ type ModelVersionInterface interface { GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error) CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error) UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) - GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) + GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) } @@ -79,14 +79,14 @@ func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface return &model, nil } -func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { +func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) { path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath) if err != nil { return nil, err } - responseData, err := client.GET(path) + responseData, err := client.GET(UrlWithPageParams(path, pageValues)) if err != nil { return nil, fmt.Errorf("error fetching model version artifacts: %w", err) } diff --git a/clients/ui/bff/internal/data/model_version_test.go b/clients/ui/bff/internal/data/model_version_test.go index b2da96371..3741dbf7b 100644 --- a/clients/ui/bff/internal/data/model_version_test.go +++ b/clients/ui/bff/internal/data/model_version_test.go @@ -2,6 +2,7 @@ package data import ( "encoding/json" + "fmt" "github.com/brianvoe/gofakeit/v7" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/stretchr/testify/assert" @@ -105,7 +106,7 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) { mockClient := new(mocks.MockHTTPClient) mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil) - actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1") + actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", nil) assert.NoError(t, err) assert.NotNil(t, actual) @@ -115,6 +116,31 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) { assert.Equal(t, len(expected.Items), len(actual.Items)) } +func TestGetModelArtifactsByModelVersionWithPageParams(t *testing.T) { + gofakeit.Seed(0) + + pageValues := mocks.GenerateMockPageValues() + expected := mocks.GenerateMockModelArtifactList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + modelVersion := ModelVersion{} + + path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath) + assert.NoError(t, err) + reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode()) + + mockClient := new(mocks.MockHTTPClient) + mockClient.On(http.MethodGet, reqUrl, mock.Anything).Return(mockData, nil) + + actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", pageValues) + assert.NoError(t, err) + + assert.NotNil(t, actual) + mockClient.AssertExpectations(t) +} + func TestCreateModelArtifactByModelVersion(t *testing.T) { gofakeit.Seed(0) diff --git a/clients/ui/bff/internal/data/registered_model.go b/clients/ui/bff/internal/data/registered_model.go index 10acd55c6..bc082feb0 100644 --- a/clients/ui/bff/internal/data/registered_model.go +++ b/clients/ui/bff/internal/data/registered_model.go @@ -13,11 +13,11 @@ const registeredModelPath = "/registered_models" const versionsPath = "/versions" type RegisteredModelInterface interface { - GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) + GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error) GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error) UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error) - GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) + GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error) } @@ -25,9 +25,9 @@ type RegisteredModel struct { RegisteredModelInterface } -func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { +func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) { + responseData, err := client.GET(UrlWithPageParams(registeredModelPath, pageValues)) - responseData, err := client.GET(registeredModelPath) if err != nil { return nil, fmt.Errorf("error fetching registered models: %w", err) } @@ -94,14 +94,14 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt return &model, nil } -func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) { +func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) { path, err := url.JoinPath(registeredModelPath, id, versionsPath) if err != nil { return nil, err } - responseData, err := client.GET(path) + responseData, err := client.GET(UrlWithPageParams(path, pageValues)) if err != nil { return nil, fmt.Errorf("error fetching model versions: %w", err) diff --git a/clients/ui/bff/internal/data/registered_model_test.go b/clients/ui/bff/internal/data/registered_model_test.go index 0af73f160..58bf7a65c 100644 --- a/clients/ui/bff/internal/data/registered_model_test.go +++ b/clients/ui/bff/internal/data/registered_model_test.go @@ -2,6 +2,7 @@ package data import ( "encoding/json" + "fmt" "github.com/brianvoe/gofakeit/v7" "github.com/kubeflow/model-registry/ui/bff/internal/mocks" "github.com/stretchr/testify/assert" @@ -24,7 +25,7 @@ func TestGetAllRegisteredModels(t *testing.T) { mockClient := new(mocks.MockHTTPClient) mockClient.On("GET", registeredModelPath).Return(mockData, nil) - actual, err := registeredModel.GetAllRegisteredModels(mockClient) + actual, err := registeredModel.GetAllRegisteredModels(mockClient, nil) assert.NoError(t, err) assert.NotNil(t, actual) assert.Equal(t, expected.NextPageToken, actual.NextPageToken) @@ -35,6 +36,28 @@ func TestGetAllRegisteredModels(t *testing.T) { mockClient.AssertExpectations(t) } +func TestGetAllRegisteredModelsWithPageParams(t *testing.T) { + gofakeit.Seed(0) + + pageValues := mocks.GenerateMockPageValues() + expected := mocks.GenerateMockRegisteredModelList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + reqUrl := fmt.Sprintf("%s?%s", registeredModelPath, pageValues.Encode()) + + registeredModel := RegisteredModel{} + + mockClient := new(mocks.MockHTTPClient) + mockClient.On("GET", reqUrl).Return(mockData, nil) + + actual, err := registeredModel.GetAllRegisteredModels(mockClient, pageValues) + assert.NoError(t, err) + assert.NotNil(t, actual) + mockClient.AssertExpectations(t) +} + func TestCreateRegisteredModel(t *testing.T) { gofakeit.Seed(0) @@ -125,7 +148,7 @@ func TestGetAllModelVersions(t *testing.T) { assert.NoError(t, err) mockClient.On("GET", path).Return(mockData, nil) - actual, err := registeredModel.GetAllModelVersions(mockClient, "1") + actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil) assert.NoError(t, err) assert.NotNil(t, actual) assert.NoError(t, err) @@ -138,6 +161,31 @@ func TestGetAllModelVersions(t *testing.T) { mockClient.AssertExpectations(t) } +func TestGetAllModelVersionsWithPageParams(t *testing.T) { + gofakeit.Seed(0) + + pageValues := mocks.GenerateMockPageValues() + expected := mocks.GenerateMockModelVersionList() + + mockData, err := json.Marshal(expected) + assert.NoError(t, err) + + registeredModel := RegisteredModel{} + + mockClient := new(mocks.MockHTTPClient) + path, err := url.JoinPath(registeredModelPath, "1", versionsPath) + assert.NoError(t, err) + reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode()) + + mockClient.On("GET", reqUrl).Return(mockData, nil) + + actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues) + assert.NoError(t, err) + assert.NotNil(t, actual) + + mockClient.AssertExpectations(t) +} + func TestCreateModelVersionForRegisteredModel(t *testing.T) { gofakeit.Seed(0) diff --git a/clients/ui/bff/internal/mocks/model_registry_client_mock.go b/clients/ui/bff/internal/mocks/model_registry_client_mock.go index b65aca0a1..8b397939c 100644 --- a/clients/ui/bff/internal/mocks/model_registry_client_mock.go +++ b/clients/ui/bff/internal/mocks/model_registry_client_mock.go @@ -5,17 +5,18 @@ import ( "github.com/kubeflow/model-registry/ui/bff/internal/integrations" "github.com/stretchr/testify/mock" "log/slog" + "net/url" ) type ModelRegistryClientMock struct { mock.Mock } -func NewModelRegistryClient(logger *slog.Logger) (*ModelRegistryClientMock, error) { +func NewModelRegistryClient(_ *slog.Logger) (*ModelRegistryClientMock, error) { return &ModelRegistryClientMock{}, nil } -func (m *ModelRegistryClientMock) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) { +func (m *ModelRegistryClientMock) GetAllRegisteredModels(_ integrations.HTTPClientInterface, _ url.Values) (*openapi.RegisteredModelList, error) { mockData := GetRegisteredModelListMock() return &mockData, nil } @@ -50,7 +51,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(client integrations.HTTPCli return &mockData, nil } -func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) { +func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) { mockData := GetModelVersionListMock() return &mockData, nil } @@ -60,10 +61,11 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in return &mockData, nil } -func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) { +func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelArtifactList, error) { mockData := GetModelArtifactListMock() return &mockData, nil } + func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) { mockData := GetModelArtifactMocks()[0] return &mockData, nil diff --git a/clients/ui/bff/internal/mocks/types_mock.go b/clients/ui/bff/internal/mocks/types_mock.go index f1d81a9a7..096d3dfe0 100644 --- a/clients/ui/bff/internal/mocks/types_mock.go +++ b/clients/ui/bff/internal/mocks/types_mock.go @@ -4,6 +4,8 @@ import ( "fmt" "github.com/brianvoe/gofakeit/v7" "github.com/kubeflow/model-registry/pkg/openapi" + "net/url" + "strconv" ) func GenerateMockRegisteredModelList() openapi.RegisteredModelList { @@ -125,6 +127,17 @@ func GenerateMockModelArtifactList() openapi.ModelArtifactList { } } +func GenerateMockPageValues() url.Values { + pageValues := url.Values{} + + pageValues.Add("pageSize", strconv.Itoa(gofakeit.Number(1, 100))) + pageValues.Add("orderBy", gofakeit.RandomString([]string{"CREATE_TIME", "LAST_UPDATE_TIME", "ID"})) + pageValues.Add("sortOrder", gofakeit.RandomString([]string{"ASC", "DESC"})) + pageValues.Add("nextPageToken", gofakeit.UUID()) + + return pageValues +} + func randomEpochTime() *string { return stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli())) }