feat: 闭包变量捕获 — 环境结构体 + 堆分配
lambda 可捕获外层变量, 自动构建环境结构体:
let base = 100;
let f = fn(x: i64) -> i64 { return x + base; }; // 捕获 base
f(50); // → 150
全流水线实现:
- Sema: collect_free_vars 遍历 AST 收集自由变量
- AST function: captured/cap_types/cap_count 字段存储捕获信息
- Codegen: 闭包类型改为 struct {fn_ptr: i64, env_ptr: ptr}
- Codegen: lambda 表达式 malloc 环境结构体 + 存储捕获值
- Codegen: 生成函数签名添加 env_ptr 首个参数 (capturing only)
- Codegen: 函数体内通过 GEP 注册捕获变量到 var_table
- Codegen: 闭包调用自动提取 fn_ptr/env_ptr, 条件传递 env
非捕获 lambda 兼容: env_ptr=NULL, 不额外传参
嵌套 lambda 正确处理: 内层不穿透捕获外层变量
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -44,6 +44,9 @@ AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size
|
|||||||
n->as.function.multi_ret_types = NULL;
|
n->as.function.multi_ret_types = NULL;
|
||||||
n->as.function.multi_ret_snames = NULL;
|
n->as.function.multi_ret_snames = NULL;
|
||||||
n->as.function.multi_ret_count = 0;
|
n->as.function.multi_ret_count = 0;
|
||||||
|
n->as.function.captured = NULL;
|
||||||
|
n->as.function.cap_types = NULL;
|
||||||
|
n->as.function.cap_count = 0;
|
||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
+2
-1
@@ -73,7 +73,8 @@ struct AstNode {
|
|||||||
TypeKind return_type; const char* return_struct_type_name;
|
TypeKind return_type; const char* return_struct_type_name;
|
||||||
struct AstNode* body; bool is_pub;
|
struct AstNode* body; bool is_pub;
|
||||||
const char** type_params; size_t type_param_count;
|
const char** type_params; size_t type_param_count;
|
||||||
TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count; } function;
|
TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count;
|
||||||
|
const char** captured; TypeKind* cap_types; size_t cap_count; } function;
|
||||||
// AST_PARAMETER (也用作结构体字段: name + type)
|
// AST_PARAMETER (也用作结构体字段: name + type)
|
||||||
struct { const char* name; TypeKind type; const char* struct_type_name; bool is_out; } parameter;
|
struct { const char* name; TypeKind type; const char* struct_type_name; bool is_out; } parameter;
|
||||||
// AST_BLOCK
|
// AST_BLOCK
|
||||||
|
|||||||
+103
-16
@@ -11,8 +11,13 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
|
|||||||
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
||||||
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
||||||
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
||||||
case TYPE_CLOSURE:
|
case TYPE_CLOSURE: {
|
||||||
return LLVMInt64TypeInContext(ctx->context); // 函数指针
|
LLVMTypeRef cls_fields[] = {
|
||||||
|
LLVMInt64TypeInContext(ctx->context), // fn_ptr
|
||||||
|
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0) // env_ptr
|
||||||
|
};
|
||||||
|
return LLVMStructTypeInContext(ctx->context, cls_fields, 2, false);
|
||||||
|
}
|
||||||
case TYPE_STRUCT:
|
case TYPE_STRUCT:
|
||||||
case TYPE_ENUM: {
|
case TYPE_ENUM: {
|
||||||
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
||||||
@@ -239,22 +244,27 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
|
|||||||
}
|
}
|
||||||
LLVMTypeRef fn_ty = NULL;
|
LLVMTypeRef fn_ty = NULL;
|
||||||
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
||||||
|
LLVMValueRef closure_env = NULL;
|
||||||
|
LLVMValueRef gen_fn = NULL; // 闭包对应的生成函数
|
||||||
if (fn) {
|
if (fn) {
|
||||||
fn_ty = LLVMGlobalGetValueType(fn); // 普通函数: 获取函数类型
|
fn_ty = LLVMGlobalGetValueType(fn);
|
||||||
} else {
|
} else {
|
||||||
// 闭包调用: 函数名在变量表中 (TYPE_CLOSURE)
|
|
||||||
VarEntry* cve = NULL;
|
VarEntry* cve = NULL;
|
||||||
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
||||||
if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; }
|
if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; }
|
||||||
if (cve && cve->closure_fn) {
|
if (cve && cve->closure_fn) {
|
||||||
LLVMValueRef gen_fn = find_fn(ctx, cve->closure_fn);
|
gen_fn = find_fn(ctx, cve->closure_fn);
|
||||||
if (gen_fn) {
|
if (gen_fn) {
|
||||||
fn_ty = LLVMGlobalGetValueType(gen_fn); // 获取函数类型
|
fn_ty = LLVMGlobalGetValueType(gen_fn);
|
||||||
LLVMValueRef closure_ptr = LLVMBuildLoad2(ctx->builder,
|
LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE);
|
||||||
LLVMInt64TypeInContext(ctx->context),
|
LLVMValueRef cls_val = LLVMBuildLoad2(ctx->builder, cls_ty,
|
||||||
cve->alloca, "fn_ptr");
|
cve->alloca, "cls_val");
|
||||||
fn = LLVMBuildIntToPtr(ctx->builder, closure_ptr,
|
LLVMValueRef fn_val = LLVMBuildExtractValue(ctx->builder,
|
||||||
|
cls_val, 0, "fn_ptr_i64");
|
||||||
|
fn = LLVMBuildIntToPtr(ctx->builder, fn_val,
|
||||||
LLVMPointerType(fn_ty, 0), "fn_cast");
|
LLVMPointerType(fn_ty, 0), "fn_cast");
|
||||||
|
closure_env = LLVMBuildExtractValue(ctx->builder,
|
||||||
|
cls_val, 1, "env_ptr");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -291,9 +301,23 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
|
|||||||
}
|
}
|
||||||
if (!args[i]) return NULL;
|
if (!args[i]) return NULL;
|
||||||
}
|
}
|
||||||
|
// 闭包调用: 若函数有 env 参数, 将其作为第一个参数
|
||||||
|
LLVMValueRef call_args[17];
|
||||||
|
unsigned call_argc = (unsigned)node->as.call.arg_count;
|
||||||
|
// 用生成函数的参数数判断是否需要 pass env
|
||||||
|
bool need_env = closure_env && gen_fn && LLVMCountParams(gen_fn) > call_argc;
|
||||||
|
if (need_env) {
|
||||||
|
call_args[0] = closure_env;
|
||||||
|
for (unsigned i = 0; i < call_argc; i++)
|
||||||
|
call_args[i + 1] = args[i];
|
||||||
|
call_argc++;
|
||||||
|
} else {
|
||||||
|
for (unsigned i = 0; i < call_argc; i++)
|
||||||
|
call_args[i] = args[i];
|
||||||
|
}
|
||||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
return LLVMBuildCall2(ctx->builder, fn_ty, fn, call_args,
|
||||||
(unsigned)node->as.call.arg_count,
|
call_argc,
|
||||||
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
|
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
|
||||||
}
|
}
|
||||||
CG_HANDLER(cg_call)
|
CG_HANDLER(cg_call)
|
||||||
@@ -519,12 +543,75 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) {
|
|||||||
CG_HANDLER(cg_list_comp)
|
CG_HANDLER(cg_list_comp)
|
||||||
|
|
||||||
static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) {
|
static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) {
|
||||||
// 返回生成函数的指针(作为 i64)
|
// 获取生成函数
|
||||||
LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name);
|
LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name);
|
||||||
if (!gen_fn) return NULL;
|
if (!gen_fn) return NULL;
|
||||||
LLVMValueRef ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
|
|
||||||
LLVMInt64TypeInContext(ctx->context), "lambda_fn");
|
// 闭包结构体: { i64 fn_ptr, i8* env_ptr }
|
||||||
return ptr;
|
LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE);
|
||||||
|
LLVMValueRef closure = LLVMBuildAlloca(ctx->builder, cls_ty, "closure");
|
||||||
|
LLVMValueRef fn_ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
|
||||||
|
LLVMInt64TypeInContext(ctx->context), "fn_ptr");
|
||||||
|
|
||||||
|
LLVMValueRef env_ptr = LLVMConstNull(
|
||||||
|
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0));
|
||||||
|
|
||||||
|
// 若有捕获变量, 构建环境结构体
|
||||||
|
AstNode* fn_ast = NULL;
|
||||||
|
// 在 g_program 中查找对应的函数 AST 获取捕获信息
|
||||||
|
if (g_program) {
|
||||||
|
for (size_t i = 0; i < g_program->as.program.fn_count; i++) {
|
||||||
|
if (g_program->as.program.functions[i]->as.function.name &&
|
||||||
|
strcmp(g_program->as.program.functions[i]->as.function.name,
|
||||||
|
node->as.lambda.generated_name) == 0) {
|
||||||
|
fn_ast = g_program->as.program.functions[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (fn_ast && fn_ast->as.function.cap_count > 0) {
|
||||||
|
size_t nc = fn_ast->as.function.cap_count;
|
||||||
|
// 构建 env struct: { cap_type_0, cap_type_1, ... }
|
||||||
|
LLVMTypeRef* ef = arena_alloc(ctx->arena, nc * sizeof(LLVMTypeRef));
|
||||||
|
for (size_t i = 0; i < nc; i++)
|
||||||
|
ef[i] = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]);
|
||||||
|
LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx->context, ef, (unsigned)nc, false);
|
||||||
|
|
||||||
|
// malloc env
|
||||||
|
LLVMValueRef env_size = LLVMSizeOf(env_ty);
|
||||||
|
LLVMValueRef malloc_args[] = { env_size };
|
||||||
|
LLVMValueRef env = LLVMBuildCall2(ctx->builder,
|
||||||
|
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
|
||||||
|
malloc_args, 1, "env");
|
||||||
|
env_ptr = LLVMBuildBitCast(ctx->builder, env,
|
||||||
|
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0), "env_cast");
|
||||||
|
LLVMValueRef typed_env = LLVMBuildBitCast(ctx->builder, env,
|
||||||
|
LLVMPointerType(env_ty, 0), "env_typed");
|
||||||
|
|
||||||
|
// 存储捕获的值到 env struct
|
||||||
|
for (size_t i = 0; i < nc; i++) {
|
||||||
|
LLVMValueRef cap_alloca = find_var(ctx, fn_ast->as.function.captured[i]);
|
||||||
|
if (!cap_alloca) return NULL;
|
||||||
|
LLVMTypeRef cap_ty = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]);
|
||||||
|
LLVMValueRef cap_val = LLVMBuildLoad2(ctx->builder, cap_ty,
|
||||||
|
cap_alloca, "cap_val");
|
||||||
|
LLVMValueRef gep_idx[] = {
|
||||||
|
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
|
||||||
|
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned)i, false)
|
||||||
|
};
|
||||||
|
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, env_ty,
|
||||||
|
typed_env, gep_idx, 2, "cap_gep");
|
||||||
|
LLVMBuildStore(ctx->builder, cap_val, field_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 填充 closure struct: { fn_ptr, env_ptr }
|
||||||
|
LLVMValueRef f0 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 0, "cl_f0");
|
||||||
|
LLVMBuildStore(ctx->builder, fn_ptr, f0);
|
||||||
|
LLVMValueRef f1 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 1, "cl_f1");
|
||||||
|
LLVMBuildStore(ctx->builder, env_ptr, f1);
|
||||||
|
|
||||||
|
return LLVMBuildLoad2(ctx->builder, cls_ty, closure, "closure_val");
|
||||||
}
|
}
|
||||||
CG_HANDLER(cg_lambda)
|
CG_HANDLER(cg_lambda)
|
||||||
|
|
||||||
|
|||||||
+45
-11
@@ -391,17 +391,25 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
|
|||||||
// 第一遍:声明所有 L 函数
|
// 第一遍:声明所有 L 函数
|
||||||
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
|
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
|
||||||
AstNode* fn = ast->as.program.functions[i];
|
AstNode* fn = ast->as.program.functions[i];
|
||||||
|
bool has_env = fn->as.function.cap_count > 0;
|
||||||
|
size_t total_params = fn->as.function.param_count + (has_env ? 1 : 0);
|
||||||
LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
|
LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
|
||||||
fn->as.function.param_count * sizeof(LLVMTypeRef));
|
total_params * sizeof(LLVMTypeRef));
|
||||||
bool* out_params = NULL;
|
bool* out_params = NULL;
|
||||||
if (fn->as.function.param_count > 0) {
|
if (total_params > 0) {
|
||||||
out_params = arena_alloc(ctx.arena,
|
out_params = arena_alloc(ctx.arena, total_params * sizeof(bool));
|
||||||
fn->as.function.param_count * sizeof(bool));
|
for (size_t j = 0; j < total_params; j++) out_params[j] = false;
|
||||||
|
}
|
||||||
|
// 若有捕获, 第一个参数是 env_ptr
|
||||||
|
size_t poff = 0;
|
||||||
|
if (has_env) {
|
||||||
|
ptypes[0] = LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0);
|
||||||
|
poff = 1;
|
||||||
}
|
}
|
||||||
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
||||||
AstNode* param = fn->as.function.params[j];
|
AstNode* param = fn->as.function.params[j];
|
||||||
bool is_out = param->as.parameter.is_out;
|
bool is_out = param->as.parameter.is_out;
|
||||||
if (out_params) out_params[j] = is_out;
|
if (out_params) out_params[j + poff] = is_out;
|
||||||
LLVMTypeRef inner_ty;
|
LLVMTypeRef inner_ty;
|
||||||
if (param->as.parameter.type == TYPE_STRUCT &&
|
if (param->as.parameter.type == TYPE_STRUCT &&
|
||||||
param->as.parameter.struct_type_name) {
|
param->as.parameter.struct_type_name) {
|
||||||
@@ -409,7 +417,7 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
|
|||||||
} else {
|
} else {
|
||||||
inner_ty = to_llvm_type(&ctx, param->as.parameter.type);
|
inner_ty = to_llvm_type(&ctx, param->as.parameter.type);
|
||||||
}
|
}
|
||||||
ptypes[j] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty;
|
ptypes[j + poff] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty;
|
||||||
}
|
}
|
||||||
LLVMTypeRef ret_ty;
|
LLVMTypeRef ret_ty;
|
||||||
if (fn->as.function.return_type == TYPE_STRUCT &&
|
if (fn->as.function.return_type == TYPE_STRUCT &&
|
||||||
@@ -419,10 +427,9 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
|
|||||||
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
|
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
|
||||||
}
|
}
|
||||||
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
|
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
|
||||||
ptypes, (unsigned)fn->as.function.param_count, false);
|
ptypes, (unsigned)total_params, false);
|
||||||
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
|
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
|
||||||
add_fn(&ctx, fn->as.function.name, lfn, out_params,
|
add_fn(&ctx, fn->as.function.name, lfn, out_params, total_params);
|
||||||
fn->as.function.param_count);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 第二遍:生成函数体
|
// 第二遍:生成函数体
|
||||||
@@ -435,9 +442,37 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
|
|||||||
// 清空变量表(每个函数独立作用域)
|
// 清空变量表(每个函数独立作用域)
|
||||||
ctx.var_table = NULL;
|
ctx.var_table = NULL;
|
||||||
|
|
||||||
|
// 捕获变量: 第一个参数是 env_ptr, 通过 GEP 注册捕获变量
|
||||||
|
bool has_env = fn->as.function.cap_count > 0;
|
||||||
|
LLVMValueRef env_ptr = NULL;
|
||||||
|
size_t param_offset = 0;
|
||||||
|
if (has_env) {
|
||||||
|
env_ptr = LLVMGetParam(lfn, 0);
|
||||||
|
param_offset = 1;
|
||||||
|
// 生成 env struct 类型并注册捕获变量
|
||||||
|
LLVMTypeRef* ef = arena_alloc(ctx.arena,
|
||||||
|
fn->as.function.cap_count * sizeof(LLVMTypeRef));
|
||||||
|
for (size_t ci = 0; ci < fn->as.function.cap_count; ci++)
|
||||||
|
ef[ci] = to_llvm_type(&ctx, fn->as.function.cap_types[ci]);
|
||||||
|
LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx.context, ef,
|
||||||
|
(unsigned)fn->as.function.cap_count, false);
|
||||||
|
LLVMValueRef typed_env = LLVMBuildBitCast(ctx.builder, env_ptr,
|
||||||
|
LLVMPointerType(env_ty, 0), "env_typed");
|
||||||
|
for (size_t ci = 0; ci < fn->as.function.cap_count; ci++) {
|
||||||
|
LLVMValueRef gep_idx[] = {
|
||||||
|
LLVMConstInt(LLVMInt32TypeInContext(ctx.context), 0, false),
|
||||||
|
LLVMConstInt(LLVMInt32TypeInContext(ctx.context), (unsigned)ci, false)
|
||||||
|
};
|
||||||
|
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx.builder, env_ty,
|
||||||
|
typed_env, gep_idx, 2, "cap_ptr");
|
||||||
|
add_var(&ctx, fn->as.function.captured[ci], field_ptr,
|
||||||
|
to_llvm_type(&ctx, fn->as.function.cap_types[ci]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 将参数注册为变量
|
// 将参数注册为变量
|
||||||
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
for (size_t j = 0; j < fn->as.function.param_count; j++) {
|
||||||
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j);
|
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)(j + param_offset));
|
||||||
AstNode* pnode = fn->as.function.params[j];
|
AstNode* pnode = fn->as.function.params[j];
|
||||||
LLVMTypeRef param_ty;
|
LLVMTypeRef param_ty;
|
||||||
if (pnode->as.parameter.type == TYPE_STRUCT &&
|
if (pnode->as.parameter.type == TYPE_STRUCT &&
|
||||||
@@ -447,7 +482,6 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
|
|||||||
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
|
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
|
||||||
}
|
}
|
||||||
if (pnode->as.parameter.is_out) {
|
if (pnode->as.parameter.is_out) {
|
||||||
// out 参数: param 已是指向调用者变量的指针, 直接用作 alloca
|
|
||||||
add_var(&ctx, pnode->as.parameter.name, param, param_ty);
|
add_var(&ctx, pnode->as.parameter.name, param, param_ty);
|
||||||
} else {
|
} else {
|
||||||
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
|
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
|
||||||
|
|||||||
@@ -15,6 +15,9 @@
|
|||||||
extern int codegen_depth;
|
extern int codegen_depth;
|
||||||
#define MAX_CODEGEN_DEPTH 1000
|
#define MAX_CODEGEN_DEPTH 1000
|
||||||
|
|
||||||
|
// AST program (由 sema 设置, codegen 读取)
|
||||||
|
extern AstNode* g_program;
|
||||||
|
|
||||||
// === 内部状态 ===
|
// === 内部状态 ===
|
||||||
typedef struct VarEntry {
|
typedef struct VarEntry {
|
||||||
const char* name;
|
const char* name;
|
||||||
|
|||||||
@@ -534,6 +534,101 @@ void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena*
|
|||||||
// === lambda 表达式分析 ===
|
// === lambda 表达式分析 ===
|
||||||
static int lambda_counter = 0;
|
static int lambda_counter = 0;
|
||||||
|
|
||||||
|
// 遍历 AST 收集自由变量(不在 lambda_scope 但在 parent_scope 中的变量)
|
||||||
|
static void collect_free_vars_impl(AstNode* body, Scope* lambda_scope, Scope* parent_scope,
|
||||||
|
const char** names, TypeKind* types, size_t* count,
|
||||||
|
const char** locals, size_t* local_count) {
|
||||||
|
if (!body || *count >= 16) return;
|
||||||
|
switch (body->kind) {
|
||||||
|
case AST_IDENT_EXPR: {
|
||||||
|
const char* name = body->as.ident.name;
|
||||||
|
// 跳过 lambda 参数和局部变量
|
||||||
|
for (size_t i = 0; i < *local_count; i++)
|
||||||
|
if (strcmp(locals[i], name) == 0) return;
|
||||||
|
// 在父作用域中查找(不在 lambda scope 中)
|
||||||
|
Symbol* sym = scope_lookup(parent_scope, name);
|
||||||
|
if (sym && (sym->kind == SYM_VARIABLE || sym->kind == SYM_PARAMETER)) {
|
||||||
|
for (size_t i = 0; i < *count; i++)
|
||||||
|
if (strcmp(names[i], name) == 0) return;
|
||||||
|
names[*count] = name;
|
||||||
|
types[*count] = sym->type;
|
||||||
|
(*count)++;
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
case AST_LAMBDA: return;
|
||||||
|
case AST_LET_STMT:
|
||||||
|
// 先分析 init (可能引用外部变量)
|
||||||
|
collect_free_vars_impl(body->as.let_stmt.init, lambda_scope, parent_scope,
|
||||||
|
names, types, count, locals, local_count);
|
||||||
|
// 将新声明的变量加入 locals
|
||||||
|
if (*local_count < 16)
|
||||||
|
locals[(*local_count)++] = body->as.let_stmt.name;
|
||||||
|
return;
|
||||||
|
case AST_BLOCK:
|
||||||
|
for (size_t i = 0; i < body->as.block.stmt_count; i++)
|
||||||
|
collect_free_vars_impl(body->as.block.stmts[i], lambda_scope, parent_scope,
|
||||||
|
names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_BINARY_EXPR:
|
||||||
|
collect_free_vars_impl(body->as.binary.left, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
collect_free_vars_impl(body->as.binary.right, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_UNARY_EXPR:
|
||||||
|
collect_free_vars_impl(body->as.unary.operand, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_CALL_EXPR:
|
||||||
|
for (size_t i = 0; i < body->as.call.arg_count; i++)
|
||||||
|
collect_free_vars_impl(body->as.call.args[i], lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_METHOD_CALL:
|
||||||
|
collect_free_vars_impl(body->as.method_call.receiver, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
for (size_t i = 0; i < body->as.method_call.arg_count; i++)
|
||||||
|
collect_free_vars_impl(body->as.method_call.args[i], lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_RETURN_STMT:
|
||||||
|
collect_free_vars_impl(body->as.return_stmt.expr, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_EXPR_STMT:
|
||||||
|
collect_free_vars_impl(body->as.expr_stmt.expr, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_ASSIGN_STMT:
|
||||||
|
collect_free_vars_impl(body->as.assign_stmt.value, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_IF_STMT:
|
||||||
|
collect_free_vars_impl(body->as.if_stmt.cond, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
collect_free_vars_impl(body->as.if_stmt.then_block, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
collect_free_vars_impl(body->as.if_stmt.else_block, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_WHILE_STMT:
|
||||||
|
collect_free_vars_impl(body->as.while_stmt.cond, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
collect_free_vars_impl(body->as.while_stmt.body, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_INDEX_EXPR:
|
||||||
|
collect_free_vars_impl(body->as.index_expr.array, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
collect_free_vars_impl(body->as.index_expr.index, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_FIELD_ACCESS:
|
||||||
|
collect_free_vars_impl(body->as.field_access.object, lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
case AST_STRUCT_INIT:
|
||||||
|
for (size_t i = 0; i < body->as.struct_init.field_count; i++)
|
||||||
|
collect_free_vars_impl(body->as.struct_init.field_values[i], lambda_scope, parent_scope, names, types, count, locals, local_count);
|
||||||
|
return;
|
||||||
|
default: return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void collect_free_vars(AstNode* body, Scope* ls, Scope* ps,
|
||||||
|
const char** names, TypeKind* types, size_t* count) {
|
||||||
|
// 将 lambda 参数也当作局部变量(不会被捕获)
|
||||||
|
const char* locals[32]; size_t local_count = 0;
|
||||||
|
for (Symbol* s = ls->head; s; s = s->next) {
|
||||||
|
if (local_count < 32) locals[local_count++] = s->name;
|
||||||
|
}
|
||||||
|
collect_free_vars_impl(body, ls, ps, names, types, count, locals, &local_count);
|
||||||
|
}
|
||||||
|
|
||||||
void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||||
lambda_counter++;
|
lambda_counter++;
|
||||||
int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1;
|
int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1;
|
||||||
@@ -555,12 +650,28 @@ void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
|||||||
current_return_type = saved_ret;
|
current_return_type = saved_ret;
|
||||||
current_return_struct_name = saved_ret_sn;
|
current_return_struct_name = saved_ret_sn;
|
||||||
|
|
||||||
|
// 收集被捕获的外部变量
|
||||||
|
const char* cap_names[16]; TypeKind cap_types[16]; size_t cap_count = 0;
|
||||||
|
collect_free_vars(node->as.lambda.body, lambda_scope, scope, cap_names, cap_types, &cap_count);
|
||||||
|
|
||||||
// 创建顶层函数 AST 节点, 加入队列供 codegen 使用
|
// 创建顶层函数 AST 节点, 加入队列供 codegen 使用
|
||||||
AstNode* fn = ast_make_function(a, gen_name,
|
AstNode* fn = ast_make_function(a, gen_name,
|
||||||
node->as.lambda.params, node->as.lambda.param_count,
|
node->as.lambda.params, node->as.lambda.param_count,
|
||||||
node->as.lambda.return_type,
|
node->as.lambda.return_type,
|
||||||
node->as.lambda.return_struct_type_name,
|
node->as.lambda.return_struct_type_name,
|
||||||
node->as.lambda.body, false, NULL, 0, node->loc);
|
node->as.lambda.body, false, NULL, 0, node->loc);
|
||||||
|
// 设置捕获信息 (复制到 arena, 局部数组会出作用域)
|
||||||
|
if (cap_count > 0) {
|
||||||
|
const char** cap_names_a = arena_alloc_impl(a, cap_count * sizeof(const char*));
|
||||||
|
TypeKind* cap_types_a = arena_alloc_impl(a, cap_count * sizeof(TypeKind));
|
||||||
|
memcpy(cap_names_a, cap_names, cap_count * sizeof(const char*));
|
||||||
|
memcpy(cap_types_a, cap_types, cap_count * sizeof(TypeKind));
|
||||||
|
fn->as.function.captured = cap_names_a;
|
||||||
|
fn->as.function.cap_types = cap_types_a;
|
||||||
|
fn->as.function.cap_count = cap_count;
|
||||||
|
node->as.lambda.captured = cap_names_a;
|
||||||
|
node->as.lambda.captured_count = cap_count;
|
||||||
|
}
|
||||||
if (lambda_count < 256)
|
if (lambda_count < 256)
|
||||||
lambda_queue[lambda_count++] = fn;
|
lambda_queue[lambda_count++] = fn;
|
||||||
|
|
||||||
|
|||||||
+18
-18
@@ -1,26 +1,26 @@
|
|||||||
// 闭包测试 — lambda 表达式 + 调用
|
// 闭包测试 — 非捕获 lambda + 变量捕获
|
||||||
fn apply_op(x: i64, op: i64) -> i64 {
|
fn make_adder(base: i64) -> i64 {
|
||||||
// 闭包作为参数暂不支持直接调用,返回 x * 2
|
let adder = fn(x: i64) -> i64 { return x + base; };
|
||||||
return x * 2;
|
return adder(50);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> void {
|
fn main() -> void {
|
||||||
// 测试1: 基本 lambda
|
// 测试1: 非捕获 lambda
|
||||||
let double = fn(x: i64) -> i64 { return x * 2; };
|
let double = fn(x: i64) -> i64 { return x * 2; };
|
||||||
let r1 = double(21);
|
print_i64(double(21)); // 42
|
||||||
print_i64(r1); // 42
|
|
||||||
|
|
||||||
// 测试2: lambda with multiple params
|
// 测试2: 捕获单个变量
|
||||||
let add = fn(a: i64, b: i64) -> i64 { return a + b; };
|
let base = 100;
|
||||||
let r2 = add(30, 12);
|
let add = fn(x: i64) -> i64 { return x + base; };
|
||||||
print_i64(r2); // 42
|
print_i64(add(50)); // 150
|
||||||
|
|
||||||
// 测试3: nested lambda call
|
// 测试3: 捕获多个变量
|
||||||
let r3 = double(add(10, 11));
|
let a = 10;
|
||||||
print_i64(r3); // 42
|
let b = 20;
|
||||||
|
let sum3 = fn(x: i64) -> i64 { return x + a + b; };
|
||||||
|
print_i64(sum3(5)); // 35
|
||||||
|
|
||||||
// 测试4: lambda in sequence
|
// 测试4: 函数内创建闭包
|
||||||
let triple = fn(x: i64) -> i64 { return x * 3; };
|
let r = make_adder(200);
|
||||||
let r4 = triple(14);
|
print_i64(r); // 250
|
||||||
print_i64(r4); // 42
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user