Skip to content

Commit dfee391

Browse files
ARROW-9187: [R] Add bindings for arithmetic kernels
Replaces apache#8947 @jonkeane Closes apache#9117 from nealrichardson/arith2 Lead-authored-by: Jonathan Keane <jkeane@gmail.com> Co-authored-by: Neal Richardson <neal.p.richardson@gmail.com> Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
1 parent 3622a2e commit dfee391

11 files changed

Lines changed: 406 additions & 39 deletions

File tree

r/NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
## Enhancements
3131

32+
* Arithmetic operations (`+`, `*`, etc.) are supported on Arrays and ChunkedArrays and can be used in filter expressions in Arrow `dplyr` pipelines
3233
* Table columns can now be added, replaced, or removed by assigning (`<-`) with either `$` or `[[`
3334
* Column names of Tables and RecordBatches can be renamed by assigning `names()`
3435
* Large string types can now be written to Parquet files

r/R/arrowExports.R

Lines changed: 0 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r/R/expression.R

Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,53 @@ build_array_expression <- function(.Generic, e1, e2, ...) {
5757
if (.Generic %in% names(.unary_function_map)) {
5858
expr <- array_expression(.unary_function_map[[.Generic]], e1)
5959
} else {
60-
e1 <- .wrap_arrow(e1, .Generic, e2$type)
61-
e2 <- .wrap_arrow(e2, .Generic, e1$type)
60+
e1 <- .wrap_arrow(e1, .Generic)
61+
e2 <- .wrap_arrow(e2, .Generic)
62+
63+
# In Arrow, "divide" is one function, which does integer division on
64+
# integer inputs and floating-point division on floats
65+
if (.Generic == "/") {
66+
# TODO: omg so many ways it's wrong to assume these types
67+
e1 <- cast_array_expression(e1, float64())
68+
e2 <- cast_array_expression(e2, float64())
69+
} else if (.Generic == "%/%") {
70+
# In R, integer division works like floor(float division)
71+
out <- build_array_expression("/", e1, e2)
72+
return(cast_array_expression(out, int32(), allow_float_truncate = TRUE))
73+
} else if (.Generic == "%%") {
74+
# {e1 - e2 * ( e1 %/% e2 )}
75+
# ^^^ form doesn't work because Ops.Array evaluates eagerly,
76+
# but we can build that up
77+
quotient <- build_array_expression("%/%", e1, e2)
78+
# this cast is to ensure that the result of this and e1 are the same
79+
# (autocasting only applies to scalars)
80+
base <- cast_array_expression(quotient * e2, e1$type)
81+
return(build_array_expression("-", e1, base))
82+
}
83+
6284
expr <- array_expression(.binary_function_map[[.Generic]], e1, e2, ...)
6385
}
6486
expr
6587
}
6688

67-
.wrap_arrow <- function(arg, fun, type) {
89+
cast_array_expression <- function(x, to_type, safe = TRUE, ...) {
90+
opts <- list(
91+
to_type = to_type,
92+
allow_int_overflow = !safe,
93+
allow_time_truncate = !safe,
94+
allow_float_truncate = !safe
95+
)
96+
array_expression("cast", x, options = modifyList(opts, list(...)))
97+
}
98+
99+
.wrap_arrow <- function(arg, fun) {
68100
if (!inherits(arg, c("ArrowObject", "array_expression"))) {
69101
# TODO: Array$create if lengths are equal?
70102
# TODO: these kernels should autocast like the dataset ones do (e.g. int vs. float)
71103
if (fun == "%in%") {
72-
arg <- Array$create(arg, type = type)
104+
arg <- Array$create(arg)
73105
} else {
74-
arg <- Scalar$create(arg, type = type)
106+
arg <- Scalar$create(arg)
75107
}
76108
}
77109
arg
@@ -91,6 +123,15 @@ build_array_expression <- function(.Generic, e1, e2, ...) {
91123
"<=" = "less_equal",
92124
"&" = "and_kleene",
93125
"|" = "or_kleene",
126+
"+" = "add_checked",
127+
"-" = "subtract_checked",
128+
"*" = "multiply_checked",
129+
"/" = "divide_checked",
130+
"%/%" = "divide_checked",
131+
# we don't actually use divide_checked with `%%`, rather it is rewritten to
132+
# use %/% above.
133+
"%%" = "divide_checked",
134+
# TODO: "^" (ARROW-11070)
94135
"%in%" = "is_in_meta_binary"
95136
)
96137

@@ -104,6 +145,16 @@ eval_array_expression <- function(x) {
104145
a
105146
}
106147
})
148+
if (length(x$args) == 2L) {
149+
# Insert implicit casts
150+
if (inherits(x$args[[1]], "Scalar")) {
151+
x$args[[1]] <- x$args[[1]]$cast(x$args[[2]]$type)
152+
} else if (inherits(x$args[[2]], "Scalar")) {
153+
x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type)
154+
} else if (x$fun == "is_in_meta_binary" && inherits(x$args[[2]], "Array")) {
155+
x$args[[2]] <- x$args[[2]]$cast(x$args[[1]]$type)
156+
}
157+
}
107158
call_function(x$fun, args = x$args, options = x$options %||% empty_named_list())
108159
}
109160

@@ -160,7 +211,16 @@ print.array_expression <- function(x, ...) {
160211
#' @export
161212
Expression <- R6Class("Expression", inherit = ArrowObject,
162213
public = list(
163-
ToString = function() dataset___expr__ToString(self)
214+
ToString = function() dataset___expr__ToString(self),
215+
cast = function(to_type, safe = TRUE, ...) {
216+
opts <- list(
217+
to_type = to_type,
218+
allow_int_overflow = !safe,
219+
allow_time_truncate = !safe,
220+
allow_float_truncate = !safe
221+
)
222+
Expression$create("cast", self, options = modifyList(opts, list(...)))
223+
}
164224
)
165225
)
166226
Expression$create <- function(function_name,
@@ -196,6 +256,21 @@ build_dataset_expression <- function(.Generic, e1, e2, ...) {
196256
if (!inherits(e2, "Expression")) {
197257
e2 <- Expression$scalar(e2)
198258
}
259+
260+
# In Arrow, "divide" is one function, which does integer division on
261+
# integer inputs and floating-point division on floats
262+
if (.Generic == "/") {
263+
# TODO: omg so many ways it's wrong to assume these types
264+
e1 <- e1$cast(float64())
265+
e2 <- e2$cast(float64())
266+
} else if (.Generic == "%/%") {
267+
# In R, integer division works like floor(float division)
268+
out <- build_dataset_expression("/", e1, e2)
269+
return(out$cast(int32(), allow_float_truncate = TRUE))
270+
} else if (.Generic == "%%") {
271+
return(e1 - e2 * ( e1 %/% e2 ))
272+
}
273+
199274
expr <- Expression$create(.binary_function_map[[.Generic]], e1, e2, ...)
200275
}
201276
expr

r/R/scalar.R

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,14 @@ Scalar <- R6Class("Scalar",
3232
# TODO: document the methods
3333
public = list(
3434
ToString = function() Scalar__ToString(self),
35-
cast = function(target_type) {
36-
Scalar__CastTo(self, as_type(target_type))
35+
cast = function(target_type, safe = TRUE, ...) {
36+
opts <- list(
37+
to_type = as_type(target_type),
38+
allow_int_overflow = !safe,
39+
allow_time_truncate = !safe,
40+
allow_float_truncate = !safe
41+
)
42+
call_function("cast", self, options = modifyList(opts, list(...)))
3743
},
3844
as_vector = function() Scalar__as_vector(self)
3945
),

r/src/arrowExports.cpp

Lines changed: 0 additions & 17 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

r/src/compute.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,33 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
185185
cpp11::as_cpp<bool>(options["skip_nulls"]));
186186
}
187187

188+
// hacky attempt to pass through to_type and other options
189+
if (func_name == "cast") {
190+
using Options = arrow::compute::CastOptions;
191+
auto out = std::make_shared<Options>(true);
192+
SEXP to_type = options["to_type"];
193+
if (!Rf_isNull(to_type) && cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type)) {
194+
out->to_type = cpp11::as_cpp<std::shared_ptr<arrow::DataType>>(to_type);
195+
}
196+
197+
SEXP allow_float_truncate = options["allow_float_truncate"];
198+
if (!Rf_isNull(allow_float_truncate) && cpp11::as_cpp<bool>(allow_float_truncate)) {
199+
out->allow_float_truncate = cpp11::as_cpp<bool>(allow_float_truncate);
200+
}
201+
202+
SEXP allow_time_truncate = options["allow_time_truncate"];
203+
if (!Rf_isNull(allow_time_truncate) && cpp11::as_cpp<bool>(allow_time_truncate)) {
204+
out->allow_time_truncate = cpp11::as_cpp<bool>(allow_time_truncate);
205+
}
206+
207+
SEXP allow_int_overflow = options["allow_int_overflow"];
208+
if (!Rf_isNull(allow_int_overflow) && cpp11::as_cpp<bool>(allow_int_overflow)) {
209+
out->allow_int_overflow = cpp11::as_cpp<bool>(allow_int_overflow);
210+
}
211+
212+
return out;
213+
}
214+
188215
return nullptr;
189216
}
190217

r/src/scalar.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ std::string Scalar__ToString(const std::shared_ptr<arrow::Scalar>& s) {
4747
return s->ToString();
4848
}
4949

50-
// [[arrow::export]]
51-
std::shared_ptr<arrow::Scalar> Scalar__CastTo(const std::shared_ptr<arrow::Scalar>& s,
52-
const std::shared_ptr<arrow::DataType>& t) {
53-
return ValueOrStop(s->CastTo(t));
54-
}
55-
5650
// [[arrow::export]]
5751
std::shared_ptr<arrow::Scalar> StructScalar__field(
5852
const std::shared_ptr<arrow::StructScalar>& s, int i) {
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
test_that("Addition", {
19+
a <- Array$create(c(1:4, NA_integer_))
20+
expect_type_equal(a, int32())
21+
expect_type_equal(a + 4, int32())
22+
expect_equal(a + 4, Array$create(c(5:8, NA_integer_)))
23+
expect_identical(as.vector(a + 4), c(5:8, NA_integer_))
24+
expect_equal(a + 4L, Array$create(c(5:8, NA_integer_)))
25+
expect_vector(a + 4L, c(5:8, NA_integer_))
26+
expect_equal(a + NA_integer_, Array$create(rep(NA_integer_, 5)))
27+
28+
# overflow errors — this is slightly different from R's `NA` coercion when
29+
# overflowing, but better than the alternative of silently restarting
30+
casted <- a$cast(int8())
31+
expect_error(casted + 127)
32+
expect_error(casted + 200)
33+
34+
skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919")
35+
expect_type_equal(a + 4.1, float64())
36+
expect_equal(a + 4.1, Array$create(c(5.1, 6.1, 7.1, 8.1, NA_real_)))
37+
})
38+
39+
test_that("Subtraction", {
40+
a <- Array$create(c(1:4, NA_integer_))
41+
expect_equal(a - 3, Array$create(c(-2:1, NA_integer_)))
42+
})
43+
44+
test_that("Multiplication", {
45+
a <- Array$create(c(1:4, NA_integer_))
46+
expect_equal(a * 2, Array$create(c(1:4 * 2L, NA_integer_)))
47+
})
48+
49+
test_that("Division", {
50+
a <- Array$create(c(1:4, NA_integer_))
51+
expect_equal(a / 2, Array$create(c(1:4 / 2, NA_real_)))
52+
expect_equal(a %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_)))
53+
expect_equal(a / 2 / 2, Array$create(c(1:4 / 2 / 2, NA_real_)))
54+
expect_equal(a %/% 2 %/% 2, Array$create(c(0L, 0L, 0L, 1L, NA_integer_)))
55+
56+
b <- a$cast(float64())
57+
expect_equal(b / 2, Array$create(c(1:4 / 2, NA_real_)))
58+
expect_equal(b %/% 2, Array$create(c(0L, 1L, 1L, 2L, NA_integer_)))
59+
60+
# the behavior of %/% matches R's (i.e. the integer of the quotient, not
61+
# simply dividing two integers)
62+
expect_equal(b / 2.2, Array$create(c(1:4 / 2.2, NA_real_)))
63+
# c(1:4) %/% 2.2 != c(1:4) %/% as.integer(2.2)
64+
# c(1:4) %/% 2.2 == c(0L, 0L, 1L, 1L)
65+
# c(1:4) %/% as.integer(2.2) == c(0L, 1L, 1L, 2L)
66+
expect_equal(b %/% 2.2, Array$create(c(0L, 0L, 1L, 1L, NA_integer_)))
67+
68+
expect_equal(a %% 2, Array$create(c(1L, 0L, 1L, 0L, NA_integer_)))
69+
70+
expect_equal(b %% 2, Array$create(c(1:4 %% 2, NA_real_)))
71+
})
72+
73+
test_that("Dates casting", {
74+
a <- Array$create(c(Sys.Date() + 1:4, NA_integer_))
75+
76+
skip("autocasting should happen in compute kernels; R workaround fails on this ARROW-8919")
77+
expect_equal(a + 2, Array$create(c((Sys.Date() + 1:4 ) + 2), NA_integer_))
78+
})

0 commit comments

Comments
 (0)