Skip to content

Commit 1b2feed

Browse files
feat: enhance benchmark tests and random number generation utilities
1 parent f757dc7 commit 1b2feed

File tree

8 files changed

+576
-88
lines changed

8 files changed

+576
-88
lines changed

benchmark/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ list(APPEND BENCHMARK_FILES
66
types/types_benc.cpp
77
file/memory_mapped_file_benc.cpp
88
pcl/downsampling_benc.cpp
9+
${CMAKE_SOURCE_DIR}/test/my_catch2_main.cpp
910
)
1011

1112

benchmark/concurrent/parallel_benc.cpp

Lines changed: 123 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include <algorithm> // For std::for_each, std::transform, std::sort
2-
#include <random>
32
#include <iomanip>
3+
#include <random>
44
#include <sstream>
55
// #include <execution> // For std::execution::par (requires C++17 and
66
// potentially TBB)
@@ -12,7 +12,10 @@
1212
#include <catch2/benchmark/catch_benchmark.hpp>
1313
#include <catch2/catch_test_macros.hpp>
1414

15-
#include "cpp-toolbox/utils/plot.hpp"
15+
#include "cpp-toolbox/logger/thread_logger.hpp"
16+
#include "cpp-toolbox/utils/print.hpp"
17+
#include "cpp-toolbox/utils/random.hpp"
18+
#include "cpp-toolbox/utils/timer.hpp"
1619

1720
// Include your parallel header
1821
#include <cpp-toolbox/concurrent/parallel.hpp> // Corrected include path
@@ -23,6 +26,11 @@
2326
long long serial_sum(const std::vector<int>& data)
2427
{
2528
return std::accumulate(data.begin(), data.end(), 0LL);
29+
// long long sum = 0LL;
30+
// for (const auto& val : data) {
31+
// sum += val;
32+
// }
33+
// return sum;
2634
}
2735

2836
// Parallel sum using std::execution::par (Requires C++17 and proper
@@ -78,9 +86,14 @@ void square_in_place_op(int& x)
7886
TEST_CASE("Benchmark Parallel Algorithms")
7987
{
8088
// Prepare large test data
81-
const size_t data_size = 5'000'000; // Five million elements to stress parallelism
82-
std::vector<int> data(data_size);
83-
std::iota(data.begin(), data.end(), 1); // Fill with 1, 2, 3, ...
89+
const size_t data_size =
90+
10'000'000; // Five million elements to stress parallelism
91+
const auto data =
92+
toolbox::utils::generate<std::vector<int> >(data_size, -100, 100);
93+
94+
const size_t sum_data_size = 100'000'000; // Hundred million elements
95+
const auto sum_data =
96+
toolbox::utils::generate<std::vector<int> >(sum_data_size, -100, 100);
8497

8598
std::vector<int> output_data(data_size); // For transform output
8699
std::vector<long long> scan_output(data_size); // For inclusive scan
@@ -175,7 +188,7 @@ TEST_CASE("Benchmark Parallel Algorithms")
175188

176189
BENCHMARK("Serial Sum (std::accumulate)")
177190
{
178-
return serial_sum(data);
191+
return serial_sum(sum_data);
179192
};
180193

181194
// Optional: Benchmark std::execution::par if configured
@@ -187,7 +200,7 @@ TEST_CASE("Benchmark Parallel Algorithms")
187200

188201
BENCHMARK("Parallel Sum (toolbox::parallel_reduce)")
189202
{
190-
return toolbox_parallel_sum(data);
203+
return toolbox_parallel_sum(sum_data);
191204
};
192205
}
193206

@@ -322,12 +335,8 @@ TEST_CASE("Benchmark Parallel Algorithms")
322335

323336
BENCHMARK("Parallel Inclusive Scan (toolbox::parallel_inclusive_scan)")
324337
{
325-
toolbox::concurrent::parallel_inclusive_scan(data.cbegin(),
326-
data.cend(),
327-
scan_out.begin(),
328-
0,
329-
std::plus<int>(),
330-
0);
338+
toolbox::concurrent::parallel_inclusive_scan(
339+
data.cbegin(), data.cend(), scan_out.begin(), 0, std::plus<int>(), 0);
331340
return scan_out.back();
332341
};
333342
}
@@ -360,51 +369,70 @@ TEST_CASE("Benchmark Parallel Algorithms")
360369
// --- Timing Table and Plot ---------------------------------------------
361370
SECTION("Timing Table")
362371
{
363-
using namespace std::chrono;
372+
// 使用 toolbox::utils::stop_watch_timer_t 进行更准确的计时
373+
// Use toolbox::utils::stop_watch_timer_t for more accurate timing
364374
auto measure = [&](auto&& func)
365375
{
366-
const int iters = 3;
367-
double total = 0.0;
376+
const int iters = 5; // 增加迭代次数以获得更稳定的结果
377+
double total_ms = 0.0;
368378
for (int i = 0; i < iters; ++i) {
369-
auto start = high_resolution_clock::now();
379+
toolbox::utils::stop_watch_timer_t timer;
380+
timer.start();
370381
func();
371-
auto end = high_resolution_clock::now();
372-
total += duration<double>(end - start).count();
382+
timer.stop();
383+
total_ms += timer.elapsed_time_ms(); // 直接获取毫秒值
373384
}
374-
return total / static_cast<double>(iters);
385+
return total_ms / static_cast<double>(iters);
375386
};
376387

377-
// Measure all algorithms
378-
double reduce_serial = measure([&]() { serial_sum(data); });
379-
double reduce_parallel = measure([&]() { toolbox_parallel_sum(data); });
380-
381-
double for_each_serial = measure([&]() {
382-
std::vector<int> tmp = data;
383-
std::for_each(tmp.begin(), tmp.end(), square_in_place_op);
384-
});
385-
double for_each_parallel = measure([&]() {
386-
std::vector<int> tmp = data;
387-
toolbox::concurrent::parallel_for_each(tmp.begin(), tmp.end(),
388-
square_in_place_op);
389-
});
390-
391-
double transform_serial = measure([&]() {
392-
std::transform(data.begin(), data.end(), output_data.begin(), square_op);
393-
});
394-
double transform_parallel = measure([&]() {
395-
toolbox::concurrent::parallel_transform(data.begin(), data.end(),
396-
output_data.begin(), square_op);
397-
});
388+
// Measure all algorithms - 确保使用相同大小的数据集
389+
// Make sure to use the same dataset size for fair comparison
390+
double reduce_serial = measure([&]() { return serial_sum(sum_data); });
391+
LOG_DEBUG_S << "reduce_serial: " << reduce_serial << "ms";
392+
double reduce_parallel =
393+
measure([&]() { return toolbox_parallel_sum(sum_data); });
394+
395+
double for_each_serial = measure(
396+
[&]()
397+
{
398+
std::vector<int> tmp = data;
399+
std::for_each(tmp.begin(), tmp.end(), square_in_place_op);
400+
});
401+
double for_each_parallel = measure(
402+
[&]()
403+
{
404+
std::vector<int> tmp = data;
405+
toolbox::concurrent::parallel_for_each(
406+
tmp.begin(), tmp.end(), square_in_place_op);
407+
});
408+
409+
double transform_serial = measure(
410+
[&]()
411+
{
412+
std::transform(
413+
data.begin(), data.end(), output_data.begin(), square_op);
414+
});
415+
double transform_parallel = measure(
416+
[&]()
417+
{
418+
toolbox::concurrent::parallel_transform(
419+
data.begin(), data.end(), output_data.begin(), square_op);
420+
});
398421

399422
std::vector<int> scan_tmp(data_size);
400-
double scan_serial = measure([&]() {
401-
std::inclusive_scan(data.begin(), data.end(), scan_tmp.begin());
402-
});
403-
double scan_parallel = measure([&]() {
404-
toolbox::concurrent::parallel_inclusive_scan(data.begin(), data.end(),
405-
scan_tmp.begin(), 0,
406-
std::plus<int>(), 0);
407-
});
423+
double scan_serial = measure(
424+
[&]()
425+
{ std::inclusive_scan(data.begin(), data.end(), scan_tmp.begin()); });
426+
double scan_parallel = measure(
427+
[&]()
428+
{
429+
toolbox::concurrent::parallel_inclusive_scan(data.begin(),
430+
data.end(),
431+
scan_tmp.begin(),
432+
0,
433+
std::plus<int>(),
434+
0);
435+
});
408436

409437
std::vector<int> sort_input(data_size);
410438
{
@@ -414,29 +442,54 @@ TEST_CASE("Benchmark Parallel Algorithms")
414442
v = dist(rng);
415443
}
416444
}
417-
double sort_serial = measure([&]() {
418-
auto tmp = sort_input;
419-
std::sort(tmp.begin(), tmp.end());
420-
});
421-
double sort_parallel = measure([&]() {
422-
auto tmp = sort_input;
423-
toolbox::concurrent::parallel_merge_sort(tmp.begin(), tmp.end());
424-
});
445+
double sort_serial = measure(
446+
[&]()
447+
{
448+
auto tmp = sort_input;
449+
std::sort(tmp.begin(), tmp.end());
450+
});
451+
double sort_parallel = measure(
452+
[&]()
453+
{
454+
auto tmp = sort_input;
455+
toolbox::concurrent::parallel_merge_sort(tmp.begin(), tmp.end());
456+
});
425457

426458
toolbox::utils::table_t table;
427-
table.set_headers({"Benchmark", "Serial (s)", "Parallel (s)", "Speedup"});
428-
auto add_row = [&](const std::string& name, double s, double p) {
429-
std::ostringstream ss;
430-
ss.setf(std::ios::fixed);
431-
ss << std::setprecision(6) << s;
432-
std::ostringstream sp;
433-
sp.setf(std::ios::fixed);
434-
sp << std::setprecision(6) << p;
435-
double speedup = s / p;
436-
std::ostringstream sd;
437-
sd.setf(std::ios::fixed);
438-
sd << std::setprecision(2) << speedup;
439-
table.add_row(name, ss.str(), sp.str(), sd.str());
459+
table.set_headers({"Benchmark", "Serial (ms)", "Parallel (ms)", "Speedup"});
460+
auto add_row =
461+
[&](const std::string& name, double serial_ms, double parallel_ms)
462+
{
463+
// 添加调试输出,显示实际的计时值
464+
// Add debug output to show actual timing values
465+
std::cout << "DEBUG - " << name << " - Serial: " << serial_ms
466+
<< " ms, Parallel: " << parallel_ms
467+
<< " ms, Speedup: " << (serial_ms / parallel_ms) << "\n";
468+
469+
// 已经是毫秒值,不需要转换
470+
// Already in milliseconds, no conversion needed
471+
std::ostringstream serial_str;
472+
serial_str.setf(std::ios::fixed);
473+
serial_str << std::setprecision(3) << serial_ms;
474+
475+
std::ostringstream parallel_str;
476+
parallel_str.setf(std::ios::fixed);
477+
parallel_str << std::setprecision(3) << parallel_ms;
478+
479+
// 计算加速比
480+
// Calculate speedup
481+
double speedup = 1.0;
482+
if (parallel_ms > 0.001) { // 避免除以非常小的值 / Avoid division by very
483+
// small values
484+
speedup = serial_ms / parallel_ms;
485+
}
486+
487+
std::ostringstream speedup_str;
488+
speedup_str.setf(std::ios::fixed);
489+
speedup_str << std::setprecision(2) << speedup;
490+
491+
table.add_row(
492+
name, serial_str.str(), parallel_str.str(), speedup_str.str());
440493
};
441494

442495
add_row("Reduce", reduce_serial, reduce_parallel);
@@ -447,17 +500,6 @@ TEST_CASE("Benchmark Parallel Algorithms")
447500

448501
std::cout << table << "\n";
449502

450-
toolbox::utils::plot_t plot;
451-
plot.set_title("Sum Benchmark (sec)");
452-
plot.set_x_axis();
453-
plot.set_y_axis();
454-
plot.enable_axis_grid();
455-
plot.add_scatter_series({1.0, 2.0},
456-
{reduce_serial, reduce_parallel},
457-
toolbox::utils::color_t::GREEN,
458-
toolbox::utils::plot_t::style_t::CROSS);
459-
std::cout << plot.render(40, 10) << "\n";
460-
461503
REQUIRE(reduce_serial > 0.0);
462504
REQUIRE(reduce_parallel > 0.0);
463505
}

0 commit comments

Comments
 (0)