From 2f64c224701f54a75d70dafe2716722fc09988ad Mon Sep 17 00:00:00 2001 From: slinkydeveloper Date: Wed, 14 Aug 2024 11:44:32 +0200 Subject: [PATCH] Port the tests in the command_interpreter to the state machine tests. The only test that was skipped is send_response_using_invocation_id, as it was redundant. --- .../command_interpreter/tests.rs | 953 ------------------ .../worker/src/partition/state_machine/mod.rs | 900 +++++++++++++++-- 2 files changed, 822 insertions(+), 1031 deletions(-) delete mode 100644 crates/worker/src/partition/state_machine/command_interpreter/tests.rs diff --git a/crates/worker/src/partition/state_machine/command_interpreter/tests.rs b/crates/worker/src/partition/state_machine/command_interpreter/tests.rs deleted file mode 100644 index f07461d637..0000000000 --- a/crates/worker/src/partition/state_machine/command_interpreter/tests.rs +++ /dev/null @@ -1,953 +0,0 @@ -// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. -// All rights reserved. -// -// Use of this software is governed by the Business Source License -// included in the LICENSE file. -// -// As of the Change Date specified in that file, in accordance with -// the Business Source License, use of this software will be governed -// by the Apache License, Version 2.0. - -use super::*; - -use bytestring::ByteString; -use futures::stream; -use googletest::matcher::Matcher; -use googletest::{all, any, assert_that, pat, unordered_elements_are}; -use prost::Message; -use restate_invoker_api::EffectKind; -use restate_service_protocol::awakeable_id::AwakeableIdentifier; -use restate_service_protocol::codec::ProtobufRawEntryCodec; -use restate_storage_api::idempotency_table::IdempotencyMetadata; -use restate_storage_api::inbox_table::SequenceNumberInboxEntry; -use restate_storage_api::invocation_status_table::{JournalMetadata, StatusTimestamps}; -use restate_storage_api::promise_table::OwnedPromiseRow; -use restate_storage_api::timer_table::{TimerKey, TimerKeyKind}; -use restate_storage_api::{Result as StorageResult, StorageError}; -use restate_test_util::matchers::*; -use restate_test_util::{assert_eq, let_assert}; -use restate_types::errors::codes; -use restate_types::identifiers::{InvocationUuid, WithPartitionKey}; -use restate_types::invocation::InvocationTarget; -use restate_types::journal::EntryResult; -use restate_types::journal::{CompleteAwakeableEntry, Entry}; -use restate_types::service_protocol; -use std::collections::HashMap; -use test_log::test; - -use crate::partition::state_machine::command_interpreter::StateReader; -use crate::partition::state_machine::effects::Effect; - -#[derive(Default)] -struct StateReaderMock { - services: HashMap, - inboxes: HashMap>, - invocations: HashMap, - journals: HashMap>, -} - -impl StateReaderMock { - pub fn mock_invocation_metadata( - journal_length: u32, - invocation_target: InvocationTarget, - ) -> InFlightInvocationMetadata { - InFlightInvocationMetadata { - invocation_target, - journal_metadata: JournalMetadata::new( - journal_length, - ServiceInvocationSpanContext::empty(), - ), - ..InFlightInvocationMetadata::mock() - } - } - - fn lock_service(&mut self, service_id: ServiceId) { - self.services.insert( - service_id.clone(), - VirtualObjectStatus::Locked(InvocationId::from_parts( - service_id.partition_key(), - InvocationUuid::new(), - )), - ); - } - - fn register_invoked_status_and_locked( - &mut self, - invocation_target: InvocationTarget, - journal: Vec, - ) -> InvocationId { - let invocation_id = InvocationId::generate(&invocation_target); - - self.services.insert( - invocation_target.as_keyed_service_id().unwrap(), - VirtualObjectStatus::Locked(invocation_id), - ); - self.register_invocation_status( - invocation_id, - InvocationStatus::Invoked(Self::mock_invocation_metadata( - u32::try_from(journal.len()).unwrap(), - invocation_target, - )), - journal, - ); - - invocation_id - } - - fn register_suspended_status_and_locked( - &mut self, - invocation_target: InvocationTarget, - waiting_for_completed_entries: impl IntoIterator, - journal: Vec, - ) -> InvocationId { - let invocation_id = InvocationId::generate(&invocation_target); - - self.services.insert( - invocation_target.as_keyed_service_id().unwrap(), - VirtualObjectStatus::Locked(invocation_id), - ); - self.register_invocation_status( - invocation_id, - InvocationStatus::Suspended { - metadata: Self::mock_invocation_metadata( - u32::try_from(journal.len()).unwrap(), - invocation_target, - ), - waiting_for_completed_entries: HashSet::from_iter(waiting_for_completed_entries), - }, - journal, - ); - - invocation_id - } - - fn register_invocation_status( - &mut self, - invocation_id: InvocationId, - invocation_status: InvocationStatus, - journal: Vec, - ) { - self.invocations.insert(invocation_id, invocation_status); - self.journals.insert(invocation_id, journal); - } - - fn enqueue_into_inbox(&mut self, service_id: ServiceId, inbox_entry: SequenceNumberInboxEntry) { - assert_eq!( - service_id, - *inbox_entry.service_id(), - "Service invocation must have the same service_id as the inbox entry" - ); - - self.inboxes - .entry(service_id) - .or_default() - .push(inbox_entry); - } -} - -impl StateReader for StateReaderMock { - async fn get_virtual_object_status( - &mut self, - service_id: &ServiceId, - ) -> StorageResult { - Ok(self - .services - .get(service_id) - .cloned() - .unwrap_or(VirtualObjectStatus::Unlocked)) - } - - async fn get_invocation_status( - &mut self, - invocation_id: &InvocationId, - ) -> StorageResult { - Ok(self - .invocations - .get(invocation_id) - .cloned() - .unwrap_or(InvocationStatus::Free)) - } - - async fn is_entry_resumable( - &mut self, - _invocation_id: &InvocationId, - _entry_index: EntryIndex, - ) -> StorageResult { - todo!() - } - - async fn load_state( - &mut self, - _service_id: &ServiceId, - _key: &Bytes, - ) -> StorageResult> { - todo!() - } - - async fn load_state_keys(&mut self, _: &ServiceId) -> StorageResult> { - todo!() - } - - async fn load_completion_result( - &mut self, - _invocation_id: &InvocationId, - _entry_index: EntryIndex, - ) -> StorageResult> { - todo!() - } - - fn get_journal( - &mut self, - invocation_id: &InvocationId, - length: EntryIndex, - ) -> impl Stream> + Send { - ReadOnlyJournalTable::get_journal(self, invocation_id, length) - } -} - -impl ReadOnlyJournalTable for StateReaderMock { - fn get_journal_entry( - &mut self, - invocation_id: &InvocationId, - journal_index: u32, - ) -> impl Future>> + Send { - futures::future::ready(Ok(self - .journals - .get(invocation_id) - .and_then(|journal| journal.get(journal_index as usize).cloned()))) - } - - fn get_journal( - &mut self, - invocation_id: &InvocationId, - journal_length: EntryIndex, - ) -> impl Stream> + Send { - let journal = self.journals.get(invocation_id); - - let cloned_journal: Vec = journal - .map(|journal| { - journal - .iter() - .take( - usize::try_from(journal_length) - .expect("Converting from u32 to usize should be possible"), - ) - .cloned() - .collect() - }) - .unwrap_or_default(); - - stream::iter( - cloned_journal - .into_iter() - .enumerate() - .map(|(index, entry)| { - Ok(( - u32::try_from(index).expect("Journal must not be larger than 2^32 - 1"), - entry, - )) - }), - ) - } - - fn all_journals( - &self, - _range: RangeInclusive, - ) -> impl Stream> + Send { - unimplemented!(); - - // I need this for type inference to work - #[allow(unreachable_code)] - futures::stream::iter(vec![]) - } -} - -impl ReadOnlyIdempotencyTable for StateReaderMock { - async fn get_idempotency_metadata( - &mut self, - _idempotency_id: &IdempotencyId, - ) -> StorageResult> { - unimplemented!(); - } - - fn all_idempotency_metadata( - &self, - _range: RangeInclusive, - ) -> impl Stream> + Send { - unimplemented!(); - - // I need this for type inference to work - #[allow(unreachable_code)] - futures::stream::iter(vec![]) - } -} - -impl ReadOnlyPromiseTable for StateReaderMock { - async fn get_promise( - &mut self, - _service_id: &ServiceId, - _key: &ByteString, - ) -> StorageResult> { - unimplemented!(); - } - - fn all_promises( - &self, - _range: RangeInclusive, - ) -> impl Stream> + Send { - unimplemented!(); - - // I need this for type inference to work - #[allow(unreachable_code)] - futures::stream::iter(vec![]) - } -} - -#[test(tokio::test)] -async fn awakeable_with_success() { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut state_machine: CommandInterpreter = - CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - let mut effects = Effects::default(); - let mut state_reader = StateReaderMock::default(); - - let callee_invocation_id = InvocationId::mock_random(); - let entry = ProtobufRawEntryCodec::serialize_enriched(Entry::CompleteAwakeable( - CompleteAwakeableEntry { - id: AwakeableIdentifier::new(callee_invocation_id, 1) - .to_string() - .into(), - result: EntryResult::Success(Bytes::default()), - }, - )); - - let caller_invocation_id = state_reader.register_invoked_status_and_locked( - InvocationTarget::mock_virtual_object(), - vec![JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Awakeable { - is_completed: false, - }, - Bytes::default(), - ))], - ); - - state_machine - .on_apply( - Command::InvokerEffect(InvokerEffect { - invocation_id: caller_invocation_id, - kind: EffectKind::JournalEntry { - entry_index: 1, - entry, - }, - }), - &mut effects, - &mut state_reader, - ) - .await - .unwrap(); - let_assert!(Effect::EnqueueIntoOutbox { message, .. } = effects.drain().next().unwrap()); - let_assert!( - OutboxMessage::ServiceResponse(InvocationResponse { - id, - entry_index, - result: ResponseResult::Success(_), - }) = message - ); - assert_eq!(id, callee_invocation_id); - assert_eq!(entry_index, 1); -} - -#[test(tokio::test)] -async fn awakeable_with_failure() { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut state_machine: CommandInterpreter = - CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - let mut effects = Effects::default(); - let mut state_reader = StateReaderMock::default(); - - let callee_invocation_id = InvocationId::mock_random(); - let entry = ProtobufRawEntryCodec::serialize_enriched(Entry::CompleteAwakeable( - CompleteAwakeableEntry { - id: AwakeableIdentifier::new(callee_invocation_id, 1) - .to_string() - .into(), - result: EntryResult::Failure(codes::BAD_REQUEST, "Some failure".into()), - }, - )); - - let caller_invocation_id = state_reader.register_invoked_status_and_locked( - InvocationTarget::mock_virtual_object(), - vec![JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Awakeable { - is_completed: false, - }, - Bytes::default(), - ))], - ); - - state_machine - .on_apply( - Command::InvokerEffect(InvokerEffect { - invocation_id: caller_invocation_id, - kind: EffectKind::JournalEntry { - entry_index: 1, - entry, - }, - }), - &mut effects, - &mut state_reader, - ) - .await - .unwrap(); - let_assert!(Effect::EnqueueIntoOutbox { message, .. } = effects.drain().next().unwrap()); - let_assert!( - OutboxMessage::ServiceResponse(InvocationResponse { - id, - entry_index, - result: ResponseResult::Failure(failure), - }) = message - ); - assert_eq!(id, callee_invocation_id); - assert_eq!(entry_index, 1); - assert_eq!(failure.message(), "Some failure"); -} - -#[test(tokio::test)] -async fn send_response_using_invocation_id() { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut state_machine: CommandInterpreter = - CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - let mut effects = Effects::default(); - let mut state_reader = StateReaderMock::default(); - - let invocation_id = state_reader - .register_invoked_status_and_locked(InvocationTarget::mock_virtual_object(), vec![]); - - state_machine - .on_apply( - Command::InvocationResponse(InvocationResponse { - id: invocation_id, - entry_index: 1, - result: ResponseResult::Success(Bytes::from_static(b"hello")), - }), - &mut effects, - &mut state_reader, - ) - .await - .unwrap(); - assert_that!( - effects.into_inner(), - all!( - contains(pat!(Effect::StoreCompletion { - invocation_id: eq(invocation_id), - completion: pat!(Completion { entry_index: eq(1) }) - })), - contains(pat!(Effect::ForwardCompletion { - invocation_id: eq(invocation_id), - completion: pat!(Completion { entry_index: eq(1) }) - })) - ) - ); -} - -#[test(tokio::test)] -async fn kill_inboxed_invocation() -> Result<(), Error> { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut command_interpreter = CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - - let mut effects = Effects::default(); - let mut state_mock = StateReaderMock::default(); - - let (inboxed_invocation_id, inboxed_invocation_target) = - InvocationId::mock_with(InvocationTarget::mock_virtual_object()); - let caller_invocation_id = InvocationId::mock_random(); - - state_mock.lock_service(ServiceId::new("svc", "key")); - state_mock.enqueue_into_inbox( - inboxed_invocation_target.as_keyed_service_id().unwrap(), - SequenceNumberInboxEntry { - inbox_sequence_number: 0, - inbox_entry: InboxEntry::Invocation( - inboxed_invocation_target.as_keyed_service_id().unwrap(), - inboxed_invocation_id, - ), - }, - ); - state_mock.invocations.insert( - inboxed_invocation_id, - InvocationStatus::Inboxed(InboxedInvocation { - inbox_sequence_number: 0, - response_sinks: HashSet::from([ServiceInvocationResponseSink::PartitionProcessor { - caller: caller_invocation_id, - entry_index: 0, - }]), - timestamps: StatusTimestamps::now(), - invocation_target: inboxed_invocation_target.clone(), - argument: Default::default(), - source: Source::Ingress, - span_context: Default::default(), - headers: vec![], - execution_time: None, - completion_retention_time: Default::default(), - idempotency_key: None, - source_table: SourceTable::New, - }), - ); - - command_interpreter - .on_apply( - Command::TerminateInvocation(InvocationTermination::kill(inboxed_invocation_id)), - &mut effects, - &mut state_mock, - ) - .await?; - - assert_that!( - effects.into_inner(), - all!( - contains(pat!(Effect::DeleteInboxEntry { - service_id: eq(inboxed_invocation_target.as_keyed_service_id().unwrap(),), - sequence_number: eq(0) - })), - contains(pat!(Effect::EnqueueIntoOutbox { - message: pat!( - restate_storage_api::outbox_table::OutboxMessage::ServiceResponse(pat!( - InvocationResponse { - id: eq(caller_invocation_id), - entry_index: eq(0), - result: eq(ResponseResult::Failure(KILLED_INVOCATION_ERROR)) - } - )) - ) - })) - ) - ); - - Ok(()) -} - -#[test(tokio::test)] -async fn kill_call_tree() -> Result<(), Error> { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut command_interpreter = CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - let mut state_reader = StateReaderMock::default(); - let mut effects = Effects::default(); - - let call_invocation_id = InvocationId::mock_random(); - let background_call_invocation_id = InvocationId::mock_random(); - let finished_call_invocation_id = InvocationId::mock_random(); - - let invocation_target = InvocationTarget::mock_virtual_object(); - let invocation_id = state_reader.register_invoked_status_and_locked( - invocation_target.clone(), - vec![ - uncompleted_invoke_entry(call_invocation_id), - background_invoke_entry(background_call_invocation_id), - completed_invoke_entry(finished_call_invocation_id), - ], - ); - - command_interpreter - .on_apply( - Command::TerminateInvocation(InvocationTermination::kill(invocation_id)), - &mut effects, - &mut state_reader, - ) - .await?; - - let effects = effects.into_inner(); - - assert_that!( - effects, - all!( - contains(pat!(Effect::SendAbortInvocationToInvoker(eq( - invocation_id - )))), - contains(pat!(Effect::FreeInvocation(eq(invocation_id)))), - contains(pat!(Effect::DropJournal { - invocation_id: eq(invocation_id), - })), - contains(pat!(Effect::PopInbox(eq(invocation_target - .as_keyed_service_id() - .unwrap())))), - contains(terminate_invocation_outbox_message_matcher( - call_invocation_id, - TerminationFlavor::Kill - )), - not(contains(pat!(Effect::EnqueueIntoOutbox { - message: pat!( - restate_storage_api::outbox_table::OutboxMessage::InvocationTermination(pat!( - InvocationTermination { - invocation_id: any!( - eq(background_call_invocation_id), - eq(finished_call_invocation_id) - ) - } - )) - ) - }))) - ) - ); - - Ok(()) -} - -fn completed_invoke_entry(invocation_id: InvocationId) -> JournalEntry { - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Call { - is_completed: true, - enrichment_result: Some(CallEnrichmentResult { - invocation_id, - invocation_target: InvocationTarget::mock_service(), - completion_retention_time: None, - span_context: ServiceInvocationSpanContext::empty(), - }), - }, - Bytes::default(), - )) -} - -fn background_invoke_entry(invocation_id: InvocationId) -> JournalEntry { - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::OneWayCall { - enrichment_result: CallEnrichmentResult { - invocation_id, - invocation_target: InvocationTarget::mock_service(), - completion_retention_time: None, - span_context: ServiceInvocationSpanContext::empty(), - }, - }, - Bytes::default(), - )) -} - -fn uncompleted_invoke_entry(invocation_id: InvocationId) -> JournalEntry { - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Call { - is_completed: false, - enrichment_result: Some(CallEnrichmentResult { - invocation_id, - invocation_target: InvocationTarget::mock_service(), - completion_retention_time: None, - span_context: ServiceInvocationSpanContext::empty(), - }), - }, - Bytes::default(), - )) -} - -#[test(tokio::test)] -async fn cancel_invoked_invocation() -> Result<(), Error> { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut command_interpreter = CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - let mut state_reader = StateReaderMock::default(); - let mut effects = Effects::default(); - - let call_invocation_id = InvocationId::mock_random(); - let background_call_invocation_id = InvocationId::mock_random(); - let finished_call_invocation_id = InvocationId::mock_random(); - - let invocation_id = state_reader.register_invoked_status_and_locked( - InvocationTarget::mock_virtual_object(), - create_termination_journal( - call_invocation_id, - background_call_invocation_id, - finished_call_invocation_id, - ), - ); - - command_interpreter - .on_apply( - Command::TerminateInvocation(InvocationTermination::cancel(invocation_id)), - &mut effects, - &mut state_reader, - ) - .await?; - - let effects = effects.into_inner(); - - assert_that!( - effects, - unordered_elements_are![ - terminate_invocation_outbox_message_matcher( - call_invocation_id, - TerminationFlavor::Cancel - ), - store_canceled_completion_matcher(4), - store_canceled_completion_matcher(5), - store_canceled_completion_matcher(6), - forward_canceled_completion_matcher(4), - forward_canceled_completion_matcher(5), - forward_canceled_completion_matcher(6), - delete_timer(5), - ] - ); - - Ok(()) -} - -#[test(tokio::test)] -async fn cancel_suspended_invocation() -> Result<(), Error> { - let partition_key_range = PartitionKey::MIN..=PartitionKey::MAX; - let mut command_interpreter = CommandInterpreter::::new( - 0, - 0, - None, - partition_key_range, - SourceTable::New, - ); - let mut state_reader = StateReaderMock::default(); - let mut effects = Effects::default(); - - let call_invocation_id = InvocationId::mock_random(); - let background_call_invocation_id = InvocationId::mock_random(); - let finished_call_invocation_id = InvocationId::mock_random(); - - let journal = create_termination_journal( - call_invocation_id, - background_call_invocation_id, - finished_call_invocation_id, - ); - let invocation_id = state_reader.register_suspended_status_and_locked( - InvocationTarget::mock_virtual_object(), - vec![3, 4, 5, 6], - journal, - ); - - command_interpreter - .on_apply( - Command::TerminateInvocation(InvocationTermination::cancel(invocation_id)), - &mut effects, - &mut state_reader, - ) - .await?; - - let effects = effects.into_inner(); - - assert_that!( - effects, - unordered_elements_are![ - terminate_invocation_outbox_message_matcher( - call_invocation_id, - TerminationFlavor::Cancel - ), - store_canceled_completion_matcher(4), - store_canceled_completion_matcher(5), - store_canceled_completion_matcher(6), - delete_timer(5), - pat!(Effect::ResumeService { - invocation_id: eq(invocation_id), - }), - ] - ); - - Ok(()) -} - -#[test(tokio::test)] -async fn truncate_outbox_from_empty() -> Result<(), Error> { - // An outbox message with index 0 has been successfully processed, and must now be truncated - let outbox_index = 0; - - let mut command_interpreter = CommandInterpreter::::new( - 0, - 0, - None, - PartitionKey::MIN..=PartitionKey::MAX, - SourceTable::New, - ); - let mut state_reader = StateReaderMock::default(); - let mut effects = Effects::default(); - - command_interpreter - .on_apply( - Command::TruncateOutbox(outbox_index), - &mut effects, - &mut state_reader, - ) - .await?; - - let effects = effects.into_inner(); - - assert_that!( - effects, - unordered_elements_are![pat!(Effect::TruncateOutbox(eq(RangeInclusive::new( - outbox_index, - outbox_index - ))))] - ); - - // The head catches up to the next available sequence number on truncation. Since we don't know - // in advance whether we will get asked to truncate a range of more than one outbox message, we - // explicitly track the head sequence number as the next position beyond the last known - // truncation point. It's only safe to leave the head as None when the outbox is known to be - // empty. - assert_eq!(command_interpreter.outbox_head_seq_number, Some(1)); - - Ok(()) -} - -#[test(tokio::test)] -async fn truncate_outbox_with_gap() -> Result<(), Error> { - // The outbox contains items [3..=5], and the range must be truncated after message 5 is processed - let outbox_head_index = 3; - let outbox_tail_index = 5; - - let mut command_interpreter = CommandInterpreter::::new( - 0, - outbox_tail_index, - Some(outbox_head_index), - PartitionKey::MIN..=PartitionKey::MAX, - SourceTable::New, - ); - let mut state_reader = StateReaderMock::default(); - let mut effects = Effects::default(); - - command_interpreter - .on_apply( - Command::TruncateOutbox(outbox_tail_index), - &mut effects, - &mut state_reader, - ) - .await?; - - let effects = effects.into_inner(); - - assert_that!( - effects, - unordered_elements_are![pat!(Effect::TruncateOutbox(eq(RangeInclusive::new( - outbox_head_index, - outbox_tail_index - ))))] - ); - - assert_eq!( - command_interpreter.outbox_head_seq_number, - Some(outbox_tail_index + 1) - ); - - Ok(()) -} - -fn create_termination_journal( - call_invocation_id: InvocationId, - background_invocation_id: InvocationId, - finished_call_invocation_id: InvocationId, -) -> Vec { - vec![ - uncompleted_invoke_entry(call_invocation_id), - completed_invoke_entry(finished_call_invocation_id), - background_invoke_entry(background_invocation_id), - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Input {}, - Bytes::default(), - )), - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::GetState { - is_completed: false, - }, - Bytes::default(), - )), - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Sleep { - is_completed: false, - }, - service_protocol::SleepEntryMessage { - wake_up_time: 1337, - result: None, - ..Default::default() - } - .encode_to_vec() - .into(), - )), - JournalEntry::Entry(EnrichedRawEntry::new( - EnrichedEntryHeader::Awakeable { - is_completed: false, - }, - Bytes::default(), - )), - ] -} - -fn canceled_completion_matcher(entry_index: EntryIndex) -> impl Matcher { - pat!(Completion { - entry_index: eq(entry_index), - result: pat!(CompletionResult::Failure( - eq(codes::ABORTED), - eq(ByteString::from_static("canceled")) - )) - }) -} - -fn store_canceled_completion_matcher(entry_index: EntryIndex) -> impl Matcher { - pat!(Effect::StoreCompletion { - completion: canceled_completion_matcher(entry_index), - }) -} - -fn forward_canceled_completion_matcher(entry_index: EntryIndex) -> impl Matcher { - pat!(Effect::ForwardCompletion { - completion: canceled_completion_matcher(entry_index), - }) -} - -fn delete_timer(entry_index: EntryIndex) -> impl Matcher { - pat!(Effect::DeleteTimer(pat!(TimerKey { - kind: pat!(TimerKeyKind::CompleteJournalEntry { - journal_index: eq(entry_index), - }), - timestamp: eq(1337), - }))) -} - -fn terminate_invocation_outbox_message_matcher( - target_invocation_id: InvocationId, - termination_flavor: TerminationFlavor, -) -> impl Matcher { - pat!(Effect::EnqueueIntoOutbox { - message: pat!( - restate_storage_api::outbox_table::OutboxMessage::InvocationTermination(pat!( - InvocationTermination { - invocation_id: eq(target_invocation_id), - flavor: eq(termination_flavor) - } - )) - ) - }) -} diff --git a/crates/worker/src/partition/state_machine/mod.rs b/crates/worker/src/partition/state_machine/mod.rs index bca60c07e5..485694fa49 100644 --- a/crates/worker/src/partition/state_machine/mod.rs +++ b/crates/worker/src/partition/state_machine/mod.rs @@ -130,16 +130,16 @@ mod tests { use crate::partition::types::{InvokerEffect, InvokerEffectKind}; use ::tracing::info; - use assert2::assert; use bytes::Bytes; use bytestring::ByteString; use futures::{StreamExt, TryStreamExt}; use googletest::matcher::Matcher; use googletest::{all, assert_that, pat, property}; use restate_core::{task_center, TaskCenterBuilder}; - use restate_invoker_api::InvokeInputJournal; + use restate_invoker_api::{EffectKind, InvokeInputJournal}; use restate_partition_store::{OpenMode, PartitionStore, PartitionStoreManager}; use restate_rocksdb::RocksDbManager; + use restate_service_protocol::awakeable_id::AwakeableIdentifier; use restate_service_protocol::codec::ProtobufRawEntryCodec; use restate_storage_api::invocation_status_table::{ InFlightInvocationMetadata, InvocationStatus, InvocationStatusTable, @@ -154,7 +154,7 @@ mod tests { use restate_storage_api::Transaction; use restate_test_util::matchers::*; use restate_types::config::{CommonOptions, WorkerOptions}; - use restate_types::errors::KILLED_INVOCATION_ERROR; + use restate_types::errors::{codes, InvocationError, KILLED_INVOCATION_ERROR}; use restate_types::identifiers::{ IngressRequestId, InvocationId, PartitionId, PartitionKey, ServiceId, }; @@ -164,7 +164,9 @@ mod tests { ServiceInvocation, ServiceInvocationResponseSink, Source, VirtualObjectHandlerType, }; use restate_types::journal::enriched::EnrichedRawEntry; - use restate_types::journal::{Completion, CompletionResult, EntryResult, InvokeRequest}; + use restate_types::journal::{ + CompleteAwakeableEntry, Completion, CompletionResult, EntryResult, InvokeRequest, + }; use restate_types::journal::{Entry, EntryType}; use restate_types::live::{Constant, Live}; use restate_types::state_mut::ExternalStateMutation; @@ -186,6 +188,19 @@ mod tests { } pub async fn create() -> Self { + Self::create_with_state_machine(StateMachine::new( + 0, /* inbox_seq_number */ + 0, /* outbox_seq_number */ + None, /* outbox_head_seq_number */ + PartitionKey::MIN..=PartitionKey::MAX, + SourceTable::New, + )) + .await + } + + pub async fn create_with_state_machine( + state_machine: StateMachine, + ) -> Self { task_center().run_in_scope_sync("db-manager-init", None, || { RocksDbManager::init(Constant::new(CommonOptions::default())) }); @@ -212,13 +227,7 @@ mod tests { .unwrap(); Self { - state_machine: StateMachine::new( - 0, /* inbox_seq_number */ - 0, /* outbox_seq_number */ - None, /* outbox_head_seq_number */ - PartitionKey::MIN..=PartitionKey::MAX, - SourceTable::New, - ), + state_machine, rocksdb_storage, } } @@ -422,7 +431,7 @@ mod tests { } #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] - async fn kill_inboxed_invocation() -> anyhow::Result<()> { + async fn complete_awakeable_with_success() { let tc = TaskCenterBuilder::default() .default_runtime_handle(tokio::runtime::Handle::current()) .build() @@ -430,90 +439,94 @@ mod tests { let mut state_machine = tc .run_in_scope("mock-state-machine", None, MockStateMachine::create()) .await; + let invocation_id = mock_start_invocation(&mut state_machine).await; - let (invocation_id, invocation_target) = - InvocationId::mock_with(InvocationTarget::mock_virtual_object()); - let (inboxed_id, inboxed_target) = InvocationId::mock_with(invocation_target.clone()); - let caller_id = InvocationId::mock_random(); + let callee_invocation_id = InvocationId::mock_random(); + let callee_entry_index = 10; + let entry = ProtobufRawEntryCodec::serialize_enriched(Entry::CompleteAwakeable( + CompleteAwakeableEntry { + id: AwakeableIdentifier::new(callee_invocation_id, callee_entry_index) + .to_string() + .into(), + result: EntryResult::Success(Bytes::default()), + }, + )); - let _ = state_machine - .apply(Command::Invoke(ServiceInvocation { + let actions = state_machine + .apply(Command::InvokerEffect(InvokerEffect { invocation_id, - invocation_target: invocation_target.clone(), - ..ServiceInvocation::mock() + kind: EffectKind::JournalEntry { + entry_index: 1, + entry, + }, })) .await; - let _ = state_machine - .apply(Command::Invoke(ServiceInvocation { - invocation_id: inboxed_id, - invocation_target: inboxed_target, - response_sink: Some(ServiceInvocationResponseSink::PartitionProcessor { - caller: caller_id, - entry_index: 0, - }), - ..ServiceInvocation::mock() + assert_that!( + actions, + contains(pat!(Action::NewOutboxMessage { + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::ServiceResponse(pat!( + restate_types::invocation::InvocationResponse { + id: eq(callee_invocation_id), + entry_index: eq(callee_entry_index), + result: pat!(ResponseResult::Success { .. }) + } + )) + ) })) - .await; + ); + } - let current_invocation_status = state_machine - .storage() - .transaction() - .get_invocation_status(&inboxed_id) - .await?; + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn complete_awakeable_with_failure() { + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope("mock-state-machine", None, MockStateMachine::create()) + .await; + let invocation_id = mock_start_invocation(&mut state_machine).await; - // assert that inboxed invocation is in invocation_status - assert!(let InvocationStatus::Inboxed(_) = current_invocation_status); + let callee_invocation_id = InvocationId::mock_random(); + let callee_entry_index = 10; + let entry = ProtobufRawEntryCodec::serialize_enriched(Entry::CompleteAwakeable( + CompleteAwakeableEntry { + id: AwakeableIdentifier::new(callee_invocation_id, callee_entry_index) + .to_string() + .into(), + result: EntryResult::Failure(codes::BAD_REQUEST, "Some failure".into()), + }, + )); let actions = state_machine - .apply(Command::TerminateInvocation(InvocationTermination::kill( - inboxed_id, - ))) + .apply(Command::InvokerEffect(InvokerEffect { + invocation_id, + kind: EffectKind::JournalEntry { + entry_index: 1, + entry, + }, + })) .await; - let current_invocation_status = state_machine - .storage() - .transaction() - .get_invocation_status(&inboxed_id) - .await?; - - // assert that invocation status was removed - assert!(let InvocationStatus::Free = current_invocation_status); - - fn outbox_message_matcher( - caller_id: InvocationId, - ) -> impl Matcher { - pat!( - restate_storage_api::outbox_table::OutboxMessage::ServiceResponse(pat!( - restate_types::invocation::InvocationResponse { - id: eq(caller_id), - entry_index: eq(0), - result: eq(ResponseResult::Failure(KILLED_INVOCATION_ERROR)) - } - )) - ) - } - assert_that!( actions, contains(pat!(Action::NewOutboxMessage { - message: outbox_message_matcher(caller_id) + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::ServiceResponse(pat!( + restate_types::invocation::InvocationResponse { + id: eq(callee_invocation_id), + entry_index: eq(callee_entry_index), + result: eq(ResponseResult::Failure(InvocationError::new( + codes::BAD_REQUEST, + "Some failure" + ))) + } + )) + ) })) ); - - let partition_id = state_machine.partition_id(); - let outbox_message = state_machine - .storage() - .transaction() - .get_next_outbox_message(partition_id, 0) - .await?; - - assert_that!( - outbox_message, - some((ge(0), outbox_message_matcher(caller_id))) - ); - - Ok(()) } #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] @@ -2121,6 +2134,737 @@ mod tests { } } + mod kill_cancel { + use super::*; + + use assert2::assert; + use assert2::let_assert; + use googletest::any; + use prost::Message; + use restate_storage_api::journal_table::JournalTable; + use restate_storage_api::timer_table::{Timer, TimerKey, TimerKeyKind, TimerTable}; + use restate_types::identifiers::EntryIndex; + use restate_types::invocation::{ServiceInvocationSpanContext, TerminationFlavor}; + use restate_types::journal::enriched::{CallEnrichmentResult, EnrichedEntryHeader}; + use restate_types::service_protocol; + use test_log::test; + + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn kill_inboxed_invocation() -> anyhow::Result<()> { + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope("mock-state-machine", None, MockStateMachine::create()) + .await; + + let (invocation_id, invocation_target) = + InvocationId::mock_with(InvocationTarget::mock_virtual_object()); + let (inboxed_id, inboxed_target) = InvocationId::mock_with(invocation_target.clone()); + let caller_id = InvocationId::mock_random(); + + let _ = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + let _ = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id: inboxed_id, + invocation_target: inboxed_target, + response_sink: Some(ServiceInvocationResponseSink::PartitionProcessor { + caller: caller_id, + entry_index: 0, + }), + ..ServiceInvocation::mock() + })) + .await; + + let current_invocation_status = state_machine + .storage() + .transaction() + .get_invocation_status(&inboxed_id) + .await?; + + // assert that inboxed invocation is in invocation_status + assert!(let InvocationStatus::Inboxed(_) = current_invocation_status); + + let actions = state_machine + .apply(Command::TerminateInvocation(InvocationTermination::kill( + inboxed_id, + ))) + .await; + + let current_invocation_status = state_machine + .storage() + .transaction() + .get_invocation_status(&inboxed_id) + .await?; + + // assert that invocation status was removed + assert!(let InvocationStatus::Free = current_invocation_status); + + fn outbox_message_matcher( + caller_id: InvocationId, + ) -> impl Matcher + { + pat!( + restate_storage_api::outbox_table::OutboxMessage::ServiceResponse(pat!( + restate_types::invocation::InvocationResponse { + id: eq(caller_id), + entry_index: eq(0), + result: eq(ResponseResult::Failure(KILLED_INVOCATION_ERROR)) + } + )) + ) + } + + assert_that!( + actions, + contains(pat!(Action::NewOutboxMessage { + message: outbox_message_matcher(caller_id) + })) + ); + + let partition_id = state_machine.partition_id(); + let outbox_message = state_machine + .storage() + .transaction() + .get_next_outbox_message(partition_id, 0) + .await?; + + assert_that!( + outbox_message, + some((ge(0), outbox_message_matcher(caller_id))) + ); + + Ok(()) + } + + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn kill_call_tree() -> anyhow::Result<()> { + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope("mock-state-machine", None, MockStateMachine::create()) + .await; + + let call_invocation_id = InvocationId::mock_random(); + let background_call_invocation_id = InvocationId::mock_random(); + let finished_call_invocation_id = InvocationId::mock_random(); + + let invocation_target = InvocationTarget::mock_virtual_object(); + let invocation_id = InvocationId::generate(&invocation_target); + let enqueued_invocation_id_on_same_target = InvocationId::generate(&invocation_target); + + let _ = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + // Let's enqueue an invocation afterward + let _ = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id: enqueued_invocation_id_on_same_target, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + // Let's add some journal entries + let mut tx = state_machine.rocksdb_storage.transaction(); + tx.put_journal_entry( + &invocation_id, + 1, + uncompleted_invoke_entry(call_invocation_id), + ) + .await; + tx.put_journal_entry( + &invocation_id, + 2, + background_invoke_entry(background_call_invocation_id), + ) + .await; + tx.put_journal_entry( + &invocation_id, + 3, + completed_invoke_entry(finished_call_invocation_id), + ) + .await; + let mut invocation_status = tx.get_invocation_status(&invocation_id).await?; + invocation_status.get_journal_metadata_mut().unwrap().length = 4; + tx.put_invocation_status(&invocation_id, invocation_status) + .await; + tx.commit().await?; + + // Now let's send the termination command + let actions = state_machine + .apply(Command::TerminateInvocation(InvocationTermination::kill( + invocation_id, + ))) + .await; + + // Invocation should be gone + assert_that!( + state_machine + .rocksdb_storage + .get_invocation_status(&invocation_id) + .await?, + pat!(InvocationStatus::Free) + ); + assert_that!( + state_machine + .rocksdb_storage + .get_journal(&invocation_id, 4) + .try_collect::>() + .await?, + empty() + ); + + assert_that!( + actions, + all!( + contains(pat!(Action::AbortInvocation(eq(invocation_id)))), + contains(pat!(Action::Invoke { + invocation_id: eq(enqueued_invocation_id_on_same_target), + invocation_target: eq(invocation_target) + })), + contains(terminate_invocation_outbox_message_matcher( + call_invocation_id, + TerminationFlavor::Kill + )), + not(contains(pat!(Action::NewOutboxMessage { + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination( + pat!(InvocationTermination { + invocation_id: any!( + eq(background_call_invocation_id), + eq(finished_call_invocation_id) + ) + }) + ) + ) + }))) + ) + ); + + Ok(()) + } + + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn cancel_invoked_invocation() -> Result<(), Error> { + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope("mock-state-machine", None, MockStateMachine::create()) + .await; + + let call_invocation_id = InvocationId::mock_random(); + let background_call_invocation_id = InvocationId::mock_random(); + let finished_call_invocation_id = InvocationId::mock_random(); + + let invocation_target = InvocationTarget::mock_virtual_object(); + let invocation_id = InvocationId::generate(&invocation_target); + + let _ = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + // Let's add some journal entries + let partition_id = state_machine.partition_id(); + let mut tx = state_machine.rocksdb_storage.transaction(); + let journal = create_termination_journal( + call_invocation_id, + background_call_invocation_id, + finished_call_invocation_id, + ); + let journal_length = journal.len(); + let (sleep_entry_idx, _) = journal + .iter() + .enumerate() + .find(|(_, j)| { + if let JournalEntry::Entry(e) = j { + e.header().as_entry_type() == EntryType::Sleep + } else { + false + } + }) + .unwrap(); + for (idx, entry) in journal.into_iter().enumerate() { + tx.put_journal_entry(&invocation_id, (idx + 1) as u32, entry) + .await; + } + // Update journal length + let mut invocation_status = tx.get_invocation_status(&invocation_id).await?; + invocation_status.get_journal_metadata_mut().unwrap().length = + (journal_length + 1) as EntryIndex; + tx.put_invocation_status(&invocation_id, invocation_status) + .await; + // Add timer + tx.add_timer( + partition_id, + &TimerKey { + timestamp: 1337, + kind: TimerKeyKind::CompleteJournalEntry { + invocation_uuid: invocation_id.invocation_uuid(), + journal_index: (sleep_entry_idx + 1) as u32, + }, + }, + Timer::CompleteJournalEntry(invocation_id, (sleep_entry_idx + 1) as u32), + ) + .await; + tx.commit().await?; + + let actions = state_machine + .apply(Command::TerminateInvocation(InvocationTermination::cancel( + invocation_id, + ))) + .await; + + // Invocation shouldn't be gone + assert_that!( + state_machine + .rocksdb_storage + .get_invocation_status(&invocation_id) + .await?, + pat!(InvocationStatus::Invoked { .. }) + ); + + // Timer is gone + assert_that!( + state_machine + .rocksdb_storage + .next_timers_greater_than(state_machine.partition_id(), None, usize::MAX) + .try_collect::>() + .await?, + empty() + ); + + // Entries are completed + assert_that!( + state_machine + .rocksdb_storage + .get_journal_entry(&invocation_id, 4) + .await?, + some(pat!(JournalEntry::Entry(entry_completed_matcher()))) + ); + assert_that!( + state_machine + .rocksdb_storage + .get_journal_entry(&invocation_id, 5) + .await?, + some(pat!(JournalEntry::Entry(entry_completed_matcher()))) + ); + assert_that!( + state_machine + .rocksdb_storage + .get_journal_entry(&invocation_id, 6) + .await?, + some(pat!(JournalEntry::Entry(entry_completed_matcher()))) + ); + + assert_that!( + actions, + all!( + contains(terminate_invocation_outbox_message_matcher( + call_invocation_id, + TerminationFlavor::Cancel + )), + contains(forward_canceled_completion_matcher(4)), + contains(forward_canceled_completion_matcher(5)), + contains(forward_canceled_completion_matcher(6)), + contains(delete_timer_matcher(5)), + ) + ); + + Ok(()) + } + + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn cancel_suspended_invocation() -> Result<(), Error> { + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope("mock-state-machine", None, MockStateMachine::create()) + .await; + + let call_invocation_id = InvocationId::mock_random(); + let background_call_invocation_id = InvocationId::mock_random(); + let finished_call_invocation_id = InvocationId::mock_random(); + + let invocation_target = InvocationTarget::mock_virtual_object(); + let invocation_id = InvocationId::generate(&invocation_target); + + let _ = state_machine + .apply(Command::Invoke(ServiceInvocation { + invocation_id, + invocation_target: invocation_target.clone(), + ..ServiceInvocation::mock() + })) + .await; + + // Let's add some journal entries + let partition_id = state_machine.partition_id(); + let mut tx = state_machine.rocksdb_storage.transaction(); + let journal = create_termination_journal( + call_invocation_id, + background_call_invocation_id, + finished_call_invocation_id, + ); + let journal_length = journal.len(); + let (sleep_entry_idx, _) = journal + .iter() + .enumerate() + .find(|(_, j)| { + if let JournalEntry::Entry(e) = j { + e.header().as_entry_type() == EntryType::Sleep + } else { + false + } + }) + .unwrap(); + for (idx, entry) in journal.into_iter().enumerate() { + tx.put_journal_entry(&invocation_id, (idx + 1) as u32, entry) + .await; + } + // Update journal length and suspend invocation + let invocation_status = tx.get_invocation_status(&invocation_id).await?; + let_assert!(InvocationStatus::Invoked(mut in_flight_meta) = invocation_status); + in_flight_meta.journal_metadata.length = (journal_length + 1) as EntryIndex; + tx.put_invocation_status( + &invocation_id, + InvocationStatus::Suspended { + metadata: in_flight_meta, + waiting_for_completed_entries: HashSet::from([3, 4, 5, 6]), + }, + ) + .await; + // Add timer + tx.add_timer( + partition_id, + &TimerKey { + timestamp: 1337, + kind: TimerKeyKind::CompleteJournalEntry { + invocation_uuid: invocation_id.invocation_uuid(), + journal_index: (sleep_entry_idx + 1) as u32, + }, + }, + Timer::CompleteJournalEntry(invocation_id, (sleep_entry_idx + 1) as u32), + ) + .await; + tx.commit().await?; + + let actions = state_machine + .apply(Command::TerminateInvocation(InvocationTermination::cancel( + invocation_id, + ))) + .await; + + // Invocation shouldn't be gone + assert_that!( + state_machine + .rocksdb_storage + .get_invocation_status(&invocation_id) + .await?, + pat!(InvocationStatus::Invoked { .. }) + ); + + // Timer is gone + assert_that!( + state_machine + .rocksdb_storage + .next_timers_greater_than(state_machine.partition_id(), None, usize::MAX) + .try_collect::>() + .await?, + empty() + ); + + // Entries are completed + assert_that!( + state_machine + .rocksdb_storage + .get_journal_entry(&invocation_id, 4) + .await?, + some(pat!(JournalEntry::Entry(entry_completed_matcher()))) + ); + assert_that!( + state_machine + .rocksdb_storage + .get_journal_entry(&invocation_id, 5) + .await?, + some(pat!(JournalEntry::Entry(entry_completed_matcher()))) + ); + assert_that!( + state_machine + .rocksdb_storage + .get_journal_entry(&invocation_id, 6) + .await?, + some(pat!(JournalEntry::Entry(entry_completed_matcher()))) + ); + + assert_that!( + actions, + all!( + contains(terminate_invocation_outbox_message_matcher( + call_invocation_id, + TerminationFlavor::Cancel + )), + contains(delete_timer_matcher(5)), + contains(pat!(Action::Invoke { + invocation_id: eq(invocation_id), + invocation_target: eq(invocation_target) + })) + ) + ); + + Ok(()) + } + + fn completed_invoke_entry(invocation_id: InvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Call { + is_completed: true, + enrichment_result: Some(CallEnrichmentResult { + invocation_id, + invocation_target: InvocationTarget::mock_service(), + completion_retention_time: None, + span_context: ServiceInvocationSpanContext::empty(), + }), + }, + Bytes::default(), + )) + } + + fn background_invoke_entry(invocation_id: InvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::OneWayCall { + enrichment_result: CallEnrichmentResult { + invocation_id, + invocation_target: InvocationTarget::mock_service(), + completion_retention_time: None, + span_context: ServiceInvocationSpanContext::empty(), + }, + }, + Bytes::default(), + )) + } + + fn uncompleted_invoke_entry(invocation_id: InvocationId) -> JournalEntry { + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Call { + is_completed: false, + enrichment_result: Some(CallEnrichmentResult { + invocation_id, + invocation_target: InvocationTarget::mock_service(), + completion_retention_time: None, + span_context: ServiceInvocationSpanContext::empty(), + }), + }, + Bytes::default(), + )) + } + + fn create_termination_journal( + call_invocation_id: InvocationId, + background_invocation_id: InvocationId, + finished_call_invocation_id: InvocationId, + ) -> Vec { + vec![ + uncompleted_invoke_entry(call_invocation_id), + completed_invoke_entry(finished_call_invocation_id), + background_invoke_entry(background_invocation_id), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::GetState { + is_completed: false, + }, + Bytes::default(), + )), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Sleep { + is_completed: false, + }, + service_protocol::SleepEntryMessage { + wake_up_time: 1337, + result: None, + ..Default::default() + } + .encode_to_vec() + .into(), + )), + JournalEntry::Entry(EnrichedRawEntry::new( + EnrichedEntryHeader::Awakeable { + is_completed: false, + }, + Bytes::default(), + )), + ] + } + + fn canceled_completion_matcher( + entry_index: EntryIndex, + ) -> impl Matcher { + pat!(Completion { + entry_index: eq(entry_index), + result: pat!(CompletionResult::Failure( + eq(codes::ABORTED), + eq(ByteString::from_static("canceled")) + )) + }) + } + + fn entry_completed_matcher() -> impl Matcher { + predicate(|e: &EnrichedRawEntry| e.header().is_completed().unwrap_or(false)) + .with_description("completed entry", "uncompleted entry") + } + + fn forward_canceled_completion_matcher( + entry_index: EntryIndex, + ) -> impl Matcher { + pat!(Action::ForwardCompletion { + completion: canceled_completion_matcher(entry_index), + }) + } + + fn delete_timer_matcher(entry_index: EntryIndex) -> impl Matcher { + pat!(Action::DeleteTimer { + timer_key: pat!(TimerKey { + kind: pat!(TimerKeyKind::CompleteJournalEntry { + journal_index: eq(entry_index), + }), + timestamp: eq(1337), + }) + }) + } + + fn terminate_invocation_outbox_message_matcher( + target_invocation_id: InvocationId, + termination_flavor: TerminationFlavor, + ) -> impl Matcher { + pat!(Action::NewOutboxMessage { + message: pat!( + restate_storage_api::outbox_table::OutboxMessage::InvocationTermination(pat!( + InvocationTermination { + invocation_id: eq(target_invocation_id), + flavor: eq(termination_flavor) + } + )) + ) + }) + } + } + + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn truncate_outbox_from_empty() -> Result<(), Error> { + // An outbox message with index 0 has been successfully processed, and must now be truncated + let outbox_index = 0; + + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope("mock-state-machine", None, MockStateMachine::create()) + .await; + + let _ = state_machine + .apply(Command::TruncateOutbox(outbox_index)) + .await; + + assert_that!( + state_machine + .rocksdb_storage + .get_outbox_message(state_machine.partition_id(), 0) + .await?, + none() + ); + + // The head catches up to the next available sequence number on truncation. Since we don't know + // in advance whether we will get asked to truncate a range of more than one outbox message, we + // explicitly track the head sequence number as the next position beyond the last known + // truncation point. It's only safe to leave the head as None when the outbox is known to be + // empty. + assert_eq!(state_machine.state_machine.outbox_head_seq_number, Some(1)); + + Ok(()) + } + + #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] + async fn truncate_outbox_with_gap() -> Result<(), Error> { + // The outbox contains items [3..=5], and the range must be truncated after message 5 is processed + let outbox_head_index = 3; + let outbox_tail_index = 5; + + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); + let mut state_machine = tc + .run_in_scope( + "mock-state-machine", + None, + MockStateMachine::create_with_state_machine( + StateMachine::::new( + 0, + outbox_tail_index, + Some(outbox_head_index), + PartitionKey::MIN..=PartitionKey::MAX, + SourceTable::New, + ), + ), + ) + .await; + + state_machine + .apply(Command::TruncateOutbox(outbox_tail_index)) + .await; + + assert_that!( + state_machine + .rocksdb_storage + .get_outbox_message(state_machine.partition_id(), 3) + .await?, + none() + ); + assert_that!( + state_machine + .rocksdb_storage + .get_outbox_message(state_machine.partition_id(), 4) + .await?, + none() + ); + assert_that!( + state_machine + .rocksdb_storage + .get_outbox_message(state_machine.partition_id(), 5) + .await?, + none() + ); + + assert_eq!( + state_machine.state_machine.outbox_head_seq_number, + Some(outbox_tail_index + 1) + ); + + Ok(()) + } + async fn mock_start_invocation_with_service_id( state_machine: &mut MockStateMachine, service_id: ServiceId,