@@ -287,6 +287,55 @@ Tensor full_like(const Tensor& self, Scalar fill_value, const TensorOptions& opt
287287 return native::full (self.sizes (), fill_value, options);
288288}
289289
290+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ fill diagonal ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
291+
292+ Tensor& fill_diagonal_ (Tensor& self, Scalar fill_value, bool wrap) {
293+ int64_t nDims = self.dim ();
294+ TORCH_CHECK (nDims >= 2 , " dimensions must larger than 1" );
295+
296+ int64_t height = self.size (0 );
297+ int64_t width = self.size (1 );
298+
299+ if (nDims > 2 ) {
300+ int64_t dim1 = height;
301+ for (int64_t i = 1 ; i < nDims; i++) {
302+ if (self.size (i) != dim1) {
303+ AT_ERROR (" all dimensions of input must be of equal length" );
304+ }
305+ }
306+ }
307+
308+ int64_t storage_offset = self.storage_offset ();
309+ std::vector<int64_t > sizes;
310+ std::vector<int64_t > strides;
311+ int64_t size = std::min (height, width);
312+
313+ int64_t stride = 0 ;
314+ for (int64_t i = 0 ; i < nDims; i++) {
315+ stride += self.stride (i);
316+ }
317+ strides.push_back (stride);
318+ sizes.push_back (size);
319+
320+ auto main_diag = self.as_strided (sizes, strides, storage_offset);
321+ main_diag.fill_ (fill_value);
322+
323+ if (wrap && nDims == 2 && height > width + 1 ) {
324+ std::vector<int64_t > wrap_sizes;
325+
326+ int64_t step = width + 1 ;
327+ int64_t wrap_size = ((self.numel () + step - 1 ) / step) - size;
328+ wrap_sizes.push_back (wrap_size);
329+
330+ int64_t offset = self.stride (0 ) * (width + 1 );
331+
332+ auto wrap_diag = self.as_strided (wrap_sizes, strides, storage_offset + offset);
333+ wrap_diag.fill_ (fill_value);
334+ }
335+
336+ return self;
337+ }
338+
290339// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linspace ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
291340
292341Tensor linspace (
0 commit comments