Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ project/boot/
project/plugins/project/
.bsp
.metals
metals.sbt
.bloop

# intellij
.idea/

# VSCode
.vscode

# Helix
.helix

# Mac
*.DS_Store

Expand Down
52 changes: 7 additions & 45 deletions src/main/scala/org/graphframes/lib/ConnectedComponents.scala
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.storage.StorageLevel
import org.graphframes.WithAlgorithmChoice

/**
* Connected Components algorithm.
Expand All @@ -40,11 +41,13 @@ import org.apache.spark.storage.StorageLevel
*/
class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
extends Arguments
with Logging {
with Logging
with WithAlgorithmChoice {

import org.graphframes.lib.ConnectedComponents._

private var broadcastThreshold: Int = 1000000
setAlgorithm(ALGO_GRAPHFRAMES)

/**
* Sets broadcast threshold in propagating component assignments (default: 1000000). If a node
Expand All @@ -71,34 +74,6 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
*/
def getBroadcastThreshold: Int = broadcastThreshold

private var algorithm: String = ALGO_GRAPHFRAMES

/**
* Sets the connected components algorithm to use (default: "graphframes"). Supported algorithms
* are:
* - "graphframes": Uses alternating large star and small star iterations proposed in
* [[http://dx.doi.org/10.1145/2670979.2670997 Connected Components in MapReduce and Beyond]]
* with skewed join optimization.
* - "graphx": Converts the graph to a GraphX graph and then uses the connected components
* implementation in GraphX.
* @see
* [[org.graphframes.lib.ConnectedComponents.supportedAlgorithms]]
*/
def setAlgorithm(value: String): this.type = {
require(
supportedAlgorithms.contains(value),
s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.")
algorithm = value
this
}

/**
* Gets the connected component algorithm to use.
* @see
* [[org.graphframes.lib.ConnectedComponents.setAlgorithm]].
*/
def getAlgorithm: String = algorithm

private var checkpointInterval: Int = 2

/**
Expand Down Expand Up @@ -159,7 +134,7 @@ class ConnectedComponents private[graphframes] (private val graph: GraphFrame)
def run(): DataFrame = {
ConnectedComponents.run(
graph,
algorithm = algorithm,
runInGraphX = algorithm == ALGO_GRAPHX,
broadcastThreshold = broadcastThreshold,
checkpointInterval = checkpointInterval,
intermediateStorageLevel = intermediateStorageLevel)
Expand All @@ -176,15 +151,6 @@ object ConnectedComponents extends Logging {
private val CNT = "cnt"
private val CHECKPOINT_NAME_PREFIX = "connected-components"

private val ALGO_GRAPHX = "graphx"
private val ALGO_GRAPHFRAMES = "graphframes"

/**
* Supported algorithms in [[org.graphframes.lib.ConnectedComponents.setAlgorithm]]:
* "graphframes" and "graphx".
*/
private val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES)

/**
* Returns the symmetric directed graph of the graph specified by input edges.
* @param ee
Expand Down Expand Up @@ -279,15 +245,11 @@ object ConnectedComponents extends Logging {

private def run(
graph: GraphFrame,
algorithm: String,
runInGraphX: Boolean,
broadcastThreshold: Int,
checkpointInterval: Int,
intermediateStorageLevel: StorageLevel): DataFrame = {
require(
supportedAlgorithms.contains(algorithm),
s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $algorithm.")

if (algorithm == ALGO_GRAPHX) {
if (runInGraphX) {
return runGraphX(graph)
}

Expand Down
124 changes: 118 additions & 6 deletions src/main/scala/org/graphframes/lib/ShortestPaths.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.graphx.{lib => graphxlib}
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.api.java.UDF1
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.functions.{col, udf, map, lit, when, map_zip_with, reduce, map_values, transform_values, collect_list}
import org.apache.spark.sql.types.{IntegerType, MapType}

import org.graphframes.GraphFrame
import org.graphframes.Logging
import org.graphframes.WithAlgorithmChoice
import org.graphframes.GraphFrame.quote

/**
Expand All @@ -39,7 +41,11 @@ import org.graphframes.GraphFrame.quote
* - distances (`MapType[vertex ID type, IntegerType]`): For each vertex v, a map containing the
* shortest-path distance to each reachable landmark vertex.
*/
class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends Arguments {
class ShortestPaths private[graphframes] (private val graph: GraphFrame)
extends Arguments
with WithAlgorithmChoice {
import org.graphframes.lib.ShortestPaths._

private var lmarks: Option[Seq[Any]] = None

/**
Expand All @@ -59,13 +65,17 @@ class ShortestPaths private[graphframes] (private val graph: GraphFrame) extends
}

def run(): DataFrame = {
ShortestPaths.run(graph, check(lmarks, "landmarks"))
val lmarksChecked = check(lmarks, "landmarks")
algorithm match {
case ALGO_GRAPHX => runInGraphX(graph, lmarksChecked)
case ALGO_GRAPHFRAMES => runInGraphFrames(graph, lmarksChecked)
}
}
}

private object ShortestPaths {
private object ShortestPaths extends Logging {

private def run(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = {
private def runInGraphX(graph: GraphFrame, landmarks: Seq[Any]): DataFrame = {
val idType = graph.vertices.schema(GraphFrame.ID).dataType
val longIdToLandmark = landmarks.map(l => GraphXConversions.integralId(graph, l) -> l).toMap
val gx = graphxlib.ShortestPaths
Expand Down Expand Up @@ -95,6 +105,108 @@ private object ShortestPaths {
g.vertices.select(cols.toSeq: _*)
}

private val DISTANCE_ID = "distances"
private def runInGraphFrames(
graph: GraphFrame,
landmarks: Seq[Any],
isDirected: Boolean = true): DataFrame = {
logWarn("The GraphFrames based implementation is slow and considered experimental!")
val vertexType = graph.vertices.schema(GraphFrame.ID).dataType

// For landmark vertices the initial distance to itself is set to 0
// Example: graph with vertices a, b, c, d; landmarks = (c, d)
// we shoudl init the following:
// (a, Map()), (b, Map()), (c, Map(c -> 0)), (d, Map(d -> 0))
//
// Inside the following function it is done by applying multiple case-when
// because we know exactly that only one landmark could be equal to the nodeId.
// For example, for vertex c it will be:
// when(id == "a", Map(a -> 0))
// .when(id == "b", Map(b -> 0))
// .when(id == "c", Map(c -> 0)) --> this one is the only true
// .when(id == "d", Map(d -> 0))
def initDistancesMap(vertexId: Column): Column = {
val firstLmarkCol = lit(landmarks.head)
var initCol = when(vertexId === firstLmarkCol, map(firstLmarkCol, lit(0)))
for (lmark <- landmarks.tail) {
initCol = initCol.when(vertexId === lit(lmark), map(lit(lmark), lit(0)))
}
initCol
}

// Concatenations of two distance maps:
// If one map is null just take another.
// In case both maps are not null:
// - iterate over keys
// - if value in the left map is null or greater than value from the right map take right one
// else take left one
def concatMaps(distancesLeft: Column, distancesRight: Column): Column =
when(distancesLeft.isNull, distancesRight)
.when(distancesRight.isNull, distancesLeft)
.otherwise(map_zip_with(
distancesLeft,
distancesRight,
(_, leftDistance, rightDistance) => {
when(leftDistance.isNull || (leftDistance > rightDistance), rightDistance)
.otherwise(leftDistance)
}))

// If distance is null, result of d + 1 will be null too
def incrementDistances(distancesMap: Column): Column =
transform_values(distancesMap, (_, distance) => distance + lit(1))

// Takes an array of distance maps and reduce them with concatMaps
def aggregateArrayOfDistanceMaps(arrayCol: Column): Column =
reduce(arrayCol, lit(null).cast(MapType(vertexType, IntegerType)), concatMaps)

// Checks that a sent distances map can change the destination distances.
// Evaluation would be "true" in case in the new distances map
// for one of keys present a non-null value but in the old distances map it is null
// or new distance is less than old one.
def isDistanceImprovedWithMessage(newMap: Column, oldMap: Column): Column = reduce(
map_values(
map_zip_with(
newMap,
oldMap,
(_, newDistance, rightDistance) =>
(newDistance.isNotNull && rightDistance.isNull) || (newDistance < rightDistance))),
lit(false),
(left, right) => left || right)

val srcDistanceCol = Pregel.src(DISTANCE_ID)
val dstDistanceCol = Pregel.dst(DISTANCE_ID)

// Overall:
// 1. Initialize distances
// 2. If new message can improve distances send it
// 3. Collect and aggregate messages
val pregel = graph.pregel
.setMaxIter(Int.MaxValue) // That is how the GraphX implementation works
.withVertexColumn(
DISTANCE_ID,
when(col(GraphFrame.ID).isInCollection(landmarks), initDistancesMap(col(GraphFrame.ID)))
.otherwise(map().cast(MapType(vertexType, IntegerType))),
concatMaps(col(DISTANCE_ID), Pregel.msg))
.sendMsgToSrc(when(
isDistanceImprovedWithMessage(incrementDistances(dstDistanceCol), srcDistanceCol),
incrementDistances(dstDistanceCol)))
.aggMsgs(aggregateArrayOfDistanceMaps(collect_list(Pregel.msg)))
.setEarlyStopping(true)

// Experimental feature
if (isDirected) {
pregel.run()
} else {
// For consider edges as undirected,
// it is enough to send messages in both directions
pregel
.sendMsgToDst(
when(
isDistanceImprovedWithMessage(incrementDistances(srcDistanceCol), dstDistanceCol),
incrementDistances(srcDistanceCol)))
.run()
}

}

private val DISTANCE_ID = "distances"
}
18 changes: 18 additions & 0 deletions src/main/scala/org/graphframes/mixins.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package org.graphframes

private[graphframes] trait WithAlgorithmChoice {
protected val ALGO_GRAPHX = "graphx"
protected val ALGO_GRAPHFRAMES = "graphframes"
protected var algorithm: String = ALGO_GRAPHX
val supportedAlgorithms: Array[String] = Array(ALGO_GRAPHX, ALGO_GRAPHFRAMES)

def setAlgorithm(value: String): this.type = {
require(
supportedAlgorithms.contains(value),
s"Supported algorithms are {${supportedAlgorithms.mkString(", ")}}, but got $value.")
algorithm = value
this
}

def getAlgorithm: String = algorithm
}
53 changes: 53 additions & 0 deletions src/test/scala/org/graphframes/lib/ShortestPathsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,39 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(results === expected)
}

test("Simple test with GraphFrames") {
val edgeSeq = Seq((1, 2), (1, 5), (2, 3), (2, 5), (3, 4), (4, 5), (4, 6))
.flatMap { case e =>
Seq(e, e.swap)
}
.map { case (src, dst) => (src.toLong, dst.toLong) }
val edges = spark.createDataFrame(edgeSeq).toDF("src", "dst")
val graph = GraphFrame.fromEdges(edges)

// Ground truth
val shortestPaths = Set(
(1, Map(1 -> 0, 4 -> 2)),
(2, Map(1 -> 1, 4 -> 2)),
(3, Map(1 -> 2, 4 -> 1)),
(4, Map(1 -> 2, 4 -> 0)),
(5, Map(1 -> 1, 4 -> 1)),
(6, Map(1 -> 3, 4 -> 1)))

val landmarks = Seq(1, 4).map(_.toLong)
val v2 = graph.shortestPaths.landmarks(landmarks).setAlgorithm("graphframes").run()

TestUtils.testSchemaInvariants(graph, v2)
TestUtils.checkColumnType(
v2.schema,
"distances",
DataTypes.createMapType(v2.schema("id").dataType, DataTypes.IntegerType, true))
val newVs = v2.select("id", "distances").collect().toSeq
val results = newVs.map { case Row(id: Long, spMap: Map[Long, Int] @unchecked) =>
(id, spMap)
}
assert(results.toSet === shortestPaths)
}

test("friends graph") {
val friends = examples.Graphs.friends
val v = friends.shortestPaths.landmarks(Seq("a", "d")).run()
Expand All @@ -92,6 +125,26 @@ class ShortestPathsSuite extends SparkFunSuite with GraphFrameTestSparkContext {
assert(results === expected)
}

test("friends graph with GraphFrames") {
val friends = examples.Graphs.friends
val v = friends.shortestPaths.landmarks(Seq("a", "d")).setAlgorithm("graphframes").run()
val expected = Set[(String, Map[String, Int])](
("a", Map("a" -> 0, "d" -> 2)),
("b", Map.empty),
("c", Map.empty),
("d", Map("a" -> 1, "d" -> 0)),
("e", Map("a" -> 2, "d" -> 1)),
("f", Map.empty),
("g", Map.empty))
val results = v
.select("id", "distances")
.collect()
.map { case Row(id: String, spMap: Map[String, Int] @unchecked) =>
(id, spMap)
}
.toSet
assert(results === expected)
}
test("Test vertices with column name") {
val verticeSeq =
Seq((1L, "one"), (2L, "two"), (3L, "three"), (4L, "four"), (5L, "five"), (6L, "six"))
Expand Down