Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

package org.graphframes.lib

import java.io.IOException
import java.math.BigDecimal
import java.util.UUID

import org.graphframes.{GraphFrame, Logging}

import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.hadoop.fs.Path
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.storage.StorageLevel
import org.graphframes.{GraphFrame, Logging}

import java.io.IOException
import java.math.BigDecimal
import java.util.UUID

/**
* Connected components algorithm.
Expand Down Expand Up @@ -441,9 +440,21 @@ object ConnectedComponents extends Logging {
logInfo(s"$logPrefix Connected components converged in ${iteration - 1} iterations.")

logInfo(s"$logPrefix Join and return component assignments with original vertex IDs.")
vv.join(ee, vv(ID) === ee(DST), "left_outer")
val output = vv
.join(ee, vv(ID) === ee(DST), "left_outer")
.select(vv(ATTR), when(ee(SRC).isNull, vv(ID)).otherwise(ee(SRC)).as(COMPONENT))
.select(col(s"$ATTR.*"), col(COMPONENT))
.persist(intermediateStorageLevel)

// materialize the output DataFrame
output.count()

// clean up persisted DFs
for (persisted_df <- lastRoundPersistedDFs) {
persisted_df.unpersist()
}

output
} finally {
// Restore original AQE setting
spark.conf.set("spark.sql.adaptive.enabled", originalAQE)
Expand Down
40 changes: 32 additions & 8 deletions src/test/scala/org/graphframes/lib/ConnectedComponentsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,20 @@

package org.graphframes.lib

import java.io.IOException

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.storage.StorageLevel

import org.graphframes._
import org.graphframes.GraphFrame._
import org.graphframes._
import org.graphframes.examples.Graphs

import java.io.IOException
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag

class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkContext {

test("default params") {
Expand Down Expand Up @@ -253,6 +253,30 @@ class ConnectedComponentsSuite extends SparkFunSuite with GraphFrameTestSparkCon
}
}

test("uses intermediate caches") {
val cc = Graphs.friends.connectedComponents
val components = cc.run()

val count = components.queryExecution.executedPlan
.toString()
.sliding("InMemoryRelation".length)
.count(window => window == "InMemoryRelation")

// 17 number derived from when output.count() call is present in the run method
assert(count == 17)
components.unpersist(blocking = true)
}

test("not leaking cached data") {
val priorCachedDFsSize = spark.sparkContext.getPersistentRDDs.size

val cc = Graphs.friends.connectedComponents
val components = cc.run()

components.unpersist(blocking = true)
assert(spark.sparkContext.getPersistentRDDs.size === priorCachedDFsSize)
}

private def assertComponents[T: ClassTag: TypeTag](
actual: DataFrame,
expected: Set[Set[T]]): Unit = {
Expand Down