LLVM API(C++)でBrankfuck->LLVM IRコンパイラを作ってみた

前回の記事LLVMの入門として、Brankfuck->LLVM IRコンパイラを作ったことについて書きました。 前回は生のLLVM IRを扱っていましたが、今回はLLVM APIを利用してBrankfuck->LLVM IRコンパイラを作ってみました。

開発環境

$ llvm-config --version
10.0.0

$ clang++ --version
clang version 10.0.0-4ubuntu1
Target: x86_64-pc-linux-gnu
Thread model: posix
InstalledDir: /usr/bin

main関数とcharを格納するメモリ領域の確保

今回作成するのは、標準入力からBrainFuckソースコードを受け取り、標準出力にそのBrainFuckソースコードに対応するLLVM IRを出力するC++プログラムです。

作業は以下の流れで進めていきます。
1. 出力したいLLVM IRを把握する(LLVM IRは前回の記事で扱ったもの使います)
2. LLVM: LLVMで出力したいLLVM IR命令に相当するLLVM APIを探す
3. LLVM APIを利用するC++プログラムを書く

1. 出力したいLLVM IRを把握する

まず、main関数の作成とcharを格納するメモリ領域確保の処理についてです。

define i32 @main() {
  %1 = alloca [30000 x i8]
  %2 = alloca i8*
  %3 = bitcast [30000 x i8]* %1 to i8*
  call void @llvm.memset.p0i8.i64(i8* %3, i8 0, i64 30000, i1 false)
  %4 = getelementptr inbounds [30000 x i8], [30000 x i8]* %1, i64 0, i64 0
  store i8* %4, i8** %2
  ret i32 0
}

2. LLVM: LLVMで出力したいLLVM IR命令に相当するLLVM APIを探す

define i32 @main() {}の部分は、LLVM: llvm::Function ClassのCreateメソッドを使います。
alloca, bitcast, memset, getelementptr inbounds, store, ret命令は、LLVM: llvm::IRBuilderBase Classにそれぞれ対応するメソッドがあるのでそれを使います。

3. C++プログラムを書く

llvm::LLVMContext context;
llvm::Module module("top", context);
llvm::IRBuilder<> builder(context);

int main()
{
  llvm::Function *mainFunc = llvm::Function::Create(
      llvm::FunctionType::get(llvm::Type::getInt32Ty(context), false),
      llvm::Function::ExternalLinkage, "main", module);

  llvm::BasicBlock *basicBlock = llvm::BasicBlock::Create(context, "", mainFunc);
  builder.SetInsertPoint(basicBlock);

  llvm::Value *array = builder.CreateAlloca(
      llvm::ArrayType::get(builder.getInt8Ty(), 30000),
      nullptr);

  llvm::Value *elementPtr = builder.CreateAlloca(builder.getInt8PtrTy(), nullptr);

  llvm::Value *arrayPtr = builder.CreateBitCast(array, builder.getInt8PtrTy());
  builder.CreateMemSet(arrayPtr, builder.getInt8(0), 30000, llvm::MaybeAlign());

  llvm::Value *int64Zero = static_cast<llvm::Value *>(builder.getInt64(0));
  const std::array<llvm::Value *, 2> indexes = {int64Zero, int64Zero};
  llvm::Value *element = builder.CreateInBoundsGEP(
      llvm::ArrayType::get(builder.getInt8Ty(), 30000),
      array,
      llvm::ArrayRef<llvm::Value *>(indexes));

  builder.CreateStore(element, elementPtr, false);

  builder.CreateRet(builder.getInt32(0));

  module.print(llvm::outs(), nullptr);

  return 0;
}

「>」「<」命令の実装

「>」「<」命令は、ポインタのインクリメント・デクリメントです。
LLVM IRは以下のようになっています。

%6 = load i8*, i8** %3, align 8
%7 = getelementptr inbounds i8, i8* %6, i32 1
store i8* %7, i8** %3, align 8

%8 = load i8*, i8** %3, align 8
%9 = getelementptr inbounds i8, i8* %8, i32 -1
store i8* %9, i8** %3, align 8

インクリメントとデクリメントは、getelementptrの第3引数が違うだけなので一つの関数にまとめました。

void emit_add_ptr(llvm::Value *ptr, int diff)
{
  builder.CreateStore(
      builder.CreateInBoundsGEP(
          builder.getInt8Ty(),
          builder.CreateLoad(builder.getInt8PtrTy(), ptr),
          builder.getInt32(diff)),
      ptr);
}

「+」「-」命令の実装

「+」「-」命令は、値のインクリメント・デクリメントです。
LLVM IRは以下のようになっています。

%10 = load i8*, i8** %3, align 8
%11 = load i8, i8* %10, align 1
%12 = add i8 %11, 1
store i8 %12, i8* %10, align 1

%13 = load i8*, i8** %3, align 8
%14 = load i8, i8* %13, align 1
%15 = add i8 %14, -1
store i8 %15, i8* %13, align 1

インクリメントとデクリメントは、addの第2引数が違うだけなので一つの関数にまとめました。

void emit_add(llvm::Value *ptr, int diff)
{
  llvm::Value *tmpPtr = builder.CreateLoad(builder.getInt8PtrTy(), ptr);
  builder.CreateStore(
      builder.CreateAdd(
          builder.CreateLoad(builder.getInt8Ty(), tmpPtr),
          builder.getInt8(diff)),
      tmpPtr);
}

「[」「]」命令の実装

「[」命令は、ポインタの指す値が0なら、後の]までジャンプ、
「]」命令は、ポインタの指す値が0でなければ、前の[までジャンプ、を示します。
LLVM IRは以下のようになります。

define i32 @main() {
; -----省略-----
  %15 = load i8*, i8** %3, align 8
  %16 = load i8, i8* %15, align 1
  %17 = add i8 %16, 1
  store i8 %17, i8* %15, align 1
  br label %18

18:                                               ; preds = %23, %0
  %19 = load i8*, i8** %3, align 8
  %20 = load i8, i8* %19, align 1
  %21 = zext i8 %20 to i32
  %22 = icmp eq i32 %21, 0
  br i1 %22, label %23, label %27

23:                                               ; preds = %18
  %24 = load i8*, i8** %3, align 8
  %25 = load i8, i8* %24, align 1
  %26 = add i8 %25, -1
  store i8 %26, i8* %24, align 1
  br label %18

27:                                               ; preds = %18
  %28 = load i8*, i8** %3, align 8
  %29 = load i8, i8* %28, align 1
  %30 = add i8 %29, 1
  store i8 %30, i8* %28, align 1
; -----省略-----
  ret i32 0
}

前回の記事では、「[」「]」命令のLLVM IRを出力するコードは以下のようになりました。

void emit_while_start(int while_index)
{
  std::cout << "  br label %while_start" << while_index << std::endl;
  std::cout << "while_start" << while_index << ":" << std::endl;
  std::cout << "  %" << index << " = load i8*, i8** %2, align 8" << std::endl;
  std::cout << "  %" << index + 1 << " = load i8, i8* %" << index << ", align 1" << std::endl;
  std::cout << "  %" << index + 2 << " = icmp ne i8 %" << index + 1 << ", 0" << std::endl;
  std::cout << "  br i1 %" << index + 2 << ", label %while_body" << while_index << ", label %while_end" << while_index << std::endl;
  std::cout << "while_body" << while_index << ":" << std::endl;
  index += 3;
}

void emit_while_end(int while_index)
{
  std::cout << "  br label %while_start" << while_index << std::endl;
  std::cout << "while_end" << while_index << ":" << std::endl;
}

LLVM APIでは、while_start, while_body, while_endのようなコードのまとまりを、BasicBlockを使用して指定します。

struct WhileBlock
{
  llvm::BasicBlock *cond_block;
  llvm::BasicBlock *body_block;
  llvm::BasicBlock *end_block;
};

int while_index = 0;
WhileBlock while_blocks[1000];
WhileBlock *while_block_ptr = while_blocks;

void emit_while_start(llvm::Function *func, llvm::Value *ptr, WhileBlock *while_block, int while_index)
{
  while_block->cond_block = llvm::BasicBlock::Create(
      context, std::string("while_cond") + std::to_string(while_index), func);
  while_block->body_block = llvm::BasicBlock::Create(
      context, std::string("while_body") + std::to_string(while_index), func);
  while_block->end_block = llvm::BasicBlock::Create(
      context, std::string("while_end") + std::to_string(while_index), func);

  builder.CreateBr(while_block->cond_block);

  builder.SetInsertPoint(while_block->cond_block);
  builder.CreateCondBr(
      builder.CreateICmpNE(
          builder.CreateLoad(
              builder.getInt8Ty(),
              builder.CreateLoad(builder.getInt8PtrTy(), ptr)),
          builder.getInt8(0)),
      while_block->body_block,
      while_block->end_block);

  builder.SetInsertPoint(while_block->body_block);
}

void emit_while_end(WhileBlock *while_block)
{
  builder.CreateBr(while_block->cond_block);

  builder.SetInsertPoint(while_block->end_block);
}

「.」命令の実装

「.」命令は、ポインタの値の出力です。
LLVM IRでは、ポインタの値の出力は以下のようになっています。

%16 = call i32 @getchar()
%17 = trunc i32 %16 to i8
%18 = load i8*, i8** %3, align 8
store i8 %17, i8* %18, align 1

LLVM APIを利用したC++のコードは以下のようになります。

void emit_put(llvm::Value *ptr)
{
  llvm::FunctionCallee calleePutChar = module.getOrInsertFunction(
      "putchar", builder.getInt32Ty(), builder.getInt32Ty());
  llvm::ArrayRef<llvm::Value *> argsRefPutChar(
      builder.CreateZExt(
          builder.CreateLoad(
              builder.getInt8Ty(),
              builder.CreateLoad(builder.getInt8PtrTy(), ptr)),
          builder.getInt32Ty()));
  builder.CreateCall(
      calleePutChar,
      argsRefPutChar);
}

「,」命令の実装

「,」命令は、入力から1バイト読み込んで、ポインタが指す値に代入、を示します。 LLVM IRは以下のようになります。

%16 = call i32 @getchar()
%17 = trunc i32 %16 to i8
%18 = load i8*, i8** %3, align 8
store i8 %17, i8* %18, align 1

LLVM APIを利用したC++のコードは以下のようになります。

void emit_get(llvm::Value *ptr)
{
  llvm::FunctionCallee calleeGetChar = module.getOrInsertFunction(
      "getchar", builder.getInt32Ty());
  builder.CreateStore(
      builder.CreateTrunc(
          builder.CreateCall(calleeGetChar),
          builder.getInt8Ty()),
      builder.CreateLoad(builder.getInt8PtrTy(), ptr));
}

実行してみる

main関数にこれまで作成した関数を呼び出すコードを追加します。

int main()
{
  llvm::Function *mainFunc = llvm::Function::Create(
      llvm::FunctionType::get(llvm::Type::getInt32Ty(context), false),
      llvm::Function::ExternalLinkage, "main", module);

  llvm::BasicBlock *basicBlock = llvm::BasicBlock::Create(context, "", mainFunc);
  builder.SetInsertPoint(basicBlock);

  llvm::Value *array = builder.CreateAlloca(
      llvm::ArrayType::get(builder.getInt8Ty(), 30000),
      nullptr);

  llvm::Value *elementPtr = builder.CreateAlloca(builder.getInt8PtrTy(), nullptr);

  llvm::Value *arrayPtr = builder.CreateBitCast(array, builder.getInt8PtrTy());
  builder.CreateMemSet(arrayPtr, builder.getInt8(0), 30000, llvm::MaybeAlign());

  llvm::Value *int64Zero = static_cast<llvm::Value *>(builder.getInt64(0));
  const std::array<llvm::Value *, 2> indexes = {int64Zero, int64Zero};
  llvm::Value *element = builder.CreateInBoundsGEP(
      llvm::ArrayType::get(builder.getInt8Ty(), 30000),
      array,
      llvm::ArrayRef<llvm::Value *>(indexes));

  builder.CreateStore(element, elementPtr, false);

  char c;
  while ((c = getchar()) != EOF)
  {
    switch (c)
    {
    case '>':
      emit_add_ptr(elementPtr, 1);
      break;
    case '<':
      emit_add_ptr(elementPtr, -1);
      break;
    case '+':
      emit_add(elementPtr, 1);
      break;
    case '-':
      emit_add(elementPtr, -1);
      break;
    case '[':
      emit_while_start(mainFunc, elementPtr, while_block_ptr++, while_index++);
      break;
    case ']':
      emit_while_end(--while_block_ptr);
      break;
    case '.':
      emit_put(elementPtr);
      break;
    case ',':
      emit_get(elementPtr);
      break;
    }
  }

  builder.CreateRet(builder.getInt32(0));

  module.print(llvm::outs(), nullptr);

  return 0;
}

以下のコマンドを実行すると、Hello world!出力されます。

$ clang++ -c $(llvm-config --cxxflags) main.cpp -o main.o
$ clang++ main.o $(llvm-config --ldflags --libs)
$ echo "+++++++++[>++++++++>+++++++++++>+++>+<<<<-]>.>++.+++++++..+++.>+++++.<<+++++++++++++++.>.+++.------.--------.>+.>+." | ./a.out | lli
Hello world!

全コード

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/Alignment.h"
#include <array>

struct WhileBlock
{
  llvm::BasicBlock *cond_block;
  llvm::BasicBlock *body_block;
  llvm::BasicBlock *end_block;
};

llvm::LLVMContext context;
llvm::Module module("top", context);
llvm::IRBuilder<> builder(context);

int while_index = 0;
WhileBlock while_blocks[1000];
WhileBlock *while_block_ptr = while_blocks;

void emit_add_ptr(llvm::Value *ptr, int diff)
{
  builder.CreateStore(
      builder.CreateInBoundsGEP(
          builder.getInt8Ty(),
          builder.CreateLoad(builder.getInt8PtrTy(), ptr),
          builder.getInt32(diff)),
      ptr);
}

void emit_add(llvm::Value *ptr, int diff)
{
  llvm::Value *tmpPtr = builder.CreateLoad(builder.getInt8PtrTy(), ptr);
  builder.CreateStore(
      builder.CreateAdd(
          builder.CreateLoad(builder.getInt8Ty(), tmpPtr),
          builder.getInt8(diff)),
      tmpPtr);
}

void emit_while_start(llvm::Function *func, llvm::Value *ptr, WhileBlock *while_block, int while_index)
{
  while_block->cond_block = llvm::BasicBlock::Create(
      context, std::string("while_cond") + std::to_string(while_index), func);
  while_block->body_block = llvm::BasicBlock::Create(
      context, std::string("while_body") + std::to_string(while_index), func);
  while_block->end_block = llvm::BasicBlock::Create(
      context, std::string("while_end") + std::to_string(while_index), func);

  builder.CreateBr(while_block->cond_block);

  builder.SetInsertPoint(while_block->cond_block);
  builder.CreateCondBr(
      builder.CreateICmpNE(
          builder.CreateLoad(
              builder.getInt8Ty(),
              builder.CreateLoad(builder.getInt8PtrTy(), ptr)),
          builder.getInt8(0)),
      while_block->body_block,
      while_block->end_block);

  builder.SetInsertPoint(while_block->body_block);
}

void emit_while_end(WhileBlock *while_block)
{
  builder.CreateBr(while_block->cond_block);

  builder.SetInsertPoint(while_block->end_block);
}

void emit_put(llvm::Value *ptr)
{
  llvm::FunctionCallee calleePutChar = module.getOrInsertFunction(
      "putchar", builder.getInt32Ty(), builder.getInt32Ty());
  llvm::ArrayRef<llvm::Value *> argsRefPutChar(
      builder.CreateZExt(
          builder.CreateLoad(
              builder.getInt8Ty(),
              builder.CreateLoad(builder.getInt8PtrTy(), ptr)),
          builder.getInt32Ty()));
  builder.CreateCall(
      calleePutChar,
      argsRefPutChar);
}

void emit_get(llvm::Value *ptr)
{
  llvm::FunctionCallee calleeGetChar = module.getOrInsertFunction(
      "getchar", builder.getInt32Ty());
  builder.CreateStore(
      builder.CreateTrunc(
          builder.CreateCall(calleeGetChar),
          builder.getInt8Ty()),
      builder.CreateLoad(builder.getInt8PtrTy(), ptr));
}

int main()
{
  llvm::Function *mainFunc = llvm::Function::Create(
      llvm::FunctionType::get(llvm::Type::getInt32Ty(context), false),
      llvm::Function::ExternalLinkage, "main", module);

  llvm::BasicBlock *basicBlock = llvm::BasicBlock::Create(context, "", mainFunc);
  builder.SetInsertPoint(basicBlock);

  llvm::Value *array = builder.CreateAlloca(
      llvm::ArrayType::get(builder.getInt8Ty(), 30000),
      nullptr);

  llvm::Value *elementPtr = builder.CreateAlloca(builder.getInt8PtrTy(), nullptr);

  llvm::Value *arrayPtr = builder.CreateBitCast(array, builder.getInt8PtrTy());
  builder.CreateMemSet(arrayPtr, builder.getInt8(0), 30000, llvm::MaybeAlign());

  llvm::Value *int64Zero = static_cast<llvm::Value *>(builder.getInt64(0));
  const std::array<llvm::Value *, 2> indexes = {int64Zero, int64Zero};
  llvm::Value *element = builder.CreateInBoundsGEP(
      llvm::ArrayType::get(builder.getInt8Ty(), 30000),
      array,
      llvm::ArrayRef<llvm::Value *>(indexes));

  builder.CreateStore(element, elementPtr, false);

  char c;
  while ((c = getchar()) != EOF)
  {
    switch (c)
    {
    case '>':
      emit_add_ptr(elementPtr, 1);
      break;
    case '<':
      emit_add_ptr(elementPtr, -1);
      break;
    case '+':
      emit_add(elementPtr, 1);
      break;
    case '-':
      emit_add(elementPtr, -1);
      break;
    case '[':
      emit_while_start(mainFunc, elementPtr, while_block_ptr++, while_index++);
      break;
    case ']':
      emit_while_end(--while_block_ptr);
      break;
    case '.':
      emit_put(elementPtr);
      break;
    case ',':
      emit_get(elementPtr);
      break;
    }
  }

  builder.CreateRet(builder.getInt32(0));

  module.print(llvm::outs(), nullptr);

  return 0;
}

参考にした記事