Skip to content

Commit

Permalink
impl(generator): support protobuf wrapper types as query params (goog…
Browse files Browse the repository at this point in the history
  • Loading branch information
scotthart authored and cuiy0006 committed Jul 22, 2024
1 parent bb8bb1b commit 9c67282
Show file tree
Hide file tree
Showing 105 changed files with 509 additions and 285 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ DefaultGoldenKitchenSinkRestStub::GenerateIdToken(
absl::StrCat("/", rest_internal::DetermineApiVersion("v1", options), "/token:generate"),
rest_internal::TrimEmptyQueryParameters({std::make_pair("name", request.name()),
std::make_pair("audience", request.audience()),
std::make_pair("include_email", request.include_email() ? "1" : "0")}));
std::make_pair("include_email", (request.include_email() ? "1" : "0"))}));
}

StatusOr<google::test::admin::database::v1::WriteLogEntriesResponse>
Expand Down
4 changes: 2 additions & 2 deletions generator/internal/descriptor_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1281,7 +1281,7 @@ INSTANTIATE_TEST_SUITE_P(
MethodVarsTestValues("my.service.v1.Service.Method1",
"method_http_query_parameters", R"""(,
rest_internal::TrimEmptyQueryParameters({std::make_pair("number", std::to_string(request.number())),
std::make_pair("toggle", request.toggle() ? "1" : "0"),
std::make_pair("toggle", (request.toggle() ? "1" : "0")),
std::make_pair("title", request.title()),
std::make_pair("parent", request.parent())}))"""),
// Method2
Expand Down Expand Up @@ -1310,7 +1310,7 @@ INSTANTIATE_TEST_SUITE_P(
"method_http_query_parameters", R"""(,
rest_internal::TrimEmptyQueryParameters({std::make_pair("number", std::to_string(request.number())),
std::make_pair("name", request.name()),
std::make_pair("toggle", request.toggle() ? "1" : "0"),
std::make_pair("toggle", (request.toggle() ? "1" : "0")),
std::make_pair("title", request.title())}))"""),
// Method3
MethodVarsTestValues("my.service.v1.Service.Method3",
Expand Down
90 changes: 70 additions & 20 deletions generator/internal/http_option_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,52 @@ void SetHttpDerivedMethodVars(
absl::visit(HttpInfoVisitor(method, method_vars), parsed_http_info);
}

// Request fields not appering in the path may not wind up as part of the json
absl::optional<QueryParameterInfo> DetermineQueryParameterInfo(
google::protobuf::FieldDescriptor const& field) {
static auto* const kSupportedWellKnownValueTypes = [] {
auto foo = std::make_unique<
std::unordered_map<std::string, protobuf::FieldDescriptor::CppType>>();
foo->emplace("google.protobuf.BoolValue",
protobuf::FieldDescriptor::CPPTYPE_BOOL);
foo->emplace("google.protobuf.DoubleValue",
protobuf::FieldDescriptor::CPPTYPE_DOUBLE);
foo->emplace("google.protobuf.FloatValue",
protobuf::FieldDescriptor::CPPTYPE_FLOAT);
foo->emplace("google.protobuf.Int32Value",
protobuf::FieldDescriptor::CPPTYPE_INT32);
foo->emplace("google.protobuf.Int64Value",
protobuf::FieldDescriptor::CPPTYPE_INT64);
foo->emplace("google.protobuf.StringValue",
protobuf::FieldDescriptor::CPPTYPE_STRING);
foo->emplace("google.protobuf.UInt32Value",
protobuf::FieldDescriptor::CPPTYPE_UINT32);
foo->emplace("google.protobuf.UInt64Value",
protobuf::FieldDescriptor::CPPTYPE_UINT64);
return foo.release();
}();

absl::optional<QueryParameterInfo> param_info;
// Only attempt to make non-repeated, simple fields query parameters.
if (!field.is_repeated() && !field.options().deprecated()) {
if (field.cpp_type() != protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
param_info = QueryParameterInfo{
field.cpp_type(), absl::StrCat("request.", field.name(), "()"),
false};
} else {
// But also consider protobuf well known types that wrap simple types.
auto iter = kSupportedWellKnownValueTypes->find(
field.message_type()->full_name());
if (iter != kSupportedWellKnownValueTypes->end()) {
param_info = QueryParameterInfo{
iter->second, absl::StrCat("request.", field.name(), "().value()"),
true};
}
}
}
return param_info;
}

// Request fields not appearing in the path may not wind up as part of the json
// request body, so per https://cloud.google.com/apis/design/standard_methods,
// for HTTP transcoding we need to turn the request fields into query
// parameters.
Expand All @@ -211,36 +256,41 @@ void SetHttpQueryParameters(
: method(method), method_vars(method_vars) {}
void FormatQueryParameterCode(
std::vector<std::string> const& param_field_names) {
std::vector<std::pair<std::string, protobuf::FieldDescriptor::CppType>>
std::vector<std::pair<std::string, QueryParameterInfo>>
remaining_request_fields;
auto const* request = method.input_type();
for (int i = 0; i < request->field_count(); ++i) {
auto const* field = request->field(i);
// Only attempt to make non-repeated, simple fields query parameters.
if (!field->is_repeated() && !field->options().deprecated() &&
field->cpp_type() != protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
if (!internal::Contains(param_field_names, field->name())) {
remaining_request_fields.emplace_back(field->name(),
field->cpp_type());
}
auto param_info = DetermineQueryParameterInfo(*field);
if (param_info &&
!internal::Contains(param_field_names, field->name())) {
remaining_request_fields.emplace_back(field->name(), *param_info);
}
}

auto format = [](auto* out, auto const& i) {
if (i.second == protobuf::FieldDescriptor::CPPTYPE_STRING) {
out->append(absl::StrFormat("std::make_pair(\"%s\", request.%s())",
i.first, i.first));
return;
std::string field_access;
if (i.second.cpp_type == protobuf::FieldDescriptor::CPPTYPE_STRING) {
field_access = i.second.request_field_accessor;
} else if (i.second.cpp_type ==
protobuf::FieldDescriptor::CPPTYPE_BOOL) {
field_access = absl::StrCat("(", i.second.request_field_accessor,
R"""( ? "1" : "0"))""");
} else {
field_access = absl::StrCat("std::to_string(",
i.second.request_field_accessor, ")");
}
if (i.second == protobuf::FieldDescriptor::CPPTYPE_BOOL) {

if (i.second.check_presence) {
out->append(absl::StrFormat(
R"""(std::make_pair("%s", request.%s() ? "1" : "0"))""", i.first,
i.first));
return;
R"""(std::make_pair("%s", (request.has_%s() ? %s : "")))""",
i.first, i.first, field_access));
} else {
out->append(absl::StrFormat(R"""(std::make_pair("%s", %s))""",
i.first, field_access));
}
out->append(absl::StrFormat(
"std::make_pair(\"%s\", std::to_string(request.%s()))", i.first,
i.first));
};

if (remaining_request_fields.empty()) {
method_vars["method_http_query_parameters"] = "";
} else {
Expand Down
18 changes: 18 additions & 0 deletions generator/internal/http_option_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ void SetHttpDerivedMethodVars(
google::protobuf::MethodDescriptor const& method,
VarsDictionary& method_vars);

struct QueryParameterInfo {
protobuf::FieldDescriptor::CppType cpp_type;
// A code fragment the generator emits to access the value of the field.
std::string request_field_accessor;
// Check presence for MESSAGE types as their default values may result in
// undesired behavior.
bool check_presence;
};

/**
* Determine if a field is a query parameter candidate, such that it's a
* non-repeated field that is also not an aggregate type. This includes numeric,
* bool, and string native protobuf data types, as well as, protobuf "Well Known
* Types" that wrap those data types.
*/
absl::optional<QueryParameterInfo> DetermineQueryParameterInfo(
google::protobuf::FieldDescriptor const& field);

/**
* Sets the "method_http_query_parameters" value in method_vars based on the
* parsed_http_info.
Expand Down
120 changes: 118 additions & 2 deletions generator/internal/http_option_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,42 @@ syntax = "proto3";
package google.protobuf;
// Leading comments about message Empty.
message Empty {}
message DoubleValue {
// The double value.
double value = 1;
}
message FloatValue {
// The float value.
float value = 1;
}
message Int64Value {
// The int64 value.
int64 value = 1;
}
message UInt64Value {
// The uint64 value.
uint64 value = 1;
}
message Int32Value {
// The int32 value.
int32 value = 1;
}
message UInt32Value {
// The uint32 value.
uint32 value = 1;
}
message BoolValue {
// The bool value.
bool value = 1;
}
message StringValue {
// The string value.
string value = 1;
}
message BytesValue {
// The bytes value.
bytes value = 1;
}
)""";

char const* const kServiceProto =
Expand Down Expand Up @@ -275,6 +311,65 @@ char const* const kServiceProtoWithoutVersion =
" }\n"
"}\n";

char const* const kBigQueryServiceProto = R"""(
syntax = "proto3";
package my.package.v1;
import "google/api/annotations.proto";
import "google/api/client.proto";
import "google/api/http.proto";
import "google/protobuf/well_known.proto";
service BigQueryLikeService {
// RPC to get the results of a query job.
rpc GetQueryResults(GetQueryResultsRequest)
returns (GetQueryResultsResponse) {
option (google.api.http) = {
get: "/bigquery/v2/projects/{project_id=*}/queries/{job_id=*}"
};
}
}
// Request object of GetQueryResults.
message GetQueryResultsRequest {
// Required. Project ID of the query job.
string project_id = 1;
// Required. Job ID of the query job.
string job_id = 2;
// Zero-based index of the starting row.
google.protobuf.UInt64Value start_index = 3;
// Page token, returned by a previous call, to request the next page of
// results.
google.protobuf.StringValue page_token = 4;
// Maximum number of results to read.
google.protobuf.UInt32Value max_results = 5;
// The geographic location of the job.
google.protobuf.BoolValue include_location = 7;
// Double field.
google.protobuf.DoubleValue double_value = 8;
// Float field.
google.protobuf.FloatValue float_value = 9;
// Int32 field.
google.protobuf.Int32Value int32_value = 10;
// Int64 field.
google.protobuf.Int64Value int64_value = 11;
// Non supported message type that is not a query param.
google.protobuf.Empty non_supported_type = 12;
}
// Response object of GetQueryResults.
message GetQueryResultsResponse {}
)""";

struct MethodVarsTestValues {
MethodVarsTestValues(std::string m, std::string k, std::string v)
: method(std::move(m)),
Expand All @@ -298,7 +393,9 @@ class HttpOptionUtilsTest
{std::string("google/protobuf/well_known.proto"), kWellKnownProto},
{std::string("google/foo/v1/service.proto"), kServiceProto},
{std::string("google/foo/v1/service_without_version.proto"),
kServiceProtoWithoutVersion}}),
kServiceProtoWithoutVersion},
{std::string("google/foo/v1/big_query_service.proto"),
kBigQueryServiceProto}}),
source_tree_db_(&source_tree_),
merged_db_(&simple_db_, &source_tree_db_),
pool_(&merged_db_, &collector_) {
Expand Down Expand Up @@ -550,7 +647,26 @@ TEST_F(HttpOptionUtilsTest, SetHttpGetQueryParametersGetPaginated) {
rest_internal::TrimEmptyQueryParameters({std::make_pair("page_size", std::to_string(request.page_size())),
std::make_pair("page_token", request.page_token()),
std::make_pair("name", request.name()),
std::make_pair("include_foo", request.include_foo() ? "1" : "0")}))"""));
std::make_pair("include_foo", (request.include_foo() ? "1" : "0"))}))"""));
}

TEST_F(HttpOptionUtilsTest,
SetHttpGetQueryParametersGetWellKnownTypesPaginated) {
FileDescriptor const* service_file_descriptor =
pool_.FindFileByName("google/foo/v1/big_query_service.proto");
MethodDescriptor const* method =
service_file_descriptor->service(0)->method(0);
VarsDictionary vars;
SetHttpQueryParameters(ParseHttpExtension(*method), *method, vars);
EXPECT_THAT(vars.at("method_http_query_parameters"), Eq(R"""(,
rest_internal::TrimEmptyQueryParameters({std::make_pair("start_index", (request.has_start_index() ? std::to_string(request.start_index().value()) : "")),
std::make_pair("page_token", (request.has_page_token() ? request.page_token().value() : "")),
std::make_pair("max_results", (request.has_max_results() ? std::to_string(request.max_results().value()) : "")),
std::make_pair("include_location", (request.has_include_location() ? (request.include_location().value() ? "1" : "0") : "")),
std::make_pair("double_value", (request.has_double_value() ? std::to_string(request.double_value().value()) : "")),
std::make_pair("float_value", (request.has_float_value() ? std::to_string(request.float_value().value()) : "")),
std::make_pair("int32_value", (request.has_int32_value() ? std::to_string(request.int32_value().value()) : "")),
std::make_pair("int64_value", (request.has_int64_value() ? std::to_string(request.int64_value().value()) : ""))}))"""));
}

TEST_F(HttpOptionUtilsTest, HasHttpAnnotationRoutingHeaderSuccess) {
Expand Down
10 changes: 7 additions & 3 deletions google/cloud/bigquerycontrol/v2/internal/dataset_rest_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Status DefaultDatasetServiceRestStub::DeleteDataset(
"projects", "/", request.project_id(), "/", "datasets", "/",
request.dataset_id()),
rest_internal::TrimEmptyQueryParameters({std::make_pair(
"delete_contents", request.delete_contents() ? "1" : "0")}));
"delete_contents", (request.delete_contents() ? "1" : "0"))}));
}

StatusOr<google::cloud::bigquery::v2::DatasetList>
Expand All @@ -117,8 +117,12 @@ DefaultDatasetServiceRestStub::ListDatasets(
rest_internal::DetermineApiVersion("v2", options), "/",
"projects", "/", request.project_id(), "/", "datasets"),
rest_internal::TrimEmptyQueryParameters(
{std::make_pair("page_token", request.page_token()),
std::make_pair("all", request.all() ? "1" : "0"),
{std::make_pair("max_results",
(request.has_max_results()
? std::to_string(request.max_results().value())
: "")),
std::make_pair("page_token", request.page_token()),
std::make_pair("all", (request.all() ? "1" : "0")),
std::make_pair("filter", request.filter())}));
}

Expand Down
25 changes: 23 additions & 2 deletions google/cloud/bigquerycontrol/v2/internal/job_rest_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,18 @@ DefaultJobServiceRestStub::ListJobs(
rest_internal::DetermineApiVersion("v2", options), "/",
"projects", "/", request.project_id(), "/", "jobs"),
rest_internal::TrimEmptyQueryParameters(
{std::make_pair("all_users", request.all_users() ? "1" : "0"),
{std::make_pair("all_users", (request.all_users() ? "1" : "0")),
std::make_pair("max_results",
(request.has_max_results()
? std::to_string(request.max_results().value())
: "")),
std::make_pair("min_creation_time",
std::to_string(request.min_creation_time())),
std::make_pair(
"max_creation_time",
(request.has_max_creation_time()
? std::to_string(request.max_creation_time().value())
: "")),
std::make_pair("page_token", request.page_token()),
std::make_pair("projection", std::to_string(request.projection())),
std::make_pair("parent_job_id", request.parent_job_id())}));
Expand All @@ -125,7 +134,19 @@ DefaultJobServiceRestStub::GetQueryResults(
"projects", "/", request.project_id(), "/", "queries", "/",
request.job_id()),
rest_internal::TrimEmptyQueryParameters(
{std::make_pair("page_token", request.page_token()),
{std::make_pair("start_index",
(request.has_start_index()
? std::to_string(request.start_index().value())
: "")),
std::make_pair("page_token", request.page_token()),
std::make_pair("max_results",
(request.has_max_results()
? std::to_string(request.max_results().value())
: "")),
std::make_pair("timeout_ms",
(request.has_timeout_ms()
? std::to_string(request.timeout_ms().value())
: "")),
std::make_pair("location", request.location())}));
}

Expand Down
6 changes: 5 additions & 1 deletion google/cloud/bigquerycontrol/v2/internal/model_rest_stub.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ DefaultModelServiceRestStub::ListModels(
"projects", "/", request.project_id(), "/", "datasets", "/",
request.dataset_id(), "/", "models"),
rest_internal::TrimEmptyQueryParameters(
{std::make_pair("page_token", request.page_token())}));
{std::make_pair("max_results",
(request.has_max_results()
? std::to_string(request.max_results().value())
: "")),
std::make_pair("page_token", request.page_token())}));
}

StatusOr<google::cloud::bigquery::v2::Model>
Expand Down
Loading

0 comments on commit 9c67282

Please sign in to comment.