From bc1b70dcde4cf22ef52fe794125c533afcee2655 Mon Sep 17 00:00:00 2001 From: Tiger Wang Date: Sat, 18 Jul 2020 19:22:28 +0100 Subject: Use cMultiVersionProtocol's buffer --- src/Protocol/Protocol.h | 2 +- src/Protocol/ProtocolRecognizer.cpp | 4 +-- src/Protocol/Protocol_1_8.cpp | 57 ++++++++++++++++++------------------- src/Protocol/Protocol_1_8.h | 7 ++--- 4 files changed, 33 insertions(+), 37 deletions(-) (limited to 'src/Protocol') diff --git a/src/Protocol/Protocol.h b/src/Protocol/Protocol.h index 96e837bb0..c2b3cd3f0 100644 --- a/src/Protocol/Protocol.h +++ b/src/Protocol/Protocol.h @@ -343,7 +343,7 @@ public: }; /** Called when client sends some data */ - virtual void DataReceived(const char * a_Data, size_t a_Size) = 0; + virtual void DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) = 0; // Sending stuff to clients (alphabetically sorted): virtual void SendAttachEntity (const cEntity & a_Entity, const cEntity & a_Vehicle) = 0; diff --git a/src/Protocol/ProtocolRecognizer.cpp b/src/Protocol/ProtocolRecognizer.cpp index 6dc1c8bd8..46c2d239f 100644 --- a/src/Protocol/ProtocolRecognizer.cpp +++ b/src/Protocol/ProtocolRecognizer.cpp @@ -50,7 +50,7 @@ struct sTriedToJoinWithUnsupportedProtocolException : public std::runtime_error cMultiVersionProtocol::cMultiVersionProtocol() : HandleIncomingData(std::bind(&cMultiVersionProtocol::HandleIncomingDataInRecognitionStage, this, std::placeholders::_1, std::placeholders::_2)), - m_Buffer(8 KiB) // We need a larger buffer to support BungeeCord - it sends one huge packet at the start + m_Buffer(32 KiB) { } @@ -107,7 +107,7 @@ void cMultiVersionProtocol::HandleIncomingDataInRecognitionStage(cClientHandle & HandleIncomingData = [this](cClientHandle &, const std::string_view a_In) { // TODO: make it take our a_ReceivedData - m_Protocol->DataReceived(a_In.data(), a_In.size()); + m_Protocol->DataReceived(m_Buffer, a_In.data(), a_In.size()); }; } catch (const sUnsupportedButPingableProtocolException &) diff --git a/src/Protocol/Protocol_1_8.cpp b/src/Protocol/Protocol_1_8.cpp index 1c2792461..6cfe370cd 100644 --- a/src/Protocol/Protocol_1_8.cpp +++ b/src/Protocol/Protocol_1_8.cpp @@ -107,7 +107,6 @@ cProtocol_1_8_0::cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_Ser m_ServerAddress(a_ServerAddress), m_ServerPort(a_ServerPort), m_State(a_State), - m_ReceivedData(32 KiB), m_IsEncrypted(false) { AStringVector Params; @@ -183,7 +182,7 @@ cProtocol_1_8_0::cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_Ser -void cProtocol_1_8_0::DataReceived(const char * a_Data, size_t a_Size) +void cProtocol_1_8_0::DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) { if (m_IsEncrypted) { @@ -192,14 +191,14 @@ void cProtocol_1_8_0::DataReceived(const char * a_Data, size_t a_Size) { size_t NumBytes = (a_Size > sizeof(Decrypted)) ? sizeof(Decrypted) : a_Size; m_Decryptor.ProcessData(Decrypted, reinterpret_cast(a_Data), NumBytes); - AddReceivedData(reinterpret_cast(Decrypted), NumBytes); + AddReceivedData(a_Buffer, reinterpret_cast(Decrypted), NumBytes); a_Size -= NumBytes; a_Data += NumBytes; } } else { - AddReceivedData(a_Data, a_Size); + AddReceivedData(a_Buffer, a_Data, a_Size); } } @@ -1873,19 +1872,19 @@ UInt32 cProtocol_1_8_0::GetProtocolMobType(eMonsterType a_MobType) -void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) +void cProtocol_1_8_0::AddReceivedData(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) { // Write the incoming data into the comm log file: if (g_ShouldLogCommIn && m_CommLogFile.IsOpen()) { - if (m_ReceivedData.GetReadableSpace() > 0) + if (a_Buffer.GetReadableSpace() > 0) { AString AllData; - size_t OldReadableSpace = m_ReceivedData.GetReadableSpace(); - m_ReceivedData.ReadAll(AllData); - m_ReceivedData.ResetRead(); - m_ReceivedData.SkipRead(m_ReceivedData.GetReadableSpace() - OldReadableSpace); - ASSERT(m_ReceivedData.GetReadableSpace() == OldReadableSpace); + size_t OldReadableSpace = a_Buffer.GetReadableSpace(); + a_Buffer.ReadAll(AllData); + a_Buffer.ResetRead(); + a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace); + ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace); AString Hex; CreateHexDump(Hex, AllData.data(), AllData.size(), 16); m_CommLogFile.Printf("Incoming data, %zu (0x%zx) unparsed bytes already present in buffer:\n%s\n", @@ -1900,7 +1899,7 @@ void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) m_CommLogFile.Flush(); } - if (!m_ReceivedData.Write(a_Data, a_Size)) + if (!a_Buffer.Write(a_Data, a_Size)) { // Too much data in the incoming queue, report to caller: m_Client->PacketBufferFull(); @@ -1911,16 +1910,16 @@ void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) for (;;) { UInt32 PacketLen; - if (!m_ReceivedData.ReadVarInt(PacketLen)) + if (!a_Buffer.ReadVarInt(PacketLen)) { // Not enough data - m_ReceivedData.ResetRead(); + a_Buffer.ResetRead(); break; } - if (!m_ReceivedData.CanReadBytes(PacketLen)) + if (!a_Buffer.CanReadBytes(PacketLen)) { // The full packet hasn't been received yet - m_ReceivedData.ResetRead(); + a_Buffer.ResetRead(); break; } @@ -1929,15 +1928,15 @@ void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) AString UncompressedData; if (m_State == 3) { - UInt32 NumBytesRead = static_cast(m_ReceivedData.GetReadableSpace()); + UInt32 NumBytesRead = static_cast(a_Buffer.GetReadableSpace()); - if (!m_ReceivedData.ReadVarInt(UncompressedSize)) + if (!a_Buffer.ReadVarInt(UncompressedSize)) { m_Client->Kick("Compression packet incomplete"); return; } - NumBytesRead -= static_cast(m_ReceivedData.GetReadableSpace()); // How many bytes has the UncompressedSize taken up? + NumBytesRead -= static_cast(a_Buffer.GetReadableSpace()); // How many bytes has the UncompressedSize taken up? ASSERT(PacketLen > NumBytesRead); PacketLen -= NumBytesRead; @@ -1945,7 +1944,7 @@ void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) { // Decompress the data: AString CompressedData; - VERIFY(m_ReceivedData.ReadString(CompressedData, PacketLen)); + VERIFY(a_Buffer.ReadString(CompressedData, PacketLen)); if (InflateString(CompressedData.data(), PacketLen, UncompressedData) != Z_OK) { m_Client->Kick("Compression failure"); @@ -1965,14 +1964,14 @@ void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) if (UncompressedSize == 0) { // No compression was used, move directly - VERIFY(m_ReceivedData.ReadToByteBuffer(bb, static_cast(PacketLen))); + VERIFY(a_Buffer.ReadToByteBuffer(bb, static_cast(PacketLen))); } else { // Compression was used, move the uncompressed data: VERIFY(bb.Write(UncompressedData.data(), UncompressedData.size())); } - m_ReceivedData.CommitRead(); + a_Buffer.CommitRead(); UInt32 PacketType; if (!bb.ReadVarInt(PacketType)) @@ -2048,18 +2047,18 @@ void cProtocol_1_8_0::AddReceivedData(const char * a_Data, size_t a_Size) } // for (ever) // Log any leftover bytes into the logfile: - if (g_ShouldLogCommIn && (m_ReceivedData.GetReadableSpace() > 0) && m_CommLogFile.IsOpen()) + if (g_ShouldLogCommIn && (a_Buffer.GetReadableSpace() > 0) && m_CommLogFile.IsOpen()) { AString AllData; - size_t OldReadableSpace = m_ReceivedData.GetReadableSpace(); - m_ReceivedData.ReadAll(AllData); - m_ReceivedData.ResetRead(); - m_ReceivedData.SkipRead(m_ReceivedData.GetReadableSpace() - OldReadableSpace); - ASSERT(m_ReceivedData.GetReadableSpace() == OldReadableSpace); + size_t OldReadableSpace = a_Buffer.GetReadableSpace(); + a_Buffer.ReadAll(AllData); + a_Buffer.ResetRead(); + a_Buffer.SkipRead(a_Buffer.GetReadableSpace() - OldReadableSpace); + ASSERT(a_Buffer.GetReadableSpace() == OldReadableSpace); AString Hex; CreateHexDump(Hex, AllData.data(), AllData.size(), 16); m_CommLogFile.Printf("There are %zu (0x%zx) bytes of non-parse-able data left in the buffer:\n%s", - m_ReceivedData.GetReadableSpace(), m_ReceivedData.GetReadableSpace(), Hex.c_str() + a_Buffer.GetReadableSpace(), a_Buffer.GetReadableSpace(), Hex.c_str() ); m_CommLogFile.Flush(); } diff --git a/src/Protocol/Protocol_1_8.h b/src/Protocol/Protocol_1_8.h index b62e3129f..e7686577f 100644 --- a/src/Protocol/Protocol_1_8.h +++ b/src/Protocol/Protocol_1_8.h @@ -34,7 +34,7 @@ public: cProtocol_1_8_0(cClientHandle * a_Client, const AString & a_ServerAddress, UInt16 a_ServerPort, UInt32 a_State); /** Called when client sends some data: */ - virtual void DataReceived(const char * a_Data, size_t a_Size) override; + virtual void DataReceived(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size) override; /** Sending stuff to clients (alphabetically sorted): */ virtual void SendAttachEntity (const cEntity & a_Entity, const cEntity & a_Vehicle) override; @@ -143,9 +143,6 @@ protected: /** State of the protocol. 1 = status, 2 = login, 3 = game */ UInt32 m_State; - /** Buffer for the received data */ - cByteBuffer m_ReceivedData; - bool m_IsEncrypted; cAesCfb128Decryptor m_Decryptor; @@ -155,7 +152,7 @@ protected: cFile m_CommLogFile; /** Adds the received (unencrypted) data to m_ReceivedData, parses complete packets */ - virtual void AddReceivedData(const char * a_Data, size_t a_Size); + virtual void AddReceivedData(cByteBuffer & a_Buffer, const char * a_Data, size_t a_Size); /** Nobody inherits 1.8, so it doesn't use this method */ virtual UInt32 GetPacketID(ePacketType a_Packet) override; -- cgit v1.2.3