Skip to content

Commit f2c704b

Browse files
d0krtg0795
authored andcommitted
[XLA:CPU] Refactor CpuExecutable so LLVM errors can be propagated
Otherwise we'd crash on cases like non-existing CustomCall target. PiperOrigin-RevId: 562563302
1 parent 8543b7b commit f2c704b

File tree

5 files changed

+73
-40
lines changed

5 files changed

+73
-40
lines changed

tensorflow/compiler/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,6 +2116,7 @@ tf_xla_py_test(
21162116
"//tensorflow/python:framework",
21172117
"//tensorflow/python:platform_test",
21182118
"//tensorflow/python:training",
2119+
"//tensorflow/python/framework:errors",
21192120
"//tensorflow/python/platform:client_testlib",
21202121
],
21212122
)

tensorflow/compiler/tests/xla_custom_call_ops_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from tensorflow.compiler.tf2xla.python import xla
1919
from tensorflow.python.eager import def_function
2020
from tensorflow.python.framework import dtypes
21+
from tensorflow.python.framework import errors_impl
2122
from tensorflow.python.framework import ops
2223
from tensorflow.python.framework import tensor_spec
2324
from tensorflow.python.ops import random_ops
@@ -46,6 +47,22 @@ def f(x, y):
4647
self.assertIn('custom_call_target="my_call"', hlo)
4748
self.assertIn('backend_config="my_backend_config"', hlo)
4849

50+
def testXlaCustomCallOpDoesntExist(self):
51+
with ops.device('device:{}:0'.format(self.device)):
52+
53+
def f():
54+
return xla.custom_call(
55+
args=(1, 2),
56+
target_name='my_non_existing_call_target',
57+
dtype=dtypes.int32,
58+
shape=(),
59+
backend_config='my_backend_config',
60+
)
61+
62+
with self.assertRaises(errors_impl.InvalidArgumentError):
63+
compiled_f = def_function.function(f, jit_compile=True)
64+
compiled_f()
65+
4966
def testXlaCustomCallV2Op(self):
5067
with ops.device('device:{}:0'.format(self.device)):
5168

tensorflow/compiler/xla/service/cpu/cpu_compiler.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1403,9 +1403,12 @@ CpuCompiler::CompileLegacyCpuExecutable(std::unique_ptr<HloModule> module) {
14031403
std::move(llvm_context));
14041404
cantFail((*jit)->AddModule(std::move(thread_safe_module)));
14051405

1406-
auto cpu_executable = std::make_unique<CpuExecutable>(
1407-
std::move(*jit), std::move(assignment), std::move(module), function_name,
1408-
std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map));
1406+
TF_ASSIGN_OR_RETURN(
1407+
auto cpu_executable,
1408+
CpuExecutable::Create(std::move(*jit), std::move(assignment),
1409+
std::move(module), function_name,
1410+
std::move(hlo_profile_printer_data),
1411+
std::move(hlo_profile_index_map)));
14091412

14101413
if (embed_ir_in_executable) {
14111414
cpu_executable->set_ir_module_string(ir_module_string);
@@ -1507,7 +1510,7 @@ CpuCompiler::CompileXlaRuntimeCpuExecutable(
15071510
obj_file);
15081511
}
15091512

1510-
return std::make_unique<CpuExecutable>(
1513+
return CpuExecutable::Create(
15111514
std::move(hlo_module), std::move(hlo_profile_printer_data),
15121515
std::move(hlo_profile_index_map), std::move(assignment),
15131516
std::move(xla_runtime_executable));

tensorflow/compiler/xla/service/cpu/cpu_executable.cc

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -58,53 +58,60 @@ namespace cpu {
5858

5959
namespace runtime = ::xla::runtime;
6060

61-
CpuExecutable::CpuExecutable(
61+
StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
6262
std::unique_ptr<SimpleOrcJIT> jit,
6363
std::unique_ptr<const BufferAssignment> assignment,
6464
std::unique_ptr<HloModule> hlo_module,
6565
const std::string& entry_function_name,
6666
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
67-
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
68-
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
69-
std::move(hlo_profile_index_map)),
70-
jit_(std::move(jit)),
71-
assignment_(std::move(assignment)),
72-
module_name_(entry_function_name) {
73-
if (assignment_) {
74-
buffer_assignment_ =
75-
std::make_shared<BufferAssignmentProto>(assignment_->ToProto());
76-
}
77-
if (has_module()) {
78-
XlaDebugInfoManager::Get()->RegisterModule(
79-
module().unique_id(), shared_module(), buffer_assignment_);
80-
}
67+
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) {
68+
std::unique_ptr<CpuExecutable> executable(new CpuExecutable(
69+
std::move(hlo_module), std::move(hlo_profile_printer_data),
70+
std::move(hlo_profile_index_map), std::move(assignment)));
71+
executable->jit_ = std::move(jit);
72+
executable->module_name_ = entry_function_name;
8173

8274
// Resolve symbols in the constructor rather than at execution time to avoid
8375
// races because FindSymbol is not thread safe.
8476
llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
85-
jit_->FindCompiledSymbol(entry_function_name);
77+
executable->jit_->FindCompiledSymbol(entry_function_name);
8678
// We expect to find the symbol provided with entry_function_name; otherwise
8779
// this is an internal error.
88-
CHECK(sym->getAddress()) << "Symbol " << entry_function_name << " not found.";
80+
if (!sym) {
81+
return absl::InvalidArgumentError(
82+
absl::StrCat("Symbol ", entry_function_name, " not found."));
83+
}
8984
// getAddress can do work under the hood in the jit, so it needs to be
9085
// guarded by the mutex.
91-
compute_function_ =
86+
executable->compute_function_ =
9287
reinterpret_cast<ComputeFunctionType>(sym->getAddress().getValue());
9388
VLOG(1) << "compute_function_ at address "
94-
<< reinterpret_cast<void*>(compute_function_);
95-
jit_->DoneCompiling();
89+
<< reinterpret_cast<void*>(executable->compute_function_);
90+
executable->jit_->DoneCompiling();
91+
return executable;
9692
}
9793

98-
CpuExecutable::CpuExecutable(
94+
StatusOr<std::unique_ptr<CpuExecutable>> CpuExecutable::Create(
9995
std::unique_ptr<HloModule> hlo_module,
10096
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
10197
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
10298
std::unique_ptr<const BufferAssignment> assignment,
103-
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable)
99+
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable) {
100+
std::unique_ptr<CpuExecutable> executable(new CpuExecutable(
101+
std::move(hlo_module), std::move(hlo_profile_printer_data),
102+
std::move(hlo_profile_index_map), std::move(assignment)));
103+
executable->xla_runtime_executable_ = std::move(xla_runtime_executable);
104+
return executable;
105+
}
106+
107+
CpuExecutable::CpuExecutable(
108+
std::unique_ptr<HloModule> hlo_module,
109+
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
110+
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
111+
std::unique_ptr<const BufferAssignment> assignment)
104112
: Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
105113
std::move(hlo_profile_index_map)),
106-
assignment_(std::move(assignment)),
107-
xla_runtime_executable_(std::move(xla_runtime_executable)) {
114+
assignment_(std::move(assignment)) {
108115
if (assignment_) {
109116
buffer_assignment_ =
110117
std::make_shared<BufferAssignmentProto>(assignment_->ToProto());
@@ -328,9 +335,9 @@ StatusOr<std::unique_ptr<Executable>> CpuExecutable::LoadFromObjFile(
328335
std::move(executable_ptr), xla_framework_mapping,
329336
std::move(*ffi_modules_state));
330337

331-
return std::unique_ptr<Executable>(new CpuExecutable(
332-
std::move(hlo_module), nullptr, nullptr, std::move(buffer_assignment),
333-
std::move(xla_runtime_executable)));
338+
return CpuExecutable::Create(std::move(hlo_module), nullptr, nullptr,
339+
std::move(buffer_assignment),
340+
std::move(xla_runtime_executable));
334341
}
335342

336343
StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer(

tensorflow/compiler/xla/service/cpu/cpu_executable.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,15 @@ class XlaRuntimeCpuExecutable {
138138
// architecture, so JIT-ed code and host code share the same ABI.
139139
class CpuExecutable : public Executable {
140140
public:
141-
CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
142-
std::unique_ptr<const BufferAssignment> assignment,
143-
std::unique_ptr<HloModule> hlo_module,
144-
const std::string& entry_function_name,
145-
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
146-
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
147-
// XLA Runtime constructor.
148-
CpuExecutable(
141+
static StatusOr<std::unique_ptr<CpuExecutable>> Create(
142+
std::unique_ptr<SimpleOrcJIT> jit,
143+
std::unique_ptr<const BufferAssignment> assignment,
144+
std::unique_ptr<HloModule> hlo_module,
145+
const std::string& entry_function_name,
146+
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
147+
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
148+
// XLA Runtime factory method.
149+
static StatusOr<std::unique_ptr<CpuExecutable>> Create(
149150
std::unique_ptr<HloModule> hlo_module,
150151
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
151152
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
@@ -257,7 +258,7 @@ class CpuExecutable : public Executable {
257258
const InstructionValueSet& GetRootValueSet() const;
258259

259260
// The JIT containing compiled modules.
260-
const std::unique_ptr<SimpleOrcJIT> jit_;
261+
std::unique_ptr<SimpleOrcJIT> jit_;
261262

262263
// Buffer assignment for the buffers we need to allocate.
263264
const std::unique_ptr<const BufferAssignment> assignment_;
@@ -281,6 +282,10 @@ class CpuExecutable : public Executable {
281282
// If not null, XLA Runtime is enabled.
282283
std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable_;
283284

285+
CpuExecutable(std::unique_ptr<HloModule> hlo_module,
286+
std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
287+
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
288+
std::unique_ptr<const BufferAssignment> assignment);
284289
CpuExecutable(const CpuExecutable&) = delete;
285290
CpuExecutable& operator=(const CpuExecutable&) = delete;
286291
};

0 commit comments

Comments
 (0)