diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index da74bf35f0..a8ba82709b 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -146,7 +146,7 @@ def _check_state(self): if self.committed is not None: raise ValueError("Batch already committed") - def commit(self, return_commit_stats=False, request_options=None): + def commit(self, return_commit_stats=False, request_options=None, max_commit_delay=None): """Commit mutations to the database. :type return_commit_stats: bool @@ -189,6 +189,7 @@ def commit(self, return_commit_stats=False, request_options=None): single_use_transaction=txn_options, return_commit_stats=return_commit_stats, request_options=request_options, + max_commit_delay=max_commit_delay, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): method = functools.partial( diff --git a/google/cloud/spanner_v1/database.py b/google/cloud/spanner_v1/database.py index 1a651a66f5..2449b4c0ec 100644 --- a/google/cloud/spanner_v1/database.py +++ b/google/cloud/spanner_v1/database.py @@ -721,7 +721,7 @@ def snapshot(self, **kw): """ return SnapshotCheckout(self, **kw) - def batch(self, request_options=None): + def batch(self, request_options=None, max_commit_delay=None): """Return an object which wraps a batch. The wrapper *must* be used as a context manager, with the batch @@ -737,7 +737,7 @@ def batch(self, request_options=None): :rtype: :class:`~google.cloud.spanner_v1.database.BatchCheckout` :returns: new wrapper """ - return BatchCheckout(self, request_options) + return BatchCheckout(self, request_options, max_commit_delay) def mutation_groups(self): """Return an object which wraps a mutation_group. @@ -1037,7 +1037,7 @@ class BatchCheckout(object): message :class:`~google.cloud.spanner_v1.types.RequestOptions`. """ - def __init__(self, database, request_options=None): + def __init__(self, database, request_options=None, max_commit_delay=None): self._database = database self._session = self._batch = None if request_options is None: @@ -1046,6 +1046,7 @@ def __init__(self, database, request_options=None): self._request_options = RequestOptions(request_options) else: self._request_options = request_options + self._max_commit_delay = max_commit_delay def __enter__(self): """Begin ``with`` block.""" @@ -1062,6 +1063,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._batch.commit( return_commit_stats=self._database.log_commit_stats, request_options=self._request_options, + max_commit_delay=self._max_commit_delay, ) finally: if self._database.log_commit_stats and self._batch.commit_stats: diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index b25af53805..29772662ee 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -372,6 +372,7 @@ def run_in_transaction(self, func, *args, **kw): """ deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS) commit_request_options = kw.pop("commit_request_options", None) + max_commit_delay = kw.pop("max_commit_delay", None) transaction_tag = kw.pop("transaction_tag", None) attempts = 0 @@ -400,6 +401,7 @@ def run_in_transaction(self, func, *args, **kw): txn.commit( return_commit_stats=self._database.log_commit_stats, request_options=commit_request_options, + max_commit_delay=max_commit_delay, ) except Aborted as exc: del self._transaction diff --git a/google/cloud/spanner_v1/transaction.py b/google/cloud/spanner_v1/transaction.py index d564d0d488..4d4bf1f824 100644 --- a/google/cloud/spanner_v1/transaction.py +++ b/google/cloud/spanner_v1/transaction.py @@ -180,7 +180,7 @@ def rollback(self): self.rolled_back = True del self._session._transaction - def commit(self, return_commit_stats=False, request_options=None): + def commit(self, return_commit_stats=False, request_options=None, max_commit_delay=None): """Commit mutations to the database. :type return_commit_stats: bool @@ -229,6 +229,7 @@ def commit(self, return_commit_stats=False, request_options=None): transaction_id=self._transaction_id, return_commit_stats=return_commit_stats, request_options=request_options, + max_commit_delay=max_commit_delay, ) with trace_call("CloudSpanner.Commit", self._session, trace_attributes): method = functools.partial( diff --git a/tests/system/test_database_api.py b/tests/system/test_database_api.py index 052e628188..6484663fbe 100644 --- a/tests/system/test_database_api.py +++ b/tests/system/test_database_api.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import datetime import time import uuid @@ -819,3 +819,38 @@ def _transaction_read(transaction): with pytest.raises(exceptions.InvalidArgument): shared_database.run_in_transaction(_transaction_read) + + +def test_db_batch_insert_w_max_commit_delay(shared_database): + _helpers.retry_has_all_dll(shared_database.reload)() + sd = _sample_data + + with shared_database.batch(max_commit_delay=datetime.timedelta(milliseconds=100)) as batch: + batch.delete(sd.TABLE, sd.ALL) + batch.insert(sd.TABLE, sd.COLUMNS, sd.ROW_DATA) + + with shared_database.snapshot(read_timestamp=batch.committed) as snapshot: + from_snap = list(snapshot.read(sd.TABLE, sd.COLUMNS, sd.ALL)) + + sd._check_rows_data(from_snap) + + +def test_db_run_in_transaction_w_max_commit_delay(shared_database): + _helpers.retry_has_all_dll(shared_database.reload)() + sd = _sample_data + + with shared_database.batch() as batch: + batch.delete(sd.TABLE, sd.ALL) + + def _unit_of_work(transaction, test): + rows = list(transaction.read(test.TABLE, test.COLUMNS, sd.ALL)) + assert rows == [] + + transaction.insert_or_update(test.TABLE, test.COLUMNS, test.ROW_DATA) + + shared_database.run_in_transaction(_unit_of_work, test=sd, max_commit_delay=datetime.timedelta(milliseconds=100)) + + with shared_database.snapshot() as after: + rows = list(after.execute_sql(sd.SQL)) + + sd._check_rows_data(rows)