Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ jobs:
- spark-version: 3.0.3
scala-version: 2.12.12
python-version: 3.8
- spark-version: 2.4.8
scala-version: 2.11.12
python-version: 3.7
runs-on: ubuntu-20.04
env:
# define Java options for both official sbt and sbt-extras
Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/scala-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ jobs:
scala-version: 2.12.12
- spark-version: 3.0.3
scala-version: 2.12.12
- spark-version: 2.4.8
scala-version: 2.11.12
runs-on: ubuntu-20.04
env:
# define Java options for both official sbt and sbt-extras
Expand Down
1 change: 0 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ val defaultScalaVer = sparkBranch match {
case "3.2" => "2.12.15"
case "3.1" => "2.12.15"
case "3.0" => "2.12.15"
case "2.4" => "2.11.12"
case _ => throw new IllegalArgumentException(s"Unsupported Spark version: $sparkVer.")
}
val scalaVer = sys.props.getOrElse("scala.version", defaultScalaVer)
Expand Down
2 changes: 1 addition & 1 deletion dev/release.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def verify(prompt, interactive):
@click.option("--publish-docs", type=bool, default=PUBLISH_DOCS_DEFAULT, show_default=True,
help="Publish docs to github-pages.")
@click.option("--spark-version", multiple=True, show_default=True,
default=["2.4.8", "3.0.3", "3.1.3", "3.2.2", "3.3.0"])
default=["3.0.3", "3.1.3", "3.2.2", "3.3.0"])
def main(release_version, next_version, publish_to, no_prompt, git_remote, publish_docs,
spark_version):
interactive = not no_prompt
Expand Down
21 changes: 9 additions & 12 deletions python/graphframes/examples/belief_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@

import math

from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext, functions as sqlfunctions, types

# Import subpackage examples here explicitly so that
# this module can be run directly with spark-submit.
import graphframes.examples
from graphframes import GraphFrame
from graphframes.lib import AggregateMessages as AM
# Import subpackage examples here explicitly so that this module can be
# run directly with spark-submit.
import graphframes.examples
from pyspark.sql import SparkSession, functions as sqlfunctions, types

__all__ = ['BeliefPropagation']

Expand Down Expand Up @@ -151,13 +149,11 @@ def _sigmoid(x):

def main():
"""Run the belief propagation algorithm for an example problem."""
# setup context
conf = SparkConf().setAppName("BeliefPropagation example")
sc = SparkContext.getOrCreate(conf)
sql = SQLContext.getOrCreate(sc)
# setup spark session
spark = SparkSession.builder.appName("BeliefPropagation example").getOrCreate()

# create graphical model g of size 3 x 3
g = graphframes.examples.Graphs(sql).gridIsingModel(3)
g = graphframes.examples.Graphs(spark).gridIsingModel(3)
print("Original Ising model:")
g.vertices.show()
g.edges.show()
Expand All @@ -171,7 +167,8 @@ def main():
print("Done with BP. Final beliefs after {} iterations:".format(numIter))
beliefs.show()

sc.stop()
spark.stop()


if __name__ == '__main__':
main()
15 changes: 7 additions & 8 deletions python/graphframes/examples/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,17 @@
class Graphs(object):
"""Example GraphFrames for testing the API

:param sqlContext: SQLContext
:param spark: SparkSession
"""

def __init__(self, sqlContext):
self._sql = sqlContext
self._sc = sqlContext._sc
def __init__(self, spark):
self._spark = spark
self._sc = spark._sc

def friends(self):
"""A GraphFrame of friends in a (fake) social network."""
sqlContext = self._sql
# Vertex DataFrame
v = sqlContext.createDataFrame([
v = self._spark.createDataFrame([
("a", "Alice", 34),
("b", "Bob", 36),
("c", "Charlie", 30),
Expand All @@ -47,7 +46,7 @@ def friends(self):
("f", "Fanny", 36)
], ["id", "name", "age"])
# Edge DataFrame
e = sqlContext.createDataFrame([
e = self._spark.createDataFrame([
("a", "b", "friend"),
("b", "c", "follow"),
("c", "b", "follow"),
Expand Down Expand Up @@ -92,7 +91,7 @@ def gridIsingModel(self, n, vStd=1.0, eStd=1.0):
.format(n))

# create coodinates grid
coordinates = self._sql.createDataFrame(
coordinates = self._spark.createDataFrame(
itertools.product(range(n), range(n)),
schema=('i', 'j'))

Expand Down
54 changes: 27 additions & 27 deletions python/graphframes/graphframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,20 @@
basestring = str

from pyspark import SparkContext
from pyspark.sql import Column, DataFrame, SQLContext
from pyspark.sql import Column, DataFrame, SparkSession
from pyspark.storagelevel import StorageLevel

from graphframes.lib import Pregel


def _from_java_gf(jgf, sqlContext):
def _from_java_gf(jgf, spark):
"""
(internal) creates a python GraphFrame wrapper from a java GraphFrame.

:param jgf:
"""
pv = DataFrame(jgf.vertices(), sqlContext)
pe = DataFrame(jgf.edges(), sqlContext)
pv = DataFrame(jgf.vertices(), spark)
pe = DataFrame(jgf.edges(), spark)
return GraphFrame(pv, pe)

def _java_api(jsc):
Expand All @@ -55,16 +55,16 @@ class GraphFrame(object):

>>> localVertices = [(1,"A"), (2,"B"), (3, "C")]
>>> localEdges = [(1,2,"love"), (2,1,"hate"), (2,3,"follow")]
>>> v = sqlContext.createDataFrame(localVertices, ["id", "name"])
>>> e = sqlContext.createDataFrame(localEdges, ["src", "dst", "action"])
>>> v = spark.createDataFrame(localVertices, ["id", "name"])
>>> e = spark.createDataFrame(localEdges, ["src", "dst", "action"])
>>> g = GraphFrame(v, e)
"""

def __init__(self, v, e):
self._vertices = v
self._edges = e
self._sqlContext = v.sql_ctx
self._sc = self._sqlContext._sc
self._spark = SparkSession.getActiveSession()
self._sc = self._spark._sc
self._jvm_gf_api = _java_api(self._sc)

self.ID = self._jvm_gf_api.ID()
Expand Down Expand Up @@ -142,7 +142,7 @@ def outDegrees(self):
:return: DataFrame with new vertices column "outDegree"
"""
jdf = self._jvm_graph.outDegrees()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

@property
def inDegrees(self):
Expand All @@ -156,7 +156,7 @@ def inDegrees(self):
:return: DataFrame with new vertices column "inDegree"
"""
jdf = self._jvm_graph.inDegrees()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

@property
def degrees(self):
Expand All @@ -170,7 +170,7 @@ def degrees(self):
:return: DataFrame with new vertices column "degree"
"""
jdf = self._jvm_graph.degrees()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

@property
def triplets(self):
Expand All @@ -185,7 +185,7 @@ def triplets(self):
:return: DataFrame with columns 'src', 'edge', and 'dst'
"""
jdf = self._jvm_graph.triplets()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

@property
def pregel(self):
Expand All @@ -206,7 +206,7 @@ def find(self, pattern):
:return: DataFrame with one Row for each instance of the motif found
"""
jdf = self._jvm_graph.find(pattern)
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

def filterVertices(self, condition):
"""
Expand All @@ -222,7 +222,7 @@ def filterVertices(self, condition):
jdf = self._jvm_graph.filterVertices(condition._jc)
else:
raise TypeError("condition should be string or Column")
return _from_java_gf(jdf, self._sqlContext)
return _from_java_gf(jdf, self._spark)

def filterEdges(self, condition):
"""
Expand All @@ -237,7 +237,7 @@ def filterEdges(self, condition):
jdf = self._jvm_graph.filterEdges(condition._jc)
else:
raise TypeError("condition should be string or Column")
return _from_java_gf(jdf, self._sqlContext)
return _from_java_gf(jdf, self._spark)

def dropIsolatedVertices(self):
"""
Expand All @@ -246,7 +246,7 @@ def dropIsolatedVertices(self):
:return: GraphFrame with filtered vertices.
"""
jdf = self._jvm_graph.dropIsolatedVertices()
return _from_java_gf(jdf, self._sqlContext)
return _from_java_gf(jdf, self._spark)

def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10):
"""
Expand All @@ -263,7 +263,7 @@ def bfs(self, fromExpr, toExpr, edgeFilter=None, maxPathLength=10):
if edgeFilter is not None:
builder.edgeFilter(edgeFilter)
jdf = builder.run()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None):
"""
Expand Down Expand Up @@ -305,7 +305,7 @@ def aggregateMessages(self, aggCol, sendToSrc=None, sendToDst=None):
jdf = builder.agg(aggCol._jc)
else:
jdf = builder.agg(aggCol)
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

# Standard algorithms

Expand All @@ -329,7 +329,7 @@ def connectedComponents(self, algorithm = "graphframes", checkpointInterval = 2,
.setCheckpointInterval(checkpointInterval) \
.setBroadcastThreshold(broadcastThreshold) \
.run()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

def labelPropagation(self, maxIter):
"""
Expand All @@ -341,7 +341,7 @@ def labelPropagation(self, maxIter):
:return: DataFrame with new vertices column "label"
"""
jdf = self._jvm_graph.labelPropagation().maxIter(maxIter).run()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None,
tol = None):
Expand Down Expand Up @@ -369,7 +369,7 @@ def pageRank(self, resetProbability = 0.15, sourceId = None, maxIter = None,
assert tol is not None, "Exactly one of maxIter or tol should be set."
builder = builder.tol(tol)
jgf = builder.run()
return _from_java_gf(jgf, self._sqlContext)
return _from_java_gf(jgf, self._spark)

def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None,
maxIter = None):
Expand All @@ -392,7 +392,7 @@ def parallelPersonalizedPageRank(self, resetProbability = 0.15, sourceIds = None
builder = builder.sourceIds(sourceIds)
builder = builder.maxIter(maxIter)
jgf = builder.run()
return _from_java_gf(jgf, self._sqlContext)
return _from_java_gf(jgf, self._spark)

def shortestPaths(self, landmarks):
"""
Expand All @@ -404,7 +404,7 @@ def shortestPaths(self, landmarks):
:return: DataFrame with new vertices column "distances"
"""
jdf = self._jvm_graph.shortestPaths().landmarks(landmarks).run()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

def stronglyConnectedComponents(self, maxIter):
"""
Expand All @@ -416,7 +416,7 @@ def stronglyConnectedComponents(self, maxIter):
:return: DataFrame with new vertex column "component"
"""
jdf = self._jvm_graph.stronglyConnectedComponents().maxIter(maxIter).run()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)

def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0,
gamma1 = 0.007, gamma2 = 0.007, gamma6 = 0.005, gamma7 = 0.015):
Expand All @@ -433,7 +433,7 @@ def svdPlusPlus(self, rank = 10, maxIter = 2, minValue = 0.0, maxValue = 5.0,
builder.gamma1(gamma1).gamma2(gamma2).gamma6(gamma6).gamma7(gamma7)
jdf = builder.run()
loss = builder.loss()
v = DataFrame(jdf, self._sqlContext)
v = DataFrame(jdf, self._spark)
return (v, loss)

def triangleCount(self):
Expand All @@ -445,15 +445,15 @@ def triangleCount(self):
:return: DataFrame with new vertex column "count"
"""
jdf = self._jvm_graph.triangleCount().run()
return DataFrame(jdf, self._sqlContext)
return DataFrame(jdf, self._spark)


def _test():
import doctest
import graphframe
globs = graphframe.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
globs['sqlContext'] = SQLContext(globs['sc'])
globs['spark'] = SparkSession(globs['sc']).builder.getOrCreate()
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
Expand Down
8 changes: 4 additions & 4 deletions python/graphframes/lib/aggregate_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from pyspark import SparkContext
from pyspark.sql import DataFrame, functions as sqlfunctions
from pyspark.sql import DataFrame, functions as sqlfunctions, SparkSession


def _java_api(jsc):
Expand Down Expand Up @@ -77,7 +77,7 @@ def getCachedDataFrame(df):
WARNING: This is NOT the same as `DataFrame.cache()`.
The original DataFrame will NOT be cached.
"""
sqlContext = df.sql_ctx
jvm_gf_api = _java_api(sqlContext._sc)
spark = SparkSession.getActiveSession()
jvm_gf_api = _java_api(spark._sc)
jdf = jvm_gf_api.aggregateMessages().getCachedDataFrame(df._jdf)
return DataFrame(jdf, sqlContext)
return DataFrame(jdf, spark)
6 changes: 3 additions & 3 deletions python/graphframes/lib/pregel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
if sys.version > '3':
basestring = str

from pyspark.sql import DataFrame
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col
from pyspark.ml.wrapper import JavaWrapper, _jvm
from pyspark.ml.wrapper import JavaWrapper


class Pregel(JavaWrapper):
Expand Down Expand Up @@ -169,7 +169,7 @@ def run(self):

:return: the result vertex DataFrame from the final iteration including both original and additional columns.
"""
return DataFrame(self._java_obj.run(), self.graph.vertices.sql_ctx)
return DataFrame(self._java_obj.run(), SparkSession.getActiveSession())

@staticmethod
def msg():
Expand Down
Loading