Skip to content

Question regarding computing ValueRanges for symbolic nodes #128640

@peri044

Description

@peri044

🐛 Describe the bug

Hello,

For some models, we have observed SymInt nodes in the graph which are of the form s0*s1. s0 and s1 are other SymInt Nodes in the graph.
Here's the shape_env info about s0*s1 node,

(Pdb) node
s0*s1
(Pdb) shape_env.var_to_val
{s0: 2, s1: 2}
(Pdb) shape_env.var_to_range
{s0: VR[2, 10], s1: VR[2, 10]}

I would like to get the ValueRanges for s0*s1 directly. This is currently working for us

(Pdb) node = dim.node # dim is `SymInt`
(Pdb) expr = node.expr
(Pdb) var_range  = shape_env.bound_sympy(expr)
(Pdb) var_range
VR[4, 100]
(Pdb) var_val = expr.xreplace(shape_env.var_to_val)
(Pdb) var_val
4

We use var_to_val as the optimal (or general) value that the variable has and var_range to be the (min, max) values it can take.

So is shape_env.bound_sympy and expr.xreplace(shape_env.var_to_val) correct way to compute ValueRanges for such SymInt nodes such as s0*s1 ? Are there any other recommendations ? Thanks

cc: @ezyang @avikchaudhuri @angelayi

Versions

[pip3] numpy==1.26.4
[pip3] pytorch-triton==3.0.0+45fff310c8
[pip3] torch==2.4.0.dev20240610+cu121
[pip3] torch-tensorrt==2.4.0.dev0+a8a079715
[pip3] torchvision==0.19.0.dev20240610+cu121
[pip3] triton==2.3.1
transformers==4.41.2

cc @ezyang @anijain2305 @chauhang

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: dynamic shapesoncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions