forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathspecial_math_ops.py
More file actions
162 lines (128 loc) · 5.17 KB
/
Copy pathspecial_math_ops.py
File metadata and controls
162 lines (128 loc) · 5.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Arithmetic Operations that don't fit into math_ops due to dependencies.
To avoid circular dependencies, some math_ops should go here. Documentation
callouts, e.g. "@@my_op" should go in math_ops. To the user, these are just
normal math_ops.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import re
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
def lbeta(x, name='lbeta'):
r"""Computes `ln(|Beta(x)|)`, reducing along the last dimension.
Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define
```Beta(z) = \prod_j Gamma(z_j) / Gamma(\sum_j z_j)```
And for `n + 1` dimensional `x` with shape `[N1, ..., Nn, K]`, we define
`lbeta(x)[i1, ..., in] = Log(|Beta(x[i1, ..., in, :])|)`. In other words,
the last dimension is treated as the `z` vector.
Note that if `z = [u, v]`, then
`Beta(z) = int_0^1 t^{u-1} (1 - t)^{v-1} dt`, which defines the traditional
bivariate beta function.
Args:
x: A rank `n + 1` `Tensor` with type `float`, or `double`.
name: A name for the operation (optional).
Returns:
The logarithm of `|Beta(x)|` reducing along the last dimension.
Raises:
ValueError: If `x` is empty with rank one or less.
"""
with ops.name_scope(name, values=[x]):
x = ops.convert_to_tensor(x, name='x')
x = control_flow_ops.with_dependencies(
[check_ops.assert_rank_at_least(x, 1)], x)
is_empty = math_ops.equal(0, array_ops.size(x))
def nonempty_lbeta():
log_prod_gamma_x = math_ops.reduce_sum(
math_ops.lgamma(x), reduction_indices=[-1])
sum_x = math_ops.reduce_sum(x, reduction_indices=[-1])
log_gamma_sum_x = math_ops.lgamma(sum_x)
result = log_prod_gamma_x - log_gamma_sum_x
return result
def empty_lbeta():
# If x is empty, return version with one less dimension.
# Can only do this if rank >= 2.
assertion = check_ops.assert_rank_at_least(x, 2)
with ops.control_dependencies([assertion]):
return array_ops.squeeze(x, squeeze_dims=[0])
static_size = x.get_shape().num_elements()
if static_size is not None:
if static_size > 0:
return nonempty_lbeta()
else:
return empty_lbeta()
else:
return control_flow_ops.cond(is_empty, empty_lbeta, nonempty_lbeta)
def einsum(axes, *inputs):
"""
A generalized contraction between tensors of arbitrary dimension.
Like numpy.einsum.
"""
match = re.match('([a-z,]+)->([a-z]+)', axes)
assert match, \
"Indices have incorrect format: %s" % axes
inputs = list(inputs)
idx_in = match.group(1).split(',')
idx_out = match.group(2)
idx_all = set(''.join(idx_in))
assert len(idx_in) == len(inputs), \
"Expected %d inputs but only got %d" % (len(idx_in), len(inputs))
# transpose inputs so axes are in alphabetical order
for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
assert input_.get_shape().ndims == len(axes_), \
"Input %d with axes %s has incorrect" \
" number of dimensions (expected %d, got %d)" % (
i, axes_, len(axes_), input_.get_shape().ndims
)
sorted_idx = sorted(axes_)
if list(axes_) != sorted_idx:
permuted = [axes_.find(ax) for ax in sorted_idx]
inputs[i] = array_ops.transpose(input_, permuted)
idx_in[i] = sorted_idx
missing_idx = set(idx_out).difference(idx_all)
assert not missing_idx, \
"Unknown ouput axes: %s" % missing_idx
reduction_idx = []
shapes = [[dim if dim else -1
for dim in tensor.get_shape().as_list()]
for tensor in inputs]
# validate shapes for broadcasting
for j, ax in enumerate(sorted(idx_all)):
dims = []
for i, idx in enumerate(idx_in):
if ax not in idx:
shapes[i].insert(j, 1)
else:
dim = shapes[i][j]
if isinstance(dim, int) and dim > 1:
dims.append(dim)
assert len(set(dims)) <= 1, \
"Dimension mismatch on axis: %s" % ax
if ax not in idx_out:
reduction_idx.append(j)
# reshape, multiply
expanded_inputs = [array_ops.reshape(input_, shape)
for input_, shape in zip(inputs, shapes)]
expanded_output = 1
for input_ in expanded_inputs:
expanded_output *= input_
# contract
return math_ops.reduce_sum(expanded_output, reduction_idx)