@@ -303,7 +303,10 @@ Tensor nested_from_padded_generic(
303303 std::move (new_buffer), sizes);
304304}
305305
306- Tensor NestedTensor_to_padded_tensor_generic (const Tensor& t, double padding) {
306+ Tensor NestedTensor_to_padded_tensor_generic (
307+ const Tensor& t,
308+ double padding,
309+ OptionalIntArrayRef output_size) {
307310 // TODO: skipped optimization for case of all 1x1 tensors
308311 auto & nt = *get_nested_tensor_impl (t);
309312 auto max_size = NestedTensor_get_max_size (nt);
@@ -356,7 +359,22 @@ Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding) {
356359 buffers.push_back (pad_tensor_to_shape (to_pad, max_size, padding));
357360 sizes_ptr += sizes_num_columns;
358361 }
359- return at::stack (buffers);
362+ auto ret_val = at::stack (buffers);
363+
364+ // Pad output tensor to output_size if provided
365+ if (output_size.has_value ()) {
366+ auto output_size_ = output_size.value ();
367+ TORCH_CHECK (
368+ (int64_t )output_size_.size () == ret_val.dim (),
369+ " Length of output_size does not match NestedTensor dims. Broadcasting is not supported." );
370+ for (int64_t i = 0 ; i < (int64_t )ret_val.dim (); i++) {
371+ TORCH_CHECK (
372+ output_size_[i] >= ret_val.size (i),
373+ " Value in output_size is less than NestedTensor padded size. Truncation is not supported." );
374+ }
375+ return pad_tensor_to_shape (ret_val, output_size_, padding);
376+ }
377+ return ret_val;
360378}
361379
362380Tensor NestedTensor_embedding (
@@ -385,5 +403,119 @@ Tensor NestedTensor_embedding(
385403 return at::detail::make_tensor<NestedTensorImpl>(
386404 result_buffer.reshape ({-1 }), std::move (new_sizes));
387405}
406+
407+ std::pair<NestedTensorImpl*, NestedTensorImpl*>
408+ get_elementwise_nested_tensor_impl (
409+ const Tensor& self,
410+ const Tensor& other,
411+ const std::string& op_name) {
412+ if (self.is_nested () && !(other.is_nested ())) {
413+ TORCH_CHECK (
414+ false ,
415+ " Expected both self and other to be nested, but got a nested self and non-nested other" );
416+ } else if (!(self.is_nested ()) && other.is_nested ()) {
417+ TORCH_CHECK (
418+ false ,
419+ " Expected both self and other to be nested, but got a non-nested self and nested other" );
420+ } else if (!(self.is_nested ()) || !(other.is_nested ())) {
421+ TORCH_CHECK (
422+ false ,
423+ " Expected both self and other to be nested, but got a non-nested self and non-nested other" );
424+ }
425+
426+ auto self_ptr = get_nested_tensor_impl (self);
427+ auto other_ptr = get_nested_tensor_impl (other);
428+
429+ TORCH_CHECK (
430+ self.dim () == other.dim (),
431+ op_name,
432+ " does not support broadcasting when given a NestedTensor" );
433+ TORCH_CHECK (
434+ at::equal (
435+ self_ptr->get_nested_size_tensor (),
436+ other_ptr->get_nested_size_tensor ()),
437+ op_name,
438+ " does not support broadcasting when given a NestedTensor" );
439+ TORCH_CHECK (
440+ nested_tensor_impl_is_contiguous (self_ptr) &&
441+ nested_tensor_impl_is_contiguous (other_ptr),
442+ op_name,
443+ " does not support non-contiguous NestedTensor inputs" );
444+ return std::make_pair (self_ptr, other_ptr);
445+ }
446+
447+ template <typename Func>
448+ Tensor NestedTensor_elementwise_Tensor (
449+ const Tensor& self,
450+ const Tensor& other,
451+ const std::string& op_name,
452+ Func f) {
453+ NestedTensorImpl* self_impl = nullptr ;
454+ NestedTensorImpl* other_impl = nullptr ;
455+ std::tie (self_impl, other_impl) =
456+ get_elementwise_nested_tensor_impl (self, other, op_name);
457+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (self_impl);
458+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (other_impl);
459+ const auto & nt_self = *self_impl;
460+ const auto & nt_other = *other_impl;
461+ const auto & self_sizes = nt_self.get_nested_size_tensor ();
462+ return wrap_buffer (
463+ f (nt_self.get_buffer ().reshape ({-1 }),
464+ nt_other.get_buffer ().reshape ({-1 })),
465+ self_sizes);
466+ }
467+
468+ Tensor NestedTensor_add_Tensor (
469+ const Tensor& self,
470+ const Tensor& other,
471+ const Scalar& alpha) {
472+ return NestedTensor_elementwise_Tensor (
473+ self, other, " add" , [alpha](const Tensor& b1, const Tensor& b2) {
474+ return at::add (b1, b2, alpha);
475+ });
476+ }
477+
478+ Tensor NestedTensor_mul_Tensor (const Tensor& self, const Tensor& other) {
479+ return NestedTensor_elementwise_Tensor (
480+ self, other, " mul" , [](const Tensor& b1, const Tensor& b2) {
481+ return at::mul (b1, b2);
482+ });
483+ }
484+
485+ template <typename Func>
486+ Tensor& NestedTensor_elementwise__Tensor (
487+ Tensor& self,
488+ const Tensor& other,
489+ const std::string& op_name,
490+ Func f) {
491+ NestedTensorImpl* self_impl = nullptr ;
492+ NestedTensorImpl* other_impl = nullptr ;
493+ std::tie (self_impl, other_impl) =
494+ get_elementwise_nested_tensor_impl (self, other, op_name);
495+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (self_impl);
496+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY (other_impl);
497+ const auto & nt_self = *self_impl;
498+ const auto & nt_other = *other_impl;
499+ f (nt_self.get_buffer ().view ({-1 }), nt_other.get_buffer ().view ({-1 }));
500+ return self;
501+ }
502+
503+ Tensor& NestedTensor_add__Tensor (
504+ Tensor& self,
505+ const Tensor& other,
506+ const Scalar& alpha) {
507+ return NestedTensor_elementwise__Tensor (
508+ self, other, " add_" , [alpha](const Tensor& b1, const Tensor& b2) {
509+ return b1.add_ (b2, alpha);
510+ });
511+ }
512+
513+ Tensor& NestedTensor_mul__Tensor (Tensor& self, const Tensor& other) {
514+ return NestedTensor_elementwise__Tensor (
515+ self, other, " mul_" , [](const Tensor& b1, const Tensor& b2) {
516+ return b1.mul_ (b2);
517+ });
518+ }
519+
388520} // namespace native
389521} // namespace at
0 commit comments