From 5a4bafccc63ca2793130af04dce3b27268f58b1d Mon Sep 17 00:00:00 2001 From: wei liu Date: Wed, 8 Jan 2025 19:40:57 +0800 Subject: [PATCH] enhance: Add resource group api on Restful (#38885) issue: #38739 enable to manage resource group by Restful - CreateResourceGroup - DescribeResourceGroup - UpdateResourceGroup - ListResourceGroup - DropResourceGroup - TransferReplica Signed-off-by: Wei Liu --- .../distributed/proxy/httpserver/constant.go | 2 + .../proxy/httpserver/handler_v2.go | 8 + .../proxy/httpserver/request_v2.go | 96 ++++++++++ .../httpserver/resource_group_handler_v2.go | 175 ++++++++++++++++++ .../resource_group_handler_v2_test.go | 167 +++++++++++++++++ 5 files changed, 448 insertions(+) create mode 100644 internal/distributed/proxy/httpserver/resource_group_handler_v2.go create mode 100644 internal/distributed/proxy/httpserver/resource_group_handler_v2_test.go diff --git a/internal/distributed/proxy/httpserver/constant.go b/internal/distributed/proxy/httpserver/constant.go index 1aece6d66af4f..f797b8fdf4862 100644 --- a/internal/distributed/proxy/httpserver/constant.go +++ b/internal/distributed/proxy/httpserver/constant.go @@ -36,6 +36,7 @@ const ( ImportJobCategory = "/jobs/import/" PrivilegeGroupCategory = "/privilege_groups/" CollectionFieldCategory = "/collections/fields/" + ResourceGroupCategory = "/resource_groups/" ListAction = "list" HasAction = "has" @@ -70,6 +71,7 @@ const ( GetProgressAction = "get_progress" // deprecated, keep it for compatibility, use `/v2/vectordb/jobs/import/describe` instead AddPrivilegesToGroupAction = "add_privileges_to_group" RemovePrivilegesFromGroupAction = "remove_privileges_from_group" + TransferReplicaAction = "transfer_replica" ) const ( diff --git a/internal/distributed/proxy/httpserver/handler_v2.go b/internal/distributed/proxy/httpserver/handler_v2.go index 2bcc0318f420a..2fc9c9dc60a4f 100644 --- a/internal/distributed/proxy/httpserver/handler_v2.go +++ b/internal/distributed/proxy/httpserver/handler_v2.go @@ -189,6 +189,14 @@ func (h *HandlersV2) RegisterRoutesToV2(router gin.IRouter) { router.POST(ImportJobCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ImportReq{} }, wrapperTraceLog(h.createImportJob)))) router.POST(ImportJobCategory+GetProgressAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess)))) router.POST(ImportJobCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &JobIDReq{} }, wrapperTraceLog(h.getImportJobProcess)))) + + // resource group + router.POST(ResourceGroupCategory+CreateAction, timeoutMiddleware(wrapperPost(func() any { return &ResourceGroupReq{} }, wrapperTraceLog(h.createResourceGroup)))) + router.POST(ResourceGroupCategory+DropAction, timeoutMiddleware(wrapperPost(func() any { return &ResourceGroupReq{} }, wrapperTraceLog(h.dropResourceGroup)))) + router.POST(ResourceGroupCategory+AlterAction, timeoutMiddleware(wrapperPost(func() any { return &UpdateResourceGroupReq{} }, wrapperTraceLog(h.updateResourceGroup)))) + router.POST(ResourceGroupCategory+DescribeAction, timeoutMiddleware(wrapperPost(func() any { return &ResourceGroupReq{} }, wrapperTraceLog(h.describeResourceGroup)))) + router.POST(ResourceGroupCategory+ListAction, timeoutMiddleware(wrapperPost(func() any { return &EmptyReq{} }, wrapperTraceLog(h.listResourceGroups)))) + router.POST(ResourceGroupCategory+TransferReplicaAction, timeoutMiddleware(wrapperPost(func() any { return &TransferReplicaReq{} }, wrapperTraceLog(h.transferReplica)))) } type ( diff --git a/internal/distributed/proxy/httpserver/request_v2.go b/internal/distributed/proxy/httpserver/request_v2.go index 4231c7acd8b2f..409c8fb799212 100644 --- a/internal/distributed/proxy/httpserver/request_v2.go +++ b/internal/distributed/proxy/httpserver/request_v2.go @@ -542,3 +542,99 @@ func wrapperReturnDefault() gin.H { func wrapperReturnDefaultWithCost(cost int) gin.H { return gin.H{HTTPReturnCode: merr.Code(nil), HTTPReturnData: gin.H{}, HTTPReturnCost: cost} } + +type ResourceGroupNodeFilter struct { + NodeLabels map[string]string `json:"node_labels" binding:"required"` +} + +func (req *ResourceGroupNodeFilter) GetNodeLabels() map[string]string { + return req.NodeLabels +} + +type ResourceGroupLimit struct { + NodeNum int32 `json:"node_num" binding:"required"` +} + +func (req *ResourceGroupLimit) GetNodeNum() int32 { + return req.NodeNum +} + +type ResourceGroupTransfer struct { + ResourceGroup string `json:"resource_group" binding:"required"` +} + +func (req *ResourceGroupTransfer) GetResourceGroup() string { + return req.ResourceGroup +} + +type ResourceGroupConfig struct { + Requests *ResourceGroupLimit `json:"requests" binding:"required"` + Limits *ResourceGroupLimit `json:"limits" binding:"required"` + TransferFrom []*ResourceGroupTransfer `json:"transfer_from"` + TransferTo []*ResourceGroupTransfer `json:"transfer_to"` + NodeFilter *ResourceGroupNodeFilter `json:"node_filter"` +} + +func (req *ResourceGroupConfig) GetRequests() *ResourceGroupLimit { + return req.Requests +} + +func (req *ResourceGroupConfig) GetLimits() *ResourceGroupLimit { + return req.Limits +} + +func (req *ResourceGroupConfig) GetTransferFrom() []*ResourceGroupTransfer { + return req.TransferFrom +} + +func (req *ResourceGroupConfig) GetTransferTo() []*ResourceGroupTransfer { + return req.TransferTo +} + +func (req *ResourceGroupConfig) GetNodeFilter() *ResourceGroupNodeFilter { + return req.NodeFilter +} + +type ResourceGroupReq struct { + Name string `json:"name" binding:"required"` + Config *ResourceGroupConfig `json:"config"` +} + +func (req *ResourceGroupReq) GetName() string { + return req.Name +} + +func (req *ResourceGroupReq) GetConfig() *ResourceGroupConfig { + return req.Config +} + +type UpdateResourceGroupReq struct { + ResourceGroups map[string]*ResourceGroupConfig `json:"resource_groups" binding:"required"` +} + +func (req *UpdateResourceGroupReq) GetResourceGroups() map[string]*ResourceGroupConfig { + return req.ResourceGroups +} + +type TransferReplicaReq struct { + SourceRgName string `json:"sourceRgName" binding:"required"` + TargetRgName string `json:"targetRgName" binding:"required"` + CollectionName string `json:"collectionName" binding:"required"` + ReplicaNum int64 `json:"replicaNum" binding:"required"` +} + +func (req *TransferReplicaReq) GetSourceRgName() string { + return req.SourceRgName +} + +func (req *TransferReplicaReq) GetTargetRgName() string { + return req.TargetRgName +} + +func (req *TransferReplicaReq) GetCollectionName() string { + return req.CollectionName +} + +func (req *TransferReplicaReq) GetReplicaNum() int64 { + return req.ReplicaNum +} diff --git a/internal/distributed/proxy/httpserver/resource_group_handler_v2.go b/internal/distributed/proxy/httpserver/resource_group_handler_v2.go new file mode 100644 index 0000000000000..33a618e259cbe --- /dev/null +++ b/internal/distributed/proxy/httpserver/resource_group_handler_v2.go @@ -0,0 +1,175 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpserver + +import ( + "context" + "net/http" + + "github.com/gin-gonic/gin" + "github.com/samber/lo" + + "github.com/milvus-io/milvus-proto/go-api/v2/commonpb" + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus-proto/go-api/v2/rgpb" + "github.com/milvus-io/milvus/pkg/log" +) + +func toKVPair(data map[string]string) []*commonpb.KeyValuePair { + if len(data) == 0 { + return nil + } + + var pairs []*commonpb.KeyValuePair + for k, v := range data { + pairs = append(pairs, &commonpb.KeyValuePair{ + Key: k, + Value: v, + }) + } + return pairs +} + +func toPbRGNodeFilter(nodeFilter *ResourceGroupNodeFilter) *rgpb.ResourceGroupNodeFilter { + if nodeFilter == nil { + return nil + } + + return &rgpb.ResourceGroupNodeFilter{ + NodeLabels: toKVPair(nodeFilter.GetNodeLabels()), + } +} + +func toPbResourceGroupConfig(config *ResourceGroupConfig) *rgpb.ResourceGroupConfig { + if config == nil { + return nil + } + + return &rgpb.ResourceGroupConfig{ + Requests: &rgpb.ResourceGroupLimit{ + NodeNum: config.GetRequests().GetNodeNum(), + }, + Limits: &rgpb.ResourceGroupLimit{ + NodeNum: config.GetLimits().GetNodeNum(), + }, + TransferFrom: lo.Map(config.GetTransferFrom(), func(t *ResourceGroupTransfer, _ int) *rgpb.ResourceGroupTransfer { + return &rgpb.ResourceGroupTransfer{ + ResourceGroup: t.GetResourceGroup(), + } + }), + TransferTo: lo.Map(config.GetTransferTo(), func(t *ResourceGroupTransfer, _ int) *rgpb.ResourceGroupTransfer { + return &rgpb.ResourceGroupTransfer{ + ResourceGroup: t.GetResourceGroup(), + } + }), + NodeFilter: toPbRGNodeFilter(config.GetNodeFilter()), + } +} + +func (h *HandlersV2) createResourceGroup(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*ResourceGroupReq) + req := &milvuspb.CreateResourceGroupRequest{ + ResourceGroup: httpReq.GetName(), + Config: toPbResourceGroupConfig(httpReq.GetConfig()), + } + log.Info("f1") + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/CreateResourceGroup", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.CreateResourceGroup(reqCtx, req.(*milvuspb.CreateResourceGroupRequest)) + }) + log.Info("f2") + if err == nil { + log.Info("f3") + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + log.Info("f4") + return resp, err +} + +func (h *HandlersV2) dropResourceGroup(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*ResourceGroupReq) + req := &milvuspb.DropResourceGroupRequest{ + ResourceGroup: httpReq.GetName(), + } + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DropResourceGroup", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DropResourceGroup(reqCtx, req.(*milvuspb.DropResourceGroupRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) listResourceGroups(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + req := &milvuspb.ListResourceGroupsRequest{} + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/ListResourceGroups", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.ListResourceGroups(reqCtx, req.(*milvuspb.ListResourceGroupsRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnList(resp.(*milvuspb.ListResourceGroupsResponse).GetResourceGroups())) + } + return resp, err +} + +func (h *HandlersV2) describeResourceGroup(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*ResourceGroupReq) + req := &milvuspb.DescribeResourceGroupRequest{ + ResourceGroup: httpReq.GetName(), + } + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/DescribeResourceGroup", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.DescribeResourceGroup(reqCtx, req.(*milvuspb.DescribeResourceGroupRequest)) + }) + if err == nil { + response := resp.(*milvuspb.DescribeResourceGroupResponse) + HTTPReturn(c, http.StatusOK, gin.H{ + "resource_group": response.GetResourceGroup(), + }) + } + return resp, err +} + +func (h *HandlersV2) updateResourceGroup(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*UpdateResourceGroupReq) + req := &milvuspb.UpdateResourceGroupsRequest{ + ResourceGroups: lo.MapValues(httpReq.GetResourceGroups(), func(config *ResourceGroupConfig, name string) *rgpb.ResourceGroupConfig { + return toPbResourceGroupConfig(config) + }), + } + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/UpdateResourceGroups", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.UpdateResourceGroups(reqCtx, req.(*milvuspb.UpdateResourceGroupsRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} + +func (h *HandlersV2) transferReplica(ctx context.Context, c *gin.Context, anyReq any, dbName string) (interface{}, error) { + httpReq := anyReq.(*TransferReplicaReq) + req := &milvuspb.TransferReplicaRequest{ + SourceResourceGroup: httpReq.GetSourceRgName(), + TargetResourceGroup: httpReq.GetTargetRgName(), + CollectionName: httpReq.GetCollectionName(), + NumReplica: httpReq.GetReplicaNum(), + } + resp, err := wrapperProxyWithLimit(ctx, c, req, h.checkAuth, false, "/milvus.proto.milvus.MilvusService/TransferMaster", true, h.proxy, func(reqCtx context.Context, req any) (interface{}, error) { + return h.proxy.TransferReplica(reqCtx, req.(*milvuspb.TransferReplicaRequest)) + }) + if err == nil { + HTTPReturn(c, http.StatusOK, wrapperReturnDefault()) + } + return resp, err +} diff --git a/internal/distributed/proxy/httpserver/resource_group_handler_v2_test.go b/internal/distributed/proxy/httpserver/resource_group_handler_v2_test.go new file mode 100644 index 0000000000000..da6980fbb61ab --- /dev/null +++ b/internal/distributed/proxy/httpserver/resource_group_handler_v2_test.go @@ -0,0 +1,167 @@ +// Licensed to the LF AI & Data foundation under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package httpserver + +import ( + "bytes" + "fmt" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + + "github.com/milvus-io/milvus-proto/go-api/v2/milvuspb" + "github.com/milvus-io/milvus/internal/json" + "github.com/milvus-io/milvus/internal/mocks" + "github.com/milvus-io/milvus/pkg/util/merr" + "github.com/milvus-io/milvus/pkg/util/paramtable" +) + +func validateRequestBodyTestCases(t *testing.T, testEngine *gin.Engine, queryTestCases []requestBodyTestCase, allowInt64 bool) { + for i, testcase := range queryTestCases { + t.Run(testcase.path, func(t *testing.T) { + bodyReader := bytes.NewReader(testcase.requestBody) + req := httptest.NewRequest(http.MethodPost, testcase.path, bodyReader) + if allowInt64 { + req.Header.Set(HTTPHeaderAllowInt64, "true") + } + w := httptest.NewRecorder() + testEngine.ServeHTTP(w, req) + assert.Equal(t, http.StatusOK, w.Code, "case %d: ", i, string(testcase.requestBody)) + returnBody := &ReturnErrMsg{} + err := json.Unmarshal(w.Body.Bytes(), returnBody) + assert.Nil(t, err, "case %d: ", i) + assert.Equal(t, testcase.errCode, returnBody.Code, "case %d: ", i, string(testcase.requestBody)) + if testcase.errCode != 0 { + assert.Equal(t, testcase.errMsg, returnBody.Message, "case %d: ", i, string(testcase.requestBody)) + } + fmt.Println(w.Body.String()) + }) + } +} + +func TestResourceGroupHandlerV2(t *testing.T) { + // init params + paramtable.Init() + + // disable rate limit + paramtable.Get().Save(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key, "false") + defer paramtable.Get().Reset(paramtable.Get().QuotaConfig.QuotaAndLimitsEnabled.Key) + + // mock proxy + mockProxy := mocks.NewMockProxy(t) + mockProxy.EXPECT().CreateResourceGroup(mock.Anything, mock.Anything).Return(merr.Success(), nil) + mockProxy.EXPECT().DescribeResourceGroup(mock.Anything, mock.Anything).Return(&milvuspb.DescribeResourceGroupResponse{}, nil) + mockProxy.EXPECT().DropResourceGroup(mock.Anything, mock.Anything).Return(merr.Success(), nil) + mockProxy.EXPECT().ListResourceGroups(mock.Anything, mock.Anything).Return(&milvuspb.ListResourceGroupsResponse{}, nil) + mockProxy.EXPECT().UpdateResourceGroups(mock.Anything, mock.Anything).Return(merr.Success(), nil) + mockProxy.EXPECT().TransferReplica(mock.Anything, mock.Anything).Return(merr.Success(), nil) + + // setup test server + testServer := initHTTPServerV2(mockProxy, false) + + // setup test case + testCases := make([]requestBodyTestCase, 0) + + // test case: create resource group with empty request body + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, CreateAction), + requestBody: []byte(`{}`), + errCode: 1802, + errMsg: "missing required parameters, error: Key: 'ResourceGroupReq.Name' Error:Field validation for 'Name' failed on the 'required' tag", + }) + + // test case: create resource group with name + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, CreateAction), + requestBody: []byte(`{"name":"test"}`), + }) + + // test case: create resource group with name and config + requestBodyInJSON := `{"name":"test","config":{"requests":{"node_num":1},"limits":{"node_num":1},"transfer_from":[{"resource_group":"__default_resource_group"}],"transfer_to":[{"resource_group":"__default_resource_group"}]}}` + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, CreateAction), + requestBody: []byte(requestBodyInJSON), + }) + + // test case: describe resource group with empty request body + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, DescribeAction), + requestBody: []byte(`{}`), + errCode: 1802, + errMsg: "missing required parameters, error: Key: 'ResourceGroupReq.Name' Error:Field validation for 'Name' failed on the 'required' tag", + }) + // test case: describe resource group with name + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, DescribeAction), + requestBody: []byte(`{"name":"test"}`), + }) + + // test case: drop resource group with empty request body + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, DropAction), + requestBody: []byte(`{}`), + errCode: 1802, + errMsg: "missing required parameters, error: Key: 'ResourceGroupReq.Name' Error:Field validation for 'Name' failed on the 'required' tag", + }) + // test case: drop resource group with name + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, DropAction), + requestBody: []byte(`{"name":"test"}`), + }) + + // test case: list resource groups + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, ListAction), + requestBody: []byte(`{}`), + }) + + // test case: update resource group with empty request body + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, AlterAction), + requestBody: []byte(`{}`), + errCode: 1802, + errMsg: "missing required parameters, error: Key: 'UpdateResourceGroupReq.ResourceGroups' Error:Field validation for 'ResourceGroups' failed on the 'required' tag", + }) + + // test case: update resource group with resource groups + requestBodyInJSON = `{"resource_groups":{"test":{"requests":{"node_num":1},"limits":{"node_num":1},"transfer_from":[{"resource_group":"__default_resource_group"}],"transfer_to":[{"resource_group":"__default_resource_group"}]}}}` + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, AlterAction), + requestBody: []byte(requestBodyInJSON), + }) + + // test case: transfer replica with empty request body + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, TransferReplicaAction), + requestBody: []byte(`{}`), + errCode: 1802, + errMsg: "missing required parameters, error: Key: 'TransferReplicaReq.SourceRgName' Error:Field validation for 'SourceRgName' failed on the 'required' tag\nKey: 'TransferReplicaReq.TargetRgName' Error:Field validation for 'TargetRgName' failed on the 'required' tag\nKey: 'TransferReplicaReq.CollectionName' Error:Field validation for 'CollectionName' failed on the 'required' tag\nKey: 'TransferReplicaReq.ReplicaNum' Error:Field validation for 'ReplicaNum' failed on the 'required' tag", + }) + + // test case: transfer replica + requestBodyInJSON = `{"sourceRgName":"rg1","targetRgName":"rg2","collectionName":"hello_milvus","replicaNum":1}` + testCases = append(testCases, requestBodyTestCase{ + path: versionalV2(ResourceGroupCategory, TransferReplicaAction), + requestBody: []byte(requestBodyInJSON), + }) + + // verify test case + validateRequestBodyTestCases(t, testServer, testCases, false) +}