@@ -2117,38 +2117,53 @@ def upsample_bilinear(input, size=None, scale_factor=None):
21172117
21182118def grid_sample (input , grid , mode = 'bilinear' , padding_mode = 'zeros' ):
21192119 r"""Given an :attr:`input` and a flow-field :attr:`grid`, computes the
2120- `output` using input pixel locations from the grid.
2121-
2122- Uses bilinear interpolation to sample the input pixels.
2123- Currently, only spatial (4 dimensional) and volumetric (5 dimensional)
2124- inputs are supported.
2125-
2126- For each output location, :attr:`grid` has `x`, `y`
2127- input pixel locations which are used to compute output.
2128- In the case of 5D inputs, :attr:`grid` has `x`, `y`, `z` pixel locations.
2129-
2130- .. Note::
2131- To avoid confusion in notation, let's note that `x` corresponds to the `width` dimension `IW`,
2132- `y` corresponds to the height dimension `IH` and `z` corresponds to the `depth` dimension `ID`.
2133-
2134- :attr:`grid` has values in the range of `[-1, 1]`. This is because the
2135- pixel locations are normalized by the input height and width.
2136-
2137- For example, values: x: -1, y: -1 is the left-top pixel of the input, and
2138- values: x: 1, y: 1 is the right-bottom pixel of the input.
2139-
2140- If :attr:`grid` has values outside the range of `[-1, 1]`, those locations
2141- are handled as defined by `padding_mode`. Options are `zeros` or `border`,
2142- defining those locations to use 0 or image border values as contribution
2143- to the bilinear interpolation.
2144-
2145- .. Note:: This function is used in building Spatial Transformer Networks
2120+ ``output`` using :attr:`input` values and pixel locations from :attr:`grid`.
2121+
2122+ Currently, only spatial (4-D) and volumetric (5-D) :attr:`input` are
2123+ supported.
2124+
2125+ In the spatial (4-D) case, for :attr:`input` with shape
2126+ :math:`(N, C, H_\text{in}, W_\text{in})` and :attr:`grid` with shape
2127+ :math:`(N, H_\text{out}, W_\text{out}, 2)`, the output will have shape
2128+ :math:`(N, C, H_\text{out}, W_\text{out})`.
2129+
2130+ For each output location ``output[n, :, h, w]``, the size-2 vector
2131+ ``grid[n, h, w]`` specifies :attr:`input` pixel locations ``x`` and ``y``,
2132+ which are used to interpolate the output value ``output[n, :, h, w]``.
2133+ In the case of 5D inputs, ``grid[n, d, h, w]`` specifies the
2134+ ``x``, ``y``, ``z`` pixel locations for interpolating
2135+ ``output[n, :, d, h, w]``. :attr:`mode` argument specifies ``nearest`` or
2136+ ``bilinear`` interpolation method to sample the input pixels.
2137+
2138+ :attr:`grid` should have most values in the range of ``[-1, 1]``. This is
2139+ because the pixel locations are normalized by the :attr:`input` spatial
2140+ dimensions. For example, values ``x = -1, y = -1`` is the left-top pixel of
2141+ :attr:`input`, and values ``x = 1, y = 1`` is the right-bottom pixel of
2142+ :attr:`input`.
2143+
2144+ If :attr:`grid` has values outside the range of ``[-1, 1]``, those locations
2145+ are handled as defined by :attr:`padding_mode`. Options are
2146+
2147+ * ``padding_mode="zeros"``: use ``0`` for out-of-bound values,
2148+ * ``padding_mode="border"``: use border values for out-of-bound values,
2149+ * ``padding_mode="reflection"``: use values at locations reflected by
2150+ the border for out-of-bound values. For location far away from the
2151+ border, it will keep being reflected until becoming in bound, e.g.,
2152+ (normalized) pixel location ``x = -3.5`` reflects by ``-1`` and
2153+ becomes ``x' = 2.5``, then reflects by border ``1`` and becomes
2154+ ``x'' = -0.5``.
2155+
2156+ .. Note:: This function is often used in building Spatial Transformer Networks.
21462157
21472158 Args:
2148- input (Tensor): input batch (N x C x IH x IW) or (N x C x ID x IH x IW)
2149- grid (Tensor): flow-field of size (N x OH x OW x 2) or (N x OD x OH x OW x 3)
2159+ input (Tensor): input of shape :math:`(N, C, H_\text{in}, W_\text{in})` (4-D case)
2160+ or :math:`(N, C, D_\text{in}, H_\text{in}, W_\text{in})` (5-D case)
2161+ grid (Tensor): flow-field of shape :math:`(N, H_\text{out}, W_\text{out}, 2)` (4-D case)
2162+ or :math:`(N, D_\text{out}, H_\text{out}, W_\text{out}, 3)` (5-D case)
2163+ mode (str): interpolation mode to calculate output values
2164+ 'bilinear' | 'nearest'. Default: 'bilinear'
21502165 padding_mode (str): padding mode for outside grid values
2151- 'zeros' | 'border'. Default: 'zeros'
2166+ 'zeros' | 'border' | 'reflection' . Default: 'zeros'
21522167
21532168 Returns:
21542169 output (Tensor): output Tensor
0 commit comments