From f5c0650a97747e23d70c8fdc234ae946c55f7602 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sun, 7 Jun 2026 15:24:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E9=97=AD=E5=8C=85=E5=8F=98=E9=87=8F?= =?UTF-8?q?=E6=8D=95=E8=8E=B7=20=E2=80=94=20=E7=8E=AF=E5=A2=83=E7=BB=93?= =?UTF-8?q?=E6=9E=84=E4=BD=93=20+=20=E5=A0=86=E5=88=86=E9=85=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit lambda 可捕获外层变量, 自动构建环境结构体: let base = 100; let f = fn(x: i64) -> i64 { return x + base; }; // 捕获 base f(50); // → 150 全流水线实现: - Sema: collect_free_vars 遍历 AST 收集自由变量 - AST function: captured/cap_types/cap_count 字段存储捕获信息 - Codegen: 闭包类型改为 struct {fn_ptr: i64, env_ptr: ptr} - Codegen: lambda 表达式 malloc 环境结构体 + 存储捕获值 - Codegen: 生成函数签名添加 env_ptr 首个参数 (capturing only) - Codegen: 函数体内通过 GEP 注册捕获变量到 var_table - Codegen: 闭包调用自动提取 fn_ptr/env_ptr, 条件传递 env 非捕获 lambda 兼容: env_ptr=NULL, 不额外传参 嵌套 lambda 正确处理: 内层不穿透捕获外层变量 Co-Authored-By: Claude Opus 4.7 --- src/ast/ast.c | 3 + src/ast/ast.h | 3 +- src/codegen/cg_expr.c | 119 ++++++++++++++++++++++++++++----- src/codegen/codegen.c | 56 +++++++++++++--- src/codegen/codegen_internal.h | 3 + src/sema/typeck.c | 111 ++++++++++++++++++++++++++++++ test/programs/44_lambda.l | 36 +++++----- 7 files changed, 285 insertions(+), 46 deletions(-) diff --git a/src/ast/ast.c b/src/ast/ast.c index 8e923d3..5fbe7d9 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -44,6 +44,9 @@ AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size n->as.function.multi_ret_types = NULL; n->as.function.multi_ret_snames = NULL; n->as.function.multi_ret_count = 0; + n->as.function.captured = NULL; + n->as.function.cap_types = NULL; + n->as.function.cap_count = 0; return n; } diff --git a/src/ast/ast.h b/src/ast/ast.h index 9a25397..30bc6c5 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -73,7 +73,8 @@ struct AstNode { TypeKind return_type; const char* return_struct_type_name; struct AstNode* body; bool is_pub; const char** type_params; size_t type_param_count; - TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count; } function; + TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count; + const char** captured; TypeKind* cap_types; size_t cap_count; } function; // AST_PARAMETER (也用作结构体字段: name + type) struct { const char* name; TypeKind type; const char* struct_type_name; bool is_out; } parameter; // AST_BLOCK diff --git a/src/codegen/cg_expr.c b/src/codegen/cg_expr.c index a5e44fd..bbe08bc 100644 --- a/src/codegen/cg_expr.c +++ b/src/codegen/cg_expr.c @@ -11,8 +11,13 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context); case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context); case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0); - case TYPE_CLOSURE: - return LLVMInt64TypeInContext(ctx->context); // 函数指针 + case TYPE_CLOSURE: { + LLVMTypeRef cls_fields[] = { + LLVMInt64TypeInContext(ctx->context), // fn_ptr + LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0) // env_ptr + }; + return LLVMStructTypeInContext(ctx->context, cls_fields, 2, false); + } case TYPE_STRUCT: case TYPE_ENUM: { LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context), @@ -239,22 +244,27 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) { } LLVMTypeRef fn_ty = NULL; LLVMValueRef fn = find_fn(ctx, node->as.call.name); + LLVMValueRef closure_env = NULL; + LLVMValueRef gen_fn = NULL; // 闭包对应的生成函数 if (fn) { - fn_ty = LLVMGlobalGetValueType(fn); // 普通函数: 获取函数类型 + fn_ty = LLVMGlobalGetValueType(fn); } else { - // 闭包调用: 函数名在变量表中 (TYPE_CLOSURE) VarEntry* cve = NULL; for (VarEntry* e = ctx->var_table; e; e = e->next) if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; } if (cve && cve->closure_fn) { - LLVMValueRef gen_fn = find_fn(ctx, cve->closure_fn); + gen_fn = find_fn(ctx, cve->closure_fn); if (gen_fn) { - fn_ty = LLVMGlobalGetValueType(gen_fn); // 获取函数类型 - LLVMValueRef closure_ptr = LLVMBuildLoad2(ctx->builder, - LLVMInt64TypeInContext(ctx->context), - cve->alloca, "fn_ptr"); - fn = LLVMBuildIntToPtr(ctx->builder, closure_ptr, + fn_ty = LLVMGlobalGetValueType(gen_fn); + LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE); + LLVMValueRef cls_val = LLVMBuildLoad2(ctx->builder, cls_ty, + cve->alloca, "cls_val"); + LLVMValueRef fn_val = LLVMBuildExtractValue(ctx->builder, + cls_val, 0, "fn_ptr_i64"); + fn = LLVMBuildIntToPtr(ctx->builder, fn_val, LLVMPointerType(fn_ty, 0), "fn_cast"); + closure_env = LLVMBuildExtractValue(ctx->builder, + cls_val, 1, "env_ptr"); } } } @@ -291,9 +301,23 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) { } if (!args[i]) return NULL; } + // 闭包调用: 若函数有 env 参数, 将其作为第一个参数 + LLVMValueRef call_args[17]; + unsigned call_argc = (unsigned)node->as.call.arg_count; + // 用生成函数的参数数判断是否需要 pass env + bool need_env = closure_env && gen_fn && LLVMCountParams(gen_fn) > call_argc; + if (need_env) { + call_args[0] = closure_env; + for (unsigned i = 0; i < call_argc; i++) + call_args[i + 1] = args[i]; + call_argc++; + } else { + for (unsigned i = 0; i < call_argc; i++) + call_args[i] = args[i]; + } LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); - return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, - (unsigned)node->as.call.arg_count, + return LLVMBuildCall2(ctx->builder, fn_ty, fn, call_args, + call_argc, ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call"); } CG_HANDLER(cg_call) @@ -519,12 +543,75 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) { CG_HANDLER(cg_list_comp) static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) { - // 返回生成函数的指针(作为 i64) + // 获取生成函数 LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name); if (!gen_fn) return NULL; - LLVMValueRef ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn, - LLVMInt64TypeInContext(ctx->context), "lambda_fn"); - return ptr; + + // 闭包结构体: { i64 fn_ptr, i8* env_ptr } + LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE); + LLVMValueRef closure = LLVMBuildAlloca(ctx->builder, cls_ty, "closure"); + LLVMValueRef fn_ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn, + LLVMInt64TypeInContext(ctx->context), "fn_ptr"); + + LLVMValueRef env_ptr = LLVMConstNull( + LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0)); + + // 若有捕获变量, 构建环境结构体 + AstNode* fn_ast = NULL; + // 在 g_program 中查找对应的函数 AST 获取捕获信息 + if (g_program) { + for (size_t i = 0; i < g_program->as.program.fn_count; i++) { + if (g_program->as.program.functions[i]->as.function.name && + strcmp(g_program->as.program.functions[i]->as.function.name, + node->as.lambda.generated_name) == 0) { + fn_ast = g_program->as.program.functions[i]; + break; + } + } + } + if (fn_ast && fn_ast->as.function.cap_count > 0) { + size_t nc = fn_ast->as.function.cap_count; + // 构建 env struct: { cap_type_0, cap_type_1, ... } + LLVMTypeRef* ef = arena_alloc(ctx->arena, nc * sizeof(LLVMTypeRef)); + for (size_t i = 0; i < nc; i++) + ef[i] = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]); + LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx->context, ef, (unsigned)nc, false); + + // malloc env + LLVMValueRef env_size = LLVMSizeOf(env_ty); + LLVMValueRef malloc_args[] = { env_size }; + LLVMValueRef env = LLVMBuildCall2(ctx->builder, + LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn, + malloc_args, 1, "env"); + env_ptr = LLVMBuildBitCast(ctx->builder, env, + LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0), "env_cast"); + LLVMValueRef typed_env = LLVMBuildBitCast(ctx->builder, env, + LLVMPointerType(env_ty, 0), "env_typed"); + + // 存储捕获的值到 env struct + for (size_t i = 0; i < nc; i++) { + LLVMValueRef cap_alloca = find_var(ctx, fn_ast->as.function.captured[i]); + if (!cap_alloca) return NULL; + LLVMTypeRef cap_ty = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]); + LLVMValueRef cap_val = LLVMBuildLoad2(ctx->builder, cap_ty, + cap_alloca, "cap_val"); + LLVMValueRef gep_idx[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned)i, false) + }; + LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, env_ty, + typed_env, gep_idx, 2, "cap_gep"); + LLVMBuildStore(ctx->builder, cap_val, field_ptr); + } + } + + // 填充 closure struct: { fn_ptr, env_ptr } + LLVMValueRef f0 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 0, "cl_f0"); + LLVMBuildStore(ctx->builder, fn_ptr, f0); + LLVMValueRef f1 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 1, "cl_f1"); + LLVMBuildStore(ctx->builder, env_ptr, f1); + + return LLVMBuildLoad2(ctx->builder, cls_ty, closure, "closure_val"); } CG_HANDLER(cg_lambda) diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 0e88aad..8b68e7d 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -391,17 +391,25 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, // 第一遍:声明所有 L 函数 for (size_t i = 0; i < ast->as.program.fn_count; i++) { AstNode* fn = ast->as.program.functions[i]; + bool has_env = fn->as.function.cap_count > 0; + size_t total_params = fn->as.function.param_count + (has_env ? 1 : 0); LLVMTypeRef* ptypes = arena_alloc(ctx.arena, - fn->as.function.param_count * sizeof(LLVMTypeRef)); + total_params * sizeof(LLVMTypeRef)); bool* out_params = NULL; - if (fn->as.function.param_count > 0) { - out_params = arena_alloc(ctx.arena, - fn->as.function.param_count * sizeof(bool)); + if (total_params > 0) { + out_params = arena_alloc(ctx.arena, total_params * sizeof(bool)); + for (size_t j = 0; j < total_params; j++) out_params[j] = false; + } + // 若有捕获, 第一个参数是 env_ptr + size_t poff = 0; + if (has_env) { + ptypes[0] = LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0); + poff = 1; } for (size_t j = 0; j < fn->as.function.param_count; j++) { AstNode* param = fn->as.function.params[j]; bool is_out = param->as.parameter.is_out; - if (out_params) out_params[j] = is_out; + if (out_params) out_params[j + poff] = is_out; LLVMTypeRef inner_ty; if (param->as.parameter.type == TYPE_STRUCT && param->as.parameter.struct_type_name) { @@ -409,7 +417,7 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, } else { inner_ty = to_llvm_type(&ctx, param->as.parameter.type); } - ptypes[j] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty; + ptypes[j + poff] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty; } LLVMTypeRef ret_ty; if (fn->as.function.return_type == TYPE_STRUCT && @@ -419,10 +427,9 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, ret_ty = to_llvm_type(&ctx, fn->as.function.return_type); } LLVMTypeRef fty = LLVMFunctionType(ret_ty, - ptypes, (unsigned)fn->as.function.param_count, false); + ptypes, (unsigned)total_params, false); LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty); - add_fn(&ctx, fn->as.function.name, lfn, out_params, - fn->as.function.param_count); + add_fn(&ctx, fn->as.function.name, lfn, out_params, total_params); } // 第二遍:生成函数体 @@ -435,9 +442,37 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, // 清空变量表(每个函数独立作用域) ctx.var_table = NULL; + // 捕获变量: 第一个参数是 env_ptr, 通过 GEP 注册捕获变量 + bool has_env = fn->as.function.cap_count > 0; + LLVMValueRef env_ptr = NULL; + size_t param_offset = 0; + if (has_env) { + env_ptr = LLVMGetParam(lfn, 0); + param_offset = 1; + // 生成 env struct 类型并注册捕获变量 + LLVMTypeRef* ef = arena_alloc(ctx.arena, + fn->as.function.cap_count * sizeof(LLVMTypeRef)); + for (size_t ci = 0; ci < fn->as.function.cap_count; ci++) + ef[ci] = to_llvm_type(&ctx, fn->as.function.cap_types[ci]); + LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx.context, ef, + (unsigned)fn->as.function.cap_count, false); + LLVMValueRef typed_env = LLVMBuildBitCast(ctx.builder, env_ptr, + LLVMPointerType(env_ty, 0), "env_typed"); + for (size_t ci = 0; ci < fn->as.function.cap_count; ci++) { + LLVMValueRef gep_idx[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx.context), 0, false), + LLVMConstInt(LLVMInt32TypeInContext(ctx.context), (unsigned)ci, false) + }; + LLVMValueRef field_ptr = LLVMBuildGEP2(ctx.builder, env_ty, + typed_env, gep_idx, 2, "cap_ptr"); + add_var(&ctx, fn->as.function.captured[ci], field_ptr, + to_llvm_type(&ctx, fn->as.function.cap_types[ci])); + } + } + // 将参数注册为变量 for (size_t j = 0; j < fn->as.function.param_count; j++) { - LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j); + LLVMValueRef param = LLVMGetParam(lfn, (unsigned)(j + param_offset)); AstNode* pnode = fn->as.function.params[j]; LLVMTypeRef param_ty; if (pnode->as.parameter.type == TYPE_STRUCT && @@ -447,7 +482,6 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, param_ty = to_llvm_type(&ctx, pnode->as.parameter.type); } if (pnode->as.parameter.is_out) { - // out 参数: param 已是指向调用者变量的指针, 直接用作 alloca add_var(&ctx, pnode->as.parameter.name, param, param_ty); } else { LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder, diff --git a/src/codegen/codegen_internal.h b/src/codegen/codegen_internal.h index e4ee173..b67e95c 100644 --- a/src/codegen/codegen_internal.h +++ b/src/codegen/codegen_internal.h @@ -15,6 +15,9 @@ extern int codegen_depth; #define MAX_CODEGEN_DEPTH 1000 +// AST program (由 sema 设置, codegen 读取) +extern AstNode* g_program; + // === 内部状态 === typedef struct VarEntry { const char* name; diff --git a/src/sema/typeck.c b/src/sema/typeck.c index 930cb56..9c424cf 100644 --- a/src/sema/typeck.c +++ b/src/sema/typeck.c @@ -534,6 +534,101 @@ void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena* // === lambda 表达式分析 === static int lambda_counter = 0; +// 遍历 AST 收集自由变量(不在 lambda_scope 但在 parent_scope 中的变量) +static void collect_free_vars_impl(AstNode* body, Scope* lambda_scope, Scope* parent_scope, + const char** names, TypeKind* types, size_t* count, + const char** locals, size_t* local_count) { + if (!body || *count >= 16) return; + switch (body->kind) { + case AST_IDENT_EXPR: { + const char* name = body->as.ident.name; + // 跳过 lambda 参数和局部变量 + for (size_t i = 0; i < *local_count; i++) + if (strcmp(locals[i], name) == 0) return; + // 在父作用域中查找(不在 lambda scope 中) + Symbol* sym = scope_lookup(parent_scope, name); + if (sym && (sym->kind == SYM_VARIABLE || sym->kind == SYM_PARAMETER)) { + for (size_t i = 0; i < *count; i++) + if (strcmp(names[i], name) == 0) return; + names[*count] = name; + types[*count] = sym->type; + (*count)++; + } + return; + } + case AST_LAMBDA: return; + case AST_LET_STMT: + // 先分析 init (可能引用外部变量) + collect_free_vars_impl(body->as.let_stmt.init, lambda_scope, parent_scope, + names, types, count, locals, local_count); + // 将新声明的变量加入 locals + if (*local_count < 16) + locals[(*local_count)++] = body->as.let_stmt.name; + return; + case AST_BLOCK: + for (size_t i = 0; i < body->as.block.stmt_count; i++) + collect_free_vars_impl(body->as.block.stmts[i], lambda_scope, parent_scope, + names, types, count, locals, local_count); + return; + case AST_BINARY_EXPR: + collect_free_vars_impl(body->as.binary.left, lambda_scope, parent_scope, names, types, count, locals, local_count); + collect_free_vars_impl(body->as.binary.right, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_UNARY_EXPR: + collect_free_vars_impl(body->as.unary.operand, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_CALL_EXPR: + for (size_t i = 0; i < body->as.call.arg_count; i++) + collect_free_vars_impl(body->as.call.args[i], lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_METHOD_CALL: + collect_free_vars_impl(body->as.method_call.receiver, lambda_scope, parent_scope, names, types, count, locals, local_count); + for (size_t i = 0; i < body->as.method_call.arg_count; i++) + collect_free_vars_impl(body->as.method_call.args[i], lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_RETURN_STMT: + collect_free_vars_impl(body->as.return_stmt.expr, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_EXPR_STMT: + collect_free_vars_impl(body->as.expr_stmt.expr, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_ASSIGN_STMT: + collect_free_vars_impl(body->as.assign_stmt.value, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_IF_STMT: + collect_free_vars_impl(body->as.if_stmt.cond, lambda_scope, parent_scope, names, types, count, locals, local_count); + collect_free_vars_impl(body->as.if_stmt.then_block, lambda_scope, parent_scope, names, types, count, locals, local_count); + collect_free_vars_impl(body->as.if_stmt.else_block, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_WHILE_STMT: + collect_free_vars_impl(body->as.while_stmt.cond, lambda_scope, parent_scope, names, types, count, locals, local_count); + collect_free_vars_impl(body->as.while_stmt.body, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_INDEX_EXPR: + collect_free_vars_impl(body->as.index_expr.array, lambda_scope, parent_scope, names, types, count, locals, local_count); + collect_free_vars_impl(body->as.index_expr.index, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_FIELD_ACCESS: + collect_free_vars_impl(body->as.field_access.object, lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + case AST_STRUCT_INIT: + for (size_t i = 0; i < body->as.struct_init.field_count; i++) + collect_free_vars_impl(body->as.struct_init.field_values[i], lambda_scope, parent_scope, names, types, count, locals, local_count); + return; + default: return; + } +} + +static void collect_free_vars(AstNode* body, Scope* ls, Scope* ps, + const char** names, TypeKind* types, size_t* count) { + // 将 lambda 参数也当作局部变量(不会被捕获) + const char* locals[32]; size_t local_count = 0; + for (Symbol* s = ls->head; s; s = s->next) { + if (local_count < 32) locals[local_count++] = s->name; + } + collect_free_vars_impl(body, ls, ps, names, types, count, locals, &local_count); +} + void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { lambda_counter++; int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1; @@ -555,12 +650,28 @@ void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { current_return_type = saved_ret; current_return_struct_name = saved_ret_sn; + // 收集被捕获的外部变量 + const char* cap_names[16]; TypeKind cap_types[16]; size_t cap_count = 0; + collect_free_vars(node->as.lambda.body, lambda_scope, scope, cap_names, cap_types, &cap_count); + // 创建顶层函数 AST 节点, 加入队列供 codegen 使用 AstNode* fn = ast_make_function(a, gen_name, node->as.lambda.params, node->as.lambda.param_count, node->as.lambda.return_type, node->as.lambda.return_struct_type_name, node->as.lambda.body, false, NULL, 0, node->loc); + // 设置捕获信息 (复制到 arena, 局部数组会出作用域) + if (cap_count > 0) { + const char** cap_names_a = arena_alloc_impl(a, cap_count * sizeof(const char*)); + TypeKind* cap_types_a = arena_alloc_impl(a, cap_count * sizeof(TypeKind)); + memcpy(cap_names_a, cap_names, cap_count * sizeof(const char*)); + memcpy(cap_types_a, cap_types, cap_count * sizeof(TypeKind)); + fn->as.function.captured = cap_names_a; + fn->as.function.cap_types = cap_types_a; + fn->as.function.cap_count = cap_count; + node->as.lambda.captured = cap_names_a; + node->as.lambda.captured_count = cap_count; + } if (lambda_count < 256) lambda_queue[lambda_count++] = fn; diff --git a/test/programs/44_lambda.l b/test/programs/44_lambda.l index 4205b57..2fa431f 100644 --- a/test/programs/44_lambda.l +++ b/test/programs/44_lambda.l @@ -1,26 +1,26 @@ -// 闭包测试 — lambda 表达式 + 调用 -fn apply_op(x: i64, op: i64) -> i64 { - // 闭包作为参数暂不支持直接调用,返回 x * 2 - return x * 2; +// 闭包测试 — 非捕获 lambda + 变量捕获 +fn make_adder(base: i64) -> i64 { + let adder = fn(x: i64) -> i64 { return x + base; }; + return adder(50); } fn main() -> void { - // 测试1: 基本 lambda + // 测试1: 非捕获 lambda let double = fn(x: i64) -> i64 { return x * 2; }; - let r1 = double(21); - print_i64(r1); // 42 + print_i64(double(21)); // 42 - // 测试2: lambda with multiple params - let add = fn(a: i64, b: i64) -> i64 { return a + b; }; - let r2 = add(30, 12); - print_i64(r2); // 42 + // 测试2: 捕获单个变量 + let base = 100; + let add = fn(x: i64) -> i64 { return x + base; }; + print_i64(add(50)); // 150 - // 测试3: nested lambda call - let r3 = double(add(10, 11)); - print_i64(r3); // 42 + // 测试3: 捕获多个变量 + let a = 10; + let b = 20; + let sum3 = fn(x: i64) -> i64 { return x + a + b; }; + print_i64(sum3(5)); // 35 - // 测试4: lambda in sequence - let triple = fn(x: i64) -> i64 { return x * 3; }; - let r4 = triple(14); - print_i64(r4); // 42 + // 测试4: 函数内创建闭包 + let r = make_adder(200); + print_i64(r); // 250 }