HW-SW co-design in the RISC-V Ecosystem [Part 2]: MLIR to LLVM
09 Apr 2024 #compilation #llvm #llvm #mlirIn Part 1, we explored the overarching concept of hardware-software co-design. Now, in Part 2, we delve into the specifics of implementing an MLIR pass. Passes are transformative actions applied to MLIR code during compilation, serving to optimize, analyze, or manipulate the code. They can be utilized for both IR analysis and Dialect-to-Dialect transformations. For further insights, refer to the documentation available here.
In this particular pass, our objective is to convert a singular operation (arith.mulf
) to an equivalent LLVM intrinsic, dependent on certain conditions. Essentially, an LLVM intrinsic can be regarded as a specialized function. At a high level, the process entails replacing instances of
arith.mulf (approx = "exp")
with an LLVM call tollvm.call @llvm.riscv.floatexp.mul(%arg0, %arg1) : (f32, f32) -> f32
.
It’s important to note that if the approx
attribute is not set to exp
, or if the inputs of arith.mulf
are not f32
, no action should be taken.
The overall implementation of the pass is in this patch. We go over each part of the pass implementation below.
mlir/include/mlir/Conversion/Passes.h
#include "mlir/Conversion/ArithToRISCVNN/ArithToRISCVNN.h"
This includes the new pass in the set of all passes that mlir-opt
supports.
mlir/include/mlir/Conversion/Passes.td
def ConvertArithToRISCVNNPass : Pass<"convert-arith-to-riscvnn"> {
let summary = "Convert arith dialect operations to LLVM RISCV intrinsics for NN";
let dependentDialects = ["LLVM::LLVMDialect", "arith::ArithDialect"];
let constructor = "mlir::createConvertArithToRISCVNN()";
}
This markdown snippet defines a TableGen-like record, outlining the high-level details of the pass. In this instance, the pass is named convert-arith-to-riscvnn
. Additionally, it indicates that the pass depends on two other dialects: LLVM::LLVMDialect
and arith::ArithDialect
. Furthermore, it specifies the description of the pass as available in mlir-opt
.
mlir/include/mlir/Conversion/ArithToRISCVNN/ArithToRISCVNN.h
//===- ArithToRISCVNN.h - Arith to LLVM dialect conversion -----------*- C++ -*-===//
#ifndef MLIR_CONVERSION_ARITHTORISCVNN_ARITHTORISCVNN_H
#define MLIR_CONVERSION_ARITHTORISCVNN_ARITHTORISCVNN_H
#include "mlir/Pass/Pass.h" // from @llvm-project
// Extra includes needed for dependent dialects
#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
namespace mlir {
class ModuleOp;
#define GEN_PASS_DECL_CONVERTARITHTORISCVNNPASS
#include "mlir/Conversion/Passes.h.inc"
std::unique_ptr<OperationPass<>> createConvertArithToRISCVNN();
}
#endif // MLIR_CONVERSION_ARITHTORISCVNN_ARITHTORISCVNN_H
In this context, the goal is to define a new pass that converts certain operations (e.g. arith.mulf {approx="exp"}
) into the new intrinsics that were defined above.
//===- ArithToRISCVNNPass.cpp - Arith to LLVM Pass ------------------------===//
#include "mlir/Conversion/ArithToRISCVNN/ArithToRISCVNN.h"
// #include <iostream>
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "llvm/Support/Casting.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"
#include <memory>
#include <utility>
namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTORISCVNNPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
using namespace mlir;
// ************** Patterns **********
static bool isSupportedSourceType(Type originalType) {
// https://github.com/llvm/llvm-project/blob/c5f839bd58e7f888acc4cb39a18e9e5bbaa9fb0a/mlir/lib/IR/Types.cpp#L123
if (originalType.isF32())
return true;
return false;
}
static LogicalResult checkSourceOpTypes(PatternRewriter &rewriter,
Operation *sourceOp) {
auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
llvm::append_range(allTypes, sourceOp->getResultTypes());
for (Type ty : allTypes) {
if (!isSupportedSourceType(ty)) {
return rewriter.notifyMatchFailure(
sourceOp,
llvm::formatv(
"unsupported source type for Arith to LLVM conversion: {0}",
ty));
}
}
return success();
}
namespace {
// lower arith.mulf{approx='exp'}
// to llvm intrinsic
struct ApproxPattern : public OpRewritePattern<arith::MulFOp> {
ApproxPattern(MLIRContext *context) : OpRewritePattern<arith::MulFOp>(context) {}
// Define the match function to check if the operation has the "approx" attribute.
LogicalResult matchAndRewrite(arith::MulFOp op, PatternRewriter &rewriter) const override {
// Check if the operation has the "approx" attribute.
StringAttr approxAttr = op->getAttrOfType<StringAttr>("approx");
// TODO: Add other patterns here for other attribute values!
if (!approxAttr || approxAttr.getValue() != "exp")
return failure();
if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
return res;
// // Replace the arith.mulf operation with the llvm.fmul intrinsic call
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// TODO: Choose the function name based on the attribute value
auto fnName = "llvm.riscv.floatexp.mul";
auto context = parentModule->getContext();
auto llvmF32Ty = Float32Type::get(context); // 'mlir::Float32Type'
auto llvmFnType = LLVM::LLVMFunctionType::get(
llvmF32Ty, // return type.
{llvmF32Ty, llvmF32Ty}, // parameter type.
false);
// Get a symbol reference to the printf function, inserting it if necessary.
auto printfRef = getLLVMFuncRef(rewriter, parentModule, fnName);
// Assuming op has operands that need to be passed as arguments
auto operands = op.getOperands();
// Create an array to hold the arguments for the LLVM::CallOp
SmallVector<Value, 4> args;
args.reserve(operands.size());
// Add operands as arguments
for (auto operand : operands) {
args.push_back(operand);
}
auto newOp = rewriter.create<LLVM::CallOp>(
op.getLoc(), llvmFnType, printfRef, args);
// Replace the original operation with the newly created LLVM intrinsic call.
rewriter.replaceOp(op, newOp->getResult(0));
return success();
}
static LLVM::LLVMFunctionType getFnType(MLIRContext *context) {
auto llvmF32Ty = Float32Type::get(context); // 'mlir::Float32Type'
auto llvmFnType = LLVM::LLVMFunctionType::get(
llvmF32Ty, // return type.
{llvmF32Ty, llvmF32Ty}, // parameter type.
false);
return llvmFnType;
}
// / Return a symbol reference to the printf function, inserting it into the
// / module if necessary.
static FlatSymbolRefAttr getLLVMFuncRef(PatternRewriter &rewriter,
ModuleOp module,
std::string funcName) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName))
return SymbolRefAttr::get(context, funcName);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName,
getFnType(context));
return SymbolRefAttr::get(context, funcName);
}
};
}
// ************** Patterns **********
namespace {
/// A pass converting MLIR Math operations into the SPIR-V dialect.
class ConvertArithToRISCVNNPass
: public impl::ConvertArithToRISCVNNPassBase<ConvertArithToRISCVNNPass> {
void runOnOperation() override;
};
} // namespace
void ConvertArithToRISCVNNPass::runOnOperation() {
MLIRContext *context = &getContext();
LLVMConversionTarget target(*context);
RewritePatternSet patterns(context);
patterns.insert<ApproxPattern>(context);
if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
}
std::unique_ptr<OperationPass<>> mlir::createConvertArithToRISCVNN() {
return std::make_unique<ConvertArithToRISCVNNPass>();
}
This patch demonstrates the overall pattern matching and corresponding lowering process. Specifically, it showcases the lowering of a particular case (arith.mulf {approx="exp"}
) into an LLVM intrinsic call (llvm.riscv.floatexp.mul
). Certain sections of the code have been marked with TODO
comments. Lowering to other LLVM intrinsics could be additionally implemented depending on different values of the approx
attribute.
One of the key challenges to writing passes effectively is to understand the
different template structures used to select patterns and how new operations
are sepcified. Fortunately, there are existing passes, such as ArithToLLVM and SPIRVToLLVM, which can serve as valuable examples for study.
mlir/lib/Conversion/ArithToRISCVNN/CMakeLists.txt
add_mlir_conversion_library(MLIRArithToRISCVNN
ArithToRISCVNNPass.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVM
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
DEPENDS
MLIRConversionPassIncGen
LINK_LIBS PUBLIC
MLIRIR
MLIRArithDialect
MLIRMathDialect
MLIRLLVMDialect
MLIRPass
MLIRSupport
MLIRTransformUtils
)
The CMakeLists.txt
file included in this repository provides support for compiling the new pass, along with the necessary dependencies.
To test out this pass, follow the instructions provided in the README of the code repository after building LLVM. It’s important to note that a total of three passes are utilized to lower all operations into the LLVM Dialect. Additionally, the convert-func-to-llvm
pass is employed to convert MLIR functions (func.func
) into LLVM Dialect functions (llvm.func
).
mlir-opt --help | grep riscvnn
249: --convert-arith-to-riscvnn - Convert math dialect operations
to LLVM RISCV intrinsics for NN
mlir-opt \
-pass-pipeline="builtin.module(func.func(convert-arith-to-riscvnn,convert-arith-to-llvm,convert-math-to-llvm),convert-func-to-llvm,convert-vector-to-llvm)" \
benchmark.mlir > benchmark_llvm.mlir
Lowering from MLIR.LLVM to LLVM IR
The code in LLVM Dialect of MLIR can be translated directly into LLVM IR using the
mlir-translate
tool. For our considered example, no specific changes are necessary to the tool.
mlir-translate -mlir-to-llvmir -split-input-file \
-verify-diagnostics benchmark_llvm.mlir > benchmark_llvm.ll
Conclusion
In this post, we have explored the intricacies of hardware-software co-design, focusing on the implementation of an MLIR pass to optimize and transform code within the MLIR framework. The journey began with an overview of the pass’s purpose and dependencies, followed by a detailed examination of the pattern matching and lowering process involved.
Moving forward, readers are encouraged to build LLVM and test the pass using the instructions provided in the code repository.
In the upcoming blog post, we will delve into the process of adding support for new intrinsics and custom instructions in LLVM, specifically targeting the RISC-V architecture. Stay tuned for a deeper dive into the intricacies of integrating new instructions for the RISC-V target.
References
- Github Code Repository
- MLIR Pass Manager
- Pattern Rewriting in MLIR
- An example of MLIR Pass implementation