Skip to content

Commit 7d0eedd

Browse files
committed
support | operation between typing.Union and strings
Adds support for performing '|' operation between Union objects and strings, e.g. forward type references. For example following code: from typing import Union U1 = Union[int, str] U1 | "float" The result of the operation above becomes: int | str | ForwardRef('float')
1 parent 8d07c45 commit 7d0eedd

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

Lib/test/test_typing.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2281,7 +2281,6 @@ class Ints(enum.IntEnum):
22812281
self.assertEqual(Union[Literal[1], Literal[Ints.B], Literal[True]].__args__,
22822282
(Literal[1], Literal[Ints.B], Literal[True]))
22832283

2284-
@unittest.expectedFailure # TODO: RUSTPYTHON; AssertionError: types.UnionType[int, str] | float != types.UnionType[int, str, float]
22852284
def test_allow_non_types_in_or(self):
22862285
# gh-140348: Test that using | with a Union object allows things that are
22872286
# not allowed by is_unionable().

crates/vm/src/builtins/type.rs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2048,12 +2048,7 @@ pub(crate) fn call_slot_new(
20482048
}
20492049

20502050
pub(crate) fn or_(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
2051-
if !union_::is_unionable(zelf.clone(), vm) || !union_::is_unionable(other.clone(), vm) {
2052-
return Ok(vm.ctx.not_implemented());
2053-
}
2054-
2055-
let tuple = PyTuple::new_ref(vec![zelf, other], &vm.ctx);
2056-
union_::make_union(&tuple, vm)
2051+
union_::or_op(zelf, other, vm)
20572052
}
20582053

20592054
fn take_next_base(bases: &mut [Vec<PyTypeRef>]) -> Option<PyTypeRef> {

crates/vm/src/builtins/union.rs

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ use super::{genericalias, type_};
22
use crate::{
33
AsObject, Context, Py, PyObject, PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
44
atomic_func,
5-
builtins::{PyFrozenSet, PySet, PyStr, PyTuple, PyTupleRef, PyType},
5+
builtins::{PyFrozenSet, PySet, PyStr, PyTuple, PyTupleRef, PyType, pystr::AsPyStr},
66
class::PyClassImpl,
77
common::hash,
88
convert::ToPyObject,
9+
function::IntoFuncArgs,
910
function::PyComparisonValue,
1011
protocol::{PyMappingMethods, PyNumberMethods},
1112
stdlib::typing::TypeAliasType,
@@ -193,7 +194,7 @@ impl PyUnion {
193194
}
194195
}
195196

196-
pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
197+
fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
197198
let cls = obj.class();
198199
cls.is(vm.ctx.types.none_type)
199200
|| obj.downcastable::<PyType>()
@@ -202,6 +203,58 @@ pub fn is_unionable(obj: PyObjectRef, vm: &VirtualMachine) -> bool {
202203
|| obj.downcast_ref::<TypeAliasType>().is_some()
203204
}
204205

206+
fn _call_typing_func_object<'a>(
207+
vm: &VirtualMachine,
208+
func_name: impl AsPyStr<'a>,
209+
args: impl IntoFuncArgs,
210+
) -> PyResult {
211+
let module = vm.import("typing", 0)?;
212+
let func = module.get_attr(func_name.as_pystr(&vm.ctx), vm)?;
213+
func.call(args, vm)
214+
}
215+
216+
fn type_check(arg: PyObjectRef, vm: &VirtualMachine) -> PyResult<PyObjectRef> {
217+
// Fast path to avoid calling into typing.py
218+
if is_unionable(arg.clone(), vm) {
219+
return Ok(arg);
220+
}
221+
let message_str: PyObjectRef = vm
222+
.ctx
223+
.new_str("Union[arg, ...]: each arg must be a type.")
224+
.into();
225+
_call_typing_func_object(vm, "_type_check", (arg, message_str))
226+
}
227+
228+
fn has_union_operands(a: PyObjectRef, b: PyObjectRef, vm: &VirtualMachine) -> bool {
229+
let union_type = vm.ctx.types.union_type;
230+
a.class().is(union_type) || b.class().is(union_type)
231+
}
232+
233+
pub fn or_op(zelf: PyObjectRef, other: PyObjectRef, vm: &VirtualMachine) -> PyResult {
234+
if !has_union_operands(zelf.clone(), other.clone(), vm) {
235+
if !is_unionable(zelf.clone(), vm) || !is_unionable(other.clone(), vm) {
236+
return Ok(vm.ctx.not_implemented());
237+
}
238+
}
239+
240+
let left = match type_check(zelf, vm) {
241+
Ok(v) => v,
242+
err => {
243+
return err;
244+
}
245+
};
246+
247+
let right = match type_check(other, vm) {
248+
Ok(v) => v,
249+
err => {
250+
return err;
251+
}
252+
};
253+
254+
let tuple = PyTuple::new_ref(vec![left, right], &vm.ctx);
255+
make_union(&tuple, vm)
256+
}
257+
205258
fn make_parameters(args: &Py<PyTuple>, vm: &VirtualMachine) -> PyResult<PyTupleRef> {
206259
let parameters = genericalias::make_parameters(args, vm);
207260
let result = dedup_and_flatten_args(&parameters, vm)?;

0 commit comments

Comments
 (0)