From b608f9e57f602d9c37db8afa086c496000e1aca6 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Mon, 2 Oct 2023 12:40:37 +0200 Subject: [PATCH] Some refactor of internal state machines (#112) --- .../BaseSuspendableCallbackStateMachine.java | 54 +++++++ .../restate/sdk/core/impl/CallbackHandle.java | 38 +++++ .../dev/restate/sdk/core/impl/Entries.java | 48 +++--- .../restate/sdk/core/impl/EntriesQueue.java | 66 -------- .../impl/IncomingEntriesStateMachine.java | 46 ++++++ .../sdk/core/impl/InputChannelState.java | 48 ------ .../sdk/core/impl/InputPublisherState.java | 25 +++ .../sdk/core/impl/InvocationStateMachine.java | 146 ++++++++++-------- ...sher.java => ReadyResultStateMachine.java} | 50 ++---- .../sdk/core/impl/SideEffectAckPublisher.java | 66 -------- .../core/impl/SideEffectAckStateMachine.java | 49 ++++++ .../sdk/core/impl/SuspendableCallback.java | 8 + .../restate/sdk/core/impl/SyscallsImpl.java | 17 +- ...lStateStorage.java => UserStateStore.java} | 4 +- .../java/dev/restate/sdk/core/impl/Util.java | 4 +- 15 files changed, 356 insertions(+), 313 deletions(-) create mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/BaseSuspendableCallbackStateMachine.java create mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/CallbackHandle.java delete mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/EntriesQueue.java create mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/IncomingEntriesStateMachine.java delete mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputChannelState.java create mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputPublisherState.java rename sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/{ReadyResultPublisher.java => ReadyResultStateMachine.java} (64%) delete mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckPublisher.java create mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckStateMachine.java create mode 100644 sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SuspendableCallback.java rename sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/{LocalStateStorage.java => UserStateStore.java} (92%) diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/BaseSuspendableCallbackStateMachine.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/BaseSuspendableCallbackStateMachine.java new file mode 100644 index 00000000..fe53bda3 --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/BaseSuspendableCallbackStateMachine.java @@ -0,0 +1,54 @@ +package dev.restate.sdk.core.impl; + +import java.util.function.Consumer; + +// Implements the base logic for state machines containing suspensable callbacks. +abstract class BaseSuspendableCallbackStateMachine { + + private final CallbackHandle callbackHandle; + private final InputPublisherState inputPublisherState; + + BaseSuspendableCallbackStateMachine() { + this.callbackHandle = new CallbackHandle<>(); + this.inputPublisherState = new InputPublisherState(); + } + + void abort(Throwable cause) { + this.inputPublisherState.notifyClosed(cause); + } + + public void tryFailCallback() { + callbackHandle.consume( + cb -> { + if (inputPublisherState.isSuspended()) { + cb.onSuspend(); + } else if (inputPublisherState.isClosed()) { + cb.onError(inputPublisherState.getCloseCause()); + } + }); + } + + public void consumeCallback(Consumer consumer) { + this.callbackHandle.consume(consumer); + } + + public void consumeCallbackOrElse(Consumer consumer, Runnable elseRunnable) { + this.callbackHandle.consumeOrElse(consumer, elseRunnable); + } + + public void assertCallbackNotSet(String reason) { + if (!this.callbackHandle.isEmpty()) { + throw new IllegalStateException(reason); + } + } + + void setCallback(CB callback) { + if (inputPublisherState.isSuspended()) { + callback.onSuspend(); + } else if (inputPublisherState.isClosed()) { + callback.onError(inputPublisherState.getCloseCause()); + } else { + callbackHandle.set(callback); + } + } +} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/CallbackHandle.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/CallbackHandle.java new file mode 100644 index 00000000..5f4c98a6 --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/CallbackHandle.java @@ -0,0 +1,38 @@ +package dev.restate.sdk.core.impl; + +import java.util.function.Consumer; +import javax.annotation.Nullable; + +/** Handle for callbacks. */ +final class CallbackHandle { + + private @Nullable T cb = null; + + public void set(T t) { + this.cb = t; + } + + public boolean isEmpty() { + return this.cb == null; + } + + public void consume(Consumer consumer) { + if (this.cb != null) { + consumer.accept(pop()); + } + } + + public void consumeOrElse(Consumer consumer, Runnable elseRunnable) { + if (this.cb != null) { + consumer.accept(pop()); + } else { + elseRunnable.run(); + } + } + + private @Nullable T pop() { + T temp = this.cb; + this.cb = null; + return temp; + } +} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Entries.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Entries.java index 286d347c..1f543839 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Entries.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Entries.java @@ -19,7 +19,7 @@ void checkEntryHeader(E expected, MessageLite actual) throws ProtocolException { abstract void trace(E expected, Span span); - void updateLocalStateStorage(E expected, LocalStateStorage localStateStorage) {} + void updateUserStateStoreWithEntry(E expected, UserStateStore userStateStore) {} } abstract static class CompletableJournalEntry extends JournalEntry { @@ -32,12 +32,12 @@ ReadyResultInternal parseCompletionResult(CompletionMessage actual) { this.getClass().getName(), actual.getResultCase()); } - E tryCompleteWithLocalStateStorage(E expected, LocalStateStorage localStateStorage) { + E tryCompleteWithUserStateStorage(E expected, UserStateStore userStateStore) { return expected; } - void updateLocalStateStorageWithCompletion( - E expected, CompletionMessage actual, LocalStateStorage localStateStorage) {} + void updateUserStateStorageWithCompletion( + E expected, CompletionMessage actual, UserStateStore userStateStore) {} } static final class PollInputEntry @@ -135,32 +135,30 @@ public ReadyResultInternal parseCompletionResult(CompletionMessage actual) { } @Override - void updateLocalStateStorage( - GetStateEntryMessage expected, LocalStateStorage localStateStorage) { - localStateStorage.set(expected.getKey(), expected.getValue()); + void updateUserStateStoreWithEntry( + GetStateEntryMessage expected, UserStateStore userStateStore) { + userStateStore.set(expected.getKey(), expected.getValue()); } @Override - GetStateEntryMessage tryCompleteWithLocalStateStorage( - GetStateEntryMessage expected, LocalStateStorage localStateStorage) { - LocalStateStorage.State value = localStateStorage.get(expected.getKey()); - if (value instanceof LocalStateStorage.Value) { - return expected.toBuilder().setValue(((LocalStateStorage.Value) value).getValue()).build(); - } else if (value instanceof LocalStateStorage.Empty) { + GetStateEntryMessage tryCompleteWithUserStateStorage( + GetStateEntryMessage expected, UserStateStore userStateStore) { + UserStateStore.State value = userStateStore.get(expected.getKey()); + if (value instanceof UserStateStore.Value) { + return expected.toBuilder().setValue(((UserStateStore.Value) value).getValue()).build(); + } else if (value instanceof UserStateStore.Empty) { return expected.toBuilder().setEmpty(Empty.getDefaultInstance()).build(); } return expected; } @Override - void updateLocalStateStorageWithCompletion( - GetStateEntryMessage expected, - CompletionMessage actual, - LocalStateStorage localStateStorage) { + void updateUserStateStorageWithCompletion( + GetStateEntryMessage expected, CompletionMessage actual, UserStateStore userStateStore) { if (actual.hasEmpty()) { - localStateStorage.clear(expected.getKey()); + userStateStore.clear(expected.getKey()); } else { - localStateStorage.set(expected.getKey(), actual.getValue()); + userStateStore.set(expected.getKey(), actual.getValue()); } } } @@ -184,9 +182,9 @@ void checkEntryHeader(ClearStateEntryMessage expected, MessageLite actual) } @Override - void updateLocalStateStorage( - ClearStateEntryMessage expected, LocalStateStorage localStateStorage) { - localStateStorage.clear(expected.getKey()); + void updateUserStateStoreWithEntry( + ClearStateEntryMessage expected, UserStateStore userStateStore) { + userStateStore.clear(expected.getKey()); } } @@ -209,9 +207,9 @@ void checkEntryHeader(SetStateEntryMessage expected, MessageLite actual) } @Override - void updateLocalStateStorage( - SetStateEntryMessage expected, LocalStateStorage localStateStorage) { - localStateStorage.set(expected.getKey(), expected.getValue()); + void updateUserStateStoreWithEntry( + SetStateEntryMessage expected, UserStateStore userStateStore) { + userStateStore.set(expected.getKey(), expected.getValue()); } } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/EntriesQueue.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/EntriesQueue.java deleted file mode 100644 index 96ca1cce..00000000 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/EntriesQueue.java +++ /dev/null @@ -1,66 +0,0 @@ -package dev.restate.sdk.core.impl; - -import com.google.protobuf.MessageLite; -import dev.restate.sdk.core.syscalls.SyscallCallback; -import java.util.ArrayDeque; -import java.util.Queue; -import java.util.function.Consumer; -import javax.annotation.Nullable; - -class EntriesQueue { - - private final Queue unprocessedMessages; - - @Nullable private SyscallCallback callback; - private boolean closed; - - EntriesQueue() { - this.unprocessedMessages = new ArrayDeque<>(); - - this.closed = false; - } - - void offer(MessageLite msg) { - Util.assertIsEntry(msg); - - if (this.callback != null) { - popCallback().onSuccess(msg); - } else { - this.unprocessedMessages.offer(msg); - } - } - - void read(Consumer msgCallback, Consumer errorCallback) { - if (this.callback != null) { - throw new IllegalStateException("Two concurrent reads were requested."); - } - if (this.closed) { - throw new IllegalStateException("Cannot read when closed"); - } - - MessageLite popped = this.unprocessedMessages.poll(); - if (popped != null) { - msgCallback.accept(popped); - } else { - this.callback = SyscallCallback.of(msgCallback, errorCallback); - } - } - - void abort(Throwable e) { - this.closed = true; - if (this.callback != null) { - popCallback().onCancel(e); - } - } - - boolean isEmpty() { - return this.unprocessedMessages.isEmpty(); - } - - @Nullable - private SyscallCallback popCallback() { - SyscallCallback callback = this.callback; - this.callback = null; - return callback; - } -} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/IncomingEntriesStateMachine.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/IncomingEntriesStateMachine.java new file mode 100644 index 00000000..daaff182 --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/IncomingEntriesStateMachine.java @@ -0,0 +1,46 @@ +package dev.restate.sdk.core.impl; + +import com.google.protobuf.MessageLite; +import java.util.ArrayDeque; +import java.util.Queue; + +class IncomingEntriesStateMachine + extends BaseSuspendableCallbackStateMachine { + + interface OnEntryCallback extends SuspendableCallback { + void onEntry(MessageLite msg); + } + + private final Queue unprocessedMessages; + + IncomingEntriesStateMachine() { + this.unprocessedMessages = new ArrayDeque<>(); + } + + void offer(MessageLite msg) { + Util.assertIsEntry(msg); + this.consumeCallbackOrElse(cb -> cb.onEntry(msg), () -> this.unprocessedMessages.offer(msg)); + } + + void read(OnEntryCallback msgCallback) { + this.assertCallbackNotSet("Two concurrent reads were requested."); + + MessageLite popped = this.unprocessedMessages.poll(); + if (popped != null) { + msgCallback.onEntry(popped); + } else { + this.setCallback(msgCallback); + } + } + + boolean isEmpty() { + return this.unprocessedMessages.isEmpty(); + } + + @Override + void abort(Throwable cause) { + super.abort(cause); + // We can't do anything else if the input stream is closed, so we just fail the callback, if any + this.tryFailCallback(); + } +} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputChannelState.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputChannelState.java deleted file mode 100644 index 6057e638..00000000 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputChannelState.java +++ /dev/null @@ -1,48 +0,0 @@ -package dev.restate.sdk.core.impl; - -import dev.restate.sdk.core.SuspendedException; -import javax.annotation.Nullable; - -class InputChannelState { - - interface SuspendableCallback { - - void onSuspend(); - - void onError(Throwable e); - } - - private @Nullable Throwable closeCause = null; - - /** - * @return false if it was already closed. - */ - boolean close(Throwable cause) { - // Guard against multiple requests of transitions to suspended - if (this.closeCause != null) { - return false; - } - closeCause = cause; - return true; - } - - /** Consumes the callback if the state is closed or suspended, otherwise returns it. */ - @Nullable CB handleOrReturn(CB callback) { - if (isSuspended()) { - callback.onSuspend(); - return null; - } else if (isClosed()) { - callback.onError(closeCause); - return null; - } - return callback; - } - - private boolean isSuspended() { - return this.closeCause == SuspendedException.INSTANCE; - } - - boolean isClosed() { - return this.closeCause != null; - } -} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputPublisherState.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputPublisherState.java new file mode 100644 index 00000000..35ec649e --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InputPublisherState.java @@ -0,0 +1,25 @@ +package dev.restate.sdk.core.impl; + +import dev.restate.sdk.core.SuspendedException; +import javax.annotation.Nullable; + +class InputPublisherState { + + private @Nullable Throwable closeCause = null; + + void notifyClosed(Throwable cause) { + closeCause = cause; + } + + boolean isSuspended() { + return this.closeCause == SuspendedException.INSTANCE; + } + + boolean isClosed() { + return this.closeCause != null; + } + + public Throwable getCloseCause() { + return closeCause; + } +} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java index b75e7675..14e78ca9 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/InvocationStateMachine.java @@ -47,28 +47,30 @@ private enum State { private ByteString id; private String debugId; private int entriesToReplay; - private LocalStateStorage localStateStorage; + private UserStateStore userStateStore; // Index tracking progress in the journal private int currentJournalIndex; // Buffering of messages and completions - private final SideEffectAckPublisher sideEffectAckPublisher; - private final EntriesQueue entriesQueue; - private final ReadyResultPublisher readyResultPublisher; + private final IncomingEntriesStateMachine incomingEntriesStateMachine; + private final SideEffectAckStateMachine sideEffectAckStateMachine; + private final ReadyResultStateMachine readyResultStateMachine; // Flow sub/pub private Flow.Subscriber outputSubscriber; private Flow.Subscription inputSubscription; - private Consumer afterStartCallback; + private final CallbackHandle> afterStartCallback; public InvocationStateMachine(String serviceName, Span span) { this.serviceName = serviceName; this.span = span; - this.sideEffectAckPublisher = new SideEffectAckPublisher(); - this.entriesQueue = new EntriesQueue(); - this.readyResultPublisher = new ReadyResultPublisher(); + this.incomingEntriesStateMachine = new IncomingEntriesStateMachine(); + this.readyResultStateMachine = new ReadyResultStateMachine(); + this.sideEffectAckStateMachine = new SideEffectAckStateMachine(); + + this.afterStartCallback = new CallbackHandle<>(); } // --- Getters @@ -124,12 +126,12 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) { // If ack, give it to side effect publisher if (completionMessage.getResultCase() == Protocol.CompletionMessage.ResultCase.RESULT_NOT_SET) { - this.sideEffectAckPublisher.tryHandleSideEffectAck(completionMessage.getEntryIndex()); + this.sideEffectAckStateMachine.tryHandleSideEffectAck(completionMessage.getEntryIndex()); } else { - this.readyResultPublisher.offerCompletion((Protocol.CompletionMessage) msg); + this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg); } } else { - this.entriesQueue.offer(msg); + this.incomingEntriesStateMachine.offer(msg); } } @@ -142,14 +144,14 @@ public void onError(Throwable throwable) { @Override public void onComplete() { LOG.trace("Input publisher closed"); - this.sideEffectAckPublisher.abort(SuspendedException.INSTANCE); - this.readyResultPublisher.abort(SuspendedException.INSTANCE); + this.readyResultStateMachine.abort(SuspendedException.INSTANCE); + this.sideEffectAckStateMachine.abort(SuspendedException.INSTANCE); } // --- Init routine to wait for the start message void start(Consumer afterStartCallback) { - this.afterStartCallback = afterStartCallback; + this.afterStartCallback.set(afterStartCallback); this.inputSubscription.request(1); } @@ -166,8 +168,8 @@ void onStart(MessageLite msg) { this.entriesToReplay = startMessage.getKnownEntries(); // Set up the state cache - this.localStateStorage = - new LocalStateStorage( + this.userStateStore = + new UserStateStore( startMessage.getPartialState(), startMessage.getStateMapList().stream() .collect( @@ -188,15 +190,15 @@ void onStart(MessageLite msg) { this.inputSubscription.request(Long.MAX_VALUE); // Now execute the callback after start - Consumer afterStartCallback = this.afterStartCallback; - this.afterStartCallback = null; - afterStartCallback.accept(new InvocationIdImpl(this.debugId)); + this.afterStartCallback.consume(cb -> cb.accept(new InvocationIdImpl(this.debugId))); } void close() { if (this.state != State.CLOSED) { this.transitionState(State.CLOSED); LOG.debug("Closing state machine"); + + // Cancel inputSubscription and complete outputSubscriber if (inputSubscription != null) { this.inputSubscription.cancel(); } @@ -204,9 +206,11 @@ void close() { this.outputSubscriber.onComplete(); this.outputSubscriber = null; } - this.readyResultPublisher.abort(ProtocolException.CLOSED); - this.sideEffectAckPublisher.abort(ProtocolException.CLOSED); - this.entriesQueue.abort(ProtocolException.CLOSED); + + // Unblock any eventual waiting callbacks + this.readyResultStateMachine.abort(ProtocolException.CLOSED); + this.sideEffectAckStateMachine.abort(ProtocolException.CLOSED); + this.incomingEntriesStateMachine.abort(ProtocolException.CLOSED); this.span.end(); } } @@ -215,10 +219,13 @@ void fail(Throwable cause) { if (this.state != State.CLOSED) { this.transitionState(State.CLOSED); LOG.debug("Closing state machine with failure", cause); + + // Cancel inputSubscription and complete outputSubscriber if (inputSubscription != null) { this.inputSubscription.cancel(); } if (this.outputSubscriber != null) { + // Publish ErrorMessage to output subscriber before closing. if (cause instanceof ProtocolException) { this.outputSubscriber.onNext(((ProtocolException) cause).toErrorMessage()); } else if (cause != null) { @@ -231,15 +238,16 @@ void fail(Throwable cause) { this.outputSubscriber.onComplete(); this.outputSubscriber = null; } - this.insideSideEffect = false; - this.readyResultPublisher.abort(cause); - this.sideEffectAckPublisher.abort(cause); - this.entriesQueue.abort(cause); + + // Unblock any eventual waiting callbacks + this.readyResultStateMachine.abort(cause); + this.sideEffectAckStateMachine.abort(cause); + this.incomingEntriesStateMachine.abort(cause); this.span.end(); } } - // --- Methods to implement syscalls + // --- Methods to implement Syscalls @SuppressWarnings("unchecked") void processCompletableJournalEntry( @@ -256,16 +264,19 @@ void processCompletableJournalEntry( journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); if (journalEntry.hasResult((E) actualEntryMessage)) { - journalEntry.updateLocalStateStorage((E) actualEntryMessage, this.localStateStorage); + // Entry is already completed + journalEntry.updateUserStateStoreWithEntry( + (E) actualEntryMessage, this.userStateStore); ReadyResultInternal readyResultInternal = journalEntry.parseEntryResult((E) actualEntryMessage); callback.onSuccess(DeferredResults.completedSingle(entryIndex, readyResultInternal)); } else { - this.readyResultPublisher.offerCompletionParser( + // Entry is not completed yet + this.readyResultStateMachine.offerCompletionParser( entryIndex, completionMessage -> { - journalEntry.updateLocalStateStorageWithCompletion( - (E) actualEntryMessage, completionMessage, this.localStateStorage); + journalEntry.updateUserStateStorageWithCompletion( + (E) actualEntryMessage, completionMessage, this.userStateStore); return journalEntry.parseCompletionResult(completionMessage); }); callback.onSuccess(DeferredResults.single(entryIndex)); @@ -275,7 +286,7 @@ void processCompletableJournalEntry( } else if (this.state == State.PROCESSING) { // Try complete with local storage E entryToWrite = - journalEntry.tryCompleteWithLocalStateStorage(expectedEntryMessage, localStateStorage); + journalEntry.tryCompleteWithUserStateStorage(expectedEntryMessage, userStateStore); if (span.isRecording()) { journalEntry.trace(entryToWrite, span); @@ -294,11 +305,11 @@ void processCompletableJournalEntry( entryIndex, journalEntry.parseEntryResult(entryToWrite))); } else { // Register the completion parser - this.readyResultPublisher.offerCompletionParser( + this.readyResultStateMachine.offerCompletionParser( entryIndex, completionMessage -> { - journalEntry.updateLocalStateStorageWithCompletion( - entryToWrite, completionMessage, this.localStateStorage); + journalEntry.updateUserStateStorageWithCompletion( + entryToWrite, completionMessage, this.userStateStore); return journalEntry.parseCompletionResult(completionMessage); }); @@ -322,7 +333,7 @@ void processJournalEntry( this.readEntry( (entryIndex, actualEntryMessage) -> { journalEntry.checkEntryHeader(expectedEntryMessage, actualEntryMessage); - journalEntry.updateLocalStateStorage((E) actualEntryMessage, this.localStateStorage); + journalEntry.updateUserStateStoreWithEntry((E) actualEntryMessage, this.userStateStore); callback.onSuccess(null); }, callback::onCancel); @@ -335,7 +346,7 @@ void processJournalEntry( this.writeEntry(expectedEntryMessage); // Update local storage - journalEntry.updateLocalStateStorage(expectedEntryMessage, this.localStateStorage); + journalEntry.updateUserStateStoreWithEntry(expectedEntryMessage, this.userStateStore); // Invoke the ok callback callback.onSuccess(null); @@ -346,7 +357,6 @@ void processJournalEntry( } void enterSideEffectBlock( - Consumer traceFn, Consumer entryCallback, Runnable noEntryCallback, Consumer failureCallback) { @@ -363,20 +373,20 @@ void enterSideEffectBlock( }, failureCallback); } else if (this.state == State.PROCESSING) { - this.sideEffectAckPublisher.executeEnterSideEffect( - new SideEffectAckPublisher.OnEnterSideEffectCallback() { + this.sideEffectAckStateMachine.executeEnterSideEffect( + new SideEffectAckStateMachine.OnEnterSideEffectCallback() { @Override public void onEnter() { insideSideEffect = true; if (span.isRecording()) { - traceFn.accept(span); + span.addEvent("Enter SideEffect"); } noEntryCallback.run(); } @Override public void onSuspend() { - writeSuspension(sideEffectAckPublisher.getLastExecutedSideEffect()); + writeSuspension(sideEffectAckStateMachine.getLastExecutedSideEffect()); failureCallback.accept(SuspendedException.INSTANCE); } @@ -393,23 +403,22 @@ public void onError(Throwable e) { void exitSideEffectBlock( Java.SideEffectEntryMessage sideEffectToWrite, - Consumer traceFn, Consumer entryCallback, Consumer failureCallback) { this.insideSideEffect = false; - if (this.state == State.REPLAYING) { + if (this.state == State.CLOSED) { + failureCallback.accept(SuspendedException.INSTANCE); + } else if (this.state == State.REPLAYING) { throw new IllegalStateException( "exitSideEffect has been invoked when the state machine is in replaying mode. " + "This is probably an SDK bug and might be caused by a missing enterSideEffectBlock invocation before exitSideEffectBlock."); - } else if (this.state == State.CLOSED) { - failureCallback.accept(SuspendedException.INSTANCE); } else if (this.state == State.PROCESSING) { if (span.isRecording()) { - traceFn.accept(span); + span.addEvent("Exit SideEffect"); } // Write new entry - this.sideEffectAckPublisher.registerExecutedSideEffect(this.currentJournalIndex); + this.sideEffectAckStateMachine.registerExecutedSideEffect(this.currentJournalIndex); this.writeEntry(sideEffectToWrite); entryCallback.accept(sideEffectToWrite); @@ -442,8 +451,8 @@ void resolveDeferred(DeferredResult deferredToResolve, SyscallCallback void resolveSingleDeferred( ResolvableSingleDeferredResult deferred, SyscallCallback callback) { - this.readyResultPublisher.onNewReadyResult( - new ReadyResultPublisher.OnNewReadyResultCallback() { + this.readyResultStateMachine.onNewReadyResult( + new ReadyResultStateMachine.OnNewReadyResultCallback() { @SuppressWarnings("unchecked") @Override public boolean onNewReadyResult(Map> resultMap) { @@ -481,9 +490,9 @@ public void onError(Throwable e) { * is resolved through {@link ResolvableSingleDeferredResult#resolve(ReadyResultInternal)}, we try * to resolve the tree again. We start by checking if we have enough resolved leafs in the * combinator tree to resolve it. If not, we register a callback to the {@link - * ReadyResultPublisher} to wait on future completions. As soon as the tree is resolved, we record - * in the journal the order of the leafs we've seen so far, and we finish by calling the {@code - * callback}, giving back control to user code. + * ReadyResultStateMachine} to wait on future completions. As soon as the tree is resolved, we + * record in the journal the order of the leafs we've seen so far, and we finish by calling the + * {@code callback}, giving back control to user code. * *

An important property of this algorithm is that we don't write multiple {@link * Java.CombinatorAwaitableEntryMessage} per combinator nodes composing the tree, but we write one @@ -569,8 +578,8 @@ private void resolveCombinatorDeferred( } // Not completed yet, we need to wait on the ReadyResultPublisher - this.readyResultPublisher.onNewReadyResult( - new ReadyResultPublisher.OnNewReadyResultCallback() { + this.readyResultStateMachine.onNewReadyResult( + new ReadyResultStateMachine.OnNewReadyResultCallback() { @SuppressWarnings({"unchecked", "rawtypes"}) @Override public boolean onNewReadyResult(Map> resultMap) { @@ -638,7 +647,7 @@ private void incrementCurrentIndex() { this.currentJournalIndex++; if (currentJournalIndex >= entriesToReplay && this.state == State.REPLAYING) { - if (!this.entriesQueue.isEmpty()) { + if (!this.incomingEntriesStateMachine.isEmpty()) { throw new IllegalStateException("Entries queue should be empty at this point"); } this.transitionState(State.PROCESSING); @@ -652,12 +661,25 @@ private void checkInsideSideEffectGuard() { } void readEntry(BiConsumer msgCallback, Consumer errorCallback) { - this.entriesQueue.read( - msg -> { - incrementCurrentIndex(); - msgCallback.accept(this.currentJournalIndex - 1, msg); - }, - errorCallback); + this.incomingEntriesStateMachine.read( + new IncomingEntriesStateMachine.OnEntryCallback() { + @Override + public void onEntry(MessageLite msg) { + incrementCurrentIndex(); + msgCallback.accept(currentJournalIndex - 1, msg); + } + + @Override + public void onSuspend() { + // This is not expected to happen, so we treat this case as closed + errorCallback.accept(ProtocolException.CLOSED); + } + + @Override + public void onError(Throwable e) { + errorCallback.accept(e); + } + }); } private void writeEntry(MessageLite message) { diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ReadyResultPublisher.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ReadyResultStateMachine.java similarity index 64% rename from sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ReadyResultPublisher.java rename to sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ReadyResultStateMachine.java index 943ae7f8..aeae788b 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ReadyResultPublisher.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/ReadyResultStateMachine.java @@ -5,16 +5,16 @@ import java.util.HashMap; import java.util.Map; import java.util.function.Function; -import javax.annotation.Nullable; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; -/** Implements determinism of publishers */ -class ReadyResultPublisher { +/** State machine tracking ready results */ +class ReadyResultStateMachine + extends BaseSuspendableCallbackStateMachine { - private static final Logger LOG = LogManager.getLogger(ReadyResultPublisher.class); + private static final Logger LOG = LogManager.getLogger(ReadyResultStateMachine.class); - interface OnNewReadyResultCallback extends InputChannelState.SuspendableCallback { + interface OnNewReadyResultCallback extends SuspendableCallback { boolean onNewReadyResult(Map> resultMap); } @@ -23,22 +23,13 @@ interface OnNewReadyResultCallback extends InputChannelState.SuspendableCallback completionParsers; private final Map> results; - private @Nullable OnNewReadyResultCallback onNewReadyResultCallback; - - private final InputChannelState state; - - ReadyResultPublisher() { + ReadyResultStateMachine() { this.completions = new HashMap<>(); this.completionParsers = new HashMap<>(); this.results = new HashMap<>(); - this.state = new InputChannelState(); } void offerCompletion(Protocol.CompletionMessage completionMessage) { - if (this.state.isClosed()) { - LOG.warn("Offering a completion when the publisher is closed"); - return; - } LOG.trace("Offered new completion {}", completionMessage); this.completions.put(completionMessage.getEntryIndex(), completionMessage); @@ -54,18 +45,14 @@ void offerCompletionParser( } void onNewReadyResult(OnNewReadyResultCallback callback) { - if (this.onNewReadyResultCallback != null) { - throw new IllegalStateException("Two concurrent reads were requested."); - } - this.onNewReadyResultCallback = callback; + this.assertCallbackNotSet("Two concurrent reads were requested."); - this.tryProgress(); + this.tryProgress(callback); } void abort(Throwable cause) { - if (this.state.close(cause)) { - tryProgress(); - } + super.abort(cause); + this.consumeCallback(this::tryProgress); } private void tryParse(int entryIndex) { @@ -89,20 +76,13 @@ private void tryParse(int entryIndex) { this.results.put(completionMessage.getEntryIndex(), readyResult); // We have a new result, let's try to progress - this.tryProgress(); + this.consumeCallback(this::tryProgress); } - private void tryProgress() { - if (this.onNewReadyResultCallback != null) { - // Pop callback - OnNewReadyResultCallback cb = this.onNewReadyResultCallback; - this.onNewReadyResultCallback = null; - - // Try to consume results - boolean resolved = cb.onNewReadyResult(this.results); - if (!resolved) { - this.onNewReadyResultCallback = this.state.handleOrReturn(cb); - } + private void tryProgress(OnNewReadyResultCallback cb) { + boolean resolved = cb.onNewReadyResult(this.results); + if (!resolved) { + this.setCallback(cb); } } } diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckPublisher.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckPublisher.java deleted file mode 100644 index 5679a157..00000000 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckPublisher.java +++ /dev/null @@ -1,66 +0,0 @@ -package dev.restate.sdk.core.impl; - -import javax.annotation.Nullable; - -class SideEffectAckPublisher { - - interface OnEnterSideEffectCallback extends InputChannelState.SuspendableCallback { - void onEnter(); - } - - private int lastAcknowledgedEntry = -1; - /** -1 means no side effect waiting to be acked. */ - private int lastExecutedSideEffect = -1; - - private @Nullable OnEnterSideEffectCallback onEnterSideEffectCallback; - private final InputChannelState state = new InputChannelState(); - - void executeEnterSideEffect(OnEnterSideEffectCallback callback) { - if (canExecuteSideEffect()) { - callback.onEnter(); - } else { - this.onEnterSideEffectCallback = state.handleOrReturn(callback); - } - } - - void tryHandleSideEffectAck(int entryIndex) { - this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry); - if (canExecuteSideEffect()) { - tryInvokeCallback(); - } - } - - void registerExecutedSideEffect(int entryIndex) { - this.lastExecutedSideEffect = entryIndex; - } - - void abort(Throwable e) { - if (this.state.close(e)) { - tryFailCallback(); - } - } - - private void tryInvokeCallback() { - if (this.onEnterSideEffectCallback != null) { - OnEnterSideEffectCallback cb = this.onEnterSideEffectCallback; - this.onEnterSideEffectCallback = null; - cb.onEnter(); - } - } - - private void tryFailCallback() { - if (this.onEnterSideEffectCallback != null) { - OnEnterSideEffectCallback cb = this.onEnterSideEffectCallback; - this.onEnterSideEffectCallback = null; - state.handleOrReturn(cb); - } - } - - private boolean canExecuteSideEffect() { - return this.lastExecutedSideEffect <= this.lastAcknowledgedEntry; - } - - public int getLastExecutedSideEffect() { - return lastExecutedSideEffect; - } -} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckStateMachine.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckStateMachine.java new file mode 100644 index 00000000..65230e33 --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SideEffectAckStateMachine.java @@ -0,0 +1,49 @@ +package dev.restate.sdk.core.impl; + +/** State machine tracking side effects acks */ +class SideEffectAckStateMachine + extends BaseSuspendableCallbackStateMachine< + SideEffectAckStateMachine.OnEnterSideEffectCallback> { + + interface OnEnterSideEffectCallback extends SuspendableCallback { + void onEnter(); + } + + private int lastAcknowledgedEntry = -1; + /** -1 means no side effect waiting to be acked. */ + private int lastExecutedSideEffect = -1; + + void executeEnterSideEffect(OnEnterSideEffectCallback callback) { + if (canExecuteSideEffect()) { + callback.onEnter(); + } else { + this.setCallback(callback); + } + } + + void tryHandleSideEffectAck(int entryIndex) { + this.lastAcknowledgedEntry = Math.max(entryIndex, this.lastAcknowledgedEntry); + if (canExecuteSideEffect()) { + this.consumeCallback(OnEnterSideEffectCallback::onEnter); + } + } + + void registerExecutedSideEffect(int entryIndex) { + this.lastExecutedSideEffect = entryIndex; + } + + private boolean canExecuteSideEffect() { + return this.lastExecutedSideEffect <= this.lastAcknowledgedEntry; + } + + public int getLastExecutedSideEffect() { + return lastExecutedSideEffect; + } + + @Override + void abort(Throwable cause) { + super.abort(cause); + // We can't do anything else if the input stream is closed, so we just fail the callback, if any + this.tryFailCallback(); + } +} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SuspendableCallback.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SuspendableCallback.java new file mode 100644 index 00000000..029d28c9 --- /dev/null +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SuspendableCallback.java @@ -0,0 +1,8 @@ +package dev.restate.sdk.core.impl; + +interface SuspendableCallback { + + void onSuspend(); + + void onError(Throwable e); +} diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java index b78a5216..a11217ab 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/SyscallsImpl.java @@ -15,12 +15,18 @@ import dev.restate.sdk.core.impl.ReadyResults.ReadyResultInternal; import dev.restate.sdk.core.serde.CustomSerdeFunctionsTypeTag; import dev.restate.sdk.core.serde.Serde; -import dev.restate.sdk.core.syscalls.*; +import dev.restate.sdk.core.syscalls.DeferredResult; +import dev.restate.sdk.core.syscalls.EnterSideEffectSyscallCallback; +import dev.restate.sdk.core.syscalls.ExitSideEffectSyscallCallback; +import dev.restate.sdk.core.syscalls.SyscallCallback; import io.grpc.MethodDescriptor; import java.nio.ByteBuffer; import java.time.Duration; import java.time.Instant; -import java.util.*; +import java.util.AbstractMap; +import java.util.Base64; +import java.util.Map; +import java.util.Objects; import java.util.function.Consumer; import java.util.function.Function; import javax.annotation.Nonnull; @@ -165,10 +171,7 @@ public void enterSideEffectBlock( TypeTag typeTag, EnterSideEffectSyscallCallback callback) { LOG.trace("enterSideEffectBlock"); this.stateMachine.enterSideEffectBlock( - span -> span.addEvent("Enter SideEffect"), - sideEffectEntryHandler(typeTag, callback), - callback::onNotExecuted, - callback::onCancel); + sideEffectEntryHandler(typeTag, callback), callback::onNotExecuted, callback::onCancel); } @Override @@ -177,7 +180,6 @@ public void exitSideEffectBlock( LOG.trace("exitSideEffectBlock with success"); this.stateMachine.exitSideEffectBlock( Java.SideEffectEntryMessage.newBuilder().setValue(serialize(typeTag, toWrite)).build(), - span -> span.addEvent("Exit SideEffect"), sideEffectEntryHandler(typeTag, callback), callback::onCancel); } @@ -214,7 +216,6 @@ public void exitSideEffectBlockWithException( this.stateMachine.exitSideEffectBlock( Java.SideEffectEntryMessage.newBuilder().setFailure(toProtocolFailure(toWrite)).build(), - span -> span.addEvent("Exit SideEffect"), sideEffectEntry -> callback.onFailure( Util.toGrpcStatus(sideEffectEntry.getFailure()).asRuntimeException()), diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/LocalStateStorage.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/UserStateStore.java similarity index 92% rename from sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/LocalStateStorage.java rename to sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/UserStateStore.java index 0a5939fc..c9be8ebd 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/LocalStateStorage.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/UserStateStore.java @@ -5,7 +5,7 @@ import java.util.Map; import java.util.stream.Collectors; -final class LocalStateStorage { +final class UserStateStore { interface State {} @@ -36,7 +36,7 @@ public ByteString getValue() { private final boolean isPartial; private final HashMap map; - LocalStateStorage(boolean isPartial, Map map) { + UserStateStore(boolean isPartial, Map map) { this.isPartial = isPartial; this.map = new HashMap<>( diff --git a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java index 4fe91e79..1175930e 100644 --- a/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java +++ b/sdk-core-impl/src/main/java/dev/restate/sdk/core/impl/Util.java @@ -4,7 +4,9 @@ import dev.restate.generated.sdk.java.Java; import dev.restate.generated.service.protocol.Protocol; import dev.restate.sdk.core.SuspendedException; -import io.grpc.*; +import io.grpc.Status; +import io.grpc.StatusException; +import io.grpc.StatusRuntimeException; import java.util.Objects; import java.util.Optional; import java.util.function.Predicate;