The goal is to use a pandas user-defined function as a window function in pyspark. Here is a minimal example.
df is a pandas DataFrame and a spark table:
import pandas as pd
from pyspark.sql import SparkSession
df = pd.DataFrame(
{'x': [1, 1, 2, 2, 2, 3, 3],
'y': [1, 2, 3, 4, 5, 6, 7]})
spark = SparkSession.builder.getOrCreate()
spark.createDataFrame(df).createOrReplaceTempView('df')
Here is df as a spark table
In [10]: spark.sql('SELECT * FROM df').show()
+---+---+
| x| y|
+---+---+
| 1| 1|
| 1| 2|
| 2| 3|
| 2| 4|
| 2| 5|
| 3| 6|
| 3| 7|
+---+---+
The minimal example is to implement a cumulative sum of y partitioned by x. Without any pandas user-defined function that looks like:
dx = spark.sql(f"""
SELECT x, y,
SUM(y) OVER (PARTITION BY x ORDER BY y) AS ysum
FROM df
ORDER BY x""").toPandas()
where dx is then
In [2]: dx
Out[2]:
x y ysum
0 1 1 1
1 1 2 3
2 2 3 3
3 2 4 7
4 2 5 12
5 3 6 6
6 3 7 13
And a non-working attempt to do the same with pandas_udf is
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
@pandas_udf(returnType=DoubleType())
def func(x: pd.Series) -> pd.Series:
return x.cumsum()
spark.udf.register('func', func)
dx = spark.sql(f"""
SELECT x, y,
func(y) OVER (PARTITION BY x ORDER BY y) AS ysum
FROM df
ORDER BY x""").toPandas()
which returns this error
AnalysisException: Expression 'func(y#1L)' not supported within a window function.;
...
UPDATE Based on answer by wwnde, solution was
def pdf_cumsum(pdf):
pdf['ysum'] = pdf['y'].cumsum()
return pdf
dx = sdf.groupby('x').applyInPandas(pdf_cumsum, schema='x long, y long, ysum long').toPandas()