-
Notifications
You must be signed in to change notification settings - Fork 26.3k
fix DebugInterpreter, use it + functionalization stride debugger unconditionally in aot_eager backend #91038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…nditionally in aot_eager backend [ghstack-poisoned]
| def run(self, *args): | ||
| self.symbol_mapping = bind_symbols(self.module, *args) | ||
| super().run(*args) | ||
| if hasattr(self.module, "shape_env"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oopsie, thanks
| # NB: NOT cloned! | ||
| with enable_aot_logging(): | ||
| with enable_aot_logging(), torch._dispatch.python.enable_crossref_functionalize( | ||
| crossref_functionalize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be more logical to push this into AOT Autograd itself?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's fair. Actually, what do you think of unconditionally running it in aot autograd? These extra checks probably won't be our bottleneck for compile time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm mostly worried about the robustness of crossref functionalize, not sure I got the logic entirely right. If we can prove it out I'm amenable.
|
Gonna put this PR aside for now to focus on my other outstanding PR's. I spent a bit of time debugging why the functionalization stride checks don't play well when turned on for Notes for me: What I see is that some point when running a decomp underneath fake tensor, we end up calling |
|
The last time this happened to me, it's because there was a device= argument and fake tensor hadn't taken care of it (converting it to meta) |
|
Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as |
|
This one... is still relevant I think? |
The aot_eager backend now performs to types of stride checks automatically:
(1) at compile time, when running functioanlization we perform stride checks
(2) at runtime, when executing the graph we perform stride checks
Adding these debug asserts so that they always run in the
aot_eagerbackend has a small downside: at runtime, the backend will technically be slower than eager mode.This tentatively seems ok: the main purpose of this backend is for debugging purposes anyway.
This also would have caught a silent correctness error automatically, that was fixed at #91029
Stack from ghstack (oldest at bottom):
cc @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire