Skip to content

Commit d1ac1eb

Browse files
David Riazatifacebook-github-bot
authored andcommitted
Add bool type to IR (#11834)
Summary: This PR adds a bool type to `IValue` and puts it into place. * changes conds for `prim::If` and `prim::Loop` to use `bool` type * changes operators that take `bool`s to match their native ops * fixes ambiguous `aten` ops `aten::std` and `aten::var` * fixes tests in `test_jit.py TestJitGenerated` ``` 'test_std_dim', 'test_std_dim_1d', 'test_std_dim_1d_neg0', 'test_std_dim_neg0', 'test_var_dim', 'test_var_dim_1d', 'test_var_dim_1d_neg0', 'test_var_dim_neg0' ``` * adds `prim::BoolToTensor` and `prim::TensorToBool` apaszke zdevito Pull Request resolved: #11834 Differential Revision: D9928570 Pulled By: driazati fbshipit-source-id: 373c53df2f1a8ffa9e33d9a517002fbeef25f3eb
1 parent c029c83 commit d1ac1eb

File tree

58 files changed

+440
-227
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+440
-227
lines changed

aten/src/ATen/core/ivalue.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
_(Tensor) \
77
_(Double) \
88
_(Int) \
9+
_(Bool) \
910
_(Tuple) \
1011
_(IntList) \
1112
_(DoubleList) \
13+
_(BoolList) \
1214
_(String) \
1315
_(TensorList) \
1416
_(Blob) \

aten/src/ATen/core/ivalue.h

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ struct C10_EXPORT Tuple : public List<IValue> {
7474
using IntList = List<int64_t>;
7575
using TensorList = List<at::Tensor>;
7676
using DoubleList = List<double>;
77+
using BoolList = List<bool>;
7778
using GenericList = List<IValue>;
7879

7980
// IValue is the generic tagged union used by the interpreter to hold
@@ -88,9 +89,11 @@ using GenericList = List<IValue>;
8889
_(Tensor) \
8990
_(Double) \
9091
_(Int) \
92+
_(Bool) \
9193
_(Tuple) \
9294
_(IntList) \
9395
_(DoubleList) \
96+
_(BoolList) \
9497
_(String) \
9598
_(TensorList) \
9699
_(Blob) \
@@ -224,8 +227,6 @@ struct CAFFE2_API IValue final {
224227
// allow you to pass literals (3, 4) without ambiguity
225228
IValue(int32_t i)
226229
: IValue(static_cast<int64_t>(i)) {}
227-
IValue(bool b)
228-
: IValue(static_cast<int64_t>(b)) {}
229230

230231
bool isInt() const { return Tag::Int == tag; }
231232

@@ -234,6 +235,17 @@ struct CAFFE2_API IValue final {
234235
return payload.as_int;
235236
}
236237

238+
// Bool
239+
IValue(bool b)
240+
: tag(Tag::Bool), is_intrusive_ptr(false) {
241+
payload.as_bool = b;
242+
}
243+
bool isBool() const { return Tag::Bool == tag; }
244+
bool toBool() const {
245+
AT_ASSERT(isBool());
246+
return payload.as_bool;
247+
}
248+
237249
// IntList
238250
IValue(c10::intrusive_ptr<IntList> v);
239251
IValue(std::vector<int64_t> v);
@@ -251,6 +263,7 @@ struct CAFFE2_API IValue final {
251263

252264
const std::vector<int64_t>& toIntListRef() const;
253265
const std::vector<double>& toDoubleListRef() const;
266+
const std::vector<bool>& toBoolListRef() const;
254267
const std::vector<at::Tensor>& toTensorListRef() const;
255268
const std::vector<IValue>& toGenericListRef() const;
256269

@@ -280,6 +293,19 @@ struct CAFFE2_API IValue final {
280293
return toIntrusivePtr<DoubleList>();
281294
}
282295

296+
// BoolList
297+
IValue(c10::intrusive_ptr<BoolList> v);
298+
IValue(std::vector<bool> v);
299+
bool isBoolList() const { return Tag::BoolList == tag; }
300+
c10::intrusive_ptr<BoolList> toBoolList() && {
301+
AT_ASSERT(isBoolList());
302+
return moveToIntrusivePtr<BoolList>();
303+
}
304+
c10::intrusive_ptr<BoolList> toBoolList() const & {
305+
AT_ASSERT(isBoolList());
306+
return toIntrusivePtr<BoolList>();
307+
}
308+
283309
//TensorList
284310
IValue(c10::intrusive_ptr<TensorList> v);
285311
IValue(std::vector<at::Tensor> v);
@@ -323,15 +349,16 @@ struct CAFFE2_API IValue final {
323349
}
324350
}
325351
bool isScalar() {
326-
return isDouble() || isInt();
352+
return isDouble() || isInt() || isBool();
327353
}
328354
at::Scalar toScalar() const {
329355
if(isDouble())
330356
return toDouble();
331357
else if(isInt())
332358
return toInt();
333-
else
334-
throw std::runtime_error("IValue is not a Scalar");
359+
else if (isBool())
360+
return int(toBool());
361+
throw std::runtime_error("IValue is not a Scalar");
335362
}
336363

337364
// for debugging
@@ -396,6 +423,7 @@ struct CAFFE2_API IValue final {
396423
union {
397424
int64_t as_int;
398425
double as_double;
426+
bool as_bool;
399427
c10::intrusive_ptr_target* as_intrusive_ptr;
400428
World as_world;
401429
} payload;
@@ -419,15 +447,16 @@ DEFINE_TO(at::Tensor, toTensor)
419447
DEFINE_TO(c10::intrusive_ptr<Tuple>, toTuple)
420448
DEFINE_TO(double, toDouble)
421449
DEFINE_TO(int64_t, toInt)
450+
DEFINE_TO(bool, toBool)
422451
DEFINE_TO(c10::intrusive_ptr<DoubleList>, toDoubleList)
423452
DEFINE_TO(c10::intrusive_ptr<IntList>, toIntList)
424453
DEFINE_TO(c10::intrusive_ptr<TensorList>, toTensorList)
425454
DEFINE_TO(c10::intrusive_ptr<GenericList>, toGenericList)
426455
DEFINE_TO(c10::intrusive_ptr<ConstantString>, toString)
427456
DEFINE_TO(at::Scalar, toScalar)
428-
DEFINE_TO(bool, toInt)
429457
DEFINE_TO(std::vector<int64_t>, toIntListRef)
430458
DEFINE_TO(std::vector<double>, toDoubleListRef)
459+
DEFINE_TO(std::vector<bool>, toBoolListRef)
431460
DEFINE_TO(std::vector<at::Tensor>, toTensorListRef)
432461
DEFINE_TO(std::vector<IValue>, toGenericListRef)
433462
DEFINE_TO(World, toWorld)
@@ -490,6 +519,13 @@ inline IValue::IValue(c10::intrusive_ptr<DoubleList> v)
490519
inline IValue::IValue(std::vector<double> v)
491520
: IValue(DoubleList::create(std::move(v))) {}
492521

522+
inline IValue::IValue(c10::intrusive_ptr<BoolList> v)
523+
: tag(Tag::BoolList), is_intrusive_ptr(true) {
524+
payload.as_intrusive_ptr = v.release();
525+
}
526+
inline IValue::IValue(std::vector<bool> v)
527+
: IValue(BoolList::create(std::move(v))) {}
528+
493529
inline IValue::IValue(c10::intrusive_ptr<TensorList> v)
494530
: tag(Tag::TensorList), is_intrusive_ptr(true) {
495531
payload.as_intrusive_ptr = v.release();
@@ -517,6 +553,10 @@ inline const std::vector<at::Tensor>& IValue::toTensorListRef() const {
517553
return toTensorList()->elements();
518554
}
519555

556+
inline const std::vector<bool>& IValue::toBoolListRef() const {
557+
return toBoolList()->elements();
558+
}
559+
520560
inline const std::vector<IValue>& IValue::toGenericListRef() const {
521561
return toGenericList()->elements();
522562
}

test/expect/TestBatched.test_for.expect

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ graph(%x.1_data : Dynamic
55
%y_mask : Dynamic
66
%y_dims : Dynamic) {
77
%6 : int = prim::Constant[value=10]()
8-
%7 : int = prim::Constant[value=1]()
8+
%7 : bool = prim::Constant[value=1]()
99
%x : Dynamic, %9 : Dynamic, %10 : Dynamic = prim::Loop(%6, %7, %x.1_data, %x.1_mask, %x.1_dims)
1010
block0(%loop_num : int, %5_data : Dynamic, %5_mask : Dynamic, %5_dims : Dynamic) {
1111
%15 : int = prim::Constant[value=1]()
@@ -14,7 +14,7 @@ graph(%x.1_data : Dynamic
1414
%data.1 : Dynamic = aten::add(%5_data, %y_data, %alpha)
1515
%mask : Dynamic = aten::mul(%5_mask, %y_mask)
1616
%dims : Dynamic = aten::__or__(%5_dims, %y_dims)
17-
%21 : int = prim::Constant[value=1]()
17+
%21 : bool = prim::Constant[value=1]()
1818
%data : Dynamic = aten::where(%mask, %data.1, %5_data)
1919
-> (%21, %data, %mask, %dims)
2020
}

test/expect/TestBatched.test_if_else.expect

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ graph(%a.1_data : Dynamic
77
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
88
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
99
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
10-
%9 : int = prim::TensorToNum(%6)
10+
%9 : bool = prim::TensorToBool(%6)
1111
%10 : int = prim::Constant[value=1]()
1212
%11 : Long() = prim::NumToTensor(%10)
1313
%alpha.1 : float = prim::TensorToNum(%11)
@@ -24,17 +24,17 @@ graph(%a.1_data : Dynamic
2424
%23 : Dynamic = aten::type_as(%7, %6)
2525
%cond_mask.1 : Dynamic = aten::mul(%6, %23)
2626
%25 : int = aten::dim(%cond_mask.1)
27-
%26 : int = aten::eq(%25, %22)
27+
%26 : bool = aten::eq(%25, %22)
2828
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%26)
2929
block0() {
3030
%30 : int = aten::dim(%data.1)
3131
%31 : int = aten::sub(%30, %22)
32-
%32 : int = prim::Constant[value=1]()
32+
%32 : bool = prim::Constant[value=1]()
3333
%data.3 : Dynamic = prim::Loop(%31, %32, %cond_mask.1)
3434
block0(%_ : int, %35 : Dynamic) {
3535
%36 : int = aten::dim(%35)
3636
%data.2 : Dynamic = aten::unsqueeze(%35, %36)
37-
%38 : int = prim::Constant[value=1]()
37+
%38 : bool = prim::Constant[value=1]()
3838
-> (%38, %data.2)
3939
}
4040
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)

test/expect/TestBatched.test_if_else_with_scalar.expect

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ graph(%a.1_data : Dynamic
88
%7 : Float() = prim::NumToTensor(%6)
99
%other : float = prim::TensorToNum(%7)
1010
%9 : Dynamic = aten::gt(%a.1_data, %other)
11-
%10 : int = prim::TensorToNum(%9)
11+
%10 : bool = prim::TensorToBool(%9)
1212
%11 : int = prim::Constant[value=1]()
1313
%12 : Long() = prim::NumToTensor(%11)
1414
%alpha.1 : float = prim::TensorToNum(%12)
@@ -25,17 +25,17 @@ graph(%a.1_data : Dynamic
2525
%24 : Dynamic = aten::type_as(%a.1_mask, %9)
2626
%cond_mask.1 : Dynamic = aten::mul(%9, %24)
2727
%26 : int = aten::dim(%cond_mask.1)
28-
%27 : int = aten::eq(%26, %23)
28+
%27 : bool = aten::eq(%26, %23)
2929
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%27)
3030
block0() {
3131
%31 : int = aten::dim(%data.1)
3232
%32 : int = aten::sub(%31, %23)
33-
%33 : int = prim::Constant[value=1]()
33+
%33 : bool = prim::Constant[value=1]()
3434
%data.3 : Dynamic = prim::Loop(%32, %33, %cond_mask.1)
3535
block0(%_ : int, %36 : Dynamic) {
3636
%37 : int = aten::dim(%36)
3737
%data.2 : Dynamic = aten::unsqueeze(%36, %37)
38-
%39 : int = prim::Constant[value=1]()
38+
%39 : bool = prim::Constant[value=1]()
3939
-> (%39, %data.2)
4040
}
4141
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)

test/expect/TestBatched.test_if_noelse.expect

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ graph(%a.1_data : Dynamic
77
%6 : Dynamic = aten::gt(%a.1_data, %b_data)
88
%7 : Dynamic = aten::mul(%a.1_mask, %b_mask)
99
%8 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
10-
%9 : int = prim::TensorToNum(%6)
10+
%9 : bool = prim::TensorToBool(%6)
1111
%10 : int = prim::Constant[value=1]()
1212
%11 : Long() = prim::NumToTensor(%10)
1313
%alpha : float = prim::TensorToNum(%11)
@@ -18,17 +18,17 @@ graph(%a.1_data : Dynamic
1818
%17 : Dynamic = aten::type_as(%7, %6)
1919
%cond_mask.1 : Dynamic = aten::mul(%6, %17)
2020
%19 : int = aten::dim(%cond_mask.1)
21-
%20 : int = aten::eq(%19, %16)
21+
%20 : bool = aten::eq(%19, %16)
2222
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%20)
2323
block0() {
2424
%24 : int = aten::dim(%data.1)
2525
%25 : int = aten::sub(%24, %16)
26-
%26 : int = prim::Constant[value=1]()
26+
%26 : bool = prim::Constant[value=1]()
2727
%data.3 : Dynamic = prim::Loop(%25, %26, %cond_mask.1)
2828
block0(%_ : int, %29 : Dynamic) {
2929
%30 : int = aten::dim(%29)
3030
%data.2 : Dynamic = aten::unsqueeze(%29, %30)
31-
%32 : int = prim::Constant[value=1]()
31+
%32 : bool = prim::Constant[value=1]()
3232
-> (%32, %data.2)
3333
}
3434
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)

test/expect/TestBatched.test_if_noelse_with_scalar.expect

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ graph(%a.1_data : Dynamic
88
%7 : Float() = prim::NumToTensor(%6)
99
%other : float = prim::TensorToNum(%7)
1010
%9 : Dynamic = aten::gt(%a.1_data, %other)
11-
%10 : int = prim::TensorToNum(%9)
11+
%10 : bool = prim::TensorToBool(%9)
1212
%11 : int = prim::Constant[value=1]()
1313
%12 : Long() = prim::NumToTensor(%11)
1414
%alpha : float = prim::TensorToNum(%12)
@@ -19,17 +19,17 @@ graph(%a.1_data : Dynamic
1919
%18 : Dynamic = aten::type_as(%a.1_mask, %9)
2020
%cond_mask.1 : Dynamic = aten::mul(%9, %18)
2121
%20 : int = aten::dim(%cond_mask.1)
22-
%21 : int = aten::eq(%20, %17)
22+
%21 : bool = aten::eq(%20, %17)
2323
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%21)
2424
block0() {
2525
%25 : int = aten::dim(%data.1)
2626
%26 : int = aten::sub(%25, %17)
27-
%27 : int = prim::Constant[value=1]()
27+
%27 : bool = prim::Constant[value=1]()
2828
%data.3 : Dynamic = prim::Loop(%26, %27, %cond_mask.1)
2929
block0(%_ : int, %30 : Dynamic) {
3030
%31 : int = aten::dim(%30)
3131
%data.2 : Dynamic = aten::unsqueeze(%30, %31)
32-
%33 : int = prim::Constant[value=1]()
32+
%33 : bool = prim::Constant[value=1]()
3333
-> (%33, %data.2)
3434
}
3535
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)

test/expect/TestBatched.test_while.expect

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@ graph(%a.1_data : Dynamic
88
%7 : Dynamic = aten::gt(%a.1_data, %b_data)
99
%8 : Dynamic = aten::mul(%a.1_mask, %b_mask)
1010
%9 : Dynamic = aten::__or__(%a.1_dims, %b_dims)
11-
%10 : int = prim::TensorToNum(%7)
11+
%10 : bool = prim::TensorToBool(%7)
1212
%11 : int = prim::Constant[value=0]()
1313
%12 : Dynamic = aten::mul(%7, %8)
1414
%13 : Dynamic = aten::sum(%12)
1515
%14 : Dynamic = aten::gt(%13, %11)
16-
%15 : int = prim::TensorToNum(%14)
16+
%15 : bool = prim::TensorToBool(%14)
1717
%16 : Dynamic, %17 : Dynamic, %18 : Dynamic, %a : Dynamic, %20 : Dynamic, %21 : Dynamic = prim::Loop(%6, %15, %7, %8, %9, %a.1_data, %a.1_mask, %a.1_dims)
1818
block0(%loop_num : int, %cond_data.2 : Dynamic, %cond_mask.3 : Dynamic, %cond_dims : Dynamic, %6_data : Dynamic, %6_mask : Dynamic, %6_dims : Dynamic) {
1919
%29 : int = prim::Constant[value=1]()
@@ -25,22 +25,22 @@ graph(%a.1_data : Dynamic
2525
%35 : Dynamic = aten::gt(%data.1, %b_data)
2626
%36 : Dynamic = aten::mul(%mask, %b_mask)
2727
%37 : Dynamic = aten::__or__(%dims, %b_dims)
28-
%38 : int = prim::TensorToNum(%35)
28+
%38 : bool = prim::TensorToBool(%35)
2929
%39 : int = prim::Constant[value=1]()
3030
%40 : Dynamic = aten::type_as(%cond_mask.3, %cond_data.2)
3131
%cond_mask.1 : Dynamic = aten::mul(%cond_data.2, %40)
3232
%42 : int = aten::dim(%cond_mask.1)
33-
%43 : int = aten::eq(%42, %39)
33+
%43 : bool = aten::eq(%42, %39)
3434
%cond_data : Dynamic, %cond_mask : Dynamic, %data : Dynamic = prim::If(%43)
3535
block0() {
3636
%47 : int = aten::dim(%data.1)
3737
%48 : int = aten::sub(%47, %39)
38-
%49 : int = prim::Constant[value=1]()
38+
%49 : bool = prim::Constant[value=1]()
3939
%data.3 : Dynamic = prim::Loop(%48, %49, %cond_mask.1)
4040
block0(%_ : int, %52 : Dynamic) {
4141
%53 : int = aten::dim(%52)
4242
%data.2 : Dynamic = aten::unsqueeze(%52, %53)
43-
%55 : int = prim::Constant[value=1]()
43+
%55 : bool = prim::Constant[value=1]()
4444
-> (%55, %data.2)
4545
}
4646
%cond_data.1 : Dynamic = aten::expand_as(%data.3, %data.1)
@@ -57,7 +57,7 @@ graph(%a.1_data : Dynamic
5757
%62 : Dynamic = aten::mul(%35, %36)
5858
%63 : Dynamic = aten::sum(%62)
5959
%64 : Dynamic = aten::gt(%63, %61)
60-
%65 : int = prim::TensorToNum(%64)
60+
%65 : bool = prim::TensorToBool(%64)
6161
-> (%65, %35, %36, %37, %res_data, %res_mask, %res_dims)
6262
}
6363
return (%a, %20, %21);

test/expect/TestJit.test_batchnorm.expect

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ graph(%0 : Double(2, 2, 2, 2)
44
%3 : Double(2)
55
%4 : Double(2)
66
%5 : Long()) {
7-
%6 : int = prim::Constant[value=1](), scope: BatchNorm2d
7+
%6 : bool = prim::Constant[value=1](), scope: BatchNorm2d
88
%7 : float = prim::Constant[value=0.1](), scope: BatchNorm2d
99
%8 : float = prim::Constant[value=1e-05](), scope: BatchNorm2d
10-
%9 : int = prim::Constant[value=1](), scope: BatchNorm2d
10+
%9 : bool = prim::Constant[value=1](), scope: BatchNorm2d
1111
%10 : Double(2, 2, 2, 2) = aten::batch_norm(%0, %1, %2, %3, %4, %6, %7, %8, %9), scope: BatchNorm2d
1212
return (%10);
1313
}

test/expect/TestJit.test_constant_prop_if_constant.expect

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
graph(%a : Dynamic
22
%b : Dynamic) {
33
%c2.1 : int = prim::Constant[value=1]()
4-
%3 : int = prim::TensorToNum(%a)
4+
%3 : bool = prim::TensorToBool(%a)
55
%c0.4 : int, %c1 : int = prim::If(%3)
66
block0() {
7-
%6 : int = prim::TensorToNum(%b)
7+
%6 : bool = prim::TensorToBool(%b)
88
%c0.3 : int = prim::If(%6)
99
block0() {
1010
%8 : int = prim::Constant[value=2]()

0 commit comments

Comments
 (0)