|
27 | 27 | #include "arrow/compute/exec/util.h" |
28 | 28 | #include "arrow/compute/kernels/row_encoder.h" |
29 | 29 | #include "arrow/compute/kernels/test_util.h" |
| 30 | +#include "arrow/testing/extension_type.h" |
30 | 31 | #include "arrow/testing/gtest_util.h" |
31 | 32 | #include "arrow/testing/matchers.h" |
32 | 33 | #include "arrow/testing/random.h" |
@@ -1801,6 +1802,114 @@ TEST(HashJoin, UnsupportedTypes) { |
1801 | 1802 | } |
1802 | 1803 | } |
1803 | 1804 |
|
| 1805 | +void TestSimpleJoinHelper(BatchesWithSchema input_left, BatchesWithSchema input_right, |
| 1806 | + BatchesWithSchema expected) { |
| 1807 | + ExecContext exec_ctx; |
| 1808 | + ASSERT_OK_AND_ASSIGN(auto plan, ExecPlan::Make(&exec_ctx)); |
| 1809 | + AsyncGenerator<util::optional<ExecBatch>> sink_gen; |
| 1810 | + |
| 1811 | + ExecNode* left_source; |
| 1812 | + ExecNode* right_source; |
| 1813 | + ASSERT_OK_AND_ASSIGN( |
| 1814 | + left_source, |
| 1815 | + MakeExecNode("source", plan.get(), {}, |
| 1816 | + SourceNodeOptions{input_left.schema, input_left.gen(/*parallel=*/false, |
| 1817 | + /*slow=*/false)})); |
| 1818 | + |
| 1819 | + ASSERT_OK_AND_ASSIGN(right_source, |
| 1820 | + MakeExecNode("source", plan.get(), {}, |
| 1821 | + SourceNodeOptions{input_right.schema, |
| 1822 | + input_right.gen(/*parallel=*/false, |
| 1823 | + /*slow=*/false)})); |
| 1824 | + |
| 1825 | + HashJoinNodeOptions join_opts{JoinType::INNER, |
| 1826 | + /*left_keys=*/{"lkey"}, |
| 1827 | + /*right_keys=*/{"rkey"}, literal(true), "_l", "_r"}; |
| 1828 | + |
| 1829 | + ASSERT_OK_AND_ASSIGN( |
| 1830 | + auto hashjoin, |
| 1831 | + MakeExecNode("hashjoin", plan.get(), {left_source, right_source}, join_opts)); |
| 1832 | + |
| 1833 | + ASSERT_OK_AND_ASSIGN(std::ignore, MakeExecNode("sink", plan.get(), {hashjoin}, |
| 1834 | + SinkNodeOptions{&sink_gen})); |
| 1835 | + |
| 1836 | + ASSERT_FINISHES_OK_AND_ASSIGN(auto result, StartAndCollect(plan.get(), sink_gen)); |
| 1837 | + |
| 1838 | + ASSERT_OK_AND_ASSIGN( |
| 1839 | + auto output_rows_test, |
| 1840 | + TableFromExecBatches(std::move(hashjoin->output_schema()), result)); |
| 1841 | + ASSERT_OK_AND_ASSIGN( |
| 1842 | + auto expected_rows_test, |
| 1843 | + TableFromExecBatches(std::move(expected.schema), expected.batches)); |
| 1844 | + |
| 1845 | + AssertTablesEqual(*output_rows_test, *expected_rows_test, /*same_chunk_layout=*/false, |
| 1846 | + /*flatten=*/true); |
| 1847 | + AssertSchemaEqual(expected.schema, hashjoin->output_schema()); |
| 1848 | +} |
| 1849 | + |
| 1850 | +TEST(HashJoin, ExtensionTypesSwissJoin) { |
| 1851 | + // For simpler types swiss join will be used. |
| 1852 | + auto ext_arr = ExampleUuid(); |
| 1853 | + auto l_int_arr = ArrayFromJSON(int32(), "[1, 2, 3, 4]"); |
| 1854 | + auto l_int_arr2 = ArrayFromJSON(int32(), "[4, 5, 6, 7]"); |
| 1855 | + auto r_int_arr = ArrayFromJSON(int32(), "[4, 3, 2, null, 1]"); |
| 1856 | + |
| 1857 | + BatchesWithSchema input_left; |
| 1858 | + ASSERT_OK_AND_ASSIGN(ExecBatch left_batches, |
| 1859 | + ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr})); |
| 1860 | + input_left.batches = {left_batches}; |
| 1861 | + input_left.schema = schema( |
| 1862 | + {field("lkey", int32()), field("shared", int32()), field("ldistinct", uuid())}); |
| 1863 | + |
| 1864 | + BatchesWithSchema input_right; |
| 1865 | + ASSERT_OK_AND_ASSIGN(ExecBatch right_batches, ExecBatch::Make({r_int_arr})); |
| 1866 | + input_right.batches = {right_batches}; |
| 1867 | + input_right.schema = schema({field("rkey", int32())}); |
| 1868 | + |
| 1869 | + BatchesWithSchema expected; |
| 1870 | + ASSERT_OK_AND_ASSIGN(ExecBatch expected_batches, |
| 1871 | + ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr, l_int_arr})); |
| 1872 | + expected.batches = {expected_batches}; |
| 1873 | + expected.schema = schema({field("lkey", int32()), field("shared", int32()), |
| 1874 | + field("ldistinct", uuid()), field("rkey", int32())}); |
| 1875 | + |
| 1876 | + TestSimpleJoinHelper(input_left, input_right, expected); |
| 1877 | +} |
| 1878 | + |
| 1879 | +TEST(HashJoin, ExtensionTypesHashJoin) { |
| 1880 | + // Swiss join doesn't support dictionaries so HashJoin will be used. |
| 1881 | + auto dict_type = dictionary(int64(), int8()); |
| 1882 | + auto ext_arr = ExampleUuid(); |
| 1883 | + auto l_int_arr = ArrayFromJSON(int32(), "[1, 2, 3, 4]"); |
| 1884 | + auto l_int_arr2 = ArrayFromJSON(int32(), "[4, 5, 6, 7]"); |
| 1885 | + auto r_int_arr = ArrayFromJSON(int32(), "[4, 3, 2, null, 1]"); |
| 1886 | + auto l_dict_array = |
| 1887 | + DictArrayFromJSON(dict_type, R"([2, 0, 1, null])", R"([null, 0, 1])"); |
| 1888 | + |
| 1889 | + BatchesWithSchema input_left; |
| 1890 | + ASSERT_OK_AND_ASSIGN(ExecBatch left_batches, |
| 1891 | + ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr, l_dict_array})); |
| 1892 | + input_left.batches = {left_batches}; |
| 1893 | + input_left.schema = schema({field("lkey", int32()), field("shared", int32()), |
| 1894 | + field("ldistinct", uuid()), field("dict_type", dict_type)}); |
| 1895 | + |
| 1896 | + BatchesWithSchema input_right; |
| 1897 | + ASSERT_OK_AND_ASSIGN(ExecBatch right_batches, ExecBatch::Make({r_int_arr})); |
| 1898 | + input_right.batches = {right_batches}; |
| 1899 | + input_right.schema = schema({field("rkey", int32())}); |
| 1900 | + |
| 1901 | + BatchesWithSchema expected; |
| 1902 | + ASSERT_OK_AND_ASSIGN( |
| 1903 | + ExecBatch expected_batches, |
| 1904 | + ExecBatch::Make({l_int_arr, l_int_arr2, ext_arr, l_dict_array, l_int_arr})); |
| 1905 | + expected.batches = {expected_batches}; |
| 1906 | + expected.schema = schema({field("lkey", int32()), field("shared", int32()), |
| 1907 | + field("ldistinct", uuid()), field("dict_type", dict_type), |
| 1908 | + field("rkey", int32())}); |
| 1909 | + |
| 1910 | + TestSimpleJoinHelper(input_left, input_right, expected); |
| 1911 | +} |
| 1912 | + |
1804 | 1913 | TEST(HashJoin, CheckHashJoinNodeOptionsValidation) { |
1805 | 1914 | auto exec_ctx = |
1806 | 1915 | arrow::internal::make_unique<ExecContext>(default_memory_pool(), nullptr); |
|
0 commit comments