Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aten/src/ATen/native/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Ten
// Fused op is marginally faster.
return at::addmm(*bias, input, weight.t());
}
if (input.dim() == 3 && bias->defined() && input.is_contiguous()) {
// Also hit the fused path for contiguous 3D input.
const auto input_sizes = input.sizes();
const auto result = at::addmm(*bias, input.view({input_sizes[0] * input_sizes[1], input_sizes[2]}), weight.t());
return result.view({input_sizes[0], input_sizes[1], result.size(1)});
}
auto output = at::matmul(input, weight.t());
if (bias->defined()) {
output.add_(*bias);
Expand Down