diff options
Diffstat (limited to '')
-rw-r--r-- | src/Network.cpp | 126 |
1 files changed, 102 insertions, 24 deletions
diff --git a/src/Network.cpp b/src/Network.cpp index cf4fe15..9cb2097 100644 --- a/src/Network.cpp +++ b/src/Network.cpp @@ -1,5 +1,7 @@ #include "Network.hpp" +#include <zlib.h> + Network::Network(std::string address, unsigned short port) { try { socket = new Socket(address, port); @@ -13,6 +15,8 @@ Network::Network(std::string address, unsigned short port) { } catch (std::exception &e) { LOG(WARNING) << "Stream creation failed: " << e.what(); } + + } Network::~Network() { @@ -20,40 +24,114 @@ Network::~Network() { delete socket; } -std::shared_ptr<Packet> Network::ReceivePacket(ConnectionState state) { - int packetSize = stream->ReadVarInt(); - auto packetData = stream->ReadByteArray(packetSize); - StreamBuffer streamBuffer(packetData.data(), packetData.size()); - int packetId = streamBuffer.ReadVarInt(); - auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer); - return packet; +std::shared_ptr<Packet> Network::ReceivePacket(ConnectionState state, bool useCompression) { + if (useCompression) { + int packetLength = stream->ReadVarInt(); + auto packetData = stream->ReadByteArray(packetLength); + StreamBuffer streamBuffer(packetData.data(), packetData.size()); + + int dataLength = streamBuffer.ReadVarInt(); + if (dataLength == 0) { + auto packetData = streamBuffer.ReadByteArray(packetLength - streamBuffer.GetReadedLength()); + StreamBuffer streamBuffer(packetData.data(), packetData.size()); + int packetId = streamBuffer.ReadVarInt(); + auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer); + return packet; + } else { + std::vector<unsigned char> compressedData = streamBuffer.ReadByteArray(packetLength - streamBuffer.GetReadedLength()); + std::vector<unsigned char> uncompressedData; + uncompressedData.resize(dataLength); + + z_stream stream; + stream.avail_in = compressedData.size(); + stream.next_in = compressedData.data(); + stream.avail_out = uncompressedData.size(); + stream.next_out = uncompressedData.data(); + stream.zalloc = Z_NULL; + stream.zfree = Z_NULL; + stream.opaque = Z_NULL; + if (inflateInit(&stream) != Z_OK) + throw std::runtime_error("Zlib decompression initalization error"); + + int status = inflate(&stream, Z_FINISH); + switch (status) { + case Z_STREAM_END: + break; + case Z_OK: + case Z_STREAM_ERROR: + case Z_BUF_ERROR: + throw std::runtime_error("Zlib decompression error: " + std::to_string(status)); + } + + if (inflateEnd(&stream) != Z_OK) + throw std::runtime_error("Zlib decompression end error"); + + StreamBuffer streamBuffer(uncompressedData.data(), uncompressedData.size()); + int packetId = streamBuffer.ReadVarInt(); + auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer); + return packet; + } + } else { + int packetSize = stream->ReadVarInt(); + auto packetData = stream->ReadByteArray(packetSize); + StreamBuffer streamBuffer(packetData.data(), packetData.size()); + int packetId = streamBuffer.ReadVarInt(); + auto packet = ReceivePacketByPacketId(packetId, state, streamBuffer); + return packet; + } } -void Network::SendPacket(Packet &packet) { - StreamCounter packetSize; - packetSize.WriteVarInt(packet.GetPacketId()); - packet.ToStream(&packetSize); - stream->WriteVarInt(packetSize.GetCountedSize()); - stream->WriteVarInt(packet.GetPacketId()); - packet.ToStream(stream); +void Network::SendPacket(Packet &packet, int compressionThreshold) { + if (compressionThreshold >= 0) { + StreamCounter packetSize; + packetSize.WriteVarInt(packet.GetPacketId()); + packetSize.WriteVarInt(0); + packet.ToStream(&packetSize); + if (packetSize.GetCountedSize() < compressionThreshold) { + stream->WriteVarInt(packetSize.GetCountedSize()); + stream->WriteVarInt(0); + stream->WriteVarInt(packet.GetPacketId()); + packet.ToStream(stream); + } else { + throw std::runtime_error("Compressing data"); + /*StreamBuffer buffer(packetSize.GetCountedSize()); + packet.ToStream(&buffer); + + z_stream stream;*/ + } + } + else { + StreamCounter packetSize; + packetSize.WriteVarInt(packet.GetPacketId()); + packet.ToStream(&packetSize); + stream->WriteVarInt(packetSize.GetCountedSize()); + stream->WriteVarInt(packet.GetPacketId()); + packet.ToStream(stream); + } } std::shared_ptr<Packet> Network::ReceivePacketByPacketId(int packetId, ConnectionState state, StreamInput &stream) { std::shared_ptr < Packet > packet(nullptr); switch (state) { case Handshaking: - switch (packetId) { - case PacketNameHandshakingCB::Handshake: - packet = std::make_shared<PacketHandshake>(); - break; - } + switch (packetId) { + case PacketNameHandshakingCB::Handshake: + packet = std::make_shared<PacketHandshake>(); + break; + } break; case Login: - switch (packetId) { - case PacketNameLoginCB::LoginSuccess: - packet = std::make_shared<PacketLoginSuccess>(); - break; - } + switch (packetId) { + case PacketNameLoginCB::LoginSuccess: + packet = std::make_shared<PacketLoginSuccess>(); + break; + case PacketNameLoginCB::SetCompression: + packet = std::make_shared<PacketSetCompression>(); + break; + case PacketNameLoginCB::Disconnect: + packet = std::make_shared<PacketDisconnect>(); + break; + } break; case Play: packet = ParsePacketPlay((PacketNamePlayCB) packetId); |