@@ -1331,14 +1331,17 @@ void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
13311331 THTensor_ (copy )(r_ , t );
13321332 }
13331333
1334- if (mat -> stride [0 ] == 1 )
1334+ // n == 1 || lda >= max(1, m)
1335+ #define LDA_COND (M , N , LDA ) ((N) == 1 || (LDA) >= THMax(1, (M)))
1336+
1337+ if (mat -> stride [0 ] == 1 && LDA_COND (mat -> size [0 ], mat -> size [1 ], mat -> stride [1 ]))
13351338 {
13361339 THBlas_ (gemv )('n' , mat -> size [0 ], mat -> size [1 ],
13371340 alpha , THTensor_ (data )(mat ), mat -> stride [1 ],
13381341 THTensor_ (data )(vec ), vec -> stride [0 ],
13391342 beta , THTensor_ (data )(r_ ), r_ -> stride [0 ]);
13401343 }
1341- else if (mat -> stride [1 ] == 1 )
1344+ else if (mat -> stride [1 ] == 1 && LDA_COND ( mat -> size [ 1 ], mat -> size [ 0 ], mat -> stride [ 0 ]) )
13421345 {
13431346 THBlas_ (gemv )('t' , mat -> size [1 ], mat -> size [0 ],
13441347 alpha , THTensor_ (data )(mat ), mat -> stride [0 ],
@@ -1356,6 +1359,8 @@ void THTensor_(addmv)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
13561359
13571360 THTensor_ (free )(cmat );
13581361 }
1362+
1363+ #undef LDA_COND
13591364}
13601365
13611366void THTensor_ (match )(THTensor * r_ , THTensor * m1 , THTensor * m2 , real gain )
@@ -1434,15 +1439,18 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
14341439 }
14351440 }
14361441
1442+ // n == 1 || ldc >= max(1, m)
1443+ #define LDC_COND (M , N , LDC ) ((N) == 1 || (LDC) >= THMax(1, M))
1444+
14371445 /* r_ */
14381446 if (r_ -> stride [0 ] == 1 &&
1439- r_ -> stride [ 1 ] != 0 )
1447+ LDC_COND ( r_ -> size [ 0 ], r_ -> size [ 1 ], r_ -> stride [ 1 ]) )
14401448 {
14411449 transpose_r = 'n' ;
14421450 r__ = r_ ;
14431451 }
14441452 else if (r_ -> stride [1 ] == 1 &&
1445- r_ -> stride [ 0 ] != 0 )
1453+ LDC_COND ( r_ -> size [ 1 ], r_ -> size [ 0 ], r_ -> stride [ 0 ]) )
14461454 {
14471455 THTensor * swap = m2 ;
14481456 m2 = m1 ;
@@ -1453,22 +1461,30 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
14531461 else
14541462 {
14551463 transpose_r = 'n' ;
1456-
1464+ // make r__ FORTRAN contiguous
14571465 THTensor * transp_r_ = THTensor_ (newTranspose )(r_ , 0 , 1 );
14581466 r__ = THTensor_ (newClone )(transp_r_ );
14591467 THTensor_ (free )(transp_r_ );
14601468 THTensor_ (transpose )(r__ , NULL , 0 , 1 );
14611469 }
14621470
1471+ #undef LDC_COND
1472+
1473+ int64_t m = r__ -> size [(transpose_r == 'n' ? 0 : 1 )];
1474+ int64_t n = r__ -> size [(transpose_r == 'n' ? 1 : 0 )];
1475+ int64_t k = m1 -> size [(transpose_r == 'n' ? 1 : 0 )];
1476+ int64_t ldr__ = r__ -> stride [(transpose_r == 'n' ? 1 : 0 )];
1477+
14631478 /* m1 */
1479+ /* Need ldm1_ >= max(1, (transpose_m1 == 't' ? m : k)) */
14641480 if (m1 -> stride [(transpose_r == 'n' ? 0 : 1 )] == 1 &&
1465- m1 -> stride [(transpose_r == 'n' ? 1 : 0 )] != 0 )
1481+ m1 -> stride [(transpose_r == 'n' ? 1 : 0 )] >= THMax ( 1 , k ) )
14661482 {
14671483 transpose_m1 = 'n' ;
14681484 m1_ = m1 ;
14691485 }
14701486 else if (m1 -> stride [(transpose_r == 'n' ? 1 : 0 )] == 1 &&
1471- m1 -> stride [(transpose_r == 'n' ? 0 : 1 )] != 0 )
1487+ m1 -> stride [(transpose_r == 'n' ? 0 : 1 )] >= THMax ( 1 , m ) )
14721488 {
14731489 transpose_m1 = 't' ;
14741490 m1_ = m1 ;
@@ -1481,14 +1497,15 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
14811497 }
14821498
14831499 /* m2 */
1500+ /* Need ldm2_ >= max(1, (transpose_m2 == 't' ? n : k)) */
14841501 if (m2 -> stride [(transpose_r == 'n' ? 0 : 1 )] == 1 &&
1485- m2 -> stride [(transpose_r == 'n' ? 1 : 0 )] != 0 )
1502+ m2 -> stride [(transpose_r == 'n' ? 1 : 0 )] >= THMax ( 1 , k ) )
14861503 {
14871504 transpose_m2 = 'n' ;
14881505 m2_ = m2 ;
14891506 }
14901507 else if (m2 -> stride [(transpose_r == 'n' ? 1 : 0 )] == 1 &&
1491- m2 -> stride [(transpose_r == 'n' ? 0 : 1 )] != 0 )
1508+ m2 -> stride [(transpose_r == 'n' ? 0 : 1 )] >= THMax ( 1 , n ) )
14921509 {
14931510 transpose_m2 = 't' ;
14941511 m2_ = m2 ;
@@ -1500,21 +1517,24 @@ void THTensor_(addmm)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
15001517 free_m2 = 1 ;
15011518 }
15021519
1520+ int64_t ldm1_ = (transpose_m1 == 'n' ? m1_ -> stride [(transpose_r == 'n' ? 1 : 0 )] : m1_ -> stride [(transpose_r == 'n' ? 0 : 1 )]);
1521+ int64_t ldm2_ = (transpose_m2 == 'n' ? m2_ -> stride [(transpose_r == 'n' ? 1 : 0 )] : m2_ -> stride [(transpose_r == 'n' ? 0 : 1 )]);
1522+
15031523#pragma omp critical(blasgemm)
15041524 /* do the operation */
15051525 THBlas_ (gemm )(transpose_m1 ,
15061526 transpose_m2 ,
1507- r__ -> size [( transpose_r == 'n' ? 0 : 1 )] ,
1508- r__ -> size [( transpose_r == 'n' ? 1 : 0 )] ,
1509- m1_ -> size [( transpose_r == 'n' ? 1 : 0 )] ,
1527+ m ,
1528+ n ,
1529+ k ,
15101530 alpha ,
15111531 THTensor_ (data )(m1_ ),
1512- ( transpose_m1 == 'n' ? m1_ -> stride [( transpose_r == 'n' ? 1 : 0 )] : m1_ -> stride [( transpose_r == 'n' ? 0 : 1 )]) ,
1532+ ldm1_ ,
15131533 THTensor_ (data )(m2_ ),
1514- ( transpose_m2 == 'n' ? m2_ -> stride [( transpose_r == 'n' ? 1 : 0 )] : m2_ -> stride [( transpose_r == 'n' ? 0 : 1 )]) ,
1534+ ldm2_ ,
15151535 beta ,
15161536 THTensor_ (data )(r__ ),
1517- r__ -> stride [( transpose_r == 'n' ? 1 : 0 )] );
1537+ ldr__ );
15181538
15191539 /* free intermediate variables */
15201540 if (free_m1 )
@@ -1555,14 +1575,17 @@ void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
15551575 else if (beta != 1 )
15561576 THTensor_ (mul )(r_ , r_ , beta );
15571577
1558- if (r_ -> stride [0 ] == 1 )
1578+ // n == 1 || lda >= max(1, m)
1579+ #define LDA_COND (M , N , LDA ) ((N) == 1 || (LDA) >= THMax(1, (M)))
1580+
1581+ if (r_ -> stride [0 ] == 1 && LDA_COND (vec1 -> size [0 ], vec2 -> size [0 ], r_ -> stride [1 ]))
15591582 {
15601583 THBlas_ (ger )(vec1 -> size [0 ], vec2 -> size [0 ],
15611584 alpha , THTensor_ (data )(vec1 ), vec1 -> stride [0 ],
15621585 THTensor_ (data )(vec2 ), vec2 -> stride [0 ],
15631586 THTensor_ (data )(r_ ), r_ -> stride [1 ]);
15641587 }
1565- else if (r_ -> stride [1 ] == 1 )
1588+ else if (r_ -> stride [1 ] == 1 && LDA_COND ( vec2 -> size [ 0 ], vec1 -> size [ 0 ], r_ -> stride [ 0 ]) )
15661589 {
15671590 THBlas_ (ger )(vec2 -> size [0 ], vec1 -> size [0 ],
15681591 alpha , THTensor_ (data )(vec2 ), vec2 -> stride [0 ],
@@ -1580,6 +1603,8 @@ void THTensor_(addr)(THTensor *r_, real beta, THTensor *t, real alpha, THTensor
15801603
15811604 THTensor_ (freeCopyTo )(cr , r_ );
15821605 }
1606+
1607+ #undef LDA_COND
15831608}
15841609
15851610void THTensor_ (addbmm )(THTensor * result , real beta , THTensor * t , real alpha , THTensor * batch1 , THTensor * batch2 )
0 commit comments