Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 5 additions & 30 deletions Lib/test/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,6 @@ def test_typevartuple(self):
class A(Generic[Unpack[Ts]]): ...
Alias = Optional[Unpack[Ts]]

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typevartuple_specialization(self):
T = TypeVar("T")
Ts = TypeVarTuple('Ts', default=Unpack[Tuple[str, int]])
Expand All @@ -691,8 +689,6 @@ class A(Generic[T, Unpack[Ts]]): ...
self.assertEqual(A[float, range].__args__, (float, range))
self.assertEqual(A[float, *tuple[int, ...]].__args__, (float, *tuple[int, ...]))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typevar_and_typevartuple_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
Expand Down Expand Up @@ -740,8 +736,6 @@ class A(Generic[T, P]): ...
self.assertEqual(A[float].__args__, (float, (str, int)))
self.assertEqual(A[float, [range]].__args__, (float, (range,)))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_typevar_and_paramspec_specialization(self):
T = TypeVar("T")
U = TypeVar("U", default=float)
Expand All @@ -752,8 +746,6 @@ class A(Generic[T, U, P]): ...
self.assertEqual(A[float, int].__args__, (float, int, (str, int)))
self.assertEqual(A[float, int, [range]].__args__, (float, int, (range,)))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_and_typevar_specialization(self):
T = TypeVar("T")
P = ParamSpec('P', default=[str, int])
Expand Down Expand Up @@ -1049,8 +1041,6 @@ class C(Generic[T1, T2]): pass
eval(expected_str)
)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_three_parameters(self):
T1 = TypeVar('T1')
T2 = TypeVar('T2')
Expand Down Expand Up @@ -2543,8 +2533,6 @@ def __call__(self):
self.assertIs(a().__class__, C1)
self.assertEqual(a().__orig_class__, C1[[int], T])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec(self):
Callable = self.Callable
fullname = f"{Callable.__module__}.Callable"
Expand Down Expand Up @@ -2579,8 +2567,6 @@ def test_paramspec(self):
self.assertEqual(repr(C2), f"{fullname}[~P, int]")
self.assertEqual(repr(C2[int, str]), f"{fullname}[[int, str], int]")

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_concatenate(self):
Callable = self.Callable
fullname = f"{Callable.__module__}.Callable"
Expand Down Expand Up @@ -2608,8 +2594,6 @@ def test_concatenate(self):
Callable[Concatenate[int, str, P2], int])
self.assertEqual(C[...], Callable[Concatenate[int, ...], int])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_nested_paramspec(self):
# Since Callable has some special treatment, we want to be sure
# that substituion works correctly, see gh-103054
Expand Down Expand Up @@ -2652,8 +2636,6 @@ class My(Generic[P, T]):
self.assertEqual(C4[bool, bytes, float],
My[[Callable[[int, bool, bytes, str], float], float], float])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_errors(self):
Callable = self.Callable
alias = Callable[[int, str], float]
Expand Down Expand Up @@ -2682,6 +2664,11 @@ def test_consistency(self):
class CollectionsCallableTests(BaseCallableTests, BaseTestCase):
Callable = collections.abc.Callable

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_errors(self):
super().test_errors()


class LiteralTests(BaseTestCase):
def test_basics(self):
Expand Down Expand Up @@ -4631,8 +4618,6 @@ class Base(Generic[T_co]):
class Sub(Base, Generic[T]):
...

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_parameter_detection(self):
self.assertEqual(List[T].__parameters__, (T,))
self.assertEqual(List[List[T]].__parameters__, (T,))
Expand All @@ -4650,8 +4635,6 @@ class A:
# C version of GenericAlias
self.assertEqual(list[A()].__parameters__, (T,))

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_non_generic_subscript(self):
T = TypeVar('T')
class G(Generic[T]):
Expand Down Expand Up @@ -8858,8 +8841,6 @@ def test_bad_var_substitution(self):
with self.assertRaises(TypeError):
collections.abc.Callable[P, T][arg, str]

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_type_var_subst_for_other_type_vars(self):
T = TypeVar('T')
T2 = TypeVar('T2')
Expand Down Expand Up @@ -8981,8 +8962,6 @@ class PandT(Generic[P, T]):
self.assertEqual(C3.__args__, ((int, *Ts), T))
self.assertEqual(C3[str, bool, bytes], PandT[[int, str, bool], bytes])

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_in_nested_generics(self):
# Although ParamSpec should not be found in __parameters__ of most
# generics, they probably should be found when nested in
Expand All @@ -9001,8 +8980,6 @@ def test_paramspec_in_nested_generics(self):
self.assertEqual(G2[[int, str], float], list[C])
self.assertEqual(G3[[int, str], float], list[C] | int)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_paramspec_gets_copied(self):
# bpo-46581
P = ParamSpec('P')
Expand Down Expand Up @@ -9090,8 +9067,6 @@ def test_invalid_uses(self):
):
Concatenate[int]

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_var_substitution(self):
T = TypeVar('T')
P = ParamSpec('P')
Expand Down
145 changes: 113 additions & 32 deletions vm/src/builtins/genericalias.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ use crate::{
function::{FuncArgs, PyComparisonValue},
protocol::{PyMappingMethods, PyNumberMethods},
types::{
AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, PyComparisonOp,
Representable,
AsMapping, AsNumber, Callable, Comparable, Constructor, GetAttr, Hashable, Iterable,
PyComparisonOp, Representable,
},
};
use std::fmt;
Expand Down Expand Up @@ -78,6 +78,7 @@ impl Constructor for PyGenericAlias {
Constructor,
GetAttr,
Hashable,
Iterable,
Representable
),
flags(BASETYPE)
Expand Down Expand Up @@ -166,17 +167,17 @@ impl PyGenericAlias {
}

#[pymethod]
fn __getitem__(&self, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
fn __getitem__(zelf: PyRef<Self>, needle: PyObjectRef, vm: &VirtualMachine) -> PyResult {
let new_args = subs_parameters(
|vm| self.repr(vm),
self.args.clone(),
self.parameters.clone(),
zelf.to_owned().into(),
zelf.args.clone(),
zelf.parameters.clone(),
needle,
vm,
)?;

Ok(
PyGenericAlias::new(self.origin.clone(), new_args.to_pyobject(vm), vm)
PyGenericAlias::new(zelf.origin.clone(), new_args.to_pyobject(vm), vm)
.into_pyobject(vm),
)
}
Expand Down Expand Up @@ -277,6 +278,18 @@ fn tuple_index(vec: &[PyObjectRef], item: &PyObjectRef) -> Option<usize> {
vec.iter().position(|element| element.is(item))
}

fn is_unpacked_typevartuple(arg: &PyObjectRef, vm: &VirtualMachine) -> PyResult<bool> {
if arg.class().is(vm.ctx.types.type_type) {
return Ok(false);
}

if let Ok(attr) = arg.get_attr(identifier!(vm, __typing_is_unpacked_typevartuple__), vm) {
attr.try_to_bool(vm)
} else {
Ok(false)
}
}

fn subs_tvars(
obj: PyObjectRef,
params: &PyTupleRef,
Expand Down Expand Up @@ -324,22 +337,40 @@ fn subs_tvars(
}

// _Py_subs_parameters
pub fn subs_parameters<F: Fn(&VirtualMachine) -> PyResult<String>>(
repr: F,
pub fn subs_parameters(
alias: PyObjectRef, // The GenericAlias object itself
args: PyTupleRef,
parameters: PyTupleRef,
needle: PyObjectRef,
vm: &VirtualMachine,
) -> PyResult<PyTupleRef> {
let num_params = parameters.len();
if num_params == 0 {
return Err(vm.new_type_error(format!("There are no type variables left in {}", repr(vm)?)));
return Err(vm.new_type_error(format!("{} is not a generic class", alias.repr(vm)?)));
}

let items = needle.try_to_ref::<PyTuple>(vm);
// Handle __typing_prepare_subst__ for each parameter
// Following CPython: each prepare function transforms the args
let mut prepared_args = needle.clone();

// Ensure args is a tuple
if prepared_args.try_to_ref::<PyTuple>(vm).is_err() {
prepared_args = PyTuple::new_ref(vec![prepared_args], &vm.ctx).into();
}

for param in parameters.iter() {
if let Ok(prepare) = param.get_attr(identifier!(vm, __typing_prepare_subst__), vm) {
if !prepare.is(&vm.ctx.none) {
// Call prepare(cls, args) where cls is the GenericAlias
prepared_args = prepare.call((alias.clone(), prepared_args), vm)?;
}
}
}

let items = prepared_args.try_to_ref::<PyTuple>(vm);
let arg_items = match items {
Ok(tuple) => tuple.as_slice(),
Err(_) => std::slice::from_ref(&needle),
Err(_) => std::slice::from_ref(&prepared_args),
};

let num_items = arg_items.len();
Expand All @@ -362,40 +393,82 @@ pub fn subs_parameters<F: Fn(&VirtualMachine) -> PyResult<String>>(

let min_required = num_params - params_with_defaults;
if num_items < min_required {
let repr_str = alias.repr(vm)?;
return Err(vm.new_type_error(format!(
"Too few arguments for {}; actual {}, expected at least {}",
repr(vm)?,
num_items,
min_required
"Too few arguments for {repr_str}; actual {num_items}, expected at least {min_required}"
)));
}
} else if num_items > num_params {
let repr_str = alias.repr(vm)?;
return Err(vm.new_type_error(format!(
"Too many arguments for {}; actual {}, expected {}",
repr(vm)?,
num_items,
num_params
"Too many arguments for {repr_str}; actual {num_items}, expected {num_params}"
)));
}

let mut new_args = Vec::new();
let mut new_args = Vec::with_capacity(args.len());

for arg in args.iter() {
// Skip bare Python classes
if arg.class().is(vm.ctx.types.type_type) {
new_args.push(arg.clone());
continue;
}

// Check if this is an unpacked TypeVarTuple
let unpack = is_unpacked_typevartuple(arg, vm)?;

// Check for __typing_subst__ attribute directly (like CPython)
if let Ok(subst) = arg.get_attr(identifier!(vm, __typing_subst__), vm) {
let idx = tuple_index(parameters.as_slice(), arg).unwrap();
if idx < num_items {
// Call __typing_subst__ with the argument
let substituted = subst.call((arg_items[idx].clone(),), vm)?;
new_args.push(substituted);
if let Some(idx) = tuple_index(parameters.as_slice(), arg) {
if idx < num_items {
// Call __typing_subst__ with the argument
let substituted = subst.call((arg_items[idx].clone(),), vm)?;

if unpack {
// Unpack the tuple if it's a TypeVarTuple
if let Ok(tuple) = substituted.try_to_ref::<PyTuple>(vm) {
for elem in tuple.iter() {
new_args.push(elem.clone());
}
} else {
new_args.push(substituted);
}
} else {
new_args.push(substituted);
}
} else {
// Use default value if available
if let Ok(default_val) = vm.call_method(arg, "__default__", ()) {
if !default_val.is(&vm.ctx.typing_no_default) {
new_args.push(default_val);
} else {
return Err(vm.new_type_error(format!(
"No argument provided for parameter at index {idx}"
)));
}
} else {
return Err(vm.new_type_error(format!(
"No argument provided for parameter at index {idx}"
)));
}
}
} else {
// CPython doesn't support default values in this context
return Err(
vm.new_type_error(format!("No argument provided for parameter at index {idx}"))
);
new_args.push(arg.clone());
}
} else {
new_args.push(subs_tvars(arg.clone(), &parameters, arg_items, vm)?);
let subst_arg = subs_tvars(arg.clone(), &parameters, arg_items, vm)?;
if unpack {
// Unpack the tuple if it's a TypeVarTuple
if let Ok(tuple) = subst_arg.try_to_ref::<PyTuple>(vm) {
for elem in tuple.iter() {
new_args.push(elem.clone());
}
} else {
new_args.push(subst_arg);
}
} else {
new_args.push(subst_arg);
}
}
}

Expand All @@ -406,7 +479,8 @@ impl AsMapping for PyGenericAlias {
fn as_mapping() -> &'static PyMappingMethods {
static AS_MAPPING: LazyLock<PyMappingMethods> = LazyLock::new(|| PyMappingMethods {
subscript: atomic_func!(|mapping, needle, vm| {
PyGenericAlias::mapping_downcast(mapping).__getitem__(needle.to_owned(), vm)
let zelf = PyGenericAlias::mapping_downcast(mapping);
PyGenericAlias::__getitem__(zelf.to_owned(), needle.to_owned(), vm)
}),
..PyMappingMethods::NOT_IMPLEMENTED
});
Expand Down Expand Up @@ -490,6 +564,13 @@ impl Representable for PyGenericAlias {
}
}

impl Iterable for PyGenericAlias {
fn iter(zelf: PyRef<Self>, vm: &VirtualMachine) -> PyResult {
// Return an iterator over the args tuple
Ok(zelf.args.clone().to_pyobject(vm).get_iter(vm)?.into())
}
}

pub fn init(context: &Context) {
let generic_alias_type = &context.types.generic_alias_type;
PyGenericAlias::extend_class(context, generic_alias_type);
Expand Down
Loading
Loading