Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions tensorflow/core/kernels/segment_reduction_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ namespace tensorflow {

class OpKernelContext;

bool RequireDeterminism();
bool DisableSegmentReductionOpDeterminismExceptions();

namespace functor {

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
Expand All @@ -54,6 +57,8 @@ struct SegmentReductionFunctor {
typename TTypes<Index>::ConstFlat segment_ids,
const Index data_size, const T* data,
typename TTypes<T, 2>::Tensor output);
static constexpr bool atomic_reduction_is_associative =
AtomicReductionF::is_associative;
};

#endif
Expand All @@ -76,6 +81,7 @@ struct AtomicSumOpGpu {
const T& value) {
GpuAtomicAdd(dest, value);
}
static constexpr bool is_associative = std::is_integral<T>::value;
};

template <typename T>
Expand All @@ -84,6 +90,7 @@ struct AtomicProdOpGpu {
const T& value) {
GpuAtomicMul(dest, value);
}
static constexpr bool is_associative = std::is_integral<T>::value;
};

template <typename T>
Expand All @@ -92,6 +99,7 @@ struct AtomicMaxOpGpu {
const T& value) {
GpuAtomicMax(dest, value);
}
static constexpr bool is_associative = true;
};

template <typename T>
Expand All @@ -100,6 +108,7 @@ struct AtomicMinOpGpu {
const T& value) {
GpuAtomicMin(dest, value);
}
static constexpr bool is_associative = true;
};

// Non-atomic reduction functors for the gpu.
Expand Down
35 changes: 35 additions & 0 deletions tensorflow/core/kernels/segment_reduction_ops_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ limitations under the License.

#include "tensorflow/core/kernels/segment_reduction_ops.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/gpu_device_functions.h"

namespace tensorflow {
Expand Down Expand Up @@ -126,6 +127,30 @@ __global__ void UnsortedSegmentCustomKernel(
}
}

// TODO(duncanriach): move this into a utility and share it
bool RequireDeterminism() {
static bool require_determinism = [] {
bool deterministic_ops = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
/*default_val=*/false,
&deterministic_ops));
return deterministic_ops;
}();
return require_determinism;
}

bool DisableSegmentReductionOpDeterminismExceptions() {
static bool cached_disable = [] {
bool disable = false;
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar(
"TF_DISABLE_SEGMENT_REDUCTION_OP_DETERMINISM_EXCEPTIONS",
/*default_val=*/false,
&disable));
return disable;
}();
return cached_disable;
}

namespace functor {

template <typename T, typename Index, typename InitialValueF,
Expand All @@ -141,6 +166,7 @@ void SegmentReductionFunctor<
if (output.size() == 0) {
return;
}

// Set 'output' to initial value.
GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
const T InitialValue = InitialValueF()();
Expand Down Expand Up @@ -188,6 +214,15 @@ struct UnsortedSegmentFunctor<GPUDevice, T, Index, InitialValueF, ReductionF> {
if (output.size() == 0) {
return;
}

bool determinism_requirement_met =
ReductionF::is_associative ||
!RequireDeterminism() ||
DisableSegmentReductionOpDeterminismExceptions();
OP_REQUIRES(ctx, determinism_requirement_met, errors::Unimplemented(
"Deterministic GPU implementation of unsorted segment reduction op"
" not available."));

// Set 'output' to initial value.
GPUDevice d = ctx->template eigen_device<GPUDevice>();
GpuLaunchConfig config = GetGpuLaunchConfig(output.size(), d);
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/core/kernels/segment_reduction_ops_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ class SegmentReductionOp : public OpKernel {
};

#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM

// SegmentReductionGPUOp is a segment reduction operator implemented for GPU
// only.
// TODO: This implementation of SegmentReductionGPUOp is sometimes slower than
Expand Down Expand Up @@ -292,6 +293,19 @@ class SegmentReductionGPUOp : public AsyncOpKernel {
OP_REQUIRES_OK_ASYNC(
context, context->allocate_output(0, output_shape, &output), done);

// The determinism check is here, rather than inside the functor (as it is
// for the unsorted segment reduction ops) because the done callback
// (required for OP_REQUIRES_ASYNC) is not available inside the functor.
bool determinism_requirement_met =
SegmentReductionFunctor::atomic_reduction_is_associative ||
!RequireDeterminism() ||
DisableSegmentReductionOpDeterminismExceptions();
OP_REQUIRES_ASYNC(
context, determinism_requirement_met, errors::Unimplemented(
"Deterministic GPU implementation of sorted segment reduction op"
" not available."),
done);

auto output_flat = output->flat_outer_dims<T>();
auto data_ptr = input.template flat<T>().data();
auto segment_flat = segment_ids.flat<Index>();
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/python/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,20 @@ tf_py_test(
],
)

cuda_py_test(
name = "segment_reduction_ops_deterministic_test",
size = "small",
srcs = ["segment_reduction_ops_deterministic_test.py"],
xla_enable_strict_auto_jit = False,
deps = [
"//tensorflow/python:array_ops",
"//tensorflow/python:framework",
"//tensorflow/python:math_ops",
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
],
)

cuda_py_test(
name = "segment_reduction_ops_test",
size = "medium",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for deterministic functionality of segment reduction ops."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.python.eager import backprop
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import indexed_slices
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test


class SegmentReductionDeterminismExceptionsTest(test.TestCase):
"""
Test that tf.errors.UnimplementedError is thrown or not thrown, as
appropriate, by the GPU code-paths for the segment reduction ops when
determinsitic ops are enabled. This test assumes that the base op test
runs all the same test cases when deterministic ops are not enabled and
will therefore detect erroneous exception throwing in those cases.
"""

def _input(self, data_type, segment_ids_type):
data = constant_op.constant([[1,2,3,4], [5,6,7,8]], dtype=data_type)
segment_ids = constant_op.constant([0, 1], dtype=segment_ids_type)
num_segments = 2
return data, segment_ids, num_segments

@test_util.run_cuda_only
def testSortedOps(self):
op_should_throw_for_float = {
math_ops.segment_max : False,
math_ops.segment_min : False,
math_ops.segment_mean: False, # implemented on CPU only
math_ops.segment_prod: True,
math_ops.segment_sum : True,
}
for op, should_throw_for_float in op_should_throw_for_float.items():
for segment_ids_type in [dtypes.int32, dtypes.int64]:
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64]:
with self.cached_session(force_gpu=True):
data, segment_ids, _ = self._input(data_type, segment_ids_type)
if should_throw_for_float:
with self.assertRaisesRegex(
errors_impl.UnimplementedError,
"Deterministic GPU implementation of sorted segment " +
"reduction op not available."):
op(data, segment_ids)
else:
op(data, segment_ids)

_UNSORTED_ERROR_MESSAGE = ("Deterministic GPU implementation of unsorted " +
"segment reduction op not available.")

@test_util.run_cuda_only
@test_util.run_in_graph_and_eager_modes
def testUnsortedOps(self):
op_should_throw_for_float = {
math_ops.unsorted_segment_max : False,
math_ops.unsorted_segment_min : False,
math_ops.unsorted_segment_mean : True, # uses unsorted_segment_sum
math_ops.unsorted_segment_sqrt_n: True, # uses unsorted_segment_sum
math_ops.unsorted_segment_prod : True,
math_ops.unsorted_segment_sum : True,
}
with self.session(force_gpu=True):
for op, should_throw_for_float in op_should_throw_for_float.items():
for segment_ids_type in [dtypes.int32, dtypes.int64]:
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.int32]:
if (op == math_ops.unsorted_segment_sqrt_n and
data_type == dtypes.int32): # sqrt_n doesn't support int32
continue
data, segment_ids, num_segments = self._input(data_type,
segment_ids_type)
if (data_type != dtypes.int32) and should_throw_for_float:
with self.assertRaisesRegex(
errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE):
result = op(data, segment_ids, num_segments)
self.evaluate(result)
else:
result = op(data, segment_ids, num_segments)
self.evaluate(result)

@test_util.run_cuda_only
def testUnsortedOpsComplex(self):
for op in [
math_ops.unsorted_segment_mean , # uses unsorted_segment_sum
math_ops.unsorted_segment_sqrt_n, # uses unsorted_segment_sum
math_ops.unsorted_segment_sum,
]:
for data_type in [dtypes.complex64, dtypes.complex128]:
for segment_ids_type in [dtypes.int32, dtypes.int64]:
with self.cached_session(force_gpu=True):
data, segment_ids, num_segments = self._input(data_type,
segment_ids_type)
with self.assertRaisesRegex(
errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE):
op(data, segment_ids, num_segments)

@test_util.run_cuda_only
@test_util.run_in_graph_and_eager_modes
def testConvertToTensor(self):
with self.session(force_gpu=True):
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128]:
for segment_ids_type in [dtypes.int32, dtypes.int64]:
values, indices, _ = self._input(data_type, segment_ids_type)
sparse_value = indexed_slices.IndexedSlices(
values, indices, dense_shape=values.shape)
with self.assertRaisesRegex(
errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE):
# convert_to_tensor with IndexedSlices uses unsorted_segment_sum
result = ops.convert_to_tensor(sparse_value)
self.evaluate(result)

@test_util.run_cuda_only
def testGatherBackprop(self):
for data_type in [dtypes.float16, dtypes.float32, dtypes.float64,
dtypes.complex64, dtypes.complex128]:
for segment_ids_type in [dtypes.int32, dtypes.int64]:
with self.cached_session(force_gpu=True):
params, indices, _ = self._input(dtypes.float32, dtypes.int32)
params = variables.Variable(params)
with backprop.GradientTape() as tape:
tape.watch(params)
op_output = array_ops.gather(params, indices)
gradient = tape.gradient(op_output, params)
with self.assertRaisesRegex(
errors_impl.UnimplementedError, self._UNSORTED_ERROR_MESSAGE):
params.assign(gradient) # convert_to_tensor on IndexedSlices


if __name__ == "__main__":
# Note that the effect of setting the following environment variable to
# 'true' is not tested. Unless we can find a simpler pattern for testing these
# environment variables, it would require this file to be made into a base
# and then two more test files to be created.
os.environ['TF_DETERMINISTIC_OPS'] = '1'
test.main()