From 2b46ad173b05bfb54235da3444577664dbc00e32 Mon Sep 17 00:00:00 2001 From: Josh Elkins Date: Thu, 16 May 2024 11:23:28 -0500 Subject: [PATCH] chore: Move event marshal/unmarshal code to smithy-swift (#728) --- .../events/MessageMarshallableGenerator.kt | 204 ++++++++++++++ .../events/MessageUnmarshallableGenerator.kt | 210 +++++++++++++++ .../HTTPBindingProtocolGenerator.kt | 48 +++- .../src/test/kotlin/EventStreamTests.kt | 249 ++++++++++++++++++ .../MockHTTPAWSJson11ProtocolGenerator.kt | 8 - .../MockHTTPEC2QueryProtocolGenerator.kt | 8 - .../MockHTTPRestJsonProtocolGenerator.kt | 8 - .../mocks/MockHTTPRestXMLProtocolGenerator.kt | 8 - .../src/test/resources/eventstream.smithy | 85 ++++++ 9 files changed, 794 insertions(+), 34 deletions(-) create mode 100644 smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageMarshallableGenerator.kt create mode 100644 smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageUnmarshallableGenerator.kt create mode 100644 smithy-swift-codegen/src/test/kotlin/EventStreamTests.kt create mode 100644 smithy-swift-codegen/src/test/resources/eventstream.smithy diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageMarshallableGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageMarshallableGenerator.kt new file mode 100644 index 000000000..74821f4f5 --- /dev/null +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageMarshallableGenerator.kt @@ -0,0 +1,204 @@ +package software.amazon.smithy.swift.codegen.events + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EventHeaderTrait +import software.amazon.smithy.model.traits.EventPayloadTrait +import software.amazon.smithy.swift.codegen.ClientRuntimeTypes +import software.amazon.smithy.swift.codegen.SwiftDependency +import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator +import software.amazon.smithy.swift.codegen.integration.serde.readwrite.NodeInfoUtils +import software.amazon.smithy.swift.codegen.integration.serde.readwrite.WritingClosureUtils +import software.amazon.smithy.swift.codegen.integration.serde.readwrite.requestWireProtocol +import software.amazon.smithy.swift.codegen.integration.serde.readwrite.responseWireProtocol +import software.amazon.smithy.swift.codegen.integration.serde.struct.writerSymbol +import software.amazon.smithy.swift.codegen.model.eventStreamEvents +import software.amazon.smithy.swift.codegen.model.hasTrait + +class MessageMarshallableGenerator( + private val ctx: ProtocolGenerator.GenerationContext, + private val payloadContentType: String +) { + internal fun render(streamShape: UnionShape) { + val streamSymbol: Symbol = ctx.symbolProvider.toSymbol(streamShape) + val rootNamespace = ctx.settings.moduleName + val streamMember = Symbol.builder() + .definitionFile("./$rootNamespace/models/${streamSymbol.name}+MessageMarshallable.swift") + .name(streamSymbol.name) + .build() + ctx.delegator.useShapeWriter(streamMember) { writer -> + writer.apply { + addImport(SwiftDependency.CLIENT_RUNTIME.target) + openBlock("extension \$L {", "}", streamSymbol.fullName) { + openBlock( + "static var marshal: \$N<\$N> {", "}", + ClientRuntimeTypes.EventStream.MarshalClosure, + streamSymbol + ) { + openBlock("{ (self) in", "}") { + write( + "var headers: [\$N] = [.init(name: \":message-type\", value: .string(\"event\"))]", + ClientRuntimeTypes.EventStream.Header + ) + write("var payload: \$D", ClientRuntimeTypes.Core.Data) + write("switch self {") + streamShape.eventStreamEvents(ctx.model).forEach { member -> + val memberName = ctx.symbolProvider.toMemberName(member) + write("case .\$L(let value):", memberName) + indent() + addStringHeader(":event-type", member.memberName) + val variant = ctx.model.expectShape(member.target) + val eventHeaderBindings = variant.members().filter { + it.hasTrait() + } + val eventPayloadBinding = variant.members().firstOrNull { + it.hasTrait() + } + val unbound = variant.members().filterNot { + it.hasTrait() || it.hasTrait() + } + + eventHeaderBindings.forEach { + renderSerializeEventHeader(ctx, it, writer) + } + + when { + eventPayloadBinding != null -> renderSerializeEventPayload(ctx, eventPayloadBinding, writer) + unbound.isNotEmpty() -> { + writer.addStringHeader(":content-type", payloadContentType) + writer.addImport(ctx.service.writerSymbol.namespace) + val nodeInfo = NodeInfoUtils(ctx, writer, ctx.service.requestWireProtocol).nodeInfo(member, true) + writer.write("let writer = \$N(nodeInfo: \$L)", ctx.service.writerSymbol, nodeInfo) + unbound.forEach { + val writingClosure = WritingClosureUtils(ctx, writer).writingClosure(ctx.model.expectShape(it.target)) + writer.write( + "try writer[\$S].write(value.\$L, with: \$L)", + it.memberName, + ctx.symbolProvider.toMemberName(it), + writingClosure, + ) + } + writer.write("payload = try writer.data()") + } + } + writer.dedent() + } + writer.write("case .sdkUnknown(_):") + writer.indent() + writer.write( + "throw \$N(\"cannot serialize the unknown event type!\")", + ClientRuntimeTypes.Core.UnknownClientError + ) + writer.dedent() + writer.write("}") + writer.write( + "return \$N(headers: headers, payload: payload ?? .init())", + ClientRuntimeTypes.EventStream.Message + ) + } + } + } + } + } + } + + private fun renderSerializeEventPayload(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: SwiftWriter) { + val target = ctx.model.expectShape(member.target) + val memberName = ctx.symbolProvider.toMemberName(member) + when (target.type) { + ShapeType.BLOB -> { + writer.addStringHeader(":content-type", "application/octet-stream") + writer.write("payload = value.\$L", memberName) + } + ShapeType.STRING -> { + writer.addStringHeader(":content-type", "text/plain") + writer.write("payload = value.\$L?.data(using: .utf8)", memberName) + } + ShapeType.STRUCTURE, ShapeType.UNION -> { + writer.addStringHeader(":content-type", payloadContentType) + renderPayloadSerialization(ctx, writer, member) + } + else -> throw CodegenException("unsupported shape type `${target.type}` for target: $target; expected blob, string, structure, or union for eventPayload member: $member") + } + } + + /** + * + * if let headerValue = value.blob { + * headers.append(.init(name: "blob", value: .byteArray(headerValue))) + * } + * if let headerValue = value.boolean { + * headers.append(.init(name: "boolean", value: .bool(headerValue))) + * } + * if let headerValue = value.byte { + * headers.append(.init(name: "byte", value: .byte(headerValue))) + * } + * if let headerValue = value.int { + * headers.append(.init(name: "int", value: .int32(Int32(headerValue)))) + * } + * if let headerValue = value.long { + * headers.append(.init(name: "long", value: .int64(Int64(headerValue)))) + * } + * if let headerValue = value.short { + * headers.append(.init(name: "short", value: .int16(headerValue))) + * } + * if let headerValue = value.string { + * headers.append(.init(name: "string", value: .string(headerValue))) + * } + * if let headerValue = value.timestamp { + * headers.append(.init(name: "timestamp", value: .timestamp(headerValue))) + * } + */ + private fun renderSerializeEventHeader(ctx: ProtocolGenerator.GenerationContext, member: MemberShape, writer: SwiftWriter) { + val target = ctx.model.expectShape(member.target) + val headerValue = when (target.type) { + ShapeType.BOOLEAN -> "bool" + ShapeType.BYTE -> "byte" + ShapeType.SHORT -> "int16" + ShapeType.INTEGER -> "int32" + ShapeType.LONG -> "int64" + ShapeType.BLOB -> "byteArray" + ShapeType.STRING -> "string" + ShapeType.TIMESTAMP -> "timestamp" + else -> throw CodegenException("unsupported shape type `${target.type}` for eventHeader member `$member`; target: $target") + } + + val memberName = ctx.symbolProvider.toMemberName(member) + writer.openBlock("if let headerValue = value.\$L {", "}", memberName) { + when (target.type) { + ShapeType.INTEGER -> { + writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(Int32(headerValue))))", headerValue) + } + ShapeType.LONG -> { + writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(Int64(headerValue))))", headerValue) + } + else -> { + writer.write("headers.append(.init(name: \"${member.memberName}\", value: .\$L(headerValue)))", headerValue) + } + } + } + } + + private fun SwiftWriter.addStringHeader(name: String, value: String) { + write("headers.append(.init(name: \$S, value: .string(\$S)))", name, value) + } + + private fun renderPayloadSerialization(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter, memberShape: MemberShape) { + // get a payload serializer for the given members of the variant + val nodeInfoUtils = NodeInfoUtils(ctx, writer, ctx.service.responseWireProtocol) + val rootNodeInfo = nodeInfoUtils.nodeInfo(memberShape, true) + val valueWritingClosure = WritingClosureUtils(ctx, writer).writingClosure(memberShape) + writer.addImport(ctx.service.writerSymbol.namespace) + writer.write( + "payload = try \$N.write(value.\$L, rootNodeInfo: \$L, with: \$L)", + ctx.service.writerSymbol, + ctx.symbolProvider.toMemberName(memberShape), + rootNodeInfo, + valueWritingClosure, + ) + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageUnmarshallableGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageUnmarshallableGenerator.kt new file mode 100644 index 000000000..479d86da3 --- /dev/null +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/events/MessageUnmarshallableGenerator.kt @@ -0,0 +1,210 @@ +package software.amazon.smithy.swift.codegen.events + +import software.amazon.smithy.codegen.core.CodegenException +import software.amazon.smithy.codegen.core.Symbol +import software.amazon.smithy.model.shapes.MemberShape +import software.amazon.smithy.model.shapes.ShapeType +import software.amazon.smithy.model.shapes.UnionShape +import software.amazon.smithy.model.traits.EventHeaderTrait +import software.amazon.smithy.model.traits.EventPayloadTrait +import software.amazon.smithy.swift.codegen.ClientRuntimeTypes +import software.amazon.smithy.swift.codegen.SwiftDependency +import software.amazon.smithy.swift.codegen.SwiftTypes +import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.integration.HTTPProtocolCustomizable +import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator +import software.amazon.smithy.swift.codegen.integration.serde.readwrite.ReadingClosureUtils +import software.amazon.smithy.swift.codegen.integration.serde.struct.readerSymbol +import software.amazon.smithy.swift.codegen.model.eventStreamErrors +import software.amazon.smithy.swift.codegen.model.eventStreamEvents +import software.amazon.smithy.swift.codegen.model.expectShape +import software.amazon.smithy.swift.codegen.model.hasTrait + +class MessageUnmarshallableGenerator( + val ctx: ProtocolGenerator.GenerationContext, + val customizations: HTTPProtocolCustomizable, +) { + fun render( + streamingMember: MemberShape + ) { + val symbol: Symbol = ctx.symbolProvider.toSymbol(ctx.model.expectShape(streamingMember.target)) + val rootNamespace = ctx.settings.moduleName + val streamMember = Symbol.builder() + .definitionFile("./$rootNamespace/models/${symbol.name}+MessageUnmarshallable.swift") + .name(symbol.name) + .build() + + val streamShape = ctx.model.expectShape(streamingMember.target) + val streamSymbol = ctx.symbolProvider.toSymbol(streamShape) + + ctx.delegator.useShapeWriter(streamMember) { writer -> + + writer.addImport(SwiftDependency.CLIENT_RUNTIME.target) + writer.addImport(customizations.unknownServiceErrorSymbol.namespace) + writer.openBlock("extension \$L {", "}", streamSymbol.fullName) { + writer.openBlock( + "static var unmarshal: \$N<\$N> {", "}", + ClientRuntimeTypes.EventStream.UnmarshalClosure, + streamSymbol, + ) { + writer.openBlock("{ message in", "}") { + writer.write("switch try message.type() {") + writer.write("case .event(let params):") + writer.indent { + writer.write("switch params.eventType {") + streamShape.eventStreamEvents(ctx.model).forEach { member -> + writer.write("case \"${member.memberName}\":") + writer.indent { + renderDeserializeEventVariant(ctx, streamSymbol, member, writer) + } + } + writer.write("default:") + writer.indent { + writer.write("return .sdkUnknown(\"error processing event stream, unrecognized event: \\(params.eventType)\")") + } + writer.write("}") + } + writer.write("case .exception(let params):") + writer.indent { + writer.write( + "let makeError: (\$N, \$N) throws -> \$N = { message, params in", + ClientRuntimeTypes.EventStream.Message, + ClientRuntimeTypes.EventStream.ExceptionParams, + SwiftTypes.Error + ) + writer.indent { + writer.write("switch params.exceptionType {") + streamShape.eventStreamErrors(ctx.model).forEach { member -> + writer.write("case \$S:", member.memberName) + writer.indent { + renderReadToValue(writer, member) + writer.write("return value") + } + } + writer.write("default:") + writer.indent { + writer.write("let httpResponse = HttpResponse(body: .data(message.payload), statusCode: .ok)") + writer.write( + "return \$L(httpResponse: httpResponse, message: \"error processing event stream, unrecognized ':exceptionType': \\(params.exceptionType); contentType: \\(params.contentType ?? \"nil\")\", requestID: nil, typeName: nil)", + customizations.unknownServiceErrorSymbol, + ) + } + writer.write("}") + } + writer.write("}") + writer.write("let error = try makeError(message, params)") + writer.write("throw error") + } + writer.write("case .error(let params):") + writer.indent { + // this is a service exception still, just un-modeled + writer.write("let httpResponse = HttpResponse(body: .data(message.payload), statusCode: .ok)") + writer.write( + "throw \$L(httpResponse: httpResponse, message: \"error processing event stream, unrecognized ':errorType': \\(params.errorCode); message: \\(params.message ?? \"nil\")\", requestID: nil, typeName: nil)", + customizations.unknownServiceErrorSymbol, + ) + } + writer.write("case .unknown(messageType: let messageType):") + writer.indent { + // this is a client exception because we failed to parse it + writer.write( + "throw \$L(\"unrecognized event stream message ':message-type': \\(messageType)\")", + ClientRuntimeTypes.Core.UnknownClientError + ) + } + writer.write("}") + } + } + } + } + } + + private fun renderDeserializeEventVariant(ctx: ProtocolGenerator.GenerationContext, unionSymbol: Symbol, member: MemberShape, writer: SwiftWriter) { + val variant = ctx.model.expectShape(member.target) + + val eventHeaderBindings = variant.members().filter { it.hasTrait() } + val eventPayloadBinding = variant.members().firstOrNull { it.hasTrait() } + val unbound = variant.members().filterNot { it.hasTrait() || it.hasTrait() } + val memberName = ctx.symbolProvider.toMemberName(member) + + if (eventHeaderBindings.isEmpty() && eventPayloadBinding == null) { + renderReadToValue(writer, member) + writer.write("return .\$L(value)", memberName) + } else { + val variantSymbol = ctx.symbolProvider.toSymbol(variant) + writer.write("var event = \$N()", variantSymbol) + // render members bound to header + eventHeaderBindings.forEach { hdrBinding -> + val target = ctx.model.expectShape(hdrBinding.target) + + val conversionFn = when (target.type) { + ShapeType.BOOLEAN -> "bool" + ShapeType.BYTE -> "byte" + ShapeType.SHORT -> "int16" + ShapeType.INTEGER -> "int32" + ShapeType.LONG -> "int64" + ShapeType.BLOB -> "byteArray" + ShapeType.STRING -> "string" + ShapeType.TIMESTAMP -> "timestamp" + else -> throw CodegenException("unsupported eventHeader shape: member=$hdrBinding; targetShape=$target") + } + + writer.openBlock("if case .\$L(let value) = message.headers.value(name: \$S) {", "}", conversionFn, hdrBinding.memberName) { + val memberName = ctx.symbolProvider.toMemberName(hdrBinding) + when (target.type) { + ShapeType.INTEGER, ShapeType.LONG -> { + writer.write("event.\$L = Int(value)", memberName) + } + else -> { + writer.write("event.\$L = value", memberName) + } + } + } + } + + if (eventPayloadBinding != null) { + renderDeserializeExplicitEventPayloadMember(ctx, eventPayloadBinding, writer) + } else { + if (unbound.isNotEmpty()) { + // all remaining members are bound to payload (but not explicitly bound via @eventPayload) + // generate a payload deserializer specific to the unbound members (note this will be a deserializer + // for the overall event shape but only payload members will be considered for deserialization), + // and then assign each deserialized payload member to the current builder instance + unbound.forEach { + renderReadToValue(writer, it) + writer.write("event.\$L = value", ctx.symbolProvider.toMemberName(it)) + } + } + } + writer.write("return .\$L(event)", memberName) + } + } + + private fun renderDeserializeExplicitEventPayloadMember( + ctx: ProtocolGenerator.GenerationContext, + member: MemberShape, + writer: SwiftWriter, + ) { + val target = ctx.model.expectShape(member.target) + val memberName = ctx.symbolProvider.toMemberName(member) + when (target.type) { + ShapeType.BLOB -> writer.write("event.\$L = message.payload", memberName) + ShapeType.STRING -> writer.write("event.\$L = String(data: message.payload, encoding: .utf8)", memberName) + ShapeType.STRUCTURE, ShapeType.UNION -> { + renderReadToValue(writer, member) + writer.write("event.\$L = value", ctx.symbolProvider.toMemberName(member)) + } + else -> throw CodegenException("unsupported shape type `${target.type}` for target: $target; expected blob, string, structure, or union for eventPayload member: $member") + } + } + + private fun renderReadToValue(writer: SwiftWriter, memberShape: MemberShape) { + writer.addImport(ctx.service.readerSymbol.namespace) + val readingClosure = ReadingClosureUtils(ctx, writer).readingClosure(memberShape) + writer.write( + "let value = try \$N.readFrom(message.payload, with: \$L)", + ctx.service.readerSymbol, + readingClosure, + ) + } +} diff --git a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HTTPBindingProtocolGenerator.kt b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HTTPBindingProtocolGenerator.kt index dfd5d11d9..1e1c084ea 100644 --- a/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HTTPBindingProtocolGenerator.kt +++ b/smithy-swift-codegen/src/main/kotlin/software/amazon/smithy/swift/codegen/integration/HTTPBindingProtocolGenerator.kt @@ -39,6 +39,8 @@ import software.amazon.smithy.swift.codegen.SwiftDependency import software.amazon.smithy.swift.codegen.SwiftWriter import software.amazon.smithy.swift.codegen.customtraits.NeedsReaderTrait import software.amazon.smithy.swift.codegen.customtraits.NeedsWriterTrait +import software.amazon.smithy.swift.codegen.events.MessageMarshallableGenerator +import software.amazon.smithy.swift.codegen.events.MessageUnmarshallableGenerator import software.amazon.smithy.swift.codegen.integration.httpResponse.HTTPResponseGenerator import software.amazon.smithy.swift.codegen.integration.middlewares.AuthSchemeMiddleware import software.amazon.smithy.swift.codegen.integration.middlewares.ContentLengthMiddleware @@ -66,6 +68,8 @@ import software.amazon.smithy.swift.codegen.model.ShapeMetadata import software.amazon.smithy.swift.codegen.model.findStreamingMember import software.amazon.smithy.swift.codegen.model.hasEventStreamMember import software.amazon.smithy.swift.codegen.model.hasTrait +import software.amazon.smithy.swift.codegen.model.isInputEventStream +import software.amazon.smithy.swift.codegen.model.isOutputEventStream import software.amazon.smithy.swift.codegen.model.targetOrSelf import software.amazon.smithy.swift.codegen.supportsStreamingAndIsRPC import software.amazon.smithy.utils.OptionalUtils @@ -474,9 +478,49 @@ abstract class HTTPBindingProtocolGenerator( return containedOperations } - abstract override fun generateMessageMarshallable(ctx: ProtocolGenerator.GenerationContext) + fun outputStreamingShapes(ctx: ProtocolGenerator.GenerationContext): MutableSet { + val streamingShapes = mutableMapOf() + val streamingOperations = getHttpBindingOperations(ctx).filter { it.isOutputEventStream(ctx.model) } + streamingOperations.forEach { operation -> + val input = operation.output.get() + val streamingMember = ctx.model.expectShape(input).findStreamingMember(ctx.model) + streamingMember?.let { + val targetType = ctx.model.expectShape(it.target) + streamingShapes[targetType.id] = it + } + } + return streamingShapes.values.toMutableSet() + } + + fun inputStreamingShapes(ctx: ProtocolGenerator.GenerationContext): MutableSet { + val streamingShapes = mutableSetOf() + val streamingOperations = getHttpBindingOperations(ctx).filter { it.isInputEventStream(ctx.model) } + streamingOperations.forEach { operation -> + val input = operation.input.get() + val streamingMember = ctx.model.expectShape(input).findStreamingMember(ctx.model) + streamingMember?.let { + val targetType = ctx.model.expectShape(it.target) + streamingShapes.add(targetType as UnionShape) + } + } + return streamingShapes + } + + override fun generateMessageMarshallable(ctx: ProtocolGenerator.GenerationContext) { + var streamingShapes = inputStreamingShapes(ctx) + val messageMarshallableGenerator = MessageMarshallableGenerator(ctx, defaultContentType) + streamingShapes.forEach { streamingMember -> + messageMarshallableGenerator.render(streamingMember) + } + } - abstract override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) + override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) { + var streamingShapes = outputStreamingShapes(ctx) + val messageUnmarshallableGenerator = MessageUnmarshallableGenerator(ctx, customizations) + streamingShapes.forEach { streamingMember -> + messageUnmarshallableGenerator.render(streamingMember) + } + } } class DefaultServiceConfig(writer: SwiftWriter, serviceName: String) : diff --git a/smithy-swift-codegen/src/test/kotlin/EventStreamTests.kt b/smithy-swift-codegen/src/test/kotlin/EventStreamTests.kt new file mode 100644 index 000000000..0a0683f34 --- /dev/null +++ b/smithy-swift-codegen/src/test/kotlin/EventStreamTests.kt @@ -0,0 +1,249 @@ +import io.kotest.matchers.string.shouldContainOnlyOnce +import org.junit.jupiter.api.Test +import software.amazon.smithy.swift.codegen.DefaultClientConfigurationIntegration + +class EventStreamTests { + @Test + fun `test MessageMarshallable`() { + val context = setupTests("eventstream.smithy", "aws.protocoltests.restjson#TestService") + println(context.manifest.files) + val contents = getFileContents(context.manifest, "/Example/models/TestStream+MessageMarshallable.swift") + val expected = """ +extension EventStreamTestClientTypes.TestStream { + static var marshal: ClientRuntime.MarshalClosure { + { (self) in + var headers: [ClientRuntime.EventStream.Header] = [.init(name: ":message-type", value: .string("event"))] + var payload: ClientRuntime.Data? = nil + switch self { + case .messagewithblob(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithBlob"))) + headers.append(.init(name: ":content-type", value: .string("application/octet-stream"))) + payload = value.data + case .messagewithstring(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithString"))) + headers.append(.init(name: ":content-type", value: .string("text/plain"))) + payload = value.data?.data(using: .utf8) + case .messagewithstruct(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithStruct"))) + headers.append(.init(name: ":content-type", value: .string("application/json"))) + payload = try SmithyJSON.Writer.write(value.someStruct, rootNodeInfo: "", with: EventStreamTestClientTypes.TestStruct.write(value:to:)) + case .messagewithunion(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithUnion"))) + headers.append(.init(name: ":content-type", value: .string("application/json"))) + payload = try SmithyJSON.Writer.write(value.someUnion, rootNodeInfo: "", with: EventStreamTestClientTypes.TestUnion.write(value:to:)) + case .messagewithheaders(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithHeaders"))) + if let headerValue = value.blob { + headers.append(.init(name: "blob", value: .byteArray(headerValue))) + } + if let headerValue = value.boolean { + headers.append(.init(name: "boolean", value: .bool(headerValue))) + } + if let headerValue = value.byte { + headers.append(.init(name: "byte", value: .byte(headerValue))) + } + if let headerValue = value.int { + headers.append(.init(name: "int", value: .int32(Int32(headerValue)))) + } + if let headerValue = value.long { + headers.append(.init(name: "long", value: .int64(Int64(headerValue)))) + } + if let headerValue = value.short { + headers.append(.init(name: "short", value: .int16(headerValue))) + } + if let headerValue = value.string { + headers.append(.init(name: "string", value: .string(headerValue))) + } + if let headerValue = value.timestamp { + headers.append(.init(name: "timestamp", value: .timestamp(headerValue))) + } + case .messagewithheaderandpayload(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithHeaderAndPayload"))) + if let headerValue = value.header { + headers.append(.init(name: "header", value: .string(headerValue))) + } + headers.append(.init(name: ":content-type", value: .string("application/octet-stream"))) + payload = value.payload + case .messagewithnoheaderpayloadtraits(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithNoHeaderPayloadTraits"))) + headers.append(.init(name: ":content-type", value: .string("application/json"))) + let writer = SmithyJSON.Writer(nodeInfo: "") + try writer["someInt"].write(value.someInt, with: Swift.Int.write(value:to:)) + try writer["someString"].write(value.someString, with: Swift.String.write(value:to:)) + payload = try writer.data() + case .messagewithunboundpayloadtraits(let value): + headers.append(.init(name: ":event-type", value: .string("MessageWithUnboundPayloadTraits"))) + if let headerValue = value.header { + headers.append(.init(name: "header", value: .string(headerValue))) + } + headers.append(.init(name: ":content-type", value: .string("application/json"))) + let writer = SmithyJSON.Writer(nodeInfo: "") + try writer["unboundString"].write(value.unboundString, with: Swift.String.write(value:to:)) + payload = try writer.data() + case .sdkUnknown(_): + throw ClientRuntime.ClientError.unknownError("cannot serialize the unknown event type!") + } + return ClientRuntime.EventStream.Message(headers: headers, payload: payload ?? .init()) + } + } +} +""" + contents.shouldContainOnlyOnce(expected) + } + + @Test + fun `test MessageUnmarshallable`() { + val context = setupTests("eventstream.smithy", "aws.protocoltests.restjson#TestService") + val contents = getFileContents(context.manifest, "/Example/models/TestStream+MessageUnmarshallable.swift") + val expected = """ +extension EventStreamTestClientTypes.TestStream { + static var unmarshal: ClientRuntime.UnmarshalClosure { + { message in + switch try message.type() { + case .event(let params): + switch params.eventType { + case "MessageWithBlob": + var event = EventStreamTestClientTypes.MessageWithBlob() + event.data = message.payload + return .messagewithblob(event) + case "MessageWithString": + var event = EventStreamTestClientTypes.MessageWithString() + event.data = String(data: message.payload, encoding: .utf8) + return .messagewithstring(event) + case "MessageWithStruct": + var event = EventStreamTestClientTypes.MessageWithStruct() + let value = try SmithyJSON.Reader.readFrom(message.payload, with: EventStreamTestClientTypes.TestStruct.read(from:)) + event.someStruct = value + return .messagewithstruct(event) + case "MessageWithUnion": + var event = EventStreamTestClientTypes.MessageWithUnion() + let value = try SmithyJSON.Reader.readFrom(message.payload, with: EventStreamTestClientTypes.TestUnion.read(from:)) + event.someUnion = value + return .messagewithunion(event) + case "MessageWithHeaders": + var event = EventStreamTestClientTypes.MessageWithHeaders() + if case .byteArray(let value) = message.headers.value(name: "blob") { + event.blob = value + } + if case .bool(let value) = message.headers.value(name: "boolean") { + event.boolean = value + } + if case .byte(let value) = message.headers.value(name: "byte") { + event.byte = value + } + if case .int32(let value) = message.headers.value(name: "int") { + event.int = Int(value) + } + if case .int64(let value) = message.headers.value(name: "long") { + event.long = Int(value) + } + if case .int16(let value) = message.headers.value(name: "short") { + event.short = value + } + if case .string(let value) = message.headers.value(name: "string") { + event.string = value + } + if case .timestamp(let value) = message.headers.value(name: "timestamp") { + event.timestamp = value + } + return .messagewithheaders(event) + case "MessageWithHeaderAndPayload": + var event = EventStreamTestClientTypes.MessageWithHeaderAndPayload() + if case .string(let value) = message.headers.value(name: "header") { + event.header = value + } + event.payload = message.payload + return .messagewithheaderandpayload(event) + case "MessageWithNoHeaderPayloadTraits": + let value = try SmithyJSON.Reader.readFrom(message.payload, with: EventStreamTestClientTypes.MessageWithNoHeaderPayloadTraits.read(from:)) + return .messagewithnoheaderpayloadtraits(value) + case "MessageWithUnboundPayloadTraits": + var event = EventStreamTestClientTypes.MessageWithUnboundPayloadTraits() + if case .string(let value) = message.headers.value(name: "header") { + event.header = value + } + let value = try SmithyJSON.Reader.readFrom(message.payload, with: Swift.String.read(from:)) + event.unboundString = value + return .messagewithunboundpayloadtraits(event) + default: + return .sdkUnknown("error processing event stream, unrecognized event: \(params.eventType)") + } + case .exception(let params): + let makeError: (ClientRuntime.EventStream.Message, ClientRuntime.EventStream.MessageType.ExceptionParams) throws -> Swift.Error = { message, params in + switch params.exceptionType { + case "SomeError": + let value = try SmithyJSON.Reader.readFrom(message.payload, with: SomeError.read(from:)) + return value + default: + let httpResponse = HttpResponse(body: .data(message.payload), statusCode: .ok) + return ClientRuntime.UnknownHTTPServiceError(httpResponse: httpResponse, message: "error processing event stream, unrecognized ':exceptionType': \(params.exceptionType); contentType: \(params.contentType ?? "nil")", requestID: nil, typeName: nil) + } + } + let error = try makeError(message, params) + throw error + case .error(let params): + let httpResponse = HttpResponse(body: .data(message.payload), statusCode: .ok) + throw ClientRuntime.UnknownHTTPServiceError(httpResponse: httpResponse, message: "error processing event stream, unrecognized ':errorType': \(params.errorCode); message: \(params.message ?? "nil")", requestID: nil, typeName: nil) + case .unknown(messageType: let messageType): + throw ClientRuntime.ClientError.unknownError("unrecognized event stream message ':message-type': \(messageType)") + } + } + } +} +""" + contents.shouldContainOnlyOnce(expected) + } + + @Test + fun `operation stack`() { + val context = setupTests("eventstream.smithy", "aws.protocoltests.restjson#TestService") + println(context.manifest.files) + val contents = getFileContents(context.manifest, "/Example/EventStreamTestClient.swift") + var expected = """ + public func testStreamOp(input: TestStreamOpInput) async throws -> TestStreamOpOutput { + let context = ClientRuntime.HttpContextBuilder() + .withMethod(value: .post) + .withServiceName(value: serviceName) + .withOperation(value: "testStreamOp") + .withIdempotencyTokenGenerator(value: config.idempotencyTokenGenerator) + .withLogger(value: config.logger) + .withPartitionID(value: config.partitionID) + .withAuthSchemes(value: config.authSchemes ?? []) + .withAuthSchemeResolver(value: config.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) + .withSocketTimeout(value: config.httpClientConfiguration.socketTimeout) + .build() + var operation = ClientRuntime.OperationStack(id: "testStreamOp") + operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware(TestStreamOpInput.urlPathProvider(_:))) + operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLHostMiddleware()) + operation.buildStep.intercept(position: .before, middleware: ClientRuntime.AuthSchemeMiddleware()) + operation.serializeStep.intercept(position: .after, middleware: ContentTypeMiddleware(contentType: "application/json")) + operation.serializeStep.intercept(position: .after, middleware: ClientRuntime.EventStreamBodyMiddleware(keyPath: \.value, defaultBody: "{}", marshalClosure: EventStreamTestClientTypes.TestStream.marshal)) + operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) + operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) + operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) + operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware(TestStreamOpOutput.httpOutput(from:), TestStreamOpOutputError.httpError(from:))) + operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) + let result = try await operation.handleMiddleware(context: context, input: input, next: client.getHandler()) + return result + } +""" + contents.shouldContainOnlyOnce(expected) + } + + private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext { + val context = TestContext.initContextFrom( + listOf(smithyFile), + serviceShapeId, + MockHTTPRestJsonProtocolGenerator(), + { model -> model.defaultSettings(serviceShapeId, "Example", "456", "EventStreamTest") }, + listOf(DefaultClientConfigurationIntegration()) + ) + context.generator.initializeMiddleware(context.generationCtx) + context.generator.generateProtocolClient(context.generationCtx) + context.generator.generateMessageMarshallable(context.generationCtx) + context.generator.generateMessageUnmarshallable(context.generationCtx) + context.generationCtx.delegator.flushWriters() + return context + } +} diff --git a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPAWSJson11ProtocolGenerator.kt b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPAWSJson11ProtocolGenerator.kt index 566df0680..ad241f210 100644 --- a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPAWSJson11ProtocolGenerator.kt +++ b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPAWSJson11ProtocolGenerator.kt @@ -71,14 +71,6 @@ class MockHTTPAWSJson11ProtocolGenerator() : HTTPBindingProtocolGenerator(MockAW // Intentionally empty } - override fun generateMessageMarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - - override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - override fun getProtocolHttpBindingResolver(ctx: ProtocolGenerator.GenerationContext, defaultContentType: String): HttpBindingResolver = MockJsonHttpBindingResolver(ctx, defaultContentType) } diff --git a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPEC2QueryProtocolGenerator.kt b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPEC2QueryProtocolGenerator.kt index 47748e40e..ef37ebf9a 100644 --- a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPEC2QueryProtocolGenerator.kt +++ b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPEC2QueryProtocolGenerator.kt @@ -48,14 +48,6 @@ class MockHTTPEC2QueryProtocolGenerator : HTTPBindingProtocolGenerator(MockEC2Qu // Intentionally empty } - override fun generateMessageMarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - - override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - override fun getProtocolHttpBindingResolver(ctx: ProtocolGenerator.GenerationContext, defaultContentType: String): HttpBindingResolver = MockEC2QueryHttpBindingResolver(ctx, defaultContentType) diff --git a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestJsonProtocolGenerator.kt b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestJsonProtocolGenerator.kt index a22323dd6..a376a0319 100644 --- a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestJsonProtocolGenerator.kt +++ b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestJsonProtocolGenerator.kt @@ -25,14 +25,6 @@ class MockHTTPRestJsonProtocolGenerator : HTTPBindingProtocolGenerator(MockRestJ // Intentionally empty } - override fun generateMessageMarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - - override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - override fun generateProtocolUnitTests(ctx: ProtocolGenerator.GenerationContext): Int { val requestTestBuilder = HttpProtocolUnitTestRequestGenerator.Builder() val responseTestBuilder = HttpProtocolUnitTestResponseGenerator.Builder() diff --git a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestXMLProtocolGenerator.kt b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestXMLProtocolGenerator.kt index 09f70f02a..64c21cdfb 100644 --- a/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestXMLProtocolGenerator.kt +++ b/smithy-swift-codegen/src/test/kotlin/mocks/MockHTTPRestXMLProtocolGenerator.kt @@ -26,14 +26,6 @@ class MockHTTPRestXMLProtocolGenerator : HTTPBindingProtocolGenerator(MockRestXM // Intentionally empty } - override fun generateMessageMarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - - override fun generateMessageUnmarshallable(ctx: ProtocolGenerator.GenerationContext) { - TODO("Not yet implemented") - } - override fun generateProtocolUnitTests(ctx: ProtocolGenerator.GenerationContext): Int { val requestTestBuilder = HttpProtocolUnitTestRequestGenerator.Builder() val responseTestBuilder = HttpProtocolUnitTestResponseGenerator.Builder() diff --git a/smithy-swift-codegen/src/test/resources/eventstream.smithy b/smithy-swift-codegen/src/test/resources/eventstream.smithy new file mode 100644 index 000000000..c153874b5 --- /dev/null +++ b/smithy-swift-codegen/src/test/resources/eventstream.smithy @@ -0,0 +1,85 @@ +namespace aws.protocoltests.restjson + +use aws.protocols#restJson1 +use aws.api#service +use aws.auth#sigv4 + +@restJson1 +@sigv4(name: "event-stream-test") +@service(sdkId: "EventStreamTest") +service TestService { version: "123", operations: [TestStreamOp] } + +@documentation("This operation is cool.") +@http(method: "POST", uri: "/test") +operation TestStreamOp { + input: TestStreamInputOutput, + output: TestStreamInputOutput, + errors: [SomeError], +} + +structure TestStreamInputOutput { + @httpPayload + @required + value: TestStream +} + +@documentation("You don't have permission.") +@error("client") +structure SomeError { + Message: String, +} + +union TestUnion { + Foo: String, + Bar: Integer, +} + +structure TestStruct { + someString: String, + someInt: Integer, +} + +structure MessageWithBlob { @eventPayload data: Blob } + +structure MessageWithString { @eventPayload data: String } + +structure MessageWithStruct { @eventPayload someStruct: TestStruct } + +structure MessageWithUnion { @eventPayload someUnion: TestUnion } + +structure MessageWithHeaders { + @eventHeader blob: Blob, + @eventHeader boolean: Boolean, + @eventHeader byte: Byte, + @eventHeader int: Integer, + @eventHeader long: Long, + @eventHeader short: Short, + @eventHeader string: String, + @eventHeader timestamp: Timestamp, +} +structure MessageWithHeaderAndPayload { + @eventHeader header: String, + @eventPayload payload: Blob, +} +structure MessageWithNoHeaderPayloadTraits { + someInt: Integer, + someString: String, +} + +structure MessageWithUnboundPayloadTraits { + @eventHeader header: String, + unboundString: String, +} + +@streaming +union TestStream { + MessageWithBlob: MessageWithBlob, + MessageWithString: MessageWithString, + MessageWithStruct: MessageWithStruct, + MessageWithUnion: MessageWithUnion, + MessageWithHeaders: MessageWithHeaders, + MessageWithHeaderAndPayload: MessageWithHeaderAndPayload, + MessageWithNoHeaderPayloadTraits: MessageWithNoHeaderPayloadTraits, + MessageWithUnboundPayloadTraits: MessageWithUnboundPayloadTraits, + SomeError: SomeError, +}