Skip to content

Commit

Permalink
Restrict key type for key extraction to string (#968)
Browse files Browse the repository at this point in the history
* Forbid registration with non-string type keys
* Make sure we generate a string version of the uuid for unkeyed service key
* Remove the service_key poli views
* Remove the service key udfs
* Update protos as in restatedev/proto#31
* Propagate dev.event.Event changes.
* Remove length delimiter 
* Remove varint encoded length at the beginning of the key
  • Loading branch information
slinkydeveloper authored Dec 12, 2023
1 parent dc4a3df commit 5c45df3
Show file tree
Hide file tree
Showing 34 changed files with 127 additions and 990 deletions.
7 changes: 2 additions & 5 deletions crates/errors/src/error_codes/META0002.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@

Bad key definition encountered while registering/updating a service.
When a service is keyed, for each method the input message must have a field annotated with `dev.restate.ext.field`.
When defining the key field, make sure:

* The field type is either a primitive or a custom message, and not a repeated field nor a map.
* The field type is the same for every method input message of the same service.
The key field type must be `string`.

Example:

Expand All @@ -17,6 +14,6 @@ service HelloWorld {
}
message GreetingRequest {
Person person = 1 [(dev.restate.ext.field) = KEY];
string person_id = 1 [(dev.restate.ext.field) = KEY];
}
```
6 changes: 3 additions & 3 deletions crates/ingress-dispatcher/src/event_remapping.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ use std::fmt;
#[derive(Debug, thiserror::Error)]
#[error("Field {field_name} cannot be mapped to field tag {tag} because it's not a valid UTF-8 string: {reason}")]
pub struct Error {
field_name: &'static str,
tag: u32,
pub(crate) field_name: &'static str,
pub(crate) tag: u32,
#[source]
reason: core::str::Utf8Error,
pub(crate) reason: core::str::Utf8Error,
}

/// Structure that implements the remapping of the event fields.
Expand Down
26 changes: 10 additions & 16 deletions crates/ingress-dispatcher/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use bytes::{Buf, BufMut, Bytes, BytesMut};
use bytes::Bytes;
use bytestring::ByteString;
use prost::Message;
use restate_pb::restate::Event;
Expand Down Expand Up @@ -188,13 +188,19 @@ impl IngressRequest {
match instance_type {
EventReceiverServiceInstanceType::Keyed {
ordering_key_is_key,
} => generate_restate_key(if *ordering_key_is_key {
} => Bytes::from(if *ordering_key_is_key {
event.ordering_key.clone()
} else {
event.key.clone()
std::str::from_utf8(&event.key)
.map_err(|e| EventError {
field_name: "key",
tag: 2,
reason: e,
})?
.to_owned()
}),
EventReceiverServiceInstanceType::Unkeyed => {
Bytes::copy_from_slice(InvocationUuid::now_v7().as_bytes())
Bytes::from(InvocationUuid::now_v7().to_string())
}
EventReceiverServiceInstanceType::Singleton => Bytes::new(),
},
Expand Down Expand Up @@ -254,18 +260,6 @@ impl IngressRequest {
}
}

fn generate_restate_key(key: impl Buf) -> Bytes {
// Because this needs to be a valid Restate key, we need to prepend it with its length to make it
// look like it was extracted using the RestateKeyExtractor
// This is done to ensure all the other operations on the key will work correctly (e.g. key to json)
let key_len = key.remaining();
let mut buf =
BytesMut::with_capacity(prost::encoding::encoded_len_varint(key_len as u64) + key_len);
prost::encoding::encode_varint(key_len as u64, &mut buf);
buf.put(key);
buf.freeze()
}

// -- Types used by the network to interact with the ingress dispatcher service

pub type IngressDispatcherInputReceiver = mpsc::Receiver<IngressDispatcherInput>;
Expand Down
1 change: 1 addition & 0 deletions crates/ingress-kafka/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ restate-schema-api = { workspace = true, features = ["subscription"] }
restate-timer-queue = { workspace = true }
restate-types = { workspace = true }

base64 = { workspace = true }
bytes = { workspace = true }
derive_builder = { workspace = true }
drain = { workspace = true }
Expand Down
29 changes: 16 additions & 13 deletions crates/ingress-kafka/src/consumer_task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0.

use bytes::{BufMut, Bytes, BytesMut};
use base64::Engine;
use bytes::Bytes;
use opentelemetry_api::trace::TraceContextExt;
use rdkafka::consumer::{Consumer, DefaultConsumerContext, StreamConsumer};
use rdkafka::error::KafkaError;
Expand Down Expand Up @@ -161,20 +162,22 @@ impl MessageSender {
ordering_key_format: &KafkaOrderingKeyFormat,
ordering_key_prefix: &str,
msg: &impl Message,
) -> Bytes {
let mut buf = BytesMut::new();
buf.put(ordering_key_prefix.as_bytes());
buf.put(msg.topic().as_bytes());
buf.put_i32(msg.partition());
) -> String {
let partition = msg.partition().to_string();

match ordering_key_format {
KafkaOrderingKeyFormat::ConsumerGroupTopicPartitionKey if msg.key().is_some() => {
buf.put(msg.key().unwrap())
}
_ => {}
};
let mut buf =
String::with_capacity(ordering_key_prefix.len() + msg.topic().len() + partition.len());
buf.push_str(ordering_key_prefix);
buf.push_str(msg.topic());
buf.push_str(&partition);

if let (KafkaOrderingKeyFormat::ConsumerGroupTopicPartitionKey, Some(key)) =
(ordering_key_format, msg.key())
{
buf.push_str(&base64::prelude::BASE64_STANDARD.encode(key));
}

buf.freeze()
buf
}

fn generate_events_attributes(
Expand Down
22 changes: 1 addition & 21 deletions crates/pb/proto/dev/restate/events.proto
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,16 @@ syntax = "proto3";
package dev.restate;

import "dev/restate/ext.proto";
import "google/protobuf/empty.proto";
import "google/protobuf/struct.proto";

option java_multiple_files = true;
option java_package = "dev.restate.generated";
option go_package = "restate.dev/sdk-go/pb";

message Event {
bytes ordering_key = 1 [(dev.restate.ext.field) = KEY];
string ordering_key = 1 [(dev.restate.ext.field) = KEY];

bytes key = 2;
bytes payload = 3;

map<string, string> attributes = 15;
}

message KeyedEvent {
option deprecated = true;

// Payload
bytes key = 1 [(dev.restate.ext.field) = KEY];
bytes payload = 2 [(dev.restate.ext.field) = EVENT_PAYLOAD];
map<string, string> attributes = 15 [(dev.restate.ext.field) = EVENT_METADATA];
}

message StringKeyedEvent {
option deprecated = true;

// Payload
string key = 1 [(dev.restate.ext.field) = KEY];
bytes payload = 2 [(dev.restate.ext.field) = EVENT_PAYLOAD];
map<string, string> attributes = 15 [(dev.restate.ext.field) = EVENT_METADATA];
}
6 changes: 6 additions & 0 deletions crates/pb/proto/dev/restate/ext.proto
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,14 @@ enum ServiceType {

enum FieldType {
// protolint:disable:next ENUM_FIELD_NAMES_ZERO_VALUE_END_WITH
// Note: only string fields can be used for service key fields
KEY = 0;

// Flag a field as event payload. When receiving events, this field will be filled with the event payload.
// Note: only string fields can be used for event payload fields
EVENT_PAYLOAD = 1;
// Flag a field as event metadata. When receiving events, this field will be filled with the event metadata.
// Note: only type map<string, string> can be used for event payload fields
EVENT_METADATA = 2;
}

Expand Down
8 changes: 0 additions & 8 deletions crates/pb/tests/proto/event_handler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,3 @@ package eventhandler;
service EventHandler {
rpc Handle(dev.restate.Event) returns (google.protobuf.Empty);
}

service KeyedEventHandler {
rpc Handle(dev.restate.KeyedEvent) returns (google.protobuf.Empty);
}

service StringKeyedEventHandler {
rpc Handle(dev.restate.StringKeyedEvent) returns (google.protobuf.Empty);
}
2 changes: 2 additions & 0 deletions crates/schema-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ pub mod key {
UnexpectedServiceInstanceType,
#[error("unexpected value for a singleton service. Singleton service have no service key associated")]
UnexpectedNonNullSingletonKey,
#[error("bad unkeyed service key. Expected a string")]
BadUnkeyedKey,
#[error("error when decoding the json key: {0}")]
DecodeJson(#[from] serde_json::Error),
}
Expand Down
66 changes: 9 additions & 57 deletions crates/schema-impl/src/json_key_conversion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,9 @@ use bytes::Bytes;
use prost::Message;
use prost_reflect::{DynamicMessage, MethodDescriptor};
use restate_schema_api::key::json_conversion::{Error, RestateKeyConverter};
use restate_serde_util::SerdeableUuid;
use serde::de::IntoDeserializer;
use serde::{Deserialize, Serialize};
use serde::Serialize;
use serde_json::{Map, Value};
use uuid::Uuid;

impl RestateKeyConverter for Schemas {
fn key_to_json(
Expand Down Expand Up @@ -89,10 +87,7 @@ fn key_to_json(
})
}
InstanceTypeMetadata::Unkeyed => Ok(Value::String(
uuid::Builder::from_slice(key.as_ref())
.unwrap()
.into_uuid()
.to_string(),
String::from_utf8(key.as_ref().to_vec()).expect("Must be a valid UTF-8 string"),
)),
InstanceTypeMetadata::Singleton => Ok(Value::Object(Map::new())),
InstanceTypeMetadata::Unsupported => Err(Error::NotFound),
Expand Down Expand Up @@ -135,9 +130,11 @@ fn json_to_key(
)?)
}
InstanceTypeMetadata::Unkeyed => {
let parse_result: Uuid = SerdeableUuid::deserialize(key.into_deserializer())?.into();

Ok(parse_result.as_bytes().to_vec().into())
return if let Some(key_str) = key.as_str() {
Ok(Bytes::copy_from_slice(key_str.as_bytes()))
} else {
Err(Error::BadUnkeyedKey)
}
}
InstanceTypeMetadata::Singleton if key.is_null() => Ok(Bytes::default()),
InstanceTypeMetadata::Singleton => Err(Error::UnexpectedNonNullSingletonKey),
Expand All @@ -156,7 +153,7 @@ mod tests {
use restate_pb::mocks::test::*;
use restate_schema_api::discovery::KeyStructure;
use serde::Serialize;
use std::collections::{BTreeMap, HashMap};
use std::collections::HashMap;
use uuid::Uuid;

static METHOD_NAME: &str = "Test";
Expand Down Expand Up @@ -189,18 +186,6 @@ mod tests {
}
}

fn nested_key_structure() -> KeyStructure {
KeyStructure::Nested(BTreeMap::from([
(1, KeyStructure::Scalar),
(2, KeyStructure::Scalar),
(3, KeyStructure::Scalar),
(
4,
KeyStructure::Nested(BTreeMap::from([(1, KeyStructure::Scalar)])),
),
]))
}

fn mock_keyed_service_instance_type(
key_structure: KeyStructure,
field_number: u32,
Expand Down Expand Up @@ -328,36 +313,6 @@ mod tests {
}

json_tests!(string);
json_tests!(bytes);
json_tests!(number);
json_tests!(nested_message, nested_key_structure());
json_tests!(
test: nested_message_with_default,
field_name: nested_message,
key_structure: nested_key_structure(),
test_message: TestMessage {
nested_message: Some(NestedKey {
b: "b".to_string(),
..Default::default()
}),
..Default::default()
}
);
json_tests!(
test: double_nested_message,
field_name: nested_message,
key_structure: nested_key_structure(),
test_message: TestMessage {
nested_message: Some(NestedKey {
b: "b".to_string(),
other: Some(OtherMessage {
d: "d".to_string()
}),
..Default::default()
}),
..Default::default()
}
);

#[test]
fn unkeyed_convert_key_to_json() {
Expand Down Expand Up @@ -411,14 +366,11 @@ mod tests {
let expected_restate_key = extract(&service_instance_type, METHOD_NAME, Bytes::new())
.expect("successful key extraction");

// Parse this as uuid
let uuid = Uuid::from_slice(&expected_restate_key).unwrap();

// Now convert the key to json
let actual_restate_key = json_to_key(
&service_instance_type,
test_method_descriptor(),
Value::String(uuid.as_simple().to_string()),
Value::String(String::from_utf8(expected_restate_key.to_vec()).unwrap()),
)
.unwrap();

Expand Down
Loading

0 comments on commit 5c45df3

Please sign in to comment.