Skip to content

Conversation

@angelayi
Copy link
Contributor

@angelayi angelayi commented Jun 9, 2023

https://docs.google.com/document/d/10DYFG2sU3TSvguFP5kYwYLlo45KHFg3BhBOkUk0NKsU/edit#bookmark=id.hgfzmhlzkamk

Renamed mixed_dtype --> out_dtype because "mixed_dtype is not very descriptive in the context of regular pytorch where we support type promotion on most ops"

cc @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented Jun 9, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/103333

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit b2d5b57:

NEW FAILURE - The following job has failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@angelayi angelayi changed the title Add torch.ops.mixed_dtype Add torch.ops.output_dtype Jun 10, 2023
@pytorch-bot pytorch-bot bot added the release notes: fx release notes category label Jun 10, 2023
@jerryzh168
Copy link
Contributor

I feel output_dtype name is a bit weird, especially when it also has an argument called out_dtype...is there any better alternatives?

@angelayi
Copy link
Contributor Author

@jerryzh168 Yeah..naming is hard. Some other options: promote_output_dtype, cast_output_dtype, as_dtype, with_dtype
cc @ezyang

@ezyang
Copy link
Contributor

ezyang commented Jun 13, 2023

How about compute_precision?

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a good start, I added some comments around testing and the implementation

Comment on lines 64 to 65
@output_dtype.py_impl(DispatchKey.Autograd)
def output_dtype_autograd(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will you care about autograd in the near future? If so -- what would the backward formula for output_dtype(torch.ops.aten.mm, int8) look like? (it might end up mattering)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah so this is exactly what I would like to talk about. Take following example:

class M(torch.nn.Module):
  def __init__(self):
    self.linear1 = ...
    self.linear2 = ...

  def forward(self, x):
    return (self.linear2(self.linear1(x))

If I want to quantize M via quantization aware training, the current workflow on top of exported graph, will modify forward graph to effectively do

call this FakeQuantizeRepresentation
  def forward(self, x):
    x = fake_quantize(x, <some-quant-params-that-are-being-learnt-for-linear1>) # this step, effectively, just uses quant params to quantize and dequntize x
    x = self.linear1(x)
    x = fake_quantize(x, <some-quant-params-that-are-being-learnt-for-linear2>)
    x = self.linear2(x)
    return x

Above may not be entirely accurate, but close enough (@jerryzh168 @andrewor14 can confirm)

Fake quantize do not accurately capture numerical behavior of target hardware. While achieving bit-exact behavior might be hard, it will be very useful to do better approximation of numerics than what is done by fake quantize nodes. If we were able to do better approximation, then, the hypothesis is that, we will have computed loss that is more accurate, which may result in gradient updates that account for achievable accuracy on target hw. So what should forward look like with better numerics? We approximate that with integer compute e.g.

call this IntRepresentation
  def forward(self, x):
    x = quantize_int8(x, <some-quant-params-that-are-being-learnt-for-linear1>)
    x = self.int8_linear1(x, <some-quant-params-that-are-being-learnt-for-linear1>) # decomposition of this op uses "mixed_dtype" op introduces in this diff. Note that int8_linear will return int8 tensor
    x = self.int8_linear2(x, <some-quant-params-that-are-being-learnt-for-linear2>)
    return dequantize(x, <some-quant-params-that-are-being-learnt-for-linear2>)

Note that in IntRepresentation linear1 and linear2 are using int8 compute as opposed to FakeQuantizeRepresentation. If we can figure out autograd story for IntRepresentation that would be great. One of the challenges, and I have heard this from other vendors, is that int tensors dont store graidents.

If some of this is not super clear, I will draft something with more details. cc: @ezyang

Copy link
Contributor

@jerryzh168 jerryzh168 Jun 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kimishpatel thanks for writing down the details, did you get any feedback from hardware vendors how to do backward on integer compute? since it's not enough just to swap in the integer representation, we also need the computation to be differentiable in order to train. maybe one thing is to just use the fp version of the int8_linear operator, i.e. we convert all inputs to fp32, like the following:

def forward(self, x): 
    x = quantize_int8(x, <some-quant-params-that-are-being-learnt-for-linear1>) 
    # although x, qweight1 are int8, we convert them to fp32 to make the computation differentiable
    x = x.to(torch.float32)
    qweight1 = qweight1.to(torch.float32)
    x = mm(x, qweight1, ...)
    x = rescale(x)
    x = clamp(x, qmin, qmax).to(torch.int8)
    x = x.to(torch.float32)
    # although x, qweight2 are int8, we convert them to fp32 to make the computation differentiable
    qweight2 = qweight2.to(torch.float32)
    x = mm(x, qweight2, ...)
    x = rescale(x)
    x = clamp(x, qmin, qmax).to(torch.int8)
    return dequantize(x, <some-quant-params-that-are-being-learnt-for-linear2>)

but then the computation would happen in fp32 as well, so I'm not sure if this can approximate the mm_int8_int8_int32 closely or not

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feedback from hardware vendors how to do backward on integer compute?

They can do that today and some of them do, blockers for them is inability of int tensors to store gradient. But if we can enable this natively for decompositions that expressed via aten ops, that would add a lot of value.

maybe one thing is to just use the fp version of the int8_linear operator, i.e. we convert all inputs to fp32

If we do this it will defeat the purpose of capturing compute with integer arithmetic. This is why it is somewhat non-trivial. Let me start a chat thread on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

blockers for them is inability of int tensors to store gradient

is this in pytorch? I remember tried this once and it just says int Tensor doesn't have gradient, is it what you are referring to?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes thats what m referring to

@jerryzh168
Copy link
Contributor

jerryzh168 commented Jun 14, 2023

I checked with chatGPT as well, and considering all the above suggestions, maybe promote_output_dtype_and_apply or cast_input_to_output_dtype_and_apply? these are a bit long though

here is the original Q & A:

Q: I'm defining a new higher order operator in PyTorch that takes an operator and an output dtype and the arguments of the operator, it casts all the input Tensors to use the output dtype and then call the operator, what can be the names for this higher order operator?

A:
There are several possible names for this higher-order operator in PyTorch. Here are a few suggestions:

dtype_cast_and_apply: This name clearly conveys the functionality of the operator, indicating that it casts the input tensors to the specified output dtype and then applies the given operator.

type_converted_operator: This name emphasizes the conversion of tensor types before applying the operator.

dtype_transform_operator: This name highlights the transformation of the dtype for the input tensors before applying the operator.

type_adaptive_operator: This name suggests that the operator adapts to the specified output dtype by converting the input tensors accordingly.

type_preserving_apply: This name conveys that the operator preserves the specified output dtype by casting the input tensors before applying the operator.

Ultimately, the choice of name depends on your preference and the context in which the operator will be used. It's important to select a name that accurately reflects the functionality and purpose of the operator to make it clear to other developers who may use or interact with your code.

Copy link
Contributor

@kimishpatel kimishpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments

@kimishpatel
Copy link
Contributor

Another candidate could be cast_inputs_to_dtype which is basically what the op is doing?

@angelayi
Copy link
Contributor Author

angelayi commented Jun 14, 2023

After offline discussion, we decided on higher_precision as the name, but changes some of the semantics of this op. Will pause until we figure the details out -- it'll look something like: the operator will be computed with the inputs at a general higher precision (with the inputs promoted based on some promotion table), and then the results will be cast to a user specified output_dtype kwarg. (look at Jerry's upcoming doc for more specific details)

Also will skip for the autograd implementation for now because we still need to figure out plans for QAT.

@pytorch-bot
Copy link

pytorch-bot bot commented Jul 17, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: argument command: invalid choice: 'rebae' (choose from 'merge', 'revert', 'rebase', 'label', 'drci')

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@angelayi
Copy link
Contributor Author

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased mixed_dtype onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout mixed_dtype && git pull --rebase)

@angelayi
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

This PR updates submodules third_party/cutlass, third_party/pybind11, third_party/cudnn_frontend

If those updates are intentional, please add "submodule" keyword to PR title/description.

@angelayi
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 1 jobs have failed, first few of them are: periodic / linux-bionic-cuda11.8-py3.9-gcc7 / test (multigpu, 1, 1, linux.16xlarge.nvidia.gpu)

Details for Dev Infra team Raised by workflow job

@angelayi
Copy link
Contributor Author

@pytorchbot merge -f "test timed out 🥺"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: dynamo release notes: quantization release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants