@@ -171,7 +171,13 @@ Expression$create <- function(function_name,
171171 args = list (... ),
172172 options = empty_named_list()) {
173173 assert_that(is.string(function_name ))
174- assert_that(is_list_of(args , " Expression" ), msg = " Expression arguments must be Expression objects" )
174+ # Make sure all inputs are Expressions
175+ args <- lapply(args , function (x ) {
176+ if (! inherits(x , " Expression" )) {
177+ x <- Expression $ scalar(x )
178+ }
179+ x
180+ })
175181 expr <- compute___expr__call(function_name , args , options )
176182 if (length(args )) {
177183 expr $ schema <- unify_schemas(schemas = lapply(args , function (x ) x $ schema ))
@@ -187,7 +193,10 @@ Expression$field_ref <- function(name) {
187193 compute___expr__field_ref(name )
188194}
189195Expression $ scalar <- function (x ) {
190- expr <- compute___expr__scalar(Scalar $ create(x ))
196+ if (! inherits(x , " Scalar" )) {
197+ x <- Scalar $ create(x )
198+ }
199+ expr <- compute___expr__scalar(x )
191200 expr $ schema <- schema()
192201 expr
193202}
@@ -208,21 +217,20 @@ build_expr <- function(FUN,
208217 }
209218 if (FUN == " %in%" ) {
210219 # Special-case %in%, which is different from the Array function name
220+ value_set <- Array $ create(args [[2 ]])
221+ try(
222+ value_set <- cast_or_parse(value_set , args [[1 ]]$ type()),
223+ silent = TRUE
224+ )
225+
211226 expr <- Expression $ create(" is_in" , args [[1 ]],
212227 options = list (
213- # If args[[2]] is already an Arrow object (like a scalar),
214- # this wouldn't work
215- value_set = Array $ create(args [[2 ]]),
228+ value_set = value_set ,
216229 skip_nulls = TRUE
217230 )
218231 )
219232 } else {
220- args <- lapply(args , function (x ) {
221- if (! inherits(x , " Expression" )) {
222- x <- Expression $ scalar(x )
223- }
224- x
225- })
233+ args <- wrap_scalars(args , FUN )
226234
227235 # In Arrow, "divide" is one function, which does integer division on
228236 # integer inputs and floating-point division on floats
@@ -258,6 +266,101 @@ build_expr <- function(FUN,
258266 expr
259267}
260268
269+ wrap_scalars <- function (args , FUN ) {
270+ arrow_fun <- .array_function_map [[FUN ]] %|| % FUN
271+ if (arrow_fun == " if_else" ) {
272+ # For if_else, the first arg should be a bool Expression, and we don't
273+ # want to consider that when casting the other args to the same type
274+ args [- 1 ] <- wrap_scalars(args [- 1 ], FUN = " " )
275+ return (args )
276+ }
277+
278+ is_expr <- map_lgl(args , ~ inherits(. , " Expression" ))
279+ if (all(is_expr )) {
280+ # No wrapping is required
281+ return (args )
282+ }
283+
284+ args [! is_expr ] <- lapply(args [! is_expr ], Scalar $ create )
285+
286+ # Some special casing by function
287+ # * %/%: we switch behavior based on int vs. dbl in R (see build_expr) so skip
288+ # * binary_repeat, list_element: 2nd arg must be integer, Acero will handle it
289+ if (any(is_expr ) && ! (arrow_fun %in% c(" binary_repeat" , " list_element" )) && ! (FUN %in% " %/%" )) {
290+ try(
291+ {
292+ # If the Expression has no Schema embedded, we cannot resolve its
293+ # type here, so this will error, hence the try() wrapping it
294+ # This will also error if length(args[is_expr]) == 0, or
295+ # if there are multiple exprs that do not share a common type.
296+ to_type <- common_type(args [is_expr ])
297+ # Try casting to this type, but if the cast fails,
298+ # we'll just keep the original
299+ args [! is_expr ] <- lapply(args [! is_expr ], cast_or_parse , type = to_type )
300+ },
301+ silent = TRUE
302+ )
303+ }
304+
305+ args [! is_expr ] <- lapply(args [! is_expr ], Expression $ scalar )
306+ args
307+ }
308+
309+ common_type <- function (exprs ) {
310+ types <- map(exprs , ~ . $ type())
311+ first_type <- types [[1 ]]
312+ if (length(types ) == 1 || all(map_lgl(types , ~ . $ Equals(first_type )))) {
313+ # Functions (in our tests) that have multiple exprs to check:
314+ # * case_when
315+ # * pmin/pmax
316+ return (first_type )
317+ }
318+ stop(" There is no common type in these expressions" )
319+ }
320+
321+ cast_or_parse <- function (x , type ) {
322+ to_type_id <- type $ id
323+ if (to_type_id %in% c(Type [[" DECIMAL128" ]], Type [[" DECIMAL256" ]])) {
324+ # TODO: determine the minimum size of decimal (or integer) required to
325+ # accommodate x
326+ # We would like to keep calculations on decimal if that's what the data has
327+ # so that we don't lose precision. However, there are some limitations
328+ # today, so it makes sense to keep x as double (which is probably is from R)
329+ # and let Acero cast the decimal to double to compute.
330+ # You can specify in your query that x should be decimal or integer if you
331+ # know it to be safe.
332+ # * ARROW-17601: multiply(decimal, decimal) can fail to make output type
333+ return (x )
334+ }
335+
336+ # For most types, just cast.
337+ # But for string -> date/time, we need to call a parsing function
338+ if (x $ type_id() %in% c(Type [[" STRING" ]], Type [[" LARGE_STRING" ]])) {
339+ if (to_type_id %in% c(Type [[" DATE32" ]], Type [[" DATE64" ]])) {
340+ x <- call_function(
341+ " strptime" ,
342+ x ,
343+ options = list (format = " %Y-%m-%d" , unit = 0L )
344+ )
345+ } else if (to_type_id == Type [[" TIMESTAMP" ]]) {
346+ x <- call_function(
347+ " strptime" ,
348+ x ,
349+ options = list (format = " %Y-%m-%d %H:%M:%S" , unit = 1L )
350+ )
351+ # R assumes timestamps without timezone specified are
352+ # local timezone while Arrow assumes UTC. For consistency
353+ # with R behavior, specify local timezone here.
354+ x <- call_function(
355+ " assume_timezone" ,
356+ x ,
357+ options = list (timezone = Sys.timezone())
358+ )
359+ }
360+ }
361+ x $ cast(type )
362+ }
363+
261364# ' @export
262365Ops.Expression <- function (e1 , e2 ) {
263366 if (.Generic == " !" ) {
0 commit comments