-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
🐛 Describe the bug
Hello,
For some models, we have observed SymInt nodes in the graph which are of the form s0*s1. s0 and s1 are other SymInt Nodes in the graph.
Here's the shape_env info about s0*s1 node,
(Pdb) node
s0*s1
(Pdb) shape_env.var_to_val
{s0: 2, s1: 2}
(Pdb) shape_env.var_to_range
{s0: VR[2, 10], s1: VR[2, 10]}I would like to get the ValueRanges for s0*s1 directly. This is currently working for us
(Pdb) node = dim.node # dim is `SymInt`
(Pdb) expr = node.expr
(Pdb) var_range = shape_env.bound_sympy(expr)
(Pdb) var_range
VR[4, 100]
(Pdb) var_val = expr.xreplace(shape_env.var_to_val)
(Pdb) var_val
4We use var_to_val as the optimal (or general) value that the variable has and var_range to be the (min, max) values it can take.
So is shape_env.bound_sympy and expr.xreplace(shape_env.var_to_val) correct way to compute ValueRanges for such SymInt nodes such as s0*s1 ? Are there any other recommendations ? Thanks
cc: @ezyang @avikchaudhuri @angelayi
Versions
[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.4.0.dev20240610+cu121
[pip3] torch-tensorrt==2.4.0.dev0+a8a079715
[pip3] torchvision==0.19.0.dev20240610+cu121
[pip3] triton==2.3.1
transformers==4.41.2