Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@ import java.io.IOException
import java.math.BigDecimal
import java.util.UUID

import org.graphframes.{GraphFrame, Logging}
import org.apache.hadoop.fs.Path

import org.apache.hadoop.fs.{FileSystem, Path}
import org.graphframes.{GraphFrame, Logging}
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.storage.StorageLevel

/**
* Connected components algorithm.
* Connected Components algorithm.
*
* Computes the connected component membership of each vertex and returns a DataFrame of vertex
* information with each vertex assigned a component ID.
Expand Down Expand Up @@ -105,7 +105,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
* Sets checkpoint interval in terms of number of iterations (default: 2). Checkpointing
* regularly helps recover from failures, clean shuffle files, shorten the lineage of the
* computation graph, and reduce the complexity of plan optimization. As of Spark 2.0, the
* complexity of plan optimization would grow exponentially without checkpointing. Hence
* complexity of plan optimization would grow exponentially without checkpointing. Hence,
* disabling or setting longer-than-default checkpoint intervals are not recommended. Checkpoint
* data is saved under `org.apache.spark.SparkContext.getCheckpointDir` with prefix
* "connected-components". If the checkpoint directory is not set, this throws a
Expand Down Expand Up @@ -171,7 +171,6 @@ object ConnectedComponents extends Logging {
import org.graphframes.GraphFrame._

private val COMPONENT = "component"
private val ORIG_ID = "orig_id"
private val MIN_NBR = "min_nbr"
private val CNT = "cnt"
private val CHECKPOINT_NAME_PREFIX = "connected-components"
Expand All @@ -183,7 +182,7 @@ object ConnectedComponents extends Logging {
* Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]:
* "graphframes" and "graphx".
*/
val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES)
private val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES)

/**
* Returns the symmetric directed graph of the graph specified by input edges.
Expand Down Expand Up @@ -331,8 +330,7 @@ object ConnectedComponents extends Logging {
val g = prepare(graph)
val vv = g.vertices
var ee = g.edges.persist(intermediateStorageLevel) // src < dst
val numEdges = ee.count()
logInfo(s"$logPrefix Found $numEdges edges after preparation.")
logInfo(s"$logPrefix Found ${ee.count()} edges after preparation.")

var converged = false
var iteration = 1
Expand Down Expand Up @@ -426,11 +424,7 @@ object ConnectedComponents extends Logging {
prevSum = currSum
}

// materialize all persisted DataFrames in current round,
// then we can unpersist last round persisted DataFrames.
for (persisted_df <- currRoundPersistedDFs) {
persisted_df.count() // materialize it.
}
// clean up persisted DFs
for (persisted_df <- lastRoundPersistedDFs) {
persisted_df.unpersist()
}
Expand All @@ -441,9 +435,23 @@ object ConnectedComponents extends Logging {
logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.")

logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.")
vv.join(ee, vv(ID) === ee(DST), "left_outer")
val output = vv
.join(ee, vv(ID) === ee(DST), "left_outer")
.select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT))
.select(col(s"$ATTR.*"), col(COMPONENT))
.persist(intermediateStorageLevel)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe persist is lazy and does not offer an eager flag. Will this code actually wind up using the cached dataframes if we dont cache the output df before we unpersist the child dataframes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only when an action is executed the dataframe needs to be persisted, in order to reuse those previous calculations.

Calculations are performed in the last round and, on that dataframe once the loop ends, another transformation is applied and then cached (but no new calculations have been performed because there's no action involved).

Nothing changes, only that the persisted dataframe is the one the method is returning, instead of the previous dataframe the last transformations are applied and then, the resulting dataframe returning to the user. So the user can unpersist the dataframe.

Copy link
Copy Markdown
Collaborator

@james-willis james-willis Mar 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here is the diff with and without the count call. removing the count call causes a cache miss: https://www.diffchecker.com/i57B411V/

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where do you see a cache miss? Because I'm debugging the "single vertex" unit test and there's one DataFrame cached and a InMemoryTableScan in the plan:

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- InMemoryTableScan [id#1136L, vattr#1137, gender#1138, component#1133L]
+- InMemoryRelation [id#1136L, vattr#1137, gender#1138, component#1133L], StorageLevel(disk, memory, deserialized, 1 replicas)
+- LocalTableScan [id#1136L, vattr#1137, gender#1138, component#1133L]

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I am just struggling to understand but I think the count is necessary.

Only when an action is executed the dataframe needs to be persisted, in order to reuse those previous calculations.

If you want the output dataframe to leverage the persisted child dataframes in its query plan, you need to call an action on the output dataframe before those children have unpersist called. Without the count call you will not utilize the cached version of the children dataframes when caching the output dataframe.

another transformation is applied and then cached

I don't agree. cache and unpersist are lazy in spark, so the dataframe is only marked for caching. It is not actually cached until some action is called. Without the count call the action will always be after the children query plans have been unpersisted and so they will be recalculated by the engine. This defeats the purpose of those persist calls.

I tried to add a test for this in my PR:

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm debugging the "single vertex" unit test and there's one DataFrame cached and a InMemoryTableScan in the plan

I believe this is an edge case because spark is optimizing away the second child of the join because ee is an empty LocalRelation.

I believe the chain graph test is more representative because there are edges in the table. There you will see only the top-level InMemoryRelation when the count call is removed and 16 when it is in place.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to the count being necessary. I think it might be the case the counts inside the loop aren't needed, as other actions like _calcMinNbrSum will trigger the DataFrame to cache. But in this case at the end, since everything is being unpersisted, output will be completely calculated from the last checkpoint when the user does something with it with none of the intermediate caching.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's simple ... let's probe it with a long dataset a then see if it takes longer or not.


// An action must be performed on the DataFrame for the cache to load
output.count()

// clean up persisted DFs
for (persisted_df <- lastRoundPersistedDFs) {
persisted_df.unpersist()
}

logWarn("The DataFrame returned by ConnectedComponents is persisted and loaded.")

output
} finally {
// Restore original AQE setting
spark.conf.set("spark.sql.adaptive.enabled", originalAQE)
Expand Down
11 changes: 11 additions & 0 deletions src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,17 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
}
}

test("not leaking cached data") {
val priorCachedDFsSize = spark.sparkContext.getPersistentRDDs.size

val cc = Graphs.friends.connectedComponents
val components = cc.run()

components.unpersist(blocking = true)

assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize)
}

private def assertComponents[T: ClassTag: TypeTag](
actual: DataFrame,
expected: Set[Set[T]]): Unit = {
Expand Down