diff --git a/README.md b/README.md index 31a90d7fb..3126bb715 100644 --- a/README.md +++ b/README.md @@ -50,20 +50,20 @@ If you are using Maven without the BOM, add this to your dependencies: If you are using Gradle 5.x or later, add this to your dependencies: ```Groovy -implementation platform('com.google.cloud:libraries-bom:26.27.0') +implementation platform('com.google.cloud:libraries-bom:26.28.0') implementation 'com.google.cloud:google-cloud-datastore' ``` If you are using Gradle without BOM, add this to your dependencies: ```Groovy -implementation 'com.google.cloud:google-cloud-datastore:2.17.5' +implementation 'com.google.cloud:google-cloud-datastore:2.17.6' ``` If you are using SBT, add this to your dependencies: ```Scala -libraryDependencies += "com.google.cloud" % "google-cloud-datastore" % "2.17.5" +libraryDependencies += "com.google.cloud" % "google-cloud-datastore" % "2.17.6" ``` @@ -380,7 +380,7 @@ Java is a registered trademark of Oracle and/or its affiliates. [kokoro-badge-link-5]: http://storage.googleapis.com/cloud-devrel-public/java/badges/java-datastore/java11.html [stability-image]: https://img.shields.io/badge/stability-stable-green [maven-version-image]: https://img.shields.io/maven-central/v/com.google.cloud/google-cloud-datastore.svg -[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-datastore/2.17.5 +[maven-version-link]: https://central.sonatype.com/artifact/com.google.cloud/google-cloud-datastore/2.17.6 [authentication]: https://github.com/googleapis/google-cloud-java#authentication [auth-scopes]: https://developers.google.com/identity/protocols/oauth2/scopes [predefined-iam-roles]: https://cloud.google.com/iam/docs/understanding-roles#predefined_roles diff --git a/datastore-v1-proto-client/pom.xml b/datastore-v1-proto-client/pom.xml index ebde7c645..1d1bb9e6a 100644 --- a/datastore-v1-proto-client/pom.xml +++ b/datastore-v1-proto-client/pom.xml @@ -95,6 +95,13 @@ test + + com.google.testparameterinjector + test-parameter-injector + 1.14 + test + + com.google.truth truth diff --git a/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java b/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java index b0b47c505..6e41b7d93 100644 --- a/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java +++ b/datastore-v1-proto-client/src/main/java/com/google/datastore/v1/client/RemoteRpc.java @@ -53,9 +53,37 @@ class RemoteRpc { private final HttpRequestInitializer initializer; private final String url; private final AtomicInteger rpcCount = new AtomicInteger(0); - // Not final - so it can be set/reset in Unittests - private static boolean enableE2EChecksum = - Boolean.parseBoolean(System.getenv("GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_CHECKSUM")); + private static final String E2E_REQUEST_CHECKSUM_FLAG = + "GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_REQUEST_CHECKSUM"; + private static final String E2E_RESPONSE_CHECKSUM_FLAG = + "GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_RESPONSE_CHECKSUM"; + // By default request checksum is enabled. + // Not final - so it can be set/reset in Unittests. + private static boolean enableE2ERequestChecksum = + System.getenv(E2E_REQUEST_CHECKSUM_FLAG) == null + || Boolean.parseBoolean(System.getenv(E2E_REQUEST_CHECKSUM_FLAG)); + + private static boolean enableE2EResponseChecksum = + Boolean.parseBoolean(System.getenv(E2E_RESPONSE_CHECKSUM_FLAG)); + + // Deprecated env var for enabling both request and response checksum. + private static final String E2E_CHECKSUM_FLAG_DEPRECATED = + "GOOGLE_CLOUD_DATASTORE_HTTP_ENABLE_E2E_CHECKSUM"; + + static { + if (System.getenv(E2E_CHECKSUM_FLAG_DEPRECATED) != null + && System.getenv(E2E_REQUEST_CHECKSUM_FLAG) == null + && System.getenv(E2E_RESPONSE_CHECKSUM_FLAG) == null) { + logger.warning( + String.format( + "%s environment variable is deprecated. " + + "Please switch to using %s and/or %s to enable/disable " + + "request and/or response checksum features.", + E2E_CHECKSUM_FLAG_DEPRECATED, E2E_REQUEST_CHECKSUM_FLAG, E2E_RESPONSE_CHECKSUM_FLAG)); + enableE2ERequestChecksum = Boolean.parseBoolean(System.getenv(E2E_CHECKSUM_FLAG_DEPRECATED)); + enableE2EResponseChecksum = enableE2ERequestChecksum; + } + } RemoteRpc(HttpRequestFactory client, HttpRequestInitializer initializer, String url) { this.client = client; @@ -113,7 +141,7 @@ public InputStream call( } } InputStream inputStream = httpResponse.getContent(); - return enableE2EChecksum && EndToEndChecksumHandler.hasChecksumHeader(httpResponse) + return enableE2EResponseChecksum && EndToEndChecksumHandler.hasChecksumHeader(httpResponse) ? new ChecksumEnforcingInputStream(inputStream, httpResponse) : inputStream; } catch (SocketTimeoutException e) { @@ -138,7 +166,7 @@ void setHeaders( builder.append(databaseId); } httpRequest.getHeaders().put(X_GOOG_REQUEST_PARAMS_HEADER, builder.toString()); - if (enableE2EChecksum && request != null) { + if (enableE2ERequestChecksum && request != null) { String checksum = EndToEndChecksumHandler.computeChecksum(request.toByteArray()); if (checksum != null) { httpRequest @@ -154,8 +182,10 @@ HttpRequestFactory getClient() { } @VisibleForTesting - static void setSystemEnvE2EChecksum(boolean enableE2EChecksum) { - RemoteRpc.enableE2EChecksum = enableE2EChecksum; + static void setSystemEnvE2EChecksum( + boolean enableE2ERequestChecksum, boolean enableE2EResponseChecksum) { + RemoteRpc.enableE2ERequestChecksum = enableE2ERequestChecksum; + RemoteRpc.enableE2EResponseChecksum = enableE2EResponseChecksum; } void resetRpcCount() { diff --git a/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java b/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java index 28e3f20b8..d6f2a507f 100644 --- a/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java +++ b/datastore-v1-proto-client/src/test/java/com/google/datastore/v1/client/RemoteRpcTest.java @@ -16,8 +16,7 @@ package com.google.datastore.v1.client; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import com.google.api.client.http.HttpRequest; import com.google.api.client.http.HttpTransport; @@ -31,17 +30,23 @@ import com.google.protobuf.MessageLite; import com.google.rpc.Code; import com.google.rpc.Status; +import com.google.testing.junit.testparameterinjector.TestParameter; +import com.google.testing.junit.testparameterinjector.TestParameterInjector; import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; import java.util.zip.GZIPOutputStream; import org.junit.Test; import org.junit.runner.RunWith; -import org.junit.runners.JUnit4; /** Test for {@link RemoteRpc}. */ -@RunWith(JUnit4.class) +@RunWith(TestParameterInjector.class) public class RemoteRpcTest { private static final String METHOD_NAME = "methodName"; @@ -157,48 +162,88 @@ public void testGzip() throws IOException, DatastoreException { } @Test - public void testHttpHeaders_expectE2eChecksumHeader() throws IOException { - // Enable E2E-Checksum system env variable - RemoteRpc.setSystemEnvE2EChecksum(true); + public void testE2EChecksum(@TestParameter boolean reqEnabled, @TestParameter boolean respEnabled) + throws IOException, DatastoreException { + RemoteRpc.setSystemEnvE2EChecksum(reqEnabled, respEnabled); String projectId = "project-id"; MessageLite request = RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build(); - RemoteRpc rpc = - newRemoteRpc( - new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true)); - HttpRequest httpRequest = - rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request)); - rpc.setHeaders(request, httpRequest, projectId, ""); - assertNotNull( - httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER)); - // Expect to find e2e-checksum header - String header = - httpRequest - .getHeaders() - .getFirstHeaderStringValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER); - assertEquals(9, header.length()); + + // Always return invalid response checksum to check that it will raise an exception only when + // response checksum verification is enabled. + List respHeaders = + Collections.singletonList( + new MyHeader( + EndToEndChecksumHandler.HTTP_RESPONSE_CHECKSUM_HEADER, "invalid_checksum")); + + Set expectedRequestHeaders = new HashSet<>(); + expectedRequestHeaders.add(MyHeader.anyValue(RemoteRpc.API_FORMAT_VERSION_HEADER)); + + if (reqEnabled) { + expectedRequestHeaders.add( + new MyHeader( + EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER, + EndToEndChecksumHandler.computeChecksum(request.toByteArray()))); + } else { + expectedRequestHeaders.add( + MyHeader.anyValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER).mustNotExist()); + } + + InjectedTestValues testVals = + new InjectedTestValues( + gzip(newBeginTransactionResponse()), + new byte[1], + true, + respHeaders, + expectedRequestHeaders); + RemoteRpc rpc = newRemoteRpc(testVals); + + InputStream stream = rpc.call("someMethod", request, projectId, ""); + byte[] buf = new byte[1000]; + if (respEnabled) { + // Must throw an IOException when verifying response checksum because we provided an invalid + // checksum in the response header. + assertThrows( + IOException.class, + () -> { + while (stream.read(buf, 0, 1000) != -1) { + // Do nothing with the bytes read. + } + }); + } else { + // Must not raise an exception even with invalid response checksum because we did not enable + // response checksum verification. + while (stream.read(buf, 0, 1000) != -1) { + // Do nothing with the bytes read. + } + } } @Test - public void testHttpHeaders_doNotExpectE2eChecksumHeader() throws IOException { - // disable E2E-Checksum system env variable - RemoteRpc.setSystemEnvE2EChecksum(false); + public void testE2EChecksum_validResponseChecksum() throws IOException, DatastoreException { + RemoteRpc.setSystemEnvE2EChecksum(false, true); String projectId = "project-id"; MessageLite request = RollbackRequest.newBuilder().setTransaction(ByteString.copyFromUtf8(projectId)).build(); - RemoteRpc rpc = - newRemoteRpc( - new InjectedTestValues(gzip(newBeginTransactionResponse()), new byte[1], true)); - HttpRequest httpRequest = - rpc.getClient().buildPostRequest(rpc.resolveURL("blah"), new ProtoHttpContent(request)); - rpc.setHeaders(request, httpRequest, projectId, ""); - assertNotNull( - httpRequest.getHeaders().getFirstHeaderStringValue(RemoteRpc.API_FORMAT_VERSION_HEADER)); - // Do not expect to find e2e-checksum header - assertNull( - httpRequest - .getHeaders() - .getFirstHeaderStringValue(EndToEndChecksumHandler.HTTP_REQUEST_CHECKSUM_HEADER)); + + BeginTransactionResponse response = newBeginTransactionResponse(); + + List respHeaders = + Collections.singletonList( + new MyHeader( + EndToEndChecksumHandler.HTTP_RESPONSE_CHECKSUM_HEADER, + EndToEndChecksumHandler.computeChecksum(response.toByteArray()))); + + InjectedTestValues testVals = + new InjectedTestValues(gzip(response), new byte[1], true, respHeaders); + RemoteRpc rpc = newRemoteRpc(testVals); + + InputStream stream = rpc.call("someMethod", request, projectId, ""); + byte[] buf = new byte[1000]; + // Must not raise an exception. + while (stream.read(buf, 0, 1000) != -1) { + // Do nothing with the bytes read. + } } @Test @@ -258,12 +303,38 @@ private static class InjectedTestValues { private final InputStream inputStream; private final int contentLength; private final boolean isGzip; + private final List responseHeaders; + private final Set expectedRequestHeaders; public InjectedTestValues(byte[] messageBytes, byte[] additionalBytes, boolean isGzip) { + this( + messageBytes, + additionalBytes, + isGzip, + new ArrayList(), + new HashSet()); + } + + public InjectedTestValues( + byte[] messageBytes, + byte[] additionalBytes, + boolean isGzip, + List responseHeaders) { + this(messageBytes, additionalBytes, isGzip, responseHeaders, new HashSet()); + } + + public InjectedTestValues( + byte[] messageBytes, + byte[] additionalBytes, + boolean isGzip, + List responseHeaders, + Set expectedRequestHeaders) { byte[] allBytes = concat(messageBytes, additionalBytes); this.inputStream = new ByteArrayInputStream(allBytes); this.contentLength = allBytes.length; this.isGzip = isGzip; + this.responseHeaders = responseHeaders; + this.expectedRequestHeaders = expectedRequestHeaders; } private static byte[] concat(byte[] a, byte[] b) { @@ -289,6 +360,45 @@ protected LowLevelHttpRequest buildRequest(String method, String url) throws IOE } } + private static class MyHeader { + private final String key; + private final String value; + private final boolean ignoreValue; + private boolean mustExist; + + public static MyHeader anyValue(String key) { + return new MyHeader(key, "", true); + } + + public MyHeader(String key, String value) { + this(key, value, false); + } + + private MyHeader(String key, String value, boolean ignoreValue) { + this.key = key.toLowerCase(); + this.value = value; + this.ignoreValue = ignoreValue; + this.mustExist = true; + } + + public MyHeader mustNotExist() { + mustExist = false; + return this; + } + + public boolean matches(MyHeader h) { + return key.equals(h.key) && (h.ignoreValue || ignoreValue || value.equals(h.value)); + } + + public String toString() { + String mustExistString = mustExist ? "" : "must not exist: "; + if (ignoreValue) { + return String.format("%s\"%s\": ANY", mustExistString, key); + } + return String.format("%s\"%s\": \"%s\"", mustExistString, key, value); + } + } + /** * {@link LowLevelHttpRequest} that allows injection of the returned {@link LowLevelHttpResponse}. */ @@ -296,17 +406,57 @@ private static class MyLowLevelHttpRequest extends LowLevelHttpRequest { private final InjectedTestValues injectedTestValues; + private final List requestHeaders = new ArrayList<>(); + public MyLowLevelHttpRequest(InjectedTestValues injectedTestValues) { this.injectedTestValues = injectedTestValues; } @Override public void addHeader(String name, String value) throws IOException { - // Do nothing. + requestHeaders.add(new MyHeader(name, value)); + } + + private void assertHeaders() { + if (injectedTestValues.expectedRequestHeaders.isEmpty()) { + return; + } + + Set mustExist = new HashSet<>(); + List mustNotExist = new ArrayList<>(); + for (MyHeader header : injectedTestValues.expectedRequestHeaders) { + if (header.mustExist) { + mustExist.add(header); + } else { + mustNotExist.add(header); + } + } + + for (MyHeader h : requestHeaders) { + mustExist.removeIf(expected -> expected.matches(h)); + } + + if (!mustExist.isEmpty()) { + throw new RuntimeException( + "These request headers were expected but missing:\n" + + mustExist + + "\nThese headers were seen:\n" + + requestHeaders); + } + + for (MyHeader notExpected : mustNotExist) { + for (MyHeader h : requestHeaders) { + if (h.matches(notExpected)) { + throw new RuntimeException( + "Expected header " + notExpected.toString() + " but found: " + h.toString()); + } + } + } } @Override public LowLevelHttpResponse execute() throws IOException { + assertHeaders(); return new MyLowLevelHttpResponse(injectedTestValues); } } @@ -357,17 +507,17 @@ public String getReasonPhrase() throws IOException { @Override public int getHeaderCount() throws IOException { - return 0; + return injectedTestValues.responseHeaders.size(); } @Override public String getHeaderName(int index) throws IOException { - return null; + return injectedTestValues.responseHeaders.get(index).key; } @Override public String getHeaderValue(int index) throws IOException { - return null; + return injectedTestValues.responseHeaders.get(index).value; } } }