@@ -74,6 +74,7 @@ struct C10_EXPORT Tuple : public List<IValue> {
7474using IntList = List<int64_t >;
7575using TensorList = List<at::Tensor>;
7676using DoubleList = List<double >;
77+ using BoolList = List<bool >;
7778using GenericList = List<IValue>;
7879
7980// IValue is the generic tagged union used by the interpreter to hold
@@ -88,9 +89,11 @@ using GenericList = List<IValue>;
8889 _ (Tensor) \
8990 _ (Double) \
9091 _ (Int) \
92+ _ (Bool) \
9193 _ (Tuple) \
9294 _ (IntList) \
9395 _ (DoubleList) \
96+ _ (BoolList) \
9497 _ (String) \
9598 _ (TensorList) \
9699 _ (Blob) \
@@ -224,8 +227,6 @@ struct CAFFE2_API IValue final {
224227 // allow you to pass literals (3, 4) without ambiguity
225228 IValue (int32_t i)
226229 : IValue(static_cast <int64_t >(i)) {}
227- IValue (bool b)
228- : IValue(static_cast <int64_t >(b)) {}
229230
230231 bool isInt () const { return Tag::Int == tag; }
231232
@@ -234,6 +235,17 @@ struct CAFFE2_API IValue final {
234235 return payload.as_int ;
235236 }
236237
238+ // Bool
239+ IValue (bool b)
240+ : tag(Tag::Bool), is_intrusive_ptr(false ) {
241+ payload.as_bool = b;
242+ }
243+ bool isBool () const { return Tag::Bool == tag; }
244+ bool toBool () const {
245+ AT_ASSERT (isBool ());
246+ return payload.as_bool ;
247+ }
248+
237249 // IntList
238250 IValue (c10::intrusive_ptr<IntList> v);
239251 IValue (std::vector<int64_t > v);
@@ -251,6 +263,7 @@ struct CAFFE2_API IValue final {
251263
252264 const std::vector<int64_t >& toIntListRef () const ;
253265 const std::vector<double >& toDoubleListRef () const ;
266+ const std::vector<bool >& toBoolListRef () const ;
254267 const std::vector<at::Tensor>& toTensorListRef () const ;
255268 const std::vector<IValue>& toGenericListRef () const ;
256269
@@ -280,6 +293,19 @@ struct CAFFE2_API IValue final {
280293 return toIntrusivePtr<DoubleList>();
281294 }
282295
296+ // BoolList
297+ IValue (c10::intrusive_ptr<BoolList> v);
298+ IValue (std::vector<bool > v);
299+ bool isBoolList () const { return Tag::BoolList == tag; }
300+ c10::intrusive_ptr<BoolList> toBoolList () && {
301+ AT_ASSERT (isBoolList ());
302+ return moveToIntrusivePtr<BoolList>();
303+ }
304+ c10::intrusive_ptr<BoolList> toBoolList () const & {
305+ AT_ASSERT (isBoolList ());
306+ return toIntrusivePtr<BoolList>();
307+ }
308+
283309 // TensorList
284310 IValue (c10::intrusive_ptr<TensorList> v);
285311 IValue (std::vector<at::Tensor> v);
@@ -323,15 +349,16 @@ struct CAFFE2_API IValue final {
323349 }
324350 }
325351 bool isScalar () {
326- return isDouble () || isInt ();
352+ return isDouble () || isInt () || isBool () ;
327353 }
328354 at::Scalar toScalar () const {
329355 if (isDouble ())
330356 return toDouble ();
331357 else if (isInt ())
332358 return toInt ();
333- else
334- throw std::runtime_error (" IValue is not a Scalar" );
359+ else if (isBool ())
360+ return int (toBool ());
361+ throw std::runtime_error (" IValue is not a Scalar" );
335362 }
336363
337364 // for debugging
@@ -396,6 +423,7 @@ struct CAFFE2_API IValue final {
396423 union {
397424 int64_t as_int;
398425 double as_double;
426+ bool as_bool;
399427 c10::intrusive_ptr_target* as_intrusive_ptr;
400428 World as_world;
401429 } payload;
@@ -419,15 +447,16 @@ DEFINE_TO(at::Tensor, toTensor)
419447DEFINE_TO (c10::intrusive_ptr<Tuple>, toTuple)
420448DEFINE_TO (double , toDouble)
421449DEFINE_TO (int64_t , toInt)
450+ DEFINE_TO (bool , toBool)
422451DEFINE_TO (c10::intrusive_ptr<DoubleList>, toDoubleList)
423452DEFINE_TO (c10::intrusive_ptr<IntList>, toIntList)
424453DEFINE_TO (c10::intrusive_ptr<TensorList>, toTensorList)
425454DEFINE_TO (c10::intrusive_ptr<GenericList>, toGenericList)
426455DEFINE_TO (c10::intrusive_ptr<ConstantString>, toString)
427456DEFINE_TO (at::Scalar, toScalar)
428- DEFINE_TO (bool , toInt)
429457DEFINE_TO (std::vector<int64_t >, toIntListRef)
430458DEFINE_TO (std::vector<double >, toDoubleListRef)
459+ DEFINE_TO (std::vector<bool >, toBoolListRef)
431460DEFINE_TO (std::vector<at::Tensor>, toTensorListRef)
432461DEFINE_TO (std::vector<IValue>, toGenericListRef)
433462DEFINE_TO (World, toWorld)
@@ -490,6 +519,13 @@ inline IValue::IValue(c10::intrusive_ptr<DoubleList> v)
490519inline IValue::IValue (std::vector<double > v)
491520: IValue(DoubleList::create(std::move(v))) {}
492521
522+ inline IValue::IValue (c10::intrusive_ptr<BoolList> v)
523+ : tag(Tag::BoolList), is_intrusive_ptr(true ) {
524+ payload.as_intrusive_ptr = v.release ();
525+ }
526+ inline IValue::IValue (std::vector<bool > v)
527+ : IValue(BoolList::create(std::move(v))) {}
528+
493529inline IValue::IValue (c10::intrusive_ptr<TensorList> v)
494530: tag(Tag::TensorList), is_intrusive_ptr(true ) {
495531 payload.as_intrusive_ptr = v.release ();
@@ -517,6 +553,10 @@ inline const std::vector<at::Tensor>& IValue::toTensorListRef() const {
517553 return toTensorList ()->elements ();
518554}
519555
556+ inline const std::vector<bool >& IValue::toBoolListRef () const {
557+ return toBoolList ()->elements ();
558+ }
559+
520560inline const std::vector<IValue>& IValue::toGenericListRef () const {
521561 return toGenericList ()->elements ();
522562}
0 commit comments