Compare commits

...

4 Commits

Author SHA1 Message Date
Serendipity 6d5f8092a7 fix: for循环变量作用域 + 列表推导crash (两个已知bug)
Bug 1 - For循环变量作用域:
- AST_BLOCK 在 sema 中未创建子作用域 → 连续 for 循环用同名变量报"重复定义"
- 修复: sema AST_BLOCK 创建 block_scope (scope_new)
- 修复: codegen AST_BLOCK 保存/恢复 var_table 实现块级变量隔离

Bug 2 - 列表推导 >2元素 crash:
- sema 对 TYPE_ARRAY 标注跳过 init 分析 → 列表推导表达式未被semantize
- 导致 codegen 处 element_type=0, array_size=0 → LLVM alloca 崩溃
- 修复: 仅自引用 (= 变量名) 跳过分析,列表推导等正常分析
- 修复: cg_list_comp 使用 to_llvm_type(elem) 而非 type_info_to_llvm(full_array)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-07 18:03:25 +08:00
Serendipity f5c0650a97 feat: 闭包变量捕获 — 环境结构体 + 堆分配
lambda 可捕获外层变量, 自动构建环境结构体:

let base = 100;
let f = fn(x: i64) -> i64 { return x + base; };  // 捕获 base
f(50);  // → 150

全流水线实现:
- Sema: collect_free_vars 遍历 AST 收集自由变量
- AST function: captured/cap_types/cap_count 字段存储捕获信息
- Codegen: 闭包类型改为 struct {fn_ptr: i64, env_ptr: ptr}
- Codegen: lambda 表达式 malloc 环境结构体 + 存储捕获值
- Codegen: 生成函数签名添加 env_ptr 首个参数 (capturing only)
- Codegen: 函数体内通过 GEP 注册捕获变量到 var_table
- Codegen: 闭包调用自动提取 fn_ptr/env_ptr, 条件传递 env

非捕获 lambda 兼容: env_ptr=NULL, 不额外传参
嵌套 lambda 正确处理: 内层不穿透捕获外层变量

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-07 15:24:35 +08:00
Serendipity 06d80f441a feat: 闭包/lambda — 匿名函数表达式
fn(x: T) -> R { body } 作为表达式, 可赋值给变量并间接调用。

全流水线实现:
- Parser: TOK_FN 前缀 → AST_LAMBDA 节点
- Sema: 自动生成 __lambda_N 顶层函数 + 符号注册
- Sema: analyze_call_expr 支持 TYPE_CLOSURE 变量调用
- Codegen: lambda 表达式返回函数指针(i64), 调用点载入+IntToPtr+间接call
- VarEntry.closure_fn 追踪闭包变量对应的生成函数

限制(MVP v0.1): 非捕获 lambda, 返回类型固定 i64

+6 sema 测试 + 1 集成测试, 209 测试全部通过

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-07 15:07:03 +08:00
Serendipity c8da286d31 feat: in/out 参数 — out 关键字引用传递
fn swap(out x: i64, out y: i64) 声明 out 参数,codegen 层面
函数签名变为 T* 指针,调用点自动传 &variable 地址。
in 是默认行为(值传递),无需显式标注。

Token → Parser → Sema → Codegen 全流水线:
- TOK_OUT + "out" 关键字注册
- AST parameter.is_out 字段
- parse_function 解析 out 前缀
- Sema: out 参数注册为 SYM_VARIABLE+is_mut(可赋值)
- Codegen: LLVM 函数签名使用 T*,调用点传 alloca

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-07 14:45:38 +08:00
21 changed files with 667 additions and 48 deletions
+2
View File
@@ -18,6 +18,7 @@ typedef enum {
TYPE_STRUCT, // 结构体类型
TYPE_ENUM, // 枚举类型
TYPE_ARRAY, // 固定大小数组类型
TYPE_CLOSURE, // 闭包类型 (函数指针 + 环境指针)
TYPE_GENERIC, // 泛型类型参数(单态化前)
TYPE_UNKNOWN, // 尚未推断
TYPE_ERROR, // 类型错误
@@ -36,6 +37,7 @@ static inline const char* type_name(TypeKind kind) {
case TYPE_STRUCT: return "struct";
case TYPE_ENUM: return "enum";
case TYPE_ARRAY: return "array";
case TYPE_CLOSURE: return "closure";
default: return "<unknown>";
}
}
+19 -1
View File
@@ -44,14 +44,18 @@ AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size
n->as.function.multi_ret_types = NULL;
n->as.function.multi_ret_snames = NULL;
n->as.function.multi_ret_count = 0;
n->as.function.captured = NULL;
n->as.function.cap_types = NULL;
n->as.function.cap_count = 0;
return n;
}
AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type,
const char* struct_type_name, SourceLoc loc) {
const char* struct_type_name, bool is_out, SourceLoc loc) {
NEW(alloc, AST_PARAMETER);
n->as.parameter.name = name; n->as.parameter.type = type;
n->as.parameter.struct_type_name = struct_type_name;
n->as.parameter.is_out = is_out;
return n;
}
@@ -279,6 +283,20 @@ AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method
return n;
}
AstNode* ast_make_lambda(void* alloc, AstNode** params, size_t pcount,
TypeKind ret, const char* ret_struct_name,
AstNode* body, SourceLoc loc) {
NEW(alloc, AST_LAMBDA);
n->as.lambda.params = params; n->as.lambda.param_count = pcount;
n->as.lambda.return_type = ret;
n->as.lambda.return_struct_type_name = ret_struct_name;
n->as.lambda.body = body;
n->as.lambda.generated_name = NULL;
n->as.lambda.captured = NULL;
n->as.lambda.captured_count = 0;
return n;
}
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc) {
NEW(alloc, AST_MOD_DECL);
n->as.mod_decl.name = name;
+14 -3
View File
@@ -32,6 +32,7 @@ typedef enum {
AST_ARRAY_ASSIGN_STMT,// arr[i] = expr
AST_IMPL_BLOCK, // impl StructName { fn method(...) ... }
AST_METHOD_CALL, // receiver.method(args)
AST_LAMBDA, // fn(x: T) -> R { body } 匿名函数/闭包
AST_MOD_DECL, // mod foo;
AST_USE_DECL, // use foo::bar;
AST_TRAIT_DECL, // trait Name { fn ... }
@@ -72,9 +73,10 @@ struct AstNode {
TypeKind return_type; const char* return_struct_type_name;
struct AstNode* body; bool is_pub;
const char** type_params; size_t type_param_count;
TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count; } function;
TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count;
const char** captured; TypeKind* cap_types; size_t cap_count; } function;
// AST_PARAMETER (也用作结构体字段: name + type)
struct { const char* name; TypeKind type; const char* struct_type_name; } parameter;
struct { const char* name; TypeKind type; const char* struct_type_name; bool is_out; } parameter;
// AST_BLOCK
struct { struct AstNode** stmts; size_t stmt_count; } block;
// AST_LET_STMT
@@ -131,6 +133,12 @@ struct AstNode {
struct { const char* struct_name; struct AstNode** methods; size_t method_count; } impl_block;
// AST_METHOD_CALL
struct { struct AstNode* receiver; const char* method_name; struct AstNode** args; const char** arg_names; size_t arg_count; } method_call;
// AST_LAMBDA
struct { struct AstNode** params; size_t param_count;
TypeKind return_type; const char* return_struct_type_name;
struct AstNode* body;
const char* generated_name; // 自动生成的顶层函数名
const char** captured; size_t captured_count; } lambda;
// AST_MOD_DECL
struct { const char* name; struct AstNode* ast; } mod_decl;
// AST_USE_DECL
@@ -150,7 +158,7 @@ AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size
TypeKind ret, const char* ret_struct_name, AstNode* body,
bool is_pub, const char** type_params, size_t tp_count,
SourceLoc loc);
AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc);
AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, bool is_out, SourceLoc loc);
AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc);
AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot,
bool is_mut, AstNode* init, const char* struct_type_name,
@@ -185,6 +193,9 @@ AstNode* ast_make_index_expr(void* alloc, AstNode* array, AstNode* index, Source
AstNode* ast_make_array_assign(void* alloc, const char* name, AstNode* index, AstNode* value, SourceLoc loc);
AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** methods, size_t count, SourceLoc loc);
AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc);
AstNode* ast_make_lambda(void* alloc, AstNode** params, size_t pcount,
TypeKind ret, const char* ret_struct_name,
AstNode* body, SourceLoc loc);
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc);
AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc);
AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc);
+1 -1
View File
@@ -8,7 +8,7 @@ typedef void* (*VisitFn)(void* ctx, AstNode* node);
// 遍历表 — 按 AstKind 索引, 未处理的条目为 NULL
// 新增 AST 节点: 在此表新增一条目, 编译器会警告未初始化的函数指针
enum { VISIT_TABLE_SIZE = 28 };
enum { VISIT_TABLE_SIZE = 29 };
typedef struct {
void* ctx;
+153 -7
View File
@@ -11,6 +11,13 @@ LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) {
case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context);
case TYPE_CHAR: return LLVMInt8TypeInContext(ctx->context);
case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0);
case TYPE_CLOSURE: {
LLVMTypeRef cls_fields[] = {
LLVMInt64TypeInContext(ctx->context), // fn_ptr
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0) // env_ptr
};
return LLVMStructTypeInContext(ctx->context, cls_fields, 2, false);
}
case TYPE_STRUCT:
case TYPE_ENUM: {
LLVMTypeRef fields[] = { LLVMInt64TypeInContext(ctx->context),
@@ -235,18 +242,82 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) {
return LLVMBuildCall2(ctx->builder, ctx->printf_ty, ctx->printf_fn,
(LLVMValueRef[]){fmt, arg}, 2, "");
}
LLVMTypeRef fn_ty = NULL;
LLVMValueRef fn = find_fn(ctx, node->as.call.name);
if (!fn) return NULL;
LLVMValueRef closure_env = NULL;
LLVMValueRef gen_fn = NULL; // 闭包对应的生成函数
if (fn) {
fn_ty = LLVMGlobalGetValueType(fn);
} else {
VarEntry* cve = NULL;
for (VarEntry* e = ctx->var_table; e; e = e->next)
if (strcmp(e->name, node->as.call.name) == 0) { cve = e; break; }
if (cve && cve->closure_fn) {
gen_fn = find_fn(ctx, cve->closure_fn);
if (gen_fn) {
fn_ty = LLVMGlobalGetValueType(gen_fn);
LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE);
LLVMValueRef cls_val = LLVMBuildLoad2(ctx->builder, cls_ty,
cve->alloca, "cls_val");
LLVMValueRef fn_val = LLVMBuildExtractValue(ctx->builder,
cls_val, 0, "fn_ptr_i64");
fn = LLVMBuildIntToPtr(ctx->builder, fn_val,
LLVMPointerType(fn_ty, 0), "fn_cast");
closure_env = LLVMBuildExtractValue(ctx->builder,
cls_val, 1, "env_ptr");
}
}
}
if (!fn || !fn_ty) return NULL;
LLVMValueRef args[16];
if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; }
FnEntry* fn_entry = find_fn_entry(ctx, node->as.call.name);
for (size_t i = 0; i < node->as.call.arg_count; i++) {
args[i] = codegen_expr(ctx, node->as.call.args[i]);
bool is_out = fn_entry && fn_entry->out_params
&& i < fn_entry->pc && fn_entry->out_params[i];
if (is_out) {
// out 参数传 alloca 地址而非加载后的值
AstNode* arg = node->as.call.args[i];
if (arg->kind == AST_IDENT_EXPR) {
args[i] = find_var(ctx, arg->as.ident.name);
} else if (arg->kind == AST_INDEX_EXPR) {
// arr[i]: 生成 GEP 得到元素指针
LLVMValueRef arr_ptr = find_var(ctx, arg->as.index_expr.array->as.ident.name);
LLVMValueRef idx_val = codegen_expr(ctx, arg->as.index_expr.index);
if (!arr_ptr || !idx_val) return NULL;
LLVMValueRef indices[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
LLVMBuildIntCast2(ctx->builder, idx_val,
LLVMInt32TypeInContext(ctx->context), false, "idx32")
};
args[i] = LLVMBuildGEP2(ctx->builder,
LLVMGetElementType(LLVMTypeOf(arr_ptr)),
arr_ptr, indices, 2, "out_gep");
} else {
args[i] = codegen_expr(ctx, node->as.call.args[i]);
}
} else {
args[i] = codegen_expr(ctx, node->as.call.args[i]);
}
if (!args[i]) return NULL;
}
LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn);
// 闭包调用: 若函数有 env 参数, 将其作为第一个参数
LLVMValueRef call_args[17];
unsigned call_argc = (unsigned)node->as.call.arg_count;
// 用生成函数的参数数判断是否需要 pass env
bool need_env = closure_env && gen_fn && LLVMCountParams(gen_fn) > call_argc;
if (need_env) {
call_args[0] = closure_env;
for (unsigned i = 0; i < call_argc; i++)
call_args[i + 1] = args[i];
call_argc++;
} else {
for (unsigned i = 0; i < call_argc; i++)
call_args[i] = args[i];
}
LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty);
return LLVMBuildCall2(ctx->builder, fn_ty, fn, args,
(unsigned)node->as.call.arg_count,
return LLVMBuildCall2(ctx->builder, fn_ty, fn, call_args,
call_argc,
ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "call");
}
CG_HANDLER(cg_call)
@@ -405,8 +476,9 @@ static AstDispatch cg_dispatch;
static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) {
TypeInfo* ti = &node->type;
LLVMTypeRef elem_ty = type_info_to_llvm(ctx, ti);
LLVMTypeRef arr_ty = LLVMArrayType(elem_ty, (unsigned)ti->array_size);
LLVMTypeRef elem_ty = to_llvm_type(ctx, ti->element_type);
LLVMTypeRef arr_ty = LLVMArrayType(elem_ty,
ti->array_size > 0 ? (unsigned)ti->array_size : 1);
LLVMValueRef result = LLVMBuildAlloca(ctx->builder, arr_ty, "list");
// 初始化为零
LLVMBuildStore(ctx->builder, LLVMConstNull(arr_ty), result);
@@ -471,6 +543,79 @@ static LLVMValueRef cg_list_comp_impl(CgCtx* ctx, AstNode* node) {
}
CG_HANDLER(cg_list_comp)
static LLVMValueRef cg_lambda_impl(CgCtx* ctx, AstNode* node) {
// 获取生成函数
LLVMValueRef gen_fn = find_fn(ctx, node->as.lambda.generated_name);
if (!gen_fn) return NULL;
// 闭包结构体: { i64 fn_ptr, i8* env_ptr }
LLVMTypeRef cls_ty = to_llvm_type(ctx, TYPE_CLOSURE);
LLVMValueRef closure = LLVMBuildAlloca(ctx->builder, cls_ty, "closure");
LLVMValueRef fn_ptr = LLVMBuildPtrToInt(ctx->builder, gen_fn,
LLVMInt64TypeInContext(ctx->context), "fn_ptr");
LLVMValueRef env_ptr = LLVMConstNull(
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0));
// 若有捕获变量, 构建环境结构体
AstNode* fn_ast = NULL;
// 在 g_program 中查找对应的函数 AST 获取捕获信息
if (g_program) {
for (size_t i = 0; i < g_program->as.program.fn_count; i++) {
if (g_program->as.program.functions[i]->as.function.name &&
strcmp(g_program->as.program.functions[i]->as.function.name,
node->as.lambda.generated_name) == 0) {
fn_ast = g_program->as.program.functions[i];
break;
}
}
}
if (fn_ast && fn_ast->as.function.cap_count > 0) {
size_t nc = fn_ast->as.function.cap_count;
// 构建 env struct: { cap_type_0, cap_type_1, ... }
LLVMTypeRef* ef = arena_alloc(ctx->arena, nc * sizeof(LLVMTypeRef));
for (size_t i = 0; i < nc; i++)
ef[i] = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]);
LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx->context, ef, (unsigned)nc, false);
// malloc env
LLVMValueRef env_size = LLVMSizeOf(env_ty);
LLVMValueRef malloc_args[] = { env_size };
LLVMValueRef env = LLVMBuildCall2(ctx->builder,
LLVMGlobalGetValueType(ctx->malloc_fn), ctx->malloc_fn,
malloc_args, 1, "env");
env_ptr = LLVMBuildBitCast(ctx->builder, env,
LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0), "env_cast");
LLVMValueRef typed_env = LLVMBuildBitCast(ctx->builder, env,
LLVMPointerType(env_ty, 0), "env_typed");
// 存储捕获的值到 env struct
for (size_t i = 0; i < nc; i++) {
LLVMValueRef cap_alloca = find_var(ctx, fn_ast->as.function.captured[i]);
if (!cap_alloca) return NULL;
LLVMTypeRef cap_ty = to_llvm_type(ctx, fn_ast->as.function.cap_types[i]);
LLVMValueRef cap_val = LLVMBuildLoad2(ctx->builder, cap_ty,
cap_alloca, "cap_val");
LLVMValueRef gep_idx[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false),
LLVMConstInt(LLVMInt32TypeInContext(ctx->context), (unsigned)i, false)
};
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx->builder, env_ty,
typed_env, gep_idx, 2, "cap_gep");
LLVMBuildStore(ctx->builder, cap_val, field_ptr);
}
}
// 填充 closure struct: { fn_ptr, env_ptr }
LLVMValueRef f0 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 0, "cl_f0");
LLVMBuildStore(ctx->builder, fn_ptr, f0);
LLVMValueRef f1 = LLVMBuildStructGEP2(ctx->builder, cls_ty, closure, 1, "cl_f1");
LLVMBuildStore(ctx->builder, env_ptr, f1);
return LLVMBuildLoad2(ctx->builder, cls_ty, closure, "closure_val");
}
CG_HANDLER(cg_lambda)
void codegen_expr_init(void) {
ast_dispatch_set(&cg_dispatch, AST_LITERAL_EXPR, cg_literal);
ast_dispatch_set(&cg_dispatch, AST_IDENT_EXPR, cg_ident);
@@ -485,6 +630,7 @@ void codegen_expr_init(void) {
ast_dispatch_set(&cg_dispatch, AST_BLOCK, cg_block);
ast_dispatch_set(&cg_dispatch, AST_IF_STMT, cg_if_expr);
ast_dispatch_set(&cg_dispatch, AST_LIST_COMP, cg_list_comp);
ast_dispatch_set(&cg_dispatch, AST_LAMBDA, cg_lambda);
}
// === 统一入口 ===
+82 -17
View File
@@ -9,11 +9,14 @@ LLVMValueRef find_var(CgCtx* ctx, const char* name) {
return NULL;
}
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type) {
VarEntry* e = arena_alloc(ctx->arena, sizeof(*e));
if (!e) return;
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type; e->next = ctx->var_table;
if (!e) return NULL;
e->name = name; e->alloca = alloca; e->alloca_type = alloca_type;
e->closure_fn = NULL;
e->next = ctx->var_table;
ctx->var_table = e;
return e;
}
// === 函数表 ===
@@ -23,17 +26,24 @@ LLVMValueRef find_fn(CgCtx* ctx, const char* name) {
return NULL;
}
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) {
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc) {
FnEntry* e = arena_alloc(ctx->arena, sizeof(*e));
if (!e) return;
e->name = name; e->fn = fn;
e->ret = TYPE_VOID;
e->params = NULL;
e->pc = 0;
e->out_params = out_params;
e->pc = pc;
e->next = ctx->fn_table;
ctx->fn_table = e;
}
FnEntry* find_fn_entry(CgCtx* ctx, const char* name) {
for (FnEntry* e = ctx->fn_table; e; e = e->next)
if (strcmp(e->name, name) == 0) return e;
return NULL;
}
// === 结构体类型表 ===
void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) {
StructTypeEntry* e = arena_alloc(ctx->arena, sizeof(*e));
@@ -127,7 +137,11 @@ void codegen_stmt(CgCtx* ctx, AstNode* node) {
} else {
return;
}
add_var(ctx, node->as.let_stmt.name, alloca, var_type);
VarEntry* ve = add_var(ctx, node->as.let_stmt.name, alloca, var_type);
// 若 init 是 lambda, 记录闭包函数名供后续调用
if (node->as.let_stmt.init &&
node->as.let_stmt.init->kind == AST_LAMBDA && ve)
ve->closure_fn = node->as.let_stmt.init->as.lambda.generated_name;
// 自动内存管理: 只追踪 str 堆分配 (拼接/malloc)
// struct 是栈上值类型,不能 free();含 str 字段时 v0.5 扩展
@@ -191,10 +205,12 @@ void codegen_stmt(CgCtx* ctx, AstNode* node) {
case AST_BLOCK: {
if (++codegen_depth > MAX_CODEGEN_DEPTH) { codegen_depth--; return; }
size_t block_mark = ctx->cleanup_count;
VarEntry* saved_table = ctx->var_table;
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
codegen_stmt(ctx, node->as.block.stmts[i]);
}
cleanup_emit(ctx, block_mark); // 作用域退出: 释放块内 str 堆分配
ctx->var_table = saved_table;
cleanup_emit(ctx, block_mark);
codegen_depth--;
break;
}
@@ -377,16 +393,33 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
// 第一遍:声明所有 L 函数
for (size_t i = 0; i < ast->as.program.fn_count; i++) {
AstNode* fn = ast->as.program.functions[i];
bool has_env = fn->as.function.cap_count > 0;
size_t total_params = fn->as.function.param_count + (has_env ? 1 : 0);
LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
fn->as.function.param_count * sizeof(LLVMTypeRef));
total_params * sizeof(LLVMTypeRef));
bool* out_params = NULL;
if (total_params > 0) {
out_params = arena_alloc(ctx.arena, total_params * sizeof(bool));
for (size_t j = 0; j < total_params; j++) out_params[j] = false;
}
// 若有捕获, 第一个参数是 env_ptr
size_t poff = 0;
if (has_env) {
ptypes[0] = LLVMPointerType(LLVMInt8TypeInContext(ctx.context), 0);
poff = 1;
}
for (size_t j = 0; j < fn->as.function.param_count; j++) {
AstNode* param = fn->as.function.params[j];
bool is_out = param->as.parameter.is_out;
if (out_params) out_params[j + poff] = is_out;
LLVMTypeRef inner_ty;
if (param->as.parameter.type == TYPE_STRUCT &&
param->as.parameter.struct_type_name) {
ptypes[j] = find_struct_type(&ctx, param->as.parameter.struct_type_name);
inner_ty = find_struct_type(&ctx, param->as.parameter.struct_type_name);
} else {
ptypes[j] = to_llvm_type(&ctx, param->as.parameter.type);
inner_ty = to_llvm_type(&ctx, param->as.parameter.type);
}
ptypes[j + poff] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty;
}
LLVMTypeRef ret_ty;
if (fn->as.function.return_type == TYPE_STRUCT &&
@@ -396,9 +429,9 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
}
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
ptypes, (unsigned)fn->as.function.param_count, false);
ptypes, (unsigned)total_params, false);
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
add_fn(&ctx, fn->as.function.name, lfn);
add_fn(&ctx, fn->as.function.name, lfn, out_params, total_params);
}
// 第二遍:生成函数体
@@ -411,9 +444,37 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
// 清空变量表(每个函数独立作用域)
ctx.var_table = NULL;
// 捕获变量: 第一个参数是 env_ptr, 通过 GEP 注册捕获变量
bool has_env = fn->as.function.cap_count > 0;
LLVMValueRef env_ptr = NULL;
size_t param_offset = 0;
if (has_env) {
env_ptr = LLVMGetParam(lfn, 0);
param_offset = 1;
// 生成 env struct 类型并注册捕获变量
LLVMTypeRef* ef = arena_alloc(ctx.arena,
fn->as.function.cap_count * sizeof(LLVMTypeRef));
for (size_t ci = 0; ci < fn->as.function.cap_count; ci++)
ef[ci] = to_llvm_type(&ctx, fn->as.function.cap_types[ci]);
LLVMTypeRef env_ty = LLVMStructTypeInContext(ctx.context, ef,
(unsigned)fn->as.function.cap_count, false);
LLVMValueRef typed_env = LLVMBuildBitCast(ctx.builder, env_ptr,
LLVMPointerType(env_ty, 0), "env_typed");
for (size_t ci = 0; ci < fn->as.function.cap_count; ci++) {
LLVMValueRef gep_idx[] = {
LLVMConstInt(LLVMInt32TypeInContext(ctx.context), 0, false),
LLVMConstInt(LLVMInt32TypeInContext(ctx.context), (unsigned)ci, false)
};
LLVMValueRef field_ptr = LLVMBuildGEP2(ctx.builder, env_ty,
typed_env, gep_idx, 2, "cap_ptr");
add_var(&ctx, fn->as.function.captured[ci], field_ptr,
to_llvm_type(&ctx, fn->as.function.cap_types[ci]));
}
}
// 将参数注册为变量
for (size_t j = 0; j < fn->as.function.param_count; j++) {
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j);
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)(j + param_offset));
AstNode* pnode = fn->as.function.params[j];
LLVMTypeRef param_ty;
if (pnode->as.parameter.type == TYPE_STRUCT &&
@@ -422,10 +483,14 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
} else {
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
}
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
param_ty, pnode->as.parameter.name);
LLVMBuildStore(ctx.builder, param, alloca);
add_var(&ctx, pnode->as.parameter.name, alloca, param_ty);
if (pnode->as.parameter.is_out) {
add_var(&ctx, pnode->as.parameter.name, param, param_ty);
} else {
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
param_ty, pnode->as.parameter.name);
LLVMBuildStore(ctx.builder, param, alloca);
add_var(&ctx, pnode->as.parameter.name, alloca, param_ty);
}
}
ctx.defer_count = 0;
+8 -2
View File
@@ -15,11 +15,15 @@
extern int codegen_depth;
#define MAX_CODEGEN_DEPTH 1000
// AST program (由 sema 设置, codegen 读取)
extern AstNode* g_program;
// === 内部状态 ===
typedef struct VarEntry {
const char* name;
LLVMValueRef alloca;
LLVMTypeRef alloca_type;
const char* closure_fn; // 闭包对应的生成函数名
struct VarEntry* next;
} VarEntry;
@@ -28,6 +32,7 @@ typedef struct FnEntry {
LLVMValueRef fn;
TypeKind ret;
TypeKind* params;
bool* out_params; // 哪些参数是 out (引用传递)
size_t pc;
struct FnEntry* next;
} FnEntry;
@@ -69,9 +74,10 @@ LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMT
// === 表操作 ===
LLVMValueRef find_var(CgCtx* ctx, const char* name);
void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
VarEntry* add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type);
LLVMValueRef find_fn(CgCtx* ctx, const char* name);
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn);
FnEntry* find_fn_entry(CgCtx* ctx, const char* name);
void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc);
void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc);
LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name);
+1 -1
View File
@@ -68,7 +68,7 @@ static TokenKind check_keyword(const Token* tok) {
KW("struct", TOK_STRUCT); KW("type", TOK_TYPE);
KW("enum", TOK_ENUM); KW("extend", TOK_EXTEND); KW("defer", TOK_DEFER); KW("match", TOK_MATCH);
KW("pub", TOK_PUB); KW("mod", TOK_MOD); KW("use", TOK_USE);
KW("trait", TOK_TRAIT); KW("Self", TOK_SELF);
KW("trait", TOK_TRAIT); KW("Self", TOK_SELF); KW("out", TOK_OUT);
KW("_", TOK_UNDERSCORE);
KW("true", TOK_TRUE); KW("false", TOK_FALSE);
#undef KW
+1 -1
View File
@@ -7,7 +7,7 @@
static const char* NAMES[] = {
[TOK_FN] = "fn", [TOK_LET] = "let", [TOK_VAR] = "var", [TOK_CONST] = "const", [TOK_IF] = "if", [TOK_GUARD] = "guard",
[TOK_PUB] = "pub", [TOK_MOD] = "mod", [TOK_USE] = "use",
[TOK_TRAIT] = "trait", [TOK_SELF] = "Self",
[TOK_TRAIT] = "trait", [TOK_SELF] = "Self", [TOK_OUT] = "out",
[TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return",
[TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_EXTEND] = "extend",
[TOK_DEFER] = "defer", [TOK_MATCH] = "match",
+1 -1
View File
@@ -8,7 +8,7 @@ typedef enum {
// 关键字
TOK_FN, TOK_LET, TOK_VAR, TOK_CONST, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, TOK_GUARD,
TOK_STRUCT, TOK_TYPE, TOK_ENUM, TOK_EXTEND, TOK_DEFER, TOK_MATCH, TOK_PUB, TOK_MOD, TOK_USE,
TOK_TRAIT, TOK_SELF,
TOK_TRAIT, TOK_SELF, TOK_OUT,
// 类型关键字
TOK_I32, TOK_I64, TOK_U64, TOK_F64, TOK_BOOL, TOK_CHAR, TOK_STR, TOK_VOID,
// 字面量
+32
View File
@@ -322,6 +322,38 @@ AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error) {
left = ast_make_list_comp(p->arena,
arena_strdup_impl(p->arena, vname->start, vname->length),
arr, body, tok_loc(tok));
} else if (tok->kind == TOK_FN) {
// lambda: fn(params) -> RetType { body }
const Token* fn_tok = advance(p); // 跳过 fn
// 泛型参数暂不支持(lambda用捕获替代)
if (!expect(p, TOK_LPAREN, error, "缺少 '('")) return NULL;
AstNode* plist[64]; int pc = 0;
while (peek(p)->kind != TOK_RPAREN && !error->message) {
if (pc >= 64) { error->message = "lambda 参数过多(最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; }
bool is_out = match(p, TOK_OUT);
const Token* pname = expect(p, TOK_IDENT, error, "参数名");
if (!pname) return NULL;
if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL;
TypeInfo pti = parse_type_expr(p, error);
if (pti.kind == TYPE_ERROR) return NULL;
plist[pc++] = ast_make_parameter(p->arena,
arena_strdup_impl(p->arena, pname->start, pname->length),
pti.kind, pti.struct_name, is_out, tok_loc(pname));
if (match(p, TOK_COMMA)) continue; else break;
}
if (!expect(p, TOK_RPAREN, error, "缺少 ')'")) return NULL;
TypeKind ret = TYPE_VOID;
const char* ret_sn = NULL;
if (match(p, TOK_ARROW)) {
TypeInfo rti = parse_type_expr(p, error);
if (rti.kind == TYPE_ERROR) return NULL;
ret = rti.kind; ret_sn = rti.struct_name;
}
AstNode* body = parse_block(p, error);
if (!body) return NULL;
AstNode** parr = arena_alloc_impl(p->arena, pc * sizeof(AstNode*));
memcpy(parr, plist, pc * sizeof(AstNode*));
left = ast_make_lambda(p->arena, parr, pc, ret, ret_sn, body, tok_loc(fn_tok));
} else if (tok->kind == TOK_MINUS || tok->kind == TOK_BANG) {
left = parse_unary(p, error);
} else if (tok->kind == TOK_LPAREN) {
+4 -3
View File
@@ -27,7 +27,7 @@ AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) {
}
fields[fcount++] = ast_make_parameter(p->arena,
arena_strdup_impl(p->arena, fname->start, fname->length),
fti.kind, fti.struct_name, tok_loc(fname));
fti.kind, fti.struct_name, false, tok_loc(fname));
if (peek(p)->kind == TOK_COMMA) advance(p);
else break;
}
@@ -348,6 +348,7 @@ AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) {
AstNode* params[64]; int pcount = 0;
while (peek(p)->kind != TOK_RPAREN && !error->message) {
if (pcount >= 64) { error->message = "函数参数过多 (最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; }
bool is_out = match(p, TOK_OUT);
const Token* pname = expect(p, TOK_IDENT, error, "参数名");
if (!pname) return NULL;
if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL;
@@ -355,7 +356,7 @@ AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) {
if (pti.kind == TYPE_ERROR) return NULL;
params[pcount++] = ast_make_parameter(p->arena,
arena_strdup_impl(p->arena, pname->start, pname->length),
pti.kind, pti.struct_name, tok_loc(pname));
pti.kind, pti.struct_name, is_out, tok_loc(pname));
if (match(p, TOK_COMMA)) continue;
else break;
}
@@ -699,7 +700,7 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count,
arena_strdup_impl(p.arena, fbuf, strlen(fbuf)),
fn->as.function.multi_ret_types[i],
fn->as.function.multi_ret_snames[i],
tok_loc(peek(&p)));
false, tok_loc(peek(&p)));
}
if (struct_count >= 64) { error->message = "结构体过多"; return NULL; }
structs[struct_count++] = ast_make_struct_decl(p.arena,
+4
View File
@@ -95,3 +95,7 @@ AstNode* mono_queue[256];
size_t mono_count = 0;
Arena* mono_arena = NULL;
AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板)
// lambda 队列: 分析时创建的闭包函数
AstNode* lambda_queue[256];
size_t lambda_count = 0;
+27 -4
View File
@@ -159,7 +159,8 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
Scope* fn_scope = scope_new(a, scope);
for (size_t j = 0; j < mono_fn->as.function.param_count; j++) {
AstNode* p = mono_fn->as.function.params[j];
scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type);
Symbol* ps = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type);
if (ps && p->as.parameter.is_out) { ps->kind = SYM_VARIABLE; ps->is_mut = true; }
}
// 分析函数体
current_return_type = mono_fn->as.function.return_type;
@@ -197,6 +198,7 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
}
}
Symbol* sym = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, pt);
if (sym && p->as.parameter.is_out) { sym->kind = SYM_VARIABLE; sym->is_mut = true; }
if (sym && pt == TYPE_STRUCT && psn) {
sym->struct_type_name = psn;
}
@@ -221,9 +223,10 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
break;
}
case AST_BLOCK:
case AST_BLOCK: {
Scope* block_scope = scope_new(a, scope);
for (size_t i = 0; i < node->as.block.stmt_count; i++) {
analyze_node(node->as.block.stmts[i], scope, errors, a);
analyze_node(node->as.block.stmts[i], block_scope, errors, a);
}
// 表达式作为值: 块类型 = 最后一条产生值的语句类型
if (node->as.block.stmt_count > 0) {
@@ -243,6 +246,7 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
}
}
break;
}
case AST_LET_STMT: {
TypeKind var_type;
@@ -251,9 +255,13 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
if (node->as.let_stmt.has_type_annot) {
if (node->as.let_stmt.annot_type == TYPE_ARRAY) {
// 数组类型标注: 跳过 init 分析 (init 是自引用的占位符)
is_array_type = true;
var_type = TYPE_ARRAY;
// 分析 init — 除非是自引用 (如 let a: i64[3] = a;)
bool self_ref = (node->as.let_stmt.init->kind == AST_IDENT_EXPR
&& strcmp(node->as.let_stmt.init->as.ident.name,
node->as.let_stmt.name) == 0);
if (!self_ref) analyze_expr(node->as.let_stmt.init, scope, errors, a);
} else {
analyze_expr(node->as.let_stmt.init, scope, errors, a);
TypeKind inferred = node->as.let_stmt.init->type.kind;
@@ -504,4 +512,19 @@ void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) {
scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1, NULL, 0);
analyze_node(ast, global_scope, errors, arena);
// 将 lambda 生成的函数追加到 program 的函数列表
if (lambda_count > 0 && g_program) {
size_t old = g_program->as.program.fn_count;
size_t total = old + lambda_count;
AstNode** new_fns = arena_alloc_impl(arena, total * sizeof(AstNode*));
if (new_fns) {
if (old > 0)
memcpy(new_fns, g_program->as.program.functions, old * sizeof(AstNode*));
memcpy(new_fns + old, lambda_queue, lambda_count * sizeof(AstNode*));
g_program->as.program.functions = new_fns;
g_program->as.program.fn_count = total;
}
lambda_count = 0;
}
}
+4
View File
@@ -18,6 +18,10 @@ extern size_t mono_count;
extern Arena* mono_arena;
extern AstNode* g_program;
// === lambda 闭包队列 ===
extern AstNode* lambda_queue[256];
extern size_t lambda_count;
// === 类型推断上下文 ===
extern TypeKind current_return_type;
extern const char* current_return_struct_name;
+167
View File
@@ -226,6 +226,14 @@ bool reorder_named_args(AstNode* node, Symbol* sym, int param_offset,
void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
Symbol* sym = scope_lookup(scope, node->as.call.name);
// 闭包调用: 变量类型为 TYPE_CLOSURE
if (sym && sym->kind == SYM_VARIABLE && sym->type == TYPE_CLOSURE) {
// 暂不做参数类型检查(MVP), 只分析参数表达式
for (size_t i = 0; i < node->as.call.arg_count; i++)
analyze_expr(node->as.call.args[i], scope, errors, a);
node->type.kind = TYPE_I64; // 默认返回 i64(MVP 限制)
return;
}
if (!sym || sym->kind != SYM_FUNCTION) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的函数 '%s'", node->as.call.name);
@@ -523,6 +531,164 @@ void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, Arena*
return NULL; \
}
// === lambda 表达式分析 ===
static int lambda_counter = 0;
// 遍历 AST 收集自由变量(不在 lambda_scope 但在 parent_scope 中的变量)
static void collect_free_vars_impl(AstNode* body, Scope* lambda_scope, Scope* parent_scope,
const char** names, TypeKind* types, size_t* count,
const char** locals, size_t* local_count) {
if (!body || *count >= 16) return;
switch (body->kind) {
case AST_IDENT_EXPR: {
const char* name = body->as.ident.name;
// 跳过 lambda 参数和局部变量
for (size_t i = 0; i < *local_count; i++)
if (strcmp(locals[i], name) == 0) return;
// 在父作用域中查找(不在 lambda scope 中)
Symbol* sym = scope_lookup(parent_scope, name);
if (sym && (sym->kind == SYM_VARIABLE || sym->kind == SYM_PARAMETER)) {
for (size_t i = 0; i < *count; i++)
if (strcmp(names[i], name) == 0) return;
names[*count] = name;
types[*count] = sym->type;
(*count)++;
}
return;
}
case AST_LAMBDA: return;
case AST_LET_STMT:
// 先分析 init (可能引用外部变量)
collect_free_vars_impl(body->as.let_stmt.init, lambda_scope, parent_scope,
names, types, count, locals, local_count);
// 将新声明的变量加入 locals
if (*local_count < 16)
locals[(*local_count)++] = body->as.let_stmt.name;
return;
case AST_BLOCK:
for (size_t i = 0; i < body->as.block.stmt_count; i++)
collect_free_vars_impl(body->as.block.stmts[i], lambda_scope, parent_scope,
names, types, count, locals, local_count);
return;
case AST_BINARY_EXPR:
collect_free_vars_impl(body->as.binary.left, lambda_scope, parent_scope, names, types, count, locals, local_count);
collect_free_vars_impl(body->as.binary.right, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_UNARY_EXPR:
collect_free_vars_impl(body->as.unary.operand, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_CALL_EXPR:
for (size_t i = 0; i < body->as.call.arg_count; i++)
collect_free_vars_impl(body->as.call.args[i], lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_METHOD_CALL:
collect_free_vars_impl(body->as.method_call.receiver, lambda_scope, parent_scope, names, types, count, locals, local_count);
for (size_t i = 0; i < body->as.method_call.arg_count; i++)
collect_free_vars_impl(body->as.method_call.args[i], lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_RETURN_STMT:
collect_free_vars_impl(body->as.return_stmt.expr, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_EXPR_STMT:
collect_free_vars_impl(body->as.expr_stmt.expr, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_ASSIGN_STMT:
collect_free_vars_impl(body->as.assign_stmt.value, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_IF_STMT:
collect_free_vars_impl(body->as.if_stmt.cond, lambda_scope, parent_scope, names, types, count, locals, local_count);
collect_free_vars_impl(body->as.if_stmt.then_block, lambda_scope, parent_scope, names, types, count, locals, local_count);
collect_free_vars_impl(body->as.if_stmt.else_block, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_WHILE_STMT:
collect_free_vars_impl(body->as.while_stmt.cond, lambda_scope, parent_scope, names, types, count, locals, local_count);
collect_free_vars_impl(body->as.while_stmt.body, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_INDEX_EXPR:
collect_free_vars_impl(body->as.index_expr.array, lambda_scope, parent_scope, names, types, count, locals, local_count);
collect_free_vars_impl(body->as.index_expr.index, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_FIELD_ACCESS:
collect_free_vars_impl(body->as.field_access.object, lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
case AST_STRUCT_INIT:
for (size_t i = 0; i < body->as.struct_init.field_count; i++)
collect_free_vars_impl(body->as.struct_init.field_values[i], lambda_scope, parent_scope, names, types, count, locals, local_count);
return;
default: return;
}
}
static void collect_free_vars(AstNode* body, Scope* ls, Scope* ps,
const char** names, TypeKind* types, size_t* count) {
// 将 lambda 参数也当作局部变量(不会被捕获)
const char* locals[32]; size_t local_count = 0;
for (Symbol* s = ls->head; s; s = s->next) {
if (local_count < 32) locals[local_count++] = s->name;
}
collect_free_vars_impl(body, ls, ps, names, types, count, locals, &local_count);
}
void analyze_lambda(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
lambda_counter++;
int name_len = snprintf(NULL, 0, "__lambda_%d", lambda_counter) + 1;
char* gen_name = arena_alloc_impl(a, name_len);
snprintf(gen_name, name_len, "__lambda_%d", lambda_counter);
node->as.lambda.generated_name = gen_name;
// 分析 lambda 体(参数作用域)
Scope* lambda_scope = scope_new(a, scope);
for (size_t i = 0; i < node->as.lambda.param_count; i++) {
AstNode* p = node->as.lambda.params[i];
scope_insert(lambda_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type);
}
TypeKind saved_ret = current_return_type;
const char* saved_ret_sn = current_return_struct_name;
current_return_type = node->as.lambda.return_type;
current_return_struct_name = node->as.lambda.return_struct_type_name;
analyze_node(node->as.lambda.body, lambda_scope, errors, a);
current_return_type = saved_ret;
current_return_struct_name = saved_ret_sn;
// 收集被捕获的外部变量
const char* cap_names[16]; TypeKind cap_types[16]; size_t cap_count = 0;
collect_free_vars(node->as.lambda.body, lambda_scope, scope, cap_names, cap_types, &cap_count);
// 创建顶层函数 AST 节点, 加入队列供 codegen 使用
AstNode* fn = ast_make_function(a, gen_name,
node->as.lambda.params, node->as.lambda.param_count,
node->as.lambda.return_type,
node->as.lambda.return_struct_type_name,
node->as.lambda.body, false, NULL, 0, node->loc);
// 设置捕获信息 (复制到 arena, 局部数组会出作用域)
if (cap_count > 0) {
const char** cap_names_a = arena_alloc_impl(a, cap_count * sizeof(const char*));
TypeKind* cap_types_a = arena_alloc_impl(a, cap_count * sizeof(TypeKind));
memcpy(cap_names_a, cap_names, cap_count * sizeof(const char*));
memcpy(cap_types_a, cap_types, cap_count * sizeof(TypeKind));
fn->as.function.captured = cap_names_a;
fn->as.function.cap_types = cap_types_a;
fn->as.function.cap_count = cap_count;
node->as.lambda.captured = cap_names_a;
node->as.lambda.captured_count = cap_count;
}
if (lambda_count < 256)
lambda_queue[lambda_count++] = fn;
// 注册函数符号(支持递归调用自身)
TypeKind* pts = node->as.lambda.param_count > 0
? arena_alloc_impl(a, node->as.lambda.param_count * sizeof(TypeKind)) : NULL;
for (size_t i = 0; i < node->as.lambda.param_count; i++)
pts[i] = node->as.lambda.params[i]->as.parameter.type;
scope_insert_function(scope, a, gen_name,
node->as.lambda.return_type,
node->as.lambda.return_struct_type_name,
pts, NULL, NULL, node->as.lambda.param_count, NULL, 0);
node->type.kind = TYPE_CLOSURE;
}
SEMA_HANDLER(analyze_lambda)
SEMA_HANDLER(analyze_ident_expr)
SEMA_HANDLER(analyze_unary_expr)
SEMA_HANDLER(analyze_binary_expr)
@@ -572,6 +738,7 @@ void analyze_expr_init(void) {
ast_dispatch_set(&sema_dispatch, AST_IF_STMT, analyze_node_wrap);
ast_dispatch_set(&sema_dispatch, AST_BLOCK, analyze_node_wrap);
ast_dispatch_set(&sema_dispatch, AST_LIST_COMP, analyze_list_comp_wrap);
ast_dispatch_set(&sema_dispatch, AST_LAMBDA, analyze_lambda_wrap);
}
void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) {
+22
View File
@@ -0,0 +1,22 @@
// out 参数测试 — 引用传递
fn swap(out x: i64, out y: i64) -> void {
let t = x;
x = y;
y = t;
}
fn increment(out x: i64) -> void {
x = x + 1;
}
fn main() -> void {
let a = 10;
let b = 20;
print_i64(a);
print_i64(b);
swap(a, b);
print_i64(a);
print_i64(b);
increment(a);
print_i64(a);
}
+22
View File
@@ -0,0 +1,22 @@
// out 参数 + 结构体测试
struct Point { x: i64, y: i64 }
fn init_point(out p: Point) -> void {
p = Point { x: 100, y: 200 };
}
fn offset_point(out p: Point) -> void {
p = Point { x: p.x + 50, y: p.y + 100 };
}
fn main() -> void {
let p = Point { x: 0, y: 0 };
print_i64(p.x);
print_i64(p.y);
init_point(p);
print_i64(p.x);
print_i64(p.y);
offset_point(p);
print_i64(p.x);
print_i64(p.y);
}
+26
View File
@@ -0,0 +1,26 @@
// 闭包测试 — 非捕获 lambda + 变量捕获
fn make_adder(base: i64) -> i64 {
let adder = fn(x: i64) -> i64 { return x + base; };
return adder(50);
}
fn main() -> void {
// 测试1: 非捕获 lambda
let double = fn(x: i64) -> i64 { return x * 2; };
print_i64(double(21)); // 42
// 测试2: 捕获单个变量
let base = 100;
let add = fn(x: i64) -> i64 { return x + base; };
print_i64(add(50)); // 150
// 测试3: 捕获多个变量
let a = 10;
let b = 20;
let sum3 = fn(x: i64) -> i64 { return x + a + b; };
print_i64(sum3(5)); // 35
// 测试4: 函数内创建闭包
let r = make_adder(200);
print_i64(r); // 250
}
+7 -7
View File
@@ -129,8 +129,8 @@ void test_codegen_struct_decl() {
/* 构造 AST: struct Point { x: i64, y: i64 } */
AstNode* fields[2];
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1));
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, false, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, false, loc_at(1, 1));
AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1));
AstNode* structs[] = { struct_decl };
@@ -185,8 +185,8 @@ void test_codegen_struct_field_access() {
/* 构造 AST: struct Point { x: i64, y: i64 } */
AstNode* fields[2];
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1));
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, false, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, false, loc_at(1, 1));
AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1));
AstNode* structs[] = { struct_decl };
@@ -356,13 +356,13 @@ void test_codegen_method_call() {
/* struct Point { x: i64, y: i64 } */
AstNode* fields[2];
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1));
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, false, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, false, loc_at(1, 1));
AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1));
AstNode* structs[] = { struct_decl };
/* fn Point$get_x(self: Point) -> i64 { return self.x; } */
AstNode* self_param = ast_make_parameter(&a, "self", TYPE_STRUCT, "Point", loc_at(1, 1));
AstNode* self_param = ast_make_parameter(&a, "self", TYPE_STRUCT, "Point", false, loc_at(1, 1));
AstNode* params[] = { self_param };
AstNode* self_ident = ast_make_ident(&a, "self", loc_at(1, 1));
self_ident->type.kind = TYPE_STRUCT;
+70
View File
@@ -418,6 +418,72 @@ void test_match_wildcard_only_sema_ok() {
arena_destroy(&a);
}
void test_out_param_assign_ok() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"fn swap(out x: i64, out y: i64) -> void { let t = x; x = y; y = t; return; }"
"fn main() -> void { let a = 10; let b = 20; swap(a, b); return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors, &a);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count == 0); // out 参数赋值不应报错
arena_destroy(&a);
}
void test_in_param_assign_error() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"fn bad(x: i64) -> void { x = 42; return; }"
"fn main() -> void { bad(10); return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors, &a);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count > 0); // 非 out 参数赋值应报错
arena_destroy(&a);
}
void test_lambda_ok() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"fn main() -> void { let f = fn(x: i64) -> i64 { return x * 2; }; return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors, &a);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count == 0); // lambda 定义应通过
arena_destroy(&a);
}
void test_lambda_call_ok() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"fn main() -> void { let f = fn(x: i64) -> i64 { return x + 1; }; let r = f(41); return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors, &a);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count == 0); // 闭包调用应通过
arena_destroy(&a);
}
int main(void) {
TEST_RUN(test_type_error);
TEST_RUN(test_undefined_var);
@@ -443,5 +509,9 @@ int main(void) {
TEST_RUN(test_match_enum_sema_ok);
TEST_RUN(test_match_int_sema_ok);
TEST_RUN(test_match_wildcard_only_sema_ok);
TEST_RUN(test_out_param_assign_ok);
TEST_RUN(test_in_param_assign_error);
TEST_RUN(test_lambda_ok);
TEST_RUN(test_lambda_call_ok);
return test_summary();
}