Skip to content

Commit 2da102f

Browse files
authored
FromArgs with error_msg (#6804)
1 parent e572023 commit 2da102f

File tree

4 files changed

+39
-33
lines changed

4 files changed

+39
-33
lines changed

crates/derive-impl/src/from_args.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ struct ArgAttribute {
3737
name: Option<String>,
3838
kind: ParameterKind,
3939
default: Option<DefaultValue>,
40+
error_msg: Option<String>,
4041
}
4142

4243
impl ArgAttribute {
@@ -63,6 +64,7 @@ impl ArgAttribute {
6364
name: None,
6465
kind,
6566
default: None,
67+
error_msg: None,
6668
});
6769
return Ok(());
6870
};
@@ -94,6 +96,12 @@ impl ArgAttribute {
9496
}
9597
let val = meta.value()?.parse::<syn::LitStr>()?;
9698
self.name = Some(val.value())
99+
} else if meta.path.is_ident("error_msg") {
100+
if self.error_msg.is_some() {
101+
return Err(meta.error("already have an error_msg"));
102+
}
103+
let val = meta.value()?.parse::<syn::LitStr>()?;
104+
self.error_msg = Some(val.value())
97105
} else {
98106
return Err(meta.error("Unrecognized pyarg attribute"));
99107
}
@@ -146,8 +154,15 @@ fn generate_field((i, field): (usize, &Field)) -> Result<TokenStream> {
146154
.or(name_string)
147155
.ok_or_else(|| err_span!(field, "field in tuple struct must have name attribute"))?;
148156

149-
let middle = quote! {
150-
.map(|x| ::rustpython_vm::convert::TryFromObject::try_from_object(vm, x)).transpose()?
157+
let middle = if let Some(error_msg) = &attr.error_msg {
158+
quote! {
159+
.map(|x| ::rustpython_vm::convert::TryFromObject::try_from_object(vm, x)
160+
.map_err(|_| vm.new_type_error(#error_msg))).transpose()?
161+
}
162+
} else {
163+
quote! {
164+
.map(|x| ::rustpython_vm::convert::TryFromObject::try_from_object(vm, x)).transpose()?
165+
}
151166
};
152167

153168
let ending = if let Some(default) = attr.default {

crates/vm/src/builtins/function.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -855,13 +855,13 @@ pub struct PyFunctionNewArgs {
855855
code: PyRef<PyCode>,
856856
#[pyarg(positional)]
857857
globals: PyDictRef,
858-
#[pyarg(any, optional)]
858+
#[pyarg(any, optional, error_msg = "arg 3 (name) must be None or string")]
859859
name: OptionalArg<PyStrRef>,
860-
#[pyarg(any, optional)]
860+
#[pyarg(any, optional, error_msg = "arg 4 (defaults) must be None or tuple")]
861861
argdefs: Option<PyTupleRef>,
862-
#[pyarg(any, optional)]
862+
#[pyarg(any, optional, error_msg = "arg 5 (closure) must be None or tuple")]
863863
closure: Option<PyTupleRef>,
864-
#[pyarg(any, optional)]
864+
#[pyarg(any, optional, error_msg = "arg 6 (kwdefaults) must be None or dict")]
865865
kwdefaults: Option<PyDictRef>,
866866
}
867867

crates/vm/src/builtins/interpolation.rs

Lines changed: 15 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,16 @@ impl Constructor for PyInterpolation {
5959
type Args = InterpolationArgs;
6060

6161
fn py_new(_cls: &Py<PyType>, args: Self::Args, vm: &VirtualMachine) -> PyResult<Self> {
62-
let conversion = match args.conversion {
63-
OptionalArg::Present(c) => {
64-
if vm.is_none(&c) {
65-
vm.ctx.none()
66-
} else {
67-
let s = c.downcast::<PyStr>().map_err(|_| {
68-
vm.new_type_error(
69-
"Interpolation() argument 'conversion' must be str or None",
70-
)
71-
})?;
72-
let s_str = s.as_str();
73-
if s_str.len() != 1 || !matches!(s_str.chars().next(), Some('s' | 'r' | 'a')) {
74-
return Err(vm.new_value_error(
75-
"Interpolation() argument 'conversion' must be one of 's', 'a' or 'r'",
76-
));
77-
}
78-
s.into()
79-
}
62+
let conversion: PyObjectRef = if let Some(s) = args.conversion {
63+
let s_str = s.as_str();
64+
if s_str.len() != 1 || !matches!(s_str.chars().next(), Some('s' | 'r' | 'a')) {
65+
return Err(vm.new_value_error(
66+
"Interpolation() argument 'conversion' must be one of 's', 'a' or 'r'",
67+
));
8068
}
81-
OptionalArg::Missing => vm.ctx.none(),
69+
s.into()
70+
} else {
71+
vm.ctx.none()
8272
};
8373

8474
let expression = args
@@ -103,8 +93,12 @@ pub struct InterpolationArgs {
10393
value: PyObjectRef,
10494
#[pyarg(any, optional)]
10595
expression: OptionalArg<PyStrRef>,
106-
#[pyarg(any, optional)]
107-
conversion: OptionalArg<PyObjectRef>,
96+
#[pyarg(
97+
any,
98+
optional,
99+
error_msg = "Interpolation() argument 'conversion' must be str or None"
100+
)]
101+
conversion: Option<PyStrRef>,
108102
#[pyarg(any, optional)]
109103
format_spec: OptionalArg<PyStrRef>,
110104
}

crates/vm/src/builtins/super.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ impl Constructor for PySuper {
6060

6161
#[derive(FromArgs)]
6262
pub struct InitArgs {
63-
#[pyarg(positional, optional)]
64-
py_type: OptionalArg<PyObjectRef>,
63+
#[pyarg(positional, optional, error_msg = "super() argument 1 must be a type")]
64+
py_type: OptionalArg<PyTypeRef>,
6565
#[pyarg(positional, optional)]
6666
py_obj: OptionalArg<PyObjectRef>,
6767
}
@@ -75,10 +75,7 @@ impl Initializer for PySuper {
7575
vm: &VirtualMachine,
7676
) -> PyResult<()> {
7777
// Get the type:
78-
let (typ, obj) = if let OptionalArg::Present(ty_obj) = py_type {
79-
let ty = ty_obj
80-
.downcast::<PyType>()
81-
.map_err(|_| vm.new_type_error("super() argument 1 must be a type"))?;
78+
let (typ, obj) = if let OptionalArg::Present(ty) = py_type {
8279
(ty, py_obj.unwrap_or_none(vm))
8380
} else {
8481
let frame = vm

0 commit comments

Comments
 (0)