Skip to content

Commit

Permalink
chore: Move event marshal/unmarshal code to smithy-swift (#728)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbelkins authored May 16, 2024
1 parent dbf0fb1 commit 2b46ad1
Show file tree
Hide file tree
Showing 9 changed files with 794 additions and 34 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package software.amazon.smithy.swift.codegen.events

import software.amazon.smithy.codegen.core.CodegenException
import software.amazon.smithy.codegen.core.Symbol
import software.amazon.smithy.model.shapes.MemberShape
import software.amazon.smithy.model.shapes.ShapeType
import software.amazon.smithy.model.shapes.UnionShape
import software.amazon.smithy.model.traits.EventHeaderTrait
import software.amazon.smithy.model.traits.EventPayloadTrait
import software.amazon.smithy.swift.codegen.ClientRuntimeTypes
import software.amazon.smithy.swift.codegen.SwiftDependency
import software.amazon.smithy.swift.codegen.SwiftWriter
import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator
import software.amazon.smithy.swift.codegen.integration.serde.readwrite.NodeInfoUtils
import software.amazon.smithy.swift.codegen.integration.serde.readwrite.WritingClosureUtils
import software.amazon.smithy.swift.codegen.integration.serde.readwrite.requestWireProtocol
import software.amazon.smithy.swift.codegen.integration.serde.readwrite.responseWireProtocol
import software.amazon.smithy.swift.codegen.integration.serde.struct.writerSymbol
import software.amazon.smithy.swift.codegen.model.eventStreamEvents
import software.amazon.smithy.swift.codegen.model.hasTrait

class MessageMarshallableGenerator(
private val ctx: ProtocolGenerator.GenerationContext,
private val payloadContentType: String
) {
internal fun render(streamShape: UnionShape) {
val streamSymbol: Symbol = ctx.symbolProvider.toSymbol(streamShape)
val rootNamespace = ctx.settings.moduleName
val streamMember = Symbol.builder()
.definitionFile("./$rootNamespace/models/${streamSymbol.name}+MessageMarshallable.swift")
.name(streamSymbol.name)
.build()
ctx.delegator.useShapeWriter(streamMember) { writer ->
writer.apply {
addImport(SwiftDependency.CLIENT_RUNTIME.target)
openBlock("extension \$L {", "}", streamSymbol.fullName) {
openBlock(
"static var marshal: \$N<\$N> {", "}",
ClientRuntimeTypes.EventStream.MarshalClosure,
streamSymbol
) {
openBlock("{ (self) in", "}") {
write(
"var headers: [\$N] = [.init(name: \":message-type\", value: .string(\"event\"))]",
ClientRuntimeTypes.EventStream.Header
)
write("var payload: \$D", ClientRuntimeTypes.Core.Data)
write("switch self {")
streamShape.eventStreamEvents(ctx.model).forEach { member ->
val memberName = ctx.symbolProvider.toMemberName(member)
write("case .\$L(let value):", memberName)
indent()
addStringHeader(":event-type", member.memberName)
val variant = ctx.model.expectShape(member.target)
val eventHeaderBindings = variant.members().filter {
it.hasTrait<EventHeaderTrait>()
}
val eventPayloadBinding = variant.members().firstOrNull {
it.hasTrait<EventPayloadTrait>()
}
val unbound = variant.members().filterNot {
it.hasTrait<EventHeaderTrait>() || it.hasTrait<EventPayloadTrait>()
}

eventHeaderBindings.forEach {
renderSerializeEventHeader(ctx, it, writer)
}

when {
eventPayloadBinding != null -> renderSerializeEventPayload(ctx, eventPayloadBinding, writer)
unbound.isNotEmpty() -> {
writer.addStringHeader(":content-type", payloadContentType)
writer.addImport(ctx.service.writerSymbol.namespace)
val nodeInfo = NodeInfoUtils(ctx, writer, ctx.service.requestWireProtocol).nodeInfo(member, true)
writer.write("let writer = \$N(nodeInfo: \$L)", ctx.service.writerSymbol, nodeInfo)
unbound.forEach {
val writingClosure = WritingClosureUtils(ctx, writer).writingClosure(ctx.model.expectShape(it.target))
writer.write(
"try writer[\$S].write(value.\$L, with: \$L)",
it.memberName,
ctx.symbolProvider.toMemberName(it),
writingClosure,
)
}
writer.write("payload = try writer.data()")
}
}
writer.dedent()
}
writer.write("case .sdkUnknown(_):")
writer.indent()
writer.write(
"throw \$N(\"cannot serialize the unknown event type!\")",
ClientRuntimeTypes.Core.UnknownClientError
)
writer.dedent()
writer.write("}")
writer.write(
"return \$N(headers: headers, payload: payload ?? .init())",
ClientRuntimeTypes.EventStream.Message
)
}
}
}
}
}
}

private fun renderSerializeEventPayload(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: SwiftWriter) {
val target = ctx.model.expectShape(member.target)
val memberName = ctx.symbolProvider.toMemberName(member)
when (target.type) {
ShapeType.BLOB -> {
writer.addStringHeader(":content-type", "application/octet-stream")
writer.write("payload = value.\$L", memberName)
}
ShapeType.STRING -> {
writer.addStringHeader(":content-type", "text/plain")
writer.write("payload = value.\$L?.data(using: .utf8)", memberName)
}
ShapeType.STRUCTURE, ShapeType.UNION -> {
writer.addStringHeader(":content-type", payloadContentType)
renderPayloadSerialization(ctx, writer, member)
}
else -> throw CodegenException("unsupported shape type `${target.type}` for target: $target; expected blob, string, structure, or union for eventPayload member: $member")
}
}

/**
*
* if let headerValue = value.blob {
* headers.append(.init(name: "blob", value: .byteArray(headerValue)))
* }
* if let headerValue = value.boolean {
* headers.append(.init(name: "boolean", value: .bool(headerValue)))
* }
* if let headerValue = value.byte {
* headers.append(.init(name: "byte", value: .byte(headerValue)))
* }
* if let headerValue = value.int {
* headers.append(.init(name: "int", value: .int32(Int32(headerValue))))
* }
* if let headerValue = value.long {
* headers.append(.init(name: "long", value: .int64(Int64(headerValue))))
* }
* if let headerValue = value.short {
* headers.append(.init(name: "short", value: .int16(headerValue)))
* }
* if let headerValue = value.string {
* headers.append(.init(name: "string", value: .string(headerValue)))
* }
* if let headerValue = value.timestamp {
* headers.append(.init(name: "timestamp", value: .timestamp(headerValue)))
* }
*/
private fun renderSerializeEventHeader(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: SwiftWriter) {
val target = ctx.model.expectShape(member.target)
val headerValue = when (target.type) {
ShapeType.BOOLEAN -> "bool"
ShapeType.BYTE -> "byte"
ShapeType.SHORT -> "int16"
ShapeType.INTEGER -> "int32"
ShapeType.LONG -> "int64"
ShapeType.BLOB -> "byteArray"
ShapeType.STRING -> "string"
ShapeType.TIMESTAMP -> "timestamp"
else -> throw CodegenException("unsupported shape type `${target.type}` for eventHeader member `$member`; target: $target")
}

val memberName = ctx.symbolProvider.toMemberName(member)
writer.openBlock("if let headerValue = value.\$L {", "}", memberName) {
when (target.type) {
ShapeType.INTEGER -> {
writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(Int32(headerValue))))", headerValue)
}
ShapeType.LONG -> {
writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(Int64(headerValue))))", headerValue)
}
else -> {
writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(headerValue)))", headerValue)
}
}
}
}

private fun SwiftWriter.addStringHeader(name: String, value: String) {
write("headers.append(.init(name: \$S, value: .string(\$S)))", name, value)
}

private fun renderPayloadSerialization(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter, memberShape: MemberShape) {
// get a payload serializer for the given members of the variant
val nodeInfoUtils = NodeInfoUtils(ctx, writer, ctx.service.responseWireProtocol)
val rootNodeInfo = nodeInfoUtils.nodeInfo(memberShape, true)
val valueWritingClosure = WritingClosureUtils(ctx, writer).writingClosure(memberShape)
writer.addImport(ctx.service.writerSymbol.namespace)
writer.write(
"payload = try \$N.write(value.\$L, rootNodeInfo: \$L, with: \$L)",
ctx.service.writerSymbol,
ctx.symbolProvider.toMemberName(memberShape),
rootNodeInfo,
valueWritingClosure,
)
}
}
Loading

0 comments on commit 2b46ad1

Please sign in to comment.