Skip to content

[jit] cannot trace tensor factory methods #9069

@ssnl

Description

@ssnl

Tracing newly created tensors is not too much a problem..... until we hit random numbers.

Things like torch.randn(3) are traced as prim::Constant, e.g.,

>>> def fn(x): return torch.randn(3,3)+x
...
>>>
>>> torch.jit.get_trace_graph(fn, torch.ones(1), nderivs=0)[0].graph()
graph(%0 : Float(1)) {
  %1 : Float(3, 3!) = aten::expand[size=[3, 3], implicit=1](%0)
  %2 : Float(3, 3) = prim::Constant[value= 1.0159 -0.3672 -0.7744  2.4991 -0.6615 -1.9963 -0.1653 -0.6848 -0.5436 [ CPUFloatTensor{3,3} ]]()
  %3 : Float(3, 3) = aten::add[alpha={1}](%2, %1)
  return (%3);
}

The _like variant is traced similarly (requires_grad or not):

>>> def fn(x): return torch.randn_like(x,requires_grad=True)+x
...
>>> torch.jit.get_trace_graph(fn, torch.ones(1), nderivs=0)[0].graph()
graph(%0 : Float(1)) {
  %1 : Float(1) = prim::Constant[value={0.299761}]()
  %2 : Float(1) = aten::add[alpha={1}](%1, %0)
  return (%2);
}

In-place sampling methods like normal_ aren't supported either.

This is blocking moving dropout into ATen #9008 . Because, assuming we don't have custom backward for dropout (we really shouldn't since it's just sampling and multiplication), then we need to trace creation of the stochastic mask tensor.

The workaround that I really really want to avoid is to have a custom backward. Then, because we need the stochastic mask in the backward, we need to write a dropout_with_mask that returns both the output and the mask, and write dropout as a wrapper. This shouldn't be necessary for such a simple op and makes the mask depending on input in graph.

The fundamental solution is to fix the jit tracer on random numbers.

Relevant #8450 .

cc @apaszke @jamesr66a

Metadata

Metadata

Assignees

No one assigned

    Labels

    oncall: jitAdd this issue/PR to JIT oncall triage queue

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions