-
Notifications
You must be signed in to change notification settings - Fork 266
Expand file tree
/
Copy pathconftest.py
More file actions
132 lines (103 loc) · 4.35 KB
/
conftest.py
File metadata and controls
132 lines (103 loc) · 4.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations
import os
import pathlib
import tempfile
import warnings
import pytest
from py4j.java_gateway import JavaObject
from pyspark.sql import SparkSession
from pyspark.version import __version__
from graphframes import GraphFrame
from graphframes.classic.graphframe import _java_api
if __version__[:3] >= "3.4":
from pyspark.sql.utils import is_remote
else:
def is_remote() -> bool:
return False
spark_major_version = __version__[:1]
scala_version = os.environ.get("SCALA_VERSION", "2.12" if __version__ < "4" else "2.13")
def get_gf_jar_locations() -> tuple[str, str, str]:
"""
Returns a location of the GraphFrames JAR and GraphFrames Connect JAR.
In the case your version of PySpark is not compatible with the version of GraphFrames,
this function will raise an exception!
"""
project_root = pathlib.Path(__file__).parent.parent.parent
graphx_dir = project_root / "graphx" / "target" / f"scala-{scala_version}"
core_dir = project_root / "core" / "target" / f"scala-{scala_version}"
connect_dir = project_root / "connect" / "target" / f"scala-{scala_version}"
graphx_jar: str | None = None
core_jar: str | None = None
connect_jar: str | None = None
for pp in graphx_dir.glob(f"graphframes-graphx-spark{spark_major_version}*.jar"):
assert isinstance(pp, pathlib.PosixPath) # type checking
graphx_jar = str(pp.absolute())
if graphx_jar is None:
raise ValueError(
f"Failed to find graphframes jar for Spark {spark_major_version} in {graphx_dir}"
)
for pp in core_dir.glob(f"graphframes-spark{spark_major_version}*.jar"):
assert isinstance(pp, pathlib.PosixPath) # type checking
core_jar = str(pp.absolute())
if core_jar is None:
raise ValueError(
f"Failed to find graphframes jar for Spark {spark_major_version} in {core_dir}"
)
for pp in connect_dir.glob(f"graphframes-connect-spark{spark_major_version}*.jar"):
assert isinstance(pp, pathlib.PosixPath) # type checking
connect_jar = str(pp.absolute())
if connect_jar is None:
raise ValueError(
f"Failed to find graphframes connect jar for Spark {spark_major_version} in {connect_dir}"
)
return core_jar, connect_jar, graphx_jar
@pytest.fixture(scope="module")
def spark():
warnings.filterwarnings("ignore", category=ResourceWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
(core_jar, connect_jar, graphx_jar) = get_gf_jar_locations()
with tempfile.TemporaryDirectory() as tmp_dir:
builder = (
SparkSession.Builder()
.appName("GraphFramesTest")
.config("spark.sql.shuffle.partitions", 4)
.config("spark.checkpoint.dir", tmp_dir)
.config("spark.jars", f"{core_jar},{connect_jar},{graphx_jar}")
.config("spark.driver.memory", "6g")
)
if spark_major_version == "3":
# Spark 3 does not include connect by default
builder = builder.config(
"spark.jars.packages",
f"org.apache.spark:spark-connect_{scala_version}:{__version__}",
)
if is_remote():
builder = builder.remote("local[4]").config(
"spark.connect.extensions.relation.classes",
"org.apache.spark.sql.graphframes.GraphFramesConnect",
)
else:
builder = builder.master("local[4]")
spark = builder.getOrCreate()
yield spark
spark.stop()
@pytest.fixture(scope="module")
def local_g(spark: SparkSession):
localVertices = [(1, "A"), (2, "B"), (3, "C")]
localEdges = [(1, 2, "love"), (2, 1, "hate"), (2, 3, "follow")]
v = spark.createDataFrame(localVertices, ["id", "name"])
e = spark.createDataFrame(localEdges, ["src", "dst", "action"])
yield GraphFrame(v, e)
@pytest.fixture(scope="module")
def examples(spark: SparkSession):
if is_remote():
# TODO: We should update tests to be able to run all of them on Spark Connect
# At the moment the problem is that examples API is py4j based.
yield None
else:
japi = _java_api(spark._sc)
assert japi is not None
examples = japi.examples()
assert examples is not None
assert isinstance(examples, JavaObject)
yield examples