Skip to content

Conversation

@SplitInfinity
Copy link

@SplitInfinity SplitInfinity commented Jul 29, 2020

Stack from ghstack:

Summary
This commit modifies IR generation to insert explicit cast that cast
each return value to Any when a function is annotated as returning Any.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an Any return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

Test Plan
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

Fixes
This commit fixes #41962.

Differential Revision: D22883244

**Summary**
This commit modifies the Python and C++ JIT frontends so that
they reject any function that is annotated with an `Any` return type.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. During IR generation, the type of
each return is checked against the annotated return type of the
function. Allowing `Any` to be used as a valid return type would involve
checking that all code paths return a common supertype that is not
`Any`. If such a type can be found, it is reasonable to expect users to
use that type instead of `Any` to annotate the return type of the
function. If it cannot be found, then the code is not valid TorchScript.

Continuing to allow the use of `Any` as a return type annotation can
also mislead users into believing that code in which different code
paths return values of different types can be compiled.

**Test Plan**
This commit adds a unit test that checks that an exception with an
appropriate error message is thrown when a function and module with an
annotated return type of `Any` is compiled.

**Fixes**
This commit fixes #41962.

[ghstack-poisoned]
@SplitInfinity SplitInfinity requested a review from apaszke as a code owner July 29, 2020 22:30
@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 29, 2020
@dr-ci
Copy link

dr-ci bot commented Jul 29, 2020

💊 CI failures summary and remediations

As of commit 07875e4 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 27 times.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should disallow this, after allowing it. It's going to break all sorts of code. I bet if you imported this there would be a bunch of failures internally.

I think the fix is to audit all of the callsites of unifyTypes, there aren't that many. All of the ones outside of ir_emitter.cpp/type.cpp (the initial type checking) and pybind_utils.h should be unifying Types to Any. If you added another overload, TypePtr unifyTypesWithAny (or something along those lines) that return Any instead of c10::nullopt I think it would fix all of the failures.

…t return Any"

**Summary**
This commit modifies the type checking performed during IR generation so
that it checks that all return types for a function marked as returning
`Any` can be unified.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. During IR generation, the type of
each return is checked against the annotated return type of the
function, or unified with all other return types if there is no
annotation. Because every type is a subtype of `Any`, this commit
modifies the IR generation to do the latter (i.e. check if all return
types can be unified) if the annotated return type is `Any`.

**Test Plan**
This commit adds a unit test that checks that an exception with an
appropriate error message is thrown when a function and module with an
annotated return type of `Any` is compiled and the possible return types
cannot be unified.

**Fixes**
This commit fixes #41962.

[ghstack-poisoned]
@SplitInfinity SplitInfinity changed the title [JIT] Disallow 'Any' return types for functions [JIT] Check mergeability of return types for functions that return Any Jul 30, 2020
SplitInfinity pushed a commit that referenced this pull request Jul 30, 2020
**Summary**
This commit modifies the type checking performed during IR generation so
that it checks that all return types for a function marked as returning
`Any` can be unified.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. During IR generation, the type of
each return is checked against the annotated return type of the
function, or unified with all other return types if there is no
annotation. Because every type is a subtype of `Any`, this commit
modifies the IR generation to do the latter (i.e. check if all return
types can be unified) if the annotated return type is `Any`.

**Test Plan**
This commit adds a unit test that checks that an exception with an
appropriate error message is thrown when a function and module with an
annotated return type of `Any` is compiled and the possible return types
cannot be unified.

**Fixes**
This commit fixes #41962.

ghstack-source-id: bc3ca78
Pull Request resolved: #42259
@SplitInfinity
Copy link
Author

SplitInfinity commented Jul 30, 2020

@eellison

Ah, backward compatibility :'(

It turns out that the IR emitter already checks if all return values' types can be unified if a function's return type is not explicitly annotated. I just uploaded a new version which uses this aforementioned code path when the annotated return type is Any instead of the code path that checks that each return value's type is a subtype of the annotated type (which is not very useful because every type is a subtype of Any).

Do we still need the two distinct versions of unifyTypes you mentioned?

@eellison
Copy link
Contributor

I think so, the updated fix still breaks BC. it's not just return types that's wrong, it's other stuff as well:

@torch.jit.script
def foo(cond: bool):
    x : Any = "hi"
    if cond:
        x = 3
    return x
print(foo.code)
def foo(cond: bool) -> Any:
  return "hi"

An output isn't getting added to the if because https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/frontend/convert_to_ssa.cpp#L95 is failing. I think we need to audit all of the use cases of unifyTypes and how they relate to possible Any usage, just covering up return types will still leave other errors around.

**Summary**
This commit modifies the type unification performed during the exit
transform pass to add if node outputs so that it falls back to Any
instead of nothing (c10::nullopt). If type unification fails in this
situation, it means that the two branches of the if return different
types, so the output if the if node should be of type Any.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
@SplitInfinity SplitInfinity changed the title [JIT] Check mergeability of return types for functions that return Any [JIT] Modify exit transform pass to unify types to Any Aug 4, 2020
SplitInfinity pushed a commit that referenced this pull request Aug 4, 2020
**Summary**
This commit modifies the type unification performed during the exit
transform pass to add if node outputs so that it falls back to Any
instead of nothing (c10::nullopt). If type unification fails in this
situation, it means that the two branches of the if return different
types, so the output if the if node should be of type Any.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

ghstack-source-id: 84adccd
Pull Request resolved: #42259
@SplitInfinity
Copy link
Author

I think I understand what you're saying. The use of unifyTypes that is responsible for causing the problem in the issue is in exit_transforms.cpp:

  static void addIfOutputs(
      Node* n,
      at::ArrayRef<Value*> true_outs,
      at::ArrayRef<Value*> false_outs) {
    IfView if_view(n);
    registerBlockOutputs(if_view.thenBlock(), true_outs);
    registerBlockOutputs(if_view.elseBlock(), false_outs);
    for (size_t i = 0; i < true_outs.size(); ++i) {
      auto out_type =
          unifyTypes(true_outs.at(i)->type(), false_outs.at(i)->type());
      n->addOutput()->setType(*out_type);
    }
  }

When the two branches return values of different types, unification fails and returns c10::nullopt, but this code doesn't check for that and uses the type as is, leading to a segmentation fault later during shape analysis.

I updated the PR to add an optional flag to unifyTypes that allows the caller to force the function to fall back to Any instead of nothing.

**Summary**
This commit modifies the type unification performed during the exit
transform pass to add if node outputs so that it falls back to Any
instead of nothing (c10::nullopt). If type unification fails in this
situation, it means that the two branches of the if return different
types, so the output if the if node should be of type Any.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 4, 2020
**Summary**
This commit modifies the type unification performed during the exit
transform pass to add if node outputs so that it falls back to Any
instead of nothing (c10::nullopt). If type unification fails in this
situation, it means that the two branches of the if return different
types, so the output if the if node should be of type Any.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

ghstack-source-id: 25100f6
Pull Request resolved: #42259
return 3


with torch._jit_internal._disable_emit_hooks():
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An unintended consequence of this change is that IR -> Python conversion broke because we always generate Python for an if node that looks like this:

if ...:
  ...
  a = 1
else:
  a = 2

return a

This doesn't work for this case because a cannot have two separate types in two separate code paths. I couldn't find a quick and easy way to modify Python code generation to account for this so I disabled export/import in the test for now while we discuss.

The only way I see around this is to hoist the return into the if and else branches if it comes immediately after the if. What do you think?

if ...:
  ...
  return 1
else:
  return 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the best solution is here. Potentially we could explicitly annotate a: Any in some cases

Copy link
Contributor

@eellison eellison Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to tell when a value is used as an Any output of an If, and to add an explicit annotation.

The simple case is that it only has one use, and to annotate the value with Any when it's printed. The trickier case is:

if cond:
    x = 1
    y = x + x
    return x, y
else:
    return "str", 3

Where the return value needs to be an int for its other uses, and then up-casted to an Any as the if-output. Maybe we could print this as

if cond:
    x = 1
    y = x + x
    out = unchecked_cast<Any, x>, y
else:
    out = "str", 3
return out

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that code legal? Something like this is not:

def my_fn(...) -> Any:
  a: Any = 3
  if cond:
    a = 5
  else:
    a = "five"
  return a

but something like this is:

def my_func(...) -> Any:
  a: Any = 3
  if cond:
    return 5
  else:
    return "five"

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, you're right the first example is not legal

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does that affect my comments ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to make the point that annotating the output with Any is not sufficient, we need to emit the return inside the if/else branches, which we don't currently do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem in your example is that neither if output has Any type, which I don't htink applies to my suggestion, that's what the unchecked_cast<Any> is for

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks so much for helping us on this @SplitInfinity @eellison ! Just an idea: would it help resolve the issue here if we require user to use an explicit cast_to_any API if they want to unify return types to Any? Something like:

def my_func(...) -> Any:
  if cond:
    return cast_to_any(5)
  else:
    return cast_to_any("five")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yf225 that would be a user-work around but I think we should just fix the problem in the compiler

@SplitInfinity
Copy link
Author

I think I managed to fix the segmentation fault from a mechanical standpoint, but I can't convince myself that this makes sense.

If the two branches of the if have corresponding outputs whose types cannot be unified without Any, doesn't that mean something is wrong?

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, I'm not exactly sure how we should handle the exporting...

I think we also need to handle shape_analysis call, and the convert_to_ssa call.

c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2) {
c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any) {
// check direct subtyping relation
if (t1->isSubtypeOf(t2)) {
Copy link
Contributor

@eellison eellison Aug 17, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of the subcalls to unifyTypes need to thread the default_to_any argument edit: nvm that's not quite right, but this doesnt handle stuff like

return c10::nullopt;

It would probably be easier to have:

c10::optional<TypePtr> unifyTypesImpl(unifyTypes(const TypePtr& t1, const TypePtr& t2)

c10::optional<TypePtr> unifyTypes(const TypePtr& t1, const TypePtr& t2, bool default_to_any) 

return 3


with torch._jit_internal._disable_emit_hooks():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what the best solution is here. Potentially we could explicitly annotate a: Any in some cases

@SplitInfinity
Copy link
Author

I have listed the callsites for unifyTypes below along with my thoughts on whether they should be modified or not. I'm not as familiar with the transforms so I could be wrong. Let me know what you think.

convert_to_ssa.cpp

  void addIfLoadStores(Node* n) {
  ...
    // Following the same logic as emitIfElseBlocks in ir_emitter.cpp,
    // we emit a node output if the variable is defined in each block
    // and the types of each block can be unified
    for (const auto& x : mutated_variables) {
      auto true_type = true_vars->findInAnyFrame(x);
      auto false_type = false_vars->findInAnyFrame(x);
      auto unified = unifyTypes(true_type, false_type);
      if (!unified) {
        continue;
      }

      addBlockOutput(true_block, true_type, x);
      addBlockOutput(false_block, false_type, x);
      addNodeOutput(n, *unified, x);
    }

Judging by the comment, it seems like this logic is supposed to mirror ir_emitter.cpp, which shouldn't default to Any.

      // since the loop may execute 0 or many times, the output types
      // of the loop and the input loop carried dependencies are conservatively
      // the union of the output of the body and the input to the loop
      auto block_type = loop_vars->findInThisFrame(name);
      auto unified_type = unifyTypes(parent_type, block_type).value();

      // Insert a store at the beginning of the loop block, so that all
      // loads of the variable will use the loop carried value
      addNodeInput(n, parent_type, name);
      addBlockInput(body_block, unified_type, name);
      addBlockOutput(body_block, block_type, name);
      addNodeOutput(n, unified_type, name);

I think we should default to Any here. This is basically the loop version of the change this PR makes to exit_transforms.cpp. We should always be able to unify the types of a loop-carried variable outside and inside a loop even if that means resorting to Any. However, there is no check in ir_emitter.cpp that the type of a loop-carried variable doesn't change inside the loop, and there is a similar check for if statement outputs.

shape_analysis.cpp

  bool mergeTypes(
      ArrayRef<Value*> lhs,
      ArrayRef<Value*> rhs,
      ArrayRef<Value*> outputs) {
    AT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
    bool changed = false;
    for (size_t i = 0; i < lhs.size(); ++i) {
      auto old_output_type = outputs[i]->type();
      auto new_type = unifyTypes(lhs[i]->type(), rhs[i]->type(), /*default_to_any=*/true);
      AT_ASSERT(new_type);
      outputs[i]->setType(*new_type);
      if (*old_output_type != *outputs[i]->type())
        changed = true;
    }
    return changed;
  }

To be honest I have no idea what this code is for but given that it expects new_type to always be valid, it seems like defaulting to Any is the right thing to do?

@SplitInfinity SplitInfinity changed the title [JIT] Modify exit transform pass to unify types to Any [JIT] Cast return values of functions returning Any Aug 21, 2020
**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
@SplitInfinity
Copy link
Author

@eellison @yf225

I made the changes we discussed and this particular issue seems to be solved. I opened #43378 for the other problem we discussed because I think it is sufficiently distinct from this one.

**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 21, 2020
**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

ghstack-source-id: 77d7dbf
Pull Request resolved: #42259
@SplitInfinity SplitInfinity requested a review from eellison August 21, 2020 00:54
@yf225
Copy link
Contributor

yf225 commented Aug 21, 2020

Thanks so much @SplitInfinity! Curious do you mind checking if the following use case works as well?

import torch
from typing import Any, Dict

class IdListFeature(object):
    def __init__(self, lengths, values):
        self.lengths = lengths
        self.values = values


class IdScore(object):
    def __init__(self, ids, scores):
        self.ids = ids
        self.scores = scores


class IdScoreListFeature(object):
    def __init__(self, lengths, ids, scores):
        self.lengths = lengths
        self.values = IdScore(ids=ids, scores=scores)


class HashFeatureIds(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, module_input: Any) -> Any:
        if isinstance(module_input, IdListFeature):
            return module_input
        elif isinstance(module_input, IdScoreListFeature):
            return module_input
        raise Exception

m = HashFeatureIds()
torch.jit.script(m)

The only difference is that the return value is now actually a custom TorchScript class. Thanks again!

@SplitInfinity
Copy link
Author

@yf225

It scripts! 🥳

@yf225
Copy link
Contributor

yf225 commented Aug 21, 2020

Thanks for the confirmation @SplitInfinity! We noticed that the fbcode diff https://www.internalfb.com/intern/diff/D22883244/ has different content from the PR (and the examples don't work on the fbcode diff), wondering should we update the fbcode diff using content from this PR?

@SplitInfinity
Copy link
Author

Ooops, yeah, I forgot run ghimport. The diff should be updated now.

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just needs one fix re: shape_analysis.

As #43378 shows, we still have pretty significant holes in our Any support that need fixing.

One thing I don't like about this approach is that it's very fragile to optimizations - if we run peephole optimizations on the graph then that will break export. I suspect with the fix to #43378 we will have to revisit our approach here and when we unify types to Any.

Nevertheless i think this does make code base better and fixes a big use case so i think we should land this, we just need to update shape_analysis.


def_stack_.back().merged_return_type_ = result_type;

if (result_type == AnyType::get() && result->type() != AnyType::get()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment to why this is here ?

types early from each branch when the return
type of the function is Any.
"""
def if_function(inp: torch.Tensor) -> Any:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this function is going to fail on the legacy executor because we need to update https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/passes/shape_analysis.cpp#L286 to unify to Any.

**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 21, 2020
**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

ghstack-source-id: 73ceebd
Pull Request resolved: #42259
@yf225
Copy link
Contributor

yf225 commented Aug 22, 2020

Thanks a lot for this feature @SplitInfinity @eellison - so much of the API simplification we are trying to do is dependent on this, without it we are not even sure how to proceed with scripting the PyPer models. Thanks again for you guys' help!

Copy link
Contributor

@eellison eellison left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! you have an accidental file included

**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

Differential Revision: [D22883244](https://our.internmc.facebook.com/intern/diff/D22883244)

[ghstack-poisoned]
SplitInfinity pushed a commit that referenced this pull request Aug 24, 2020
**Summary**
This commit modifies IR generation to insert explicit cast that cast
each return value to `Any` when a function is annotated as returning `Any`.
This precludes the failure in type unification (see below) that caused
this issue.

Issue #41962 reported that the use of an `Any` return type in
combination with different code paths returning values of different
types causes a segmentation fault. This is because the exit transform
pass tries to unify the different return types, fails, but silently sets
the type of the if node to c10::nullopt. This causes problems later in
shape analysis when that type object is dereferenced.

**Test Plan**
This commit adds a unit test that checks that a function similar to the
one in #41962 can be scripted and executed.

**Fixes**
This commit fixes #41962.

ghstack-source-id: c7a9652
Pull Request resolved: #42259
@facebook-github-bot
Copy link
Contributor

@SplitInfinity merged this pull request in 00c1501.

@facebook-github-bot facebook-github-bot deleted the gh/splitinfinity/25/head branch August 30, 2020 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged oncall: jit Add this issue/PR to JIT oncall triage queue

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants