11#include < algorithm> // For std::for_each, std::transform, std::sort
2- #include < random>
32#include < iomanip>
3+ #include < random>
44#include < sstream>
55// #include <execution> // For std::execution::par (requires C++17 and
66// potentially TBB)
1212#include < catch2/benchmark/catch_benchmark.hpp>
1313#include < catch2/catch_test_macros.hpp>
1414
15- #include " cpp-toolbox/utils/plot.hpp"
15+ #include " cpp-toolbox/logger/thread_logger.hpp"
16+ #include " cpp-toolbox/utils/print.hpp"
17+ #include " cpp-toolbox/utils/random.hpp"
18+ #include " cpp-toolbox/utils/timer.hpp"
1619
1720// Include your parallel header
1821#include < cpp-toolbox/concurrent/parallel.hpp> // Corrected include path
2326long long serial_sum (const std::vector<int >& data)
2427{
2528 return std::accumulate (data.begin (), data.end (), 0LL );
29+ // long long sum = 0LL;
30+ // for (const auto& val : data) {
31+ // sum += val;
32+ // }
33+ // return sum;
2634}
2735
2836// Parallel sum using std::execution::par (Requires C++17 and proper
@@ -78,9 +86,14 @@ void square_in_place_op(int& x)
7886TEST_CASE (" Benchmark Parallel Algorithms" )
7987{
8088 // Prepare large test data
81- const size_t data_size = 5'000'000 ; // Five million elements to stress parallelism
82- std::vector<int > data (data_size);
83- std::iota (data.begin (), data.end (), 1 ); // Fill with 1, 2, 3, ...
89+ const size_t data_size =
90+ 10'000'000 ; // Five million elements to stress parallelism
91+ const auto data =
92+ toolbox::utils::generate<std::vector<int > >(data_size, -100 , 100 );
93+
94+ const size_t sum_data_size = 100'000'000 ; // Hundred million elements
95+ const auto sum_data =
96+ toolbox::utils::generate<std::vector<int > >(sum_data_size, -100 , 100 );
8497
8598 std::vector<int > output_data (data_size); // For transform output
8699 std::vector<long long > scan_output (data_size); // For inclusive scan
@@ -175,7 +188,7 @@ TEST_CASE("Benchmark Parallel Algorithms")
175188
176189 BENCHMARK (" Serial Sum (std::accumulate)" )
177190 {
178- return serial_sum (data );
191+ return serial_sum (sum_data );
179192 };
180193
181194 // Optional: Benchmark std::execution::par if configured
@@ -187,7 +200,7 @@ TEST_CASE("Benchmark Parallel Algorithms")
187200
188201 BENCHMARK (" Parallel Sum (toolbox::parallel_reduce)" )
189202 {
190- return toolbox_parallel_sum (data );
203+ return toolbox_parallel_sum (sum_data );
191204 };
192205 }
193206
@@ -322,12 +335,8 @@ TEST_CASE("Benchmark Parallel Algorithms")
322335
323336 BENCHMARK (" Parallel Inclusive Scan (toolbox::parallel_inclusive_scan)" )
324337 {
325- toolbox::concurrent::parallel_inclusive_scan (data.cbegin (),
326- data.cend (),
327- scan_out.begin (),
328- 0 ,
329- std::plus<int >(),
330- 0 );
338+ toolbox::concurrent::parallel_inclusive_scan (
339+ data.cbegin (), data.cend (), scan_out.begin (), 0 , std::plus<int >(), 0 );
331340 return scan_out.back ();
332341 };
333342 }
@@ -360,51 +369,70 @@ TEST_CASE("Benchmark Parallel Algorithms")
360369 // --- Timing Table and Plot ---------------------------------------------
361370 SECTION (" Timing Table" )
362371 {
363- using namespace std ::chrono;
372+ // 使用 toolbox::utils::stop_watch_timer_t 进行更准确的计时
373+ // Use toolbox::utils::stop_watch_timer_t for more accurate timing
364374 auto measure = [&](auto && func)
365375 {
366- const int iters = 3 ;
367- double total = 0.0 ;
376+ const int iters = 5 ; // 增加迭代次数以获得更稳定的结果
377+ double total_ms = 0.0 ;
368378 for (int i = 0 ; i < iters; ++i) {
369- auto start = high_resolution_clock::now ();
379+ toolbox::utils::stop_watch_timer_t timer;
380+ timer.start ();
370381 func ();
371- auto end = high_resolution_clock::now ();
372- total += duration< double >(end - start). count ();
382+ timer. stop ();
383+ total_ms += timer. elapsed_time_ms (); // 直接获取毫秒值
373384 }
374- return total / static_cast <double >(iters);
385+ return total_ms / static_cast <double >(iters);
375386 };
376387
377- // Measure all algorithms
378- double reduce_serial = measure ([&]() { serial_sum (data); });
379- double reduce_parallel = measure ([&]() { toolbox_parallel_sum (data); });
380-
381- double for_each_serial = measure ([&]() {
382- std::vector<int > tmp = data;
383- std::for_each (tmp.begin (), tmp.end (), square_in_place_op);
384- });
385- double for_each_parallel = measure ([&]() {
386- std::vector<int > tmp = data;
387- toolbox::concurrent::parallel_for_each (tmp.begin (), tmp.end (),
388- square_in_place_op);
389- });
390-
391- double transform_serial = measure ([&]() {
392- std::transform (data.begin (), data.end (), output_data.begin (), square_op);
393- });
394- double transform_parallel = measure ([&]() {
395- toolbox::concurrent::parallel_transform (data.begin (), data.end (),
396- output_data.begin (), square_op);
397- });
388+ // Measure all algorithms - 确保使用相同大小的数据集
389+ // Make sure to use the same dataset size for fair comparison
390+ double reduce_serial = measure ([&]() { return serial_sum (sum_data); });
391+ LOG_DEBUG_S << " reduce_serial: " << reduce_serial << " ms" ;
392+ double reduce_parallel =
393+ measure ([&]() { return toolbox_parallel_sum (sum_data); });
394+
395+ double for_each_serial = measure (
396+ [&]()
397+ {
398+ std::vector<int > tmp = data;
399+ std::for_each (tmp.begin (), tmp.end (), square_in_place_op);
400+ });
401+ double for_each_parallel = measure (
402+ [&]()
403+ {
404+ std::vector<int > tmp = data;
405+ toolbox::concurrent::parallel_for_each (
406+ tmp.begin (), tmp.end (), square_in_place_op);
407+ });
408+
409+ double transform_serial = measure (
410+ [&]()
411+ {
412+ std::transform (
413+ data.begin (), data.end (), output_data.begin (), square_op);
414+ });
415+ double transform_parallel = measure (
416+ [&]()
417+ {
418+ toolbox::concurrent::parallel_transform (
419+ data.begin (), data.end (), output_data.begin (), square_op);
420+ });
398421
399422 std::vector<int > scan_tmp (data_size);
400- double scan_serial = measure ([&]() {
401- std::inclusive_scan (data.begin (), data.end (), scan_tmp.begin ());
402- });
403- double scan_parallel = measure ([&]() {
404- toolbox::concurrent::parallel_inclusive_scan (data.begin (), data.end (),
405- scan_tmp.begin (), 0 ,
406- std::plus<int >(), 0 );
407- });
423+ double scan_serial = measure (
424+ [&]()
425+ { std::inclusive_scan (data.begin (), data.end (), scan_tmp.begin ()); });
426+ double scan_parallel = measure (
427+ [&]()
428+ {
429+ toolbox::concurrent::parallel_inclusive_scan (data.begin (),
430+ data.end (),
431+ scan_tmp.begin (),
432+ 0 ,
433+ std::plus<int >(),
434+ 0 );
435+ });
408436
409437 std::vector<int > sort_input (data_size);
410438 {
@@ -414,29 +442,54 @@ TEST_CASE("Benchmark Parallel Algorithms")
414442 v = dist (rng);
415443 }
416444 }
417- double sort_serial = measure ([&]() {
418- auto tmp = sort_input;
419- std::sort (tmp.begin (), tmp.end ());
420- });
421- double sort_parallel = measure ([&]() {
422- auto tmp = sort_input;
423- toolbox::concurrent::parallel_merge_sort (tmp.begin (), tmp.end ());
424- });
445+ double sort_serial = measure (
446+ [&]()
447+ {
448+ auto tmp = sort_input;
449+ std::sort (tmp.begin (), tmp.end ());
450+ });
451+ double sort_parallel = measure (
452+ [&]()
453+ {
454+ auto tmp = sort_input;
455+ toolbox::concurrent::parallel_merge_sort (tmp.begin (), tmp.end ());
456+ });
425457
426458 toolbox::utils::table_t table;
427- table.set_headers ({" Benchmark" , " Serial (s)" , " Parallel (s)" , " Speedup" });
428- auto add_row = [&](const std::string& name, double s, double p) {
429- std::ostringstream ss;
430- ss.setf (std::ios::fixed);
431- ss << std::setprecision (6 ) << s;
432- std::ostringstream sp;
433- sp.setf (std::ios::fixed);
434- sp << std::setprecision (6 ) << p;
435- double speedup = s / p;
436- std::ostringstream sd;
437- sd.setf (std::ios::fixed);
438- sd << std::setprecision (2 ) << speedup;
439- table.add_row (name, ss.str (), sp.str (), sd.str ());
459+ table.set_headers ({" Benchmark" , " Serial (ms)" , " Parallel (ms)" , " Speedup" });
460+ auto add_row =
461+ [&](const std::string& name, double serial_ms, double parallel_ms)
462+ {
463+ // 添加调试输出,显示实际的计时值
464+ // Add debug output to show actual timing values
465+ std::cout << " DEBUG - " << name << " - Serial: " << serial_ms
466+ << " ms, Parallel: " << parallel_ms
467+ << " ms, Speedup: " << (serial_ms / parallel_ms) << " \n " ;
468+
469+ // 已经是毫秒值,不需要转换
470+ // Already in milliseconds, no conversion needed
471+ std::ostringstream serial_str;
472+ serial_str.setf (std::ios::fixed);
473+ serial_str << std::setprecision (3 ) << serial_ms;
474+
475+ std::ostringstream parallel_str;
476+ parallel_str.setf (std::ios::fixed);
477+ parallel_str << std::setprecision (3 ) << parallel_ms;
478+
479+ // 计算加速比
480+ // Calculate speedup
481+ double speedup = 1.0 ;
482+ if (parallel_ms > 0.001 ) { // 避免除以非常小的值 / Avoid division by very
483+ // small values
484+ speedup = serial_ms / parallel_ms;
485+ }
486+
487+ std::ostringstream speedup_str;
488+ speedup_str.setf (std::ios::fixed);
489+ speedup_str << std::setprecision (2 ) << speedup;
490+
491+ table.add_row (
492+ name, serial_str.str (), parallel_str.str (), speedup_str.str ());
440493 };
441494
442495 add_row (" Reduce" , reduce_serial, reduce_parallel);
@@ -447,17 +500,6 @@ TEST_CASE("Benchmark Parallel Algorithms")
447500
448501 std::cout << table << " \n " ;
449502
450- toolbox::utils::plot_t plot;
451- plot.set_title (" Sum Benchmark (sec)" );
452- plot.set_x_axis ();
453- plot.set_y_axis ();
454- plot.enable_axis_grid ();
455- plot.add_scatter_series ({1.0 , 2.0 },
456- {reduce_serial, reduce_parallel},
457- toolbox::utils::color_t ::GREEN,
458- toolbox::utils::plot_t ::style_t ::CROSS);
459- std::cout << plot.render (40 , 10 ) << " \n " ;
460-
461503 REQUIRE (reduce_serial > 0.0 );
462504 REQUIRE (reduce_parallel > 0.0 );
463505 }
0 commit comments