-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Proposal for a more general usage of the parameter "example_inputs" in torch.jit.trace() #80019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal for a more general usage of the parameter "example_inputs" in torch.jit.trace() #80019
Conversation
🔗 Helpful links
❌ 6 New Failures, 1 Flaky FailuresAs of commit 7c02b82 (more details on the Dr. CI page): Expand to see more
🕵️ 6 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages
|
|
feel free to add me as a reviewer when this is ready |
Hi, @davidberard98 |
Fixes pytorch#80983 Also fix a small bug uncovered by the new test where creating memory_view for 0-sized inputs is not valid and is now skipped Pull Request resolved: pytorch#81105 Approved by: https://github.com/ezyang
Summary: Fixed formatting using clang-format: ```clang-format -style=file cumsum.cpp``` Test Plan: Compile time change only. Test not needed. Reviewed By: kirklandsign Differential Revision: D37703962 Pull Request resolved: pytorch#81107 Approved by: https://github.com/kirklandsign
) Fixes pytorch#80507 Pull Request resolved: pytorch#81055 Approved by: https://github.com/ngimel
I don't think there's a way to avoid functions returning undefined tensors as outputs, so codegen will have to detect them before calling _set_fw_grad. Alternatively, we can just make calling _set_fw_grad with undefined self a no-op, but I'm biasing toward keeping _set_fw_grad more strict in case it is called in other areas. Fixes pytorch#81111 Pull Request resolved: pytorch#81114 Approved by: https://github.com/albanD
Fixes the buggy `set_requires_cuda_init` introduced in pytorch#80788. Pull Request resolved: pytorch#81183 Approved by: https://github.com/ezyang
Issues like https://github.com/pytorch/pytorch/runs/7271166135?check_suite_focus=true shouldn't happen! This used to not pop up with @swang392's OG implementation, but now that the script returns None, we should add an if. Pull Request resolved: pytorch#81226 Approved by: https://github.com/seemethere, https://github.com/swang392, https://github.com/malfet
Adds a push trigger for the docker-builds workflow so that docker images will be re-built when changes that affect docker-builds get pushed Signed-off-by: Eli Uriegas <eliuriegas@fb.com> Pull Request resolved: pytorch#81228 Approved by: https://github.com/malfet
RUN_TORCHBENCH: ALL Pull Request resolved: pytorch#65839 Approved by: https://github.com/ngimel
disable `backwards_compat` test config b/c its broken by pytorch#81160 reenable when fixed Pull Request resolved: pytorch#81246 Approved by: https://github.com/malfet
I also filed while creating this PR. This PR... **Filed issues** - pytorch#79818 - pytorch#80154 **prims** - Fixes prims.squeeze when called with an unsorted list of dimensions - Removes the clone prim **refs** - adds contiguous - adds expand - updates clone to call empty_like and copy_to - updates empty to accept a memory format - updates empty_like to accept a memory_format **utils** - adds helper functions for working with memory formats and channels last tensors, in particular **tests** - removes unused clamp sample input functions (mooted by clamp's new reference inputs) - extends the reference inputs for clone to include different memory formats - creates reference inputs for contiguous - xfails operators that depend on clone (including clone) on `test_python_ref` (see issues) Pull Request resolved: pytorch#79820 Approved by: https://github.com/ngimel
…#81170) Fixes pytorch#78512 #### TODO - [x] add tests cc @kshitij12345! Pull Request resolved: pytorch#81170 Approved by: https://github.com/albanD
…sed version registration call (pytorch#81131) Summary: Added registration call to make pytorch runtime aware of `inference_wrapper_run_flat_out` functionality. Added fused version of op which will enable out variants when [OptimizeGraph pass](https://fburl.com/code/lsagnmge) is made. Next: need to add variadic version of fused and unfused op. Reviewed By: tgalkovskyi Differential Revision: D37139204 Pull Request resolved: pytorch#81131 Approved by: https://github.com/tenpercent, https://github.com/qxy11
This reverts commit cc31260. Reverted pytorch#74727 on behalf of https://github.com/mehtanirav due to Breaking multiple internals builds and tests
Previously we had a hack for tree_flatten not supporting torch.return_types. That was fixed a while ago (pytorch#74624) so we can delete the hack. Test Plan: - wait for tests Pull Request resolved: pytorch#81057 Approved by: https://github.com/kshitij12345, https://github.com/ezyang
…ept Callable (pytorch#81059) Maybe niche, but for one-off debugging purposes, I want a variant of check_backward_formula that accepts a callable rather than an OpInfo. This is because when debugging, I try to create a repro that does not involve OpInfos because OpInfos are difficult to deal with (they have a lot of sample inputs, I may want to test my own sample inputs without creating a new OpInfo, etc). This PR refactors check_backward_formula so that it accepts a Callable instead of an OpInfo. Example usage: ``` import torch from torch.testing._internal.composite_compliance import check_backward_formula x = torch.tensor([[1., 1.], [1., 0.]], requires_grad=True) args = (x, 1) check_backward_formula_callable(torch.prod, args, {}) ``` Test Plan: - run existing tests Pull Request resolved: pytorch#81059 Approved by: https://github.com/kshitij12345, https://github.com/ezyang
…81060) Composite compliance is supposed to check if a composite function calls .item() ([ref](https://github.com/pytorch/pytorch/blob/39db8b3823b8db82396cb979105a83e5e137a02f/torch/testing/_internal/composite_compliance.py#L135-L138)). This PR fixes that and adds some more documentation. Why do we need this check? The original motivations are that Tensor subclasses may not support .item calls (e.g. vmap and ProxyTensor). There is no way for these subclasses to meaningfully override the .item() calls in composite functions that exist inside the PyTorch framework without raising an error* so we should aim to rewrite composite operations to not call .item(). *We're open to other solutions, this is just the one we decided on when we wrote composite compliance testing and these tests help us keep track of the failing functionality. Test Plan: - wait for tests Pull Request resolved: pytorch#81060 Approved by: https://github.com/ezyang
…ubclass (pytorch#80734) - Added overloads to is_mutable method in FunctionSchema to tell whether an argument at index is mutable or an argument with name is mutable. - Created SchemaInfo subclass of FunctionSchema with constructors from FunctionSchema and from const char* signature. - Tested is_mutable method overloads in new test_schema_info.cpp file. **Note that this pr is used to set up SchemaInfo. Implementation for SchemaInfo will be addressed in later commits** Differential Revision: [D37651384](https://our.internmc.facebook.com/intern/diff/D37651384) Pull Request resolved: pytorch#80734 Approved by: https://github.com/davidberard98
Fixes pytorch#79266 Pull Request resolved: pytorch#80353 Approved by: https://github.com/mruberry
This reworks [80257](pytorch#80257) a bit to use ufmt: * ufmt https://ufmt.omnilib.dev/en/stable/ unifies both black and usort to automatically format the file in the "most Pythonic" way * Also make a demo run for all files in `tools/linter/**/*.py` Pull Request resolved: pytorch#81157 Approved by: https://github.com/suo
Copy-n-paste the list from https://github.com/PyCQA/isort/blob/main/isort/stdlibs/py310.py Tested locally and in pytorch#81233 Pull Request resolved: pytorch#81261 Approved by: https://github.com/suo
…ccept Callable (pytorch#81239) Like pytorch#81059; this PR addresses the review comments. Test Plan: - run tests Pull Request resolved: pytorch#81239 Approved by: https://github.com/ezyang
…ps (pytorch#81142) This PR: - Adds ref for relu6 and makes its OpInfo a UnaryUfuncInfo - Correct hardshrink ref when lambd < 0 and when inputs are nan - Corrected nan behavior vectorized implementation of hardshrink (fixes pytorch#81138) - Make OpInfos for {hard,soft}shrink, hardtanh UnaryUfuncInfos and add error_inputs for softshrink Pull Request resolved: pytorch#81142 Approved by: https://github.com/Lezcano, https://github.com/ngimel, https://github.com/mruberry
|
Close this one & create 81623 |
Support python callable object's parameters' unpacking feature in torch.jit.trace()
Python support passing a tuple or dictionary to a callable object and automatically unpacking the parameter to make the function get the right value. For tuple, the elements' order will result in the correctness of arguments' matching while dictionary use the arguments' name for matching.
In torch.jit.trace(), we support passing a tuple or a single tensor as our example_inputs to trigger the graph construction. Passing a tuple meanings the tuple elements' order should strictly align with the parameters' order declared in the forward() function. But in some scenarios, it's not that easy or just can't manually make such a tuple to trace a model when we already have off-the-shelf dataset which consists of dictionaries. Take the below 2 examples as illustration case:
In the first example, say, if we have a dataset and each of the data is a dictionary such as data = {'key1':value1, 'key2':value2, 'key3':value3}, we will fetch the data from dataloader and extract its value and make a tuple as the example_inputs passing to jit.trace(). But if the forward() method of our module which we want to trace has such declaration:
def forward(self, key2=value2, key3=value3, key1=value1):Thus, we need manually reorder the value in the tuple and check if they can match to the declaration. This case is just not friendly to users but we can still trace the module by reordering.
In the second example, say, we still have the same dataset, but the forward() method's declaration as below:
def forward(self, key2=value2, key3=value3, key4=None, key5=None, key1=value1, key6=value6):This time, we can't just passing a tuple by simply reordering the value, cause we missing key4 and key5 in the dataset, and we can't passing a tuple like example_inputs = (value2, value3, None, None, value1) since tracer doesn't support None type when the corresponding missing arguments' default value are None. For this situation, it's even not traceable.
These cases are not fabricated deliberately. We met these issue in Hugging Face's text-classification tasks with some datasets(such as GLUE tasks), and we expect it's a common issue.
So, we propose to make torch.jit.trace() support python function's kwargs unpacking feature with dictionary to address these issues. If we can support unpacking parameter with dictionary, then we can automatically reorder the value by argument's name information inside the trace, and thus we can serialize the dictionary and assemble a traceable stack for the graph.
Some concerns about the implementation
graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key1 : type1, %key2 : type2, %key3 : type3)other than:
graph(%self : __torch__.module.___torch_mangle_n.Mymodule, %key2 : type2, %key3 : type3, %key1 : type1)If somewhere else have the assumption that the IR's input parameters' order should strictly align with the order we defined in the python code, then our implementation may be infeasible. Such as the below situation:
If we have a module which will be traced with forward() method declaration as below:
def forward(self, key2, key3=value3, key4=value4, key5=value5, key1=value1, key6=value6):(with key2 as positional argumrnt)When we trace the module, we pass it with dictionary:
traced_module = torch.jit.trace(origin_moduel, dict_input, strict=False)where
dict_inputis a python dictionary such as{'key1':value1, 'key2':value2, 'key3':value3}But when we use the traced module, we pass it with positional arguments with the assumption its order should be aligned with what the python forward() code declared:
traced_module(inputs[1], key1=inputs[0], key3=inputs[2])where inputs[1]'s value is actually correspond to key2 as we declared in the forward(). In this situation, the invoke of traced module will fail cause we reordered the arguments with key1 as the graph's first argument.
Possible solution 1:
The first solution we came up with which will add some attributes in
FunctionSchemato represent if the arguments' order have been changed before tracing and record the original arguments' order. These attributes can be set before tracing when we knew the example_inputs parameter is a dictionary, and are used the functioncreateStackForSchema()to judge if we should recover the original order when we construct the stack.Cons:
To support this feature and fix this issue, we shouldn't modify any basic data structure or class such as
FunctionSchema.Possible solution 1:
We specify in the documentation of
torch.jit.trace()if the user would use dictionary as theirexample_inputs, then the must use the traced module with dictionary too.Cons:
This restrict the flexibility of this API.
Note