diff --git a/src/ast/ast.c b/src/ast/ast.c index c59dfa8..b985695 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -4,6 +4,7 @@ // 使用宏简化节点创建 #define NEW(alloc, k) \ AstNode* n = (AstNode*)arena_alloc_impl(alloc, sizeof(AstNode)); \ + if (!n) return NULL; \ n->kind = (k); n->type.kind = TYPE_UNKNOWN; n->type.struct_name = NULL; \ n->line = line; n->col = col diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 2a38412..c3bbf0c 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -4,6 +4,10 @@ #include #include +// === 递归深度限制 +static int codegen_depth = 0; +#define MAX_CODEGEN_DEPTH 1000 + // === 内部状态 === typedef struct VarEntry { const char* name; @@ -387,15 +391,13 @@ static void codegen_stmt(CgCtx* ctx, AstNode* node) { LLVMBuildStore(ctx->builder, init_val, alloca); add_var(ctx, node->as.let_stmt.name, alloca); - // 自动内存管理: str 堆分配追踪 - // 只有 BINARY_EXPR (拼接) 和 STRUCT_INIT 产生堆内存 + // 自动内存管理: 只追踪 str 堆分配 (拼接/malloc) + // struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展 if (node->as.let_stmt.init->type.kind == TYPE_STR) { AstKind ik = node->as.let_stmt.init->kind; - if (ik == AST_BINARY_EXPR || ik == AST_STRUCT_INIT || ik == AST_CALL_EXPR) { + if (ik == AST_BINARY_EXPR || ik == AST_CALL_EXPR) { cleanup_add(ctx, alloca); } - } else if (node->as.let_stmt.init->type.kind == TYPE_STRUCT) { - cleanup_add(ctx, alloca); // struct 可能含 str 字段 } break; } @@ -421,6 +423,20 @@ static void codegen_stmt(CgCtx* ctx, AstNode* node) { ret_val = codegen_expr(ctx, node->as.return_stmt.expr); if (!ret_val) return; } + // 如果返回的是 str 类型的变量,从清理列表移除以防止 use-after-free + if (has_val && node->as.return_stmt.expr->type.kind == TYPE_STR && + node->as.return_stmt.expr->kind == AST_IDENT_EXPR) { + LLVMValueRef alloca = find_var(ctx, node->as.return_stmt.expr->as.ident.name); + if (alloca) { + for (size_t i = 0; i < ctx->cleanup_count; i++) { + if (ctx->cleanup_list[i] == alloca) { + ctx->cleanup_list[i] = ctx->cleanup_list[ctx->cleanup_count - 1]; + ctx->cleanup_count--; + break; + } + } + } + } // return 前释放当前作用域所有 str 堆分配 cleanup_emit(ctx, 0); // 然后 emit ret @@ -430,11 +446,13 @@ static void codegen_stmt(CgCtx* ctx, AstNode* node) { } case AST_BLOCK: { + if (++codegen_depth > MAX_CODEGEN_DEPTH) { codegen_depth--; return; } size_t block_mark = ctx->cleanup_count; for (size_t i = 0; i < node->as.block.stmt_count; i++) { codegen_stmt(ctx, node->as.block.stmts[i]); } cleanup_emit(ctx, block_mark); // 作用域退出: 释放块内 str 堆分配 + codegen_depth--; break; } @@ -498,12 +516,14 @@ static void codegen_stmt(CgCtx* ctx, AstNode* node) { // === 程序级代码生成 === LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, - const char* name, const char** error_msg) { + const char* name, const char** error_msg, + LLVMContextRef* out_context) { CgCtx ctx = {0}; ctx.arena = codegen_arena; ctx.context = LLVMContextCreate(); if (!ctx.context) { *error_msg = "无法创建 LLVM Context"; + *out_context = NULL; return NULL; } ctx.module = LLVMModuleCreateWithNameInContext(name, ctx.context); @@ -616,7 +636,9 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, LLVMBuildRetVoid(ctx.builder); else LLVMBuildRet(ctx.builder, - LLVMConstInt(to_llvm_type(&ctx, fn->as.function.return_type), 0, false)); + (fn->as.function.return_type == TYPE_F64 + ? LLVMConstReal(to_llvm_type(&ctx, TYPE_F64), 0.0) + : LLVMConstInt(to_llvm_type(&ctx, fn->as.function.return_type), 0, false))); } } @@ -625,10 +647,11 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, if (LLVMVerifyModule(ctx.module, LLVMReturnStatusAction, &verify_err)) { *error_msg = verify_err ? verify_err : "模块验证失败(错误消息为 NULL)"; LLVMDisposeBuilder(ctx.builder); - LLVMContextDispose(ctx.context); + *out_context = ctx.context; return NULL; } LLVMDisposeBuilder(ctx.builder); + *out_context = ctx.context; return ctx.module; } diff --git a/src/codegen/codegen.h b/src/codegen/codegen.h index cc1eae8..368b0c6 100644 --- a/src/codegen/codegen.h +++ b/src/codegen/codegen.h @@ -9,6 +9,7 @@ // codegen_arena 用于内部分配(VarEntry/FnEntry 等),需在整个 Module 生命周期保持存活。 // 出错时返回 NULL 并设置 *error_msg。 LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, - const char* module_name, const char** error_msg); + const char* module_name, const char** error_msg, + LLVMContextRef* out_context); #endif diff --git a/src/driver/main.c b/src/driver/main.c index 097899d..2f28f51 100644 --- a/src/driver/main.c +++ b/src/driver/main.c @@ -53,6 +53,12 @@ int main(int argc, char** argv) { return 1; } + // 安全: 拒绝含 shell 元字符的文件名,防止命令注入 + if (strpbrk(input, "\"'`\\$;|&()<>") || strpbrk(output, "\"'`\\$;|&()<>")) { + fprintf(stderr, "文件名包含非法字符\n"); + return 1; + } + // 1. 读取源文件 size_t src_size; char* source = read_file(input, &src_size); @@ -94,9 +100,11 @@ int main(int argc, char** argv) { // 6. LLVM IR 生成 const char* codegen_error = NULL; - LLVMModuleRef module = codegen_module(ast, &arena, "l_module", &codegen_error); + LLVMContextRef context = NULL; + LLVMModuleRef module = codegen_module(ast, &arena, "l_module", &codegen_error, &context); if (!module) { fprintf(stderr, "IR 生成错误: %s\n", codegen_error); + if (context) LLVMContextDispose(context); free(source); arena_destroy(&arena); return 1; } @@ -145,6 +153,7 @@ int main(int argc, char** argv) { // 清理 LLVMDisposeModule(module); + LLVMContextDispose(context); free(source); arena_destroy(&arena); return 0; diff --git a/src/parser/parser.c b/src/parser/parser.c index 91234d7..66e3a2d 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -10,6 +10,10 @@ typedef struct { Arena* arena; } Parser; +// === 递归深度限制 === +static int parse_depth = 0; +#define MAX_PARSE_DEPTH 1000 + // === 向前看 === static const Token* peek(const Parser* p) { return &p->tokens[p->pos]; } static const Token* advance(Parser* p) { return &p->tokens[p->pos++]; } @@ -110,6 +114,7 @@ static AstNode* parse_struct_init(Parser* p, const Token* name, ErrorInfo* error int fcount = 0; while (peek(p)->kind != TOK_RBRACE && !error->message) { + if (fcount >= 32) { error->message = "结构体初始化字段过多 (最多32)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } const Token* fname = expect(p, TOK_IDENT, error, "字段名"); if (!fname) return NULL; if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; @@ -252,6 +257,7 @@ static AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) { AstNode* fields[32]; int fcount = 0; while (peek(p)->kind != TOK_RBRACE && !error->message) { + if (fcount >= 32) { error->message = "结构体字段过多 (最多32)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } const Token* fname = expect(p, TOK_IDENT, error, "字段名"); if (!fname) return NULL; if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; @@ -285,17 +291,24 @@ static AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) { // === 语句解析 === static AstNode* parse_block(Parser* p, ErrorInfo* error) { + if (++parse_depth > MAX_PARSE_DEPTH) { + error->message = "嵌套过深"; error->filename = p->filename; + error->line = peek(p)->line; error->col = peek(p)->col; + parse_depth--; return NULL; + } const Token* open = peek(p); - if (!expect(p, TOK_LBRACE, error, "缺少 '{'")) return NULL; + if (!expect(p, TOK_LBRACE, error, "缺少 '{'")) { parse_depth--; return NULL; } AstNode* stmts[256]; int count = 0; while (peek(p)->kind != TOK_RBRACE && peek(p)->kind != TOK_EOF && !error->message) { + if (count >= 256) { error->message = "代码块语句过多 (最多256)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; parse_depth--; return NULL; } AstNode* s = parse_statement(p, error); - if (!s) return NULL; + if (!s) { parse_depth--; return NULL; } stmts[count++] = s; } - if (!expect(p, TOK_RBRACE, error, "缺少 '}'")) return NULL; + if (!expect(p, TOK_RBRACE, error, "缺少 '}'")) { parse_depth--; return NULL; } AstNode** arr = arena_alloc_impl(p->arena, count * sizeof(AstNode*)); memcpy(arr, stmts, count * sizeof(AstNode*)); + parse_depth--; return ast_make_block(p->arena, arr, count, open->line, open->col); } @@ -498,6 +511,7 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) { // 参数列表 AstNode* params[64]; int pcount = 0; while (peek(p)->kind != TOK_RPAREN && !error->message) { + if (pcount >= 64) { error->message = "函数参数过多 (最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } const Token* pname = expect(p, TOK_IDENT, error, "参数名"); if (!pname) return NULL; if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; @@ -544,8 +558,10 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, AstNode* structs[64]; int struct_count = 0; while (peek(&p)->kind != TOK_EOF && !error->message) { if (peek(&p)->kind == TOK_STRUCT) { + if (struct_count >= 64) { error->message = "结构体过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } structs[struct_count++] = parse_struct_decl(&p, error); } else if (peek(&p)->kind == TOK_FN) { + if (fn_count >= 256) { error->message = "函数过多 (最多256)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } functions[fn_count++] = parse_function(&p, error); } else { error->message = "顶层只允许 fn 或 struct"; diff --git a/src/sema/sema.c b/src/sema/sema.c index bcfdd48..3ad1548 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -2,6 +2,8 @@ #include // === 类型关系 === +static TypeKind current_return_type = TYPE_VOID; + static TypeKind promote(TypeKind a, TypeKind b) { if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64; if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64; @@ -303,7 +305,10 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* AstNode* p = node->as.function.params[i]; scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); } + TypeKind saved = current_return_type; + current_return_type = node->as.function.return_type; analyze_node(node->as.function.body, fn_scope, errors, a); + current_return_type = saved; break; } @@ -427,6 +432,16 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* if (node->as.return_stmt.expr) { analyze_expr(node->as.return_stmt.expr, scope, errors, a); node->type.kind = node->as.return_stmt.expr->type.kind; + TypeKind actual = node->as.return_stmt.expr->type.kind; + TypeKind expected = current_return_type; + if (actual != TYPE_ERROR && expected != TYPE_VOID && actual != expected) { + error_add(errors, "", node->line, node->col, + "返回类型不匹配: 期望 '%s',得到 '%s'", + type_name(expected), type_name(actual)); + } + } else if (current_return_type != TYPE_VOID) { + error_add(errors, "", node->line, node->col, + "函数应返回值类型 '%s'", type_name(current_return_type)); } break; diff --git a/test/test_codegen.c b/test/test_codegen.c index 0468997..0a386ea 100644 --- a/test/test_codegen.c +++ b/test/test_codegen.c @@ -16,7 +16,8 @@ void test_codegen_simple_function() { AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); const char* err = NULL; - LLVMModuleRef mod = codegen_module(prog, &a, "test_mod", &err); + LLVMContextRef ctx = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_mod", &err, &ctx); ASSERT(mod != NULL); ASSERT(err == NULL); @@ -26,6 +27,7 @@ void test_codegen_simple_function() { ASSERT(!failed); LLVMDisposeModule(mod); + LLVMContextDispose(ctx); arena_destroy(&a); } @@ -48,7 +50,8 @@ void test_codegen_if_else() { AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); const char* err = NULL; - LLVMModuleRef mod = codegen_module(prog, &a, "test_mod2", &err); + LLVMContextRef ctx2 = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_mod2", &err, &ctx2); ASSERT(mod != NULL); char* verify_err = NULL; @@ -56,6 +59,7 @@ void test_codegen_if_else() { ASSERT(!failed); LLVMDisposeModule(mod); + LLVMContextDispose(ctx2); arena_destroy(&a); } @@ -77,7 +81,8 @@ void test_codegen_binary_ops() { AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); const char* err = NULL; - LLVMModuleRef mod = codegen_module(prog, &a, "test_mod3", &err); + LLVMContextRef ctx3 = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_mod3", &err, &ctx3); ASSERT(mod != NULL); char* verify_err = NULL; @@ -85,6 +90,7 @@ void test_codegen_binary_ops() { ASSERT(!failed); LLVMDisposeModule(mod); + LLVMContextDispose(ctx3); arena_destroy(&a); } @@ -105,11 +111,13 @@ void test_codegen_while_loop() { AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); const char* err = NULL; - LLVMModuleRef mod = codegen_module(prog, &a, "test_while", &err); + LLVMContextRef ctx4 = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_while", &err, &ctx4); ASSERT(mod != NULL); char* verify_err = NULL; ASSERT(!LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err)); LLVMDisposeModule(mod); + LLVMContextDispose(ctx4); arena_destroy(&a); }