Skip to content

Commit d045fc5

Browse files
ARROW-17462: [R] Cast scalars to type of field in Expression building (apache#13985)
Logic is encapsulated in `wrap_scalars()` in expression.R. Most test updating (that is not linting) is changing some printed output types because `int * 2` now stays `int32`, and the printed ExecPlans don't have as many `cast`s in them. The tests added in `test-dplyr-query.R` are the explicit tests of the feature. Authored-by: Neal Richardson <neal.p.richardson@gmail.com> Signed-off-by: Neal Richardson <neal.p.richardson@gmail.com>
1 parent 8066c5e commit d045fc5

8 files changed

Lines changed: 228 additions & 33 deletions

File tree

r/R/compute.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ register_scalar_function <- function(name, fun, in_type, out_type,
379379
RegisterScalarUDF(name, scalar_function)
380380

381381
# register with dplyr binding (enables its use in mutate(), filter(), etc.)
382-
binding_fun <- function(...) build_expr(name, ...)
382+
binding_fun <- function(...) Expression$create(name, ...)
383383

384384
# inject the value of `name` into the expression to avoid saving this
385385
# execution environment in the binding, which eliminates a warning when the

r/R/expression.R

Lines changed: 114 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,13 @@ Expression$create <- function(function_name,
171171
args = list(...),
172172
options = empty_named_list()) {
173173
assert_that(is.string(function_name))
174-
assert_that(is_list_of(args, "Expression"), msg = "Expression arguments must be Expression objects")
174+
# Make sure all inputs are Expressions
175+
args <- lapply(args, function(x) {
176+
if (!inherits(x, "Expression")) {
177+
x <- Expression$scalar(x)
178+
}
179+
x
180+
})
175181
expr <- compute___expr__call(function_name, args, options)
176182
if (length(args)) {
177183
expr$schema <- unify_schemas(schemas = lapply(args, function(x) x$schema))
@@ -187,7 +193,10 @@ Expression$field_ref <- function(name) {
187193
compute___expr__field_ref(name)
188194
}
189195
Expression$scalar <- function(x) {
190-
expr <- compute___expr__scalar(Scalar$create(x))
196+
if (!inherits(x, "Scalar")) {
197+
x <- Scalar$create(x)
198+
}
199+
expr <- compute___expr__scalar(x)
191200
expr$schema <- schema()
192201
expr
193202
}
@@ -208,21 +217,20 @@ build_expr <- function(FUN,
208217
}
209218
if (FUN == "%in%") {
210219
# Special-case %in%, which is different from the Array function name
220+
value_set <- Array$create(args[[2]])
221+
try(
222+
value_set <- cast_or_parse(value_set, args[[1]]$type()),
223+
silent = TRUE
224+
)
225+
211226
expr <- Expression$create("is_in", args[[1]],
212227
options = list(
213-
# If args[[2]] is already an Arrow object (like a scalar),
214-
# this wouldn't work
215-
value_set = Array$create(args[[2]]),
228+
value_set = value_set,
216229
skip_nulls = TRUE
217230
)
218231
)
219232
} else {
220-
args <- lapply(args, function(x) {
221-
if (!inherits(x, "Expression")) {
222-
x <- Expression$scalar(x)
223-
}
224-
x
225-
})
233+
args <- wrap_scalars(args, FUN)
226234

227235
# In Arrow, "divide" is one function, which does integer division on
228236
# integer inputs and floating-point division on floats
@@ -258,6 +266,101 @@ build_expr <- function(FUN,
258266
expr
259267
}
260268

269+
wrap_scalars <- function(args, FUN) {
270+
arrow_fun <- .array_function_map[[FUN]] %||% FUN
271+
if (arrow_fun == "if_else") {
272+
# For if_else, the first arg should be a bool Expression, and we don't
273+
# want to consider that when casting the other args to the same type
274+
args[-1] <- wrap_scalars(args[-1], FUN = "")
275+
return(args)
276+
}
277+
278+
is_expr <- map_lgl(args, ~ inherits(., "Expression"))
279+
if (all(is_expr)) {
280+
# No wrapping is required
281+
return(args)
282+
}
283+
284+
args[!is_expr] <- lapply(args[!is_expr], Scalar$create)
285+
286+
# Some special casing by function
287+
# * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip
288+
# * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it
289+
if (any(is_expr) && !(arrow_fun %in% c("binary_repeat", "list_element")) && !(FUN %in% "%/%")) {
290+
try(
291+
{
292+
# If the Expression has no Schema embedded, we cannot resolve its
293+
# type here, so this will error, hence the try() wrapping it
294+
# This will also error if length(args[is_expr]) == 0, or
295+
# if there are multiple exprs that do not share a common type.
296+
to_type <- common_type(args[is_expr])
297+
# Try casting to this type, but if the cast fails,
298+
# we'll just keep the original
299+
args[!is_expr] <- lapply(args[!is_expr], cast_or_parse, type = to_type)
300+
},
301+
silent = TRUE
302+
)
303+
}
304+
305+
args[!is_expr] <- lapply(args[!is_expr], Expression$scalar)
306+
args
307+
}
308+
309+
common_type <- function(exprs) {
310+
types <- map(exprs, ~ .$type())
311+
first_type <- types[[1]]
312+
if (length(types) == 1 || all(map_lgl(types, ~ .$Equals(first_type)))) {
313+
# Functions (in our tests) that have multiple exprs to check:
314+
# * case_when
315+
# * pmin/pmax
316+
return(first_type)
317+
}
318+
stop("There is no common type in these expressions")
319+
}
320+
321+
cast_or_parse <- function(x, type) {
322+
to_type_id <- type$id
323+
if (to_type_id %in% c(Type[["DECIMAL128"]], Type[["DECIMAL256"]])) {
324+
# TODO: determine the minimum size of decimal (or integer) required to
325+
# accommodate x
326+
# We would like to keep calculations on decimal if that's what the data has
327+
# so that we don't lose precision. However, there are some limitations
328+
# today, so it makes sense to keep x as double (which is probably is from R)
329+
# and let Acero cast the decimal to double to compute.
330+
# You can specify in your query that x should be decimal or integer if you
331+
# know it to be safe.
332+
# * ARROW-17601: multiply(decimal, decimal) can fail to make output type
333+
return(x)
334+
}
335+
336+
# For most types, just cast.
337+
# But for string -> date/time, we need to call a parsing function
338+
if (x$type_id() %in% c(Type[["STRING"]], Type[["LARGE_STRING"]])) {
339+
if (to_type_id %in% c(Type[["DATE32"]], Type[["DATE64"]])) {
340+
x <- call_function(
341+
"strptime",
342+
x,
343+
options = list(format = "%Y-%m-%d", unit = 0L)
344+
)
345+
} else if (to_type_id == Type[["TIMESTAMP"]]) {
346+
x <- call_function(
347+
"strptime",
348+
x,
349+
options = list(format = "%Y-%m-%d %H:%M:%S", unit = 1L)
350+
)
351+
# R assumes timestamps without timezone specified are
352+
# local timezone while Arrow assumes UTC. For consistency
353+
# with R behavior, specify local timezone here.
354+
x <- call_function(
355+
"assume_timezone",
356+
x,
357+
options = list(timezone = Sys.timezone())
358+
)
359+
}
360+
}
361+
x$cast(type)
362+
}
363+
261364
#' @export
262365
Ops.Expression <- function(e1, e2) {
263366
if (.Generic == "!") {

r/tests/testthat/test-dataset-dplyr.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ test_that("mutate()", {
143143
chr: string
144144
dbl: double
145145
int: int32
146-
twice: double (multiply_checked(int, 2))
146+
twice: int32 (multiply_checked(int, 2))
147147
148148
* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
149149
See $.data for the source Arrow object",
@@ -219,7 +219,7 @@ test_that("arrange()", {
219219
chr: string
220220
dbl: double
221221
int: int32
222-
twice: double (multiply_checked(int, 2))
222+
twice: int32 (multiply_checked(int, 2))
223223
224224
* Filter: ((multiply_checked(dbl, 2) > 14) and (subtract_checked(dbl, 50) < 3))
225225
* Sorted by chr [asc], multiply_checked(int, 2) [desc], add_checked(dbl, int) [asc]
@@ -368,7 +368,7 @@ test_that("show_exec_plan(), show_query() and explain() with datasets", {
368368
"ExecPlan with .* nodes:.*", # boiler plate for ExecPlan
369369
"ProjectNode.*", # output columns
370370
"FilterNode.*", # filter node
371-
"int > 6.*cast.*", # filtering expressions + auto-casting of part
371+
"int > 6.*", # filtering expressions
372372
"SourceNode" # entry point
373373
)
374374
)

r/tests/testthat/test-dplyr-collapse.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ test_that("implicit_schema with mutate", {
5757
words = as.character(int)
5858
) %>%
5959
implicit_schema(),
60-
schema(numbers = float64(), words = utf8())
60+
schema(numbers = int32(), words = utf8())
6161
)
6262
})
6363

@@ -163,7 +163,7 @@ test_that("Properties of collapsed query", {
163163
"Table (query)
164164
lgl: bool
165165
total: int32
166-
extra: double (multiply_checked(total, 5))
166+
extra: int32 (multiply_checked(total, 5))
167167
168168
See $.data for the source Arrow object",
169169
fixed = TRUE

r/tests/testthat/test-dplyr-filter.R

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -217,25 +217,29 @@ test_that("filter() with between()", {
217217
filter(dbl >= int, dbl <= dbl2)
218218
)
219219

220-
expect_error(
221-
tbl %>%
222-
record_batch() %>%
220+
compare_dplyr_binding(
221+
.input %>%
223222
filter(between(dbl, 1, "2")) %>%
224-
collect()
223+
collect(),
224+
tbl
225225
)
226226

227-
expect_error(
228-
tbl %>%
229-
record_batch() %>%
227+
compare_dplyr_binding(
228+
.input %>%
230229
filter(between(dbl, 1, NA)) %>%
231-
collect()
230+
collect(),
231+
tbl
232232
)
233233

234-
expect_error(
235-
tbl %>%
236-
record_batch() %>%
237-
filter(between(chr, 1, 2)) %>%
238-
collect()
234+
expect_warning(
235+
compare_dplyr_binding(
236+
.input %>%
237+
filter(between(chr, 1, 2)) %>%
238+
collect(),
239+
tbl
240+
),
241+
# the dplyr version warns:
242+
"NAs introduced by coercion"
239243
)
240244
})
241245

r/tests/testthat/test-dplyr-mutate.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ test_that("print a mutated table", {
458458
print(),
459459
"Table (query)
460460
int: int32
461-
twice: double (multiply_checked(int, 2))
461+
twice: int32 (multiply_checked(int, 2))
462462
463463
See $.data for the source Arrow object",
464464
fixed = TRUE

r/tests/testthat/test-dplyr-query.R

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,3 +631,90 @@ test_that("collect() is identical to compute() %>% collect()", {
631631
collect()
632632
)
633633
})
634+
635+
test_that("Scalars in expressions match the type of the field, if possible", {
636+
tbl_with_datetime <- tbl
637+
tbl_with_datetime$dates <- as.Date("2022-08-28") + 1:10
638+
tbl_with_datetime$times <- lubridate::ymd_hms("2018-10-07 19:04:05") + 1:10
639+
tab <- Table$create(tbl_with_datetime)
640+
641+
# 5 is double in R but is properly interpreted as int, no cast is added
642+
expect_output(
643+
tab %>%
644+
filter(int == 5) %>%
645+
show_exec_plan(),
646+
"int == 5"
647+
)
648+
649+
# Because 5.2 can't cast to int32 without truncation, we pass as is
650+
# and Acero will cast int to float64
651+
expect_output(
652+
tab %>%
653+
filter(int == 5.2) %>%
654+
show_exec_plan(),
655+
"filter=(cast(int, {to_type=double",
656+
fixed = TRUE
657+
)
658+
expect_equal(
659+
tab %>%
660+
filter(int == 5.2) %>%
661+
nrow(),
662+
0
663+
)
664+
665+
# int == string, this works in dplyr and here too
666+
expect_output(
667+
tab %>%
668+
filter(int == "5") %>%
669+
show_exec_plan(),
670+
"int == 5",
671+
fixed = TRUE
672+
)
673+
expect_equal(
674+
tab %>%
675+
filter(int == "5") %>%
676+
nrow(),
677+
1
678+
)
679+
680+
# Strings automatically parsed to date/timestamp
681+
expect_output(
682+
tab %>%
683+
filter(dates > "2022-09-01") %>%
684+
show_exec_plan(),
685+
"dates > 2022-09-01"
686+
)
687+
compare_dplyr_binding(
688+
.input %>%
689+
filter(dates > "2022-09-01") %>%
690+
collect(),
691+
tbl_with_datetime
692+
)
693+
694+
expect_output(
695+
tab %>%
696+
filter(times > "2018-10-07 19:04:05") %>%
697+
show_exec_plan(),
698+
"times > 2018-10-0. ..:..:05"
699+
)
700+
compare_dplyr_binding(
701+
.input %>%
702+
filter(times > "2018-10-07 19:04:05") %>%
703+
collect(),
704+
tbl_with_datetime
705+
)
706+
707+
tab_with_decimal <- tab %>%
708+
mutate(dec = cast(dbl, decimal(15, 2))) %>%
709+
compute()
710+
711+
# This reproduces the issue on ARROW-17601, found in the TPC-H query 1
712+
# In ARROW-17462, we chose not to auto-cast to decimal to avoid that issue
713+
result <- tab_with_decimal %>%
714+
summarize(
715+
tpc_h_1 = sum(dec * (1 - dec) * (1 + dec), na.rm = TRUE),
716+
as_dbl = sum(dbl * (1 - dbl) * (1 + dbl), na.rm = TRUE)
717+
) %>%
718+
collect()
719+
expect_equal(result$tpc_h_1, result$as_dbl)
720+
})

r/tests/testthat/test-expression.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ test_that("C++ expressions", {
5858
# Interprets that as a list type
5959
expect_r6_class(f == c(1L, 2L), "Expression")
6060

61-
expect_error(
61+
# Non-Expression inputs are wrapped in Expression$scalar()
62+
expect_equal(
6263
Expression$create("add", 1, 2),
63-
"Expression arguments must be Expression objects"
64+
Expression$create("add", Expression$scalar(1), Expression$scalar(2))
6465
)
6566
})
6667

0 commit comments

Comments
 (0)