-
Notifications
You must be signed in to change notification settings - Fork 75.2k
[determinism] Add segment reduction op exceptions for GPU determinism #47772
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
copybara-service
merged 4 commits into
tensorflow:master
from
duncanriach:segment-sum-nond9m-exceptions
Mar 18, 2021
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
de45867
[determinism] Add segment reduction op exceptions
duncanriach 56c34b2
[determinism] Disable XLA auto-jit for segment reduction d9m-unimplem…
duncanriach b61a34d
[determinism] Address review, step 1, on PR 47772
duncanriach 71d777c
[determinism] Fix buildifier error
duncanriach File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
162 changes: 162 additions & 0 deletions
162
tensorflow/python/kernel_tests/segment_reduction_ops_deterministic_test.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # Copyright 2021 The TensorFlow Authors. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # ============================================================================== | ||
| """Tests for deterministic functionality of segment reduction ops.""" | ||
|
|
||
| from __future__ import absolute_import | ||
| from __future__ import division | ||
| from __future__ import print_function | ||
|
|
||
| import os | ||
|
|
||
| from tensorflow.python.eager import backprop | ||
| from tensorflow.python.framework import constant_op | ||
| from tensorflow.python.framework import dtypes | ||
| from tensorflow.python.framework import errors_impl | ||
| from tensorflow.python.framework import indexed_slices | ||
| from tensorflow.python.framework import ops | ||
| from tensorflow.python.framework import test_util | ||
| from tensorflow.python.ops import array_ops | ||
| from tensorflow.python.ops import math_ops | ||
| from tensorflow.python.ops import variables | ||
| from tensorflow.python.platform import test | ||
|
|
||
|
|
||
| class SegmentReductionDeterminismExceptionsTest(test.TestCase): | ||
| """ | ||
| Test that tf.errors.UnimplementedError is thrown or not thrown, as | ||
| appropriate, by the GPU code-paths for the segment reduction ops when | ||
| determinsitic ops are enabled. This test assumes that the base op test | ||
| runs all the same test cases when deterministic ops are not enabled and | ||
| will therefore detect erroneous exception throwing in those cases. | ||
| """ | ||
|
|
||
| def _input(self, data_type, segment_ids_type): | ||
| data = constant_op.constant([[1,2,3,4], [5,6,7,8]], dtype=data_type) | ||
| segment_ids = constant_op.constant([0, 1], dtype=segment_ids_type) | ||
| num_segments = 2 | ||
| return data, segment_ids, num_segments | ||
|
|
||
| @test_util.run_cuda_only | ||
| def testSortedOps(self): | ||
| op_should_throw_for_float = { | ||
| math_ops.segment_max : False, | ||
| math_ops.segment_min : False, | ||
| math_ops.segment_mean: False, # implemented on CPU only | ||
| math_ops.segment_prod: True, | ||
| math_ops.segment_sum : True, | ||
| } | ||
| for op, should_throw_for_float in op_should_throw_for_float.items(): | ||
| for segment_ids_type in [dtypes.int32, dtypes.int64]: | ||
| for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]: | ||
| with self.cached_session(force_gpu=True): | ||
| data, segment_ids, _ = self._input(data_type, segment_ids_type) | ||
| if should_throw_for_float: | ||
| with self.assertRaisesRegex( | ||
| errors_impl.UnimplementedError, | ||
| "Deterministic GPU implementation of sorted segment " + | ||
| "reduction op not available."): | ||
| op(data, segment_ids) | ||
| else: | ||
| op(data, segment_ids) | ||
|
|
||
| _UNSORTED_ERROR_MESSAGE = ("Deterministic GPU implementation of unsorted " + | ||
| "segment reduction op not available.") | ||
|
|
||
| @test_util.run_cuda_only | ||
| @test_util.run_in_graph_and_eager_modes | ||
| def testUnsortedOps(self): | ||
| op_should_throw_for_float = { | ||
| math_ops.unsorted_segment_max : False, | ||
| math_ops.unsorted_segment_min : False, | ||
| math_ops.unsorted_segment_mean : True, # uses unsorted_segment_sum | ||
| math_ops.unsorted_segment_sqrt_n: True, # uses unsorted_segment_sum | ||
| math_ops.unsorted_segment_prod : True, | ||
| math_ops.unsorted_segment_sum : True, | ||
| } | ||
| with self.session(force_gpu=True): | ||
| for op, should_throw_for_float in op_should_throw_for_float.items(): | ||
| for segment_ids_type in [dtypes.int32, dtypes.int64]: | ||
| for data_type in [dtypes.float16, dtypes.float32, dtypes.float64, | ||
| dtypes.int32]: | ||
| if (op == math_ops.unsorted_segment_sqrt_n and | ||
| data_type == dtypes.int32): # sqrt_n doesn't support int32 | ||
| continue | ||
| data, segment_ids, num_segments = self._input(data_type, | ||
| segment_ids_type) | ||
| if (data_type != dtypes.int32) and should_throw_for_float: | ||
| with self.assertRaisesRegex( | ||
| errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE): | ||
| result = op(data, segment_ids, num_segments) | ||
| self.evaluate(result) | ||
| else: | ||
| result = op(data, segment_ids, num_segments) | ||
| self.evaluate(result) | ||
|
|
||
| @test_util.run_cuda_only | ||
| def testUnsortedOpsComplex(self): | ||
| for op in [ | ||
| math_ops.unsorted_segment_mean , # uses unsorted_segment_sum | ||
| math_ops.unsorted_segment_sqrt_n, # uses unsorted_segment_sum | ||
| math_ops.unsorted_segment_sum, | ||
| ]: | ||
| for data_type in [dtypes.complex64, dtypes.complex128]: | ||
| for segment_ids_type in [dtypes.int32, dtypes.int64]: | ||
| with self.cached_session(force_gpu=True): | ||
| data, segment_ids, num_segments = self._input(data_type, | ||
| segment_ids_type) | ||
| with self.assertRaisesRegex( | ||
| errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE): | ||
| op(data, segment_ids, num_segments) | ||
|
|
||
| @test_util.run_cuda_only | ||
| @test_util.run_in_graph_and_eager_modes | ||
| def testConvertToTensor(self): | ||
| with self.session(force_gpu=True): | ||
| for data_type in [dtypes.float16, dtypes.float32, dtypes.float64, | ||
| dtypes.complex64, dtypes.complex128]: | ||
| for segment_ids_type in [dtypes.int32, dtypes.int64]: | ||
| values, indices, _ = self._input(data_type, segment_ids_type) | ||
| sparse_value = indexed_slices.IndexedSlices( | ||
| values, indices, dense_shape=values.shape) | ||
| with self.assertRaisesRegex( | ||
| errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE): | ||
| # convert_to_tensor with IndexedSlices uses unsorted_segment_sum | ||
| result = ops.convert_to_tensor(sparse_value) | ||
| self.evaluate(result) | ||
|
|
||
| @test_util.run_cuda_only | ||
| def testGatherBackprop(self): | ||
| for data_type in [dtypes.float16, dtypes.float32, dtypes.float64, | ||
| dtypes.complex64, dtypes.complex128]: | ||
| for segment_ids_type in [dtypes.int32, dtypes.int64]: | ||
| with self.cached_session(force_gpu=True): | ||
| params, indices, _ = self._input(dtypes.float32, dtypes.int32) | ||
| params = variables.Variable(params) | ||
| with backprop.GradientTape() as tape: | ||
| tape.watch(params) | ||
| op_output = array_ops.gather(params, indices) | ||
| gradient = tape.gradient(op_output, params) | ||
| with self.assertRaisesRegex( | ||
| errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE): | ||
| params.assign(gradient) # convert_to_tensor on IndexedSlices | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| # Note that the effect of setting the following environment variable to | ||
| # 'true' is not tested. Unless we can find a simpler pattern for testing these | ||
| # environment variables, it would require this file to be made into a base | ||
| # and then two more test files to be created. | ||
| os.environ['TF_DETERMINISTIC_OPS'] = '1' | ||
| test.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.