From 069a16269c35836171fba5cd8354ff29084f3f52 Mon Sep 17 00:00:00 2001 From: Francesco Guardiani Date: Mon, 18 Nov 2024 14:59:55 +0100 Subject: [PATCH] Simple Json Schemas for Kotlin (#403) --- .../dev/restate/sdk/kotlin/CodegenTest.kt | 4 +- .../kotlin/dev/restate/sdk/kotlin/KtSerdes.kt | 87 ++++++++++++++++++- .../dev/restate/sdk/common/RichSerde.java | 5 +- .../restate/sdk/core/EndpointManifest.java | 27 +++++- .../dev/restate/sdk/core/ServiceProtocol.java | 2 +- 5 files changed, 115 insertions(+), 10 deletions(-) diff --git a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt index 6adec37b..94eef43c 100644 --- a/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt +++ b/sdk-api-kotlin-gen/src/test/kotlin/dev/restate/sdk/kotlin/CodegenTest.kt @@ -57,8 +57,8 @@ class CodegenTest : TestDefinitions.TestSuite { @Exclusive suspend fun complexType( context: ObjectContext, - request: Map> - ): Map> { + request: Map> + ): Map> { return mapOf() } } diff --git a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt index 798bac18..ea94b20d 100644 --- a/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt +++ b/sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt @@ -9,14 +9,26 @@ package dev.restate.sdk.kotlin import dev.restate.sdk.common.DurablePromiseKey +import dev.restate.sdk.common.RichSerde import dev.restate.sdk.common.Serde import dev.restate.sdk.common.StateKey import java.nio.ByteBuffer import java.nio.charset.StandardCharsets import kotlin.reflect.typeOf +import kotlinx.serialization.ExperimentalSerializationApi import kotlinx.serialization.KSerializer +import kotlinx.serialization.Serializable +import kotlinx.serialization.builtins.ListSerializer +import kotlinx.serialization.builtins.serializer +import kotlinx.serialization.descriptors.PrimitiveKind +import kotlinx.serialization.descriptors.SerialDescriptor +import kotlinx.serialization.descriptors.StructureKind +import kotlinx.serialization.encodeToString import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonArray +import kotlinx.serialization.json.JsonElement import kotlinx.serialization.json.JsonNull +import kotlinx.serialization.json.JsonTransformingSerializer import kotlinx.serialization.serializer object KtStateKey { @@ -70,12 +82,13 @@ object KtSerdes { } /** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */ - fun json(serializer: KSerializer): Serde { - return object : Serde { + inline fun json(serializer: KSerializer): Serde { + return object : RichSerde { override fun serialize(value: T?): ByteArray { if (value == null) { return Json.encodeToString(JsonNull.serializer(), JsonNull).encodeToByteArray() } + return Json.encodeToString(serializer, value).encodeToByteArray() } @@ -86,6 +99,76 @@ object KtSerdes { override fun contentType(): String { return "application/json" } + + override fun jsonSchema(): String { + val schema: JsonSchema = serializer.descriptor.jsonSchema() + return Json.encodeToString(schema) + } } } + + @Serializable + @PublishedApi + internal data class JsonSchema( + @Serializable(with = StringListSerializer::class) val type: List? = null, + val format: String? = null, + ) { + companion object { + val INT = JsonSchema(type = listOf("number"), format = "int32") + + val LONG = JsonSchema(type = listOf("number"), format = "int64") + + val DOUBLE = JsonSchema(type = listOf("number"), format = "double") + + val FLOAT = JsonSchema(type = listOf("number"), format = "float") + + val STRING = JsonSchema(type = listOf("string")) + + val BOOLEAN = JsonSchema(type = listOf("boolean")) + + val OBJECT = JsonSchema(type = listOf("object")) + + val LIST = JsonSchema(type = listOf("array")) + + val ANY = JsonSchema() + } + } + + object StringListSerializer : + JsonTransformingSerializer>(ListSerializer(String.Companion.serializer())) { + override fun transformSerialize(element: JsonElement): JsonElement { + require(element is JsonArray) + return element.singleOrNull() ?: element + } + } + + /** + * Super simplistic json schema generation. We should replace this with an appropriate library. + */ + @OptIn(ExperimentalSerializationApi::class) + @PublishedApi + internal fun SerialDescriptor.jsonSchema(): JsonSchema { + var schema = + when (this.kind) { + PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN + PrimitiveKind.BYTE -> JsonSchema.INT + PrimitiveKind.CHAR -> JsonSchema.STRING + PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE + PrimitiveKind.FLOAT -> JsonSchema.FLOAT + PrimitiveKind.INT -> JsonSchema.INT + PrimitiveKind.LONG -> JsonSchema.LONG + PrimitiveKind.SHORT -> JsonSchema.INT + PrimitiveKind.STRING -> JsonSchema.STRING + StructureKind.LIST -> JsonSchema.LIST + StructureKind.MAP -> JsonSchema.OBJECT + else -> JsonSchema.ANY + } + + // Add nullability constraint + if (this.isNullable && schema.type != null) { + schema = schema.copy(type = schema.type.plus("null")) + } + + return schema + } } diff --git a/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java b/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java index 20e5418f..2e7e0961 100644 --- a/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java +++ b/sdk-common/src/main/java/dev/restate/sdk/common/RichSerde.java @@ -14,12 +14,15 @@ /** * Richer version of {@link Serde} containing schema information. * + *

This API should be considered unstable to implement. + * *

You can create one using {@link #withSchema(Object, Serde)}. */ public interface RichSerde extends Serde { /** - * @return a Draft 2020-12 Json Schema + * @return a Draft 2020-12 Json Schema. It should be self-contained, and MUST not contain refs to + * files. If the schema shouldn't be serialized with Jackson, return a {@link String} */ Object jsonSchema(); diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java index 03f58f29..782e0c6e 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java @@ -10,6 +10,7 @@ import static dev.restate.sdk.core.ServiceProtocol.*; +import com.fasterxml.jackson.core.JsonProcessingException; import dev.restate.sdk.common.HandlerType; import dev.restate.sdk.common.RichSerde; import dev.restate.sdk.common.ServiceType; @@ -108,8 +109,17 @@ private static Input convertHandlerInput(HandlerSpecification spec) { : new Input().withRequired(true).withContentType(acceptContentType); if (spec.getRequestSerde() instanceof RichSerde) { - input.setJsonSchema( - Objects.requireNonNull(((RichSerde) spec.getRequestSerde()).jsonSchema())); + Object jsonSchema = + Objects.requireNonNull(((RichSerde) spec.getRequestSerde()).jsonSchema()); + if (jsonSchema instanceof String) { + // We need to convert it to databind JSON value + try { + jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema); + } catch (JsonProcessingException e) { + throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e); + } + } + input.setJsonSchema(jsonSchema); } return input; } @@ -123,8 +133,17 @@ private static Output convertHandlerOutput(HandlerSpecification spec) { .withSetContentTypeIfEmpty(false); if (spec.getResponseSerde() instanceof RichSerde) { - output.setJsonSchema( - Objects.requireNonNull(((RichSerde) spec.getResponseSerde()).jsonSchema())); + Object jsonSchema = + Objects.requireNonNull(((RichSerde) spec.getResponseSerde()).jsonSchema()); + if (jsonSchema instanceof String) { + // We need to convert it to databind JSON value + try { + jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema); + } catch (JsonProcessingException e) { + throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e); + } + } + output.setJsonSchema(jsonSchema); } return output; diff --git a/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java b/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java index 241aca53..f410a638 100644 --- a/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java +++ b/sdk-core/src/main/java/dev/restate/sdk/core/ServiceProtocol.java @@ -135,7 +135,7 @@ static String serviceDiscoveryProtocolVersionToHeaderValue( "Service discovery protocol version '%s' has no header value", version.getNumber())); } - private static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper(); + static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper(); @JsonFilter("V2FieldsFilter") interface V2Mixin {}