diff options
author | comex <comexk@gmail.com> | 2023-07-02 00:02:25 +0200 |
---|---|---|
committer | comex <comexk@gmail.com> | 2023-07-02 02:27:35 +0200 |
commit | 0e191c271125321589dfdbb09731413550710c9a (patch) | |
tree | e0d9e826197d2fd66cc9b7fd8d11cc8c7fd6e3ac /src/core/hle/service/ssl | |
parent | Merge remote-tracking branch 'origin/master' into ssl (diff) | |
download | yuzu-0e191c271125321589dfdbb09731413550710c9a.tar yuzu-0e191c271125321589dfdbb09731413550710c9a.tar.gz yuzu-0e191c271125321589dfdbb09731413550710c9a.tar.bz2 yuzu-0e191c271125321589dfdbb09731413550710c9a.tar.lz yuzu-0e191c271125321589dfdbb09731413550710c9a.tar.xz yuzu-0e191c271125321589dfdbb09731413550710c9a.tar.zst yuzu-0e191c271125321589dfdbb09731413550710c9a.zip |
Diffstat (limited to 'src/core/hle/service/ssl')
-rw-r--r-- | src/core/hle/service/ssl/ssl.cpp | 106 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_openssl.cpp | 66 | ||||
-rw-r--r-- | src/core/hle/service/ssl/ssl_backend_schannel.cpp | 207 |
3 files changed, 194 insertions, 185 deletions
diff --git a/src/core/hle/service/ssl/ssl.cpp b/src/core/hle/service/ssl/ssl.cpp index 5638dd693..0919be55f 100644 --- a/src/core/hle/service/ssl/ssl.cpp +++ b/src/core/hle/service/ssl/ssl.cpp @@ -64,7 +64,7 @@ public: std::shared_ptr<SslContextSharedData>& shared_data, std::unique_ptr<SSLConnectionBackend>&& backend) : ServiceFramework{system_, "ISslConnection"}, ssl_version{version}, - shared_data_{shared_data}, backend_{std::move(backend)} { + shared_data{shared_data}, backend{std::move(backend)} { // clang-format off static const FunctionInfo functions[] = { {0, &ISslConnection::SetSocketDescriptor, "SetSocketDescriptor"}, @@ -112,10 +112,10 @@ public: } ~ISslConnection() { - shared_data_->connection_count--; - if (fd_to_close_.has_value()) { - const s32 fd = *fd_to_close_; - if (!do_not_close_socket_) { + shared_data->connection_count--; + if (fd_to_close.has_value()) { + const s32 fd = *fd_to_close; + if (!do_not_close_socket) { LOG_ERROR(Service_SSL, "do_not_close_socket was changed after setting socket; is this right?"); } else { @@ -132,30 +132,30 @@ public: private: SslVersion ssl_version; - std::shared_ptr<SslContextSharedData> shared_data_; - std::unique_ptr<SSLConnectionBackend> backend_; - std::optional<int> fd_to_close_; - bool do_not_close_socket_ = false; - bool get_server_cert_chain_ = false; - std::shared_ptr<Network::SocketBase> socket_; - bool did_set_host_name_ = false; - bool did_handshake_ = false; + std::shared_ptr<SslContextSharedData> shared_data; + std::unique_ptr<SSLConnectionBackend> backend; + std::optional<int> fd_to_close; + bool do_not_close_socket = false; + bool get_server_cert_chain = false; + std::shared_ptr<Network::SocketBase> socket; + bool did_set_host_name = false; + bool did_handshake = false; ResultVal<s32> SetSocketDescriptorImpl(s32 fd) { LOG_DEBUG(Service_SSL, "called, fd={}", fd); - ASSERT(!did_handshake_); + ASSERT(!did_handshake); auto bsd = system.ServiceManager().GetService<Service::Sockets::BSD>("bsd:u"); ASSERT_OR_EXECUTE(bsd, { return ResultInternalError; }); s32 ret_fd; // Based on https://switchbrew.org/wiki/SSL_services#SetSocketDescriptor - if (do_not_close_socket_) { + if (do_not_close_socket) { auto res = bsd->DuplicateSocketImpl(fd); if (!res.has_value()) { LOG_ERROR(Service_SSL, "Failed to duplicate socket with fd {}", fd); return ResultInvalidSocket; } fd = *res; - fd_to_close_ = fd; + fd_to_close = fd; ret_fd = fd; } else { ret_fd = -1; @@ -165,34 +165,34 @@ private: LOG_ERROR(Service_SSL, "invalid socket fd {}", fd); return ResultInvalidSocket; } - socket_ = std::move(*sock); - backend_->SetSocket(socket_); + socket = std::move(*sock); + backend->SetSocket(socket); return ret_fd; } Result SetHostNameImpl(const std::string& hostname) { LOG_DEBUG(Service_SSL, "called. hostname={}", hostname); - ASSERT(!did_handshake_); - Result res = backend_->SetHostName(hostname); + ASSERT(!did_handshake); + Result res = backend->SetHostName(hostname); if (res == ResultSuccess) { - did_set_host_name_ = true; + did_set_host_name = true; } return res; } Result SetVerifyOptionImpl(u32 option) { - ASSERT(!did_handshake_); + ASSERT(!did_handshake); LOG_WARNING(Service_SSL, "(STUBBED) called. option={}", option); return ResultSuccess; } - Result SetIOModeImpl(u32 _mode) { - auto mode = static_cast<IoMode>(_mode); + Result SetIoModeImpl(u32 input_mode) { + auto mode = static_cast<IoMode>(input_mode); ASSERT(mode == IoMode::Blocking || mode == IoMode::NonBlocking); - ASSERT_OR_EXECUTE(socket_, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE(socket, { return ResultNoSocket; }); const bool non_block = mode == IoMode::NonBlocking; - const Network::Errno error = socket_->SetNonBlock(non_block); + const Network::Errno error = socket->SetNonBlock(non_block); if (error != Network::Errno::SUCCESS) { LOG_ERROR(Service_SSL, "Failed to set native socket non-block flag to {}", non_block); } @@ -200,18 +200,18 @@ private: } Result SetSessionCacheModeImpl(u32 mode) { - ASSERT(!did_handshake_); + ASSERT(!did_handshake); LOG_WARNING(Service_SSL, "(STUBBED) called. value={}", mode); return ResultSuccess; } Result DoHandshakeImpl() { - ASSERT_OR_EXECUTE(!did_handshake_ && socket_, { return ResultNoSocket; }); + ASSERT_OR_EXECUTE(!did_handshake && socket, { return ResultNoSocket; }); ASSERT_OR_EXECUTE_MSG( - did_set_host_name_, { return ResultInternalError; }, + did_set_host_name, { return ResultInternalError; }, "Expected SetHostName before DoHandshake"); - Result res = backend_->DoHandshake(); - did_handshake_ = res.IsSuccess(); + Result res = backend->DoHandshake(); + did_handshake = res.IsSuccess(); return res; } @@ -225,7 +225,7 @@ private: u32 size; u32 offset; }; - if (!get_server_cert_chain_) { + if (!get_server_cert_chain) { // Just return the first one, unencoded. ASSERT_OR_EXECUTE_MSG( !certs.empty(), { return {}; }, "Should be at least one server cert"); @@ -248,9 +248,9 @@ private: } ResultVal<std::vector<u8>> ReadImpl(size_t size) { - ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); std::vector<u8> res(size); - ResultVal<size_t> actual = backend_->Read(res); + ResultVal<size_t> actual = backend->Read(res); if (actual.Failed()) { return actual.Code(); } @@ -259,8 +259,8 @@ private: } ResultVal<size_t> WriteImpl(std::span<const u8> data) { - ASSERT_OR_EXECUTE(did_handshake_, { return ResultInternalError; }); - return backend_->Write(data); + ASSERT_OR_EXECUTE(did_handshake, { return ResultInternalError; }); + return backend->Write(data); } ResultVal<s32> PendingImpl() { @@ -295,7 +295,7 @@ private: void SetIoMode(HLERequestContext& ctx) { IPC::RequestParser rp{ctx}; const u32 mode = rp.Pop<u32>(); - const Result res = SetIOModeImpl(mode); + const Result res = SetIoModeImpl(mode); IPC::ResponseBuilder rb{ctx, 2}; rb.Push(res); } @@ -307,22 +307,26 @@ private: } void DoHandshakeGetServerCert(HLERequestContext& ctx) { + struct OutputParameters { + u32 certs_size; + u32 certs_count; + }; + static_assert(sizeof(OutputParameters) == 0x8); + const Result res = DoHandshakeImpl(); - u32 certs_count = 0; - u32 certs_size = 0; + OutputParameters out{}; if (res == ResultSuccess) { - auto certs = backend_->GetServerCerts(); + auto certs = backend->GetServerCerts(); if (certs.Succeeded()) { const std::vector<u8> certs_buf = SerializeServerCerts(*certs); ctx.WriteBuffer(certs_buf); - certs_count = static_cast<u32>(certs->size()); - certs_size = static_cast<u32>(certs_buf.size()); + out.certs_count = static_cast<u32>(certs->size()); + out.certs_size = static_cast<u32>(certs_buf.size()); } } IPC::ResponseBuilder rb{ctx, 4}; rb.Push(res); - rb.Push(certs_size); - rb.Push(certs_count); + rb.PushRaw(out); } void Read(HLERequestContext& ctx) { @@ -371,10 +375,10 @@ private: switch (parameters.option) { case OptionType::DoNotCloseSocket: - do_not_close_socket_ = static_cast<bool>(parameters.value); + do_not_close_socket = static_cast<bool>(parameters.value); break; case OptionType::GetServerCertChain: - get_server_cert_chain_ = static_cast<bool>(parameters.value); + get_server_cert_chain = static_cast<bool>(parameters.value); break; default: LOG_WARNING(Service_SSL, "Unknown option={}, value={}", parameters.option, @@ -390,7 +394,7 @@ class ISslContext final : public ServiceFramework<ISslContext> { public: explicit ISslContext(Core::System& system_, SslVersion version) : ServiceFramework{system_, "ISslContext"}, ssl_version{version}, - shared_data_{std::make_shared<SslContextSharedData>()} { + shared_data{std::make_shared<SslContextSharedData>()} { static const FunctionInfo functions[] = { {0, &ISslContext::SetOption, "SetOption"}, {1, nullptr, "GetOption"}, @@ -412,7 +416,7 @@ public: private: SslVersion ssl_version; - std::shared_ptr<SslContextSharedData> shared_data_; + std::shared_ptr<SslContextSharedData> shared_data; void SetOption(HLERequestContext& ctx) { struct Parameters { @@ -439,17 +443,17 @@ private: IPC::ResponseBuilder rb{ctx, 2, 0, 1}; rb.Push(backend_res.Code()); if (backend_res.Succeeded()) { - rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data_, + rb.PushIpcInterface<ISslConnection>(system, ssl_version, shared_data, std::move(*backend_res)); } } void GetConnectionCount(HLERequestContext& ctx) { - LOG_WARNING(Service_SSL, "connection_count={}", shared_data_->connection_count); + LOG_DEBUG(Service_SSL, "connection_count={}", shared_data->connection_count); IPC::ResponseBuilder rb{ctx, 3}; rb.Push(ResultSuccess); - rb.Push(shared_data_->connection_count); + rb.Push(shared_data->connection_count); } void ImportServerPki(HLERequestContext& ctx) { diff --git a/src/core/hle/service/ssl/ssl_backend_openssl.cpp b/src/core/hle/service/ssl/ssl_backend_openssl.cpp index e7d5801fd..f69674f77 100644 --- a/src/core/hle/service/ssl/ssl_backend_openssl.cpp +++ b/src/core/hle/service/ssl/ssl_backend_openssl.cpp @@ -51,37 +51,37 @@ public: return ResultInternalError; } - ssl_ = SSL_new(ssl_ctx); - if (!ssl_) { + ssl = SSL_new(ssl_ctx); + if (!ssl) { LOG_ERROR(Service_SSL, "SSL_new failed"); return CheckOpenSSLErrors(); } - SSL_set_connect_state(ssl_); + SSL_set_connect_state(ssl); - bio_ = BIO_new(bio_meth); - if (!bio_) { + bio = BIO_new(bio_meth); + if (!bio) { LOG_ERROR(Service_SSL, "BIO_new failed"); return CheckOpenSSLErrors(); } - BIO_set_data(bio_, this); - BIO_set_init(bio_, 1); - SSL_set_bio(ssl_, bio_, bio_); + BIO_set_data(bio, this); + BIO_set_init(bio, 1); + SSL_set_bio(ssl, bio, bio); return ResultSuccess; } - void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { - socket_ = socket; + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); } Result SetHostName(const std::string& hostname) override { - if (!SSL_set1_host(ssl_, hostname.c_str())) { // hostname for verification + if (!SSL_set1_host(ssl, hostname.c_str())) { // hostname for verification LOG_ERROR(Service_SSL, "SSL_set1_host({}) failed", hostname); return CheckOpenSSLErrors(); } - if (!SSL_set_tlsext_host_name(ssl_, hostname.c_str())) { // hostname for SNI + if (!SSL_set_tlsext_host_name(ssl, hostname.c_str())) { // hostname for SNI LOG_ERROR(Service_SSL, "SSL_set_tlsext_host_name({}) failed", hostname); return CheckOpenSSLErrors(); } @@ -89,18 +89,18 @@ public: } Result DoHandshake() override { - SSL_set_verify_result(ssl_, X509_V_OK); - const int ret = SSL_do_handshake(ssl_); - const long verify_result = SSL_get_verify_result(ssl_); + SSL_set_verify_result(ssl, X509_V_OK); + const int ret = SSL_do_handshake(ssl); + const long verify_result = SSL_get_verify_result(ssl); if (verify_result != X509_V_OK) { LOG_ERROR(Service_SSL, "SSL cert verification failed because: {}", X509_verify_cert_error_string(verify_result)); return CheckOpenSSLErrors(); } if (ret <= 0) { - const int ssl_err = SSL_get_error(ssl_, ret); + const int ssl_err = SSL_get_error(ssl, ret); if (ssl_err == SSL_ERROR_ZERO_RETURN || - (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_)) { + (ssl_err == SSL_ERROR_SYSCALL && got_read_eof)) { LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); return ResultInternalError; } @@ -110,18 +110,18 @@ public: ResultVal<size_t> Read(std::span<u8> data) override { size_t actual; - const int ret = SSL_read_ex(ssl_, data.data(), data.size(), &actual); + const int ret = SSL_read_ex(ssl, data.data(), data.size(), &actual); return HandleReturn("SSL_read_ex", actual, ret); } ResultVal<size_t> Write(std::span<const u8> data) override { size_t actual; - const int ret = SSL_write_ex(ssl_, data.data(), data.size(), &actual); + const int ret = SSL_write_ex(ssl, data.data(), data.size(), &actual); return HandleReturn("SSL_write_ex", actual, ret); } ResultVal<size_t> HandleReturn(const char* what, size_t actual, int ret) { - const int ssl_err = SSL_get_error(ssl_, ret); + const int ssl_err = SSL_get_error(ssl, ret); CheckOpenSSLErrors(); switch (ssl_err) { case SSL_ERROR_NONE: @@ -137,7 +137,7 @@ public: LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_WANT_WRITE", what); return ResultWouldBlock; default: - if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof_) { + if (ssl_err == SSL_ERROR_SYSCALL && got_read_eof) { LOG_DEBUG(Service_SSL, "{} => SSL_ERROR_SYSCALL because server hung up", what); return size_t(0); } @@ -147,7 +147,7 @@ public: } ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { - STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl_); + STACK_OF(X509)* chain = SSL_get_peer_cert_chain(ssl); if (!chain) { LOG_ERROR(Service_SSL, "SSL_get_peer_cert_chain returned nullptr"); return ResultInternalError; @@ -169,8 +169,8 @@ public: ~SSLConnectionBackendOpenSSL() { // these are null-tolerant: - SSL_free(ssl_); - BIO_free(bio_); + SSL_free(ssl); + BIO_free(bio); } static void KeyLogCallback(const SSL* ssl, const char* line) { @@ -188,9 +188,9 @@ public: static int WriteCallback(BIO* bio, const char* buf, size_t len, size_t* actual_p) { auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); ASSERT_OR_EXECUTE_MSG( - self->socket_, { return 0; }, "OpenSSL asked to send but we have no socket"); + self->socket, { return 0; }, "OpenSSL asked to send but we have no socket"); BIO_clear_retry_flags(bio); - auto [actual, err] = self->socket_->Send({reinterpret_cast<const u8*>(buf), len}, 0); + auto [actual, err] = self->socket->Send({reinterpret_cast<const u8*>(buf), len}, 0); switch (err) { case Network::Errno::SUCCESS: *actual_p = actual; @@ -207,14 +207,14 @@ public: static int ReadCallback(BIO* bio, char* buf, size_t len, size_t* actual_p) { auto self = static_cast<SSLConnectionBackendOpenSSL*>(BIO_get_data(bio)); ASSERT_OR_EXECUTE_MSG( - self->socket_, { return 0; }, "OpenSSL asked to recv but we have no socket"); + self->socket, { return 0; }, "OpenSSL asked to recv but we have no socket"); BIO_clear_retry_flags(bio); - auto [actual, err] = self->socket_->Recv(0, {reinterpret_cast<u8*>(buf), len}); + auto [actual, err] = self->socket->Recv(0, {reinterpret_cast<u8*>(buf), len}); switch (err) { case Network::Errno::SUCCESS: *actual_p = actual; if (actual == 0) { - self->got_read_eof_ = true; + self->got_read_eof = true; } return actual ? 1 : 0; case Network::Errno::AGAIN: @@ -246,11 +246,11 @@ public: } } - SSL* ssl_ = nullptr; - BIO* bio_ = nullptr; - bool got_read_eof_ = false; + SSL* ssl = nullptr; + BIO* bio = nullptr; + bool got_read_eof = false; - std::shared_ptr<Network::SocketBase> socket_; + std::shared_ptr<Network::SocketBase> socket; }; ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { diff --git a/src/core/hle/service/ssl/ssl_backend_schannel.cpp b/src/core/hle/service/ssl/ssl_backend_schannel.cpp index 775d5cc07..a1d6a186e 100644 --- a/src/core/hle/service/ssl/ssl_backend_schannel.cpp +++ b/src/core/hle/service/ssl/ssl_backend_schannel.cpp @@ -48,6 +48,12 @@ static void OneTimeInit() { return; } + if (getenv("SSLKEYLOGFILE")) { + LOG_CRITICAL(Service_SSL, "SSLKEYLOGFILE was set but Schannel does not support exporting " + "keys; not logging keys!"); + // Not fatal. + } + one_time_init_success = true; } @@ -70,25 +76,25 @@ public: return ResultSuccess; } - void SetSocket(std::shared_ptr<Network::SocketBase> socket) override { - socket_ = socket; + void SetSocket(std::shared_ptr<Network::SocketBase> socket_in) override { + socket = std::move(socket_in); } - Result SetHostName(const std::string& hostname) override { - hostname_ = hostname; + Result SetHostName(const std::string& hostname_in) override { + hostname = hostname_in; return ResultSuccess; } Result DoHandshake() override { while (1) { Result r; - switch (handshake_state_) { + switch (handshake_state) { case HandshakeState::Initial: if ((r = FlushCiphertextWriteBuf()) != ResultSuccess || (r = CallInitializeSecurityContext()) != ResultSuccess) { return r; } - // CallInitializeSecurityContext updated `handshake_state_`. + // CallInitializeSecurityContext updated `handshake_state`. continue; case HandshakeState::ContinueNeeded: case HandshakeState::IncompleteMessage: @@ -96,20 +102,20 @@ public: (r = FillCiphertextReadBuf()) != ResultSuccess) { return r; } - if (ciphertext_read_buf_.empty()) { + if (ciphertext_read_buf.empty()) { LOG_ERROR(Service_SSL, "SSL handshake failed because server hung up"); return ResultInternalError; } if ((r = CallInitializeSecurityContext()) != ResultSuccess) { return r; } - // CallInitializeSecurityContext updated `handshake_state_`. + // CallInitializeSecurityContext updated `handshake_state`. continue; case HandshakeState::DoneAfterFlush: if ((r = FlushCiphertextWriteBuf()) != ResultSuccess) { return r; } - handshake_state_ = HandshakeState::Connected; + handshake_state = HandshakeState::Connected; return ResultSuccess; case HandshakeState::Connected: LOG_ERROR(Service_SSL, "Called DoHandshake but we already handshook"); @@ -121,24 +127,24 @@ public: } Result FillCiphertextReadBuf() { - const size_t fill_size = read_buf_fill_size_ ? read_buf_fill_size_ : 4096; - read_buf_fill_size_ = 0; + const size_t fill_size = read_buf_fill_size ? read_buf_fill_size : 4096; + read_buf_fill_size = 0; // This unnecessarily zeroes the buffer; oh well. - const size_t offset = ciphertext_read_buf_.size(); + const size_t offset = ciphertext_read_buf.size(); ASSERT_OR_EXECUTE(offset + fill_size >= offset, { return ResultInternalError; }); - ciphertext_read_buf_.resize(offset + fill_size, 0); - const auto read_span = std::span(ciphertext_read_buf_).subspan(offset, fill_size); - const auto [actual, err] = socket_->Recv(0, read_span); + ciphertext_read_buf.resize(offset + fill_size, 0); + const auto read_span = std::span(ciphertext_read_buf).subspan(offset, fill_size); + const auto [actual, err] = socket->Recv(0, read_span); switch (err) { case Network::Errno::SUCCESS: ASSERT(static_cast<size_t>(actual) <= fill_size); - ciphertext_read_buf_.resize(offset + actual); + ciphertext_read_buf.resize(offset + actual); return ResultSuccess; case Network::Errno::AGAIN: - ciphertext_read_buf_.resize(offset); + ciphertext_read_buf.resize(offset); return ResultWouldBlock; default: - ciphertext_read_buf_.resize(offset); + ciphertext_read_buf.resize(offset); LOG_ERROR(Service_SSL, "Socket recv returned Network::Errno {}", err); return ResultInternalError; } @@ -146,13 +152,13 @@ public: // Returns success if the write buffer has been completely emptied. Result FlushCiphertextWriteBuf() { - while (!ciphertext_write_buf_.empty()) { - const auto [actual, err] = socket_->Send(ciphertext_write_buf_, 0); + while (!ciphertext_write_buf.empty()) { + const auto [actual, err] = socket->Send(ciphertext_write_buf, 0); switch (err) { case Network::Errno::SUCCESS: - ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf_.size()); - ciphertext_write_buf_.erase(ciphertext_write_buf_.begin(), - ciphertext_write_buf_.begin() + actual); + ASSERT(static_cast<size_t>(actual) <= ciphertext_write_buf.size()); + ciphertext_write_buf.erase(ciphertext_write_buf.begin(), + ciphertext_write_buf.begin() + actual); break; case Network::Errno::AGAIN: return ResultWouldBlock; @@ -175,9 +181,9 @@ public: // only used if `initial_call_done` { // [0] - .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), .BufferType = SECBUFFER_TOKEN, - .pvBuffer = ciphertext_read_buf_.data(), + .pvBuffer = ciphertext_read_buf.data(), }, { // [1] (will be replaced by SECBUFFER_MISSING when SEC_E_INCOMPLETE_MESSAGE is @@ -211,30 +217,30 @@ public: .pBuffers = output_buffers.data(), }; ASSERT_OR_EXECUTE_MSG( - input_buffers[0].cbBuffer == ciphertext_read_buf_.size(), + input_buffers[0].cbBuffer == ciphertext_read_buf.size(), { return ResultInternalError; }, "read buffer too large"); - bool initial_call_done = handshake_state_ != HandshakeState::Initial; + bool initial_call_done = handshake_state != HandshakeState::Initial; if (initial_call_done) { LOG_DEBUG(Service_SSL, "Passing {} bytes into InitializeSecurityContext", - ciphertext_read_buf_.size()); + ciphertext_read_buf.size()); } const SECURITY_STATUS ret = - InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt_ : nullptr, + InitializeSecurityContextA(&cred_handle, initial_call_done ? &ctxt : nullptr, // Caller ensured we have set a hostname: - const_cast<char*>(hostname_.value().c_str()), req, + const_cast<char*>(hostname.value().c_str()), req, 0, // Reserved1 0, // TargetDataRep not used with Schannel initial_call_done ? &input_desc : nullptr, 0, // Reserved2 - initial_call_done ? nullptr : &ctxt_, &output_desc, &attr, + initial_call_done ? nullptr : &ctxt, &output_desc, &attr, nullptr); // ptsExpiry if (output_buffers[0].pvBuffer) { const std::span span(static_cast<u8*>(output_buffers[0].pvBuffer), output_buffers[0].cbBuffer); - ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), span.begin(), span.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), span.begin(), span.end()); FreeContextBuffer(output_buffers[0].pvBuffer); } @@ -251,64 +257,64 @@ public: LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_I_CONTINUE_NEEDED"); if (input_buffers[1].BufferType == SECBUFFER_EXTRA) { LOG_DEBUG(Service_SSL, "EXTRA of size {}", input_buffers[1].cbBuffer); - ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf_.size()); - ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), - ciphertext_read_buf_.end() - input_buffers[1].cbBuffer); + ASSERT(input_buffers[1].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - input_buffers[1].cbBuffer); } else { ASSERT(input_buffers[1].BufferType == SECBUFFER_EMPTY); - ciphertext_read_buf_.clear(); + ciphertext_read_buf.clear(); } - handshake_state_ = HandshakeState::ContinueNeeded; + handshake_state = HandshakeState::ContinueNeeded; return ResultSuccess; case SEC_E_INCOMPLETE_MESSAGE: LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_INCOMPLETE_MESSAGE"); ASSERT(input_buffers[1].BufferType == SECBUFFER_MISSING); - read_buf_fill_size_ = input_buffers[1].cbBuffer; - handshake_state_ = HandshakeState::IncompleteMessage; + read_buf_fill_size = input_buffers[1].cbBuffer; + handshake_state = HandshakeState::IncompleteMessage; return ResultSuccess; case SEC_E_OK: LOG_DEBUG(Service_SSL, "InitializeSecurityContext => SEC_E_OK"); - ciphertext_read_buf_.clear(); - handshake_state_ = HandshakeState::DoneAfterFlush; + ciphertext_read_buf.clear(); + handshake_state = HandshakeState::DoneAfterFlush; return GrabStreamSizes(); default: LOG_ERROR(Service_SSL, "InitializeSecurityContext failed (probably certificate/protocol issue): {}", Common::NativeErrorToString(ret)); - handshake_state_ = HandshakeState::Error; + handshake_state = HandshakeState::Error; return ResultInternalError; } } Result GrabStreamSizes() { const SECURITY_STATUS ret = - QueryContextAttributes(&ctxt_, SECPKG_ATTR_STREAM_SIZES, &stream_sizes_); + QueryContextAttributes(&ctxt, SECPKG_ATTR_STREAM_SIZES, &stream_sizes); if (ret != SEC_E_OK) { LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_STREAM_SIZES) failed: {}", Common::NativeErrorToString(ret)); - handshake_state_ = HandshakeState::Error; + handshake_state = HandshakeState::Error; return ResultInternalError; } return ResultSuccess; } ResultVal<size_t> Read(std::span<u8> data) override { - if (handshake_state_ != HandshakeState::Connected) { + if (handshake_state != HandshakeState::Connected) { LOG_ERROR(Service_SSL, "Called Read but we did not successfully handshake"); return ResultInternalError; } - if (data.size() == 0 || got_read_eof_) { + if (data.size() == 0 || got_read_eof) { return size_t(0); } while (1) { - if (!cleartext_read_buf_.empty()) { - const size_t read_size = std::min(cleartext_read_buf_.size(), data.size()); - std::memcpy(data.data(), cleartext_read_buf_.data(), read_size); - cleartext_read_buf_.erase(cleartext_read_buf_.begin(), - cleartext_read_buf_.begin() + read_size); + if (!cleartext_read_buf.empty()) { + const size_t read_size = std::min(cleartext_read_buf.size(), data.size()); + std::memcpy(data.data(), cleartext_read_buf.data(), read_size); + cleartext_read_buf.erase(cleartext_read_buf.begin(), + cleartext_read_buf.begin() + read_size); return read_size; } - if (!ciphertext_read_buf_.empty()) { + if (!ciphertext_read_buf.empty()) { SecBuffer empty{ .cbBuffer = 0, .BufferType = SECBUFFER_EMPTY, @@ -316,16 +322,16 @@ public: }; std::array<SecBuffer, 5> buffers{{ { - .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf_.size()), + .cbBuffer = static_cast<unsigned long>(ciphertext_read_buf.size()), .BufferType = SECBUFFER_DATA, - .pvBuffer = ciphertext_read_buf_.data(), + .pvBuffer = ciphertext_read_buf.data(), }, empty, empty, empty, }}; ASSERT_OR_EXECUTE_MSG( - buffers[0].cbBuffer == ciphertext_read_buf_.size(), + buffers[0].cbBuffer == ciphertext_read_buf.size(), { return ResultInternalError; }, "read buffer too large"); SecBufferDesc desc{ .ulVersion = SECBUFFER_VERSION, @@ -333,7 +339,7 @@ public: .pBuffers = buffers.data(), }; SECURITY_STATUS ret = - DecryptMessage(&ctxt_, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); + DecryptMessage(&ctxt, &desc, /*MessageSeqNo*/ 0, /*pfQOP*/ nullptr); switch (ret) { case SEC_E_OK: ASSERT_OR_EXECUTE(buffers[0].BufferType == SECBUFFER_STREAM_HEADER, @@ -342,24 +348,23 @@ public: { return ResultInternalError; }); ASSERT_OR_EXECUTE(buffers[2].BufferType == SECBUFFER_STREAM_TRAILER, { return ResultInternalError; }); - cleartext_read_buf_.assign(static_cast<u8*>(buffers[1].pvBuffer), - static_cast<u8*>(buffers[1].pvBuffer) + - buffers[1].cbBuffer); + cleartext_read_buf.assign(static_cast<u8*>(buffers[1].pvBuffer), + static_cast<u8*>(buffers[1].pvBuffer) + + buffers[1].cbBuffer); if (buffers[3].BufferType == SECBUFFER_EXTRA) { - ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf_.size()); - ciphertext_read_buf_.erase(ciphertext_read_buf_.begin(), - ciphertext_read_buf_.end() - - buffers[3].cbBuffer); + ASSERT(buffers[3].cbBuffer <= ciphertext_read_buf.size()); + ciphertext_read_buf.erase(ciphertext_read_buf.begin(), + ciphertext_read_buf.end() - buffers[3].cbBuffer); } else { ASSERT(buffers[3].BufferType == SECBUFFER_EMPTY); - ciphertext_read_buf_.clear(); + ciphertext_read_buf.clear(); } continue; case SEC_E_INCOMPLETE_MESSAGE: break; case SEC_I_CONTEXT_EXPIRED: // Server hung up by sending close_notify. - got_read_eof_ = true; + got_read_eof = true; return size_t(0); default: LOG_ERROR(Service_SSL, "DecryptMessage failed: {}", @@ -371,43 +376,43 @@ public: if (r != ResultSuccess) { return r; } - if (ciphertext_read_buf_.empty()) { - got_read_eof_ = true; + if (ciphertext_read_buf.empty()) { + got_read_eof = true; return size_t(0); } } } ResultVal<size_t> Write(std::span<const u8> data) override { - if (handshake_state_ != HandshakeState::Connected) { + if (handshake_state != HandshakeState::Connected) { LOG_ERROR(Service_SSL, "Called Write but we did not successfully handshake"); return ResultInternalError; } if (data.size() == 0) { return size_t(0); } - data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes_.cbMaximumMessage)); - if (!cleartext_write_buf_.empty()) { + data = data.subspan(0, std::min<size_t>(data.size(), stream_sizes.cbMaximumMessage)); + if (!cleartext_write_buf.empty()) { // Already in the middle of a write. It wouldn't make sense to not // finish sending the entire buffer since TLS has // header/MAC/padding/etc. - if (data.size() != cleartext_write_buf_.size() || - std::memcmp(data.data(), cleartext_write_buf_.data(), data.size())) { + if (data.size() != cleartext_write_buf.size() || + std::memcmp(data.data(), cleartext_write_buf.data(), data.size())) { LOG_ERROR(Service_SSL, "Called Write but buffer does not match previous buffer"); return ResultInternalError; } return WriteAlreadyEncryptedData(); } else { - cleartext_write_buf_.assign(data.begin(), data.end()); + cleartext_write_buf.assign(data.begin(), data.end()); } - std::vector<u8> header_buf(stream_sizes_.cbHeader, 0); - std::vector<u8> tmp_data_buf = cleartext_write_buf_; - std::vector<u8> trailer_buf(stream_sizes_.cbTrailer, 0); + std::vector<u8> header_buf(stream_sizes.cbHeader, 0); + std::vector<u8> tmp_data_buf = cleartext_write_buf; + std::vector<u8> trailer_buf(stream_sizes.cbTrailer, 0); std::array<SecBuffer, 3> buffers{{ { - .cbBuffer = stream_sizes_.cbHeader, + .cbBuffer = stream_sizes.cbHeader, .BufferType = SECBUFFER_STREAM_HEADER, .pvBuffer = header_buf.data(), }, @@ -417,7 +422,7 @@ public: .pvBuffer = tmp_data_buf.data(), }, { - .cbBuffer = stream_sizes_.cbTrailer, + .cbBuffer = stream_sizes.cbTrailer, .BufferType = SECBUFFER_STREAM_TRAILER, .pvBuffer = trailer_buf.data(), }, @@ -431,17 +436,17 @@ public: .pBuffers = buffers.data(), }; - const SECURITY_STATUS ret = EncryptMessage(&ctxt_, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); + const SECURITY_STATUS ret = EncryptMessage(&ctxt, /*fQOP*/ 0, &desc, /*MessageSeqNo*/ 0); if (ret != SEC_E_OK) { LOG_ERROR(Service_SSL, "EncryptMessage failed: {}", Common::NativeErrorToString(ret)); return ResultInternalError; } - ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), header_buf.begin(), - header_buf.end()); - ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), tmp_data_buf.begin(), - tmp_data_buf.end()); - ciphertext_write_buf_.insert(ciphertext_write_buf_.end(), trailer_buf.begin(), - trailer_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), header_buf.begin(), + header_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), tmp_data_buf.begin(), + tmp_data_buf.end()); + ciphertext_write_buf.insert(ciphertext_write_buf.end(), trailer_buf.begin(), + trailer_buf.end()); return WriteAlreadyEncryptedData(); } @@ -451,15 +456,15 @@ public: return r; } // write buf is empty - const size_t cleartext_bytes_written = cleartext_write_buf_.size(); - cleartext_write_buf_.clear(); + const size_t cleartext_bytes_written = cleartext_write_buf.size(); + cleartext_write_buf.clear(); return cleartext_bytes_written; } ResultVal<std::vector<std::vector<u8>>> GetServerCerts() override { PCCERT_CONTEXT returned_cert = nullptr; const SECURITY_STATUS ret = - QueryContextAttributes(&ctxt_, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); + QueryContextAttributes(&ctxt, SECPKG_ATTR_REMOTE_CERT_CONTEXT, &returned_cert); if (ret != SEC_E_OK) { LOG_ERROR(Service_SSL, "QueryContextAttributes(SECPKG_ATTR_REMOTE_CERT_CONTEXT) failed: {}", @@ -480,8 +485,8 @@ public: } ~SSLConnectionBackendSchannel() { - if (handshake_state_ != HandshakeState::Initial) { - DeleteSecurityContext(&ctxt_); + if (handshake_state != HandshakeState::Initial) { + DeleteSecurityContext(&ctxt); } } @@ -509,21 +514,21 @@ public: // Another error was returned and we shouldn't allow initialization // to continue. Error, - } handshake_state_ = HandshakeState::Initial; + } handshake_state = HandshakeState::Initial; - CtxtHandle ctxt_; - SecPkgContext_StreamSizes stream_sizes_; + CtxtHandle ctxt; + SecPkgContext_StreamSizes stream_sizes; - std::shared_ptr<Network::SocketBase> socket_; - std::optional<std::string> hostname_; + std::shared_ptr<Network::SocketBase> socket; + std::optional<std::string> hostname; - std::vector<u8> ciphertext_read_buf_; - std::vector<u8> ciphertext_write_buf_; - std::vector<u8> cleartext_read_buf_; - std::vector<u8> cleartext_write_buf_; + std::vector<u8> ciphertext_read_buf; + std::vector<u8> ciphertext_write_buf; + std::vector<u8> cleartext_read_buf; + std::vector<u8> cleartext_write_buf; - bool got_read_eof_ = false; - size_t read_buf_fill_size_ = 0; + bool got_read_eof = false; + size_t read_buf_fill_size = 0; }; ResultVal<std::unique_ptr<SSLConnectionBackend>> CreateSSLConnectionBackend() { |