From 01a2d978ebc24aa517048f941f53213db9de722e Mon Sep 17 00:00:00 2001 From: Liam Date: Thu, 25 Jan 2024 00:50:03 -0500 Subject: service: add template serializer for method calls --- src/core/hle/service/cmif_serialization.h | 337 ++++++++++++++++++++++++++++++ 1 file changed, 337 insertions(+) create mode 100644 src/core/hle/service/cmif_serialization.h (limited to 'src/core/hle/service/cmif_serialization.h') diff --git a/src/core/hle/service/cmif_serialization.h b/src/core/hle/service/cmif_serialization.h new file mode 100644 index 000000000..8e8cf2507 --- /dev/null +++ b/src/core/hle/service/cmif_serialization.h @@ -0,0 +1,337 @@ +// SPDX-FileCopyrightText: Copyright 2024 yuzu Emulator Project +// SPDX-License-Identifier: GPL-2.0-or-later + +#pragma once + +#include "common/div_ceil.h" + +#include "core/hle/service/cmif_types.h" +#include "core/hle/service/ipc_helpers.h" +#include "core/hle/service/service.h" + +namespace Service { + +// clang-format off +struct RequestLayout { + u32 copy_handle_count; + u32 move_handle_count; + u32 cmif_raw_data_size; + u32 domain_interface_count; +}; + +template +constexpr u32 GetArgumentRawDataSize() { + if constexpr (ArgIndex >= std::tuple_size_v) { + return static_cast(DataOffset); + } else { + using ArgType = std::tuple_element_t; + + if constexpr (ArgumentTraits::Type == Type1 || ArgumentTraits::Type == Type2) { + constexpr size_t ArgAlign = alignof(ArgType); + constexpr size_t ArgSize = sizeof(ArgType); + + static_assert(PrevAlign <= ArgAlign, "Input argument is not ordered by alignment"); + + constexpr size_t ArgOffset = Common::AlignUp(DataOffset, ArgAlign); + constexpr size_t ArgEnd = ArgOffset + ArgSize; + + return GetArgumentRawDataSize(); + } else { + return GetArgumentRawDataSize(); + } + } +} + +template +constexpr u32 GetArgumentTypeCount() { + if constexpr (ArgIndex >= std::tuple_size_v) { + return static_cast(ArgCount); + } else { + using ArgType = std::tuple_element_t; + + if constexpr (ArgumentTraits::Type == DataType) { + return GetArgumentTypeCount(); + } else { + return GetArgumentTypeCount(); + } + } +} + +template +constexpr RequestLayout GetNonDomainReplyInLayout() { + return RequestLayout{ + .copy_handle_count = GetArgumentTypeCount(), + .move_handle_count = 0, + .cmif_raw_data_size = GetArgumentRawDataSize(), + .domain_interface_count = 0, + }; +} + +template +constexpr RequestLayout GetDomainReplyInLayout() { + return RequestLayout{ + .copy_handle_count = GetArgumentTypeCount(), + .move_handle_count = 0, + .cmif_raw_data_size = GetArgumentRawDataSize(), + .domain_interface_count = GetArgumentTypeCount(), + }; +} + +template +constexpr RequestLayout GetNonDomainReplyOutLayout() { + return RequestLayout{ + .copy_handle_count = GetArgumentTypeCount(), + .move_handle_count = GetArgumentTypeCount() + GetArgumentTypeCount(), + .cmif_raw_data_size = GetArgumentRawDataSize(), + .domain_interface_count = 0, + }; +} + +template +constexpr RequestLayout GetDomainReplyOutLayout() { + return RequestLayout{ + .copy_handle_count = GetArgumentTypeCount(), + .move_handle_count = GetArgumentTypeCount(), + .cmif_raw_data_size = GetArgumentRawDataSize(), + .domain_interface_count = GetArgumentTypeCount(), + }; +} + +template +constexpr RequestLayout GetReplyInLayout() { + return Domain ? GetDomainReplyInLayout() : GetNonDomainReplyInLayout(); +} + +template +constexpr RequestLayout GetReplyOutLayout() { + return Domain ? GetDomainReplyOutLayout() : GetNonDomainReplyOutLayout(); +} + +using OutTemporaryBuffers = std::array, 3>; + +template +void ReadInArgument(CallArguments& args, const u8* raw_data, HLERequestContext& ctx, OutTemporaryBuffers& temp) { + if constexpr (ArgIndex >= std::tuple_size_v) { + return; + } else { + using ArgType = std::tuple_element_t; + + if constexpr (ArgumentTraits::Type == ArgumentType::InData || ArgumentTraits::Type == ArgumentType::InProcessId) { + constexpr size_t ArgAlign = alignof(ArgType); + constexpr size_t ArgSize = sizeof(ArgType); + + static_assert(PrevAlign <= ArgAlign, "Input argument is not ordered by alignment"); + static_assert(!RawDataFinished, "All input interface arguments must appear after raw data"); + + constexpr size_t ArgOffset = Common::AlignUp(DataOffset, ArgAlign); + constexpr size_t ArgEnd = ArgOffset + ArgSize; + + if constexpr (ArgumentTraits::Type == ArgumentType::InProcessId) { + // TODO: abort parsing if PID is not provided? + // TODO: validate against raw data value? + std::get(args).pid = ctx.GetPID(); + } else { + std::memcpy(&std::get(args), raw_data + ArgOffset, ArgSize); + } + + return ReadInArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::InInterface) { + constexpr size_t ArgAlign = alignof(u32); + constexpr size_t ArgSize = sizeof(u32); + constexpr size_t ArgOffset = Common::AlignUp(DataOffset, ArgAlign); + constexpr size_t ArgEnd = ArgOffset + ArgSize; + + static_assert(Domain); + ASSERT(ctx.GetDomainMessageHeader().input_object_count > 0); + + u32 value{}; + std::memcpy(&value, raw_data + ArgOffset, ArgSize); + std::get(args) = ctx.GetDomainHandler(value - 1); + + return ReadInArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::InCopyHandle) { + std::get(args) = std::move(ctx.GetObjectFromHandle(ctx.GetCopyHandle(HandleIndex))); + + return ReadInArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::InLargeData) { + constexpr size_t BufferSize = sizeof(ArgType); + + // Clear the existing data. + std::memset(&std::get(args), 0, BufferSize); + + std::span buffer{}; + + ASSERT(ctx.CanReadBuffer(InBufferIndex)); + if constexpr (ArgType::Attr & BufferAttr_HipcAutoSelect) { + buffer = ctx.ReadBuffer(InBufferIndex); + } else if constexpr (ArgType::Attr & BufferAttr_HipcMapAlias) { + buffer = ctx.ReadBufferA(InBufferIndex); + } else /* if (ArgType::Attr & BufferAttr_HipcPointer) */ { + buffer = ctx.ReadBufferX(InBufferIndex); + } + + std::memcpy(&std::get(args), buffer.data(), std::min(BufferSize, buffer.size())); + + return ReadInArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::InBuffer) { + using ElementType = typename ArgType::Type; + + std::span buffer{}; + + if (ctx.CanReadBuffer(InBufferIndex)) { + if constexpr (ArgType::Attr & BufferAttr_HipcAutoSelect) { + buffer = ctx.ReadBuffer(InBufferIndex); + } else if constexpr (ArgType::Attr & BufferAttr_HipcMapAlias) { + buffer = ctx.ReadBufferA(InBufferIndex); + } else /* if (ArgType::Attr & BufferAttr_HipcPointer) */ { + buffer = ctx.ReadBufferX(InBufferIndex); + } + } + + ElementType* ptr = (ElementType*) buffer.data(); + size_t size = buffer.size() / sizeof(ElementType); + + std::get(args) = std::span(ptr, size); + + return ReadInArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutLargeData) { + constexpr size_t BufferSize = sizeof(ArgType); + + // Clear the existing data. + std::memset(&std::get(args), 0, BufferSize); + + return ReadInArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutBuffer) { + using ElementType = typename ArgType::Type; + + // Set up scratch buffer. + auto& buffer = temp[OutBufferIndex]; + if (ctx.CanWriteBuffer(OutBufferIndex)) { + buffer.resize_destructive(ctx.GetWriteBufferSize(OutBufferIndex)); + } else { + buffer.resize_destructive(0); + } + + ElementType* ptr = (ElementType*) buffer.data(); + size_t size = buffer.size() / sizeof(ElementType); + + std::get(args) = std::span(ptr, size); + + return ReadInArgument(args, raw_data, ctx, temp); + } else { + return ReadInArgument(args, raw_data, ctx, temp); + } + } +} + +template +void WriteOutArgument(CallArguments& args, u8* raw_data, HLERequestContext& ctx, OutTemporaryBuffers& temp) { + if constexpr (ArgIndex >= std::tuple_size_v) { + return; + } else { + using ArgType = std::tuple_element_t; + + if constexpr (ArgumentTraits::Type == ArgumentType::OutData) { + constexpr size_t ArgAlign = alignof(ArgType); + constexpr size_t ArgSize = sizeof(ArgType); + + static_assert(PrevAlign <= ArgAlign, "Output argument is not ordered by alignment"); + static_assert(!RawDataFinished, "All output interface arguments must appear after raw data"); + + constexpr size_t ArgOffset = Common::AlignUp(DataOffset, ArgAlign); + constexpr size_t ArgEnd = ArgOffset + ArgSize; + + std::memcpy(raw_data + ArgOffset, &std::get(args), ArgSize); + + return WriteOutArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutInterface) { + if constexpr (Domain) { + ctx.AddDomainObject(std::get(args)); + } else { + ctx.AddMoveInterface(std::get(args)); + } + + return WriteOutArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutCopyHandle) { + ctx.AddCopyObject(std::get(args).GetPointerUnsafe()); + + return WriteOutArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutMoveHandle) { + ctx.AddMoveObject(std::get(args).GetPointerUnsafe()); + + return WriteOutArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutLargeData) { + constexpr size_t BufferSize = sizeof(ArgType); + + ASSERT(ctx.CanWriteBuffer(OutBufferIndex)); + if constexpr (ArgType::Attr & BufferAttr_HipcAutoSelect) { + ctx.WriteBuffer(std::get(args), OutBufferIndex); + } else if constexpr (ArgType::Attr & BufferAttr_HipcMapAlias) { + ctx.WriteBufferB(&std::get(args), BufferSize, OutBufferIndex); + } else /* if (ArgType::Attr & BufferAttr_HipcPointer) */ { + ctx.WriteBufferC(&std::get(args), BufferSize, OutBufferIndex); + } + + return WriteOutArgument(args, raw_data, ctx, temp); + } else if constexpr (ArgumentTraits::Type == ArgumentType::OutBuffer) { + auto& buffer = temp[OutBufferIndex]; + const size_t size = buffer.size(); + + if (ctx.CanWriteBuffer(OutBufferIndex)) { + if constexpr (ArgType::Attr & BufferAttr_HipcAutoSelect) { + ctx.WriteBuffer(buffer.data(), size, OutBufferIndex); + } else if constexpr (ArgType::Attr & BufferAttr_HipcMapAlias) { + ctx.WriteBufferB(buffer.data(), size, OutBufferIndex); + } else /* if (ArgType::Attr & BufferAttr_HipcPointer) */ { + ctx.WriteBufferC(buffer.data(), size, OutBufferIndex); + } + } + + return WriteOutArgument( args, raw_data, ctx, temp); + } else { + return WriteOutArgument(args, raw_data, ctx, temp); + } + } +} + +template +void CmifReplyWrapImpl(HLERequestContext& ctx, T& t, Result (T::*f)(A...)) { + // Verify domain state. + if constexpr (Domain) { + ASSERT_MSG(ctx.GetManager()->IsDomain(), "Domain reply used on non-domain session"); + } else { + ASSERT_MSG(!ctx.GetManager()->IsDomain(), "Non-domain reply used on domain session"); + } + + using MethodArguments = std::tuple...>; + + OutTemporaryBuffers buffers{}; + auto call_arguments = std::tuple::Type...>(); + + // Read inputs. + const size_t offset_plus_command_id = ctx.GetDataPayloadOffset() + 2; + ReadInArgument(call_arguments, reinterpret_cast(ctx.CommandBuffer() + offset_plus_command_id), ctx, buffers); + + // Call. + const auto Callable = [&](CallArgs&... args) { + return (t.*f)(args...); + }; + const Result res = std::apply(Callable, call_arguments); + + // Write result. + constexpr RequestLayout layout = GetReplyOutLayout(); + IPC::ResponseBuilder rb{ctx, 2 + Common::DivCeil(layout.cmif_raw_data_size, sizeof(u32)), layout.copy_handle_count, layout.move_handle_count + layout.domain_interface_count}; + rb.Push(res); + + // Write out arguments. + WriteOutArgument(call_arguments, reinterpret_cast(ctx.CommandBuffer() + rb.GetCurrentOffset()), ctx, buffers); +} +// clang-format on + +template +template +inline void ServiceFramework::CmifReplyWrap(HLERequestContext& ctx) { + return CmifReplyWrapImpl(ctx, *static_cast(this), F); +} + +} // namespace Service -- cgit v1.2.3