Skip to content

Commit

Permalink
Simple Json Schemas for Kotlin (#403)
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Nov 18, 2024
1 parent fa728dd commit 069a162
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ class CodegenTest : TestDefinitions.TestSuite {
@Exclusive
suspend fun complexType(
context: ObjectContext,
request: Map<Output, List<out Input>>
): Map<Input, List<out Output>> {
request: Map<String, List<out Input>>
): Map<String, List<out Output>> {
return mapOf()
}
}
Expand Down
87 changes: 85 additions & 2 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/KtSerdes.kt
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,26 @@
package dev.restate.sdk.kotlin

import dev.restate.sdk.common.DurablePromiseKey
import dev.restate.sdk.common.RichSerde
import dev.restate.sdk.common.Serde
import dev.restate.sdk.common.StateKey
import java.nio.ByteBuffer
import java.nio.charset.StandardCharsets
import kotlin.reflect.typeOf
import kotlinx.serialization.ExperimentalSerializationApi
import kotlinx.serialization.KSerializer
import kotlinx.serialization.Serializable
import kotlinx.serialization.builtins.ListSerializer
import kotlinx.serialization.builtins.serializer
import kotlinx.serialization.descriptors.PrimitiveKind
import kotlinx.serialization.descriptors.SerialDescriptor
import kotlinx.serialization.descriptors.StructureKind
import kotlinx.serialization.encodeToString
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.JsonArray
import kotlinx.serialization.json.JsonElement
import kotlinx.serialization.json.JsonNull
import kotlinx.serialization.json.JsonTransformingSerializer
import kotlinx.serialization.serializer

object KtStateKey {
Expand Down Expand Up @@ -70,12 +82,13 @@ object KtSerdes {
}

/** Creates a [Serde] implementation using the `kotlinx.serialization` json module. */
fun <T : Any?> json(serializer: KSerializer<T>): Serde<T> {
return object : Serde<T> {
inline fun <reified T : Any?> json(serializer: KSerializer<T>): Serde<T> {
return object : RichSerde<T> {
override fun serialize(value: T?): ByteArray {
if (value == null) {
return Json.encodeToString(JsonNull.serializer(), JsonNull).encodeToByteArray()
}

return Json.encodeToString(serializer, value).encodeToByteArray()
}

Expand All @@ -86,6 +99,76 @@ object KtSerdes {
override fun contentType(): String {
return "application/json"
}

override fun jsonSchema(): String {
val schema: JsonSchema = serializer.descriptor.jsonSchema()
return Json.encodeToString(schema)
}
}
}

@Serializable
@PublishedApi
internal data class JsonSchema(
@Serializable(with = StringListSerializer::class) val type: List<String>? = null,
val format: String? = null,
) {
companion object {
val INT = JsonSchema(type = listOf("number"), format = "int32")

val LONG = JsonSchema(type = listOf("number"), format = "int64")

val DOUBLE = JsonSchema(type = listOf("number"), format = "double")

val FLOAT = JsonSchema(type = listOf("number"), format = "float")

val STRING = JsonSchema(type = listOf("string"))

val BOOLEAN = JsonSchema(type = listOf("boolean"))

val OBJECT = JsonSchema(type = listOf("object"))

val LIST = JsonSchema(type = listOf("array"))

val ANY = JsonSchema()
}
}

object StringListSerializer :
JsonTransformingSerializer<List<String>>(ListSerializer(String.Companion.serializer())) {
override fun transformSerialize(element: JsonElement): JsonElement {
require(element is JsonArray)
return element.singleOrNull() ?: element
}
}

/**
* Super simplistic json schema generation. We should replace this with an appropriate library.
*/
@OptIn(ExperimentalSerializationApi::class)
@PublishedApi
internal fun SerialDescriptor.jsonSchema(): JsonSchema {
var schema =
when (this.kind) {
PrimitiveKind.BOOLEAN -> JsonSchema.BOOLEAN
PrimitiveKind.BYTE -> JsonSchema.INT
PrimitiveKind.CHAR -> JsonSchema.STRING
PrimitiveKind.DOUBLE -> JsonSchema.DOUBLE
PrimitiveKind.FLOAT -> JsonSchema.FLOAT
PrimitiveKind.INT -> JsonSchema.INT
PrimitiveKind.LONG -> JsonSchema.LONG
PrimitiveKind.SHORT -> JsonSchema.INT
PrimitiveKind.STRING -> JsonSchema.STRING
StructureKind.LIST -> JsonSchema.LIST
StructureKind.MAP -> JsonSchema.OBJECT
else -> JsonSchema.ANY
}

// Add nullability constraint
if (this.isNullable && schema.type != null) {
schema = schema.copy(type = schema.type.plus("null"))
}

return schema
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
/**
* Richer version of {@link Serde} containing schema information.
*
* <p>This API should be considered unstable to implement.
*
* <p>You can create one using {@link #withSchema(Object, Serde)}.
*/
public interface RichSerde<T extends @Nullable Object> extends Serde<T> {

/**
* @return a Draft 2020-12 Json Schema
* @return a Draft 2020-12 Json Schema. It should be self-contained, and MUST not contain refs to
* files. If the schema shouldn't be serialized with Jackson, return a {@link String}
*/
Object jsonSchema();

Expand Down
27 changes: 23 additions & 4 deletions sdk-core/src/main/java/dev/restate/sdk/core/EndpointManifest.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import static dev.restate.sdk.core.ServiceProtocol.*;

import com.fasterxml.jackson.core.JsonProcessingException;
import dev.restate.sdk.common.HandlerType;
import dev.restate.sdk.common.RichSerde;
import dev.restate.sdk.common.ServiceType;
Expand Down Expand Up @@ -108,8 +109,17 @@ private static Input convertHandlerInput(HandlerSpecification<?, ?> spec) {
: new Input().withRequired(true).withContentType(acceptContentType);

if (spec.getRequestSerde() instanceof RichSerde) {
input.setJsonSchema(
Objects.requireNonNull(((RichSerde<?>) spec.getRequestSerde()).jsonSchema()));
Object jsonSchema =
Objects.requireNonNull(((RichSerde<?>) spec.getRequestSerde()).jsonSchema());
if (jsonSchema instanceof String) {
// We need to convert it to databind JSON value
try {
jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema);
} catch (JsonProcessingException e) {
throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e);
}
}
input.setJsonSchema(jsonSchema);
}
return input;
}
Expand All @@ -123,8 +133,17 @@ private static Output convertHandlerOutput(HandlerSpecification<?, ?> spec) {
.withSetContentTypeIfEmpty(false);

if (spec.getResponseSerde() instanceof RichSerde) {
output.setJsonSchema(
Objects.requireNonNull(((RichSerde<?>) spec.getResponseSerde()).jsonSchema()));
Object jsonSchema =
Objects.requireNonNull(((RichSerde<?>) spec.getResponseSerde()).jsonSchema());
if (jsonSchema instanceof String) {
// We need to convert it to databind JSON value
try {
jsonSchema = MANIFEST_OBJECT_MAPPER.readTree((String) jsonSchema);
} catch (JsonProcessingException e) {
throw new RuntimeException("The schema generated by RichSerde is not a valid JSON", e);
}
}
output.setJsonSchema(jsonSchema);
}

return output;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ static String serviceDiscoveryProtocolVersionToHeaderValue(
"Service discovery protocol version '%s' has no header value", version.getNumber()));
}

private static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper();
static final ObjectMapper MANIFEST_OBJECT_MAPPER = new ObjectMapper();

@JsonFilter("V2FieldsFilter")
interface V2Mixin {}
Expand Down

0 comments on commit 069a162

Please sign in to comment.