-
Notifications
You must be signed in to change notification settings - Fork 26.3k
Add torch.ops.out_dtype #103333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.ops.out_dtype #103333
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103333
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 1 Unrelated FailureAs of commit b2d5b57: NEW FAILURE - The following job has failed:
FLAKY - The following job failed but was likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
I feel |
|
@jerryzh168 Yeah..naming is hard. Some other options: promote_output_dtype, cast_output_dtype, as_dtype, with_dtype |
|
How about |
zou3519
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a good start, I added some comments around testing and the implementation
| @output_dtype.py_impl(DispatchKey.Autograd) | ||
| def output_dtype_autograd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will you care about autograd in the near future? If so -- what would the backward formula for output_dtype(torch.ops.aten.mm, int8) look like? (it might end up mattering)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah so this is exactly what I would like to talk about. Take following example:
class M(torch.nn.Module):
def __init__(self):
self.linear1 = ...
self.linear2 = ...
def forward(self, x):
return (self.linear2(self.linear1(x))
If I want to quantize M via quantization aware training, the current workflow on top of exported graph, will modify forward graph to effectively do
call this FakeQuantizeRepresentation
def forward(self, x):
x = fake_quantize(x, <some-quant-params-that-are-being-learnt-for-linear1>) # this step, effectively, just uses quant params to quantize and dequntize x
x = self.linear1(x)
x = fake_quantize(x, <some-quant-params-that-are-being-learnt-for-linear2>)
x = self.linear2(x)
return x
Above may not be entirely accurate, but close enough (@jerryzh168 @andrewor14 can confirm)
Fake quantize do not accurately capture numerical behavior of target hardware. While achieving bit-exact behavior might be hard, it will be very useful to do better approximation of numerics than what is done by fake quantize nodes. If we were able to do better approximation, then, the hypothesis is that, we will have computed loss that is more accurate, which may result in gradient updates that account for achievable accuracy on target hw. So what should forward look like with better numerics? We approximate that with integer compute e.g.
call this IntRepresentation
def forward(self, x):
x = quantize_int8(x, <some-quant-params-that-are-being-learnt-for-linear1>)
x = self.int8_linear1(x, <some-quant-params-that-are-being-learnt-for-linear1>) # decomposition of this op uses "mixed_dtype" op introduces in this diff. Note that int8_linear will return int8 tensor
x = self.int8_linear2(x, <some-quant-params-that-are-being-learnt-for-linear2>)
return dequantize(x, <some-quant-params-that-are-being-learnt-for-linear2>)
Note that in IntRepresentation linear1 and linear2 are using int8 compute as opposed to FakeQuantizeRepresentation. If we can figure out autograd story for IntRepresentation that would be great. One of the challenges, and I have heard this from other vendors, is that int tensors dont store graidents.
If some of this is not super clear, I will draft something with more details. cc: @ezyang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kimishpatel thanks for writing down the details, did you get any feedback from hardware vendors how to do backward on integer compute? since it's not enough just to swap in the integer representation, we also need the computation to be differentiable in order to train. maybe one thing is to just use the fp version of the int8_linear operator, i.e. we convert all inputs to fp32, like the following:
def forward(self, x):
x = quantize_int8(x, <some-quant-params-that-are-being-learnt-for-linear1>)
# although x, qweight1 are int8, we convert them to fp32 to make the computation differentiable
x = x.to(torch.float32)
qweight1 = qweight1.to(torch.float32)
x = mm(x, qweight1, ...)
x = rescale(x)
x = clamp(x, qmin, qmax).to(torch.int8)
x = x.to(torch.float32)
# although x, qweight2 are int8, we convert them to fp32 to make the computation differentiable
qweight2 = qweight2.to(torch.float32)
x = mm(x, qweight2, ...)
x = rescale(x)
x = clamp(x, qmin, qmax).to(torch.int8)
return dequantize(x, <some-quant-params-that-are-being-learnt-for-linear2>)
but then the computation would happen in fp32 as well, so I'm not sure if this can approximate the mm_int8_int8_int32 closely or not
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
feedback from hardware vendors how to do backward on integer compute?
They can do that today and some of them do, blockers for them is inability of int tensors to store gradient. But if we can enable this natively for decompositions that expressed via aten ops, that would add a lot of value.
maybe one thing is to just use the fp version of the int8_linear operator, i.e. we convert all inputs to fp32
If we do this it will defeat the purpose of capturing compute with integer arithmetic. This is why it is somewhat non-trivial. Let me start a chat thread on this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blockers for them is inability of int tensors to store gradient
is this in pytorch? I remember tried this once and it just says int Tensor doesn't have gradient, is it what you are referring to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes thats what m referring to
|
I checked with chatGPT as well, and considering all the above suggestions, maybe here is the original Q & A: Q: I'm defining a new higher order operator in PyTorch that takes an operator and an output dtype and the arguments of the operator, it casts all the input Tensors to use the output dtype and then call the operator, what can be the names for this higher order operator? A: dtype_cast_and_apply: This name clearly conveys the functionality of the operator, indicating that it casts the input tensors to the specified output dtype and then applies the given operator. type_converted_operator: This name emphasizes the conversion of tensor types before applying the operator. dtype_transform_operator: This name highlights the transformation of the dtype for the input tensors before applying the operator. type_adaptive_operator: This name suggests that the operator adapts to the specified output dtype by converting the input tensors accordingly. type_preserving_apply: This name conveys that the operator preserves the specified output dtype by casting the input tensors before applying the operator. Ultimately, the choice of name depends on your preference and the context in which the operator will be used. It's important to select a name that accurately reflects the functionality and purpose of the operator to make it clear to other developers who may use or interact with your code. |
kimishpatel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments
|
Another candidate could be |
|
After offline discussion, we decided on Also will skip for the autograd implementation for now because we still need to figure out plans for QAT. |
|
❌ 🤖 pytorchbot command failed: Try |
|
@pytorchbot rebase |
|
@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here |
|
Successfully rebased |
aa3350d to
e1163ac
Compare
|
@pytorchbot merge |
|
This PR updates submodules third_party/cutlass, third_party/pybind11, third_party/cudnn_frontend If those updates are intentional, please add "submodule" keyword to PR title/description. |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Merge failedReason: 1 jobs have failed, first few of them are: periodic / linux-bionic-cuda11.8-py3.9-gcc7 / test (multigpu, 1, 1, linux.16xlarge.nvidia.gpu) Details for Dev Infra teamRaised by workflow job |
|
@pytorchbot merge -f "test timed out 🥺" |
Merge startedYour change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
https://docs.google.com/document/d/10DYFG2sU3TSvguFP5kYwYLlo45KHFg3BhBOkUk0NKsU/edit#bookmark=id.hgfzmhlzkamk
Renamed mixed_dtype --> out_dtype because "mixed_dtype is not very descriptive in the context of regular pytorch where we support type promotion on most ops"
cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov