LLVM API(C++)でBrankfuck->LLVM IRコンパイラを作ってみた
前回の記事でLLVMの入門として、Brankfuck->LLVM IRコンパイラを作ったことについて書きました。 前回は生のLLVM IRを扱っていましたが、今回はLLVM APIを利用してBrankfuck->LLVM IRコンパイラを作ってみました。
開発環境
$ llvm-config --version 10.0.0 $ clang++ --version clang version 10.0.0-4ubuntu1 Target: x86_64-pc-linux-gnu Thread model: posix InstalledDir: /usr/bin
main関数とcharを格納するメモリ領域の確保
今回作成するのは、標準入力からBrainFuckソースコードを受け取り、標準出力にそのBrainFuckソースコードに対応するLLVM IRを出力するC++プログラムです。
作業は以下の流れで進めていきます。
1. 出力したいLLVM IRを把握する(LLVM IRは前回の記事で扱ったもの使います)
2. LLVM: LLVMで出力したいLLVM IR命令に相当するLLVM APIを探す
3. LLVM APIを利用するC++プログラムを書く
1. 出力したいLLVM IRを把握する
まず、main関数の作成とcharを格納するメモリ領域確保の処理についてです。
define i32 @main() { %1 = alloca [30000 x i8] %2 = alloca i8* %3 = bitcast [30000 x i8]* %1 to i8* call void @llvm.memset.p0i8.i64(i8* %3, i8 0, i64 30000, i1 false) %4 = getelementptr inbounds [30000 x i8], [30000 x i8]* %1, i64 0, i64 0 store i8* %4, i8** %2 ret i32 0 }
2. LLVM: LLVMで出力したいLLVM IR命令に相当するLLVM APIを探す
define i32 @main() {}
の部分は、LLVM: llvm::Function ClassのCreateメソッドを使います。
alloca
, bitcast
, memset
, getelementptr inbounds
, store
, ret
命令は、LLVM: llvm::IRBuilderBase Classにそれぞれ対応するメソッドがあるのでそれを使います。
3. C++プログラムを書く
llvm::LLVMContext context; llvm::Module module("top", context); llvm::IRBuilder<> builder(context); int main() { llvm::Function *mainFunc = llvm::Function::Create( llvm::FunctionType::get(llvm::Type::getInt32Ty(context), false), llvm::Function::ExternalLinkage, "main", module); llvm::BasicBlock *basicBlock = llvm::BasicBlock::Create(context, "", mainFunc); builder.SetInsertPoint(basicBlock); llvm::Value *array = builder.CreateAlloca( llvm::ArrayType::get(builder.getInt8Ty(), 30000), nullptr); llvm::Value *elementPtr = builder.CreateAlloca(builder.getInt8PtrTy(), nullptr); llvm::Value *arrayPtr = builder.CreateBitCast(array, builder.getInt8PtrTy()); builder.CreateMemSet(arrayPtr, builder.getInt8(0), 30000, llvm::MaybeAlign()); llvm::Value *int64Zero = static_cast<llvm::Value *>(builder.getInt64(0)); const std::array<llvm::Value *, 2> indexes = {int64Zero, int64Zero}; llvm::Value *element = builder.CreateInBoundsGEP( llvm::ArrayType::get(builder.getInt8Ty(), 30000), array, llvm::ArrayRef<llvm::Value *>(indexes)); builder.CreateStore(element, elementPtr, false); builder.CreateRet(builder.getInt32(0)); module.print(llvm::outs(), nullptr); return 0; }
「>」「<」命令の実装
「>」「<」命令は、ポインタのインクリメント・デクリメントです。
LLVM IRは以下のようになっています。
%6 = load i8*, i8** %3, align 8 %7 = getelementptr inbounds i8, i8* %6, i32 1 store i8* %7, i8** %3, align 8 %8 = load i8*, i8** %3, align 8 %9 = getelementptr inbounds i8, i8* %8, i32 -1 store i8* %9, i8** %3, align 8
インクリメントとデクリメントは、getelementptr
の第3引数が違うだけなので一つの関数にまとめました。
void emit_add_ptr(llvm::Value *ptr, int diff) { builder.CreateStore( builder.CreateInBoundsGEP( builder.getInt8Ty(), builder.CreateLoad(builder.getInt8PtrTy(), ptr), builder.getInt32(diff)), ptr); }
「+」「-」命令の実装
「+」「-」命令は、値のインクリメント・デクリメントです。
LLVM IRは以下のようになっています。
%10 = load i8*, i8** %3, align 8 %11 = load i8, i8* %10, align 1 %12 = add i8 %11, 1 store i8 %12, i8* %10, align 1 %13 = load i8*, i8** %3, align 8 %14 = load i8, i8* %13, align 1 %15 = add i8 %14, -1 store i8 %15, i8* %13, align 1
インクリメントとデクリメントは、add
の第2引数が違うだけなので一つの関数にまとめました。
void emit_add(llvm::Value *ptr, int diff) { llvm::Value *tmpPtr = builder.CreateLoad(builder.getInt8PtrTy(), ptr); builder.CreateStore( builder.CreateAdd( builder.CreateLoad(builder.getInt8Ty(), tmpPtr), builder.getInt8(diff)), tmpPtr); }
「[」「]」命令の実装
「[」命令は、ポインタの指す値が0なら、後の]までジャンプ、
「]」命令は、ポインタの指す値が0でなければ、前の[までジャンプ、を示します。
LLVM IRは以下のようになります。
define i32 @main() { ; -----省略----- %15 = load i8*, i8** %3, align 8 %16 = load i8, i8* %15, align 1 %17 = add i8 %16, 1 store i8 %17, i8* %15, align 1 br label %18 18: ; preds = %23, %0 %19 = load i8*, i8** %3, align 8 %20 = load i8, i8* %19, align 1 %21 = zext i8 %20 to i32 %22 = icmp eq i32 %21, 0 br i1 %22, label %23, label %27 23: ; preds = %18 %24 = load i8*, i8** %3, align 8 %25 = load i8, i8* %24, align 1 %26 = add i8 %25, -1 store i8 %26, i8* %24, align 1 br label %18 27: ; preds = %18 %28 = load i8*, i8** %3, align 8 %29 = load i8, i8* %28, align 1 %30 = add i8 %29, 1 store i8 %30, i8* %28, align 1 ; -----省略----- ret i32 0 }
前回の記事では、「[」「]」命令のLLVM IRを出力するコードは以下のようになりました。
void emit_while_start(int while_index) { std::cout << " br label %while_start" << while_index << std::endl; std::cout << "while_start" << while_index << ":" << std::endl; std::cout << " %" << index << " = load i8*, i8** %2, align 8" << std::endl; std::cout << " %" << index + 1 << " = load i8, i8* %" << index << ", align 1" << std::endl; std::cout << " %" << index + 2 << " = icmp ne i8 %" << index + 1 << ", 0" << std::endl; std::cout << " br i1 %" << index + 2 << ", label %while_body" << while_index << ", label %while_end" << while_index << std::endl; std::cout << "while_body" << while_index << ":" << std::endl; index += 3; } void emit_while_end(int while_index) { std::cout << " br label %while_start" << while_index << std::endl; std::cout << "while_end" << while_index << ":" << std::endl; }
LLVM APIでは、while_start
, while_body
, while_end
のようなコードのまとまりを、BasicBlock
を使用して指定します。
struct WhileBlock { llvm::BasicBlock *cond_block; llvm::BasicBlock *body_block; llvm::BasicBlock *end_block; }; int while_index = 0; WhileBlock while_blocks[1000]; WhileBlock *while_block_ptr = while_blocks; void emit_while_start(llvm::Function *func, llvm::Value *ptr, WhileBlock *while_block, int while_index) { while_block->cond_block = llvm::BasicBlock::Create( context, std::string("while_cond") + std::to_string(while_index), func); while_block->body_block = llvm::BasicBlock::Create( context, std::string("while_body") + std::to_string(while_index), func); while_block->end_block = llvm::BasicBlock::Create( context, std::string("while_end") + std::to_string(while_index), func); builder.CreateBr(while_block->cond_block); builder.SetInsertPoint(while_block->cond_block); builder.CreateCondBr( builder.CreateICmpNE( builder.CreateLoad( builder.getInt8Ty(), builder.CreateLoad(builder.getInt8PtrTy(), ptr)), builder.getInt8(0)), while_block->body_block, while_block->end_block); builder.SetInsertPoint(while_block->body_block); } void emit_while_end(WhileBlock *while_block) { builder.CreateBr(while_block->cond_block); builder.SetInsertPoint(while_block->end_block); }
「.」命令の実装
「.」命令は、ポインタの値の出力です。
LLVM IRでは、ポインタの値の出力は以下のようになっています。
%16 = call i32 @getchar() %17 = trunc i32 %16 to i8 %18 = load i8*, i8** %3, align 8 store i8 %17, i8* %18, align 1
LLVM APIを利用したC++のコードは以下のようになります。
void emit_put(llvm::Value *ptr) { llvm::FunctionCallee calleePutChar = module.getOrInsertFunction( "putchar", builder.getInt32Ty(), builder.getInt32Ty()); llvm::ArrayRef<llvm::Value *> argsRefPutChar( builder.CreateZExt( builder.CreateLoad( builder.getInt8Ty(), builder.CreateLoad(builder.getInt8PtrTy(), ptr)), builder.getInt32Ty())); builder.CreateCall( calleePutChar, argsRefPutChar); }
「,」命令の実装
「,」命令は、入力から1バイト読み込んで、ポインタが指す値に代入、を示します。 LLVM IRは以下のようになります。
%16 = call i32 @getchar() %17 = trunc i32 %16 to i8 %18 = load i8*, i8** %3, align 8 store i8 %17, i8* %18, align 1
LLVM APIを利用したC++のコードは以下のようになります。
void emit_get(llvm::Value *ptr) { llvm::FunctionCallee calleeGetChar = module.getOrInsertFunction( "getchar", builder.getInt32Ty()); builder.CreateStore( builder.CreateTrunc( builder.CreateCall(calleeGetChar), builder.getInt8Ty()), builder.CreateLoad(builder.getInt8PtrTy(), ptr)); }
実行してみる
main関数にこれまで作成した関数を呼び出すコードを追加します。
int main() { llvm::Function *mainFunc = llvm::Function::Create( llvm::FunctionType::get(llvm::Type::getInt32Ty(context), false), llvm::Function::ExternalLinkage, "main", module); llvm::BasicBlock *basicBlock = llvm::BasicBlock::Create(context, "", mainFunc); builder.SetInsertPoint(basicBlock); llvm::Value *array = builder.CreateAlloca( llvm::ArrayType::get(builder.getInt8Ty(), 30000), nullptr); llvm::Value *elementPtr = builder.CreateAlloca(builder.getInt8PtrTy(), nullptr); llvm::Value *arrayPtr = builder.CreateBitCast(array, builder.getInt8PtrTy()); builder.CreateMemSet(arrayPtr, builder.getInt8(0), 30000, llvm::MaybeAlign()); llvm::Value *int64Zero = static_cast<llvm::Value *>(builder.getInt64(0)); const std::array<llvm::Value *, 2> indexes = {int64Zero, int64Zero}; llvm::Value *element = builder.CreateInBoundsGEP( llvm::ArrayType::get(builder.getInt8Ty(), 30000), array, llvm::ArrayRef<llvm::Value *>(indexes)); builder.CreateStore(element, elementPtr, false); char c; while ((c = getchar()) != EOF) { switch (c) { case '>': emit_add_ptr(elementPtr, 1); break; case '<': emit_add_ptr(elementPtr, -1); break; case '+': emit_add(elementPtr, 1); break; case '-': emit_add(elementPtr, -1); break; case '[': emit_while_start(mainFunc, elementPtr, while_block_ptr++, while_index++); break; case ']': emit_while_end(--while_block_ptr); break; case '.': emit_put(elementPtr); break; case ',': emit_get(elementPtr); break; } } builder.CreateRet(builder.getInt32(0)); module.print(llvm::outs(), nullptr); return 0; }
以下のコマンドを実行すると、Hello world!
出力されます。
$ clang++ -c $(llvm-config --cxxflags) main.cpp -o main.o $ clang++ main.o $(llvm-config --ldflags --libs) $ echo "+++++++++[>++++++++>+++++++++++>+++>+<<<<-]>.>++.+++++++..+++.>+++++.<<+++++++++++++++.>.+++.------.--------.>+.>+." | ./a.out | lli Hello world!
全コード
#include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/Alignment.h" #include <array> struct WhileBlock { llvm::BasicBlock *cond_block; llvm::BasicBlock *body_block; llvm::BasicBlock *end_block; }; llvm::LLVMContext context; llvm::Module module("top", context); llvm::IRBuilder<> builder(context); int while_index = 0; WhileBlock while_blocks[1000]; WhileBlock *while_block_ptr = while_blocks; void emit_add_ptr(llvm::Value *ptr, int diff) { builder.CreateStore( builder.CreateInBoundsGEP( builder.getInt8Ty(), builder.CreateLoad(builder.getInt8PtrTy(), ptr), builder.getInt32(diff)), ptr); } void emit_add(llvm::Value *ptr, int diff) { llvm::Value *tmpPtr = builder.CreateLoad(builder.getInt8PtrTy(), ptr); builder.CreateStore( builder.CreateAdd( builder.CreateLoad(builder.getInt8Ty(), tmpPtr), builder.getInt8(diff)), tmpPtr); } void emit_while_start(llvm::Function *func, llvm::Value *ptr, WhileBlock *while_block, int while_index) { while_block->cond_block = llvm::BasicBlock::Create( context, std::string("while_cond") + std::to_string(while_index), func); while_block->body_block = llvm::BasicBlock::Create( context, std::string("while_body") + std::to_string(while_index), func); while_block->end_block = llvm::BasicBlock::Create( context, std::string("while_end") + std::to_string(while_index), func); builder.CreateBr(while_block->cond_block); builder.SetInsertPoint(while_block->cond_block); builder.CreateCondBr( builder.CreateICmpNE( builder.CreateLoad( builder.getInt8Ty(), builder.CreateLoad(builder.getInt8PtrTy(), ptr)), builder.getInt8(0)), while_block->body_block, while_block->end_block); builder.SetInsertPoint(while_block->body_block); } void emit_while_end(WhileBlock *while_block) { builder.CreateBr(while_block->cond_block); builder.SetInsertPoint(while_block->end_block); } void emit_put(llvm::Value *ptr) { llvm::FunctionCallee calleePutChar = module.getOrInsertFunction( "putchar", builder.getInt32Ty(), builder.getInt32Ty()); llvm::ArrayRef<llvm::Value *> argsRefPutChar( builder.CreateZExt( builder.CreateLoad( builder.getInt8Ty(), builder.CreateLoad(builder.getInt8PtrTy(), ptr)), builder.getInt32Ty())); builder.CreateCall( calleePutChar, argsRefPutChar); } void emit_get(llvm::Value *ptr) { llvm::FunctionCallee calleeGetChar = module.getOrInsertFunction( "getchar", builder.getInt32Ty()); builder.CreateStore( builder.CreateTrunc( builder.CreateCall(calleeGetChar), builder.getInt8Ty()), builder.CreateLoad(builder.getInt8PtrTy(), ptr)); } int main() { llvm::Function *mainFunc = llvm::Function::Create( llvm::FunctionType::get(llvm::Type::getInt32Ty(context), false), llvm::Function::ExternalLinkage, "main", module); llvm::BasicBlock *basicBlock = llvm::BasicBlock::Create(context, "", mainFunc); builder.SetInsertPoint(basicBlock); llvm::Value *array = builder.CreateAlloca( llvm::ArrayType::get(builder.getInt8Ty(), 30000), nullptr); llvm::Value *elementPtr = builder.CreateAlloca(builder.getInt8PtrTy(), nullptr); llvm::Value *arrayPtr = builder.CreateBitCast(array, builder.getInt8PtrTy()); builder.CreateMemSet(arrayPtr, builder.getInt8(0), 30000, llvm::MaybeAlign()); llvm::Value *int64Zero = static_cast<llvm::Value *>(builder.getInt64(0)); const std::array<llvm::Value *, 2> indexes = {int64Zero, int64Zero}; llvm::Value *element = builder.CreateInBoundsGEP( llvm::ArrayType::get(builder.getInt8Ty(), 30000), array, llvm::ArrayRef<llvm::Value *>(indexes)); builder.CreateStore(element, elementPtr, false); char c; while ((c = getchar()) != EOF) { switch (c) { case '>': emit_add_ptr(elementPtr, 1); break; case '<': emit_add_ptr(elementPtr, -1); break; case '+': emit_add(elementPtr, 1); break; case '-': emit_add(elementPtr, -1); break; case '[': emit_while_start(mainFunc, elementPtr, while_block_ptr++, while_index++); break; case ']': emit_while_end(--while_block_ptr); break; case '.': emit_put(elementPtr); break; case ',': emit_get(elementPtr); break; } } builder.CreateRet(builder.getInt32(0)); module.print(llvm::outs(), nullptr); return 0; }