Skip to content

Commit

Permalink
Implements pagination support for GetAll style endpoints
Browse files Browse the repository at this point in the history
Signed-off-by: Alex Creasy <[email protected]>
  • Loading branch information
alexcreasy committed Sep 25, 2024
1 parent 68b5dc7 commit c42f302
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 26 deletions.
2 changes: 1 addition & 1 deletion clients/ui/bff/internal/api/model_versions_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func (app *App) GetAllModelArtifactsByModelVersionHandler(w http.ResponseWriter,
return
}

data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId))
data, err := app.modelRegistryClient.GetModelArtifactsByModelVersion(client, ps.ByName(ModelVersionId), r.URL.Query())
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand Down
15 changes: 6 additions & 9 deletions clients/ui/bff/internal/api/registered_models_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,25 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/kubeflow/model-registry/ui/bff/internal/validation"
"net/http"

"github.com/julienschmidt/httprouter"
"github.com/kubeflow/model-registry/pkg/openapi"
)

type RegisteredModelEnvelope Envelope[*openapi.RegisteredModel, None]
type RegisteredModelListEnvelope Envelope[*openapi.RegisteredModelList, None]
type RegisteredModelUpdateEnvelope Envelope[*openapi.RegisteredModelUpdate, None]

func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (ederign) implement pagination
func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client)
modelList, err := app.modelRegistryClient.GetAllRegisteredModels(client, r.URL.Query())
if err != nil {
app.serverErrorResponse(w, r, err)
return
Expand All @@ -40,7 +38,7 @@ func (app *App) GetAllRegisteredModelsHandler(w http.ResponseWriter, r *http.Req
}
}

func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
func (app *App) CreateRegisteredModelHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
Expand Down Expand Up @@ -173,14 +171,13 @@ func (app *App) UpdateRegisteredModelHandler(w http.ResponseWriter, r *http.Requ
}

func (app *App) GetAllModelVersionsForRegisteredModelHandler(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//TODO (acreasy) implement pagination
client, ok := r.Context().Value(httpClientKey).(integrations.HTTPClientInterface)
if !ok {
app.serverErrorResponse(w, r, errors.New("REST client not found"))
return
}

versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId))
versionList, err := app.modelRegistryClient.GetAllModelVersions(client, ps.ByName(RegisteredModelId), r.URL.Query())

if err != nil {
app.serverErrorResponse(w, r, err)
Expand Down
38 changes: 38 additions & 0 deletions clients/ui/bff/internal/data/helpers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package data

import (
"fmt"
"net/url"
)

func FilterPageValues(values url.Values) url.Values {
result := url.Values{}

if v := values.Get("pageSize"); v != "" {
result.Set("pageSize", v)
}
if v := values.Get("orderBy"); v != "" {
result.Set("orderBy", v)
}
if v := values.Get("sortOrder"); v != "" {
result.Set("sortOrder", v)
}
if v := values.Get("nextPageToken"); v != "" {
result.Set("nextPageToken", v)
}

return result
}

func UrlWithParams(url string, values url.Values) string {
queryString := values.Encode()
if queryString == "" {
return url
}
return fmt.Sprintf("%s?%s", url, queryString)
}

func UrlWithPageParams(url string, values url.Values) string {
pageValues := FilterPageValues(values)
return UrlWithParams(url, pageValues)
}
6 changes: 3 additions & 3 deletions clients/ui/bff/internal/data/model_version.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ type ModelVersionInterface interface {
GetModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersion, error)
CreateModelVersion(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.ModelVersion, error)
UpdateModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error)
GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error)
CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error)
}

Expand Down Expand Up @@ -79,14 +79,14 @@ func (v ModelVersion) UpdateModelVersion(client integrations.HTTPClientInterface
return &model, nil
}

func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (v ModelVersion) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelArtifactList, error) {
path, err := url.JoinPath(modelVersionPath, id, artifactsByModelVersionPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
responseData, err := client.GET(UrlWithPageParams(path, pageValues))
if err != nil {
return nil, fmt.Errorf("error fetching model version artifacts: %w", err)
}
Expand Down
28 changes: 27 additions & 1 deletion clients/ui/bff/internal/data/model_version_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"encoding/json"
"fmt"
"github.com/brianvoe/gofakeit/v7"
"github.com/kubeflow/model-registry/ui/bff/internal/mocks"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -105,7 +106,7 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, path, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1")
actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", nil)
assert.NoError(t, err)

assert.NotNil(t, actual)
Expand All @@ -115,6 +116,31 @@ func TestGetModelArtifactsByModelVersion(t *testing.T) {
assert.Equal(t, len(expected.Items), len(actual.Items))
}

func TestGetModelArtifactsByModelVersionWithPageParams(t *testing.T) {
gofakeit.Seed(0)

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockModelArtifactList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

modelVersion := ModelVersion{}

path, err := url.JoinPath(modelVersionPath, "1", artifactsByModelVersionPath)
assert.NoError(t, err)
reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode())

mockClient := new(mocks.MockHTTPClient)
mockClient.On(http.MethodGet, reqUrl, mock.Anything).Return(mockData, nil)

actual, err := modelVersion.GetModelArtifactsByModelVersion(mockClient, "1", pageValues)
assert.NoError(t, err)

assert.NotNil(t, actual)
mockClient.AssertExpectations(t)
}

func TestCreateModelArtifactByModelVersion(t *testing.T) {
gofakeit.Seed(0)

Expand Down
12 changes: 6 additions & 6 deletions clients/ui/bff/internal/data/registered_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,21 @@ const registeredModelPath = "/registered_models"
const versionsPath = "/versions"

type RegisteredModelInterface interface {
GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error)
GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error)
CreateRegisteredModel(client integrations.HTTPClientInterface, jsonData []byte) (*openapi.RegisteredModel, error)
GetRegisteredModel(client integrations.HTTPClientInterface, id string) (*openapi.RegisteredModel, error)
UpdateRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.RegisteredModel, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error)
GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error)
CreateModelVersionForRegisteredModel(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelVersion, error)
}

type RegisteredModel struct {
RegisteredModelInterface
}

func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
func (m RegisteredModel) GetAllRegisteredModels(client integrations.HTTPClientInterface, pageValues url.Values) (*openapi.RegisteredModelList, error) {
responseData, err := client.GET(UrlWithPageParams(registeredModelPath, pageValues))

responseData, err := client.GET(registeredModelPath)
if err != nil {
return nil, fmt.Errorf("error fetching registered models: %w", err)
}
Expand Down Expand Up @@ -94,14 +94,14 @@ func (m RegisteredModel) UpdateRegisteredModel(client integrations.HTTPClientInt
return &model, nil
}

func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) {
func (m RegisteredModel) GetAllModelVersions(client integrations.HTTPClientInterface, id string, pageValues url.Values) (*openapi.ModelVersionList, error) {
path, err := url.JoinPath(registeredModelPath, id, versionsPath)

if err != nil {
return nil, err
}

responseData, err := client.GET(path)
responseData, err := client.GET(UrlWithPageParams(path, pageValues))

if err != nil {
return nil, fmt.Errorf("error fetching model versions: %w", err)
Expand Down
52 changes: 50 additions & 2 deletions clients/ui/bff/internal/data/registered_model_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package data

import (
"encoding/json"
"fmt"
"github.com/brianvoe/gofakeit/v7"
"github.com/kubeflow/model-registry/ui/bff/internal/mocks"
"github.com/stretchr/testify/assert"
Expand All @@ -24,7 +25,7 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", registeredModelPath).Return(mockData, nil)

actual, err := registeredModel.GetAllRegisteredModels(mockClient)
actual, err := registeredModel.GetAllRegisteredModels(mockClient, nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.Equal(t, expected.NextPageToken, actual.NextPageToken)
Expand All @@ -35,6 +36,28 @@ func TestGetAllRegisteredModels(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllRegisteredModelsWithPageParams(t *testing.T) {
gofakeit.Seed(0)

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockRegisteredModelList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

reqUrl := fmt.Sprintf("%s?%s", registeredModelPath, pageValues.Encode())

registeredModel := RegisteredModel{}

mockClient := new(mocks.MockHTTPClient)
mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllRegisteredModels(mockClient, pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)
mockClient.AssertExpectations(t)
}

func TestCreateRegisteredModel(t *testing.T) {
gofakeit.Seed(0)

Expand Down Expand Up @@ -125,7 +148,7 @@ func TestGetAllModelVersions(t *testing.T) {
assert.NoError(t, err)
mockClient.On("GET", path).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1")
actual, err := registeredModel.GetAllModelVersions(mockClient, "1", nil)
assert.NoError(t, err)
assert.NotNil(t, actual)
assert.NoError(t, err)
Expand All @@ -138,6 +161,31 @@ func TestGetAllModelVersions(t *testing.T) {
mockClient.AssertExpectations(t)
}

func TestGetAllModelVersionsWithPageParams(t *testing.T) {
gofakeit.Seed(0)

pageValues := mocks.GenerateMockPageValues()
expected := mocks.GenerateMockModelVersionList()

mockData, err := json.Marshal(expected)
assert.NoError(t, err)

registeredModel := RegisteredModel{}

mockClient := new(mocks.MockHTTPClient)
path, err := url.JoinPath(registeredModelPath, "1", versionsPath)
assert.NoError(t, err)
reqUrl := fmt.Sprintf("%s?%s", path, pageValues.Encode())

mockClient.On("GET", reqUrl).Return(mockData, nil)

actual, err := registeredModel.GetAllModelVersions(mockClient, "1", pageValues)
assert.NoError(t, err)
assert.NotNil(t, actual)

mockClient.AssertExpectations(t)
}

func TestCreateModelVersionForRegisteredModel(t *testing.T) {
gofakeit.Seed(0)

Expand Down
10 changes: 6 additions & 4 deletions clients/ui/bff/internal/mocks/model_registry_client_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,18 @@ import (
"github.com/kubeflow/model-registry/ui/bff/internal/integrations"
"github.com/stretchr/testify/mock"
"log/slog"
"net/url"
)

type ModelRegistryClientMock struct {
mock.Mock
}

func NewModelRegistryClient(logger *slog.Logger) (*ModelRegistryClientMock, error) {
func NewModelRegistryClient(_ *slog.Logger) (*ModelRegistryClientMock, error) {
return &ModelRegistryClientMock{}, nil
}

func (m *ModelRegistryClientMock) GetAllRegisteredModels(client integrations.HTTPClientInterface) (*openapi.RegisteredModelList, error) {
func (m *ModelRegistryClientMock) GetAllRegisteredModels(_ integrations.HTTPClientInterface, _ url.Values) (*openapi.RegisteredModelList, error) {
mockData := GetRegisteredModelListMock()
return &mockData, nil
}
Expand Down Expand Up @@ -50,7 +51,7 @@ func (m *ModelRegistryClientMock) UpdateModelVersion(client integrations.HTTPCli
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetAllModelVersions(client integrations.HTTPClientInterface, id string) (*openapi.ModelVersionList, error) {
func (m *ModelRegistryClientMock) GetAllModelVersions(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelVersionList, error) {
mockData := GetModelVersionListMock()
return &mockData, nil
}
Expand All @@ -60,10 +61,11 @@ func (m *ModelRegistryClientMock) CreateModelVersionForRegisteredModel(client in
return &mockData, nil
}

func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(client integrations.HTTPClientInterface, id string) (*openapi.ModelArtifactList, error) {
func (m *ModelRegistryClientMock) GetModelArtifactsByModelVersion(_ integrations.HTTPClientInterface, _ string, _ url.Values) (*openapi.ModelArtifactList, error) {
mockData := GetModelArtifactListMock()
return &mockData, nil
}

func (m *ModelRegistryClientMock) CreateModelArtifactByModelVersion(client integrations.HTTPClientInterface, id string, jsonData []byte) (*openapi.ModelArtifact, error) {
mockData := GetModelArtifactMocks()[0]
return &mockData, nil
Expand Down
13 changes: 13 additions & 0 deletions clients/ui/bff/internal/mocks/types_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"fmt"
"github.com/brianvoe/gofakeit/v7"
"github.com/kubeflow/model-registry/pkg/openapi"
"net/url"
"strconv"
)

func GenerateMockRegisteredModelList() openapi.RegisteredModelList {
Expand Down Expand Up @@ -125,6 +127,17 @@ func GenerateMockModelArtifactList() openapi.ModelArtifactList {
}
}

func GenerateMockPageValues() url.Values {
pageValues := url.Values{}

pageValues.Add("pageSize", strconv.Itoa(gofakeit.Number(1, 100)))
pageValues.Add("orderBy", gofakeit.RandomString([]string{"CREATE_TIME", "LAST_UPDATE_TIME", "ID"}))
pageValues.Add("sortOrder", gofakeit.RandomString([]string{"ASC", "DESC"}))
pageValues.Add("nextPageToken", gofakeit.UUID())

return pageValues
}

func randomEpochTime() *string {
return stringToPointer(fmt.Sprintf("%d", gofakeit.Date().UnixMilli()))
}
Expand Down

0 comments on commit c42f302

Please sign in to comment.