A Class-less Intro to LLVM

November 28, 2023

I've dabbled with LLVM IR before, but only targeting it directly without using the LLVM tools. So, I thought it would be a fun exercise to work through the official LLVM tutorial to implement Kaleidoscope.

In this post:


Compiler design #

First things first. I have found that many compiler tutorials and toy compilers implement lexers that process one or several tokens at a time (but not all). I don't like this circular compilation and have always used the entire source code as input to the lexer.

I also want a clear separation between compiler parts, i.e. the tokenizer tokenize the source code into tokens, the parser parse the tokens into a tree, and the interpreter interpret the tree, or the compiler compile the tree into lower level code.

// tokens
std::vector<Token> tokens = tokenize(source);
// ast
AST tree = parse(tokens);
// generate llvm ir
llvm::Value *result = codegen(tree);

I don't know why this architecture is not more common in the wild.

(Or am I missing something obvious?)

A valid LLVM IR has a main function with return type i32, so to make things easier for myself, I'll be using i32 instead of double (the LLVM tutorial use double).

Note that I use tokenizer and lexer interchangeably. It just sounds better to say tokenizer tokenize than lexer lexes, or lexe, or lex? There's no interpret or compile part since not really interpreting or compiling, just mapping AST to LLVM values.

LLVM in C++ will be fun... right? #

After going through about half of the tutorial. My answer is: Not so much.

The official LLVM tutorial implementation is a mess. There's classes, constructors, initializers, members functions, and getters for encapsulated values everywhere. It's nearly impossible to follow along amongst all the abstractions.

I can't help to think that this code style is bad. What if you find some code that is particularly neat and want to copy-paste it into another project, but you can't, or at least not without bringing the rest of the randomly dispersed dependencies. This kind of tutorials and code examples would likely make me not want to use C++ at all, especially coming from C. It's just too far from what a program in C looks like.

Why not C-style programming in C++? I found this mirror of a now-deleted post by John Carmack on Functional Programming in C++. No one would call John Carmack a bad programmer, so it's probably safe to adopt this kind of style.

Reimplementing the LLVM tutorial in functional-style #

Lexer #

The LLVM tutorial lexer is not that bad, although it's not tokenizing operators like +, -, *, and /. I think it's a good idea to have a token for each operator, so I'll do that.

I'm also providing the entire source as input to the lexer, which gives back a list of tokens.

std::string program = R"(
    4+5*2;
)";
std::regex pattern(R"(^\s+)", std::regex::multiline);
std::string source = std::regex_replace(program, pattern, "");
std::vector<Token> tokens = tokenize(source);
for (auto &token : tokens) { std::cout << token_to_string(token) << std::endl; }
// tok_number(4)
// tok_plus(+)
// tok_number(5)
// tok_multiply(*)
// tok_number(2)
// tok_semicolon(;)

A list of tokens makes it trivial to check token before or after, which makes it easier to parse certain grammars. But otherwise not so different.

AST #

The AST implementation though... so much boilerplate. There's one class definition per grammar rule, each with constructor and codegen function. Below is my implementation.

struct AST {
    Token token;
    std::vector<AST> children;
};

void print_tree(const AST& node, int level = 0) {
    std::cout << std::string(level * 2, ' ') << token_to_string(node.token) << std::endl;
    for (const AST& child : node.children) {
        print_tree(child, level + 1);
    }
}

I even tossed in a pretty printer for trees and still only 11 lines of readable code with only a single non-inline function.

Parser #

The LLVM tutorial parser is also rather funky. It's recursive-descent and use map to determine operator precedence, i.e. "a combination of recursive descent parsing and operator-precedence parsing". Why not just add a grammar rule for each operator when already implementing recursive-descent?

(Am I missing something obvious again?)

Anyway, my parser implementation is 80 lines of readable code and much more elegant in my opinion. One function per grammar rule.

// parse : factor ::= tok_number
AST parse_factor(const std::vector<Token>& tokens, size_t& token_index) {
    AST expression;
    AST factor;
    Token token = tokens[token_index];
    switch (token.type) {
        case TokenType::tok_number:
            token_index++;
            return AST { token };
        default:
            log_error("parse_factor : unexpected token", token.lexeme);
            token_index++;
            return AST { token };
    }
}
// parse : term ::= factor (('*' | '/') factor)*
AST parse_term(const std::vector<Token>& tokens, size_t& token_index) {
    AST term = parse_factor(tokens, token_index);
    while (tokens[token_index].type == TokenType::tok_multiply || tokens[token_index].type == TokenType::tok_divide) {
        Token token = tokens[token_index];
        token_index++;
        // factor
        auto children = { term, parse_factor(tokens, token_index) };
        term = AST { token, children };
    }
    return term;
}
// parse : expression ::= term ('+' term)* ';'
AST parse_expression(const std::vector<Token>& tokens, size_t& token_index) {
    AST expression = parse_term(tokens, token_index);
    while (tokens[token_index].type == TokenType::tok_plus || tokens[token_index].type == TokenType::tok_minus) {
        Token token = tokens[token_index];
        token_index++;
        // term
        auto children = { expression, parse_term(tokens, token_index) };
        expression = AST { token, children };
    }
    // semicolon
    if (tokens[token_index].type != TokenType::tok_semicolon) {
        log_error("parse_expression : expected ';'", tokens[token_index].lexeme);
    }
    token_index++;
    return expression;
}
// parse
AST parse(const std::vector<Token>& tokens) {
    size_t token_index = 0;
    AST program;
    program.token = Token { TokenType::tok_bof, "" };
    while (token_index < tokens.size()) {
        program.children.push_back(parse_expression(tokens, token_index));
    }
    return program;
}

Giveth the tokens and get back a tree.

AST tree = parse(tokens);
print_tree(tree);
/*
< BOF >
  tok_plus(+)
    tok_number(4)
    tok_multiply(*)
      tok_number(5)
      tok_number(2)
*/

Codegen #

The LLVM tutorial codegen part is not that bad, but not that good either.

llvm::Value *NumberExprAST::codegen() {
    return llvm::ConstantFP::get(*context, llvm::APFloat(num));
}

Where did the num come from? Is this even correct? Or is it number, val, value, m_num? I don't know. I'll have to jump up and down in the code to find out.

Here's my implementation of the same codegen function:

llvm::Value *codegen_number(const AST& tree) {
    return llvm::ConstantInt::get(*context, llvm::APInt(32, tree.token.number, true));
}

This is much clearer. There's nothing going in or out that is not obvious, except maybe 32, which is integer width, and true for signed. The token is defined in one place and have three distinct values, type, lexeme, and number.

The LLVM tutorial implements a top-level parser with separate handlers to orchestrate the calls to codegen functions.

static void handle_toplevel_expression() {
    if (auto function_ast = parse_toplevel_expr()) {
        if (auto *function_ir = function_ast->codegen()) {
            function_ir->print(llvm::errs());
            fprintf(stderr, "\n");
            // remove anonymous expression
            function_ir->eraseFromParent();
        }
    } else {
        get_next_token(); // skip for error recovery
    }
}
// top ::= expression | ';'
static void main_loop() {
    while (true) {
        switch (current_token) {
        case tok_eof:
            return;
        case ';': // ignore
            get_next_token();
            break;
        default:
            handle_toplevel_expression();
            break;
        }
    }
}

What is even going on here? Skip for error recovery? What is eraseFromParent doing exactly? It's also ignoring semicolons, not even trying to crash violently and loudly when users forget to terminate their expressions (it's wrapped inside a REPL and used to terminate line, so not that bad, but still).

It took me a while to dissect, but below is my implementation.

llvm::Value *codegen(const AST& tree) {
    switch (tree.token.type) {
    case TokenType::tok_bof:
        for (const AST& child : tree.children) {
            return codegen(child);
        }
        return nullptr;
    case TokenType::tok_eof:
        return nullptr;
    case TokenType::tok_number:
        return codegen_number(tree);
    case TokenType::tok_plus:
    case TokenType::tok_minus:
    case TokenType::tok_multiply:
    case TokenType::tok_divide:
        return codegen_binop(tree);
    default:
        return log_error_value("codegen : unexpected token", token_to_string(tree.token));
    }
}

There's still some magic to perform to get valid LLVM IR (that is executable).

llvm::FunctionType *function_type = llvm::FunctionType::get(builder->getInt32Ty(), false);
llvm::Function *function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, "main", module.get());
llvm::BasicBlock *entry_block = llvm::BasicBlock::Create(*context, "entry", function);
builder->SetInsertPoint(entry_block);
builder->CreateRet(result);
module->print(llvm::outs(), nullptr);

Conclusion #

Running clang++ -g -O0 fparser.cpp -o fparser `llvm-config --cxxflags --ldflags --system-libs --libs core` && ./fparser outputs the LLVM IR.

; ModuleID = 'llvm codegen'
source_filename = "llvm codegen"

define i32 @main() {
entry:
  ret i32 14
}

We can also redirect the output to out.ll and execute directly with lli out.ll; echo $?.

14

In the end, both implementations output the same LLVM IR. My implementation is a hundred or so less lines of readable code and use zero classes. I can basically drag and drop functions to another project and given the same input it would just work.

Don't be a classist, go Class-less.

Full code listing #

#include <iostream>
#include <regex>
#include <string>
#include <vector>
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/IR/Verifier.h"
#include "llvm/ADT/APInt.h"

#define log_error(message, lexeme) std::cout << "error : " << message << " : " << lexeme << std::endl; exit(1);

// --------------------------------------------------------
// lexer

enum class TokenType {
    tok_number,
    tok_plus,
    tok_minus,
    tok_multiply,
    tok_divide,
    tok_semicolon,
    tok_bof,
    tok_eof,
};

struct Token {
    TokenType type;
    std::string lexeme;
    int number;
};

std::string token_to_string(const Token& token) {
    switch (token.type) {
        case TokenType::tok_bof:
            return "< BOF >";
        case TokenType::tok_eof:
            return "";
        case TokenType::tok_semicolon:
            return "tok_semicolon(;)";
        case TokenType::tok_plus:
            return "tok_plus(+)";
        case TokenType::tok_minus:
            return "tok_minus(-)";
        case TokenType::tok_multiply:
            return "tok_multiply(*)";
        case TokenType::tok_divide:
            return "tok_divide(/)";
        case TokenType::tok_number:
            return "tok_number(" + token.lexeme + ")";
        default:
            return "?";
    }
}

std::vector<Token> tokenize(const std::string& source) {
    std::vector<Token> tokens;
    size_t ch_index = 0;
    while (ch_index < source.size()) {
        char ch = source[ch_index];
        switch (ch) {
            case ' ':
            case '\t':
            case '\r':
            case '\n':
                ch_index++;
                break;
            case ';':
                tokens.push_back({ TokenType::tok_semicolon, ";" });
                ch_index++;
                break;
            case '+':
                tokens.push_back({ TokenType::tok_plus, "+" });
                ch_index++;
                break;
            case '-':
                tokens.push_back({ TokenType::tok_minus, "-" });
                ch_index++;
                break;
            case '*':
                tokens.push_back({ TokenType::tok_multiply, "*" });
                ch_index++;
                break;
            case '/':
                tokens.push_back({ TokenType::tok_divide, "/" });
                ch_index++;
                break;
            default:
                if (isdigit(ch)) {
                    size_t start = ch_index;
                    while (isdigit(source[ch_index]) && ch_index < source.size()) {
                        ch_index++;
                    }
                    std::string lexeme = source.substr(start, ch_index - start);
                    // tokens.push_back({ TokenType::tok_number, lexeme, std::stod(lexeme) });
                    tokens.push_back({ TokenType::tok_number, lexeme, std::stoi(lexeme) });
                } else if (isalpha(ch)) {
                    size_t start = ch_index;
                    while (isalnum(source[ch_index]) && ch_index < source.size()) {
                        ch_index++;
                    }
                    std::string lexeme = source.substr(start, ch_index - start);
                    // not implemented yet
                    log_error("tokenize : unexpected identifier", lexeme);
                } else {
                    log_error("tokenize : unexpected character", ch);
                    ch_index++;
                }
        }
    }
    return tokens;
}

// --------------------------------------------------------
// ast

struct AST {
    Token token;
    std::vector<AST> children;
};

void print_tree(const AST& node, int level = 0) {
    std::cout << std::string(level * 2, ' ') << token_to_string(node.token) << std::endl;
    for (const AST& child : node.children) {
        print_tree(child, level + 1);
    }
}

// --------------------------------------------------------
// parser

// parse : factor ::= tok_number
AST parse_factor(const std::vector<Token>& tokens, size_t& token_index) {
    AST expression;
    AST factor;
    Token token = tokens[token_index];
    switch (token.type) {
        case TokenType::tok_number:
            token_index++;
            return AST { token };
        default:
            log_error("parse_factor : unexpected token", token.lexeme);
            token_index++;
            return AST { token };
    }
}
// parse : term ::= factor (('*' | '/') factor)*
AST parse_term(const std::vector<Token>& tokens, size_t& token_index) {
    AST term = parse_factor(tokens, token_index);
    while (tokens[token_index].type == TokenType::tok_multiply || tokens[token_index].type == TokenType::tok_divide) {
        Token token = tokens[token_index];
        token_index++;
        // factor
        auto children = { term, parse_factor(tokens, token_index) };
        term = AST { token, children };
    }
    return term;
}
// parse : expression ::= term ('+' term)* ';'
AST parse_expression(const std::vector<Token>& tokens, size_t& token_index) {
    AST expression = parse_term(tokens, token_index);
    while (tokens[token_index].type == TokenType::tok_plus || tokens[token_index].type == TokenType::tok_minus) {
        Token token = tokens[token_index];
        token_index++;
        // term
        auto children = { expression, parse_term(tokens, token_index) };
        expression = AST { token, children };
    }
    // semicolon
    if (tokens[token_index].type != TokenType::tok_semicolon) {
        log_error("parse_expression : expected ';'", tokens[token_index].lexeme);
    }
    token_index++;
    return expression;
}
// parse
AST parse(const std::vector<Token>& tokens) {
    size_t token_index = 0;
    AST program;
    program.token = Token { TokenType::tok_bof, "" };
    while (token_index < tokens.size()) {
        program.children.push_back(parse_expression(tokens, token_index));
    }
    return program;
}

// --------------------------------------------------------
// codegen

static std::unique_ptr<llvm::LLVMContext> context;
static std::unique_ptr<llvm::Module> module;
static std::unique_ptr<llvm::IRBuilder<>> builder;

llvm::Value *codegen(const AST& tree);

llvm::Value *log_error_value(const std::string message, const std::string lexeme) {
    log_error(message, lexeme);
    return nullptr;
}

// codegen : binop
llvm::Value *codegen_binop(const AST& tree) {
    llvm::Value *left = codegen(tree.children[0]);
    llvm::Value *right = codegen(tree.children[1]);
    if (!left || !right) {
        return nullptr;
    }
    if (tree.token.type == TokenType::tok_plus) {
        return builder->CreateAdd(left, right, "addtmp");
    }
    if (tree.token.type == TokenType::tok_minus) {
        return builder->CreateSub(left, right, "subtmp");
    }
    if (tree.token.type == TokenType::tok_multiply) {
        return builder->CreateMul(left, right, "multmp");
    }
    if (tree.token.type == TokenType::tok_divide) {
        return builder->CreateSDiv(left, right, "divtmp");
    }
    return log_error_value("codegen_binop : unexpected token", token_to_string(tree.token));
}
// codegen : tok_number
llvm::Value *codegen_number(const AST& tree) {
    return llvm::ConstantInt::get(*context, llvm::APInt(32, tree.token.number, true));
}
// codegen
llvm::Value *codegen(const AST& tree) {
    switch (tree.token.type) {
    case TokenType::tok_bof:
        for (const AST& child : tree.children) {
            return codegen(child);
        }
        return nullptr;
    case TokenType::tok_eof:
        return nullptr;
    case TokenType::tok_number:
        return codegen_number(tree);
    case TokenType::tok_plus:
    case TokenType::tok_minus:
    case TokenType::tok_multiply:
    case TokenType::tok_divide:
        return codegen_binop(tree);
    default:
        return log_error_value("codegen : unexpected token", token_to_string(tree.token));
    }
}

// --------------------------------------------------------
// driver

// clang++ -g -O0 fparser.cpp -o fparser `llvm-config --cxxflags --ldflags --system-libs --libs core` && ./fparser
// clang++ -g -O0 fparser.cpp -o fparser `llvm-config --cxxflags --ldflags --system-libs --libs core` && ./fparser > out.ll 
// lli out.ll; echo $?
int main() {
    std::string program = R"(
        4+5*2;
    )";
    std::regex pattern(R"(^\s+)", std::regex::multiline);
    std::string source = std::regex_replace(program, pattern, "");
    // tokens
    std::vector<Token> tokens = tokenize(source);
    // for (auto &token : tokens) { std::cout << token_to_string(token) << std::endl; }
    // ast
    AST tree = parse(tokens);
    // print_tree(tree);
    // llvm
    context = std::make_unique<llvm::LLVMContext>();
    module = std::make_unique<llvm::Module>("llvm codegen", *context);
    builder = std::make_unique<llvm::IRBuilder<>>(*context);
    // generate llvm ir
    llvm::Value *result = codegen(tree);
    // create function type
    llvm::FunctionType *function_type = llvm::FunctionType::get(builder->getInt32Ty(), false);
    // create function
    llvm::Function *function = llvm::Function::Create(function_type, llvm::Function::ExternalLinkage, "main", module.get());
    // create entry block
    llvm::BasicBlock *entry_block = llvm::BasicBlock::Create(*context, "entry", function);
    builder->SetInsertPoint(entry_block);
    // create return instruction
    builder->CreateRet(result);
    // verify
    llvm::verifyFunction(*function);
    // print llvm ir
    module->print(llvm::outs(), nullptr);
}