Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 232 additions & 12 deletions Lib/test/test_context.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
import collections.abc
import concurrent.futures
import contextvars
import functools
import gc
import random
import sys
import time
import unittest
import weakref
Expand All @@ -26,8 +27,7 @@ def wrapper(*args, **kwargs):


class ContextTest(unittest.TestCase):
# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_context_var_new_1(self):
with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
contextvars.ContextVar()
Expand Down Expand Up @@ -63,6 +63,14 @@ def test_context_var_repr_1(self):
c.reset(t)
self.assertIn(' used ', repr(t))

@isolated_context
def test_token_repr_1(self):
c = contextvars.ContextVar('a')
tok = c.set(1)
self.assertRegex(repr(tok),
r"^<Token var=<ContextVar name='a' "
r"at 0x[0-9a-fA-F]+> at 0x[0-9a-fA-F]+>$")

def test_context_subclassing_1(self):
with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
class MyContextVar(contextvars.ContextVar):
Expand All @@ -77,8 +85,7 @@ class MyContext(contextvars.Context):
class MyToken(contextvars.Token):
pass

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_context_new_1(self):
with self.assertRaisesRegex(TypeError, 'any arguments'):
contextvars.Context(1)
Expand All @@ -88,8 +95,17 @@ def test_context_new_1(self):
contextvars.Context(a=1)
contextvars.Context(**{})

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: TypeError not raised
def test_context_new_unhashable_str_subclass(self):
# gh-132002: it used to crash on unhashable str subtypes.
class weird_str(str):
def __eq__(self, other):
pass

with self.assertRaisesRegex(TypeError, 'unhashable type'):
contextvars.ContextVar(weird_str())

@unittest.expectedFailure # TODO: RUSTPYTHON
def test_context_typerrors_1(self):
ctx = contextvars.Context()

Expand All @@ -104,8 +120,7 @@ def test_context_get_context_1(self):
ctx = contextvars.copy_context()
self.assertIsInstance(ctx, contextvars.Context)

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
def test_context_run_1(self):
ctx = contextvars.Context()

Expand Down Expand Up @@ -153,8 +168,7 @@ def func(*args, **kwargs):
with self.assertRaises(ZeroDivisionError):
ctx.run(func, 1, 2, a=123)

# TODO: RUSTPYTHON
@unittest.expectedFailure
@unittest.expectedFailure # TODO: RUSTPYTHON
@isolated_context
def test_context_run_4(self):
ctx1 = contextvars.Context()
Expand Down Expand Up @@ -353,9 +367,22 @@ def ctx2_fun():

ctx1.run(ctx1_fun)

def test_context_isinstance(self):
ctx = contextvars.Context()
self.assertIsInstance(ctx, collections.abc.Mapping)
self.assertTrue(issubclass(contextvars.Context, collections.abc.Mapping))

mapping_methods = (
'__contains__', '__eq__', '__getitem__', '__iter__', '__len__',
'__ne__', 'get', 'items', 'keys', 'values',
)
for name in mapping_methods:
with self.subTest(name=name):
self.assertTrue(callable(getattr(ctx, name)))

@unittest.skipIf(sys.platform == "darwin", "TODO: RUSTPYTHON; Flaky on Mac, self.assertEqual(cvar.get(), num + i) AssertionError: 8 != 12")
@isolated_context
@threading_helper.requires_working_threading()
@unittest.skipIf(sys.platform == 'darwin', 'TODO: RUSTPYTHON; Flaky on Mac, self.assertEqual(cvar.get(), num + i) AssertionError: 8 != 12')
def test_context_threads_1(self):
cvar = contextvars.ContextVar('cvar')

Expand All @@ -373,6 +400,199 @@ def sub(num):
tp.shutdown()
self.assertEqual(results, list(range(10)))

@isolated_context
@threading_helper.requires_working_threading()
def test_context_thread_inherit(self):
import threading

cvar = contextvars.ContextVar('cvar')

def run_context_none():
if sys.flags.thread_inherit_context:
expected = 1
else:
expected = None
self.assertEqual(cvar.get(None), expected)

# By default, context is inherited based on the
# sys.flags.thread_inherit_context option.
cvar.set(1)
thread = threading.Thread(target=run_context_none)
thread.start()
thread.join()

# Passing 'None' explicitly should have same behaviour as not
# passing parameter.
thread = threading.Thread(target=run_context_none, context=None)
thread.start()
thread.join()

# An explicit Context value can also be passed
custom_ctx = contextvars.Context()
custom_var = None

def setup_context():
nonlocal custom_var
custom_var = contextvars.ContextVar('custom')
custom_var.set(2)

custom_ctx.run(setup_context)

def run_custom():
self.assertEqual(custom_var.get(), 2)

thread = threading.Thread(target=run_custom, context=custom_ctx)
thread.start()
thread.join()

# You can also pass a new Context() object to start with an empty context
def run_empty():
with self.assertRaises(LookupError):
cvar.get()

thread = threading.Thread(target=run_empty, context=contextvars.Context())
thread.start()
thread.join()

def test_token_contextmanager_with_default(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

def fun():
with c.set(36):
self.assertEqual(c.get(), 36)

self.assertEqual(c.get(), 42)

ctx.run(fun)

def test_token_contextmanager_without_default(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c')

def fun():
with c.set(36):
self.assertEqual(c.get(), 36)

with self.assertRaisesRegex(LookupError, "<ContextVar name='c'"):
c.get()

ctx.run(fun)

def test_token_contextmanager_on_exception(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

def fun():
with c.set(36):
self.assertEqual(c.get(), 36)
raise ValueError("custom exception")

self.assertEqual(c.get(), 42)

with self.assertRaisesRegex(ValueError, "custom exception"):
ctx.run(fun)

def test_token_contextmanager_reentrant(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

def fun():
token = c.set(36)
with self.assertRaisesRegex(
RuntimeError,
"<Token .+ has already been used once"
):
with token:
with token:
self.assertEqual(c.get(), 36)

self.assertEqual(c.get(), 42)

ctx.run(fun)

def test_token_contextmanager_multiple_c_set(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

def fun():
with c.set(36):
self.assertEqual(c.get(), 36)
c.set(24)
self.assertEqual(c.get(), 24)
c.set(12)
self.assertEqual(c.get(), 12)

self.assertEqual(c.get(), 42)

ctx.run(fun)

def test_token_contextmanager_with_explicit_reset_the_same_token(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

def fun():
with self.assertRaisesRegex(
RuntimeError,
"<Token .+ has already been used once"
):
with c.set(36) as token:
self.assertEqual(c.get(), 36)
c.reset(token)

self.assertEqual(c.get(), 42)

self.assertEqual(c.get(), 42)

ctx.run(fun)

def test_token_contextmanager_with_explicit_reset_another_token(self):
ctx = contextvars.Context()
c = contextvars.ContextVar('c', default=42)

def fun():
with c.set(36):
self.assertEqual(c.get(), 36)

token = c.set(24)
self.assertEqual(c.get(), 24)
c.reset(token)
self.assertEqual(c.get(), 36)

self.assertEqual(c.get(), 42)

ctx.run(fun)

def test_context_eq_reentrant_contextvar_set(self):
var = contextvars.ContextVar("v")
ctx1 = contextvars.Context()
ctx2 = contextvars.Context()

class ReentrantEq:
def __eq__(self, other):
ctx1.run(lambda: var.set(object()))
return True

ctx1.run(var.set, ReentrantEq())
ctx2.run(var.set, object())
ctx1 == ctx2

def test_context_eq_reentrant_contextvar_set_in_hash(self):
var = contextvars.ContextVar("v")
ctx1 = contextvars.Context()
ctx2 = contextvars.Context()

class ReentrantHash:
def __hash__(self):
ctx1.run(lambda: var.set(object()))
return 0
def __eq__(self, other):
return isinstance(other, ReentrantHash)

ctx1.run(var.set, ReentrantHash())
ctx2.run(var.set, ReentrantHash())
ctx1 == ctx2


# HAMT Tests

Expand Down
20 changes: 9 additions & 11 deletions crates/stdlib/src/contextvars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,8 @@ mod _contextvars {
#[derive(FromArgs)]
struct ContextVarOptions {
#[pyarg(positional)]
#[allow(dead_code)] // TODO: RUSTPYTHON
name: PyStrRef,
#[pyarg(any, optional)]
#[allow(dead_code)] // TODO: RUSTPYTHON
default: OptionalArg<PyObjectRef>,
}

Expand Down Expand Up @@ -533,15 +531,15 @@ mod _contextvars {
impl Representable for ContextVar {
#[inline]
fn repr_str(zelf: &Py<Self>, vm: &VirtualMachine) -> PyResult<String> {
// unimplemented!("<ContextVar name={{}} default={{}} at {{}}")
Ok(format!(
"<ContextVar name={} default={:?} at {:#x}>",
zelf.name.as_str(),
zelf.default
.as_ref()
.and_then(|default| default.str(vm).ok()),
zelf.get_id()
))
let name = zelf.name.as_str();
let id = zelf.get_id();

Ok(if let Some(arg) = zelf.default.as_ref() {
let default = arg.str(vm).ok();
format!("<ContextVar name='{name}' default={default:?} at {id:#x}>",)
} else {
format!("<ContextVar name='{name}' at {id:#x}>")
})
}
}

Expand Down
Loading