-
Notifications
You must be signed in to change notification settings - Fork 15.5k
[CIR][X86]Implement handling for convert-half builtins #173143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-clangir @llvm/pr-subscribers-clang Author: Priyanshu Kumar (Priyanshu3820) ChangesRelated to: #167765 Full diff: https://github.com/llvm/llvm-project/pull/173143.diff 2 Files Affected:
diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
index 75bf25b20f1af..59d467da3a9fb 100644
--- a/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenBuiltinX86.cpp
@@ -362,6 +362,27 @@ static mlir::Value emitX86Muldq(CIRGenBuilderTy &builder, mlir::Location loc,
return builder.createMul(loc, lhs, rhs);
}
+static mlir::Value emitX86CvtF16ToFloatExpr(CIRGenBuilderTy &builder,
+ mlir::Location loc,
+ mlir::Type dstTy,
+ SmallVectorImpl<mlir::Value> &ops) {
+
+ mlir::Value src = ops[0];
+ mlir::Value passthru = ops[1];
+
+ auto vecTy = mlir::cast<cir::VectorType>(src.getType());
+ uint64_t numElems = vecTy.getSize();
+
+ mlir::Value mask = getMaskVecValue(builder, loc, ops[2], numElems);
+
+ auto halfTy = cir::VectorType::get(builder.getF16Type(), numElems);
+ mlir::Value srcF16 = builder.createBitcast(loc, src, halfTy);
+
+ mlir::Value res = builder.createFloatingCast(srcF16, dstTy);
+
+ return emitX86Select(builder, loc, mask, res, passthru);
+}
+
static mlir::Value emitX86vpcom(CIRGenBuilderTy &builder, mlir::Location loc,
llvm::SmallVector<mlir::Value> ops,
bool isSigned) {
@@ -1662,12 +1683,40 @@ mlir::Value CIRGenFunction::emitX86BuiltinExpr(unsigned builtinID,
case X86::BI__builtin_ia32_cmpnltsd:
case X86::BI__builtin_ia32_cmpnlesd:
case X86::BI__builtin_ia32_cmpordsd:
+ cgm.errorNYI(expr->getSourceRange(),
+ std::string("unimplemented X86 builtin call: ") +
+ getContext().BuiltinInfo.getName(builtinID));
+ return mlir::Value{};
case X86::BI__builtin_ia32_vcvtph2ps_mask:
case X86::BI__builtin_ia32_vcvtph2ps256_mask:
- case X86::BI__builtin_ia32_vcvtph2ps512_mask:
- case X86::BI__builtin_ia32_cvtneps2bf16_128_mask:
+ case X86::BI__builtin_ia32_vcvtph2ps512_mask: {
+ mlir::Location loc = getLoc(expr->getExprLoc());
+ return emitX86CvtF16ToFloatExpr(builder, loc, convertType(expr->getType()),
+ ops);
+ }
+ case X86::BI__builtin_ia32_cvtneps2bf16_128_mask: {
+ mlir::Location loc = getLoc(expr->getExprLoc());
+ mlir::Value intrinsicMask = getMaskVecValue(builder, loc, ops[2], 4);
+ return emitIntrinsicCallOp(builder, loc,
+ "x86.avx512bf16.mask.cvtneps2bf16.128",
+ convertType(expr->getType()),
+ mlir::ValueRange{ops[0], ops[1], intrinsicMask});
+ }
case X86::BI__builtin_ia32_cvtneps2bf16_256_mask:
- case X86::BI__builtin_ia32_cvtneps2bf16_512_mask:
+ case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: {
+ mlir::Location loc = getLoc(expr->getExprLoc());
+ unsigned numElts = cast<cir::VectorType>(ops[1].getType()).getSize();
+ mlir::Value selectMask = getMaskVecValue(builder, loc, ops[2], numElts);
+ StringRef intrinsicName;
+ if (builtinID == X86::BI__builtin_ia32_cvtneps2bf16_256_mask)
+ intrinsicName = "x86.avx512bf16.cvtneps2bf16.256";
+ else
+ intrinsicName = "x86.avx512bf16.cvtneps2bf16.512";
+ mlir::Value intrinsicResult =
+ emitIntrinsicCallOp(builder, loc, intrinsicName, ops[1].getType(),
+ mlir::ValueRange{ops[0]});
+ return emitX86Select(builder, loc, selectMask, intrinsicResult, ops[1]);
+ }
case X86::BI__cpuid:
case X86::BI__cpuidex:
case X86::BI__emul:
diff --git a/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c b/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c
new file mode 100644
index 0000000000000..ccfc0d4a6a813
--- /dev/null
+++ b/clang/test/CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.c
@@ -0,0 +1,80 @@
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -target-feature +avx512bf16 -fclangir -emit-cir -o %t.cir -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefix=CIR --input-file=%t.cir %s
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -target-feature +avx512bf16 -fclangir -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefixes=LLVM --input-file=%t.ll %s
+// RUN: %clang_cc1 -flax-vector-conversions=none -ffreestanding %s -triple=x86_64-unknown-linux -target-feature +avx512f -target-feature +avx512vl -target-feature +avx512bf16 -emit-llvm -o %t.ll -Wall -Werror -Wsign-conversion
+// RUN: FileCheck --check-prefixes=OGCG --input-file=%t.ll %s
+
+#include <immintrin.h>
+
+__m256bh test_mm512_mask_cvtneps_pbh(__m256bh src, __mmask16 k, __m512 a) {
+ // CIR-LABEL: @test_mm512_mask_cvtneps_pbh
+ // CIR: cir.call @_mm512_mask_cvtneps_pbh({{.+}}, {{.+}}, {{.+}}) : (!cir.vector<16 x !cir.bf16>, !u16i, !cir.vector<16 x !cir.float>) -> !cir.vector<16 x !cir.bf16>
+
+ // LLVM-LABEL: @test_mm512_mask_cvtneps_pbh
+ // LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512
+
+ // OGCG-LABEL: @test_mm512_mask_cvtneps_pbh
+ // OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512
+ return _mm512_mask_cvtneps_pbh(src, k, a);
+}
+
+__m256bh test_mm512_maskz_cvtneps_pbh(__mmask16 k, __m512 a) {
+ // CIR-LABEL: @test_mm512_maskz_cvtneps_pbh
+ // CIR: cir.call @_mm512_maskz_cvtneps_pbh({{.+}}, {{.+}}) : (!u16i, !cir.vector<16 x !cir.float>) -> !cir.vector<16 x !cir.bf16>
+
+ // LLVM-LABEL: @test_mm512_maskz_cvtneps_pbh
+ // LLVM: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> {{.+}})
+
+ // OGCG-LABEL: @test_mm512_maskz_cvtneps_pbh
+ // OGCG: call <16 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> {{.+}})
+ return _mm512_maskz_cvtneps_pbh(k, a);
+}
+
+__m128bh test_mm256_mask_cvtneps_pbh(__m128bh src, __mmask8 k, __m256 a) {
+ // CIR-LABEL: test_mm256_mask_cvtneps_pbh
+ // CIR: cir.call @_mm256_mask_cvtneps_pbh({{.+}}, {{.+}}, {{.+}}) : (!cir.vector<8 x !cir.bf16>, !u8i, !cir.vector<8 x !cir.float>) -> !cir.vector<8 x !cir.bf16>
+
+ // LLVM-LABEL: test_mm256_mask_cvtneps_pbh
+ // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}})
+
+ // OGCG-LABEL: test_mm256_mask_cvtneps_pbh
+ // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}})
+ return _mm256_mask_cvtneps_pbh(src, k, a);
+}
+
+__m128bh test_mm256_maskz_cvtneps_pbh(__mmask8 k, __m256 a) {
+ // CIR-LABEL: test_mm256_maskz_cvtneps_pbh
+ // CIR: cir.call @_mm256_maskz_cvtneps_pbh({{.+}}, {{.+}}) : (!u8i, !cir.vector<8 x !cir.float>) -> !cir.vector<8 x !cir.bf16>
+
+ // LLVM-LABEL: test_mm256_maskz_cvtneps_pbh
+ // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}})
+
+ // OGCG-LABEL: test_mm256_maskz_cvtneps_pbh
+ // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.cvtneps2bf16.256(<8 x float> {{.+}})
+ return _mm256_maskz_cvtneps_pbh(k, a);
+}
+
+__m128bh test_mm_mask_cvtneps_pbh(__m128bh src, __mmask8 k, __m128 a) {
+ // CIR-LABEL: test_mm_mask_cvtneps_pbh
+ // CIR: cir.call @_mm_mask_cvtneps_pbh({{.+}}, {{.+}}, {{.+}}) : (!cir.vector<8 x !cir.bf16>, !u8i, !cir.vector<4 x !cir.float>) -> !cir.vector<8 x !cir.bf1{{.+}}
+
+ // LLVM-LABEL: test_mm_mask_cvtneps_pbh
+ // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> %extract.i)
+
+ // OGCG-LABEL: test_mm_mask_cvtneps_pbh
+ // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> %extract.i)
+ return _mm_mask_cvtneps_pbh(src, k, a);
+}
+
+__m128bh test_mm_maskz_cvtneps_pbh(__mmask8 k, __m128 a) {
+ // CIR-LABEL: test_mm_maskz_cvtneps_pbh
+ // CIR: cir.call @_mm_maskz_cvtneps_pbh({{.+}}, {{.+}}) : (!u8i, !cir.vector<4 x !cir.float>) -> !cir.vector<8 x !cir.bf16>
+
+ // LLVM-LABEL: test_mm_maskz_cvtneps_pbh
+ // LLVM: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> %extract.i)
+
+ // OGCG-LABEL: test_mm_maskz_cvtneps_pbh
+ // OGCG: call <8 x bfloat> @llvm.x86.avx512bf16.mask.cvtneps2bf16.128(<4 x float> {{.+}}, <8 x bfloat> {{.+}}, <4 x i1> %extract.i)
+ return _mm_maskz_cvtneps_pbh(k, a);
+}
|
🐧 Linux x64 Test Results
Failed Tests(click on a test name to see its output) ClangClang.CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.cClang.CIR/CodeGenBuiltins/X86/avx512vlbf16-builtins.cIf these failures are unrelated to your changes (for example tests are broken or flaky at HEAD), please open an issue at https://github.com/llvm/llvm-project/issues and add the |
Related to: #167765