Skip to content

Commit

Permalink
impl(generator): support protobuf wrapper types as query params
Browse files Browse the repository at this point in the history
  • Loading branch information
scotthart committed Jul 18, 2024
1 parent eb2cacc commit 6fb2cf6
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 20 deletions.
78 changes: 59 additions & 19 deletions generator/internal/http_option_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,50 @@ void SetHttpDerivedMethodVars(
absl::visit(HttpInfoVisitor(method, method_vars), parsed_http_info);
}

// Request fields not appering in the path may not wind up as part of the json
absl::optional<QueryParameterInfo> DetermineQueryParameterInfo(
google::protobuf::FieldDescriptor const& field) {
static auto* const kSupportedWellKnownValueTypes = [] {
auto foo = std::make_unique<
std::unordered_map<std::string, protobuf::FieldDescriptor::CppType>>();
foo->emplace("google.protobuf.BoolValue",
protobuf::FieldDescriptor::CPPTYPE_BOOL);
foo->emplace("google.protobuf.DoubleValue",
protobuf::FieldDescriptor::CPPTYPE_DOUBLE);
foo->emplace("google.protobuf.FloatValue",
protobuf::FieldDescriptor::CPPTYPE_FLOAT);
foo->emplace("google.protobuf.Int32Value",
protobuf::FieldDescriptor::CPPTYPE_INT32);
foo->emplace("google.protobuf.Int64Value",
protobuf::FieldDescriptor::CPPTYPE_INT64);
foo->emplace("google.protobuf.StringValue",
protobuf::FieldDescriptor::CPPTYPE_STRING);
foo->emplace("google.protobuf.UInt32Value",
protobuf::FieldDescriptor::CPPTYPE_UINT32);
foo->emplace("google.protobuf.UInt64Value",
protobuf::FieldDescriptor::CPPTYPE_UINT64);
return foo.release();
}();

absl::optional<QueryParameterInfo> param_info;
// Only attempt to make non-repeated, simple fields query parameters.
if (!field.is_repeated() && !field.options().deprecated()) {
if (field.cpp_type() != protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
param_info = QueryParameterInfo{
field.cpp_type(), absl::StrCat("request.", field.name(), "()")};
} else {
// But also consider protobuf well known types that wrap simple types.
auto iter = kSupportedWellKnownValueTypes->find(
field.message_type()->full_name());
if (iter != kSupportedWellKnownValueTypes->end()) {
param_info = QueryParameterInfo{
iter->second, absl::StrCat("request.", field.name(), "().value()")};
}
}
}
return param_info;
}

// Request fields not appearing in the path may not wind up as part of the json
// request body, so per https://cloud.google.com/apis/design/standard_methods,
// for HTTP transcoding we need to turn the request fields into query
// parameters.
Expand All @@ -211,35 +254,32 @@ void SetHttpQueryParameters(
: method(method), method_vars(method_vars) {}
void FormatQueryParameterCode(
std::vector<std::string> const& param_field_names) {
std::vector<std::pair<std::string, protobuf::FieldDescriptor::CppType>>
std::vector<std::pair<std::string, QueryParameterInfo>>
remaining_request_fields;
auto const* request = method.input_type();
for (int i = 0; i < request->field_count(); ++i) {
auto const* field = request->field(i);
// Only attempt to make non-repeated, simple fields query parameters.
if (!field->is_repeated() && !field->options().deprecated() &&
field->cpp_type() != protobuf::FieldDescriptor::CPPTYPE_MESSAGE) {
if (!internal::Contains(param_field_names, field->name())) {
remaining_request_fields.emplace_back(field->name(),
field->cpp_type());
}
auto param_info = DetermineQueryParameterInfo(*field);
if (param_info &&
!internal::Contains(param_field_names, field->name())) {
remaining_request_fields.emplace_back(field->name(), *param_info);
}
}
auto format = [](auto* out, auto const& i) {
if (i.second == protobuf::FieldDescriptor::CPPTYPE_STRING) {
out->append(absl::StrFormat("std::make_pair(\"%s\", request.%s())",
i.first, i.first));
if (i.second.cpp_type == protobuf::FieldDescriptor::CPPTYPE_STRING) {
out->append(absl::StrCat("std::make_pair(\"", i.first, "\", ",
i.second.request_field_accessor, ")"));
return;
}
if (i.second == protobuf::FieldDescriptor::CPPTYPE_BOOL) {
out->append(absl::StrFormat(
R"""(std::make_pair("%s", request.%s() ? "1" : "0"))""", i.first,
i.first));
if (i.second.cpp_type == protobuf::FieldDescriptor::CPPTYPE_BOOL) {
out->append(absl::StrCat("std::make_pair(\"", i.first, "\", ",
i.second.request_field_accessor,
R"""( ? "1" : "0"))"""));
return;
}
out->append(absl::StrFormat(
"std::make_pair(\"%s\", std::to_string(request.%s()))", i.first,
i.first));
out->append(absl::StrCat("std::make_pair(\"", i.first,
"\", std::to_string(",
i.second.request_field_accessor, "))"));
};
if (remaining_request_fields.empty()) {
method_vars["method_http_query_parameters"] = "";
Expand Down
15 changes: 15 additions & 0 deletions generator/internal/http_option_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,21 @@ void SetHttpDerivedMethodVars(
google::protobuf::MethodDescriptor const& method,
VarsDictionary& method_vars);

struct QueryParameterInfo {
protobuf::FieldDescriptor::CppType cpp_type;
// A code fragment the generator emits to access the value of the field.
std::string request_field_accessor;
};

/**
* Determine if a field is a query parameter candidate, such that it's a
* non-repeated field that is also not an aggregate type. This includes numeric,
* bool, and string native protobuf data types, as well as, protobuf "Well Known
* Types" that wrap those data types.
*/
absl::optional<QueryParameterInfo> DetermineQueryParameterInfo(
google::protobuf::FieldDescriptor const& field);

/**
* Sets the "method_http_query_parameters" value in method_vars based on the
* parsed_http_info.
Expand Down
118 changes: 117 additions & 1 deletion generator/internal/http_option_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,42 @@ syntax = "proto3";
package google.protobuf;
// Leading comments about message Empty.
message Empty {}
message DoubleValue {
// The double value.
double value = 1;
}
message FloatValue {
// The float value.
float value = 1;
}
message Int64Value {
// The int64 value.
int64 value = 1;
}
message UInt64Value {
// The uint64 value.
uint64 value = 1;
}
message Int32Value {
// The int32 value.
int32 value = 1;
}
message UInt32Value {
// The uint32 value.
uint32 value = 1;
}
message BoolValue {
// The bool value.
bool value = 1;
}
message StringValue {
// The string value.
string value = 1;
}
message BytesValue {
// The bytes value.
bytes value = 1;
}
)""";

char const* const kServiceProto =
Expand Down Expand Up @@ -275,6 +311,65 @@ char const* const kServiceProtoWithoutVersion =
" }\n"
"}\n";

char const* const kBigQueryServiceProto = R"""(
syntax = "proto3";
package my.package.v1;
import "google/api/annotations.proto";
import "google/api/client.proto";
import "google/api/http.proto";
import "google/protobuf/well_known.proto";
service BigQueryLikeService {
// RPC to get the results of a query job.
rpc GetQueryResults(GetQueryResultsRequest)
returns (GetQueryResultsResponse) {
option (google.api.http) = {
get: "/bigquery/v2/projects/{project_id=*}/queries/{job_id=*}"
};
}
}
// Request object of GetQueryResults.
message GetQueryResultsRequest {
// Required. Project ID of the query job.
string project_id = 1;
// Required. Job ID of the query job.
string job_id = 2;
// Zero-based index of the starting row.
google.protobuf.UInt64Value start_index = 3;
// Page token, returned by a previous call, to request the next page of
// results.
google.protobuf.StringValue page_token = 4;
// Maximum number of results to read.
google.protobuf.UInt32Value max_results = 5;
// The geographic location of the job.
google.protobuf.BoolValue include_location = 7;
// Double field.
google.protobuf.DoubleValue double_value = 8;
// Float field.
google.protobuf.FloatValue float_value = 9;
// Int32 field.
google.protobuf.Int32Value int32_value = 10;
// Int64 field.
google.protobuf.Int64Value int64_value = 11;
// Non supported message type that is not a query param.
google.protobuf.Empty non_supported_type = 12;
}
// Response object of GetQueryResults.
message GetQueryResultsResponse {}
)""";

struct MethodVarsTestValues {
MethodVarsTestValues(std::string m, std::string k, std::string v)
: method(std::move(m)),
Expand All @@ -298,7 +393,9 @@ class HttpOptionUtilsTest
{std::string("google/protobuf/well_known.proto"), kWellKnownProto},
{std::string("google/foo/v1/service.proto"), kServiceProto},
{std::string("google/foo/v1/service_without_version.proto"),
kServiceProtoWithoutVersion}}),
kServiceProtoWithoutVersion},
{std::string("google/foo/v1/big_query_service.proto"),
kBigQueryServiceProto}}),
source_tree_db_(&source_tree_),
merged_db_(&simple_db_, &source_tree_db_),
pool_(&merged_db_, &collector_) {
Expand Down Expand Up @@ -553,6 +650,25 @@ TEST_F(HttpOptionUtilsTest, SetHttpGetQueryParametersGetPaginated) {
std::make_pair("include_foo", request.include_foo() ? "1" : "0")}))"""));
}

TEST_F(HttpOptionUtilsTest,
SetHttpGetQueryParametersGetWellKnownTypesPaginated) {
FileDescriptor const* service_file_descriptor =
pool_.FindFileByName("google/foo/v1/big_query_service.proto");
MethodDescriptor const* method =
service_file_descriptor->service(0)->method(0);
VarsDictionary vars;
SetHttpQueryParameters(ParseHttpExtension(*method), *method, vars);
EXPECT_THAT(vars.at("method_http_query_parameters"), Eq(R"""(,
rest_internal::TrimEmptyQueryParameters({std::make_pair("start_index", std::to_string(request.start_index().value())),
std::make_pair("page_token", request.page_token().value()),
std::make_pair("max_results", std::to_string(request.max_results().value())),
std::make_pair("include_location", request.include_location().value() ? "1" : "0"),
std::make_pair("double_value", std::to_string(request.double_value().value())),
std::make_pair("float_value", std::to_string(request.float_value().value())),
std::make_pair("int32_value", std::to_string(request.int32_value().value())),
std::make_pair("int64_value", std::to_string(request.int64_value().value()))}))"""));
}

TEST_F(HttpOptionUtilsTest, HasHttpAnnotationRoutingHeaderSuccess) {
FileDescriptor const* service_file_descriptor =
pool_.FindFileByName("google/foo/v1/service.proto");
Expand Down

0 comments on commit 6fb2cf6

Please sign in to comment.