-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Description
Tracing newly created tensors is not too much a problem..... until we hit random numbers.
Things like torch.randn(3) are traced as prim::Constant, e.g.,
>>> def fn(x): return torch.randn(3,3)+x
...
>>>
>>> torch.jit.get_trace_graph(fn, torch.ones(1), nderivs=0)[0].graph()
graph(%0 : Float(1)) {
%1 : Float(3, 3!) = aten::expand[size=[3, 3], implicit=1](%0)
%2 : Float(3, 3) = prim::Constant[value= 1.0159 -0.3672 -0.7744 2.4991 -0.6615 -1.9963 -0.1653 -0.6848 -0.5436 [ CPUFloatTensor{3,3} ]]()
%3 : Float(3, 3) = aten::add[alpha={1}](%2, %1)
return (%3);
}The _like variant is traced similarly (requires_grad or not):
>>> def fn(x): return torch.randn_like(x,requires_grad=True)+x
...
>>> torch.jit.get_trace_graph(fn, torch.ones(1), nderivs=0)[0].graph()
graph(%0 : Float(1)) {
%1 : Float(1) = prim::Constant[value={0.299761}]()
%2 : Float(1) = aten::add[alpha={1}](%1, %0)
return (%2);
}In-place sampling methods like normal_ aren't supported either.
This is blocking moving dropout into ATen #9008 . Because, assuming we don't have custom backward for dropout (we really shouldn't since it's just sampling and multiplication), then we need to trace creation of the stochastic mask tensor.
The workaround that I really really want to avoid is to have a custom backward. Then, because we need the stochastic mask in the backward, we need to write a dropout_with_mask that returns both the output and the mask, and write dropout as a wrapper. This shouldn't be necessary for such a simple op and makes the mask depending on input in graph.
The fundamental solution is to fix the jit tracer on random numbers.
Relevant #8450 .