feat: 闭包变量捕获 — 环境结构体 + 堆分配

lambda 可捕获外层变量, 自动构建环境结构体:

let base = 100;
let f = fn(x: i64) -> i64 { return x + base; };  // 捕获 base
f(50);  // → 150

全流水线实现:
- Sema: collect_free_vars 遍历 AST 收集自由变量
- AST function: captured/cap_types/cap_count 字段存储捕获信息
- Codegen: 闭包类型改为 struct {fn_ptr: i64, env_ptr: ptr}
- Codegen: lambda 表达式 malloc 环境结构体 + 存储捕获值
- Codegen: 生成函数签名添加 env_ptr 首个参数 (capturing only)
- Codegen: 函数体内通过 GEP 注册捕获变量到 var_table
- Codegen: 闭包调用自动提取 fn_ptr/env_ptr, 条件传递 env

非捕获 lambda 兼容: env_ptr=NULL, 不额外传参
嵌套 lambda 正确处理: 内层不穿透捕获外层变量

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-07 15:24:35 +08:00
parent 06d80f441a
commit f5c0650a97
7 changed files with 285 additions and 46 deletions
+103 -16
View File
@@ -11,8 +11,13 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
case TYPE_CLOSURE:
return LLVMInt64TypeInContext(ctx->context); // 函数指针
case TYPE_CLOSURE: {
LLVMTypeRef cls_fields[] = {
LLVMInt64TypeInContext(ctx->context), // fn_ptr
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0) // env_ptr
};
return LLVMStructTypeInContext(ctx->context, cls_fields, 2, false);
}
case TYPE_STRUCT:
case TYPE_ENUM: {
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
@@ -239,22 +244,27 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
}
LLVMTypeRef fn_ty = NULL;
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
LLVMValueRef closure_env = NULL;
LLVMValueRef gen_fn = NULL; // 闭包对应的生成函数
if (fn) {
fn_ty = LLVMGlobalGetValueType(fn); // 普通函数: 获取函数类型
fn_ty = LLVMGlobalGetValueType(fn);
} else {
// 闭包调用: 函数名在变量表中 (TYPE_CLOSURE)
VarEntry* cve = NULL;
for (VarEntry* e = ctx->var_table; e; e = e->next)
if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; }
if (cve && cve->closure_fn) {
LLVMValueRef gen_fn = find_fn(ctx, cve->closure_fn);
gen_fn = find_fn(ctx, cve->closure_fn);
if (gen_fn) {
fn_ty = LLVMGlobalGetValueType(gen_fn); // 获取函数类型
LLVMValueRef closure_ptr = LLVMBuildLoad2(ctx->builder,
LLVMInt64TypeInContext(ctx->context),
cve->alloca, "fn_ptr");
fn = LLVMBuildIntToPtr(ctx->builder, closure_ptr,
fn_ty = LLVMGlobalGetValueType(gen_fn);
LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE);
LLVMValueRef cls_val = LLVMBuildLoad2(ctx->builder, cls_ty,
cve->alloca, "cls_val");
LLVMValueRef fn_val = LLVMBuildExtractValue(ctx->builder,
cls_val, 0, "fn_ptr_i64");
fn = LLVMBuildIntToPtr(ctx->builder, fn_val,
LLVMPointerType(fn_ty, 0), "fn_cast");
closure_env = LLVMBuildExtractValue(ctx->builder,
cls_val, 1, "env_ptr");
}
}
}
@@ -291,9 +301,23 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
}
if (!args[i]) return NULL;
}
// 闭包调用: 若函数有 env 参数, 将其作为第一个参数
LLVMValueRef call_args[17];
unsigned call_argc = (unsigned)node->as.call.arg_count;
// 用生成函数的参数数判断是否需要 pass env
bool need_env = closure_env && gen_fn && LLVMCountParams(gen_fn) > call_argc;
if (need_env) {
call_args[0] = closure_env;
for (unsigned i = 0; i < call_argc; i++)
call_args[i + 1] = args[i];
call_argc++;
} else {
for (unsigned i = 0; i < call_argc; i++)
call_args[i] = args[i];
}
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
(unsigned)node->as.call.arg_count,
return LLVMBuildCall2(ctx->builder, fn_ty, fn, call_args,
call_argc,
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
}
CG_HANDLER(cg_call)
@@ -519,12 +543,75 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) {
CG_HANDLER(cg_list_comp)
static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) {
// 返回生成函数的指针(作为 i64
// 获取生成函数
LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name);
if (!gen_fn) return NULL;
LLVMValueRef ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
LLVMInt64TypeInContext(ctx->context), "lambda_fn");
return ptr;
// 闭包结构体: { i64 fn_ptr, i8* env_ptr }
LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE);
LLVMValueRef closure = LLVMBuildAlloca(ctx->builder, cls_ty, "closure");
LLVMValueRef fn_ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
LLVMInt64TypeInContext(ctx->context), "fn_ptr");
LLVMValueRef env_ptr = LLVMConstNull(
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0));
// 若有捕获变量, 构建环境结构体
AstNode* fn_ast = NULL;
// 在 g_program 中查找对应的函数 AST 获取捕获信息
if (g_program) {
for (size_t i = 0; i < g_program->as.program.fn_count; i++) {
if (g_program->as.program.functions[i]->as.function.name &&
strcmp(g_program->as.program.functions[i]->as.function.name,
node->as.lambda.generated_name) == 0) {
fn_ast = g_program->as.program.functions[i];
break;
}
}
}
if (fn_ast && fn_ast->as.function.cap_count > 0) {
size_t nc = fn_ast->as.function.cap_count;
// 构建 env struct: { cap_type_0, cap_type_1, ... }
LLVMTypeRef* ef = arena_alloc(ctx->arena, nc * sizeof(LLVMTypeRef));
for (size_t i = 0; i < nc; i++)
ef[i] = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]);
LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx->context, ef, (unsigned)nc, false);
// malloc env
LLVMValueRef env_size = LLVMSizeOf(env_ty);
LLVMValueRef malloc_args[] = { env_size };
LLVMValueRef env = LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
malloc_args, 1, "env");
env_ptr = LLVMBuildBitCast(ctx->builder, env,
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0), "env_cast");
LLVMValueRef typed_env = LLVMBuildBitCast(ctx->builder, env,
LLVMPointerType(env_ty, 0), "env_typed");
// 存储捕获的值到 env struct
for (size_t i = 0; i < nc; i++) {
LLVMValueRef cap_alloca = find_var(ctx, fn_ast->as.function.captured[i]);
if (!cap_alloca) return NULL;
LLVMTypeRef cap_ty = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]);
LLVMValueRef cap_val = LLVMBuildLoad2(ctx->builder, cap_ty,
cap_alloca, "cap_val");
LLVMValueRef gep_idx[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned)i, false)
};
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, env_ty,
typed_env, gep_idx, 2, "cap_gep");
LLVMBuildStore(ctx->builder, cap_val, field_ptr);
}
}
// 填充 closure struct: { fn_ptr, env_ptr }
LLVMValueRef f0 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 0, "cl_f0");
LLVMBuildStore(ctx->builder, fn_ptr, f0);
LLVMValueRef f1 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 1, "cl_f1");
LLVMBuildStore(ctx->builder, env_ptr, f1);
return LLVMBuildLoad2(ctx->builder, cls_ty, closure, "closure_val");
}
CG_HANDLER(cg_lambda)
+45 -11
View File
@@ -391,17 +391,25 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
// 第一遍:声明所有 L 函数
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
AstNode* fn = ast->as.program.functions[i];
bool has_env = fn->as.function.cap_count > 0;
size_t total_params = fn->as.function.param_count + (has_env ? 1 : 0);
LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
fn->as.function.param_count * sizeof(LLVMTypeRef));
total_params * sizeof(LLVMTypeRef));
bool* out_params = NULL;
if (fn->as.function.param_count > 0) {
out_params = arena_alloc(ctx.arena,
fn->as.function.param_count * sizeof(bool));
if (total_params > 0) {
out_params = arena_alloc(ctx.arena, total_params * sizeof(bool));
for (size_t j = 0; j < total_params; j++) out_params[j] = false;
}
// 若有捕获, 第一个参数是 env_ptr
size_t poff = 0;
if (has_env) {
ptypes[0] = LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0);
poff = 1;
}
for (size_t j = 0; j < fn->as.function.param_count; j++) {
AstNode* param = fn->as.function.params[j];
bool is_out = param->as.parameter.is_out;
if (out_params) out_params[j] = is_out;
if (out_params) out_params[j + poff] = is_out;
LLVMTypeRef inner_ty;
if (param->as.parameter.type == TYPE_STRUCT &&
param->as.parameter.struct_type_name) {
@@ -409,7 +417,7 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
} else {
inner_ty = to_llvm_type(&ctx, param->as.parameter.type);
}
ptypes[j] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty;
ptypes[j + poff] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty;
}
LLVMTypeRef ret_ty;
if (fn->as.function.return_type == TYPE_STRUCT &&
@@ -419,10 +427,9 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
}
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
ptypes, (unsigned)fn->as.function.param_count, false);
ptypes, (unsigned)total_params, false);
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
add_fn(&ctx, fn->as.function.name, lfn, out_params,
fn->as.function.param_count);
add_fn(&ctx, fn->as.function.name, lfn, out_params, total_params);
}
// 第二遍:生成函数体
@@ -435,9 +442,37 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
// 清空变量表(每个函数独立作用域)
ctx.var_table = NULL;
// 捕获变量: 第一个参数是 env_ptr, 通过 GEP 注册捕获变量
bool has_env = fn->as.function.cap_count > 0;
LLVMValueRef env_ptr = NULL;
size_t param_offset = 0;
if (has_env) {
env_ptr = LLVMGetParam(lfn, 0);
param_offset = 1;
// 生成 env struct 类型并注册捕获变量
LLVMTypeRef* ef = arena_alloc(ctx.arena,
fn->as.function.cap_count * sizeof(LLVMTypeRef));
for (size_t ci = 0; ci < fn->as.function.cap_count; ci++)
ef[ci] = to_llvm_type(&ctx, fn->as.function.cap_types[ci]);
LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx.context, ef,
(unsigned)fn->as.function.cap_count, false);
LLVMValueRef typed_env = LLVMBuildBitCast(ctx.builder, env_ptr,
LLVMPointerType(env_ty, 0), "env_typed");
for (size_t ci = 0; ci < fn->as.function.cap_count; ci++) {
LLVMValueRef gep_idx[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx.context), 0, false),
LLVMConstInt(LLVMInt32TypeInContext(ctx.context), (unsigned)ci, false)
};
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx.builder, env_ty,
typed_env, gep_idx, 2, "cap_ptr");
add_var(&ctx, fn->as.function.captured[ci], field_ptr,
to_llvm_type(&ctx, fn->as.function.cap_types[ci]));
}
}
// 将参数注册为变量
for (size_t j = 0; j < fn->as.function.param_count; j++) {
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j);
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)(j + param_offset));
AstNode* pnode = fn->as.function.params[j];
LLVMTypeRef param_ty;
if (pnode->as.parameter.type == TYPE_STRUCT &&
@@ -447,7 +482,6 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
}
if (pnode->as.parameter.is_out) {
// out 参数: param 已是指向调用者变量的指针, 直接用作 alloca
add_var(&ctx, pnode->as.parameter.name, param, param_ty);
} else {
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
+3
View File
@@ -15,6 +15,9 @@
extern int codegen_depth;
#define MAX_CODEGEN_DEPTH 1000
// AST program (由 sema 设置, codegen 读取)
extern AstNode* g_program;
// === 内部状态 ===
typedef struct VarEntry {
const char* name;