Skip to content

Commit 557a094

Browse files
hawkinsptensorflower-gardener
authored andcommitted
[PJRT] Add support for compiling MHLO/CHLO modules to PjRt.
[XLA:Python] Add support for compiling MHLO/CHLO modules to the XLA Python bindings. PiperOrigin-RevId: 408317187 Change-Id: Ibd08bc57ea3073761c8eb5717da72b257f23ded3
1 parent 5b70b9e commit 557a094

File tree

18 files changed

+332
-3
lines changed

18 files changed

+332
-3
lines changed

tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License.
2222
namespace xla {
2323

2424
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
25-
xla::HloModuleProto* hlo_module_proto,
25+
xla::HloModuleProto const* hlo_module_proto,
2626
bool import_all_computation) {
2727
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
2828
return HloModuleImporter(module, import_all_computation)

tensorflow/compiler/mlir/xla/hlo_to_mlir_hlo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class HloModuleProto;
3131
// If import_all_computation is set to true, imports all computations
3232
// irrespective if transitively called from entry computation.
3333
Status ConvertHloToMlirHlo(mlir::ModuleOp module,
34-
xla::HloModuleProto* hlo_module,
34+
xla::HloModuleProto const* hlo_module,
3535
bool import_all_computations = false);
3636

3737
// Converts an HLO module to a MLIR module in HLO dialect.

tensorflow/compiler/xla/pjrt/BUILD

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,10 +152,12 @@ cc_library(
152152
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
153153
"//tensorflow/core:lib",
154154
"@com_google_absl//absl/base",
155+
"@com_google_absl//absl/container:inlined_vector",
155156
"@com_google_absl//absl/strings",
156157
"@com_google_absl//absl/synchronization",
157158
"@com_google_absl//absl/types:optional",
158159
"@com_google_absl//absl/types:span",
160+
"@llvm-project//mlir:IR",
159161
],
160162
)
161163

@@ -196,6 +198,7 @@ cc_library(
196198
":event_pool",
197199
":local_device_state",
198200
":metrics",
201+
":mlir_to_hlo",
199202
":pjrt_client",
200203
":tracked_device_buffer",
201204
":transpose",
@@ -377,6 +380,26 @@ tf_cc_test(
377380
],
378381
)
379382

383+
cc_library(
384+
name = "mlir_to_hlo",
385+
srcs = ["mlir_to_hlo.cc"],
386+
hdrs = ["mlir_to_hlo.h"],
387+
deps = [
388+
"//tensorflow/compiler/mlir/hlo",
389+
"//tensorflow/compiler/mlir/hlo:all_passes",
390+
"//tensorflow/compiler/mlir/hlo:legalize_trigonometric_to_approximation",
391+
"//tensorflow/compiler/mlir/tensorflow:error_util",
392+
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
393+
"//tensorflow/compiler/xla:status",
394+
"//tensorflow/compiler/xla/client:xla_computation",
395+
"@com_google_absl//absl/strings",
396+
"@llvm-project//mlir:IR",
397+
"@llvm-project//mlir:Parser",
398+
"@llvm-project//mlir:Pass",
399+
"@llvm-project//mlir:StandardOps",
400+
],
401+
)
402+
380403
cc_library(
381404
name = "tracked_tfrt_cpu_device_buffer",
382405
srcs = ["tracked_tfrt_cpu_device_buffer.cc"],
@@ -397,6 +420,7 @@ cc_library(
397420
srcs = ["tfrt_cpu_pjrt_client.cc"],
398421
hdrs = ["tfrt_cpu_pjrt_client.h"],
399422
deps = [
423+
":mlir_to_hlo",
400424
":pjrt_client",
401425
":semaphore",
402426
":tracked_tfrt_cpu_device_buffer",
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
17+
18+
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
19+
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
20+
#include "mlir/Pass/Pass.h" // from @llvm-project
21+
#include "mlir/Pass/PassManager.h" // from @llvm-project
22+
#include "mlir/Transforms/Passes.h" // from @llvm-project
23+
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h"
24+
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
25+
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
26+
27+
namespace xla {
28+
29+
Status MlirToXlaComputation(mlir::ModuleOp module,
30+
XlaComputation& xla_computation,
31+
bool use_tuple_args, bool return_tuple) {
32+
mlir::StatusScopedDiagnosticHandler diagnostic_handler(module->getContext());
33+
{
34+
mlir::PassManager pm(module->getContext());
35+
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createChloLegalizeToHloPass(
36+
/*legalize_broadcasts=*/true, /*expand_compositions=*/true));
37+
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
38+
// In order to export to XLA, we must sink constants to control flow
39+
// regions, since XLA uses functional control flow.
40+
pm.addNestedPass<mlir::FuncOp>(
41+
mlir::mhlo::createSinkConstantsToControlFlowPass());
42+
if (failed(pm.run(module))) {
43+
VLOG(1) << "MHLO->HLO lowering passes failed.";
44+
module->dump();
45+
return diagnostic_handler.ConsumeStatus();
46+
}
47+
48+
VLOG(5) << "MHLO module after lowering, before HLO import ";
49+
if (VLOG_IS_ON(5)) {
50+
module->dump();
51+
}
52+
}
53+
54+
HloProto proto;
55+
TF_RETURN_IF_ERROR(
56+
ConvertMlirHloToHlo(module, &proto, use_tuple_args, return_tuple));
57+
58+
xla_computation = XlaComputation(std::move(*proto.mutable_hlo_module()));
59+
return Status::OK();
60+
}
61+
62+
} // namespace xla
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_COMPILER_XLA_PJRT_MLIR_TO_HLO_H_
17+
#define TENSORFLOW_COMPILER_XLA_PJRT_MLIR_TO_HLO_H_
18+
19+
#include "absl/strings/string_view.h"
20+
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
21+
#include "tensorflow/compiler/xla/client/xla_computation.h"
22+
#include "tensorflow/compiler/xla/status.h"
23+
24+
namespace xla {
25+
26+
// Converts an CHLO/MHLO module to XLA HLO.
27+
Status MlirToXlaComputation(mlir::ModuleOp module,
28+
XlaComputation& xla_computation,
29+
bool use_tuple_args, bool return_tuple);
30+
31+
} // namespace xla
32+
33+
#endif // TENSORFLOW_COMPILER_XLA_PJRT_MLIR_TO_HLO_H_

tensorflow/compiler/xla/pjrt/pjrt_client.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "absl/synchronization/notification.h"
2626
#include "absl/types/optional.h"
2727
#include "absl/types/span.h"
28+
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
2829
#include "tensorflow/compiler/xla/client/executable_build_options.h"
2930
#include "tensorflow/compiler/xla/client/xla_computation.h"
3031
#include "tensorflow/compiler/xla/layout.h"
@@ -269,6 +270,10 @@ class PjRtClient {
269270
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
270271
const XlaComputation& computation, CompileOptions options) = 0;
271272

273+
// Variant of `Compile` that accepts an MLIR module.
274+
virtual StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
275+
mlir::ModuleOp module, CompileOptions options) = 0;
276+
272277
// Generates a unique fingerprint for `executable`, may be absl::nullopt.
273278
virtual StatusOr<absl::optional<std::string>> ExecutableFingerprint(
274279
const PjRtExecutable& executable) const = 0;

tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,7 @@ limitations under the License.
9090
#include "tensorflow/compiler/xla/pjrt/event_pool.h"
9191
#include "tensorflow/compiler/xla/pjrt/local_device_state.h"
9292
#include "tensorflow/compiler/xla/pjrt/metrics.h"
93+
#include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
9394
#include "tensorflow/compiler/xla/pjrt/tracked_device_buffer.h"
9495
#include "tensorflow/compiler/xla/pjrt/utils.h"
9596
#include "tensorflow/compiler/xla/service/computation_layout.h"
@@ -2411,4 +2412,14 @@ StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
24112412
return std::unique_ptr<PjRtExecutable>(std::move(executable));
24122413
}
24132414

2415+
StatusOr<std::unique_ptr<PjRtExecutable>> PjRtStreamExecutorClient::Compile(
2416+
mlir::ModuleOp module, CompileOptions options) {
2417+
XlaComputation xla_computation;
2418+
TF_RETURN_IF_ERROR(MlirToXlaComputation(
2419+
module, xla_computation,
2420+
/*use_tuple_args=*/options.parameter_is_tupled_arguments,
2421+
/*return_tuple=*/false));
2422+
return Compile(xla_computation, options);
2423+
}
2424+
24142425
} // namespace xla

tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ class PjRtStreamExecutorClient : public PjRtClient {
171171

172172
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
173173
const XlaComputation& computation, CompileOptions options) override;
174+
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
175+
mlir::ModuleOp mlir_module, CompileOptions options) override;
174176

175177
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
176178
const PjRtExecutable& executable) const override {

tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "tensorflow/compiler/xla/client/xla_computation.h"
3333
#include "tensorflow/compiler/xla/layout.h"
3434
#include "tensorflow/compiler/xla/literal.h"
35+
#include "tensorflow/compiler/xla/pjrt/mlir_to_hlo.h"
3536
#include "tensorflow/compiler/xla/pjrt/pjrt_client.h"
3637
#include "tensorflow/compiler/xla/pjrt/semaphore.h"
3738
#include "tensorflow/compiler/xla/pjrt/utils.h"
@@ -372,6 +373,16 @@ StatusOr<std::unique_ptr<PjRtExecutable>> TfrtCpuClient::Compile(
372373
return std::unique_ptr<PjRtExecutable>(std::move(executable));
373374
}
374375

376+
StatusOr<std::unique_ptr<PjRtExecutable>> TfrtCpuClient::Compile(
377+
mlir::ModuleOp module, CompileOptions options) {
378+
XlaComputation xla_computation;
379+
TF_RETURN_IF_ERROR(MlirToXlaComputation(
380+
module, xla_computation,
381+
/*use_tuple_args=*/options.parameter_is_tupled_arguments,
382+
/*return_tuple=*/false));
383+
return Compile(xla_computation, options);
384+
}
385+
375386
StatusOr<std::unique_ptr<TfrtCpuBuffer>> AllocateDestinationBuffer(
376387
const Shape& on_device_shape,
377388
absl::InlinedVector<tfrt::AsyncValueRef<CpuEvent>, 4> definition_events,

tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ class TfrtCpuClient final : public PjRtClient {
132132

133133
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
134134
const XlaComputation& computation, CompileOptions options) override;
135+
StatusOr<std::unique_ptr<PjRtExecutable>> Compile(
136+
mlir::ModuleOp module, CompileOptions options) override;
135137

136138
StatusOr<absl::optional<std::string>> ExecutableFingerprint(
137139
const PjRtExecutable& executable) const override;

0 commit comments

Comments
 (0)