diff --git a/include/l_lang.h b/include/l_lang.h index 282f938..34fa06f 100644 --- a/include/l_lang.h +++ b/include/l_lang.h @@ -18,6 +18,7 @@ typedef enum { TYPE_STRUCT, // 结构体类型 TYPE_ENUM, // 枚举类型 TYPE_ARRAY, // 固定大小数组类型 + TYPE_CLOSURE, // 闭包类型 (函数指针 + 环境指针) TYPE_GENERIC, // 泛型类型参数(单态化前) TYPE_UNKNOWN, // 尚未推断 TYPE_ERROR, // 类型错误 @@ -36,6 +37,7 @@ static inline const char* type_name(TypeKind kind) { case TYPE_STRUCT: return "struct"; case TYPE_ENUM: return "enum"; case TYPE_ARRAY: return "array"; + case TYPE_CLOSURE: return "closure"; default: return ""; } } diff --git a/src/ast/ast.c b/src/ast/ast.c index d7cdd89..8e923d3 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -280,6 +280,20 @@ AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method return n; } +AstNode* ast_make_lambda(void* alloc, AstNode** params, size_t pcount, + TypeKind ret, const char* ret_struct_name, + AstNode* body, SourceLoc loc) { + NEW(alloc, AST_LAMBDA); + n->as.lambda.params = params; n->as.lambda.param_count = pcount; + n->as.lambda.return_type = ret; + n->as.lambda.return_struct_type_name = ret_struct_name; + n->as.lambda.body = body; + n->as.lambda.generated_name = NULL; + n->as.lambda.captured = NULL; + n->as.lambda.captured_count = 0; + return n; +} + AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc) { NEW(alloc, AST_MOD_DECL); n->as.mod_decl.name = name; diff --git a/src/ast/ast.h b/src/ast/ast.h index 543a1b7..9a25397 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -32,6 +32,7 @@ typedef enum { AST_ARRAY_ASSIGN_STMT,// arr[i] = expr AST_IMPL_BLOCK, // impl StructName { fn method(...) ... } AST_METHOD_CALL, // receiver.method(args) + AST_LAMBDA, // fn(x: T) -> R { body } 匿名函数/闭包 AST_MOD_DECL, // mod foo; AST_USE_DECL, // use foo::bar; AST_TRAIT_DECL, // trait Name { fn ... } @@ -131,6 +132,12 @@ struct AstNode { struct { const char* struct_name; struct AstNode** methods; size_t method_count; } impl_block; // AST_METHOD_CALL struct { struct AstNode* receiver; const char* method_name; struct AstNode** args; const char** arg_names; size_t arg_count; } method_call; + // AST_LAMBDA + struct { struct AstNode** params; size_t param_count; + TypeKind return_type; const char* return_struct_type_name; + struct AstNode* body; + const char* generated_name; // 自动生成的顶层函数名 + const char** captured; size_t captured_count; } lambda; // AST_MOD_DECL struct { const char* name; struct AstNode* ast; } mod_decl; // AST_USE_DECL @@ -185,6 +192,9 @@ AstNode* ast_make_index_expr(void* alloc, AstNode* array, AstNode* index, Source AstNode* ast_make_array_assign(void* alloc, const char* name, AstNode* index, AstNode* value, SourceLoc loc); AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** methods, size_t count, SourceLoc loc); AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc); +AstNode* ast_make_lambda(void* alloc, AstNode** params, size_t pcount, + TypeKind ret, const char* ret_struct_name, + AstNode* body, SourceLoc loc); AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc); AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc); AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc); diff --git a/src/ast/visit.h b/src/ast/visit.h index efa153c..d645ce7 100644 --- a/src/ast/visit.h +++ b/src/ast/visit.h @@ -8,7 +8,7 @@ typedef void* (*VisitFn)(void* ctx, AstNode* node); // 遍历表 — 按 AstKind 索引, 未处理的条目为 NULL // 新增 AST 节点: 在此表新增一条目, 编译器会警告未初始化的函数指针 -enum { VISIT_TABLE_SIZE = 28 }; +enum { VISIT_TABLE_SIZE = 29 }; typedef struct { void* ctx; diff --git a/src/codegen/cg_expr.c b/src/codegen/cg_expr.c index 83559cf..a5e44fd 100644 --- a/src/codegen/cg_expr.c +++ b/src/codegen/cg_expr.c @@ -11,6 +11,8 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context); case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context); case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0); + case TYPE_CLOSURE: + return LLVMInt64TypeInContext(ctx->context); // 函数指针 case TYPE_STRUCT: case TYPE_ENUM: { LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context), @@ -235,8 +237,28 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) { return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn, (LLVMValueRef[]){fmt, arg}, 2, ""); } + LLVMTypeRef fn_ty = NULL; LLVMValueRef fn = find_fn(ctx, node->as.call.name); - if (!fn) return NULL; + if (fn) { + fn_ty = LLVMGlobalGetValueType(fn); // 普通函数: 获取函数类型 + } else { + // 闭包调用: 函数名在变量表中 (TYPE_CLOSURE) + VarEntry* cve = NULL; + for (VarEntry* e = ctx->var_table; e; e = e->next) + if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; } + if (cve && cve->closure_fn) { + LLVMValueRef gen_fn = find_fn(ctx, cve->closure_fn); + if (gen_fn) { + fn_ty = LLVMGlobalGetValueType(gen_fn); // 获取函数类型 + LLVMValueRef closure_ptr = LLVMBuildLoad2(ctx->builder, + LLVMInt64TypeInContext(ctx->context), + cve->alloca, "fn_ptr"); + fn = LLVMBuildIntToPtr(ctx->builder, closure_ptr, + LLVMPointerType(fn_ty, 0), "fn_cast"); + } + } + } + if (!fn || !fn_ty) return NULL; LLVMValueRef args[16]; if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; } FnEntry* fn_entry = find_fn_entry(ctx, node->as.call.name); @@ -269,7 +291,6 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) { } if (!args[i]) return NULL; } - LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, (unsigned)node->as.call.arg_count, @@ -497,6 +518,16 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) { } CG_HANDLER(cg_list_comp) +static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) { + // 返回生成函数的指针(作为 i64) + LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name); + if (!gen_fn) return NULL; + LLVMValueRef ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn, + LLVMInt64TypeInContext(ctx->context), "lambda_fn"); + return ptr; +} +CG_HANDLER(cg_lambda) + void codegen_expr_init(void) { ast_dispatch_set(&cg_dispatch, AST_LITERAL_EXPR, cg_literal); ast_dispatch_set(&cg_dispatch, AST_IDENT_EXPR, cg_ident); @@ -511,6 +542,7 @@ void codegen_expr_init(void) { ast_dispatch_set(&cg_dispatch, AST_BLOCK, cg_block); ast_dispatch_set(&cg_dispatch, AST_IF_STMT, cg_if_expr); ast_dispatch_set(&cg_dispatch, AST_LIST_COMP, cg_list_comp); + ast_dispatch_set(&cg_dispatch, AST_LAMBDA, cg_lambda); } // === 统一入口 === diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 15ff79d..0e88aad 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -9,11 +9,14 @@ LLVMValueRef find_var(CgCtx* ctx, const char* name) { return NULL; } -void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) { +VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) { VarEntry* e = arena_alloc(ctx->arena, sizeof(*e)); - if (!e) return; - e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table; + if (!e) return NULL; + e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; + e->closure_fn = NULL; + e->next = ctx->var_table; ctx->var_table = e; + return e; } // === 函数表 === @@ -134,7 +137,11 @@ void codegen_stmt(CgCtx* ctx, AstNode* node) { } else { return; } - add_var(ctx, node->as.let_stmt.name, alloca, var_type); + VarEntry* ve = add_var(ctx, node->as.let_stmt.name, alloca, var_type); + // 若 init 是 lambda, 记录闭包函数名供后续调用 + if (node->as.let_stmt.init && + node->as.let_stmt.init->kind == AST_LAMBDA && ve) + ve->closure_fn = node->as.let_stmt.init->as.lambda.generated_name; // 自动内存管理: 只追踪 str 堆分配 (拼接/malloc) // struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展 diff --git a/src/codegen/codegen_internal.h b/src/codegen/codegen_internal.h index 98cbfee..e4ee173 100644 --- a/src/codegen/codegen_internal.h +++ b/src/codegen/codegen_internal.h @@ -20,6 +20,7 @@ typedef struct VarEntry { const char* name; LLVMValueRef alloca; LLVMTypeRef alloca_type; + const char* closure_fn; // 闭包对应的生成函数名 struct VarEntry* next; } VarEntry; @@ -70,7 +71,7 @@ LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMT // === 表操作 === LLVMValueRef find_var(CgCtx* ctx, const char* name); -void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type); +VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type); LLVMValueRef find_fn(CgCtx* ctx, const char* name); FnEntry* find_fn_entry(CgCtx* ctx, const char* name); void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc); diff --git a/src/parser/expr.c b/src/parser/expr.c index c220c8a..066a6af 100644 --- a/src/parser/expr.c +++ b/src/parser/expr.c @@ -322,6 +322,38 @@ AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error) { left = ast_make_list_comp(p->arena, arena_strdup_impl(p->arena, vname->start, vname->length), arr, body, tok_loc(tok)); + } else if (tok->kind == TOK_FN) { + // lambda: fn(params) -> RetType { body } + const Token* fn_tok = advance(p); // 跳过 fn + // 泛型参数暂不支持(lambda用捕获替代) + if (!expect(p, TOK_LPAREN, error, "缺少 '('")) return NULL; + AstNode* plist[64]; int pc = 0; + while (peek(p)->kind != TOK_RPAREN && !error->message) { + if (pc >= 64) { error->message = "lambda 参数过多(最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } + bool is_out = match(p, TOK_OUT); + const Token* pname = expect(p, TOK_IDENT, error, "参数名"); + if (!pname) return NULL; + if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; + TypeInfo pti = parse_type_expr(p, error); + if (pti.kind == TYPE_ERROR) return NULL; + plist[pc++] = ast_make_parameter(p->arena, + arena_strdup_impl(p->arena, pname->start, pname->length), + pti.kind, pti.struct_name, is_out, tok_loc(pname)); + if (match(p, TOK_COMMA)) continue; else break; + } + if (!expect(p, TOK_RPAREN, error, "缺少 ')'")) return NULL; + TypeKind ret = TYPE_VOID; + const char* ret_sn = NULL; + if (match(p, TOK_ARROW)) { + TypeInfo rti = parse_type_expr(p, error); + if (rti.kind == TYPE_ERROR) return NULL; + ret = rti.kind; ret_sn = rti.struct_name; + } + AstNode* body = parse_block(p, error); + if (!body) return NULL; + AstNode** parr = arena_alloc_impl(p->arena, pc * sizeof(AstNode*)); + memcpy(parr, plist, pc * sizeof(AstNode*)); + left = ast_make_lambda(p->arena, parr, pc, ret, ret_sn, body, tok_loc(fn_tok)); } else if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) { left = parse_unary(p, error); } else if (tok->kind == TOK_LPAREN) { diff --git a/src/sema/mono.c b/src/sema/mono.c index 10ad746..62741ec 100644 --- a/src/sema/mono.c +++ b/src/sema/mono.c @@ -95,3 +95,7 @@ AstNode* mono_queue[256]; size_t mono_count = 0; Arena* mono_arena = NULL; AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板) + +// lambda 队列: 分析时创建的闭包函数 +AstNode* lambda_queue[256]; +size_t lambda_count = 0; diff --git a/src/sema/sema.c b/src/sema/sema.c index 0f19bfe..1a0b76a 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -506,4 +506,19 @@ void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) { scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1, NULL, 0); analyze_node(ast, global_scope, errors, arena); + + // 将 lambda 生成的函数追加到 program 的函数列表 + if (lambda_count > 0 && g_program) { + size_t old = g_program->as.program.fn_count; + size_t total = old + lambda_count; + AstNode** new_fns = arena_alloc_impl(arena, total * sizeof(AstNode*)); + if (new_fns) { + if (old > 0) + memcpy(new_fns, g_program->as.program.functions, old * sizeof(AstNode*)); + memcpy(new_fns + old, lambda_queue, lambda_count * sizeof(AstNode*)); + g_program->as.program.functions = new_fns; + g_program->as.program.fn_count = total; + } + lambda_count = 0; + } } diff --git a/src/sema/sema_internal.h b/src/sema/sema_internal.h index bbbddc9..54ac917 100644 --- a/src/sema/sema_internal.h +++ b/src/sema/sema_internal.h @@ -18,6 +18,10 @@ extern size_t mono_count; extern Arena* mono_arena; extern AstNode* g_program; +// === lambda 闭包队列 === +extern AstNode* lambda_queue[256]; +extern size_t lambda_count; + // === 类型推断上下文 === extern TypeKind current_return_type; extern const char* current_return_struct_name; diff --git a/src/sema/typeck.c b/src/sema/typeck.c index dc0be43..930cb56 100644 --- a/src/sema/typeck.c +++ b/src/sema/typeck.c @@ -226,6 +226,14 @@ bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset, void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { Symbol* sym = scope_lookup(scope, node->as.call.name); + // 闭包调用: 变量类型为 TYPE_CLOSURE + if (sym && sym->kind == SYM_VARIABLE && sym->type == TYPE_CLOSURE) { + // 暂不做参数类型检查(MVP), 只分析参数表达式 + for (size_t i = 0; i < node->as.call.arg_count; i++) + analyze_expr(node->as.call.args[i], scope, errors, a); + node->type.kind = TYPE_I64; // 默认返回 i64(MVP 限制) + return; + } if (!sym || sym->kind != SYM_FUNCTION) { error_add(errors, "", node->loc.line, node->loc.col, "未定义的函数 '%s'", node->as.call.name); @@ -523,6 +531,53 @@ void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* return NULL; \ } +// === lambda 表达式分析 === +static int lambda_counter = 0; + +void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { + lambda_counter++; + int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1; + char* gen_name = arena_alloc_impl(a, name_len); + snprintf(gen_name, name_len, "__lambda_%d", lambda_counter); + node->as.lambda.generated_name = gen_name; + + // 分析 lambda 体(参数作用域) + Scope* lambda_scope = scope_new(a, scope); + for (size_t i = 0; i < node->as.lambda.param_count; i++) { + AstNode* p = node->as.lambda.params[i]; + scope_insert(lambda_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); + } + TypeKind saved_ret = current_return_type; + const char* saved_ret_sn = current_return_struct_name; + current_return_type = node->as.lambda.return_type; + current_return_struct_name = node->as.lambda.return_struct_type_name; + analyze_node(node->as.lambda.body, lambda_scope, errors, a); + current_return_type = saved_ret; + current_return_struct_name = saved_ret_sn; + + // 创建顶层函数 AST 节点, 加入队列供 codegen 使用 + AstNode* fn = ast_make_function(a, gen_name, + node->as.lambda.params, node->as.lambda.param_count, + node->as.lambda.return_type, + node->as.lambda.return_struct_type_name, + node->as.lambda.body, false, NULL, 0, node->loc); + if (lambda_count < 256) + lambda_queue[lambda_count++] = fn; + + // 注册函数符号(支持递归调用自身) + TypeKind* pts = node->as.lambda.param_count > 0 + ? arena_alloc_impl(a, node->as.lambda.param_count * sizeof(TypeKind)) : NULL; + for (size_t i = 0; i < node->as.lambda.param_count; i++) + pts[i] = node->as.lambda.params[i]->as.parameter.type; + scope_insert_function(scope, a, gen_name, + node->as.lambda.return_type, + node->as.lambda.return_struct_type_name, + pts, NULL, NULL, node->as.lambda.param_count, NULL, 0); + + node->type.kind = TYPE_CLOSURE; +} + +SEMA_HANDLER(analyze_lambda) SEMA_HANDLER(analyze_ident_expr) SEMA_HANDLER(analyze_unary_expr) SEMA_HANDLER(analyze_binary_expr) @@ -572,6 +627,7 @@ void analyze_expr_init(void) { ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap); ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap); ast_dispatch_set(&sema_dispatch, AST_LIST_COMP, analyze_list_comp_wrap); + ast_dispatch_set(&sema_dispatch, AST_LAMBDA, analyze_lambda_wrap); } void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { diff --git a/test/programs/44_lambda.l b/test/programs/44_lambda.l new file mode 100644 index 0000000..4205b57 --- /dev/null +++ b/test/programs/44_lambda.l @@ -0,0 +1,26 @@ +// 闭包测试 — lambda 表达式 + 调用 +fn apply_op(x: i64, op: i64) -> i64 { + // 闭包作为参数暂不支持直接调用,返回 x * 2 + return x * 2; +} + +fn main() -> void { + // 测试1: 基本 lambda + let double = fn(x: i64) -> i64 { return x * 2; }; + let r1 = double(21); + print_i64(r1); // 42 + + // 测试2: lambda with multiple params + let add = fn(a: i64, b: i64) -> i64 { return a + b; }; + let r2 = add(30, 12); + print_i64(r2); // 42 + + // 测试3: nested lambda call + let r3 = double(add(10, 11)); + print_i64(r3); // 42 + + // 测试4: lambda in sequence + let triple = fn(x: i64) -> i64 { return x * 3; }; + let r4 = triple(14); + print_i64(r4); // 42 +} diff --git a/test/test_sema.c b/test/test_sema.c index 0d0c2ed..e716c18 100644 --- a/test/test_sema.c +++ b/test/test_sema.c @@ -452,6 +452,38 @@ void test_in_param_assign_error() { arena_destroy(&a); } +void test_lambda_ok() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "fn main() -> void { let f = fn(x: i64) -> i64 { return x * 2; }; return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + ErrorList errors; error_init(&errors, &a); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count == 0); // lambda 定义应通过 + arena_destroy(&a); +} + +void test_lambda_call_ok() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "fn main() -> void { let f = fn(x: i64) -> i64 { return x + 1; }; let r = f(41); return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + ErrorList errors; error_init(&errors, &a); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count == 0); // 闭包调用应通过 + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_type_error); TEST_RUN(test_undefined_var); @@ -479,5 +511,7 @@ int main(void) { TEST_RUN(test_match_wildcard_only_sema_ok); TEST_RUN(test_out_param_assign_ok); TEST_RUN(test_in_param_assign_error); + TEST_RUN(test_lambda_ok); + TEST_RUN(test_lambda_call_ok); return test_summary(); }