HW-SW co-design in the RISC-V Ecosystem [Part 2]: MLIR to LLVM

#compilation #llvm #llvm #mlir

In Part 1, we explored the overarching concept of hardware-software co-design. Now, in Part 2, we delve into the specifics of implementing an MLIR pass. Passes are transformative actions applied to MLIR code during compilation, serving to optimize, analyze, or manipulate the code. They can be utilized for both IR analysis and Dialect-to-Dialect transformations. For further insights, refer to the documentation available here.

In this particular pass, our objective is to convert a singular operation (arith.mulf) to an equivalent LLVM intrinsic, dependent on certain conditions. Essentially, an LLVM intrinsic can be regarded as a specialized function. At a high level, the process entails replacing instances of

  • arith.mulf (approx = "exp") with an LLVM call to
  • llvm.call @llvm.riscv.floatexp.mul(%arg0, %arg1) : (f32, f32) -> f32.

It’s important to note that if the approx attribute is not set to exp, or if the inputs of arith.mulf are not f32, no action should be taken.

The overall implementation of the pass is in this patch. We go over each part of the pass implementation below.

mlir/include/mlir/Conversion/Passes.h

    #include "mlir/Conversion/ArithToRISCVNN/ArithToRISCVNN.h"

This includes the new pass in the set of all passes that mlir-opt supports.

mlir/include/mlir/Conversion/Passes.td

def ConvertArithToRISCVNNPass : Pass<"convert-arith-to-riscvnn"> {
  let summary = "Convert arith dialect operations to LLVM RISCV intrinsics for NN";
  let dependentDialects = ["LLVM::LLVMDialect", "arith::ArithDialect"];
  let constructor = "mlir::createConvertArithToRISCVNN()";
}

This markdown snippet defines a TableGen-like record, outlining the high-level details of the pass. In this instance, the pass is named convert-arith-to-riscvnn. Additionally, it indicates that the pass depends on two other dialects: LLVM::LLVMDialect and arith::ArithDialect. Furthermore, it specifies the description of the pass as available in mlir-opt.

mlir/include/mlir/Conversion/ArithToRISCVNN/ArithToRISCVNN.h

//===- ArithToRISCVNN.h - Arith to LLVM dialect conversion -----------*- C++ -*-===//

#ifndef MLIR_CONVERSION_ARITHTORISCVNN_ARITHTORISCVNN_H
#define MLIR_CONVERSION_ARITHTORISCVNN_ARITHTORISCVNN_H

#include "mlir/Pass/Pass.h"  // from @llvm-project

// Extra includes needed for dependent dialects
#include "mlir/Dialect/Arith/IR/Arith.h"   // from @llvm-project
#include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
namespace mlir {
class ModuleOp;

#define GEN_PASS_DECL_CONVERTARITHTORISCVNNPASS
#include "mlir/Conversion/Passes.h.inc"

std::unique_ptr<OperationPass<>> createConvertArithToRISCVNN();

}
#endif // MLIR_CONVERSION_ARITHTORISCVNN_ARITHTORISCVNN_H

In this context, the goal is to define a new pass that converts certain operations (e.g. arith.mulf {approx="exp"}) into the new intrinsics that were defined above.

//===- ArithToRISCVNNPass.cpp - Arith to LLVM Pass ------------------------===//

#include "mlir/Conversion/ArithToRISCVNN/ArithToRISCVNN.h"
// #include <iostream>
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/TypeID.h"

#include "mlir/Target/LLVMIR/ModuleTranslation.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

#include "llvm/Support/Casting.h"

#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Support/raw_ostream.h"

#include <memory>
#include <utility>


namespace mlir {
#define GEN_PASS_DEF_CONVERTARITHTORISCVNNPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir

using namespace mlir;

// ************** Patterns **********
static bool isSupportedSourceType(Type originalType) {
  // https://github.com/llvm/llvm-project/blob/c5f839bd58e7f888acc4cb39a18e9e5bbaa9fb0a/mlir/lib/IR/Types.cpp#L123
  if (originalType.isF32())
    return true;
  return false;
}

static LogicalResult checkSourceOpTypes(PatternRewriter &rewriter,
                                        Operation *sourceOp) {
  auto allTypes = llvm::to_vector(sourceOp->getOperandTypes());
  llvm::append_range(allTypes, sourceOp->getResultTypes());

  for (Type ty : allTypes) {
    if (!isSupportedSourceType(ty)) {
      return rewriter.notifyMatchFailure(
          sourceOp,
          llvm::formatv(
              "unsupported source type for Arith to LLVM conversion: {0}",
              ty));
    }
  }
  return success();
}

namespace {
// lower arith.mulf{approx='exp'}
// to llvm intrinsic

struct ApproxPattern : public OpRewritePattern<arith::MulFOp> {
  ApproxPattern(MLIRContext *context) : OpRewritePattern<arith::MulFOp>(context) {}

  // Define the match function to check if the operation has the "approx" attribute.
  LogicalResult matchAndRewrite(arith::MulFOp op, PatternRewriter &rewriter) const override {
    // Check if the operation has the "approx" attribute.
    StringAttr approxAttr = op->getAttrOfType<StringAttr>("approx");
    // TODO: Add other patterns here for other attribute values!
    if (!approxAttr || approxAttr.getValue() != "exp")
      return failure();

    if (LogicalResult res = checkSourceOpTypes(rewriter, op); failed(res))
      return res;


    // // Replace the arith.mulf operation with the llvm.fmul intrinsic call
    ModuleOp parentModule = op->getParentOfType<ModuleOp>();
    // TODO: Choose the function name based on the attribute value
    auto fnName = "llvm.riscv.floatexp.mul";
    auto context = parentModule->getContext();
    auto llvmF32Ty = Float32Type::get(context); // 'mlir::Float32Type'

    auto llvmFnType =  LLVM::LLVMFunctionType::get(
      llvmF32Ty, // return type.
      {llvmF32Ty, llvmF32Ty}, // parameter type.
      false);

    // Get a symbol reference to the printf function, inserting it if necessary.
    auto printfRef = getLLVMFuncRef(rewriter, parentModule, fnName);

    // Assuming op has operands that need to be passed as arguments
    auto operands = op.getOperands();

    // Create an array to hold the arguments for the LLVM::CallOp
    SmallVector<Value, 4> args;
    args.reserve(operands.size());

    // Add operands as arguments
    for (auto operand : operands) {
        args.push_back(operand);
    }
    auto newOp = rewriter.create<LLVM::CallOp>(
        op.getLoc(), llvmFnType, printfRef, args);

    // Replace the original operation with the newly created LLVM intrinsic call.
    rewriter.replaceOp(op, newOp->getResult(0));
    return success();
  }

  static LLVM::LLVMFunctionType getFnType(MLIRContext *context) {
    auto llvmF32Ty = Float32Type::get(context); // 'mlir::Float32Type'

    auto llvmFnType =  LLVM::LLVMFunctionType::get(
      llvmF32Ty, // return type.
      {llvmF32Ty, llvmF32Ty}, // parameter type.
      false);
    return llvmFnType;
  }
  // / Return a symbol reference to the printf function, inserting it into the
  // / module if necessary.
  static FlatSymbolRefAttr getLLVMFuncRef(PatternRewriter &rewriter,
                                             ModuleOp module,
                                             std::string funcName) {
    auto *context = module.getContext();
    if (module.lookupSymbol<LLVM::LLVMFuncOp>(funcName))
      return SymbolRefAttr::get(context, funcName);

    // Insert the printf function into the body of the parent module.
    PatternRewriter::InsertionGuard insertGuard(rewriter);
    rewriter.setInsertionPointToStart(module.getBody());
    rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), funcName,
                                      getFnType(context));
    return SymbolRefAttr::get(context, funcName);
  }
};
}

// ************** Patterns **********

namespace {
/// A pass converting MLIR Math operations into the SPIR-V dialect.
class ConvertArithToRISCVNNPass
    : public impl::ConvertArithToRISCVNNPassBase<ConvertArithToRISCVNNPass>  {

  void runOnOperation() override;
};
} // namespace

void ConvertArithToRISCVNNPass::runOnOperation() {
  MLIRContext *context = &getContext();
  LLVMConversionTarget target(*context);

  RewritePatternSet patterns(context);
  patterns.insert<ApproxPattern>(context);

  if (failed(applyPartialConversion(getOperation(), target, std::move(patterns))))
    return signalPassFailure();
}

std::unique_ptr<OperationPass<>> mlir::createConvertArithToRISCVNN() {
  return std::make_unique<ConvertArithToRISCVNNPass>();
}

This patch demonstrates the overall pattern matching and corresponding lowering process. Specifically, it showcases the lowering of a particular case (arith.mulf {approx="exp"}) into an LLVM intrinsic call (llvm.riscv.floatexp.mul). Certain sections of the code have been marked with TODO comments. Lowering to other LLVM intrinsics could be additionally implemented depending on different values of the approx attribute. One of the key challenges to writing passes effectively is to understand the different template structures used to select patterns and how new operations are sepcified. Fortunately, there are existing passes, such as ArithToLLVM and SPIRVToLLVM, which can serve as valuable examples for study.

mlir/lib/Conversion/ArithToRISCVNN/CMakeLists.txt

add_mlir_conversion_library(MLIRArithToRISCVNN
  ArithToRISCVNNPass.cpp

  ADDITIONAL_HEADER_DIRS
  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Arith
  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVM
  ${MLIR_MAIN_INCLUDE_DIR}/mlir/IR

  DEPENDS
  MLIRConversionPassIncGen

  LINK_LIBS PUBLIC
  MLIRIR
  MLIRArithDialect
  MLIRMathDialect
  MLIRLLVMDialect
  MLIRPass
  MLIRSupport
  MLIRTransformUtils
  )

The CMakeLists.txt file included in this repository provides support for compiling the new pass, along with the necessary dependencies.

To test out this pass, follow the instructions provided in the README of the code repository after building LLVM. It’s important to note that a total of three passes are utilized to lower all operations into the LLVM Dialect. Additionally, the convert-func-to-llvm pass is employed to convert MLIR functions (func.func) into LLVM Dialect functions (llvm.func).

mlir-opt --help | grep riscvnn
249:      --convert-arith-to-riscvnn   -   Convert math dialect operations
  to LLVM RISCV intrinsics for NN

mlir-opt \
  -pass-pipeline="builtin.module(func.func(convert-arith-to-riscvnn,convert-arith-to-llvm,convert-math-to-llvm),convert-func-to-llvm,convert-vector-to-llvm)" \
  benchmark.mlir > benchmark_llvm.mlir

Lowering from MLIR.LLVM to LLVM IR

The code in LLVM Dialect of MLIR can be translated directly into LLVM IR using the mlir-translate tool. For our considered example, no specific changes are necessary to the tool.

mlir-translate -mlir-to-llvmir -split-input-file \
  -verify-diagnostics benchmark_llvm.mlir > benchmark_llvm.ll

Conclusion

In this post, we have explored the intricacies of hardware-software co-design, focusing on the implementation of an MLIR pass to optimize and transform code within the MLIR framework. The journey began with an overview of the pass’s purpose and dependencies, followed by a detailed examination of the pattern matching and lowering process involved.

Moving forward, readers are encouraged to build LLVM and test the pass using the instructions provided in the code repository.

In the upcoming blog post, we will delve into the process of adding support for new intrinsics and custom instructions in LLVM, specifically targeting the RISC-V architecture. Stay tuned for a deeper dive into the intricacies of integrating new instructions for the RISC-V target.

References

Follow @debjyoti0891