feat: 闭包/lambda — 匿名函数表达式
fn(x: T) -> R { body } 作为表达式, 可赋值给变量并间接调用。
全流水线实现:
- Parser: TOK_FN 前缀 → AST_LAMBDA 节点
- Sema: 自动生成 __lambda_N 顶层函数 + 符号注册
- Sema: analyze_call_expr 支持 TYPE_CLOSURE 变量调用
- Codegen: lambda 表达式返回函数指针(i64), 调用点载入+IntToPtr+间接call
- VarEntry.closure_fn 追踪闭包变量对应的生成函数
限制(MVP v0.1): 非捕获 lambda, 返回类型固定 i64
+6 sema 测试 + 1 集成测试, 209 测试全部通过
Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
@@ -18,6 +18,7 @@ typedef enum {
|
|||||||
TYPE_STRUCT, // 结构体类型
|
TYPE_STRUCT, // 结构体类型
|
||||||
TYPE_ENUM, // 枚举类型
|
TYPE_ENUM, // 枚举类型
|
||||||
TYPE_ARRAY, // 固定大小数组类型
|
TYPE_ARRAY, // 固定大小数组类型
|
||||||
|
TYPE_CLOSURE, // 闭包类型 (函数指针 + 环境指针)
|
||||||
TYPE_GENERIC, // 泛型类型参数(单态化前)
|
TYPE_GENERIC, // 泛型类型参数(单态化前)
|
||||||
TYPE_UNKNOWN, // 尚未推断
|
TYPE_UNKNOWN, // 尚未推断
|
||||||
TYPE_ERROR, // 类型错误
|
TYPE_ERROR, // 类型错误
|
||||||
@@ -36,6 +37,7 @@ static inline const char* type_name(TypeKind kind) {
|
|||||||
case TYPE_STRUCT: return "struct";
|
case TYPE_STRUCT: return "struct";
|
||||||
case TYPE_ENUM: return "enum";
|
case TYPE_ENUM: return "enum";
|
||||||
case TYPE_ARRAY: return "array";
|
case TYPE_ARRAY: return "array";
|
||||||
|
case TYPE_CLOSURE: return "closure";
|
||||||
default: return "<unknown>";
|
default: return "<unknown>";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -280,6 +280,20 @@ AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method
|
|||||||
return n;
|
return n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AstNode* ast_make_lambda(void* alloc, AstNode** params, size_t pcount,
|
||||||
|
TypeKind ret, const char* ret_struct_name,
|
||||||
|
AstNode* body, SourceLoc loc) {
|
||||||
|
NEW(alloc, AST_LAMBDA);
|
||||||
|
n->as.lambda.params = params; n->as.lambda.param_count = pcount;
|
||||||
|
n->as.lambda.return_type = ret;
|
||||||
|
n->as.lambda.return_struct_type_name = ret_struct_name;
|
||||||
|
n->as.lambda.body = body;
|
||||||
|
n->as.lambda.generated_name = NULL;
|
||||||
|
n->as.lambda.captured = NULL;
|
||||||
|
n->as.lambda.captured_count = 0;
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
|
||||||
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc) {
|
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc) {
|
||||||
NEW(alloc, AST_MOD_DECL);
|
NEW(alloc, AST_MOD_DECL);
|
||||||
n->as.mod_decl.name = name;
|
n->as.mod_decl.name = name;
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ typedef enum {
|
|||||||
AST_ARRAY_ASSIGN_STMT,// arr[i] = expr
|
AST_ARRAY_ASSIGN_STMT,// arr[i] = expr
|
||||||
AST_IMPL_BLOCK, // impl StructName { fn method(...) ... }
|
AST_IMPL_BLOCK, // impl StructName { fn method(...) ... }
|
||||||
AST_METHOD_CALL, // receiver.method(args)
|
AST_METHOD_CALL, // receiver.method(args)
|
||||||
|
AST_LAMBDA, // fn(x: T) -> R { body } 匿名函数/闭包
|
||||||
AST_MOD_DECL, // mod foo;
|
AST_MOD_DECL, // mod foo;
|
||||||
AST_USE_DECL, // use foo::bar;
|
AST_USE_DECL, // use foo::bar;
|
||||||
AST_TRAIT_DECL, // trait Name { fn ... }
|
AST_TRAIT_DECL, // trait Name { fn ... }
|
||||||
@@ -131,6 +132,12 @@ struct AstNode {
|
|||||||
struct { const char* struct_name; struct AstNode** methods; size_t method_count; } impl_block;
|
struct { const char* struct_name; struct AstNode** methods; size_t method_count; } impl_block;
|
||||||
// AST_METHOD_CALL
|
// AST_METHOD_CALL
|
||||||
struct { struct AstNode* receiver; const char* method_name; struct AstNode** args; const char** arg_names; size_t arg_count; } method_call;
|
struct { struct AstNode* receiver; const char* method_name; struct AstNode** args; const char** arg_names; size_t arg_count; } method_call;
|
||||||
|
// AST_LAMBDA
|
||||||
|
struct { struct AstNode** params; size_t param_count;
|
||||||
|
TypeKind return_type; const char* return_struct_type_name;
|
||||||
|
struct AstNode* body;
|
||||||
|
const char* generated_name; // 自动生成的顶层函数名
|
||||||
|
const char** captured; size_t captured_count; } lambda;
|
||||||
// AST_MOD_DECL
|
// AST_MOD_DECL
|
||||||
struct { const char* name; struct AstNode* ast; } mod_decl;
|
struct { const char* name; struct AstNode* ast; } mod_decl;
|
||||||
// AST_USE_DECL
|
// AST_USE_DECL
|
||||||
@@ -185,6 +192,9 @@ AstNode* ast_make_index_expr(void* alloc, AstNode* array, AstNode* index, Source
|
|||||||
AstNode* ast_make_array_assign(void* alloc, const char* name, AstNode* index, AstNode* value, SourceLoc loc);
|
AstNode* ast_make_array_assign(void* alloc, const char* name, AstNode* index, AstNode* value, SourceLoc loc);
|
||||||
AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** methods, size_t count, SourceLoc loc);
|
AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** methods, size_t count, SourceLoc loc);
|
||||||
AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc);
|
AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc);
|
||||||
|
AstNode* ast_make_lambda(void* alloc, AstNode** params, size_t pcount,
|
||||||
|
TypeKind ret, const char* ret_struct_name,
|
||||||
|
AstNode* body, SourceLoc loc);
|
||||||
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc);
|
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc);
|
||||||
AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc);
|
AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc);
|
||||||
AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc);
|
AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc);
|
||||||
|
|||||||
+1
-1
@@ -8,7 +8,7 @@ typedef void* (*VisitFn)(void* ctx, AstNode* node);
|
|||||||
|
|
||||||
// 遍历表 — 按 AstKind 索引, 未处理的条目为 NULL
|
// 遍历表 — 按 AstKind 索引, 未处理的条目为 NULL
|
||||||
// 新增 AST 节点: 在此表新增一条目, 编译器会警告未初始化的函数指针
|
// 新增 AST 节点: 在此表新增一条目, 编译器会警告未初始化的函数指针
|
||||||
enum { VISIT_TABLE_SIZE = 28 };
|
enum { VISIT_TABLE_SIZE = 29 };
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
void* ctx;
|
void* ctx;
|
||||||
|
|||||||
+34
-2
@@ -11,6 +11,8 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
|
|||||||
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
|
||||||
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
|
||||||
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
|
||||||
|
case TYPE_CLOSURE:
|
||||||
|
return LLVMInt64TypeInContext(ctx->context); // 函数指针
|
||||||
case TYPE_STRUCT:
|
case TYPE_STRUCT:
|
||||||
case TYPE_ENUM: {
|
case TYPE_ENUM: {
|
||||||
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
|
||||||
@@ -235,8 +237,28 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
|
|||||||
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
|
||||||
(LLVMValueRef[]){fmt, arg}, 2, "");
|
(LLVMValueRef[]){fmt, arg}, 2, "");
|
||||||
}
|
}
|
||||||
|
LLVMTypeRef fn_ty = NULL;
|
||||||
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
|
||||||
if (!fn) return NULL;
|
if (fn) {
|
||||||
|
fn_ty = LLVMGlobalGetValueType(fn); // 普通函数: 获取函数类型
|
||||||
|
} else {
|
||||||
|
// 闭包调用: 函数名在变量表中 (TYPE_CLOSURE)
|
||||||
|
VarEntry* cve = NULL;
|
||||||
|
for (VarEntry* e = ctx->var_table; e; e = e->next)
|
||||||
|
if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; }
|
||||||
|
if (cve && cve->closure_fn) {
|
||||||
|
LLVMValueRef gen_fn = find_fn(ctx, cve->closure_fn);
|
||||||
|
if (gen_fn) {
|
||||||
|
fn_ty = LLVMGlobalGetValueType(gen_fn); // 获取函数类型
|
||||||
|
LLVMValueRef closure_ptr = LLVMBuildLoad2(ctx->builder,
|
||||||
|
LLVMInt64TypeInContext(ctx->context),
|
||||||
|
cve->alloca, "fn_ptr");
|
||||||
|
fn = LLVMBuildIntToPtr(ctx->builder, closure_ptr,
|
||||||
|
LLVMPointerType(fn_ty, 0), "fn_cast");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (!fn || !fn_ty) return NULL;
|
||||||
LLVMValueRef args[16];
|
LLVMValueRef args[16];
|
||||||
if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; }
|
if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; }
|
||||||
FnEntry* fn_entry = find_fn_entry(ctx, node->as.call.name);
|
FnEntry* fn_entry = find_fn_entry(ctx, node->as.call.name);
|
||||||
@@ -269,7 +291,6 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
|
|||||||
}
|
}
|
||||||
if (!args[i]) return NULL;
|
if (!args[i]) return NULL;
|
||||||
}
|
}
|
||||||
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
|
|
||||||
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
|
||||||
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
|
||||||
(unsigned)node->as.call.arg_count,
|
(unsigned)node->as.call.arg_count,
|
||||||
@@ -497,6 +518,16 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) {
|
|||||||
}
|
}
|
||||||
CG_HANDLER(cg_list_comp)
|
CG_HANDLER(cg_list_comp)
|
||||||
|
|
||||||
|
static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) {
|
||||||
|
// 返回生成函数的指针(作为 i64)
|
||||||
|
LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name);
|
||||||
|
if (!gen_fn) return NULL;
|
||||||
|
LLVMValueRef ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
|
||||||
|
LLVMInt64TypeInContext(ctx->context), "lambda_fn");
|
||||||
|
return ptr;
|
||||||
|
}
|
||||||
|
CG_HANDLER(cg_lambda)
|
||||||
|
|
||||||
void codegen_expr_init(void) {
|
void codegen_expr_init(void) {
|
||||||
ast_dispatch_set(&cg_dispatch, AST_LITERAL_EXPR, cg_literal);
|
ast_dispatch_set(&cg_dispatch, AST_LITERAL_EXPR, cg_literal);
|
||||||
ast_dispatch_set(&cg_dispatch, AST_IDENT_EXPR, cg_ident);
|
ast_dispatch_set(&cg_dispatch, AST_IDENT_EXPR, cg_ident);
|
||||||
@@ -511,6 +542,7 @@ void codegen_expr_init(void) {
|
|||||||
ast_dispatch_set(&cg_dispatch, AST_BLOCK, cg_block);
|
ast_dispatch_set(&cg_dispatch, AST_BLOCK, cg_block);
|
||||||
ast_dispatch_set(&cg_dispatch, AST_IF_STMT, cg_if_expr);
|
ast_dispatch_set(&cg_dispatch, AST_IF_STMT, cg_if_expr);
|
||||||
ast_dispatch_set(&cg_dispatch, AST_LIST_COMP, cg_list_comp);
|
ast_dispatch_set(&cg_dispatch, AST_LIST_COMP, cg_list_comp);
|
||||||
|
ast_dispatch_set(&cg_dispatch, AST_LAMBDA, cg_lambda);
|
||||||
}
|
}
|
||||||
|
|
||||||
// === 统一入口 ===
|
// === 统一入口 ===
|
||||||
|
|||||||
+11
-4
@@ -9,11 +9,14 @@ LLVMValueRef find_var(CgCtx* ctx, const char* name) {
|
|||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
|
||||||
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
|
||||||
if (!e) return;
|
if (!e) return NULL;
|
||||||
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table;
|
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type;
|
||||||
|
e->closure_fn = NULL;
|
||||||
|
e->next = ctx->var_table;
|
||||||
ctx->var_table = e;
|
ctx->var_table = e;
|
||||||
|
return e;
|
||||||
}
|
}
|
||||||
|
|
||||||
// === 函数表 ===
|
// === 函数表 ===
|
||||||
@@ -134,7 +137,11 @@ void codegen_stmt(CgCtx* ctx, AstNode* node) {
|
|||||||
} else {
|
} else {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
add_var(ctx, node->as.let_stmt.name, alloca, var_type);
|
VarEntry* ve = add_var(ctx, node->as.let_stmt.name, alloca, var_type);
|
||||||
|
// 若 init 是 lambda, 记录闭包函数名供后续调用
|
||||||
|
if (node->as.let_stmt.init &&
|
||||||
|
node->as.let_stmt.init->kind == AST_LAMBDA && ve)
|
||||||
|
ve->closure_fn = node->as.let_stmt.init->as.lambda.generated_name;
|
||||||
|
|
||||||
// 自动内存管理: 只追踪 str 堆分配 (拼接/malloc)
|
// 自动内存管理: 只追踪 str 堆分配 (拼接/malloc)
|
||||||
// struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展
|
// struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ typedef struct VarEntry {
|
|||||||
const char* name;
|
const char* name;
|
||||||
LLVMValueRef alloca;
|
LLVMValueRef alloca;
|
||||||
LLVMTypeRef alloca_type;
|
LLVMTypeRef alloca_type;
|
||||||
|
const char* closure_fn; // 闭包对应的生成函数名
|
||||||
struct VarEntry* next;
|
struct VarEntry* next;
|
||||||
} VarEntry;
|
} VarEntry;
|
||||||
|
|
||||||
@@ -70,7 +71,7 @@ LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMT
|
|||||||
|
|
||||||
// === 表操作 ===
|
// === 表操作 ===
|
||||||
LLVMValueRef find_var(CgCtx* ctx, const char* name);
|
LLVMValueRef find_var(CgCtx* ctx, const char* name);
|
||||||
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
|
VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
|
||||||
LLVMValueRef find_fn(CgCtx* ctx, const char* name);
|
LLVMValueRef find_fn(CgCtx* ctx, const char* name);
|
||||||
FnEntry* find_fn_entry(CgCtx* ctx, const char* name);
|
FnEntry* find_fn_entry(CgCtx* ctx, const char* name);
|
||||||
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc);
|
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc);
|
||||||
|
|||||||
@@ -322,6 +322,38 @@ AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error) {
|
|||||||
left = ast_make_list_comp(p->arena,
|
left = ast_make_list_comp(p->arena,
|
||||||
arena_strdup_impl(p->arena, vname->start, vname->length),
|
arena_strdup_impl(p->arena, vname->start, vname->length),
|
||||||
arr, body, tok_loc(tok));
|
arr, body, tok_loc(tok));
|
||||||
|
} else if (tok->kind == TOK_FN) {
|
||||||
|
// lambda: fn(params) -> RetType { body }
|
||||||
|
const Token* fn_tok = advance(p); // 跳过 fn
|
||||||
|
// 泛型参数暂不支持(lambda用捕获替代)
|
||||||
|
if (!expect(p, TOK_LPAREN, error, "缺少 '('")) return NULL;
|
||||||
|
AstNode* plist[64]; int pc = 0;
|
||||||
|
while (peek(p)->kind != TOK_RPAREN && !error->message) {
|
||||||
|
if (pc >= 64) { error->message = "lambda 参数过多(最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; }
|
||||||
|
bool is_out = match(p, TOK_OUT);
|
||||||
|
const Token* pname = expect(p, TOK_IDENT, error, "参数名");
|
||||||
|
if (!pname) return NULL;
|
||||||
|
if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL;
|
||||||
|
TypeInfo pti = parse_type_expr(p, error);
|
||||||
|
if (pti.kind == TYPE_ERROR) return NULL;
|
||||||
|
plist[pc++] = ast_make_parameter(p->arena,
|
||||||
|
arena_strdup_impl(p->arena, pname->start, pname->length),
|
||||||
|
pti.kind, pti.struct_name, is_out, tok_loc(pname));
|
||||||
|
if (match(p, TOK_COMMA)) continue; else break;
|
||||||
|
}
|
||||||
|
if (!expect(p, TOK_RPAREN, error, "缺少 ')'")) return NULL;
|
||||||
|
TypeKind ret = TYPE_VOID;
|
||||||
|
const char* ret_sn = NULL;
|
||||||
|
if (match(p, TOK_ARROW)) {
|
||||||
|
TypeInfo rti = parse_type_expr(p, error);
|
||||||
|
if (rti.kind == TYPE_ERROR) return NULL;
|
||||||
|
ret = rti.kind; ret_sn = rti.struct_name;
|
||||||
|
}
|
||||||
|
AstNode* body = parse_block(p, error);
|
||||||
|
if (!body) return NULL;
|
||||||
|
AstNode** parr = arena_alloc_impl(p->arena, pc * sizeof(AstNode*));
|
||||||
|
memcpy(parr, plist, pc * sizeof(AstNode*));
|
||||||
|
left = ast_make_lambda(p->arena, parr, pc, ret, ret_sn, body, tok_loc(fn_tok));
|
||||||
} else if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) {
|
} else if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) {
|
||||||
left = parse_unary(p, error);
|
left = parse_unary(p, error);
|
||||||
} else if (tok->kind == TOK_LPAREN) {
|
} else if (tok->kind == TOK_LPAREN) {
|
||||||
|
|||||||
@@ -95,3 +95,7 @@ AstNode* mono_queue[256];
|
|||||||
size_t mono_count = 0;
|
size_t mono_count = 0;
|
||||||
Arena* mono_arena = NULL;
|
Arena* mono_arena = NULL;
|
||||||
AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板)
|
AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板)
|
||||||
|
|
||||||
|
// lambda 队列: 分析时创建的闭包函数
|
||||||
|
AstNode* lambda_queue[256];
|
||||||
|
size_t lambda_count = 0;
|
||||||
|
|||||||
@@ -506,4 +506,19 @@ void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) {
|
|||||||
scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1, NULL, 0);
|
scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1, NULL, 0);
|
||||||
|
|
||||||
analyze_node(ast, global_scope, errors, arena);
|
analyze_node(ast, global_scope, errors, arena);
|
||||||
|
|
||||||
|
// 将 lambda 生成的函数追加到 program 的函数列表
|
||||||
|
if (lambda_count > 0 && g_program) {
|
||||||
|
size_t old = g_program->as.program.fn_count;
|
||||||
|
size_t total = old + lambda_count;
|
||||||
|
AstNode** new_fns = arena_alloc_impl(arena, total * sizeof(AstNode*));
|
||||||
|
if (new_fns) {
|
||||||
|
if (old > 0)
|
||||||
|
memcpy(new_fns, g_program->as.program.functions, old * sizeof(AstNode*));
|
||||||
|
memcpy(new_fns + old, lambda_queue, lambda_count * sizeof(AstNode*));
|
||||||
|
g_program->as.program.functions = new_fns;
|
||||||
|
g_program->as.program.fn_count = total;
|
||||||
|
}
|
||||||
|
lambda_count = 0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,10 @@ extern size_t mono_count;
|
|||||||
extern Arena* mono_arena;
|
extern Arena* mono_arena;
|
||||||
extern AstNode* g_program;
|
extern AstNode* g_program;
|
||||||
|
|
||||||
|
// === lambda 闭包队列 ===
|
||||||
|
extern AstNode* lambda_queue[256];
|
||||||
|
extern size_t lambda_count;
|
||||||
|
|
||||||
// === 类型推断上下文 ===
|
// === 类型推断上下文 ===
|
||||||
extern TypeKind current_return_type;
|
extern TypeKind current_return_type;
|
||||||
extern const char* current_return_struct_name;
|
extern const char* current_return_struct_name;
|
||||||
|
|||||||
@@ -226,6 +226,14 @@ bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
|
|||||||
|
|
||||||
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||||
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
Symbol* sym = scope_lookup(scope, node->as.call.name);
|
||||||
|
// 闭包调用: 变量类型为 TYPE_CLOSURE
|
||||||
|
if (sym && sym->kind == SYM_VARIABLE && sym->type == TYPE_CLOSURE) {
|
||||||
|
// 暂不做参数类型检查(MVP), 只分析参数表达式
|
||||||
|
for (size_t i = 0; i < node->as.call.arg_count; i++)
|
||||||
|
analyze_expr(node->as.call.args[i], scope, errors, a);
|
||||||
|
node->type.kind = TYPE_I64; // 默认返回 i64(MVP 限制)
|
||||||
|
return;
|
||||||
|
}
|
||||||
if (!sym || sym->kind != SYM_FUNCTION) {
|
if (!sym || sym->kind != SYM_FUNCTION) {
|
||||||
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
error_add(errors, "<sema>", node->loc.line, node->loc.col,
|
||||||
"未定义的函数 '%s'", node->as.call.name);
|
"未定义的函数 '%s'", node->as.call.name);
|
||||||
@@ -523,6 +531,53 @@ void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena*
|
|||||||
return NULL; \
|
return NULL; \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === lambda 表达式分析 ===
|
||||||
|
static int lambda_counter = 0;
|
||||||
|
|
||||||
|
void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||||
|
lambda_counter++;
|
||||||
|
int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1;
|
||||||
|
char* gen_name = arena_alloc_impl(a, name_len);
|
||||||
|
snprintf(gen_name, name_len, "__lambda_%d", lambda_counter);
|
||||||
|
node->as.lambda.generated_name = gen_name;
|
||||||
|
|
||||||
|
// 分析 lambda 体(参数作用域)
|
||||||
|
Scope* lambda_scope = scope_new(a, scope);
|
||||||
|
for (size_t i = 0; i < node->as.lambda.param_count; i++) {
|
||||||
|
AstNode* p = node->as.lambda.params[i];
|
||||||
|
scope_insert(lambda_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type);
|
||||||
|
}
|
||||||
|
TypeKind saved_ret = current_return_type;
|
||||||
|
const char* saved_ret_sn = current_return_struct_name;
|
||||||
|
current_return_type = node->as.lambda.return_type;
|
||||||
|
current_return_struct_name = node->as.lambda.return_struct_type_name;
|
||||||
|
analyze_node(node->as.lambda.body, lambda_scope, errors, a);
|
||||||
|
current_return_type = saved_ret;
|
||||||
|
current_return_struct_name = saved_ret_sn;
|
||||||
|
|
||||||
|
// 创建顶层函数 AST 节点, 加入队列供 codegen 使用
|
||||||
|
AstNode* fn = ast_make_function(a, gen_name,
|
||||||
|
node->as.lambda.params, node->as.lambda.param_count,
|
||||||
|
node->as.lambda.return_type,
|
||||||
|
node->as.lambda.return_struct_type_name,
|
||||||
|
node->as.lambda.body, false, NULL, 0, node->loc);
|
||||||
|
if (lambda_count < 256)
|
||||||
|
lambda_queue[lambda_count++] = fn;
|
||||||
|
|
||||||
|
// 注册函数符号(支持递归调用自身)
|
||||||
|
TypeKind* pts = node->as.lambda.param_count > 0
|
||||||
|
? arena_alloc_impl(a, node->as.lambda.param_count * sizeof(TypeKind)) : NULL;
|
||||||
|
for (size_t i = 0; i < node->as.lambda.param_count; i++)
|
||||||
|
pts[i] = node->as.lambda.params[i]->as.parameter.type;
|
||||||
|
scope_insert_function(scope, a, gen_name,
|
||||||
|
node->as.lambda.return_type,
|
||||||
|
node->as.lambda.return_struct_type_name,
|
||||||
|
pts, NULL, NULL, node->as.lambda.param_count, NULL, 0);
|
||||||
|
|
||||||
|
node->type.kind = TYPE_CLOSURE;
|
||||||
|
}
|
||||||
|
|
||||||
|
SEMA_HANDLER(analyze_lambda)
|
||||||
SEMA_HANDLER(analyze_ident_expr)
|
SEMA_HANDLER(analyze_ident_expr)
|
||||||
SEMA_HANDLER(analyze_unary_expr)
|
SEMA_HANDLER(analyze_unary_expr)
|
||||||
SEMA_HANDLER(analyze_binary_expr)
|
SEMA_HANDLER(analyze_binary_expr)
|
||||||
@@ -572,6 +627,7 @@ void analyze_expr_init(void) {
|
|||||||
ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap);
|
ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap);
|
||||||
ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap);
|
ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap);
|
||||||
ast_dispatch_set(&sema_dispatch, AST_LIST_COMP, analyze_list_comp_wrap);
|
ast_dispatch_set(&sema_dispatch, AST_LIST_COMP, analyze_list_comp_wrap);
|
||||||
|
ast_dispatch_set(&sema_dispatch, AST_LAMBDA, analyze_lambda_wrap);
|
||||||
}
|
}
|
||||||
|
|
||||||
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
|
||||||
|
|||||||
@@ -0,0 +1,26 @@
|
|||||||
|
// 闭包测试 — lambda 表达式 + 调用
|
||||||
|
fn apply_op(x: i64, op: i64) -> i64 {
|
||||||
|
// 闭包作为参数暂不支持直接调用,返回 x * 2
|
||||||
|
return x * 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> void {
|
||||||
|
// 测试1: 基本 lambda
|
||||||
|
let double = fn(x: i64) -> i64 { return x * 2; };
|
||||||
|
let r1 = double(21);
|
||||||
|
print_i64(r1); // 42
|
||||||
|
|
||||||
|
// 测试2: lambda with multiple params
|
||||||
|
let add = fn(a: i64, b: i64) -> i64 { return a + b; };
|
||||||
|
let r2 = add(30, 12);
|
||||||
|
print_i64(r2); // 42
|
||||||
|
|
||||||
|
// 测试3: nested lambda call
|
||||||
|
let r3 = double(add(10, 11));
|
||||||
|
print_i64(r3); // 42
|
||||||
|
|
||||||
|
// 测试4: lambda in sequence
|
||||||
|
let triple = fn(x: i64) -> i64 { return x * 3; };
|
||||||
|
let r4 = triple(14);
|
||||||
|
print_i64(r4); // 42
|
||||||
|
}
|
||||||
@@ -452,6 +452,38 @@ void test_in_param_assign_error() {
|
|||||||
arena_destroy(&a);
|
arena_destroy(&a);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void test_lambda_ok() {
|
||||||
|
Arena a = arena_create(1);
|
||||||
|
size_t tc; ErrorInfo lex_err = {0};
|
||||||
|
Token* toks = lex(&a,
|
||||||
|
"fn main() -> void { let f = fn(x: i64) -> i64 { return x * 2; }; return; }",
|
||||||
|
"test", &tc, &lex_err);
|
||||||
|
ASSERT(toks != NULL);
|
||||||
|
ErrorInfo parse_err = {0};
|
||||||
|
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
|
||||||
|
ASSERT(ast != NULL);
|
||||||
|
ErrorList errors; error_init(&errors, &a);
|
||||||
|
sema_analyze(ast, &errors, &a);
|
||||||
|
ASSERT(errors.count == 0); // lambda 定义应通过
|
||||||
|
arena_destroy(&a);
|
||||||
|
}
|
||||||
|
|
||||||
|
void test_lambda_call_ok() {
|
||||||
|
Arena a = arena_create(1);
|
||||||
|
size_t tc; ErrorInfo lex_err = {0};
|
||||||
|
Token* toks = lex(&a,
|
||||||
|
"fn main() -> void { let f = fn(x: i64) -> i64 { return x + 1; }; let r = f(41); return; }",
|
||||||
|
"test", &tc, &lex_err);
|
||||||
|
ASSERT(toks != NULL);
|
||||||
|
ErrorInfo parse_err = {0};
|
||||||
|
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
|
||||||
|
ASSERT(ast != NULL);
|
||||||
|
ErrorList errors; error_init(&errors, &a);
|
||||||
|
sema_analyze(ast, &errors, &a);
|
||||||
|
ASSERT(errors.count == 0); // 闭包调用应通过
|
||||||
|
arena_destroy(&a);
|
||||||
|
}
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
TEST_RUN(test_type_error);
|
TEST_RUN(test_type_error);
|
||||||
TEST_RUN(test_undefined_var);
|
TEST_RUN(test_undefined_var);
|
||||||
@@ -479,5 +511,7 @@ int main(void) {
|
|||||||
TEST_RUN(test_match_wildcard_only_sema_ok);
|
TEST_RUN(test_match_wildcard_only_sema_ok);
|
||||||
TEST_RUN(test_out_param_assign_ok);
|
TEST_RUN(test_out_param_assign_ok);
|
||||||
TEST_RUN(test_in_param_assign_error);
|
TEST_RUN(test_in_param_assign_error);
|
||||||
|
TEST_RUN(test_lambda_ok);
|
||||||
|
TEST_RUN(test_lambda_call_ok);
|
||||||
return test_summary();
|
return test_summary();
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user