From ed0319cfed2c99e6366aaf725d96bb28a9332e4d Mon Sep 17 00:00:00 2001 From: Liam Date: Sat, 2 Jul 2022 12:33:49 -0400 Subject: common/fiber: make fibers easier to use --- src/tests/common/fibers.cpp | 123 +++++++++++++------------------------------- 1 file changed, 35 insertions(+), 88 deletions(-) (limited to 'src/tests/common') diff --git a/src/tests/common/fibers.cpp b/src/tests/common/fibers.cpp index cfc84d423..4e29f9199 100644 --- a/src/tests/common/fibers.cpp +++ b/src/tests/common/fibers.cpp @@ -43,7 +43,15 @@ class TestControl1 { public: TestControl1() = default; - void DoWork(); + void DoWork() { + const u32 id = thread_ids.Get(); + u32 value = items[id]; + for (u32 i = 0; i < id; i++) { + value++; + } + results[id] = value; + Fiber::YieldTo(work_fibers[id], *thread_fibers[id]); + } void ExecuteThread(u32 id); @@ -54,35 +62,16 @@ public: std::vector results; }; -static void WorkControl1(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork(); -} - -void TestControl1::DoWork() { - const u32 id = thread_ids.Get(); - u32 value = items[id]; - for (u32 i = 0; i < id; i++) { - value++; - } - results[id] = value; - Fiber::YieldTo(work_fibers[id], *thread_fibers[id]); -} - void TestControl1::ExecuteThread(u32 id) { thread_ids.Register(id); auto thread_fiber = Fiber::ThreadToFiber(); thread_fibers[id] = thread_fiber; - work_fibers[id] = std::make_shared(std::function{WorkControl1}, this); + work_fibers[id] = std::make_shared([this] { DoWork(); }); items[id] = rand() % 256; Fiber::YieldTo(thread_fibers[id], *work_fibers[id]); thread_fibers[id]->Exit(); } -static void ThreadStart1(u32 id, TestControl1& test_control) { - test_control.ExecuteThread(id); -} - /** This test checks for fiber setup configuration and validates that fibers are * doing all the work required. */ @@ -95,7 +84,7 @@ TEST_CASE("Fibers::Setup", "[common]") { test_control.results.resize(num_threads, 0); std::vector threads; for (u32 i = 0; i < num_threads; i++) { - threads.emplace_back(ThreadStart1, i, std::ref(test_control)); + threads.emplace_back([&test_control, i] { test_control.ExecuteThread(i); }); } for (u32 i = 0; i < num_threads; i++) { threads[i].join(); @@ -167,21 +156,6 @@ public: std::shared_ptr fiber3; }; -static void WorkControl2_1(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork1(); -} - -static void WorkControl2_2(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork2(); -} - -static void WorkControl2_3(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork3(); -} - void TestControl2::ExecuteThread(u32 id) { thread_ids.Register(id); auto thread_fiber = Fiber::ThreadToFiber(); @@ -193,18 +167,6 @@ void TestControl2::Exit() { thread_fibers[id]->Exit(); } -static void ThreadStart2_1(u32 id, TestControl2& test_control) { - test_control.ExecuteThread(id); - test_control.CallFiber1(); - test_control.Exit(); -} - -static void ThreadStart2_2(u32 id, TestControl2& test_control) { - test_control.ExecuteThread(id); - test_control.CallFiber2(); - test_control.Exit(); -} - /** This test checks for fiber thread exchange configuration and validates that fibers are * that a fiber has been successfully transferred from one thread to another and that the TLS * region of the thread is kept while changing fibers. @@ -212,14 +174,19 @@ static void ThreadStart2_2(u32 id, TestControl2& test_control) { TEST_CASE("Fibers::InterExchange", "[common]") { TestControl2 test_control{}; test_control.thread_fibers.resize(2); - test_control.fiber1 = - std::make_shared(std::function{WorkControl2_1}, &test_control); - test_control.fiber2 = - std::make_shared(std::function{WorkControl2_2}, &test_control); - test_control.fiber3 = - std::make_shared(std::function{WorkControl2_3}, &test_control); - std::thread thread1(ThreadStart2_1, 0, std::ref(test_control)); - std::thread thread2(ThreadStart2_2, 1, std::ref(test_control)); + test_control.fiber1 = std::make_shared([&test_control] { test_control.DoWork1(); }); + test_control.fiber2 = std::make_shared([&test_control] { test_control.DoWork2(); }); + test_control.fiber3 = std::make_shared([&test_control] { test_control.DoWork3(); }); + std::thread thread1{[&test_control] { + test_control.ExecuteThread(0); + test_control.CallFiber1(); + test_control.Exit(); + }}; + std::thread thread2{[&test_control] { + test_control.ExecuteThread(1); + test_control.CallFiber2(); + test_control.Exit(); + }}; thread1.join(); thread2.join(); REQUIRE(test_control.assert1); @@ -270,16 +237,6 @@ public: std::shared_ptr fiber2; }; -static void WorkControl3_1(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork1(); -} - -static void WorkControl3_2(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork2(); -} - void TestControl3::ExecuteThread(u32 id) { thread_ids.Register(id); auto thread_fiber = Fiber::ThreadToFiber(); @@ -291,12 +248,6 @@ void TestControl3::Exit() { thread_fibers[id]->Exit(); } -static void ThreadStart3(u32 id, TestControl3& test_control) { - test_control.ExecuteThread(id); - test_control.CallFiber1(); - test_control.Exit(); -} - /** This test checks for one two threads racing for starting the same fiber. * It checks execution occurred in an ordered manner and by no time there were * two contexts at the same time. @@ -304,12 +255,15 @@ static void ThreadStart3(u32 id, TestControl3& test_control) { TEST_CASE("Fibers::StartRace", "[common]") { TestControl3 test_control{}; test_control.thread_fibers.resize(2); - test_control.fiber1 = - std::make_shared(std::function{WorkControl3_1}, &test_control); - test_control.fiber2 = - std::make_shared(std::function{WorkControl3_2}, &test_control); - std::thread thread1(ThreadStart3, 0, std::ref(test_control)); - std::thread thread2(ThreadStart3, 1, std::ref(test_control)); + test_control.fiber1 = std::make_shared([&test_control] { test_control.DoWork1(); }); + test_control.fiber2 = std::make_shared([&test_control] { test_control.DoWork2(); }); + const auto race_function{[&test_control](u32 id) { + test_control.ExecuteThread(id); + test_control.CallFiber1(); + test_control.Exit(); + }}; + std::thread thread1([&] { race_function(0); }); + std::thread thread2([&] { race_function(1); }); thread1.join(); thread2.join(); REQUIRE(test_control.value1 == 1); @@ -319,12 +273,10 @@ TEST_CASE("Fibers::StartRace", "[common]") { class TestControl4; -static void WorkControl4(void* control); - class TestControl4 { public: TestControl4() { - fiber1 = std::make_shared(std::function{WorkControl4}, this); + fiber1 = std::make_shared([this] { DoWork(); }); goal_reached = false; rewinded = false; } @@ -336,7 +288,7 @@ public: } void DoWork() { - fiber1->SetRewindPoint(std::function{WorkControl4}, this); + fiber1->SetRewindPoint([this] { DoWork(); }); if (rewinded) { goal_reached = true; Fiber::YieldTo(fiber1, *thread_fiber); @@ -351,11 +303,6 @@ public: bool rewinded; }; -static void WorkControl4(void* control) { - auto* test_control = static_cast(control); - test_control->DoWork(); -} - TEST_CASE("Fibers::Rewind", "[common]") { TestControl4 test_control{}; test_control.Execute(); -- cgit v1.2.3