#include "riscv32.h" #include typedef struct { ir_func_t* func; int stack_offset; int stack_base; int func_idx; int block_idx; } gen_ctx_t; static inline int stack_pos(ir_node_t* ptr, gen_ctx_t *ctx) { // ir_func_t *func, int stack_base, int stack_offset int offset = ctx->stack_base; for (int i = 0; i < ctx->func->bblocks.size; i ++) { ir_bblock_t* block = vector_at(ctx->func->bblocks, i); for (int i = 0; i < block->instrs.size; i++) { if (vector_at(block->instrs, i) == ptr) { offset += i * 4; Assert(offset >= 0 && offset < ctx->stack_offset); return offset; } } offset += block->instrs.size * 4; } Panic("stack pos got error"); return 0; } static int system_func(const char* name) { static struct { const char* name; int ecall_num; } defined_func[] = { {"ecall_pnt_int", 1}, {"ecall_pnt_char", 11}, {"ecall_scan_int", 1025 + 4}, }; for (int i = 0; i < sizeof(defined_func)/sizeof(defined_func[0]); i++) { if (rt_strcmp(name, defined_func[i].name) == 0) { return defined_func[i].ecall_num; } } return -1; } static int get_node_val(rv32_prog_t* out_asm, gen_ctx_t* ctx, ir_node_t* ptr, int reg) { int len = 0; switch (ptr->tag) { case IR_NODE_CONST_INT: { // TODO rv32_li(out_asm, reg, ptr->data.const_int.val); // emit_rv32_instr(out_asm, RV_ADDI, reg, reg, 0, ptr->data.const_int.val); break; } default: { int offset = stack_pos(ptr, ctx); rv32_lw(out_asm, reg, REG_SP, offset); break; } } return len; } static int gen_instr(rv32_prog_t* out_asm, gen_ctx_t* ctx, ir_node_t* instr) { int idx = 0; int offset; char buf[1024]; symasm_entry_t label; switch (instr->tag) { case IR_NODE_ALLOC: { // TODO break; } case IR_NODE_LOAD: { offset = stack_pos(instr->data.load.target, ctx); // t0 = M[sp + offset] rv32_lw(out_asm, REG_T0, REG_SP, offset); break; } case IR_NODE_STORE: { idx += get_node_val(out_asm, ctx, instr->data.store.value, REG_T0); offset = stack_pos(instr->data.store.target, ctx); // M[sp + offset] = t0 rv32_sw(out_asm, REG_T0, REG_SP, offset); break; } case IR_NODE_RET: { // A0 = S0 if (instr->data.ret.ret_val != NULL) { idx += get_node_val(out_asm, ctx, instr->data.ret.ret_val, REG_A0); } // ra = M[sp + 0] rv32_lw(out_asm, REG_RA, REG_SP, 0); // sp = sp + stack_offset rv32_addi(out_asm, REG_SP, REG_SP, ctx->stack_offset); // ret == JALR(REG_X0, REG_RA, 0) rv32_ret(out_asm); break; } case IR_NODE_OP: { idx += get_node_val(out_asm, ctx, instr->data.op.lhs, REG_T1); idx += get_node_val(out_asm, ctx, instr->data.op.rhs, REG_T2); rv32_instr_t _instr = { .rd = REG_T0, .rs1 = REG_T1, .rs2 = REG_T2, .imm = 0 }; #define GEN_BIN_OP(type) _instr.instr_type = type, \ emit_rv32_instr(out_asm, &_instr, EMIT_PUSH_BACK, NULL) switch (instr->data.op.op) { case IR_OP_ADD: GEN_BIN_OP(RV_ADD); break; case IR_OP_SUB: GEN_BIN_OP(RV_SUB); break; case IR_OP_MUL: GEN_BIN_OP(RV_MUL); break; case IR_OP_DIV: GEN_BIN_OP(RV_DIV); break; case IR_OP_MOD: GEN_BIN_OP(RV_REM); break; case IR_OP_EQ: GEN_BIN_OP(RV_XOR); rv32_seqz(out_asm, REG_T0, REG_T0); break; case IR_OP_GE: GEN_BIN_OP(RV_SLT); rv32_seqz(out_asm, REG_T0, REG_T0); break; case IR_OP_GT: // SGT(rd, rs1, rs2) SLT(rd, rs2, rs1) // GENCODE(SGT(REG_T0, REG_T1, REG_T2)); rv32_slt(out_asm, REG_T0, REG_T2, REG_T1); break; case IR_OP_LE: // GENCODE(SGT(REG_T0, REG_T1, REG_T2)); rv32_slt(out_asm, REG_T0, REG_T2, REG_T1); rv32_seqz(out_asm, REG_T0, REG_T0); break; case IR_OP_LT: rv32_slt(out_asm, REG_T0, REG_T1, REG_T2); break; case IR_OP_NEQ: GEN_BIN_OP(RV_XOR); break; default: LOG_ERROR("ERROR gen_instr op in riscv"); break; } offset = stack_pos(instr, ctx); rv32_sw(out_asm, REG_T0, REG_SP, offset); break; } case IR_NODE_BRANCH: { get_node_val(out_asm, ctx, instr->data.branch.cond, REG_T0); rt.snprintf(buf, sizeof(buf), "L%s%p", instr->data.branch.true_bblock->label, instr->data.branch.true_bblock); label.name = strpool_intern(out_asm->strpool, buf); label.attr = LOCAL; rv32_bne_l(out_asm, REG_T0, REG_X0, &label); rt.snprintf(buf, sizeof(buf), "L%s%p", instr->data.branch.false_bblock->label, instr->data.branch.false_bblock); label.name = strpool_intern(out_asm->strpool, buf); label.attr = LOCAL; rv32_jal_l(out_asm, REG_X0, &label); break; } case IR_NODE_JUMP: { // TODO rt.snprintf(buf, sizeof(buf), "L%s%p", instr->data.jump.target_bblock->label, instr->data.jump.target_bblock); label.name = strpool_intern(out_asm->strpool, buf); label.attr = LOCAL; rv32_jal_l(out_asm, REG_X0, &label); break; } case IR_NODE_CALL: { if (instr->data.call.args.size > 8) { LOG_ERROR("can't add so much params"); } int param_regs[8] = { REG_A0, REG_A1, REG_A2, REG_A3, REG_A4, REG_A5, REG_A6, REG_A7 }; for (int i = 0; i < instr->data.call.args.size; i++) { ir_node_t* param = vector_at(instr->data.call.args, i); idx += get_node_val(out_asm, ctx, param, param_regs[i]); } int system_func_idx = system_func(instr->data.call.callee->name); if (system_func_idx != -1) { rv32_li(out_asm, REG_A7, system_func_idx); rv32_ecall(out_asm); goto CALL_END; } /* // GENCODES(CALL(0)); // AUIPC(REG_X1, REG_X0), \ // JALR(REG_X1, REG_X1, offset) */ // TODO CALL label.name = strpool_intern(out_asm->strpool, instr->data.call.callee->name); label.attr = GLOBAL; rv32_call_l(out_asm, &label); CALL_END: offset = stack_pos(instr, ctx); rv32_sw(out_asm, REG_A0, REG_SP, offset); break; } default: LOG_ERROR("ERROR gen_instr in riscv"); } return idx; } static int gen_block(rv32_prog_t* out_asm, gen_ctx_t* ctx, ir_bblock_t* block) { symasm_entry_t label; char buf[1024]; rt.snprintf(buf, sizeof(buf), "L%s%p", block->label, block); label.name = strpool_intern(out_asm->strpool, buf); label.attr = LOCAL; rv32_append_label(out_asm, &label, out_asm->text.size); for (int i = 0; i < block->instrs.size; i ++) { gen_instr(out_asm, ctx, vector_at(block->instrs, i)); } return 0; } static int gen_func(rv32_prog_t* out_asm, ir_func_t* func) { gen_ctx_t ctx; symasm_entry_t label = { .name = strpool_intern(out_asm->strpool, func->name), .attr = GLOBAL, }; rv32_append_label(out_asm, &label, out_asm->text.size); int stack_base = 4; int stack_offset = stack_base; for (int i = 0; i < func->bblocks.size; i++) { // TODO every instr push ret val to stack stack_offset += 4 * (*vector_at(func->bblocks, i)).instrs.size; } ctx.func = func; ctx.stack_base = stack_base; ctx.stack_offset = stack_offset; ctx.func_idx = 0; ctx.block_idx = 0; // TODO Alignment by 16 // sp = sp - stack_offset; rv32_addi(out_asm, REG_SP, REG_SP, -stack_offset); // M[sp] = ra; rv32_sw(out_asm, REG_RA, REG_SP, 0); int param_regs[8] = { REG_A0, REG_A1, REG_A2, REG_A3, REG_A4, REG_A5, REG_A6, REG_A7 }; if (func->params.size > 8) { LOG_ERROR("can't add so much params"); } for (int i = 0; i < func->params.size; i++) { int offset = stack_pos(vector_at(func->params, i), &ctx); // M[sp + offset] = param[idx]; rv32_sw(out_asm, param_regs[i], REG_SP, offset); } for(int i = 0; i < func->bblocks.size; i ++) { gen_block(out_asm, &ctx ,vector_at(func->bblocks, i)); } return 0; } int gen_rv32_from_ir(ir_prog_t* ir, rv32_prog_t* out_asm) { init_rv32_prog(out_asm, NULL); for(int i = 0; i < ir->funcs.size; i ++) { gen_func(out_asm, vector_at(ir->funcs, i)); } return 0; // // Got Main pos; // for (int i = 0; i < prog->funcs.size; i++) { // if (strcmp(vector_at(prog->funcs, i)->name, "main") == 0) { // return jmp_cache[i]; // } // } // LOG_ERROR("main not found"); }