Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions torch/csrc/jit/symbolic_script.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,30 @@ const std::vector<std::string> functions = {
out = mat.permute(dims)
return out

def AD_matmul_size(mat1, mat2,
# In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
# instead of doing [b, m, p] and then reduce to [m, p]
# whice potentially uses large intermediate of size b*m*p,
# we do [m, bn] * [bn, p] to avoid having the large
# intermediate, thus reduces max memory usage.
def AD_matmul_bw_special_fold(mat1, mat2):
mat1_transpose = AD_mat_transpose(mat1)
mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
mat2_fold = mat2.reshape(-1, mat2.size()[-1])
return mat1_fold.t().mm(mat2_fold)

def AD_matmul_bw_size(mat1, mat2,
out_size: List[int]):
dim1 = mat1.dim()
dim2 = mat2.dim()
dim_out = len(out_size)
if dim1 == 0 or dim2 == 0:
out = mat1 * mat2
elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
out = AD_matmul_bw_special_fold(mat1, mat2)
elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
mat2_unsqueeze = mat2.unsqueeze(-1)
out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
out = out.squeeze(-1)
elif dim1 + dim2 == dim_out:
if dim2 == 1:
target_dim2 = 0
Expand All @@ -380,8 +397,8 @@ const std::vector<std::string> functions = {
def backward(grad_output):
self_size = self.size()
other_size = other.size()
grad_self = AD_matmul_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
grad_other = AD_matmul_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
return grad_self, grad_other

return torch.matmul(self, other), backward
Expand Down