diff --git a/src/main/java/com/google/cloud/spanner/r2dbc/SpannerStatement.java b/src/main/java/com/google/cloud/spanner/r2dbc/SpannerStatement.java index 6e0028bc..ccf77374 100644 --- a/src/main/java/com/google/cloud/spanner/r2dbc/SpannerStatement.java +++ b/src/main/java/com/google/cloud/spanner/r2dbc/SpannerStatement.java @@ -24,6 +24,7 @@ import com.google.cloud.spanner.r2dbc.statement.TypedNull; import com.google.cloud.spanner.r2dbc.util.Assert; import com.google.protobuf.Struct; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.Session; import io.r2dbc.spi.Result; @@ -52,6 +53,8 @@ public class SpannerStatement implements Statement { private StatementBindings statementBindings; + private StatementType statementType; + /** * Creates a Spanner statement for a given SQL statement. * @@ -73,6 +76,7 @@ public SpannerStatement( this.transaction = transaction; this.sql = Assert.requireNonNull(sql, "SQL string can not be null"); this.statementBindings = new StatementBindings(); + this.statementType = StatementParser.getStatementType(this.sql); } @Override @@ -108,34 +112,34 @@ public Statement bindNull(int i, Class type) { @Override public Publisher execute() { - Flux structFlux = Flux.fromIterable(this.statementBindings.getBindings()); - StatementType statementType = StatementParser.getStatementType(this.sql); - - if (statementType == StatementType.SELECT) { - return structFlux.flatMap(struct -> runSingleStatement(struct, statementType)); + switch (this.statementType) { + case DML: + return this.client + .executeBatchDml(this.session, this.transaction, this.sql, + this.statementBindings.getBindings(), + this.statementBindings.getTypes()) + .flatMapIterable(ExecuteBatchDmlResponse::getResultSetsList) + .map(resultSet -> new SpannerResult(Flux.empty(), + Mono.just(Math.toIntExact(resultSet.getStats().getRowCountExact())))); + case SELECT: + Flux structFlux = Flux.fromIterable(this.statementBindings.getBindings()); + return structFlux.flatMap(this::runSelectStatement); + default: + throw new UnsupportedOperationException("Unsupported statement type " + this.statementType); } - // DML statements have to be executed sequentially because they need seqNo to be in order - return structFlux.concatMapDelayError(struct -> runSingleStatement(struct, statementType)); } - private Mono runSingleStatement(Struct params, StatementType statementType) { + private Mono runSelectStatement(Struct params) { PartialResultRowExtractor partialResultRowExtractor = new PartialResultRowExtractor(); Flux resultSetFlux = this.client.executeStreamingSql( this.session, this.transaction, this.sql, params, this.statementBindings.getTypes()); - if (statementType == StatementType.SELECT) { - return resultSetFlux - .flatMapIterable(partialResultRowExtractor, getPartialResultSetFetchSize()) - .transform(result -> Mono.just(new SpannerResult(result, Mono.just(0)))) - .next(); - } else { - return resultSetFlux - .last() - .map(partialResultSet -> Math.toIntExact(partialResultSet.getStats().getRowCountExact())) - .map(rowCount -> new SpannerResult(Flux.empty(), Mono.just(rowCount))); - } + return resultSetFlux + .flatMapIterable(partialResultRowExtractor, getPartialResultSetFetchSize()) + .transform(result -> Mono.just(new SpannerResult(result, Mono.just(0)))) + .next(); } /** diff --git a/src/main/java/com/google/cloud/spanner/r2dbc/client/Client.java b/src/main/java/com/google/cloud/spanner/r2dbc/client/Client.java index 69f5e576..4d247ac5 100644 --- a/src/main/java/com/google/cloud/spanner/r2dbc/client/Client.java +++ b/src/main/java/com/google/cloud/spanner/r2dbc/client/Client.java @@ -19,10 +19,12 @@ import com.google.cloud.spanner.r2dbc.SpannerTransactionContext; import com.google.protobuf.Struct; import com.google.spanner.v1.CommitResponse; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.Session; import com.google.spanner.v1.Transaction; import com.google.spanner.v1.Type; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; import reactor.core.publisher.Flux; @@ -83,6 +85,13 @@ default Flux executeStreamingSql( return executeStreamingSql(session, transaction, sql, null, null); } + /** + * Execute DML batch. + */ + Mono executeBatchDml(Session session, + @Nullable SpannerTransactionContext transactionContext, String sql, + List params, Map types); + /** * Release any resources held by the {@link Client}. * diff --git a/src/main/java/com/google/cloud/spanner/r2dbc/client/GrpcClient.java b/src/main/java/com/google/cloud/spanner/r2dbc/client/GrpcClient.java index 68159721..4b7b5191 100644 --- a/src/main/java/com/google/cloud/spanner/r2dbc/client/GrpcClient.java +++ b/src/main/java/com/google/cloud/spanner/r2dbc/client/GrpcClient.java @@ -27,6 +27,8 @@ import com.google.spanner.v1.CommitResponse; import com.google.spanner.v1.CreateSessionRequest; import com.google.spanner.v1.DeleteSessionRequest; +import com.google.spanner.v1.ExecuteBatchDmlRequest; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.ExecuteSqlRequest; import com.google.spanner.v1.PartialResultSet; import com.google.spanner.v1.RollbackRequest; @@ -42,6 +44,7 @@ import io.grpc.ManagedChannel; import io.grpc.ManagedChannelBuilder; import io.grpc.auth.MoreCallCredentials; +import java.util.List; import java.util.Map; import javax.annotation.Nullable; import reactor.core.publisher.Flux; @@ -159,6 +162,31 @@ public Mono deleteSession(Session session) { }); } + @Override + public Mono executeBatchDml(Session session, + @Nullable SpannerTransactionContext transactionContext, String sql, + List params, Map types) { + + ExecuteBatchDmlRequest.Builder request = ExecuteBatchDmlRequest.newBuilder() + .setSession(session.getName()); + if (transactionContext != null && transactionContext.getTransaction() != null) { + request.setTransaction( + TransactionSelector.newBuilder().setId(transactionContext.getTransaction().getId()) + .build()) + .setSeqno(transactionContext.nextSeqNum()); + + } + for (Struct paramsStruct : params) { + ExecuteBatchDmlRequest.Statement statement = ExecuteBatchDmlRequest.Statement.newBuilder() + .setSql(sql).setParams(paramsStruct).putAllParamTypes(types) + .build(); + request.addStatements(statement); + } + + return ObservableReactiveUtil + .unaryCall(obs -> this.spanner.executeBatchDml(request.build(), obs)); + } + @Override public Flux executeStreamingSql( Session session, @Nullable SpannerTransactionContext transactionContext, String sql, diff --git a/src/test/java/com/google/cloud/spanner/r2dbc/SpannerStatementTest.java b/src/test/java/com/google/cloud/spanner/r2dbc/SpannerStatementTest.java index e49b4bd9..c58c3182 100644 --- a/src/test/java/com/google/cloud/spanner/r2dbc/SpannerStatementTest.java +++ b/src/test/java/com/google/cloud/spanner/r2dbc/SpannerStatementTest.java @@ -28,7 +28,9 @@ import com.google.cloud.spanner.r2dbc.client.Client; import com.google.protobuf.Struct; import com.google.protobuf.Value; +import com.google.spanner.v1.ExecuteBatchDmlResponse; import com.google.spanner.v1.PartialResultSet; +import com.google.spanner.v1.ResultSet; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.ResultSetStats; import com.google.spanner.v1.Session; @@ -37,6 +39,7 @@ import com.google.spanner.v1.Type; import com.google.spanner.v1.TypeCode; import io.r2dbc.spi.Result; +import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -195,7 +198,8 @@ public void readMultiResultSetQueryTest() { when(this.mockClient.executeStreamingSql(any(), any(), any(), any(), any())).thenReturn(inputs); - StepVerifier.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "").execute()) + StepVerifier + .create(Flux.from(new SpannerStatement(this.mockClient, null, null, "SELECT").execute()) .flatMap(r -> Mono.from(r.getRowsUpdated()))) .expectNext(0) .verifyComplete(); @@ -203,15 +207,19 @@ public void readMultiResultSetQueryTest() { @Test public void readDmlQueryTest() { - PartialResultSet p1 = PartialResultSet.newBuilder().setStats( - ResultSetStats.newBuilder().setRowCountExact(555).build() - ).build(); + ResultSet resultSet = ResultSet.newBuilder() + .setStats(ResultSetStats.newBuilder().setRowCountExact(555).build()) + .build(); - Flux inputs = Flux.just(p1); + ExecuteBatchDmlResponse executeBatchDmlResponse = ExecuteBatchDmlResponse.newBuilder() + .addResultSets(resultSet) + .build(); - when(this.mockClient.executeStreamingSql(any(), any(), any(), any(), any())).thenReturn(inputs); + when(this.mockClient.executeBatchDml(any(), any(), any(), any(), any())) + .thenReturn(Mono.just(executeBatchDmlResponse)); - StepVerifier.create(Flux.from(new SpannerStatement(this.mockClient, null, null, "").execute()) + StepVerifier.create( + Flux.from(new SpannerStatement(this.mockClient, null, null, "Insert into books").execute()) .flatMap(r -> Mono.from(r.getRowsUpdated()))) .expectNext(555) .verifyComplete(); @@ -221,13 +229,19 @@ public void readDmlQueryTest() { public void noopMapOnUpdateQueriesWhenNoRowsAffected() { Client mockClient = mock(Client.class); String sql = "delete from Books where true"; - PartialResultSet partialResultSet = PartialResultSet.newBuilder() + + ResultSet resultSet = ResultSet.newBuilder() .setMetadata(ResultSetMetadata.getDefaultInstance()) .setStats(ResultSetStats.getDefaultInstance()) .build(); - when(mockClient.executeStreamingSql(TEST_SESSION, null, sql, - Struct.newBuilder().build(), Collections.EMPTY_MAP)) - .thenReturn(Flux.just(partialResultSet)); + + ExecuteBatchDmlResponse executeBatchDmlResponse = ExecuteBatchDmlResponse.newBuilder() + .addResultSets(resultSet) + .build(); + + when(mockClient.executeBatchDml(TEST_SESSION, null, sql, + Arrays.asList(Struct.newBuilder().build()), Collections.EMPTY_MAP)) + .thenReturn(Mono.just(executeBatchDmlResponse)); SpannerStatement statement = new SpannerStatement(mockClient, TEST_SESSION, null, sql); @@ -244,7 +258,7 @@ public void noopMapOnUpdateQueriesWhenNoRowsAffected() { .expectNext(0) .verifyComplete(); - verify(mockClient, times(2)).executeStreamingSql(TEST_SESSION, null, sql, - Struct.newBuilder().build(), Collections.EMPTY_MAP); + verify(mockClient, times(1)).executeBatchDml(TEST_SESSION, null, sql, + Arrays.asList(Struct.newBuilder().build()), Collections.EMPTY_MAP); } } diff --git a/src/test/java/com/google/cloud/spanner/r2dbc/it/SpannerIT.java b/src/test/java/com/google/cloud/spanner/r2dbc/it/SpannerIT.java index 9d7fdcd6..72f41a97 100644 --- a/src/test/java/com/google/cloud/spanner/r2dbc/it/SpannerIT.java +++ b/src/test/java/com/google/cloud/spanner/r2dbc/it/SpannerIT.java @@ -65,6 +65,7 @@ import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; /** * Integration test for connecting to a real Spanner instance. @@ -252,25 +253,48 @@ public void testQuerying() { Mono.from(this.connectionFactory.create()) .delayUntil(c -> c.beginTransaction()) - .delayUntil(c -> Flux.from(c.createStatement( - "INSERT BOOKS (UUID, TITLE, AUTHOR, CATEGORY, FICTION, PUBLISHED, WORDS_PER_SENTENCE)" - + " VALUES (@uuid, @title, @author, @category, @fiction, @published, @wps);") - .bind("uuid", "2b2cbd78-ecd8-430e-b685-fa7910f8a4c7") - .bind("author", "Douglas Crockford") - .bind("category", 100L) - .bind("title", "JavaScript: The Good Parts") - .bind("fiction", true) - .bind("published", LocalDate.of(2008, 5, 1)) - .bind("wps", 20.8) - .add() - .bind("uuid", "df0e3d06-2743-4691-8e51-6d33d90c5cb9") - .bind("author", "Joshua Bloch") - .bind("category", 100L) - .bind("title", "Effective Java") - .bind("fiction", false) - .bind("published", LocalDate.of(2018, 1, 6)) - .bind("wps", 15.1) - .execute()).flatMapSequential(r -> Mono.from(r.getRowsUpdated()))) + .delayUntil(c -> + Mono.fromRunnable(() -> + StepVerifier.create(Flux.from(c.createStatement( + "INSERT BOOKS " + + "(UUID, TITLE, AUTHOR, CATEGORY, FICTION, PUBLISHED, WORDS_PER_SENTENCE)" + + " VALUES " + + "(@uuid, @title, @author, @category, @fiction, @published, @wps);") + .bind("uuid", "2b2cbd78-ecd8-430e-b685-fa7910f8a4c7") + .bind("author", "Douglas Crockford") + .bind("category", 100L) + .bind("title", "JavaScript: The Good Parts") + .bind("fiction", true) + .bind("published", LocalDate.of(2008, 5, 1)) + .bind("wps", 20.8) + .add() + .bind("uuid", "df0e3d06-2743-4691-8e51-6d33d90c5cb9") + .bind("author", "Joshua Bloch") + .bind("category", 100L) + .bind("title", "Effective Java") + .bind("fiction", false) + .bind("published", LocalDate.of(2018, 1, 6)) + .bind("wps", 15.1) + .execute()) + .flatMapSequential(r -> Mono.from(r.getRowsUpdated()))) + .expectNext(1).expectNext(1).verifyComplete()) + ) + .delayUntil(c -> c.commitTransaction()) + .block(); + + Mono.from(this.connectionFactory.create()) + .delayUntil(c -> c.beginTransaction()) + .delayUntil(c -> + Mono.fromRunnable(() -> + StepVerifier + .create(Flux.from(c.createStatement( + "UPDATE BOOKS SET CATEGORY = @new_cat WHERE CATEGORY = @old_cat") + .bind("new_cat", 101L) + .bind("old_cat", 100L) + .execute()) + .flatMap(r -> Mono.from(r.getRowsUpdated()))) + .expectNext(2).verifyComplete()) + ) .delayUntil(c -> c.commitTransaction()) .block(); @@ -353,12 +377,13 @@ private int executeDmlQuery(String sql) { Connection connection = Mono.from(connectionFactory.create()).block(); Mono.from(connection.beginTransaction()).block(); - int rowsUpdated = Mono.from(connection.createStatement(sql).execute()) + List rowsUpdatedPerStatement = Flux.from(connection.createStatement(sql).execute()) .flatMap(result -> Mono.from(result.getRowsUpdated())) + .collectList() .block(); Mono.from(connection.commitTransaction()).block(); - return rowsUpdated; + return rowsUpdatedPerStatement.get(0); } /**