@@ -58,53 +58,60 @@ namespace cpu {
5858
5959namespace runtime = ::xla::runtime;
6060
61- CpuExecutable:: CpuExecutable (
61+ StatusOr<std::unique_ptr< CpuExecutable>> CpuExecutable::Create (
6262 std::unique_ptr<SimpleOrcJIT> jit,
6363 std::unique_ptr<const BufferAssignment> assignment,
6464 std::unique_ptr<HloModule> hlo_module,
6565 const std::string& entry_function_name,
6666 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
67- std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
68- : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
69- std::move (hlo_profile_index_map)),
70- jit_(std::move(jit)),
71- assignment_(std::move(assignment)),
72- module_name_(entry_function_name) {
73- if (assignment_) {
74- buffer_assignment_ =
75- std::make_shared<BufferAssignmentProto>(assignment_->ToProto ());
76- }
77- if (has_module ()) {
78- XlaDebugInfoManager::Get ()->RegisterModule (
79- module ().unique_id (), shared_module (), buffer_assignment_);
80- }
67+ std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map) {
68+ std::unique_ptr<CpuExecutable> executable (new CpuExecutable (
69+ std::move (hlo_module), std::move (hlo_profile_printer_data),
70+ std::move (hlo_profile_index_map), std::move (assignment)));
71+ executable->jit_ = std::move (jit);
72+ executable->module_name_ = entry_function_name;
8173
8274 // Resolve symbols in the constructor rather than at execution time to avoid
8375 // races because FindSymbol is not thread safe.
8476 llvm::Expected<llvm::orc::ExecutorSymbolDef> sym =
85- jit_->FindCompiledSymbol (entry_function_name);
77+ executable-> jit_ ->FindCompiledSymbol (entry_function_name);
8678 // We expect to find the symbol provided with entry_function_name; otherwise
8779 // this is an internal error.
88- CHECK (sym->getAddress ()) << " Symbol " << entry_function_name << " not found." ;
80+ if (!sym) {
81+ return absl::InvalidArgumentError (
82+ absl::StrCat (" Symbol " , entry_function_name, " not found." ));
83+ }
8984 // getAddress can do work under the hood in the jit, so it needs to be
9085 // guarded by the mutex.
91- compute_function_ =
86+ executable-> compute_function_ =
9287 reinterpret_cast <ComputeFunctionType>(sym->getAddress ().getValue ());
9388 VLOG (1 ) << " compute_function_ at address "
94- << reinterpret_cast <void *>(compute_function_);
95- jit_->DoneCompiling ();
89+ << reinterpret_cast <void *>(executable->compute_function_ );
90+ executable->jit_ ->DoneCompiling ();
91+ return executable;
9692}
9793
98- CpuExecutable:: CpuExecutable (
94+ StatusOr<std::unique_ptr< CpuExecutable>> CpuExecutable::Create (
9995 std::unique_ptr<HloModule> hlo_module,
10096 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
10197 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
10298 std::unique_ptr<const BufferAssignment> assignment,
103- std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable)
99+ std::unique_ptr<XlaRuntimeCpuExecutable> xla_runtime_executable) {
100+ std::unique_ptr<CpuExecutable> executable (new CpuExecutable (
101+ std::move (hlo_module), std::move (hlo_profile_printer_data),
102+ std::move (hlo_profile_index_map), std::move (assignment)));
103+ executable->xla_runtime_executable_ = std::move (xla_runtime_executable);
104+ return executable;
105+ }
106+
107+ CpuExecutable::CpuExecutable (
108+ std::unique_ptr<HloModule> hlo_module,
109+ std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
110+ std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map,
111+ std::unique_ptr<const BufferAssignment> assignment)
104112 : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
105113 std::move (hlo_profile_index_map)),
106- assignment_(std::move(assignment)),
107- xla_runtime_executable_(std::move(xla_runtime_executable)) {
114+ assignment_(std::move(assignment)) {
108115 if (assignment_) {
109116 buffer_assignment_ =
110117 std::make_shared<BufferAssignmentProto>(assignment_->ToProto ());
@@ -328,9 +335,9 @@ StatusOr<std::unique_ptr<Executable>> CpuExecutable::LoadFromObjFile(
328335 std::move (executable_ptr), xla_framework_mapping,
329336 std::move (*ffi_modules_state));
330337
331- return std::unique_ptr<Executable>( new CpuExecutable (
332- std::move (hlo_module), nullptr , nullptr , std::move (buffer_assignment),
333- std::move (xla_runtime_executable) ));
338+ return CpuExecutable::Create ( std::move (hlo_module), nullptr , nullptr ,
339+ std::move (buffer_assignment),
340+ std::move (xla_runtime_executable));
334341}
335342
336343StatusOr<ExecutionOutput> CpuExecutable::CreateResultShapedBuffer (
0 commit comments