ZZY 172d72b0a0 feat(backend/riscv32): 实现基础的编译器功能
- 完成 RV32IMA 指令集的代码生成
- 添加整数运算、分支、调用等基本指令支持
- 实现从 IR 到机器码的转换
- 添加简单的测试用例和测试框架
2025-03-08 16:50:21 +08:00

440 lines
13 KiB
C

#include "ir.h"
#include "ir_lib.h"
#include "ir_type.h"
#include "../frontend/frontend.h"
// 上下文结构,记录生成过程中的状态
typedef struct {
ir_func_t* cur_func; // 当前处理的函数
ir_bblock_t* cur_block; // 当前基本块
} IRGenContext;
IRGenContext ctx;
ir_prog_t prog;
static void emit_instr(ir_bblock_t* block, ir_node_t* node) {
if (block == NULL) block = ctx.cur_block;
vector_push(block->instrs, node);
// return &(vector_at(block->instrs, block->instrs.size - 1));
}
static ir_node_t* emit_br(ir_node_t* cond, ir_bblock_t* trueb, ir_bblock_t* falseb) {
ir_node_t* br = new_ir_node(NULL, IR_NODE_BRANCH);
emit_instr(NULL, br);
br->data.branch.cond = cond;
br->data.branch.true_bblock = trueb;
br->data.branch.false_bblock = falseb;
return br;
}
static ir_node_t* gen_ir_expr(ast_node_t* node);
static ir_node_t* gen_ir_term(ast_node_t* node) {
switch (node->type) {
case NT_TERM_VAL: {
ir_node_t* ir = new_ir_node(NULL, IR_NODE_CONST_INT);
ir->data.const_int.val = node->syms.tok.val.i;
return ir;
}
case NT_TERM_IDENT: {
ir_node_t* decl = node->syms.decl_node->decl_val.data;
return decl;
}
case NT_TERM_CALL: {
ir_node_t* call = new_ir_node(NULL, IR_NODE_CALL);
call->data.call.callee = node->call.func_decl->decl_func.def->func.data;
for (int i = 0; i < node->call.params->params.params.size; i++) {
ast_node_t* param = vector_at(node->call.params->params.params, i);
ir_node_t *tmp = gen_ir_expr(param);
vector_push(call->data.call.args, tmp);
}
emit_instr(NULL, call);
return call;
}
default: {
assert(0);
}
}
}
static ir_node_t* gen_ir_expr(ast_node_t* node) {
// term node
switch (node->type) {
case NT_TERM_VAL:
case NT_TERM_IDENT:
case NT_TERM_CALL:
return gen_ir_term(node);
default:
break;
}
ir_node_t* lhs = gen_ir_expr(node->expr.left);
ir_node_t* rhs = node->expr.right ? gen_ir_expr(node->expr.right) : NULL;
if (node->type == NT_COMMA) {
return rhs;
}
ir_node_t* instr = NULL;
vector_push(lhs->used_by, instr);
if (rhs) { vector_push(rhs->used_by, instr); }
ir_node_t* ret;
#define BINOP(operand) do { \
instr = new_ir_node(NULL, IR_NODE_OP); \
instr->data.op.op = operand; \
instr->data.op.lhs = lhs; \
instr->data.op.rhs = rhs; \
ret = instr; \
} while (0)
switch (node->type) {
case NT_ADD: {
// (expr) + (expr)
BINOP(IR_OP_ADD); break;
}
case NT_SUB: {
// (expr) - (expr)
BINOP(IR_OP_SUB); break;
}
case NT_MUL: {
// (expr) * (expr)
BINOP(IR_OP_MUL); break;
}
case NT_DIV: {
// (expr) / (expr)
BINOP(IR_OP_DIV); break;
}
case NT_MOD: {
// (expr) % (expr)
BINOP(IR_OP_MOD); break;
}
case NT_AND: {
// (expr) & (expr)
BINOP(IR_OP_AND); break;
}
case NT_OR: {
// (expr) | (expr)
BINOP(IR_OP_OR); break;
}
case NT_XOR: {
// (expr) ^ (expr)
BINOP(IR_OP_XOR); break;
}
case NT_BIT_NOT: {
// ~ (expr)
// TODO
// BINOP(IR_OP_NOT);
break;
}
case NT_L_SH: {
// (expr) << (expr)
BINOP(IR_OP_SHL);
break;
}
case NT_R_SH: {
// (expr) >> (expr)
BINOP(IR_OP_SHR); // Shift right logical.
// TODO
// BINOP(IR_OP_SAR); // Shift right arithmetic.
break;
}
case NT_EQ: {
// (expr) == (expr)
BINOP(IR_OP_EQ); break;
}
case NT_NEQ: {
// (expr) != (expr)
BINOP(IR_OP_NEQ); break;
}
case NT_LE: {
// (expr) <= (expr)
BINOP(IR_OP_LE); break;
}
case NT_GE: {
// (expr) >= (expr)
BINOP(IR_OP_GE); break;
}
case NT_LT: {
// (expr) < (expr)
BINOP(IR_OP_LT); break;
}
case NT_GT: {
// (expr) > (expr)
BINOP(IR_OP_GE); break;
}
case NT_AND_AND:// (expr) && (expr)
error("unimpliment");
break;
case NT_OR_OR:// (expr) || (expr)
error("unimpliment");
break;
case NT_NOT: {
// ! (expr)
instr = new_ir_node(NULL, IR_NODE_OP);
instr->data.op.op = IR_OP_EQ,
instr->data.op.lhs = &node_zero,
instr->data.op.rhs = lhs,
ret = instr;
break;
}
case NT_ASSIGN: {
// (expr) = (expr)
instr = new_ir_node(NULL, IR_NODE_STORE);
instr->data.store.target = lhs;
instr->data.store.value = rhs;
ret = rhs;
break;
}
// case NT_COND: // (expr) ? (expr) : (expr)
default: {
// TODO self error msg
error("Unsupported IR generation for AST node type %d", node->type);
break;
}
}
emit_instr(NULL, instr);
return ret;
}
static void gen_ir_func(ast_node_t* node, ir_func_t* func) {
assert(node->type == NT_FUNC);
ir_bblock_t *entry = new_ir_bblock("entry");
vector_push(func->bblocks, entry);
vector_push(prog.funcs, func);
IRGenContext prev_ctx = ctx;
ctx.cur_func = func;
ctx.cur_block = entry;
ast_node_t* params = node->func.decl->decl_func.params;
for (int i = 0; i < params->params.params.size; i ++) {
ast_node_t* param = params->params.params.data[i];
ir_node_t* decl = new_ir_node(param->decl_val.name->syms.tok.val.str, IR_NODE_ALLOC);
emit_instr(entry, decl);
vector_push(func->params, decl);
// TODO Typing system
decl->type = &type_i32;
param->decl_val.data = decl;
}
gen_ir_from_ast(node->func.body);
ctx = prev_ctx;
}
void gen_ir_jmp(ast_node_t* node) {
ir_bblock_t *bblocks[3];
for (int i = 0; i < sizeof(bblocks)/sizeof(bblocks[0]); i++) {
bblocks[i] = new_ir_bblock(NULL);
vector_push(ctx.cur_func->bblocks, bblocks[i]);
}
#define NEW_IR_JMP(name, block) do { \
name = new_ir_node(NULL, IR_NODE_JUMP); \
name->data.jump.target_bblock = block; \
} while (0)
switch (node->type) {
case NT_STMT_IF: {
ir_bblock_t* trueb = bblocks[0];
ir_bblock_t* falseb = bblocks[1];
ir_bblock_t* endb = bblocks[2];
ir_node_t* jmp;
// cond
ir_node_t *cond = gen_ir_expr(node->if_stmt.cond);
emit_br(cond, trueb, falseb);
// true block
vector_push(ctx.cur_func->bblocks, trueb);
ctx.cur_block = trueb;
gen_ir_from_ast(node->if_stmt.if_stmt);
// else block
if (node->if_stmt.else_stmt != NULL) {
vector_push(ctx.cur_func->bblocks, falseb);
ctx.cur_block = falseb;
gen_ir_from_ast(node->if_stmt.else_stmt);
ir_node_t* jmp;
ctx.cur_block = endb;
vector_push(ctx.cur_func->bblocks, ctx.cur_block);
NEW_IR_JMP(jmp, ctx.cur_block);
emit_instr(falseb, jmp);
} else {
ctx.cur_block = falseb;
}
NEW_IR_JMP(jmp, ctx.cur_block);
emit_instr(trueb, jmp);
break;
}
case NT_STMT_WHILE: {
ir_bblock_t* entryb = bblocks[0];
ir_bblock_t* bodyb = bblocks[1];
ir_bblock_t* endb = bblocks[2];
ir_node_t* entry;
NEW_IR_JMP(entry, entryb);
emit_instr(NULL, entry);
// Entry:
ctx.cur_block = entryb;
ir_node_t *cond = gen_ir_expr(node->while_stmt.cond);
emit_br(cond, bodyb, endb);
// Body:
ir_node_t* jmp;
ctx.cur_block = bodyb;
gen_ir_from_ast(node->while_stmt.body);
NEW_IR_JMP(jmp, entryb);
emit_instr(NULL, jmp);
// End:
ctx.cur_block = endb;
break;
}
case NT_STMT_DOWHILE: {
ir_bblock_t* entryb = bblocks[0];
ir_bblock_t* bodyb = bblocks[1];
ir_bblock_t* endb = bblocks[2];
ir_node_t* entry;
NEW_IR_JMP(entry, bodyb);
emit_instr(NULL, entry);
// Body:
ctx.cur_block = bodyb;
gen_ir_from_ast(node->do_while_stmt.body);
ir_node_t* jmp;
NEW_IR_JMP(jmp, entryb);
emit_instr(NULL, jmp);
// Entry:
ctx.cur_block = entryb;
ir_node_t *cond = gen_ir_expr(node->do_while_stmt.cond);
emit_br(cond, bodyb, endb);
// End:
ctx.cur_block = endb;
break;
}
case NT_STMT_FOR: {
ir_bblock_t* entryb = bblocks[0];
ir_bblock_t* bodyb = bblocks[1];
ir_bblock_t* endb = bblocks[2];
if (node->for_stmt.init) {
gen_ir_from_ast(node->for_stmt.init);
}
ir_node_t* entry;
NEW_IR_JMP(entry, entryb);
emit_instr(NULL, entry);
// Entry:
ctx.cur_block = entryb;
if (node->for_stmt.cond) {
ir_node_t *cond = gen_ir_expr(node->for_stmt.cond);
emit_br(cond, bodyb, endb);
} else {
ir_node_t* jmp;
NEW_IR_JMP(jmp, bodyb);
}
// Body:
ctx.cur_block = bodyb;
gen_ir_from_ast(node->for_stmt.body);
if (node->for_stmt.iter) {
gen_ir_expr(node->for_stmt.iter);
}
ir_node_t* jmp;
NEW_IR_JMP(jmp, entryb);
emit_instr(NULL, jmp);
// End:
ctx.cur_block = endb;
break;
}
default:
error("ir jmp can't hit here");
}
}
void gen_ir_from_ast(ast_node_t* node) {
switch (node->type) {
case NT_ROOT: {
for (int i = 0; i < node->root.children.size; i ++) {
gen_ir_from_ast(node->root.children.data[i]);
}
break;
}
case NT_DECL_FUNC: {
ir_func_t* func = new_ir_func(node->decl_func.name->syms.tok.val.str, &type_i32);
if (node->decl_func.def == NULL) {
ast_node_t* def = new_ast_node();
def->func.body = NULL;
def->func.decl = node;
node->decl_func.def = def;
vector_push(prog.extern_funcs, func);
}
node->decl_func.def->func.data = func;
break;
}
case NT_FUNC: {
gen_ir_func(node, node->func.data);
break;
}
case NT_STMT_RETURN: {
ir_node_t* ret = NULL;
if (node->return_stmt.expr_stmt != NULL) {
ret = gen_ir_expr(node->return_stmt.expr_stmt);
}
ir_node_t* ir = new_ir_node(NULL, IR_NODE_RET);
ir->data.ret.ret_val = ret;
emit_instr(NULL, ir);
ir_bblock_t* block = new_ir_bblock(NULL);
ctx.cur_block = block;
vector_push(ctx.cur_func->bblocks, block);
break;
}
case NT_STMT_BLOCK: {
gen_ir_from_ast(node->block_stmt.block);
break;
}
case NT_BLOCK: {
for (int i = 0; i < node->block.children.size; i ++) {
gen_ir_from_ast(node->block.children.data[i]);
}
break;
}
case NT_STMT_IF:
case NT_STMT_WHILE:
case NT_STMT_DOWHILE:
case NT_STMT_FOR:
gen_ir_jmp(node);
break;
case NT_DECL_VAR: {
ir_node_t* ir = new_ir_node(node->decl_val.name->syms.tok.val.str, IR_NODE_ALLOC);
emit_instr(NULL, ir);
// TODO Typing system
ir->type = &type_i32;
node->decl_val.data = ir;
if (node->decl_val.expr_stmt != NULL) {
gen_ir_from_ast(node->decl_val.expr_stmt);
}
break;
}
case NT_STMT_EXPR: {
gen_ir_expr(node->expr_stmt.expr_stmt);
break;
}
case NT_STMT_EMPTY: {
break;
}
default:
// TODO: 错误处理
error("unknown node type");
break;
}
}