I have the following requirement
- Pivot the dataframe to sum amount column based on document type
- Join the pivot dataframe back to the original dataframe to get additional columns
- Filter the joined dataframe using window function
Sample code
Setting up the dataframe
from datetime import date
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DateType
import pyspark.sql.functions as F
from pyspark.sql.window import Window
schema = StructType([
StructField('company_code', StringType(), True)
, StructField('line_no', IntegerType(), True)
, StructField('document_type', StringType(), True)
, StructField('amount', IntegerType(), True)
, StructField('posting_date', DateType(), True)
])
data = [
['AB', 10, 'RE', 12, date(2019,1,1)]
, ['AB', 10, 'RE', 13, date(2019,2,10)]
, ['AB', 20, 'WE', 14, date(2019,1,11)]
, ['BC', 10, 'WL', 11, date(2019,2,12)]
, ['BC', 20, 'RE', 15, date(2019,1,21)]
]
df = spark.createDataFrame(data, schema)
First using the pivot way
# Partitioning upfront so as to not shuffle twice(one in groupby and other in window)
partition_df = df.repartition('company_code', 'line_no').cache()
pivot_df = (
partition_df.groupBy('company_code', 'line_no')
.pivot('document_type', ['RE', 'WE', 'WL'])
.sum('amount')
)
# It will broadcast join because pivot_df is small (it is small for my actual case as well)
join_df = (
partition_df.join(pivot_df, ['company_code', 'line_no'])
.select(partition_df['*'], 'RE', 'WE', 'WL')
)
window_spec = Window.partitionBy('company_code', 'line_no').orderBy('posting_date')
final_df = join_df.withColumn("Row_num", F.row_number().over(window_spec)).filter("Row_num == 1").drop("Row_num")
final_df.show()
+------------+-------+-------------+------+------------+----+----+----+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+----+----+----+
| AB| 10| RE| 12| 2019-01-01| 25|NULL|NULL|
| AB| 20| WE| 14| 2019-01-11|NULL| 14|NULL|
| BC| 10| WL| 11| 2019-02-12|NULL|NULL| 11|
| BC| 20| RE| 15| 2019-01-21| 15|NULL|NULL|
+------------+-------+-------------+------+------------+----+----+----+
And using the window way
t_df = df.withColumns({
'RE': F.when(F.col('document_type') == 'RE', F.col('amount')).otherwise(0)
, 'WE': F.when(F.col('document_type') == 'WE', F.col('amount')).otherwise(0)
, 'WL': F.when(F.col('document_type') == 'WL', F.col('amount')).otherwise(0)
})
window_spec = Window.partitionBy('company_code', 'line_no').orderBy('posting_date')
sum_window_spec = window_spec.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
t2_df = t_df.withColumns({
'RE': F.sum('RE').over(sum_window_spec)
, 'WE': F.sum('WE').over(sum_window_spec)
, 'WL': F.sum('WL').over(sum_window_spec)
, 'Row_num': F.row_number().over(window_spec)
})
final_df = t2_df.filter("Row_num == 1").drop("Row_num")
final_df.show()
+------------+-------+-------------+------+------------+---+---+---+
|company_code|line_no|document_type|amount|posting_date| RE| WE| WL|
+------------+-------+-------------+------+------------+---+---+---+
| AB| 10| RE| 12| 2019-01-01| 25| 0| 0|
| AB| 20| WE| 14| 2019-01-11| 0| 14| 0|
| BC| 10| WL| 11| 2019-02-12| 0| 0| 11|
| BC| 20| RE| 15| 2019-01-21| 15| 0| 0|
+------------+-------+-------------+------+------------+---+---+---+
I have not put the output of explain here as it will make the question lengthy. But, there is only one shuffle in both methods. So, how to decide which one will take more time?
I'm using databricks runtime 14.3LTS