Skip to content

Commit b057cd4

Browse files
committed
PYTHON-1075 Support running the entire test suite with TLS
1 parent dff826b commit b057cd4

File tree

2 files changed

+148
-217
lines changed

2 files changed

+148
-217
lines changed

test/__init__.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@
3535

3636
from bson.py3compat import _unicode
3737
from pymongo import common
38+
from pymongo.ssl_support import HAVE_SSL
3839
from test.version import Version
3940

41+
if HAVE_SSL:
42+
import ssl
43+
4044
# hostnames retrieved from isMaster will be of unicode type in Python 2,
4145
# so ensure these hostnames are unicodes, too. It makes tests like
4246
# `test_repr` predictable.
@@ -53,6 +57,24 @@
5357
db_user = _unicode(os.environ.get("DB_USER", "user"))
5458
db_pwd = _unicode(os.environ.get("DB_PASSWORD", "password"))
5559

60+
CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)),
61+
'certificates')
62+
CLIENT_PEM = os.path.join(CERT_PATH, 'client.pem')
63+
64+
65+
def is_server_resolvable():
66+
"""Returns True if 'server' is resolvable."""
67+
socket_timeout = socket.getdefaulttimeout()
68+
socket.setdefaulttimeout(1)
69+
try:
70+
try:
71+
socket.gethostbyname('server')
72+
return True
73+
except socket.error:
74+
return False
75+
finally:
76+
socket.setdefaulttimeout(socket_timeout)
77+
5678

5779
class client_knobs(object):
5880
def __init__(
@@ -119,18 +141,38 @@ def __init__(self):
119141
self.is_mongos = False
120142
self.is_rs = False
121143
self.has_ipv6 = False
144+
self.ssl_cert_none = False
145+
self.ssl_certfile = False
146+
self.server_is_resolvable = is_server_resolvable()
122147

123-
try:
124-
client = pymongo.MongoClient(host, port,
125-
serverSelectionTimeoutMS=100)
126-
client.admin.command('ismaster') # Can we connect?
148+
self.client = self.rs_or_standalone_client = None
127149

128-
# If so, then reset client to defaults.
129-
self.client = pymongo.MongoClient(host, port)
130-
131-
except pymongo.errors.ConnectionFailure:
132-
self.client = self.rs_or_standalone_client = None
133-
else:
150+
def connect(**kwargs):
151+
try:
152+
client = pymongo.MongoClient(
153+
host, port, serverSelectionTimeoutMS=100, **kwargs)
154+
client.admin.command('ismaster') # Can we connect?
155+
# If connected, then return client with default timeout
156+
return pymongo.MongoClient(host, port, **kwargs)
157+
except pymongo.errors.ConnectionFailure:
158+
return None
159+
160+
self.client = connect()
161+
162+
if HAVE_SSL and not self.client:
163+
# Is MongoDB configured for SSL?
164+
self.client = connect(ssl=True, ssl_cert_reqs=ssl.CERT_NONE)
165+
if self.client:
166+
self.ssl_cert_none = True
167+
168+
# Can client connect with certfile?
169+
client = connect(ssl=True, ssl_cert_reqs=ssl.CERT_NONE,
170+
ssl_certfile=CLIENT_PEM,)
171+
if client:
172+
self.ssl_certfile = True
173+
self.client = client
174+
175+
if self.client:
134176
self.connected = True
135177
self.ismaster = self.client.admin.command('ismaster')
136178
self.w = len(self.ismaster.get("hosts", [])) or 1
@@ -338,6 +380,36 @@ def require_test_commands(self, func):
338380
"Test commands must be enabled",
339381
func=func)
340382

383+
def require_ssl(self, func):
384+
"""Run a test only if the client can connect over SSL."""
385+
return self._require(self.ssl_cert_none or self.ssl_certfile,
386+
"Must be able to connect via SSL",
387+
func=func)
388+
389+
def require_no_ssl(self, func):
390+
"""Run a test only if the client can connect over SSL."""
391+
return self._require(not (self.ssl_cert_none or self.ssl_certfile),
392+
"Must be able to connect without SSL",
393+
func=func)
394+
395+
def require_ssl_cert_none(self, func):
396+
"""Run a test only if the client can connect with ssl.CERT_NONE."""
397+
return self._require(self.ssl_cert_none,
398+
"Must be able to connect with ssl.CERT_NONE",
399+
func=func)
400+
401+
def require_ssl_certfile(self, func):
402+
"""Run a test only if the client can connect with ssl_certfile."""
403+
return self._require(self.ssl_certfile,
404+
"Must be able to connect with ssl_certfile",
405+
func=func)
406+
407+
def require_server_resolvable(self, func):
408+
"""Run a test only if the hostname 'server' is resolvable."""
409+
return self._require(self.server_is_resolvable,
410+
"No hosts entry for 'server'. Cannot validate "
411+
"hostname in the certificate",
412+
func=func)
341413

342414
# Reusable client context
343415
client_context = ClientContext()

0 commit comments

Comments
 (0)