Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,4 +83,11 @@ spark-*
# Emacs
.dir-locals.el
*~

# AI
.claude
.opencode
.qwen
.cursor
openspec
.aider*
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ lazy val siteBaseUri = sys.props.getOrElse("docs.mode", "preview") match {

lazy val protocVersion = sparkMajorVer match {
case "4" => "4.29.3"
case "3" => "3.23.4"
case "3" => "3.21.12"
case _ => throw new IllegalArgumentException(s"Unsupported Spark version: $sparkVer.")
}

Expand Down
32 changes: 32 additions & 0 deletions connect/src/main/protobuf/graphframes.proto
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ message GraphFramesAPI {
KCore kcore = 21;
MaximalIndependentSet mis = 22;
RandomWalkEmbeddings rw_embeddings = 23;
AggregateNeighbors aggregate_neighbors = 24;
}
}

Expand Down Expand Up @@ -210,6 +211,37 @@ message KCore {
optional StorageLevel storage_level = 3;
}

message AggregateNeighbors {
// Starting vertices condition (Boolean column expression)
ColumnOrExpression starting_vertices = 1;
// Maximum number of hops to explore
int32 max_hops = 2;
// Accumulator names
repeated string accumulator_names = 3;
// Accumulator initial value expressions
repeated ColumnOrExpression accumulator_inits = 4;
// Accumulator update expressions
repeated ColumnOrExpression accumulator_updates = 5;
// Optional stopping condition (Boolean column expression)
optional ColumnOrExpression stopping_condition = 6;
// Optional target condition (Boolean column expression)
optional ColumnOrExpression target_condition = 7;
// Optional required vertex attributes to carry through traversal
repeated string required_vertex_attributes = 8;
// Optional required edge attributes to carry through traversal
repeated string required_edge_attributes = 9;
// Optional edge filter condition (Boolean column expression)
optional ColumnOrExpression edge_filter = 10;
// Whether to remove self-loops
bool remove_loops = 11;
// Checkpoint interval (0 means disabled)
int32 checkpoint_interval = 12;
// Whether to use local checkpoints
bool use_local_checkpoints = 13;
// Optional storage level for intermediate results
optional StorageLevel storage_level = 14;
}

message RandomWalkEmbeddings {
bool use_edge_direction = 1;
string rw_model = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,67 @@ object GraphFramesConnectUtils {

kCoreBuilder.run()
}
case proto.GraphFramesAPI.MethodCase.AGGREGATE_NEIGHBORS => {
val anProto = apiMessage.getAggregateNeighbors
var anBuilder = graphFrame.aggregateNeighbors
.setStartingVertices(parseColumnOrExpression(anProto.getStartingVertices, planner))
.setMaxHops(anProto.getMaxHops)

// Set accumulators
val accNames = anProto.getAccumulatorNamesList.asScala.toSeq
val accInits = anProto.getAccumulatorInitsList.asScala
.map(parseColumnOrExpression(_, planner))
.toSeq
val accUpdates = anProto.getAccumulatorUpdatesList.asScala
.map(parseColumnOrExpression(_, planner))
.toSeq

if (accNames.nonEmpty) {
anBuilder = anBuilder.setAccumulators(accNames, accInits, accUpdates)
}

// Optional parameters
if (anProto.hasStoppingCondition) {
anBuilder = anBuilder.setStoppingCondition(
parseColumnOrExpression(anProto.getStoppingCondition, planner))
}

if (anProto.hasTargetCondition) {
anBuilder = anBuilder.setTargetCondition(
parseColumnOrExpression(anProto.getTargetCondition, planner))
}

val reqVertexAttrs = anProto.getRequiredVertexAttributesList.asScala.toSeq
if (reqVertexAttrs.nonEmpty) {
anBuilder = anBuilder.setRequiredVertexAttributes(reqVertexAttrs)
}

val reqEdgeAttrs = anProto.getRequiredEdgeAttributesList.asScala.toSeq
if (reqEdgeAttrs.nonEmpty) {
anBuilder = anBuilder.setRequiredEdgeAttributes(reqEdgeAttrs)
}

if (anProto.hasEdgeFilter) {
anBuilder =
anBuilder.setEdgeFilter(parseColumnOrExpression(anProto.getEdgeFilter, planner))
}

anBuilder = anBuilder.setRemoveLoops(anProto.getRemoveLoops)

if (anProto.getCheckpointInterval > 0) {
anBuilder = anBuilder.setCheckpointInterval(anProto.getCheckpointInterval)
}

anBuilder = anBuilder.setUseLocalCheckpoints(anProto.getUseLocalCheckpoints)

if (anProto.hasStorageLevel) {
anBuilder =
anBuilder.setIntermediateStorageLevel(parseStorageLevel(anProto.getStorageLevel))
}

anBuilder.run()
}

case proto.GraphFramesAPI.MethodCase.RW_EMBEDDINGS => {
val message = apiMessage.getRwEmbeddings()

Expand Down
51 changes: 51 additions & 0 deletions core/src/main/scala/org/graphframes/GraphFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,57 @@ class GraphFrame private (
*/
def bfs: BFS = new BFS(this)

/**
* Aggregate information from neighboring vertices and edges through a controlled traversal.
*
* This method provides a flexible way to perform graph traversals while accumulating state at
* each vertex. It can be used to implement various graph algorithms that require propagating
* information through the graph, such as influence propagation, belief propagation, or custom
* message-passing algorithms.
*
* The traversal starts from a set of starting vertices (by default all vertices) and proceeds
* for up to a specified number of hops. At each step, accumulators are updated based on
* neighboring vertices and edges. The traversal can be stopped early based on conditions, and
* results can be collected when target conditions are met.
*
* Key features:
* - Configurable starting vertices via `setStartingVertices()`
* - Maximum number of hops via `setMaxHops()`
* - Accumulators to maintain state during traversal via `setAccumulators()` or
* `addAccumulator()`
* - Stopping conditions to terminate traversal early via `setStoppingCondition()`
* - Target conditions to collect results when specific conditions are met via
* `setTargetCondition()`
* - Edge filtering via `setEdgeFilter()`
* - Control over intermediate storage and checkpointing
*
* The algorithm works as follows:
* 1. Initialize accumulators for starting vertices
* 2. For each iteration up to maxHops:
* - Join current frontier with edges to get neighbors
* - Update accumulators using the provided update expressions
* - Apply stopping conditions to determine which vertices should stop
* - Apply target conditions to determine which stopped vertices should be collected
* - Continue with vertices that haven't stopped
* 3. Return collected results as a DataFrame
*
* The result DataFrame contains:
* - The accumulators' final values for collected vertices
* - The vertex ID (in column "id")
* - The number of hops taken (in column "hop")
*
* Note: This is a stateful iterative algorithm that may be performance-intensive for large
* graphs or large maxHops values. Consider using appropriate storage levels and checkpoint
* intervals for stability.
*
* @see
* [[org.graphframes.lib.AggregateNeighbors]] for implementation details
* @return
* an [[org.graphframes.lib.AggregateNeighbors]] instance for configuration
* @group stdlib
*/
def aggregateNeighbors: AggregateNeighbors = new AggregateNeighbors(this)

/**
* This is a primitive for implementing graph algorithms. This method aggregates values from the
* neighboring edges and vertices of each vertex. See
Expand Down
Loading
Loading