Skip to content

Commit 1fcbc6d

Browse files
pitroukszucs
andcommitted
ARROW-9478: [C++] Improve error message for unsupported casts
Mention both input type and target type, as far as possible. Closes apache#7773 from pitrou/ARROW-9478-better-cast-error-message Lead-authored-by: Antoine Pitrou <antoine@python.org> Co-authored-by: Krisztián Szűcs <szucs.krisztian@gmail.com> Signed-off-by: Krisztián Szűcs <szucs.krisztian@gmail.com>
1 parent bec2c85 commit 1fcbc6d

11 files changed

Lines changed: 410 additions & 223 deletions

File tree

cpp/src/arrow/compute/cast.cc

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
#include "arrow/util/logging.h"
3333

3434
namespace arrow {
35+
36+
using internal::ToTypeName;
37+
3538
namespace compute {
3639
namespace internal {
3740

@@ -54,6 +57,29 @@ void InitCastTable() {
5457

5558
void EnsureInitCastTable() { std::call_once(cast_table_initialized, InitCastTable); }
5659

60+
namespace {
61+
62+
// Private version of GetCastFunction with better error reporting
63+
// if the input type is known.
64+
Result<std::shared_ptr<CastFunction>> GetCastFunctionInternal(
65+
const std::shared_ptr<DataType>& to_type, const DataType* from_type = nullptr) {
66+
internal::EnsureInitCastTable();
67+
auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
68+
if (it == internal::g_cast_table.end()) {
69+
if (from_type != nullptr) {
70+
return Status::NotImplemented("Unsupported cast from ", *from_type, " to ",
71+
*to_type,
72+
" (no available cast function for target type)");
73+
} else {
74+
return Status::NotImplemented("Unsupported cast to ", *to_type,
75+
" (no available cast function for target type)");
76+
}
77+
}
78+
return it->second;
79+
}
80+
81+
} // namespace
82+
5783
// Metafunction for dispatching to appropraite CastFunction. This corresponds
5884
// to the standard SQL CAST(expr AS target_type)
5985
class CastMetaFunction : public MetaFunction {
@@ -79,8 +105,9 @@ class CastMetaFunction : public MetaFunction {
79105
if (args[0].type()->Equals(*cast_options->to_type)) {
80106
return args[0];
81107
}
82-
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<CastFunction> cast_func,
83-
GetCastFunction(cast_options->to_type));
108+
ARROW_ASSIGN_OR_RAISE(
109+
std::shared_ptr<CastFunction> cast_func,
110+
GetCastFunctionInternal(cast_options->to_type, args[0].type().get()));
84111
return cast_func->Execute(args, options, ctx);
85112
}
86113
};
@@ -147,9 +174,9 @@ Result<const ScalarKernel*> CastFunction::DispatchExact(
147174
}
148175

149176
if (candidate_kernels.size() == 0) {
150-
return Status::NotImplemented("Function ", this->name(),
151-
" has no kernel matching input type ",
152-
values[0].ToString());
177+
return Status::NotImplemented("Unsupported cast from ", values[0].type->ToString(),
178+
" to ", ToTypeName(impl_->out_type), " using function ",
179+
this->name());
153180
} else if (candidate_kernels.size() == 1) {
154181
// One match, return it
155182
return candidate_kernels[0];
@@ -188,13 +215,7 @@ Result<std::shared_ptr<Array>> Cast(const Array& value, std::shared_ptr<DataType
188215

189216
Result<std::shared_ptr<CastFunction>> GetCastFunction(
190217
const std::shared_ptr<DataType>& to_type) {
191-
internal::EnsureInitCastTable();
192-
auto it = internal::g_cast_table.find(static_cast<int>(to_type->id()));
193-
if (it == internal::g_cast_table.end()) {
194-
return Status::NotImplemented("No cast function available to cast to ",
195-
to_type->ToString());
196-
}
197-
return it->second;
218+
return internal::GetCastFunctionInternal(to_type);
198219
}
199220

200221
bool CanCast(const DataType& from_type, const DataType& to_type) {

cpp/src/arrow/compute/kernels/scalar_cast_test.cc

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include <string>
2323
#include <vector>
2424

25+
#include <gmock/gmock.h>
2526
#include <gtest/gtest.h>
2627

2728
#include "arrow/array.h"
@@ -1457,14 +1458,40 @@ TEST_F(TestCast, ChunkedArray) {
14571458
ASSERT_TRUE(out.chunked_array()->Equals(*ex_carr));
14581459
}
14591460

1460-
TEST_F(TestCast, UnsupportedTarget) {
1461-
std::vector<bool> is_valid = {true, false, true, true, true};
1462-
std::vector<int32_t> v1 = {0, 1, 2, 3, 4};
1461+
TEST_F(TestCast, UnsupportedInputType) {
1462+
// Casting to a supported target type, but with an unsupported input type
1463+
// for the target type.
1464+
const auto arr = ArrayFromJSON(int32(), "[1, 2, 3]");
14631465

1464-
std::shared_ptr<Array> arr;
1465-
ArrayFromVector<Int32Type>(int32(), is_valid, v1, &arr);
1466+
const auto to_type = list(utf8());
1467+
const char* expected_message = "Unsupported cast from int32 to list";
1468+
1469+
// Try through concrete API
1470+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
1471+
Cast(*arr, to_type));
1472+
1473+
// Try through general kernel API
1474+
CastOptions options;
1475+
options.to_type = to_type;
1476+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
1477+
CallFunction("cast", {arr}, &options));
1478+
}
1479+
1480+
TEST_F(TestCast, UnsupportedTargetType) {
1481+
// Casting to an unsupported target type
1482+
const auto arr = ArrayFromJSON(int32(), "[1, 2, 3]");
1483+
const auto to_type = dense_union({field("a", int32())});
14661484

1467-
ASSERT_RAISES(NotImplemented, Cast(*arr, list(utf8())));
1485+
// Try through concrete API
1486+
const char* expected_message = "Unsupported cast from int32 to dense_union";
1487+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
1488+
Cast(*arr, to_type));
1489+
1490+
// Try through general kernel API
1491+
CastOptions options;
1492+
options.to_type = to_type;
1493+
EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, ::testing::HasSubstr(expected_message),
1494+
CallFunction("cast", {arr}, &options));
14681495
}
14691496

14701497
TEST_F(TestCast, DateTimeZeroCopy) {

cpp/src/arrow/testing/gtest_util.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,45 @@ namespace arrow {
5454
using internal::checked_cast;
5555
using internal::checked_pointer_cast;
5656

57+
std::vector<Type::type> AllTypeIds() {
58+
return {Type::NA,
59+
Type::BOOL,
60+
Type::INT8,
61+
Type::INT16,
62+
Type::INT32,
63+
Type::INT64,
64+
Type::UINT8,
65+
Type::UINT16,
66+
Type::UINT32,
67+
Type::UINT64,
68+
Type::HALF_FLOAT,
69+
Type::FLOAT,
70+
Type::DOUBLE,
71+
Type::DECIMAL,
72+
Type::DATE32,
73+
Type::DATE64,
74+
Type::TIME32,
75+
Type::TIME64,
76+
Type::TIMESTAMP,
77+
Type::INTERVAL_DAY_TIME,
78+
Type::INTERVAL_MONTHS,
79+
Type::DURATION,
80+
Type::STRING,
81+
Type::BINARY,
82+
Type::LARGE_STRING,
83+
Type::LARGE_BINARY,
84+
Type::FIXED_SIZE_BINARY,
85+
Type::STRUCT,
86+
Type::LIST,
87+
Type::LARGE_LIST,
88+
Type::FIXED_SIZE_LIST,
89+
Type::MAP,
90+
Type::DENSE_UNION,
91+
Type::SPARSE_UNION,
92+
Type::DICTIONARY,
93+
Type::EXTENSION};
94+
}
95+
5796
template <typename T, typename CompareFunctor>
5897
void AssertTsSame(const T& expected, const T& actual, CompareFunctor&& compare) {
5998
if (!compare(actual, expected)) {

cpp/src/arrow/testing/gtest_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,9 @@ class RecordBatch;
153153
class Table;
154154
struct Datum;
155155

156+
ARROW_TESTING_EXPORT
157+
std::vector<Type::type> AllTypeIds();
158+
156159
#define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs))
157160
#define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs))
158161
#define ASSERT_BATCHES_APPROX_EQUAL(lhs, rhs) AssertBatchesApproxEqual((lhs), (rhs))

cpp/src/arrow/type.cc

Lines changed: 62 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -93,79 +93,71 @@ constexpr Type::type DictionaryType::type_id;
9393

9494
namespace internal {
9595

96+
struct TypeIdToTypeNameVisitor {
97+
std::string out;
98+
99+
template <typename ArrowType>
100+
Status Visit(const ArrowType*) {
101+
out = ArrowType::type_name();
102+
return Status::OK();
103+
}
104+
};
105+
106+
std::string ToTypeName(Type::type id) {
107+
TypeIdToTypeNameVisitor visitor;
108+
109+
ARROW_CHECK_OK(VisitTypeIdInline(id, &visitor));
110+
return std::move(visitor.out);
111+
}
112+
96113
std::string ToString(Type::type id) {
97114
switch (id) {
98-
case Type::NA:
99-
return "NA";
100-
case Type::BOOL:
101-
return "BOOL";
102-
case Type::UINT8:
103-
return "UINT8";
104-
case Type::INT8:
105-
return "INT8";
106-
case Type::UINT16:
107-
return "UINT16";
108-
case Type::INT16:
109-
return "INT16";
110-
case Type::UINT32:
111-
return "UINT32";
112-
case Type::INT32:
113-
return "INT32";
114-
case Type::UINT64:
115-
return "UINT64";
116-
case Type::INT64:
117-
return "INT64";
118-
case Type::HALF_FLOAT:
119-
return "HALF_FLOAT";
120-
case Type::FLOAT:
121-
return "FLOAT";
122-
case Type::DOUBLE:
123-
return "DOUBLE";
124-
case Type::STRING:
125-
return "UTF8";
126-
case Type::BINARY:
127-
return "BINARY";
128-
case Type::FIXED_SIZE_BINARY:
129-
return "FIXED_SIZE_BINARY";
130-
case Type::DATE64:
131-
return "DATE64";
132-
case Type::TIMESTAMP:
133-
return "TIMESTAMP";
134-
case Type::TIME32:
135-
return "TIME32";
136-
case Type::TIME64:
137-
return "TIME64";
138-
case Type::INTERVAL_MONTHS:
139-
return "INTERVAL_MONTHS";
140-
case Type::INTERVAL_DAY_TIME:
141-
return "INTERVAL_DAY_TIME";
142-
case Type::DECIMAL:
143-
return "DECIMAL";
144-
case Type::LIST:
145-
return "LIST";
146-
case Type::STRUCT:
147-
return "STRUCT";
148-
case Type::SPARSE_UNION:
149-
return "SPARSE_UNION";
150-
case Type::DENSE_UNION:
151-
return "DENSE_UNION";
152-
case Type::DICTIONARY:
153-
return "DICTIONARY";
154-
case Type::MAP:
155-
return "MAP";
156-
case Type::EXTENSION:
157-
return "EXTENSION";
158-
case Type::FIXED_SIZE_LIST:
159-
return "FIXED_SIZE_LIST";
160-
case Type::DURATION:
161-
return "DURATION";
162-
case Type::LARGE_BINARY:
163-
return "LARGE_BINARY";
164-
case Type::LARGE_LIST:
165-
return "LARGE_LIST";
115+
#define TO_STRING_CASE(_id) \
116+
case Type::_id: \
117+
return ARROW_STRINGIFY(_id);
118+
119+
TO_STRING_CASE(NA)
120+
TO_STRING_CASE(BOOL)
121+
TO_STRING_CASE(INT8)
122+
TO_STRING_CASE(INT16)
123+
TO_STRING_CASE(INT32)
124+
TO_STRING_CASE(INT64)
125+
TO_STRING_CASE(UINT8)
126+
TO_STRING_CASE(UINT16)
127+
TO_STRING_CASE(UINT32)
128+
TO_STRING_CASE(UINT64)
129+
TO_STRING_CASE(HALF_FLOAT)
130+
TO_STRING_CASE(FLOAT)
131+
TO_STRING_CASE(DOUBLE)
132+
TO_STRING_CASE(DECIMAL)
133+
TO_STRING_CASE(DATE32)
134+
TO_STRING_CASE(DATE64)
135+
TO_STRING_CASE(TIME32)
136+
TO_STRING_CASE(TIME64)
137+
TO_STRING_CASE(TIMESTAMP)
138+
TO_STRING_CASE(INTERVAL_DAY_TIME)
139+
TO_STRING_CASE(INTERVAL_MONTHS)
140+
TO_STRING_CASE(DURATION)
141+
TO_STRING_CASE(STRING)
142+
TO_STRING_CASE(BINARY)
143+
TO_STRING_CASE(LARGE_STRING)
144+
TO_STRING_CASE(LARGE_BINARY)
145+
TO_STRING_CASE(FIXED_SIZE_BINARY)
146+
TO_STRING_CASE(STRUCT)
147+
TO_STRING_CASE(LIST)
148+
TO_STRING_CASE(LARGE_LIST)
149+
TO_STRING_CASE(FIXED_SIZE_LIST)
150+
TO_STRING_CASE(MAP)
151+
TO_STRING_CASE(DENSE_UNION)
152+
TO_STRING_CASE(SPARSE_UNION)
153+
TO_STRING_CASE(DICTIONARY)
154+
TO_STRING_CASE(EXTENSION)
155+
156+
#undef TO_STRING_CASE
157+
166158
default:
167-
DCHECK(false) << "Should not be able to reach here";
168-
return "unknown";
159+
ARROW_LOG(FATAL) << "Unhandled type id: " << id;
160+
return "";
169161
}
170162
}
171163

0 commit comments

Comments
 (0)