Skip to content

Commit

Permalink
Fixes + Improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
andiricum2 committed May 23, 2024
1 parent a4bcf29 commit 4ebc9cb
Show file tree
Hide file tree
Showing 28 changed files with 321 additions and 334 deletions.
11 changes: 3 additions & 8 deletions EXAMPLE.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
import threading
from pieraknet.server import Server as PieRakNet
from pieraknet.packets.game_packet import GamePacket
from pieraknet.connection import Connection as RakNetConnection
from pieraknet.packets.frame_set import Frame

import logging
import os
import time
import random

from pieraknet.server import Server as PieRakNet

class BedrockServer:
def __init__(self, hostname="0.0.0.0", port=19132, logger=logging.getLogger("PieBedrock"), gamemode="survival", timeout=20):
self.initialized = False
Expand Down Expand Up @@ -85,4 +80,4 @@ def stop(self):
server.start()
except KeyboardInterrupt:
server.logger.info('Stopping...')
server.stop()
server.stop()
2 changes: 1 addition & 1 deletion pieraknet/__main__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .server import Server

if __name__ == '__main__':
server = Server()
server = Server(logginglevel = "DEBUG")
server.responseData = "MCPE;PieRakNet Server;589;1.20.0;2;20;13253860892328930865;Powered by PieMC;Survival;1;19132;19133;"
try:
server.start()
Expand Down
18 changes: 4 additions & 14 deletions pieraknet/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,35 +23,27 @@
class UnsupportedIPVersion(Exception):
pass


class EOSError(Exception):
pass


class BuffError(Exception):
pass


class Buffer(BytesIO):
def feos(self):
if len(self.getvalue()[self.tell():]) == 0:
return True
else:
return False
return len(self.getvalue()[self.tell():]) == 0

def read_packet_id(self): # Read Packet ID
return self.read_byte()

def write_packet_id(self, data):
self.write_byte(str(data))
self.write_byte(data)

def read_byte(self):
return struct.unpack('B', self.read(1))[0]

def write_byte(self, data):
if not isinstance(data, bytes):
data = str(data).encode()
self.write(struct.pack('B', int(data)))
self.write(struct.pack('B', data))

def read_bits(self, num_bits):
byte_data = self.read((num_bits + 7) // 8)
Expand All @@ -76,8 +68,6 @@ def read_ubyte(self):
return struct.unpack('<B', self.read(1))[0]

def write_ubyte(self, data):
if not isinstance(data, bytes):
data = data.encode('utf-8')
self.write(struct.pack('<B', data))

def read_short(self):
Expand Down Expand Up @@ -146,9 +136,9 @@ def read_string(self):
return string

def write_string(self, data):
self.write_short(len(data))
if not isinstance(data, bytes):
data = data.encode('ascii')
self.write_short(len(data))
self.write(data)

def read_address(self):
Expand Down
169 changes: 77 additions & 92 deletions pieraknet/connection.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import time
from pieraknet.packets.frame_set import FrameSet, Frame
from time import time

from pieraknet.packets.frame_set import Frame, FrameSet
from pieraknet.protocol_info import ProtocolInfo
from pieraknet.packets.acknowledgement import Ack, Nack
from pieraknet.handlers.connection_request import ConnectionRequestHandler
Expand All @@ -22,17 +23,16 @@ def __init__(self, address, server, mtu_size, guid):
self.fragmented_packets = {}
self.compound_id = 0
self.client_sequence_numbers = []
# self.server_sequence_numbers = []
self.client_sequence_number = 0
self.server_sequence_number = 0
self.queue = FrameSet()
self.server_reliable_frame_index = 0
self.client_reliable_frame_index = 0
self.channel_index = [0] * 32
self.last_receive_time = time.time()
self.last_receive_time = time()

def update(self):
if (time.time() - self.last_receive_time) >= self.server.timeout:
if (time() - self.last_receive_time) >= self.server.timeout:
self.disconnect()
self.send_ack_queue()
self.send_nack_queue()
Expand All @@ -42,27 +42,26 @@ def send_data(self, data: bytes):
self.server.send(data, self.address)

def handle(self, data):
self.last_receive_time = time.time()
self.server.logger.debug(f"New Packet: {data}") # Log the received packet data
if data[0] == ProtocolInfo.ACK:
self.last_receive_time = time()
self.server.logger.info(f"New Packet: {data}")
packet_type = data[0]
if packet_type == ProtocolInfo.ACK:
self.handle_ack(data)
elif data[0] == ProtocolInfo.NACK:
elif packet_type == ProtocolInfo.NACK:
self.handle_nack(data)
elif ProtocolInfo.FRAME_SET_0 <= data[0] <= ProtocolInfo.FRAME_SET_F:
self.handle_frame_set(data) # Pass the raw binary data to the method
self.server.logger.debug("Frame Set handled.") # Log that the frame set was handled

elif ProtocolInfo.FRAME_SET_0 <= packet_type <= ProtocolInfo.FRAME_SET_F:
self.handle_frame_set(data)
self.server.logger.info("Frame Set handled.") # Log that the frame set was handled

def handle_ack(self, data: bytes):
self.server.logger.debug("Handling ACK packet...")
self.server.logger.info("Handling ACK packet...")
packet = Ack(data)
packet.decode()
for sequence_number in packet.sequence_numbers:
if sequence_number in self.recovery_queue:
del self.recovery_queue[sequence_number]
self.recovery_queue.pop(sequence_number, None)

def handle_nack(self, data: bytes):
self.server.logger.debug("Handling NACK packet...")
self.server.logger.info("Handling NACK packet...")
packet = Nack(data)
packet.decode()
for sequence_number in packet.sequence_numbers:
Expand All @@ -75,20 +74,17 @@ def handle_nack(self, data: bytes):
del self.recovery_queue[sequence_number]

def handle_frame_set(self, data):
self.server.logger.debug("Handling Frame Set...")
buf = Buffer(data) # Create a Buffer instance from the received data
self.server.logger.info("Handling Frame Set...")
buf = Buffer(data)
frame_set = FrameSet()
frame_set.decode(data) # Pass the Buffer instance to the decode method
frame_set.decode(data)
if frame_set.sequence_number not in self.client_sequence_numbers:
if frame_set.sequence_number in self.nack_queue:
self.nack_queue.remove(frame_set.sequence_number)
self.client_sequence_numbers.append(frame_set.sequence_number)
self.ack_queue.append(frame_set.sequence_number)
hole_size = frame_set.sequence_number - self.client_sequence_number
if hole_size > 0:
for sequence_number in range(self.client_sequence_number + 1, hole_size):
if sequence_number not in self.client_sequence_numbers:
self.nack_queue.append(sequence_number)
self.nack_queue.extend(
range(self.client_sequence_number + 1, hole_size))
self.client_sequence_number = frame_set.sequence_number
for frame in frame_set.frames:
if not (2 <= frame.reliability <= 7 and frame.reliability != 5):
Expand All @@ -99,65 +95,58 @@ def handle_frame_set(self, data):
self.handle_frame(frame)
self.client_reliable_frame_index += 1


def handle_fragmented_frame(self, packet):
self.server.logger.debug("Handling Fragmented Frame...")
if packet.compound_id not in self.fragmented_packets:
self.fragmented_packets[packet.compound_id] = {packet.index: packet}
else:
self.fragmented_packets[packet.compound_id][packet.index] = packet
if len(self.fragmented_packets[packet.compound_id]) == packet.compound_size:
self.server.logger.info("Handling Fragmented Frame...")
fragments = self.fragmented_packets.setdefault(packet.compound_id, {})
fragments[packet.index] = packet
if len(fragments) == packet.compound_size:
new_frame = Frame()
new_frame.body = b''
for i in range(0, packet.compound_size):
new_frame.body += self.fragmented_packets[packet.compound_id][i].body
new_frame.body = b''.join(fragments[i].body for i in range(packet.compound_size))
del self.fragmented_packets[packet.compound_id]
self.handle_frame(new_frame)

def handle_frame(self, packet):
self.server.logger.debug("Handling Frame...")
self.server.logger.info("Handling Frame...")
if packet.fragmented:
self.handle_fragmented_frame(packet)
else:
packet_type = packet.body[0]
if not self.connected:
if packet.body[0] == ProtocolInfo.CONNECTION_REQUEST:
new_frame = Frame()
new_frame.reliability = 0
if packet_type == ProtocolInfo.CONNECTION_REQUEST:
new_frame = Frame(reliability=0)
new_frame.body = ConnectionRequestHandler.handle(packet.body, self.server, self)
self.add_to_queue(new_frame)
elif packet.body[0] == ProtocolInfo.NEW_INCOMING_CONNECTION:
elif packet_type == ProtocolInfo.NEW_INCOMING_CONNECTION:
packet = NewIncomingConnection(packet.body)
packet.decode()
if packet.server_address[1] == self.server.port:
self.connected = True
if hasattr(self.server, "interface"):
if hasattr(self.server.interface, "on_new_incoming_connection"):
self.server.interface.on_new_incoming_connection(self)
elif packet.body[0] == ProtocolInfo.ONLINE_PING:
new_frame = Frame()
new_frame.reliability = 0
new_frame.body = OnlinePingHandler.handle(OnlinePing(packet.body), self, self.server)
self.add_to_queue(new_frame, False)
elif packet.body[0] == ProtocolInfo.DISCONNECT:
self.disconnect()
elif packet.body[0] == ProtocolInfo.GAME_PACKET:
if hasattr(self.server, "interface"):
if hasattr(self.server.interface, "on_game_packet"):
self.server.interface.on_game_packet(packet, self)
if hasattr(self.server, "interface") and hasattr(self.server.interface, "on_new_incoming"):
self.server.interface.on_new_incoming(self)
else:
if hasattr(self.server, "interface"):
if hasattr(self.server.interface, "on_unknown_packet"):
if packet_type == ProtocolInfo.ONLINE_PING:
new_frame = Frame(reliability=0)
new_frame.body = OnlinePingHandler.handle(OnlinePing(packet.body), self, self.server)
self.add_to_queue(new_frame, False)
elif packet_type == ProtocolInfo.DISCONNECT:
self.disconnect()
elif packet_type == ProtocolInfo.GAME_PACKET:
if hasattr(self.server, "interface") and hasattr(self.server.interface, "on_game_packet"):
self.server.interface.on_game_packet(packet, self)
else:
if hasattr(self.server, "interface") and hasattr(self.server.interface, "on_unknown_packet"):
self.server.interface.on_unknown_packet(packet, self)

def send_queue(self):
self.server.logger.debug("Sending Queue...")
if len(self.queue.frames) > 0:
self.queue.sequence_number = self.server_sequence_number
self.server_sequence_number += 1
self.recovery_queue[self.queue.sequence_number] = self.queue
self.queue.encode()
self.send_data(self.queue.getvalue())
self.queue = FrameSet()
if not self.queue.frames:
return

self.queue.sequence_number = self.server_sequence_number
self.server_sequence_number += 1
self.recovery_queue[self.queue.sequence_number] = self.queue
self.queue.encode()
self.send_data(self.queue.getvalue())
self.queue = FrameSet()

def add_to_queue(self, packet: Frame, is_immediate=True):
if 2 <= packet.reliability <= 7 and packet.reliability != 5:
Expand All @@ -166,18 +155,18 @@ def add_to_queue(self, packet: Frame, is_immediate=True):
if packet.reliability == 3:
packet.ordered_frame_index = self.channel_index[packet.order_channel]
self.channel_index[packet.order_channel] += 1

if packet.get_size() > self.mtu_size:
fragmented_body = []
for i in range(0, len(packet.body), self.mtu_size):
fragmented_body.append(packet.body[i:i + self.mtu_size])
fragmented_body = [packet.body[i:i + self.mtu_size] for i in range(0, len(packet.body), self.mtu_size)]
for index, body in enumerate(fragmented_body):
new_packet = Frame()
new_packet.fragmented = True
new_packet.reliability = packet.reliability
new_packet.compound_id = self.compound_id
new_packet.compound_size = len(fragmented_body)
new_packet.index = index
new_packet.body = body
new_packet = Frame(
fragmented=True,
reliability=packet.reliability,
compound_id=self.compound_id,
compound_size=len(fragmented_body),
index=index,
body=body
)
if index != 0:
new_packet.reliable_frame_index = self.server_reliable_frame_index
self.server_reliable_frame_index += 1
Expand All @@ -188,8 +177,8 @@ def add_to_queue(self, packet: Frame, is_immediate=True):
self.queue.frames.append(new_packet)
self.send_queue()
else:
frame_size: int = new_packet.get_size()
queue_size: int = self.queue.get_size()
frame_size = new_packet.get_size()
queue_size = self.queue.get_size()
if frame_size + queue_size >= self.mtu_size:
self.send_queue()
self.queue.frames.append(new_packet)
Expand All @@ -199,37 +188,33 @@ def add_to_queue(self, packet: Frame, is_immediate=True):
self.queue.frames.append(packet)
self.send_queue()
else:
frame_size: int = packet.get_size()
queue_size: int = self.queue.get_size()
frame_size = packet.get_size()
queue_size = self.queue.get_size()
if frame_size + queue_size >= self.mtu_size:
self.send_queue()
self.queue.frames.append(packet)

def send_ack_queue(self):
if len(self.ack_queue) > 0:
packet = Ack()
packet.sequence_numbers = self.ack_queue
self.ack_queue = []
if self.ack_queue:
packet = Ack(sequence_numbers=self.ack_queue)
packet.encode()
self.send_data(packet.getvalue())
self.ack_queue.clear()

def send_nack_queue(self):
if len(self.nack_queue) > 0:
packet = Nack()
packet.sequence_numbers = self.nack_queue
self.nack_queue = []
if self.nack_queue:
packet = Nack(sequence_numbers=self.nack_queue)
packet.encode()
self.send_data(packet.getvalue())
self.nack_queue.clear()

def disconnect(self):
self.server.logger.debug("Disconnecting...")
new_frame = Frame()
new_frame.reliability = 0
self.server.logger.info("Disconnecting...")
new_frame = Frame(reliability=0)
disconnect_packet = Disconnect()
disconnect_packet.encode()
new_frame.body = disconnect_packet.getvalue()
self.add_to_queue(new_frame)
self.server.remove_connection(self.address)
if hasattr(self.server, "interface"):
if hasattr(self.server.interface, "on_disconnect"):
self.server.interface.on_disconnect(self)
if hasattr(self.server.interface, "on_disconnect"):
self.server.interface.on_disconnect(self)
11 changes: 6 additions & 5 deletions pieraknet/handlers/connection_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
from pieraknet.packets.connection_request import ConnectionRequest
from pieraknet.packets.connection_request_accepted import ConnectionRequestAccepted


class ConnectionRequestHandler:
@staticmethod
def handle(packet: bytes, server, connection):
packet = ConnectionRequest(packet)
packet.decode()
request_packet = ConnectionRequest(packet)
request_packet.decode()

new_packet = ConnectionRequestAccepted()
new_packet.client_address = connection.address
new_packet.system_index = 0
new_packet.internal_ids = [('255.255.255.255', 19132)] * 10
new_packet.request_time = packet.client_timestamp
new_packet.request_time = request_packet.client_timestamp
new_packet.accepted_time = int(time.time())
new_packet.encode()
return new_packet.getvalue()

return new_packet.getvalue()
Loading

0 comments on commit 4ebc9cb

Please sign in to comment.