diff --git a/Sources/MySQLNIO/Protocol/MySQLProtocol+HandshakeV10.swift b/Sources/MySQLNIO/Protocol/MySQLProtocol+HandshakeV10.swift index 0636b36..aaf4a99 100644 --- a/Sources/MySQLNIO/Protocol/MySQLProtocol+HandshakeV10.swift +++ b/Sources/MySQLNIO/Protocol/MySQLProtocol+HandshakeV10.swift @@ -103,13 +103,13 @@ extension MySQLProtocol { guard let reserved1 = packet.payload.readSlice(length: 6) else { throw Error.missingReserved } - assert(reserved1.isZeroes, "invalid reserve 1 \(reserved1)") + assert(reserved1.readableBytesView.allSatisfy { $0 == 0 }, "invalid reserve 1 \(reserved1)") if capabilities.contains(.CLIENT_LONG_PASSWORD) { /// string[4] reserved (all [00]) guard let reserved2 = packet.payload.readSlice(length: 4) else { throw Error.missingReserved } - assert(reserved2.isZeroes, "invalid reserve 2: \(reserved2)") + assert(reserved2.readableBytesView.allSatisfy { $0 == 0 }, "invalid reserve 2: \(reserved2)") } else { /// Capabilities 3rd part. MariaDB specific flags. /// MariaDB Initial Handshake Packet specific flags diff --git a/Sources/MySQLNIO/Utilities/NIOUtils.swift b/Sources/MySQLNIO/Utilities/NIOUtils.swift index 5d45727..f720ae5 100644 --- a/Sources/MySQLNIO/Utilities/NIOUtils.swift +++ b/Sources/MySQLNIO/Utilities/NIOUtils.swift @@ -1,58 +1,44 @@ extension ByteBuffer { mutating func readNullTerminatedString() -> String? { - var copy = self - while let byte = copy.readInteger(as: UInt8.self), byte != 0x00 { continue } - defer { self.moveReaderIndex(forwardBy: 1) } - return self.readString(length: (self.readableBytes - copy.readableBytes) - 1) + guard let nullIndex = readableBytesView.firstIndex(of: 0) else { + return nil + } + + defer { moveReaderIndex(forwardBy: 1) } + return readString(length: nullIndex - readerIndex) + } + + @discardableResult + mutating func writeNullTerminatedString(_ string: String) -> Int { + return self.writeString(string) + + self.writeInteger(0, as: UInt8.self) } mutating func readInteger(endianness: Endianness = .big, as: T.Type = T.self) -> T? where T: RawRepresentable, T.RawValue: FixedWidthInteger { return self.readInteger(endianness: endianness, as: T.RawValue.self) - .flatMap { T(rawValue: $0) } + .flatMap(T.init(rawValue:)) } - mutating func writeNullTerminatedString(_ string: String) { - self.writeString(string) - self.writeInteger(0, as: UInt8.self) - } - - mutating func writeLengthEncodedInteger(_ integer: UInt64) { + @discardableResult + mutating func writeLengthEncodedInteger(_ integer: UInt64) -> Int { switch integer { case 0..<251: - self.writeInteger(numericCast(integer), as: UInt8.self) + return self.writeInteger(numericCast(integer), as: UInt8.self) case 251..<1<<16: - self.writeInteger(0xFC, as: UInt8.self) - self.writeInteger(numericCast(integer), endianness: .little, as: UInt16.self) + return self.writeBytes([0xfc, .init(integer & 0xff), .init(integer >> 8 & 0xff)]) case 1<<16..<1<<24: - self.writeInteger(0xFD, as: UInt8.self) - self.writeInteger(numericCast(integer & 0xFF), as: UInt8.self) - self.writeInteger(numericCast(integer >> 8 & 0xFF), as: UInt8.self) - self.writeInteger(numericCast(integer >> 16 & 0xFF), as: UInt8.self) + return self.writeBytes([0xfd, .init(integer & 0xff), .init(integer >> 8 & 0xff), .init(integer >> 16 & 0xff)]) default: - self.writeInteger(0xFE, as: UInt8.self) - self.writeInteger(numericCast(integer), endianness: .little, as: UInt64.self) + return self.writeInteger(0xfe, as: UInt8.self) + self.writeInteger(integer, endianness: .little, as: UInt64.self) } } - mutating func writeLengthEncodedSlice(_ buffer: inout ByteBuffer) { - self.writeLengthEncodedInteger(numericCast(buffer.readableBytes)) - self.writeBuffer(&buffer) - } - - var readableString: String? { - return self.getString(at: self.readerIndex, length: self.readableBytes) - } - - var isZeroes: Bool { - for byte in self.readableBytesView { - switch byte { - case 0x00: continue - default: return false - } - } - return true + @discardableResult + mutating func writeLengthEncodedSlice(_ buffer: inout ByteBuffer) -> Int { + return self.writeLengthEncodedInteger(numericCast(buffer.readableBytes)) + + self.writeBuffer(&buffer) } mutating func readLengthEncodedString() -> String? { @@ -70,38 +56,17 @@ extension ByteBuffer { } mutating func readLengthEncodedInteger() -> UInt64? { - guard let first = self.readInteger(endianness: .little, as: UInt8.self) else { - return nil - } - - switch first { - case 0xFC: - guard let uint16 = readInteger(endianness: .little, as: UInt16.self) else { - return nil - } - return numericCast(uint16) - case 0xFD: - guard let one = readInteger(endianness: .little, as: UInt8.self) else { + switch self.readInteger(endianness: .little, as: UInt8.self) { + case .none: return nil - } - guard let two = readInteger(endianness: .little, as: UInt8.self) else { - return nil - } - guard let three = readInteger(endianness: .little, as: UInt8.self) else { - return nil - } - var num: UInt64 = 0 - num += numericCast(one) << 0 - num += numericCast(two) << 8 - num += numericCast(three) << 16 - return num - case 0xFE: - guard let uint64 = readInteger(endianness: .little, as: UInt64.self) else { - return nil - } - return uint64 - default: - return numericCast(first) + case .some(0xfc): + return readInteger(endianness: .little, as: UInt16.self).map(numericCast) + case .some(0xfd): + return readBytes(length: 3).map { $0.reversed().reduce(UInt64.zero) { ($0 << 8) | numericCast($1) } } + case .some(0xfe): + return readInteger(endianness: .little, as: UInt64.self) + case .some(let byte): + return numericCast(byte) } } }