|
35 | 35 |
|
36 | 36 | from bson.py3compat import _unicode |
37 | 37 | from pymongo import common |
| 38 | +from pymongo.ssl_support import HAVE_SSL |
38 | 39 | from test.version import Version |
39 | 40 |
|
| 41 | +if HAVE_SSL: |
| 42 | + import ssl |
| 43 | + |
40 | 44 | # hostnames retrieved from isMaster will be of unicode type in Python 2, |
41 | 45 | # so ensure these hostnames are unicodes, too. It makes tests like |
42 | 46 | # `test_repr` predictable. |
|
53 | 57 | db_user = _unicode(os.environ.get("DB_USER", "user")) |
54 | 58 | db_pwd = _unicode(os.environ.get("DB_PASSWORD", "password")) |
55 | 59 |
|
| 60 | +CERT_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), |
| 61 | + 'certificates') |
| 62 | +CLIENT_PEM = os.path.join(CERT_PATH, 'client.pem') |
| 63 | + |
| 64 | + |
| 65 | +def is_server_resolvable(): |
| 66 | + """Returns True if 'server' is resolvable.""" |
| 67 | + socket_timeout = socket.getdefaulttimeout() |
| 68 | + socket.setdefaulttimeout(1) |
| 69 | + try: |
| 70 | + try: |
| 71 | + socket.gethostbyname('server') |
| 72 | + return True |
| 73 | + except socket.error: |
| 74 | + return False |
| 75 | + finally: |
| 76 | + socket.setdefaulttimeout(socket_timeout) |
| 77 | + |
56 | 78 |
|
57 | 79 | class client_knobs(object): |
58 | 80 | def __init__( |
@@ -119,18 +141,38 @@ def __init__(self): |
119 | 141 | self.is_mongos = False |
120 | 142 | self.is_rs = False |
121 | 143 | self.has_ipv6 = False |
| 144 | + self.ssl_cert_none = False |
| 145 | + self.ssl_certfile = False |
| 146 | + self.server_is_resolvable = is_server_resolvable() |
122 | 147 |
|
123 | | - try: |
124 | | - client = pymongo.MongoClient(host, port, |
125 | | - serverSelectionTimeoutMS=100) |
126 | | - client.admin.command('ismaster') # Can we connect? |
| 148 | + self.client = self.rs_or_standalone_client = None |
127 | 149 |
|
128 | | - # If so, then reset client to defaults. |
129 | | - self.client = pymongo.MongoClient(host, port) |
130 | | - |
131 | | - except pymongo.errors.ConnectionFailure: |
132 | | - self.client = self.rs_or_standalone_client = None |
133 | | - else: |
| 150 | + def connect(**kwargs): |
| 151 | + try: |
| 152 | + client = pymongo.MongoClient( |
| 153 | + host, port, serverSelectionTimeoutMS=100, **kwargs) |
| 154 | + client.admin.command('ismaster') # Can we connect? |
| 155 | + # If connected, then return client with default timeout |
| 156 | + return pymongo.MongoClient(host, port, **kwargs) |
| 157 | + except pymongo.errors.ConnectionFailure: |
| 158 | + return None |
| 159 | + |
| 160 | + self.client = connect() |
| 161 | + |
| 162 | + if HAVE_SSL and not self.client: |
| 163 | + # Is MongoDB configured for SSL? |
| 164 | + self.client = connect(ssl=True, ssl_cert_reqs=ssl.CERT_NONE) |
| 165 | + if self.client: |
| 166 | + self.ssl_cert_none = True |
| 167 | + |
| 168 | + # Can client connect with certfile? |
| 169 | + client = connect(ssl=True, ssl_cert_reqs=ssl.CERT_NONE, |
| 170 | + ssl_certfile=CLIENT_PEM,) |
| 171 | + if client: |
| 172 | + self.ssl_certfile = True |
| 173 | + self.client = client |
| 174 | + |
| 175 | + if self.client: |
134 | 176 | self.connected = True |
135 | 177 | self.ismaster = self.client.admin.command('ismaster') |
136 | 178 | self.w = len(self.ismaster.get("hosts", [])) or 1 |
@@ -338,6 +380,36 @@ def require_test_commands(self, func): |
338 | 380 | "Test commands must be enabled", |
339 | 381 | func=func) |
340 | 382 |
|
| 383 | + def require_ssl(self, func): |
| 384 | + """Run a test only if the client can connect over SSL.""" |
| 385 | + return self._require(self.ssl_cert_none or self.ssl_certfile, |
| 386 | + "Must be able to connect via SSL", |
| 387 | + func=func) |
| 388 | + |
| 389 | + def require_no_ssl(self, func): |
| 390 | + """Run a test only if the client can connect over SSL.""" |
| 391 | + return self._require(not (self.ssl_cert_none or self.ssl_certfile), |
| 392 | + "Must be able to connect without SSL", |
| 393 | + func=func) |
| 394 | + |
| 395 | + def require_ssl_cert_none(self, func): |
| 396 | + """Run a test only if the client can connect with ssl.CERT_NONE.""" |
| 397 | + return self._require(self.ssl_cert_none, |
| 398 | + "Must be able to connect with ssl.CERT_NONE", |
| 399 | + func=func) |
| 400 | + |
| 401 | + def require_ssl_certfile(self, func): |
| 402 | + """Run a test only if the client can connect with ssl_certfile.""" |
| 403 | + return self._require(self.ssl_certfile, |
| 404 | + "Must be able to connect with ssl_certfile", |
| 405 | + func=func) |
| 406 | + |
| 407 | + def require_server_resolvable(self, func): |
| 408 | + """Run a test only if the hostname 'server' is resolvable.""" |
| 409 | + return self._require(self.server_is_resolvable, |
| 410 | + "No hosts entry for 'server'. Cannot validate " |
| 411 | + "hostname in the certificate", |
| 412 | + func=func) |
341 | 413 |
|
342 | 414 | # Reusable client context |
343 | 415 | client_context = ClientContext() |
|
0 commit comments