-
Notifications
You must be signed in to change notification settings - Fork 26.3k
ONNX Export LayerNorm #22265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX Export LayerNorm #22265
Changes from all commits
04efde0
9c81b6f
7b88ae6
9d33fd6
7059042
21e6d43
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| ir_version: 4 | ||
| producer_name: "pytorch" | ||
| producer_version: "1.1" | ||
| graph { | ||
| node { | ||
| input: "input" | ||
| input: "weight" | ||
| input: "bias" | ||
| output: "3" | ||
| op_type: "ATen" | ||
| attribute { | ||
| name: "cudnn_enable" | ||
| i: 1 | ||
| type: INT | ||
| } | ||
| attribute { | ||
| name: "eps" | ||
| f: 1e-05 | ||
| type: FLOAT | ||
| } | ||
| attribute { | ||
| name: "normalized_shape" | ||
| ints: 10 | ||
| ints: 10 | ||
| type: INTS | ||
| } | ||
| attribute { | ||
| name: "operator" | ||
| s: "layer_norm" | ||
| type: STRING | ||
| } | ||
| } | ||
| name: "torch-jit-export" | ||
| initializer { | ||
| dims: 10 | ||
| dims: 10 | ||
| data_type: 1 | ||
| name: "bias" | ||
| raw_data: "\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000\000" | ||
| } | ||
| initializer { | ||
| dims: 10 | ||
| dims: 10 | ||
| data_type: 1 | ||
| name: "weight" | ||
| raw_data: "\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?\000\000\200?" | ||
| } | ||
| input { | ||
| name: "input" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 20 | ||
| } | ||
| dim { | ||
| dim_value: 5 | ||
| } | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "weight" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| input { | ||
| name: "bias" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| output { | ||
| name: "3" | ||
| type { | ||
| tensor_type { | ||
| elem_type: 1 | ||
| shape { | ||
| dim { | ||
| dim_value: 20 | ||
| } | ||
| dim { | ||
| dim_value: 5 | ||
| } | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| dim { | ||
| dim_value: 10 | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
| opset_import { | ||
| version: 9 | ||
| } |
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -910,6 +910,33 @@ def batch_norm(g, input, weight, bias, running_mean, running_var, training, mome | |||
| return res | ||||
|
|
||||
|
|
||||
| @parse_args('v', 'is', 'v', 'v', 'f', 'i') | ||||
| def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): | ||||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a check here, if export mode is ONNX_ATEN_FALLBACK, we still export layer norm as ATen operator. Otherwise, the new export logic will significantly degrade the performance. Reference:
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can add a function _set_operator_export_type (similar to _set_opset_version) in symbolic_helper.py, to save the operator_export_type and access it here
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good to me. |
||||
| if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK: | ||||
| return g.op("ATen", input, weight, bias, normalized_shape_i=normalized_shape, | ||||
| eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm") | ||||
|
|
||||
| axes = [-i for i in range(len(normalized_shape), 0, -1)] | ||||
|
|
||||
| two_cst = g.op("Constant", value_t=torch.tensor(2.)) | ||||
| eps_cst = g.op("Constant", value_t=torch.tensor(eps)) | ||||
|
|
||||
| mean = g.op("ReduceMean", input, axes_i=axes) | ||||
| numerator = sub(g, input, mean) | ||||
| # variance = e((x - e(x))^2), and (x - e(x)) is the numerator in the layer_norm formula | ||||
| variance = g.op("ReduceMean", pow(g, numerator, two_cst), axes_i=axes) | ||||
| denominator = sqrt(g, add(g, variance, eps_cst)) | ||||
|
|
||||
| layer_norm = div(g, numerator, denominator) | ||||
|
|
||||
| if not (weight is None or weight.node().mustBeNone()): | ||||
| layer_norm = mul(g, layer_norm, weight) | ||||
| if not (bias is None or bias.node().mustBeNone()): | ||||
| layer_norm = add(g, layer_norm, bias) | ||||
|
|
||||
| return layer_norm | ||||
|
|
||||
|
|
||||
| @parse_args('v', 'v', 'v', 'v', 'v', 'i', 'f', 'f', 'i') | ||||
| def instance_norm(g, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled): | ||||
| input_sizes = input.type().sizes() | ||||
|
|
@@ -973,12 +1000,6 @@ def type_as(g, self, other): | |||
| return g.op("ATen", self, other, operator_s="type_as") | ||||
|
|
||||
|
|
||||
| @parse_args('v', 'is', 'v', 'v', 'f', 'i') | ||||
| def layer_norm(g, self, normalized_shape, weight, bias, eps, cudnn_enable): | ||||
| return g.op("ATen", self, weight, bias, normalized_shape_i=normalized_shape, | ||||
| eps_f=eps, cudnn_enable_i=cudnn_enable, operator_s="layer_norm") | ||||
|
|
||||
|
|
||||
| @parse_args('v', 'v', 'i', 'f') | ||||
| def cosine_similarity(g, x1, x2, dim, eps): | ||||
| return g.op("ATen", x1, x2, dim_i=dim, eps_f=eps, operator_s="cosine_similarity") | ||||
|
|
||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: one more empty line?