forked from SoftwareDesignXRays/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtf_export.py
More file actions
128 lines (103 loc) · 4.04 KB
/
tf_export.py
File metadata and controls
128 lines (103 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for exporting TensorFlow symbols to the API.
Exporting a function or a class:
To export a function or a class use tf_export decorator. For e.g.:
```python
@tf_export('foo', 'bar.foo')
def foo(...):
...
```
If a function is assigned to a variable, you can export it by calling
tf_export explicitly. For e.g.:
```python
foo = get_foo(...)
tf_export('foo', 'bar.foo')(foo)
```
Exporting a constant
```python
foo = 1
tf_export("consts.foo").export_constant(__name__, 'foo')
```
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from tensorflow.python.util import tf_decorator
class SymbolAlreadyExposedError(Exception):
"""Raised when adding API names to symbol that already has API names."""
pass
class tf_export(object): # pylint: disable=invalid-name
"""Provides ways to export symbols to the TensorFlow API."""
def __init__(self, *args, **kwargs):
"""Export under the names *args (first one is considered canonical).
Args:
*args: API names in dot delimited format.
**kwargs: Optional keyed arguments. Currently only supports 'overrides'
argument. overrides: List of symbols that this is overriding
(those overrided api exports will be removed). Note: passing overrides
has no effect on exporting a constant.
"""
self._names = args
self._overrides = kwargs.get('overrides', [])
def __call__(self, func):
"""Calls this decorator.
Args:
func: decorated symbol (function or class).
Returns:
The input function with _tf_api_names attribute set.
Raises:
SymbolAlreadyExposedError: Raised when a symbol already has API names.
"""
# Undecorate overridden names
for f in self._overrides:
_, undecorated_f = tf_decorator.unwrap(f)
del undecorated_f._tf_api_names # pylint: disable=protected-access
_, undecorated_func = tf_decorator.unwrap(func)
# Check for an existing api. We check if attribute name is in
# __dict__ instead of using hasattr to verify that subclasses have
# their own _tf_api_names as opposed to just inheriting it.
if '_tf_api_names' in undecorated_func.__dict__:
# pylint: disable=protected-access
raise SymbolAlreadyExposedError(
'Symbol %s is already exposed as %s.' %
(undecorated_func.__name__, undecorated_func._tf_api_names))
# pylint: enable=protected-access
# Complete the export by creating/overriding attribute
# pylint: disable=protected-access
undecorated_func._tf_api_names = self._names
# pylint: enable=protected-access
return func
def export_constant(self, module_name, name):
"""Store export information for constants/string literals.
Export information is stored in the module where constants/string literals
are defined.
e.g.
```python
foo = 1
bar = 2
tf_export("consts.foo").export_constant(__name__, 'foo')
tf_export("consts.bar").export_constant(__name__, 'bar')
```
Args:
module_name: (string) Name of the module to store constant at.
name: (string) Current constant name.
"""
module = sys.modules[module_name]
if not hasattr(module, '_tf_api_constants'):
module._tf_api_constants = [] # pylint: disable=protected-access
# pylint: disable=protected-access
module._tf_api_constants.append((self._names, name))