From ca75c58f4391fcd20d325333c802f9e225ff9c47 Mon Sep 17 00:00:00 2001 From: Liam Date: Wed, 25 Oct 2023 12:59:11 -0400 Subject: sockets: use safe access helpers --- src/core/hle/service/sockets/bsd.cpp | 77 +++++++++++++++++------------------- src/core/hle/service/sockets/bsd.h | 2 +- 2 files changed, 38 insertions(+), 41 deletions(-) (limited to 'src') diff --git a/src/core/hle/service/sockets/bsd.cpp b/src/core/hle/service/sockets/bsd.cpp index 85849d5f3..dd652ca42 100644 --- a/src/core/hle/service/sockets/bsd.cpp +++ b/src/core/hle/service/sockets/bsd.cpp @@ -39,6 +39,18 @@ bool IsConnectionBased(Type type) { } } +template +T GetValue(std::span buffer) { + T t{}; + std::memcpy(&t, buffer.data(), std::min(sizeof(T), buffer.size())); + return t; +} + +template +void PutValue(std::span buffer, const T& t) { + std::memcpy(buffer.data(), &t, std::min(sizeof(T), buffer.size())); +} + } // Anonymous namespace void BSD::PollWork::Execute(BSD* bsd) { @@ -316,22 +328,12 @@ void BSD::SetSockOpt(HLERequestContext& ctx) { const s32 fd = rp.Pop(); const u32 level = rp.Pop(); const OptName optname = static_cast(rp.Pop()); - - const auto buffer = ctx.ReadBuffer(); - const u8* optval = buffer.empty() ? nullptr : buffer.data(); - size_t optlen = buffer.size(); - - std::array values; - if ((optname == OptName::SNDTIMEO || optname == OptName::RCVTIMEO) && buffer.size() == 8) { - std::memcpy(values.data(), buffer.data(), sizeof(values)); - optlen = sizeof(values); - optval = reinterpret_cast(values.data()); - } + const auto optval = ctx.ReadBuffer(); LOG_DEBUG(Service, "called. fd={} level={} optname=0x{:x} optlen={}", fd, level, - static_cast(optname), optlen); + static_cast(optname), optval.size()); - BuildErrnoResponse(ctx, SetSockOptImpl(fd, level, optname, optlen, optval)); + BuildErrnoResponse(ctx, SetSockOptImpl(fd, level, optname, optval)); } void BSD::Shutdown(HLERequestContext& ctx) { @@ -521,18 +523,19 @@ std::pair BSD::SocketImpl(Domain domain, Type type, Protocol protoco std::pair BSD::PollImpl(std::vector& write_buffer, std::span read_buffer, s32 nfds, s32 timeout) { - if (write_buffer.size() < nfds * sizeof(PollFD)) { - return {-1, Errno::INVAL}; - } - - if (nfds == 0) { + if (nfds <= 0) { // When no entries are provided, -1 is returned with errno zero return {-1, Errno::SUCCESS}; } + if (read_buffer.size() < nfds * sizeof(PollFD)) { + return {-1, Errno::INVAL}; + } + if (write_buffer.size() < nfds * sizeof(PollFD)) { + return {-1, Errno::INVAL}; + } - const size_t length = std::min(read_buffer.size(), write_buffer.size()); std::vector fds(nfds); - std::memcpy(fds.data(), read_buffer.data(), length); + std::memcpy(fds.data(), read_buffer.data(), nfds * sizeof(PollFD)); if (timeout >= 0) { const s64 seconds = timeout / 1000; @@ -580,7 +583,7 @@ std::pair BSD::PollImpl(std::vector& write_buffer, std::span BSD::AcceptImpl(s32 fd, std::vector& write_buffer) { new_descriptor.is_connection_based = descriptor.is_connection_based; const SockAddrIn guest_addr_in = Translate(result.sockaddr_in); - const size_t length = std::min(sizeof(guest_addr_in), write_buffer.size()); - std::memcpy(write_buffer.data(), &guest_addr_in, length); + PutValue(write_buffer, guest_addr_in); return {new_fd, Errno::SUCCESS}; } @@ -619,8 +621,7 @@ Errno BSD::BindImpl(s32 fd, std::span addr) { return Errno::BADF; } ASSERT(addr.size() == sizeof(SockAddrIn)); - SockAddrIn addr_in; - std::memcpy(&addr_in, addr.data(), sizeof(addr_in)); + auto addr_in = GetValue(addr); return Translate(file_descriptors[fd]->socket->Bind(Translate(addr_in))); } @@ -631,8 +632,7 @@ Errno BSD::ConnectImpl(s32 fd, std::span addr) { } UNIMPLEMENTED_IF(addr.size() != sizeof(SockAddrIn)); - SockAddrIn addr_in; - std::memcpy(&addr_in, addr.data(), sizeof(addr_in)); + auto addr_in = GetValue(addr); return Translate(file_descriptors[fd]->socket->Connect(Translate(addr_in))); } @@ -650,7 +650,7 @@ Errno BSD::GetPeerNameImpl(s32 fd, std::vector& write_buffer) { ASSERT(write_buffer.size() >= sizeof(guest_addrin)); write_buffer.resize(sizeof(guest_addrin)); - std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); + PutValue(write_buffer, guest_addrin); return Translate(bsd_errno); } @@ -667,7 +667,7 @@ Errno BSD::GetSockNameImpl(s32 fd, std::vector& write_buffer) { ASSERT(write_buffer.size() >= sizeof(guest_addrin)); write_buffer.resize(sizeof(guest_addrin)); - std::memcpy(write_buffer.data(), &guest_addrin, sizeof(guest_addrin)); + PutValue(write_buffer, guest_addrin); return Translate(bsd_errno); } @@ -725,7 +725,7 @@ Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector& o optval.size() == sizeof(Errno), { return Errno::INVAL; }, "Incorrect getsockopt option size"); optval.resize(sizeof(Errno)); - memcpy(optval.data(), &translated_pending_err, sizeof(Errno)); + PutValue(optval, translated_pending_err); } return Translate(getsockopt_err); } @@ -735,7 +735,7 @@ Errno BSD::GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector& o } } -Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval) { +Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, std::span optval) { if (!IsFileDescriptorValid(fd)) { return Errno::BADF; } @@ -748,17 +748,15 @@ Errno BSD::SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, con Network::SocketBase* const socket = file_descriptors[fd]->socket.get(); if (optname == OptName::LINGER) { - ASSERT(optlen == sizeof(Linger)); - Linger linger; - std::memcpy(&linger, optval, sizeof(linger)); + ASSERT(optval.size() == sizeof(Linger)); + auto linger = GetValue(optval); ASSERT(linger.onoff == 0 || linger.onoff == 1); return Translate(socket->SetLinger(linger.onoff != 0, linger.linger)); } - ASSERT(optlen == sizeof(u32)); - u32 value; - std::memcpy(&value, optval, sizeof(value)); + ASSERT(optval.size() == sizeof(u32)); + auto value = GetValue(optval); switch (optname) { case OptName::REUSEADDR: @@ -862,7 +860,7 @@ std::pair BSD::RecvFromImpl(s32 fd, u32 flags, std::vector& mess } else { ASSERT(addr.size() == sizeof(SockAddrIn)); const SockAddrIn result = Translate(addr_in); - std::memcpy(addr.data(), &result, sizeof(result)); + PutValue(addr, result); } } @@ -886,8 +884,7 @@ std::pair BSD::SendToImpl(s32 fd, u32 flags, std::span mes Network::SockAddrIn* p_addr_in = nullptr; if (!addr.empty()) { ASSERT(addr.size() == sizeof(SockAddrIn)); - SockAddrIn guest_addr_in; - std::memcpy(&guest_addr_in, addr.data(), sizeof(guest_addr_in)); + auto guest_addr_in = GetValue(addr); addr_in = Translate(guest_addr_in); p_addr_in = &addr_in; } diff --git a/src/core/hle/service/sockets/bsd.h b/src/core/hle/service/sockets/bsd.h index 161f22b9b..4f69d382c 100644 --- a/src/core/hle/service/sockets/bsd.h +++ b/src/core/hle/service/sockets/bsd.h @@ -163,7 +163,7 @@ private: Errno ListenImpl(s32 fd, s32 backlog); std::pair FcntlImpl(s32 fd, FcntlCmd cmd, s32 arg); Errno GetSockOptImpl(s32 fd, u32 level, OptName optname, std::vector& optval); - Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, size_t optlen, const void* optval); + Errno SetSockOptImpl(s32 fd, u32 level, OptName optname, std::span optval); Errno ShutdownImpl(s32 fd, s32 how); std::pair RecvImpl(s32 fd, u32 flags, std::vector& message); std::pair RecvFromImpl(s32 fd, u32 flags, std::vector& message, -- cgit v1.2.3