Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ object BeliefPropagation {
.drop("aggMess") // drop messages
.drop("belief") // drop old beliefs
.withColumnRenamed("newBelief", "belief")
// Cache new vertices using workaround for SPARK-13346
val cachedNewVertices = AM.getCachedDataFrame(newVertices)
val cachedNewVertices = newVertices.localCheckpoint()
gx = GraphFrame(cachedNewVertices, gx.edges)
}
}
Expand Down
38 changes: 16 additions & 22 deletions core/src/main/scala/org/graphframes/lib/AggregateMessages.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.functions.expr
import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithIntermediateStorageLevel

/**
* This is a primitive for implementing graph algorithms. This method aggregates messages from the
Expand Down Expand Up @@ -59,9 +60,13 @@ import org.graphframes.Logging
*/
class AggregateMessages private[graphframes] (private val g: GraphFrame)
extends Arguments
with Serializable {
with Serializable
with WithIntermediateStorageLevel
with Logging {

import org.graphframes.GraphFrame.{DST, ID, SRC}
import org.graphframes.GraphFrame.DST
import org.graphframes.GraphFrame.ID
import org.graphframes.GraphFrame.SRC

private var msgToSrc: Option[Column] = None

Expand Down Expand Up @@ -106,19 +111,20 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame)
"To run GraphFrame.aggregateMessages," +
" messages must be sent to src, dst, or both. Set using sendToSrc(), sendToDst().")
val triplets = g.triplets
val cachedVertices = g.vertices.persist(intermediateStorageLevel)
val sentMsgsToSrc = msgToSrc.map { msg =>
val msgsToSrc =
triplets.select(msg.as(AggregateMessages.MSG_COL_NAME), triplets(SRC)(ID).as(ID))
// Inner join: only send messages to vertices with edges
msgsToSrc
.join(g.vertices, ID)
.join(cachedVertices, ID)
.select(msgsToSrc(AggregateMessages.MSG_COL_NAME), col(ID))
}
val sentMsgsToDst = msgToDst.map { msg =>
val msgsToDst =
triplets.select(msg.as(AggregateMessages.MSG_COL_NAME), triplets(DST)(ID).as(ID))
msgsToDst
.join(g.vertices, ID)
.join(cachedVertices, ID)
.select(msgsToDst(AggregateMessages.MSG_COL_NAME), col(ID))
}
val unionedMsgs = (sentMsgsToSrc, sentMsgsToDst) match {
Expand All @@ -130,7 +136,12 @@ class AggregateMessages private[graphframes] (private val g: GraphFrame)
// Should never happen. Specify this case to avoid compilation warnings.
throw new RuntimeException("AggregateMessages: No messages were specified to be sent.")
}
unionedMsgs.groupBy(ID).agg(aggCol)
val cachedResult = unionedMsgs.groupBy(ID).agg(aggCol).persist(intermediateStorageLevel)
// materialize
cachedResult.count()
cachedVertices.unpersist()
resultIsPersistent()
cachedResult
}

/**
Expand All @@ -157,21 +168,4 @@ object AggregateMessages extends Logging with Serializable {

/** Reference for message column, used for specifying aggregation function */
def msg: Column = col(MSG_COL_NAME)

/**
* Create a new cached copy of a DataFrame. For iterative DataFrame-based algorithms.
*
* WARNING: This is NOT the same as `DataFrame.cache()`. The original DataFrame will NOT be
* cached.
*
* This is a workaround for SPARK-13346, which makes it difficult to use DataFrames in iterative
* algorithms. This workaround converts the DataFrame to an RDD, caches the RDD, and creates a
* new DataFrame. This is important for avoiding the creation of extremely complex DataFrame
* query plans when using DataFrames in iterative algorithms.
*/
def getCachedDataFrame(df: DataFrame): DataFrame = {
val rdd = df.rdd.cache()
// rdd.count()
df.sparkSession.createDataFrame(rdd, df.schema)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)

object ConnectedComponents extends Logging {

import org.graphframes.GraphFrame._
import org.graphframes.GraphFrame.*

private val COMPONENT = "component"
private val ORIG_ID = "orig_id"
Expand Down Expand Up @@ -176,7 +176,7 @@ object ConnectedComponents extends Logging {
minNbrs: DataFrame,
broadcastThreshold: Int,
logPrefix: String): DataFrame = {
import edges.sparkSession.implicits._
import edges.sparkSession.implicits.*
val hubs = minNbrs
.filter(col(CNT) > broadcastThreshold)
.select(SRC)
Expand Down
19 changes: 14 additions & 5 deletions core/src/main/scala/org/graphframes/lib/LabelPropagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.graphframes.lib

import org.apache.spark.graphframes.graphx.{lib => graphxlib}
import org.apache.spark.graphframes.graphx
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
Expand All @@ -28,6 +28,7 @@ import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithLocalCheckpoints
import org.graphframes.WithMaxIter

Expand All @@ -51,6 +52,7 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
with WithCheckpointInterval
with WithMaxIter
with WithLocalCheckpoints
with WithIntermediateStorageLevel
with Logging {

def run(): DataFrame = {
Expand All @@ -62,7 +64,8 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)
graph,
maxIterChecked,
checkpointInterval,
useLocalCheckpoints = useLocalCheckpoints)
useLocalCheckpoints = useLocalCheckpoints,
intermediateStorageLevel = intermediateStorageLevel)
}
resultIsPersistent()
res
Expand All @@ -71,7 +74,7 @@ class LabelPropagation private[graphframes] (private val graph: GraphFrame)

private object LabelPropagation {
private def runInGraphX(graph: GraphFrame, maxIter: Int): DataFrame = {
val gx = graphxlib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter)
val gx = graphx.lib.LabelPropagation.run(graph.cachedTopologyGraphX, maxIter)
val res = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(LABEL_ID)).vertices
res.persist(StorageLevel.MEMORY_AND_DISK_SER)
res.count()
Expand All @@ -95,13 +98,18 @@ private object LabelPropagation {
maxIter: Int,
checkpointInterval: Int,
isDirected: Boolean = true,
useLocalCheckpoints: Boolean): DataFrame = {
useLocalCheckpoints: Boolean,
intermediateStorageLevel: StorageLevel): DataFrame = {
// Overall:
// - Initial labels - IDs
// - Active vertex col (halt voting) - did the label changed?
// - Choosing a new label - top across neighbours (tie-braking is determenistic)

var pregel = graph.pregel
val preparedGraph = GraphFrame(
graph.vertices.select(GraphFrame.ID),
graph.edges.select(GraphFrame.SRC, GraphFrame.DST))

var pregel = preparedGraph.pregel
.withVertexColumn(LABEL_ID, col(GraphFrame.ID).alias(LABEL_ID), keyWithMaxValue(Pregel.msg))
.setMaxIter(maxIter)
.setStopIfAllNonActiveVertices(true)
Expand All @@ -110,6 +118,7 @@ private object LabelPropagation {
.setSkipMessagesFromNonActiveVertices(false)
.setUpdateActiveVertexExpression(col(LABEL_ID) =!= keyWithMaxValue(Pregel.msg))
.setUseLocalCheckpoints(useLocalCheckpoints)
.setIntermediateStorageLevel(intermediateStorageLevel)

if (isDirected) {
pregel = pregel.sendMsgToDst(Pregel.src(LABEL_ID))
Expand Down
16 changes: 9 additions & 7 deletions core/src/main/scala/org/graphframes/lib/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import org.apache.spark.sql.functions.struct
import org.graphframes.GraphFrame
import org.graphframes.GraphFrame._
import org.graphframes.Logging
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithLocalCheckpoints

import java.io.IOException
Expand Down Expand Up @@ -81,7 +82,10 @@ import scala.util.control.Breaks.breakable
* <a href="https://doi.org/10.1145/1807167.1807184"> Malewicz et al., Pregel: a system for
* large-scale graph processing. </a>
*/
class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
class Pregel(val graph: GraphFrame)
extends Logging
with WithLocalCheckpoints
with WithIntermediateStorageLevel {

private val withVertexColumnList = collection.mutable.ListBuffer.empty[(String, Column, Column)]

Expand Down Expand Up @@ -343,7 +347,7 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
val edges = graph.edges
.select(col(SRC).alias("edge_src"), col(DST).alias("edge_dst"), struct(col("*")).as(EDGE))
.repartition(col("edge_src"), col("edge_dst"))
.persist()
.persist(intermediateStorageLevel)

var iteration = 1

Expand All @@ -364,7 +368,7 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
while (iteration <= maxIter) {
logInfo(s"start Pregel iteration $iteration / $maxIter")
val currRoundPersistent = scala.collection.mutable.Queue[DataFrame]()
currRoundPersistent.enqueue(currentVertices.persist())
currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel))
var tripletsDF = currentVertices
.select(struct(col("*")).as(SRC))
.join(edges, Pregel.src(ID) === col("edge_src"))
Expand Down Expand Up @@ -406,14 +410,12 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
if (shouldCheckpoint && iteration % checkpointInterval == 0) {
if (useLocalCheckpoints) {
currentVertices = currentVertices.localCheckpoint(eager = false)
currRoundPersistent.enqueue(currentVertices)
} else {
currentVertices = currentVertices.checkpoint(eager = false)
currRoundPersistent.enqueue(currentVertices)
}
} else {
// checkpointing do persistence and we do not need to do it again
currRoundPersistent.enqueue(currentVertices.persist())
currRoundPersistent.enqueue(currentVertices.persist(intermediateStorageLevel))
}

if (stopIfAllNonActiveVertices) {
Expand Down Expand Up @@ -442,7 +444,7 @@ class Pregel(val graph: GraphFrame) extends Logging with WithLocalCheckpoints {
}
}

val res = currentVertices.persist()
val res = currentVertices.persist(intermediateStorageLevel)
res.count()
while (lastRoundPersistent.nonEmpty) {
lastRoundPersistent.dequeue().unpersist()
Expand Down
21 changes: 15 additions & 6 deletions core/src/main/scala/org/graphframes/lib/ShortestPaths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.graphframes.lib

import org.apache.spark.graphframes.graphx.{lib => graphxlib}
import org.apache.spark.graphframes.graphx
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
Expand All @@ -39,6 +39,7 @@ import org.graphframes.GraphFramesUnreachableException
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.WithCheckpointInterval
import org.graphframes.WithIntermediateStorageLevel
import org.graphframes.WithLocalCheckpoints

import java.util
Expand All @@ -57,7 +58,8 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice
with WithCheckpointInterval
with WithLocalCheckpoints {
with WithLocalCheckpoints
with WithIntermediateStorageLevel {
import org.graphframes.lib.ShortestPaths._

private var lmarks: Option[Seq[Any]] = None
Expand Down Expand Up @@ -87,7 +89,8 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame)
graph,
lmarksChecked,
checkpointInterval,
useLocalCheckpoints = useLocalCheckpoints)
useLocalCheckpoints = useLocalCheckpoints,
intermediateStorageLevel = intermediateStorageLevel)
case _ => throw new GraphFramesUnreachableException()
}
resultIsPersistent()
Expand All @@ -99,7 +102,7 @@ private object ShortestPaths extends Logging {

private def runInGraphX(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = {
val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap
val gx = graphxlib.ShortestPaths
val gx = graphx.lib.ShortestPaths
.run(graph.cachedTopologyGraphX, longIdToLandmark.keys.toSeq.sorted)
val g = GraphXConversions.fromGraphX(graph, gx, vertexNames = Seq(DISTANCE_ID))
val distanceCol: Column = if (graph.hasIntegralIdType) {
Expand All @@ -124,7 +127,8 @@ private object ShortestPaths extends Logging {
landmarks: Seq[Any],
checkpointInterval: Int,
isDirected: Boolean = true,
useLocalCheckpoints: Boolean): DataFrame = {
useLocalCheckpoints: Boolean,
intermediateStorageLevel: StorageLevel): DataFrame = {
logWarn("The GraphFrames based implementation is slow and considered experimental!")
val vertexType = graph.vertices.schema(GraphFrame.ID).dataType

Expand Down Expand Up @@ -197,11 +201,16 @@ private object ShortestPaths extends Logging {
// Mark vertex as active only in the case idstance changed
val updateActiveVierticesExpr = isDistanceImprovedWithMessage(Pregel.msg, col(DISTANCE_ID))

val preparedGraph = GraphFrame(
graph.vertices.select(GraphFrame.ID),
graph.edges.select(GraphFrame.SRC, GraphFrame.DST))

// Overall:
// 1. Initialize distances
// 2. If new message can improve distances send it
// 3. Collect and aggregate messages
val pregel = graph.pregel
val pregel = preparedGraph.pregel
.setIntermediateStorageLevel(intermediateStorageLevel)
.setMaxIter(Int.MaxValue) // That is how the GraphX implementation works
.withVertexColumn(
DISTANCE_ID,
Expand Down
2 changes: 1 addition & 1 deletion python/graphframes/examples/belief_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def runBPwithGraphFrames(cls, g: GraphFrame, numIter: int) -> GraphFrame:
.withColumnRenamed("newBelief", "belief")
)
# cache new vertices using workaround for SPARK-1334
cachedNewVertices = AM.getCachedDataFrame(newVertices)
cachedNewVertices = newVertices.localCheckpoint()
gx = GraphFrame(cachedNewVertices, gx.edges)

# Drop the "color" column from vertices
Expand Down
18 changes: 1 addition & 17 deletions python/graphframes/lib/aggregate_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any

from pyspark import SparkContext
from pyspark.sql import Column, DataFrame, SparkSession
from pyspark.sql import Column
from pyspark.sql import functions as sqlfunctions


Expand Down Expand Up @@ -72,19 +72,3 @@ def msg(cls) -> Column:
"""Reference for message column, used for specifying aggregation function."""
jvm_gf_api = _java_api(SparkContext)
return sqlfunctions.col(jvm_gf_api.aggregateMessages().MSG_COL_NAME())

@staticmethod
def getCachedDataFrame(df: DataFrame) -> DataFrame:
"""
Create a new cached copy of a DataFrame.

This utility method is useful for iterative DataFrame-based algorithms. See Scala
documentation for more details.

WARNING: This is NOT the same as `DataFrame.cache()`.
The original DataFrame will NOT be cached.
"""
spark = SparkSession.getActiveSession()
jvm_gf_api = _java_api(spark._sc)
jdf = jvm_gf_api.aggregateMessages().getCachedDataFrame(df._jdf)
return DataFrame(jdf, spark)