|
| 1 | +#! /usr/bin/python |
| 2 | +# -*- coding: utf-8 -*- |
| 3 | +"""A file containing functions related to array manipulation.""" |
| 4 | + |
| 5 | +from tensorflow.python.eager import context |
| 6 | +from tensorflow.python.framework import constant_op |
| 7 | +from tensorflow.python.framework import dtypes |
| 8 | +from tensorflow.python.framework import ops |
| 9 | +from tensorflow.python.framework import tensor_shape |
| 10 | +from tensorflow.python.framework.constant_op import constant |
| 11 | +from tensorflow.python.framework.ops import convert_to_tensor |
| 12 | +from tensorflow.python.ops.array_ops import shape_internal |
| 13 | +from tensorflow.python.ops.gen_array_ops import fill |
| 14 | +from tensorflow.python.ops.gen_array_ops import reshape |
| 15 | + |
| 16 | +__all__ = ['alphas', 'alphas_like'] |
| 17 | + |
| 18 | + |
| 19 | +def alphas(shape, alpha_value, name=None): |
| 20 | + """Creates a tensor with all elements set to `alpha_value`. |
| 21 | + This operation returns a tensor of type `dtype` with shape `shape` and all |
| 22 | + elements set to alpha. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + shape: A list of integers, a tuple of integers, or a 1-D `Tensor` of type `int32`. |
| 27 | + The shape of the desired tensor |
| 28 | + alpha_value: `float32`, `float64`, `int8`, `uint8`, `int16`, `uint16`, int32`, `int64` |
| 29 | + The value used to fill the resulting `Tensor`. |
| 30 | + name: str |
| 31 | + A name for the operation (optional). |
| 32 | +
|
| 33 | + Returns |
| 34 | + ------- |
| 35 | + A `Tensor` with all elements set to alpha. |
| 36 | +
|
| 37 | + Examples |
| 38 | + -------- |
| 39 | + >>> tl.alphas([2, 3], tf.int32) # [[alpha, alpha, alpha], [alpha, alpha, alpha]] |
| 40 | + """ |
| 41 | + |
| 42 | + with ops.name_scope(name, "alphas", [shape]) as name: |
| 43 | + |
| 44 | + alpha_tensor = convert_to_tensor(alpha_value) |
| 45 | + alpha_dtype = dtypes.as_dtype(alpha_tensor.dtype).base_dtype |
| 46 | + |
| 47 | + if not isinstance(shape, ops.Tensor): |
| 48 | + try: |
| 49 | + shape = constant_op._tensor_shape_tensor_conversion_function(tensor_shape.TensorShape(shape)) |
| 50 | + except (TypeError, ValueError): |
| 51 | + shape = ops.convert_to_tensor(shape, dtype=dtypes.int32) |
| 52 | + |
| 53 | + if not shape._shape_tuple(): |
| 54 | + shape = reshape(shape, [-1]) # Ensure it's a vector |
| 55 | + |
| 56 | + try: |
| 57 | + output = constant(alpha_value, shape=shape, dtype=alpha_dtype, name=name) |
| 58 | + |
| 59 | + except (TypeError, ValueError): |
| 60 | + output = fill(shape, constant(alpha_value, dtype=alpha_dtype), name=name) |
| 61 | + |
| 62 | + if output.dtype.base_dtype != alpha_dtype: |
| 63 | + raise AssertionError("Dtypes do not corresponds: %s and %s" % (output.dtype.base_dtype, alpha_dtype)) |
| 64 | + |
| 65 | + return output |
| 66 | + |
| 67 | + |
| 68 | +def alphas_like(tensor, alpha_value, name=None, optimize=True): |
| 69 | + """Creates a tensor with all elements set to `alpha_value`. |
| 70 | + Given a single tensor (`tensor`), this operation returns a tensor of the same |
| 71 | + type and shape as `tensor` with all elements set to `alpha_value`. |
| 72 | +
|
| 73 | + Parameters |
| 74 | + ---------- |
| 75 | + tensor: tf.Tensor |
| 76 | + The Tensorflow Tensor that will be used as a template. |
| 77 | + alpha_value: `float32`, `float64`, `int8`, `uint8`, `int16`, `uint16`, int32`, `int64` |
| 78 | + The value used to fill the resulting `Tensor`. |
| 79 | + name: str |
| 80 | + A name for the operation (optional). |
| 81 | + optimize: bool |
| 82 | + if true, attempt to statically determine the shape of 'tensor' and encode it as a constant. |
| 83 | +
|
| 84 | + Returns |
| 85 | + ------- |
| 86 | + A `Tensor` with all elements set to `alpha_value`. |
| 87 | +
|
| 88 | + Examples |
| 89 | + -------- |
| 90 | + >>> tensor = tf.constant([[1, 2, 3], [4, 5, 6]]) |
| 91 | + >>> tl.alphas_like(tensor, 0.5) # [[0.5, 0.5, 0.5], [0.5, 0.5, 0.5]] |
| 92 | + """ |
| 93 | + |
| 94 | + with ops.name_scope(name, "alphas_like", [tensor]) as name: |
| 95 | + tensor = ops.convert_to_tensor(tensor, name="tensor") |
| 96 | + |
| 97 | + if context.in_eager_mode(): #and dtype is not None and dtype != tensor.dtype: |
| 98 | + ret = alphas(shape_internal(tensor, optimize=optimize), alpha_value=alpha_value, name=name) |
| 99 | + |
| 100 | + else: # if context.in_graph_mode(): |
| 101 | + |
| 102 | + # For now, variant types must be created via zeros_like; as we need to |
| 103 | + # pass the input variant object to the proper zeros callback. |
| 104 | + |
| 105 | + if (optimize and tensor.shape.is_fully_defined()): |
| 106 | + # We can produce a zeros tensor independent of the value of 'tensor', |
| 107 | + # since the shape is known statically. |
| 108 | + ret = alphas(tensor.shape, alpha_value=alpha_value, name=name) |
| 109 | + |
| 110 | + # elif dtype is not None and dtype != tensor.dtype and dtype != dtypes.variant: |
| 111 | + else: |
| 112 | + ret = alphas(shape_internal(tensor, optimize=optimize), alpha_value=alpha_value, name=name) |
| 113 | + |
| 114 | + ret.set_shape(tensor.get_shape()) |
| 115 | + |
| 116 | + return ret |
0 commit comments