1

The goal is to use a pandas user-defined function as a window function in pyspark. Here is a minimal example.

df is a pandas DataFrame and a spark table:

import pandas as pd
from pyspark.sql import SparkSession

df = pd.DataFrame(
    {'x': [1, 1, 2, 2, 2, 3, 3],
     'y': [1, 2, 3, 4, 5, 6, 7]})
spark = SparkSession.builder.getOrCreate()
spark.createDataFrame(df).createOrReplaceTempView('df')

Here is df as a spark table

In [10]: spark.sql('SELECT * FROM df').show()
+---+---+
|  x|  y|
+---+---+
|  1|  1|
|  1|  2|
|  2|  3|
|  2|  4|
|  2|  5|
|  3|  6|
|  3|  7|
+---+---+

The minimal example is to implement a cumulative sum of y partitioned by x. Without any pandas user-defined function that looks like:

dx = spark.sql(f"""
    SELECT x, y,
    SUM(y) OVER (PARTITION BY x ORDER BY y) AS ysum
    FROM df
    ORDER BY x""").toPandas()

where dx is then

In [2]: dx
Out[2]:
   x  y  ysum
0  1  1     1
1  1  2     3
2  2  3     3
3  2  4     7
4  2  5    12
5  3  6     6
6  3  7    13

And a non-working attempt to do the same with pandas_udf is

from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType

@pandas_udf(returnType=DoubleType())
def func(x: pd.Series) -> pd.Series:
    return x.cumsum()
spark.udf.register('func', func)

dx = spark.sql(f"""
    SELECT x, y,
    func(y) OVER (PARTITION BY x ORDER BY y) AS ysum
    FROM df
    ORDER BY x""").toPandas()

which returns this error

AnalysisException: Expression 'func(y#1L)' not supported within a window function.;
...

UPDATE Based on answer by wwnde, solution was

def pdf_cumsum(pdf):
    pdf['ysum'] = pdf['y'].cumsum()
    return pdf
dx = sdf.groupby('x').applyInPandas(pdf_cumsum, schema='x long, y long, ysum long').toPandas()

1 Answer 1

3

use mapInPandas from Map Pandas Function API

sch =df.withColumn('ysum',lit(3)).schema
def cumsum_pdf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for pdf in iterator:
      yield pdf.assign(ysum=pdf.groupby('x')['y'].cumsum())

df.mapInPandas(cumsum_pdf, schema=sch).show()

Outcome

+---+---+----+
|  x|  y|ysum|
+---+---+----+
|  1|  1|   1|
|  1|  2|   3|
|  2|  3|   3|
|  2|  4|   7|
|  2|  5|  12|
|  3|  6|   6|
|  3|  7|  13|
+---+---+----+
Sign up to request clarification or add additional context in comments.

3 Comments

Don't think you need the groupby within the udf as at that point there will be only a single value for x.
You can use mapInPandas if you do not want groupby. See my edited answer.
Nice, thanks! Preferred the first answer via applyInPandas just did not understand why using groupby within the user-defined function was needed. See the edited question.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.