-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
### Motivation: There will be times when a client wishes to send larger requests with gzipped bodies to save on network traffic. This PR adds a `NIOHTTPRequestDecompressor` which can be added to the server's channel pipeline so those requests are automatically inflated. ### Modifications: - Added a `CNIOExtrasZlib_voidPtr_to_BytefPtr` C method. - Added a `NIOHTTPRequestDecompressor` type. - Added a `HTTPResponseDecompressorTest` test case. ### Result: Now you don't have to manually check the `Content-Encoding` header and decompress the body on each incoming request.
- Loading branch information
1 parent
16fbdf3
commit 0584020
Showing
6 changed files
with
334 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,3 +6,4 @@ Package.pins | |
*.pem | ||
/docs | ||
Package.resolved | ||
.swiftpm/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the SwiftNIO open source project | ||
// | ||
// Copyright (c) 2019 Apple Inc. and the SwiftNIO project authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
import CNIOExtrasZlib | ||
import NIOHTTP1 | ||
import NIO | ||
|
||
public final class NIOHTTPRequestDecompressor: ChannelDuplexHandler, RemovableChannelHandler { | ||
public typealias InboundIn = HTTPServerRequestPart | ||
public typealias InboundOut = HTTPServerRequestPart | ||
public typealias OutboundIn = HTTPServerResponsePart | ||
public typealias OutboundOut = HTTPServerResponsePart | ||
|
||
private struct Compression { | ||
let algorithm: NIOHTTPDecompression.CompressionAlgorithm | ||
let contentLength: Int | ||
} | ||
|
||
private var decompressor: NIOHTTPDecompression.Decompressor | ||
private var compression: Compression? | ||
|
||
public init(limit: NIOHTTPDecompression.DecompressionLimit) { | ||
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit) | ||
self.compression = nil | ||
} | ||
|
||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) { | ||
let request = self.unwrapInboundIn(data) | ||
|
||
switch request { | ||
case .head(let head): | ||
if | ||
let encoding = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased(), | ||
let algorithm = NIOHTTPDecompression.CompressionAlgorithm(header: encoding), | ||
let length = head.headers[canonicalForm: "Content-Length"].first.flatMap({ Int($0) }) | ||
{ | ||
do { | ||
try self.decompressor.initializeDecoder(encoding: algorithm, length: length) | ||
self.compression = Compression(algorithm: algorithm, contentLength: length) | ||
} catch let error { | ||
context.fireErrorCaught(error) | ||
return | ||
} | ||
} | ||
|
||
context.fireChannelRead(data) | ||
case .body(var part): | ||
guard let compression = self.compression else { | ||
context.fireChannelRead(data) | ||
return | ||
} | ||
|
||
while part.readableBytes > 0 { | ||
do { | ||
var buffer = context.channel.allocator.buffer(capacity: 16384) | ||
try self.decompressor.decompress(part: &part, buffer: &buffer, originalLength: compression.contentLength) | ||
|
||
context.fireChannelRead(self.wrapInboundOut(.body(buffer))) | ||
} catch let error { | ||
context.fireErrorCaught(error) | ||
return | ||
} | ||
} | ||
case .end: | ||
if self.compression != nil { | ||
self.decompressor.deinitializeDecoder() | ||
self.compression = nil | ||
} | ||
|
||
context.fireChannelRead(data) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
36 changes: 36 additions & 0 deletions
36
Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest+XCTest.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the SwiftNIO open source project | ||
// | ||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// HTTPRequestDecompressorTest+XCTest.swift | ||
// | ||
import XCTest | ||
|
||
/// | ||
/// NOTE: This file was generated by generate_linux_tests.rb | ||
/// | ||
/// Do NOT edit this file directly as it will be regenerated automatically when needed. | ||
/// | ||
|
||
extension HTTPRequestDecompressorTest { | ||
|
||
static var allTests : [(String, (HTTPRequestDecompressorTest) -> () throws -> Void)] { | ||
return [ | ||
("testDecompressionNoLimit", testDecompressionNoLimit), | ||
("testDecompressionLimitRatio", testDecompressionLimitRatio), | ||
("testDecompressionLimitSize", testDecompressionLimitSize), | ||
("testDecompression", testDecompression), | ||
] | ||
} | ||
} | ||
|
210 changes: 210 additions & 0 deletions
210
Tests/NIOHTTPCompressionTests/HTTPRequestDecompressorTest.swift
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
//===----------------------------------------------------------------------===// | ||
// | ||
// This source file is part of the SwiftNIO open source project | ||
// | ||
// Copyright (c) 2017-2018 Apple Inc. and the SwiftNIO project authors | ||
// Licensed under Apache License v2.0 | ||
// | ||
// See LICENSE.txt for license information | ||
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
import XCTest | ||
import CNIOExtrasZlib | ||
@testable import NIO | ||
@testable import NIOHTTP1 | ||
@testable import NIOHTTPCompression | ||
|
||
private let testString = "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua." | ||
|
||
private final class DecompressedAssert: ChannelInboundHandler { | ||
typealias InboundIn = HTTPServerRequestPart | ||
|
||
public func channelRead(context: ChannelHandlerContext, data: NIOAny) { | ||
let request = self.unwrapInboundIn(data) | ||
|
||
switch request { | ||
case .body(let buffer): | ||
let string = buffer.getString(at: buffer.readerIndex, length: buffer.readableBytes) | ||
guard string == testString else { | ||
context.fireErrorCaught(NIOHTTPDecompression.DecompressionError.inflationError(42)) | ||
return | ||
} | ||
default: context.fireChannelRead(data) | ||
} | ||
} | ||
} | ||
|
||
class HTTPRequestDecompressorTest: XCTestCase { | ||
func testDecompressionNoLimit() throws { | ||
let channel = EmbeddedChannel() | ||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() | ||
try channel.pipeline.addHandler(DecompressedAssert()).wait() | ||
|
||
let buffer = ByteBuffer.of(string: testString) | ||
let compressed = compress(buffer, "gzip") | ||
|
||
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "\(compressed.readableBytes)")]) | ||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) | ||
|
||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.body(compressed))) | ||
} | ||
|
||
func testDecompressionLimitRatio() throws { | ||
let channel = EmbeddedChannel() | ||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .ratio(10))).wait() | ||
|
||
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "13")]) | ||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) | ||
|
||
let buffer = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) | ||
let compressed = compress(buffer, "gzip") | ||
|
||
do { | ||
try channel.writeInbound(HTTPServerRequestPart.body(compressed)) | ||
} catch let error as NIOHTTPDecompression.DecompressionError { | ||
switch error { | ||
case .limit: | ||
// ok | ||
break | ||
default: | ||
XCTFail("Unexptected error: \(error)") | ||
} | ||
} | ||
} | ||
|
||
func testDecompressionLimitSize() throws { | ||
let channel = EmbeddedChannel() | ||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .size(10))).wait() | ||
|
||
let headers = HTTPHeaders([("Content-Encoding", "gzip"), ("Content-Length", "13")]) | ||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) | ||
|
||
let buffer = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17]) | ||
let compressed = compress(buffer, "gzip") | ||
|
||
do { | ||
try channel.writeInbound(HTTPServerRequestPart.body(compressed)) | ||
} catch let error as NIOHTTPDecompression.DecompressionError { | ||
switch error { | ||
case .limit: | ||
// ok | ||
break | ||
default: | ||
XCTFail("Unexptected error: \(error)") | ||
} | ||
} | ||
} | ||
|
||
func testDecompression() throws { | ||
let channel = EmbeddedChannel() | ||
try channel.pipeline.addHandler(NIOHTTPRequestDecompressor(limit: .none)).wait() | ||
|
||
let body = Array(repeating: testString, count: 1000).joined() | ||
|
||
for algorithm in [nil, "gzip", "deflate"] { | ||
let compressed: ByteBuffer | ||
var headers = HTTPHeaders() | ||
if let algorithm = algorithm { | ||
headers.add(name: "Content-Encoding", value: algorithm) | ||
compressed = compress(ByteBuffer.of(string: body), algorithm) | ||
} else { | ||
compressed = ByteBuffer.of(string: body) | ||
} | ||
headers.add(name: "Content-Length", value: "\(compressed.readableBytes)") | ||
|
||
XCTAssertNoThrow( | ||
try channel.writeInbound(HTTPServerRequestPart.head(.init(version: .init(major: 1, minor: 1), method: .POST, uri: "https://nio.swift.org/test", headers: headers))) | ||
) | ||
|
||
do { | ||
try channel.writeInbound(HTTPServerRequestPart.body(compressed)) | ||
} catch let error as NIOHTTPDecompression.DecompressionError { | ||
switch error { | ||
case .limit: | ||
// ok | ||
break | ||
default: | ||
XCTFail("Unexptected error: \(error)") | ||
} | ||
} | ||
} | ||
|
||
XCTAssertNoThrow(try channel.writeInbound(HTTPServerRequestPart.end(nil))) | ||
} | ||
|
||
private func compress(_ body: ByteBuffer, _ algorithm: String) -> ByteBuffer { | ||
var stream = z_stream() | ||
|
||
stream.zalloc = nil | ||
stream.zfree = nil | ||
stream.opaque = nil | ||
|
||
var buffer = ByteBufferAllocator().buffer(capacity: 1000) | ||
|
||
let windowBits: Int32 | ||
switch algorithm { | ||
case "deflate": | ||
windowBits = 15 | ||
case "gzip": | ||
windowBits = 16 + 15 | ||
default: | ||
XCTFail("Unsupported algorithm: \(algorithm)") | ||
return buffer | ||
} | ||
|
||
let rc = CNIOExtrasZlib_deflateInit2(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, windowBits, 8, Z_DEFAULT_STRATEGY) | ||
XCTAssertEqual(Z_OK, rc) | ||
|
||
defer { | ||
stream.avail_in = 0 | ||
stream.next_in = nil | ||
stream.avail_out = 0 | ||
stream.next_out = nil | ||
} | ||
|
||
var body = body | ||
|
||
body.readWithUnsafeMutableReadableBytes { dataPtr in | ||
let typedPtr = dataPtr.baseAddress!.assumingMemoryBound(to: UInt8.self) | ||
let typedDataPtr = UnsafeMutableBufferPointer(start: typedPtr, | ||
count: dataPtr.count) | ||
|
||
stream.avail_in = UInt32(typedDataPtr.count) | ||
stream.next_in = typedDataPtr.baseAddress! | ||
|
||
buffer.writeWithUnsafeMutableBytes { outputPtr in | ||
let typedOutputPtr = UnsafeMutableBufferPointer(start: outputPtr.baseAddress!.assumingMemoryBound(to: UInt8.self), | ||
count: outputPtr.count) | ||
stream.avail_out = UInt32(typedOutputPtr.count) | ||
stream.next_out = typedOutputPtr.baseAddress! | ||
let rc = deflate(&stream, Z_FINISH) | ||
XCTAssertTrue(rc == Z_OK || rc == Z_STREAM_END) | ||
return typedOutputPtr.count - Int(stream.avail_out) | ||
} | ||
|
||
return typedDataPtr.count - Int(stream.avail_in) | ||
} | ||
|
||
deflateEnd(&stream) | ||
|
||
return buffer | ||
} | ||
} | ||
|
||
extension ByteBuffer { | ||
fileprivate static func of(string: String) -> ByteBuffer { | ||
var buffer = ByteBufferAllocator().buffer(capacity: string.count) | ||
buffer.writeString(string) | ||
return buffer | ||
} | ||
|
||
fileprivate static func of(bytes: [UInt8]) -> ByteBuffer { | ||
var buffer = ByteBufferAllocator().buffer(capacity: bytes.count) | ||
buffer.writeBytes(bytes) | ||
return buffer | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters