#include "ast.h" #include "../parser.h" struct ASTNode* new_ast_node(void) { struct ASTNode* node = xmalloc(sizeof(struct ASTNode)); init_ast_node(node); return node; } void init_ast_node(struct ASTNode* node) { node->type = NT_INIT; for (int i = 0; i < sizeof(node->children) / sizeof(node->children[0]); i++) { node->children[i] = NULL; } } struct ASTNode* find_ast_node(struct ASTNode* node, enum ASTType type) { } #include static void pnt_depth(int depth) { for (int i = 0; i < depth; i++) { printf(" "); } } void pnt_ast(struct ASTNode* node, int depth) { if (!node) return; pnt_depth(depth); switch (node->type) { case NT_ROOT: for (int i = 0; i < node->root.child_size; i++) { pnt_ast(node->root.children[i], depth); } return; case NT_ADD : printf("+ \n"); break; // (expr) + (expr) case NT_SUB : printf("- \n"); break; // (expr) - (expr) case NT_MUL : printf("* \n"); break; // (expr) * (expr) case NT_DIV : printf("/ \n"); break; // (expr) / (expr) case NT_MOD : printf("%%\n"); break; // (expr) % (expr) case NT_AND : printf("& \n"); break; // (expr) & (expr) case NT_OR : printf("| \n"); break; // (expr) | (expr) case NT_XOR : printf("^ \n"); break; // (expr) ^ (expr) case NT_L_SH : printf("<<\n"); break; // (expr) << (expr) case NT_R_SH : printf(">>\n"); break; // (expr) >> (expr) case NT_EQ : printf("==\n"); break; // (expr) == (expr) case NT_NEQ : printf("!=\n"); break; // (expr) != (expr) case NT_LE : printf("<=\n"); break; // (expr) <= (expr) case NT_GE : printf(">=\n"); break; // (expr) >= (expr) case NT_LT : printf("< \n"); break; // (expr) < (expr) case NT_GT : printf("> \n"); break; // (expr) > (expr) case NT_AND_AND : printf("&&\n"); break; // (expr) && (expr) case NT_OR_OR : printf("||\n"); break; // (expr) || (expr) case NT_NOT : printf("! \n"); break; // ! (expr) case NT_BIT_NOT : printf("~ \n"); break; // ~ (expr) case NT_COMMA : printf(", \n"); break; // expr, expr 逗号运算符 case NT_ASSIGN : printf("= \n"); break; // (expr) = (expr) // case NT_COND : // (expr) ? (expr) : (expr) case NT_STMT_EMPTY : // ; printf(";\n"); break; case NT_STMT_IF : // if (cond) { ... } [else {...}] printf("if"); pnt_ast(node->if_stmt.cond, depth+1); pnt_ast(node->if_stmt.if_stmt, depth+1); if (node->if_stmt.else_stmt) { pnt_depth(depth); printf("else"); pnt_ast(node->if_stmt.else_stmt, depth+1); } break; case NT_STMT_WHILE : // while (cond) { ... } printf("while\n"); pnt_ast(node->while_stmt.cond, depth+1); pnt_ast(node->while_stmt.body, depth+1); break; case NT_STMT_DOWHILE : // do {...} while (cond) printf("do-while\n"); pnt_ast(node->do_while_stmt.body, depth+1); pnt_ast(node->do_while_stmt.cond, depth+1); break; case NT_STMT_FOR : // for (init; cond; iter) {...} printf("for\n"); if (node->for_stmt.init) pnt_ast(node->for_stmt.init, depth+1); if (node->for_stmt.cond) pnt_ast(node->for_stmt.cond, depth+1); if (node->for_stmt.iter) pnt_ast(node->for_stmt.iter, depth+1); pnt_ast(node->for_stmt.body, depth+1); break; case NT_STMT_SWITCH : // switch (expr) { case ... } case NT_STMT_BREAK : // break; case NT_STMT_CONTINUE : // continue; case NT_STMT_GOTO : // goto label; case NT_STMT_CASE : // case const_expr: case NT_STMT_DEFAULT : // default: case NT_STMT_LABEL : // label: break; case NT_STMT_BLOCK : // { ... } printf("{\n"); for (int i = 0; i < node->block.child_size; i++) { pnt_ast(node->block.children[i], depth+1); } pnt_depth(depth); printf("}\n"); break; case NT_STMT_RETURN : // return expr; printf("return"); if (node->return_stmt.expr_stmt) { printf(" "); pnt_ast(node->return_stmt.expr_stmt, depth+1); } else { printf("\n"); } break; case NT_STMT_EXPR : // expr; printf("stmt\n"); pnt_ast(node->expr_stmt.expr_stmt, depth); pnt_depth(depth); printf(";\n"); break; case NT_DECL_VAR : // type name; or type name = expr; printf("decl_val\n"); break; case NT_DECL_FUNC: // type func_name(param_list); printf("decl func %s\n", node->func.name->syms.tok.constant.str); break; case NT_FUNC : // type func_name(param_list) {...} printf("def func %s\n", node->func.name->syms.tok.constant.str); // pnt_ast(node->child.func.params, depth); pnt_ast(node->func.body, depth); // pnt_ast(node->child.func.ret, depth); break; case NT_PARAM : // 函数形参 printf("param\n"); case NT_ARG_LIST : // 实参列表(需要与NT_CALL配合) printf("arg_list\n"); case NT_TERM_CALL : // func (expr) printf("call\n"); break; case NT_TERM_IDENT: printf("%s\n", node->syms.tok.constant.str); break; case NT_TERM_VAL : // Terminal Symbols like constant, identifier, keyword struct Token * tok = &node->syms.tok; switch (tok->type) { case TOKEN_CHAR_LITERAL: printf("%c\n", tok->constant.ch); break; case TOKEN_INT_LITERAL: printf("%d\n", tok->constant.i); break; case TOKEN_STRING_LITERAL: printf("%s\n", tok->constant.str); break; default: printf("unknown term val\n"); break; } default: break; } // 通用子节点递归处理 if (node->type <= NT_ASSIGN) { // 表达式类统一处理子节点 if (node->expr.left) pnt_ast(node->expr.left, depth+1); if (node->expr.right) pnt_ast(node->expr.right, depth + 1); } }