Skip to content

Commit e8b2483

Browse files
feat: refactor KNN algorithms to use metrics infrastructure
- Create generic KNN base class supporting arbitrary element and metric types - Add knn_traits template for type abstraction - Update bfknn, bfknn_parallel, and kdtree implementations - Add point_t operator() overloads to vector metrics - Update tests and benchmarks to use new metric system - Maintain backward compatibility with legacy metric_type_t enum Note: This is a work in progress with some compilation issues remaining.
1 parent 9e2999b commit e8b2483

File tree

13 files changed

+1464
-773
lines changed

13 files changed

+1464
-773
lines changed

benchmark/pcl/knn_benc.cpp

Lines changed: 234 additions & 145 deletions
Large diffs are not rendered by default.

src/include/cpp-toolbox/metrics/angular_metrics.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <numeric>
66

77
#include <cpp-toolbox/metrics/base_metric.hpp>
8+
#include <cpp-toolbox/types/point.hpp>
89

910
namespace toolbox::metrics
1011
{
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#pragma once
2+
3+
#include <cpp-toolbox/metrics/metric_factory.hpp>
4+
#include <cpp-toolbox/types/point.hpp>
5+
6+
namespace toolbox::metrics
7+
{
8+
9+
// Adapter to use IMetric with point types
10+
template<typename T>
11+
class PointMetricAdapter : public IMetric<T>
12+
{
13+
public:
14+
explicit PointMetricAdapter(std::shared_ptr<IMetric<T>> metric)
15+
: m_metric(metric) {}
16+
17+
T distance(const T* a, const T* b, std::size_t size) const override
18+
{
19+
return m_metric->distance(a, b, size);
20+
}
21+
22+
T squared_distance(const T* a, const T* b, std::size_t size) const override
23+
{
24+
return m_metric->squared_distance(a, b, size);
25+
}
26+
27+
// Adapter method for point types
28+
T operator()(const toolbox::types::point_t<T>& a,
29+
const toolbox::types::point_t<T>& b) const
30+
{
31+
T arr_a[3] = {a.x, a.y, a.z};
32+
T arr_b[3] = {b.x, b.y, b.z};
33+
return distance(arr_a, arr_b, 3);
34+
}
35+
36+
private:
37+
std::shared_ptr<IMetric<T>> m_metric;
38+
};
39+
40+
} // namespace toolbox::metrics

src/include/cpp-toolbox/metrics/vector_metrics.hpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <stdexcept>
99

1010
#include <cpp-toolbox/metrics/base_metric.hpp>
11+
#include <cpp-toolbox/types/point.hpp>
1112

1213
namespace toolbox::metrics
1314
{
@@ -34,6 +35,17 @@ class L2Metric : public base_metric_t<L2Metric<T>, T>
3435
}
3536
return sum;
3637
}
38+
39+
// Overload for point_t types
40+
template<typename U>
41+
constexpr auto operator()(const toolbox::types::point_t<U>& a,
42+
const toolbox::types::point_t<U>& b) const
43+
{
44+
T dx = a.x - b.x;
45+
T dy = a.y - b.y;
46+
T dz = a.z - b.z;
47+
return std::sqrt(dx * dx + dy * dy + dz * dz);
48+
}
3749
};
3850

3951
template<typename T>
@@ -58,6 +70,16 @@ class L1Metric : public base_metric_t<L1Metric<T>, T>
5870
T dist = distance_impl(a, b, size);
5971
return dist * dist;
6072
}
73+
74+
// Overload for point_t types
75+
template<typename U>
76+
constexpr auto operator()(const toolbox::types::point_t<U>& a,
77+
const toolbox::types::point_t<U>& b) const
78+
{
79+
return std::abs(static_cast<T>(a.x - b.x)) +
80+
std::abs(static_cast<T>(a.y - b.y)) +
81+
std::abs(static_cast<T>(a.z - b.z));
82+
}
6183
};
6284

6385
template<typename T>
@@ -85,6 +107,16 @@ class LinfMetric : public base_metric_t<LinfMetric<T>, T>
85107
T dist = distance_impl(a, b, size);
86108
return dist * dist;
87109
}
110+
111+
// Overload for point_t types
112+
template<typename U>
113+
constexpr auto operator()(const toolbox::types::point_t<U>& a,
114+
const toolbox::types::point_t<U>& b) const
115+
{
116+
return std::max({std::abs(static_cast<T>(a.x - b.x)),
117+
std::abs(static_cast<T>(a.y - b.y)),
118+
std::abs(static_cast<T>(a.z - b.z))});
119+
}
88120
};
89121

90122
template<typename T, int P = 2>
@@ -206,6 +238,35 @@ class GeneralizedLpMetric : public base_metric_t<GeneralizedLpMetric<T>, T>
206238
return dist * dist;
207239
}
208240
}
241+
242+
// Overload for point_t types
243+
template<typename U>
244+
constexpr auto operator()(const toolbox::types::point_t<U>& a,
245+
const toolbox::types::point_t<U>& b) const
246+
{
247+
if (std::abs(p_ - T(1)) < std::numeric_limits<T>::epsilon()) {
248+
return std::abs(static_cast<T>(a.x - b.x)) +
249+
std::abs(static_cast<T>(a.y - b.y)) +
250+
std::abs(static_cast<T>(a.z - b.z));
251+
}
252+
else if (std::abs(p_ - T(2)) < std::numeric_limits<T>::epsilon()) {
253+
T dx = a.x - b.x;
254+
T dy = a.y - b.y;
255+
T dz = a.z - b.z;
256+
return std::sqrt(dx * dx + dy * dy + dz * dz);
257+
}
258+
else if (std::isinf(p_)) {
259+
return std::max({std::abs(static_cast<T>(a.x - b.x)),
260+
std::abs(static_cast<T>(a.y - b.y)),
261+
std::abs(static_cast<T>(a.z - b.z))});
262+
}
263+
else {
264+
T sum = std::pow(std::abs(static_cast<T>(a.x - b.x)), p_) +
265+
std::pow(std::abs(static_cast<T>(a.y - b.y)), p_) +
266+
std::pow(std::abs(static_cast<T>(a.z - b.z)), p_);
267+
return std::pow(sum, T(1) / p_);
268+
}
269+
}
209270

210271
T get_p() const { return p_; }
211272

src/include/cpp-toolbox/pcl/knn/base_knn.hpp

Lines changed: 88 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,28 @@
11
#pragma once
22

33
#include <memory>
4+
#include <type_traits>
45

56
#include <cpp-toolbox/cpp-toolbox_export.hpp>
67
#include <cpp-toolbox/types/point.hpp>
8+
#include <cpp-toolbox/metrics/vector_metrics.hpp>
9+
#include <cpp-toolbox/metrics/metric_factory.hpp>
710

811
namespace toolbox::pcl
912
{
1013

1114
using toolbox::types::point_t;
1215

16+
// KNN traits for generic element and metric types
17+
template<typename Element, typename Metric>
18+
struct knn_traits
19+
{
20+
using element_type = Element;
21+
using metric_type = Metric;
22+
using distance_type = typename Metric::result_type;
23+
};
24+
25+
// Legacy metric type enum for backward compatibility
1326
enum class CPP_TOOLBOX_EXPORT metric_type_t : std::uint8_t
1427
{
1528
euclidean,
@@ -18,59 +31,107 @@ enum class CPP_TOOLBOX_EXPORT metric_type_t : std::uint8_t
1831
minkowski,
1932
};
2033

21-
template<typename Derived, typename DataType>
22-
class CPP_TOOLBOX_EXPORT base_knn_t
34+
// Primary template: generic element and metric support
35+
template<typename Derived, typename Element, typename Metric = toolbox::metrics::L2Metric<typename Element::value_type>>
36+
class CPP_TOOLBOX_EXPORT base_knn_generic_t
2337
{
2438
public:
25-
using data_type = DataType;
26-
using point_cloud = toolbox::types::point_cloud_t<data_type>;
27-
using point_cloud_ptr =
28-
std::shared_ptr<toolbox::types::point_cloud_t<data_type>>;
29-
30-
std::size_t set_input(const point_cloud& cloud)
39+
using traits_type = knn_traits<Element, Metric>;
40+
using element_type = typename traits_type::element_type;
41+
using metric_type = typename traits_type::metric_type;
42+
using distance_type = typename traits_type::distance_type;
43+
using container_type = std::vector<element_type>;
44+
using container_ptr = std::shared_ptr<container_type>;
45+
46+
// Set input data
47+
std::size_t set_input(const container_type& data)
3148
{
32-
return static_cast<Derived*>(this)->set_input_impl(cloud);
49+
return static_cast<Derived*>(this)->set_input_impl(data);
3350
}
3451

35-
std::size_t set_input(const point_cloud_ptr& cloud)
52+
std::size_t set_input(const container_ptr& data)
3653
{
37-
return static_cast<Derived*>(this)->set_input_impl(cloud);
54+
return static_cast<Derived*>(this)->set_input_impl(data);
3855
}
3956

40-
std::size_t set_metric(metric_type_t metric)
57+
// Set metric (compile-time version)
58+
void set_metric(const metric_type& metric)
4159
{
42-
return static_cast<Derived*>(this)->set_metric_impl(metric);
60+
static_cast<Derived*>(this)->set_metric_impl(metric);
4361
}
4462

45-
[[nodiscard]] metric_type_t get_metric() const noexcept { return m_metric; }
63+
// Set metric (runtime version)
64+
template<typename T = typename Element::value_type>
65+
void set_metric(std::shared_ptr<toolbox::metrics::IMetric<T>> metric)
66+
{
67+
static_cast<Derived*>(this)->set_metric_impl(metric);
68+
}
4669

47-
bool kneighbors(const point_t<data_type>& query_point,
70+
// K-nearest neighbors search
71+
bool kneighbors(const element_type& query,
4872
std::size_t num_neighbors,
4973
std::vector<std::size_t>& indices,
50-
std::vector<data_type>& distances)
74+
std::vector<distance_type>& distances)
5175
{
5276
return static_cast<Derived*>(this)->kneighbors_impl(
53-
query_point, num_neighbors, indices, distances);
77+
query, num_neighbors, indices, distances);
5478
}
5579

56-
bool radius_neighbors(const point_t<data_type>& query_point,
57-
data_type radius,
80+
// Radius neighbors search
81+
bool radius_neighbors(const element_type& query,
82+
distance_type radius,
5883
std::vector<std::size_t>& indices,
59-
std::vector<data_type>& distances)
84+
std::vector<distance_type>& distances)
6085
{
6186
return static_cast<Derived*>(this)->radius_neighbors_impl(
62-
query_point, radius, indices, distances);
87+
query, radius, indices, distances);
6388
}
6489

6590
protected:
66-
base_knn_t() = default;
67-
~base_knn_t() = default;
91+
base_knn_generic_t() = default;
92+
~base_knn_generic_t() = default;
6893

6994
public:
70-
base_knn_t(const base_knn_t&) = delete;
71-
base_knn_t& operator=(const base_knn_t&) = delete;
72-
base_knn_t(base_knn_t&&) = delete;
73-
base_knn_t& operator=(base_knn_t&&) = delete;
95+
base_knn_generic_t(const base_knn_generic_t&) = delete;
96+
base_knn_generic_t& operator=(const base_knn_generic_t&) = delete;
97+
base_knn_generic_t(base_knn_generic_t&&) = delete;
98+
base_knn_generic_t& operator=(base_knn_generic_t&&) = delete;
99+
}; // class base_knn_generic_t
100+
101+
// Legacy base_knn_t for backward compatibility
102+
template<typename Derived, typename DataType>
103+
class CPP_TOOLBOX_EXPORT base_knn_t : public base_knn_generic_t<Derived, point_t<DataType>, toolbox::metrics::L2Metric<DataType>>
104+
{
105+
public:
106+
using base_type = base_knn_generic_t<Derived, point_t<DataType>, toolbox::metrics::L2Metric<DataType>>;
107+
using data_type = DataType;
108+
using point_cloud = toolbox::types::point_cloud_t<data_type>;
109+
using point_cloud_ptr = std::shared_ptr<toolbox::types::point_cloud_t<data_type>>;
110+
111+
// Legacy interface
112+
std::size_t set_input(const point_cloud& cloud)
113+
{
114+
// Convert point cloud to vector of points
115+
std::vector<point_t<DataType>> points(cloud.points.begin(), cloud.points.end());
116+
return base_type::set_input(points);
117+
}
118+
119+
std::size_t set_input(const point_cloud_ptr& cloud)
120+
{
121+
if (!cloud) return 0;
122+
return set_input(*cloud);
123+
}
124+
125+
std::size_t set_metric(metric_type_t metric)
126+
{
127+
return static_cast<Derived*>(this)->set_metric_impl(metric);
128+
}
129+
130+
[[nodiscard]] metric_type_t get_metric() const noexcept { return m_metric; }
131+
132+
protected:
133+
base_knn_t() = default;
134+
~base_knn_t() = default;
74135

75136
private:
76137
metric_type_t m_metric = metric_type_t::euclidean;

src/include/cpp-toolbox/pcl/knn/bfknn.hpp

Lines changed: 55 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,61 @@
11
#pragma once
22

33
#include <cpp-toolbox/pcl/knn/base_knn.hpp>
4+
#include <cpp-toolbox/metrics/metric_factory.hpp>
5+
#include <cpp-toolbox/metrics/vector_metrics.hpp>
46

57
namespace toolbox::pcl
68
{
79

10+
// Generic brute-force KNN implementation
11+
template<typename Element, typename Metric = toolbox::metrics::L2Metric<typename Element::value_type>>
12+
class CPP_TOOLBOX_EXPORT bfknn_generic_t : public base_knn_generic_t<bfknn_generic_t<Element, Metric>, Element, Metric>
13+
{
14+
public:
15+
using base_type = base_knn_generic_t<bfknn_generic_t<Element, Metric>, Element, Metric>;
16+
using traits_type = typename base_type::traits_type;
17+
using element_type = typename traits_type::element_type;
18+
using metric_type = typename traits_type::metric_type;
19+
using distance_type = typename traits_type::distance_type;
20+
using container_type = typename base_type::container_type;
21+
using container_ptr = typename base_type::container_ptr;
22+
23+
bfknn_generic_t() = default;
24+
~bfknn_generic_t() = default;
25+
26+
bfknn_generic_t(const bfknn_generic_t&) = delete;
27+
bfknn_generic_t& operator=(const bfknn_generic_t&) = delete;
28+
bfknn_generic_t(bfknn_generic_t&&) = delete;
29+
bfknn_generic_t& operator=(bfknn_generic_t&&) = delete;
30+
31+
// Set input data implementations
32+
std::size_t set_input_impl(const container_type& data);
33+
std::size_t set_input_impl(const container_ptr& data);
34+
35+
// Set metric implementations
36+
void set_metric_impl(const metric_type& metric);
37+
template<typename T = typename Element::value_type>
38+
void set_metric_impl(std::shared_ptr<toolbox::metrics::IMetric<T>> metric);
39+
40+
// KNN search implementations
41+
bool kneighbors_impl(const element_type& query,
42+
std::size_t num_neighbors,
43+
std::vector<std::size_t>& indices,
44+
std::vector<distance_type>& distances);
45+
46+
bool radius_neighbors_impl(const element_type& query,
47+
distance_type radius,
48+
std::vector<std::size_t>& indices,
49+
std::vector<distance_type>& distances);
50+
51+
private:
52+
container_ptr m_data;
53+
metric_type m_compile_time_metric;
54+
std::shared_ptr<toolbox::metrics::IMetric<typename Element::value_type>> m_runtime_metric;
55+
bool m_use_runtime_metric = false;
56+
};
57+
58+
// Legacy brute-force KNN for backward compatibility
859
template<typename DataType>
960
class CPP_TOOLBOX_EXPORT bfknn_t : public base_knn_t<bfknn_t<DataType>, DataType>
1061
{
@@ -14,6 +65,9 @@ class CPP_TOOLBOX_EXPORT bfknn_t : public base_knn_t<bfknn_t<DataType>, DataType
1465
using point_cloud = typename base_type::point_cloud;
1566
using point_cloud_ptr = typename base_type::point_cloud_ptr;
1667

68+
// Internal generic implementation
69+
using generic_impl_type = bfknn_generic_t<point_t<DataType>, toolbox::metrics::L2Metric<DataType>>;
70+
1771
bfknn_t() = default;
1872
~bfknn_t() = default;
1973

@@ -37,11 +91,7 @@ class CPP_TOOLBOX_EXPORT bfknn_t : public base_knn_t<bfknn_t<DataType>, DataType
3791
std::vector<data_type>& distances);
3892

3993
private:
40-
data_type compute_distance(const point_t<data_type>& p1,
41-
const point_t<data_type>& p2,
42-
metric_type_t metric) const;
43-
44-
point_cloud_ptr m_cloud;
94+
generic_impl_type m_impl;
4595
metric_type_t m_metric = metric_type_t::euclidean;
4696
};
4797

0 commit comments

Comments
 (0)