feat: 闭包/lambda — 匿名函数表达式
fn(x: T) -> R { body } 作为表达式, 可赋值给变量并间接调用。
全流水线实现:
- Parser: TOK_FN 前缀 → AST_LAMBDA 节点
- Sema: 自动生成 __lambda_N 顶层函数 + 符号注册
- Sema: analyze_call_expr 支持 TYPE_CLOSURE 变量调用
- Codegen: lambda 表达式返回函数指针(i64), 调用点载入+IntToPtr+间接call
- VarEntry.closure_fn 追踪闭包变量对应的生成函数
限制(MVP v0.1): 非捕获 lambda, 返回类型固定 i64
+6 sema 测试 + 1 集成测试, 209 测试全部通过
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+34
-2
@@ -11,6 +11,8 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
|
||||
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
||||
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
||||
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
||||
case TYPE_CLOSURE:
|
||||
return LLVMInt64TypeInContext(ctx->context); // 函数指针
|
||||
case TYPE_STRUCT:
|
||||
case TYPE_ENUM: {
|
||||
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
||||
@@ -235,8 +237,28 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
|
||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||
(LLVMValueRef[]){fmt, arg}, 2, "");
|
||||
}
|
||||
LLVMTypeRef fn_ty = NULL;
|
||||
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
||||
if (!fn) return NULL;
|
||||
if (fn) {
|
||||
fn_ty = LLVMGlobalGetValueType(fn); // 普通函数: 获取函数类型
|
||||
} else {
|
||||
// 闭包调用: 函数名在变量表中 (TYPE_CLOSURE)
|
||||
VarEntry* cve = NULL;
|
||||
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
||||
if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; }
|
||||
if (cve && cve->closure_fn) {
|
||||
LLVMValueRef gen_fn = find_fn(ctx, cve->closure_fn);
|
||||
if (gen_fn) {
|
||||
fn_ty = LLVMGlobalGetValueType(gen_fn); // 获取函数类型
|
||||
LLVMValueRef closure_ptr = LLVMBuildLoad2(ctx->builder,
|
||||
LLVMInt64TypeInContext(ctx->context),
|
||||
cve->alloca, "fn_ptr");
|
||||
fn = LLVMBuildIntToPtr(ctx->builder, closure_ptr,
|
||||
LLVMPointerType(fn_ty, 0), "fn_cast");
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!fn || !fn_ty) return NULL;
|
||||
LLVMValueRef args[16];
|
||||
if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; }
|
||||
FnEntry* fn_entry = find_fn_entry(ctx, node->as.call.name);
|
||||
@@ -269,7 +291,6 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
|
||||
}
|
||||
if (!args[i]) return NULL;
|
||||
}
|
||||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
||||
(unsigned)node->as.call.arg_count,
|
||||
@@ -497,6 +518,16 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) {
|
||||
}
|
||||
CG_HANDLER(cg_list_comp)
|
||||
|
||||
static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) {
|
||||
// 返回生成函数的指针(作为 i64)
|
||||
LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name);
|
||||
if (!gen_fn) return NULL;
|
||||
LLVMValueRef ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
|
||||
LLVMInt64TypeInContext(ctx->context), "lambda_fn");
|
||||
return ptr;
|
||||
}
|
||||
CG_HANDLER(cg_lambda)
|
||||
|
||||
void codegen_expr_init(void) {
|
||||
ast_dispatch_set(&cg_dispatch, AST_LITERAL_EXPR, cg_literal);
|
||||
ast_dispatch_set(&cg_dispatch, AST_IDENT_EXPR, cg_ident);
|
||||
@@ -511,6 +542,7 @@ void codegen_expr_init(void) {
|
||||
ast_dispatch_set(&cg_dispatch, AST_BLOCK, cg_block);
|
||||
ast_dispatch_set(&cg_dispatch, AST_IF_STMT, cg_if_expr);
|
||||
ast_dispatch_set(&cg_dispatch, AST_LIST_COMP, cg_list_comp);
|
||||
ast_dispatch_set(&cg_dispatch, AST_LAMBDA, cg_lambda);
|
||||
}
|
||||
|
||||
// === 统一入口 ===
|
||||
|
||||
+11
-4
@@ -9,11 +9,14 @@ LLVMValueRef find_var(CgCtx* ctx, const char* name) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
||||
VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
||||
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||||
if (!e) return;
|
||||
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table;
|
||||
if (!e) return NULL;
|
||||
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type;
|
||||
e->closure_fn = NULL;
|
||||
e->next = ctx->var_table;
|
||||
ctx->var_table = e;
|
||||
return e;
|
||||
}
|
||||
|
||||
// === 函数表 ===
|
||||
@@ -134,7 +137,11 @@ void codegen_stmt(CgCtx* ctx, AstNode* node) {
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
add_var(ctx, node->as.let_stmt.name, alloca, var_type);
|
||||
VarEntry* ve = add_var(ctx, node->as.let_stmt.name, alloca, var_type);
|
||||
// 若 init 是 lambda, 记录闭包函数名供后续调用
|
||||
if (node->as.let_stmt.init &&
|
||||
node->as.let_stmt.init->kind == AST_LAMBDA && ve)
|
||||
ve->closure_fn = node->as.let_stmt.init->as.lambda.generated_name;
|
||||
|
||||
// 自动内存管理: 只追踪 str 堆分配 (拼接/malloc)
|
||||
// struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展
|
||||
|
||||
@@ -20,6 +20,7 @@ typedef struct VarEntry {
|
||||
const char* name;
|
||||
LLVMValueRef alloca;
|
||||
LLVMTypeRef alloca_type;
|
||||
const char* closure_fn; // 闭包对应的生成函数名
|
||||
struct VarEntry* next;
|
||||
} VarEntry;
|
||||
|
||||
@@ -70,7 +71,7 @@ LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMT
|
||||
|
||||
// === 表操作 ===
|
||||
LLVMValueRef find_var(CgCtx* ctx, const char* name);
|
||||
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
|
||||
VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
|
||||
LLVMValueRef find_fn(CgCtx* ctx, const char* name);
|
||||
FnEntry* find_fn_entry(CgCtx* ctx, const char* name);
|
||||
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc);
|
||||
|
||||
Reference in New Issue
Block a user