-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Closed
Description
Issue description
I was experimenting with the torch.jit module in PyTorch master to test time savings on static Pyro models. I found that the call to torch._C._infer_size function in torch.distributions.utils._broadcast_shape fails when the function is traced. All distributions call this function to ensure that the parameters are correctly broadcasted and stored as attributes with the correct shape (even if broadcasting is not required in all instances).
A minimal example that throws error is pasted below. I am not sure what the status of JIT is currently (and whether distributions are expected to work currently), so please feel free to close this if it is a known issue.
Code example
>>> @torch.jit.trace(torch.randn(3, 3), torch.randn(3, 3))
... def fn_to_jit(a, b):
... shape = torch._C._infer_size(a.size(), b.size())
... return torch.ones(shape).sum()
...
>>> fn_to_jit(torch.randn(3, 3), torch.randn(3, 3))Error Trace:
Traceback (most recent call last):
File "/Applications/PyCharm.app/Contents/helpers/pydev/pydev_run_in_console.py", line 53, in run_file
pydev_imports.execfile(file, globals, locals) # execute the script
File "/Users/npradhan/workspace/pyro_dev/pyro/examples/jit.py", line 20, in <module>
@torch.jit.trace(torch.randn(3, 3), torch.randn(3, 3))
File "/Users/npradhan/miniconda2/envs/pytorch-build/lib/python2.7/site-packages/torch/jit/__init__.py", line 308, in wrapper
return torch._C.GraphExecutor(func, args, **executor_options)
File "/Users/npradhan/workspace/pyro_dev/pyro/examples/jit.py", line 22, in fn_to_jit
shape = torch._C._infer_size(a.size(), b.size())
RuntimeError: expected int at position 0, but got: Tensor
fn_to_jit(torch.randn(3, 3), torch.randn(3, 3))
apaszke
Metadata
Metadata
Assignees
Labels
No labels