From 9169796b771890a5ff0651f3dddb7ae32b5accc8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Sat, 6 Jun 2026 16:25:18 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B3=9B=E5=9E=8B=E5=8D=95=E6=80=81?= =?UTF-8?q?=E5=8C=96=E5=AE=8C=E6=88=90=20=E2=80=94=20fn=20id(x:=20T)=20?= =?UTF-8?q?->=20T=20=E5=85=A8=E6=B5=81=E6=B0=B4=E7=BA=BF=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/sema/sema.c | 189 +++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 179 insertions(+), 10 deletions(-) diff --git a/src/sema/sema.c b/src/sema/sema.c index 650eee2..0c4eb0a 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -2,10 +2,106 @@ #include #include +// === 泛型单态化: AST 类型替换 === +// 将 AST 中所有匹配 type_param_name 的类型引用替换为 concrete_type +static void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname); + +static void subst_type_info(TypeInfo* ti, const char* tparam, TypeKind concrete, const char* concrete_sname) { + if (ti->kind == TYPE_STRUCT && ti->struct_name && strcmp(ti->struct_name, tparam) == 0) { + ti->kind = concrete; + ti->struct_name = concrete_sname; + } +} + +static void subst_ast_types(AstNode* node, const char* tparam, TypeKind concrete, const char* concrete_sname) { + if (!node) return; + // 替换节点自身类型 + subst_type_info(&node->type, tparam, concrete, concrete_sname); + switch (node->kind) { + case AST_PROGRAM: + for (size_t i = 0; i < node->as.program.fn_count; i++) + subst_ast_types(node->as.program.functions[i], tparam, concrete, concrete_sname); + break; + case AST_FUNCTION: + if (node->as.function.return_type == TYPE_STRUCT + && node->as.function.return_struct_type_name + && strcmp(node->as.function.return_struct_type_name, tparam) == 0) { + node->as.function.return_type = concrete; + node->as.function.return_struct_type_name = concrete_sname; + } + for (size_t i = 0; i < node->as.function.param_count; i++) + subst_ast_types(node->as.function.params[i], tparam, concrete, concrete_sname); + subst_ast_types(node->as.function.body, tparam, concrete, concrete_sname); + break; + case AST_PARAMETER: + if (node->as.parameter.type == TYPE_STRUCT && node->as.parameter.struct_type_name + && strcmp(node->as.parameter.struct_type_name, tparam) == 0) { + node->as.parameter.type = concrete; + node->as.parameter.struct_type_name = concrete_sname; + } + break; + case AST_BLOCK: + for (size_t i = 0; i < node->as.block.stmt_count; i++) + subst_ast_types(node->as.block.stmts[i], tparam, concrete, concrete_sname); + break; + case AST_LET_STMT: + if (node->as.let_stmt.annot_type == TYPE_STRUCT && node->as.let_stmt.struct_type_name + && strcmp(node->as.let_stmt.struct_type_name, tparam) == 0) { + node->as.let_stmt.annot_type = concrete; + node->as.let_stmt.struct_type_name = concrete_sname; + } + subst_ast_types(node->as.let_stmt.init, tparam, concrete, concrete_sname); + break; + case AST_IF_STMT: + subst_ast_types(node->as.if_stmt.cond, tparam, concrete, concrete_sname); + subst_ast_types(node->as.if_stmt.then_block, tparam, concrete, concrete_sname); + subst_ast_types(node->as.if_stmt.else_block, tparam, concrete, concrete_sname); + break; + case AST_WHILE_STMT: + subst_ast_types(node->as.while_stmt.cond, tparam, concrete, concrete_sname); + subst_ast_types(node->as.while_stmt.body, tparam, concrete, concrete_sname); + break; + case AST_RETURN_STMT: + subst_ast_types(node->as.return_stmt.expr, tparam, concrete, concrete_sname); + break; + case AST_EXPR_STMT: + subst_ast_types(node->as.expr_stmt.expr, tparam, concrete, concrete_sname); + break; + case AST_BINARY_EXPR: + subst_ast_types(node->as.binary.left, tparam, concrete, concrete_sname); + subst_ast_types(node->as.binary.right, tparam, concrete, concrete_sname); + break; + case AST_UNARY_EXPR: + subst_ast_types(node->as.unary.operand, tparam, concrete, concrete_sname); + break; + case AST_CALL_EXPR: + for (size_t i = 0; i < node->as.call.arg_count; i++) + subst_ast_types(node->as.call.args[i], tparam, concrete, concrete_sname); + break; + case AST_ASSIGN_STMT: + subst_ast_types(node->as.assign_stmt.value, tparam, concrete, concrete_sname); + break; + case AST_FIELD_ACCESS: + subst_ast_types(node->as.field_access.object, tparam, concrete, concrete_sname); + break; + case AST_STRUCT_INIT: + for (size_t i = 0; i < node->as.struct_init.field_count; i++) + subst_ast_types(node->as.struct_init.field_values[i], tparam, concrete, concrete_sname); + break; + default: break; + } +} + // === 类型关系 === static TypeKind current_return_type = TYPE_VOID; static const char* current_return_struct_name = NULL; +// 单态化队列: 泛型函数调用时生成的具象化函数 +static AstNode* mono_queue[256]; +static size_t mono_count = 0; +static Arena* mono_arena = NULL; +static AstNode* g_program = NULL; // 当前 AST_PROGRAM(用于查找泛型函数模板) + static TypeKind promote(TypeKind a, TypeKind b) { // 枚举在算术运算中视为 i64 if (a == TYPE_ENUM) a = TYPE_I64; @@ -252,16 +348,58 @@ static void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Ar bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i], sym->param_struct_names ? sym->param_struct_names[i] : NULL, i, node, sym, errors, a); - // 泛型: 若实参匹配类型参数,传播具体类型到返回值 - if (is_generic_param && sym->return_type == TYPE_STRUCT - && sym->return_struct_type_name && sym->type_params) { - for (size_t t = 0; t < sym->type_param_count; t++) { - if (strcmp(sym->return_struct_type_name, sym->type_params[t]) == 0) { - node->type.kind = node->as.call.args[i]->type.kind; - node->type.struct_name = node->as.call.args[i]->type.struct_name; - return; + // 泛型单态化: 创建具象化函数副本并注册 + if (is_generic_param && sym->type_params && sym->type_param_count > 0) { + TypeKind concrete = node->as.call.args[i]->type.kind; + const char* concrete_sn = node->as.call.args[i]->type.struct_name; + // 构造 mangled 名: fn$concrete_type + const char* ct_name = concrete_sn ? concrete_sn : type_name(concrete); + int mname_len = snprintf(NULL, 0, "%s$%s", node->as.call.name, ct_name) + 1; + char* mname = arena_alloc_impl(a, mname_len); + snprintf(mname, mname_len, "%s$%s", node->as.call.name, ct_name); + // 检查是否已存在 + Symbol* existing = scope_lookup(scope, mname); + if (!existing && g_program) { + // 查找原始泛型函数 AST 节点 + AstNode* generic_fn = NULL; + for (size_t fn_i = 0; fn_i < g_program->as.program.fn_count; fn_i++) { + if (strcmp(g_program->as.program.functions[fn_i]->as.function.name, + node->as.call.name) == 0) { + generic_fn = g_program->as.program.functions[fn_i]; + break; + } + } + if (generic_fn && mono_count < 256) { + // 创建浅拷贝(共享 body,subst_ast_types 修改类型标注) + AstNode* mono_fn = ast_make_function(a, mname, + generic_fn->as.function.params, + generic_fn->as.function.param_count, + generic_fn->as.function.return_type, + generic_fn->as.function.return_struct_type_name, + generic_fn->as.function.body, + false, NULL, 0, + generic_fn->loc); + // 类型替换: T → concrete + subst_ast_types(mono_fn, sym->type_params[0], concrete, concrete_sn); + // 注册到队列 + mono_queue[mono_count++] = mono_fn; + // 注册符号(后续分析会处理函数体) + TypeKind* mpts = mono_fn->as.function.param_count > 0 + ? arena_alloc_impl(a, mono_fn->as.function.param_count * sizeof(TypeKind)) : NULL; + for (size_t pj = 0; pj < mono_fn->as.function.param_count; pj++) { + mpts[pj] = mono_fn->as.function.params[pj]->as.parameter.type; + } + scope_insert_function(scope, a, mname, + mono_fn->as.function.return_type, + mono_fn->as.function.return_struct_type_name, + mpts, NULL, NULL, + mono_fn->as.function.param_count, NULL, 0); } } + // 重定向调用到单态化函数 + node->as.call.name = mname; + sym = scope_lookup(scope, mname); + if (!sym) { node->type.kind = TYPE_ERROR; return; } } } node->type.kind = sym->return_type; @@ -595,9 +733,37 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* fn->as.function.type_params, fn->as.function.type_param_count); } - // 第三遍:分析每个函数体 + // 第三遍:分析每个函数体(跳过泛型模板,它们由单态化副本替代) for (size_t i = 0; i < node->as.program.fn_count; i++) { - analyze_node(node->as.program.functions[i], scope, errors, a); + AstNode* fn = node->as.program.functions[i]; + if (fn->as.function.type_param_count > 0) continue; + analyze_node(fn, scope, errors, a); + } + // 处理单态化队列: 分析新生成的具象化函数体 + for (size_t mi = 0; mi < mono_count; mi++) { + // 先将单态化函数注册到 scope(函数签名已在 analyze_call_expr 中注册) + AstNode* mono_fn = mono_queue[mi]; + // 注册参数到新作用域 + Scope* fn_scope = scope_new(a, scope); + for (size_t j = 0; j < mono_fn->as.function.param_count; j++) { + AstNode* p = mono_fn->as.function.params[j]; + scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); + } + // 分析函数体 + current_return_type = mono_fn->as.function.return_type; + current_return_struct_name = mono_fn->as.function.return_struct_type_name; + if (mono_fn->as.function.body) + analyze_node(mono_fn->as.function.body, fn_scope, errors, a); + // 用单态化函数替换 program 中的泛型模板 + for (size_t fi = 0; fi < node->as.program.fn_count; fi++) { + AstNode* fn = node->as.program.functions[fi]; + if (fn->as.function.type_param_count > 0 && + strncmp(mono_fn->as.function.name, fn->as.function.name, + strlen(fn->as.function.name)) == 0) { + node->as.program.functions[fi] = mono_fn; + break; + } + } } break; @@ -902,6 +1068,9 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) { Scope* global_scope = scope_new(arena, NULL); + g_program = ast; + mono_count = 0; + mono_arena = arena; // 注册内置函数 TypeKind params_i64[] = {TYPE_I64};