diff --git a/include/l_lang.h b/include/l_lang.h index 0e5b6ea..282f938 100644 --- a/include/l_lang.h +++ b/include/l_lang.h @@ -18,6 +18,7 @@ typedef enum { TYPE_STRUCT, // 结构体类型 TYPE_ENUM, // 枚举类型 TYPE_ARRAY, // 固定大小数组类型 + TYPE_GENERIC, // 泛型类型参数(单态化前) TYPE_UNKNOWN, // 尚未推断 TYPE_ERROR, // 类型错误 } TypeKind; diff --git a/src/ast/ast.c b/src/ast/ast.c index ab642eb..c918e2d 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -31,13 +31,16 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, TypeKind ret, const char* ret_struct_name, AstNode* body, - bool is_pub, SourceLoc loc) { + bool is_pub, const char** type_params, size_t tp_count, + SourceLoc loc) { NEW(alloc, AST_FUNCTION); n->as.function.name = name; n->as.function.params = params; n->as.function.param_count = pcount; n->as.function.return_type = ret; n->as.function.return_struct_type_name = ret_struct_name; n->as.function.body = body; n->as.function.is_pub = is_pub; + n->as.function.type_params = type_params; + n->as.function.type_param_count = tp_count; return n; } diff --git a/src/ast/ast.h b/src/ast/ast.h index cca2cd3..39d1b48 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -67,7 +67,8 @@ struct AstNode { // AST_FUNCTION struct { const char* name; struct AstNode** params; size_t param_count; TypeKind return_type; const char* return_struct_type_name; - struct AstNode* body; bool is_pub; } function; + struct AstNode* body; bool is_pub; + const char** type_params; size_t type_param_count; } function; // AST_PARAMETER (也用作结构体字段: name + type) struct { const char* name; TypeKind type; const char* struct_type_name; } parameter; // AST_BLOCK @@ -137,7 +138,8 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode** impls, size_t impl_count, SourceLoc loc); AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, TypeKind ret, const char* ret_struct_name, AstNode* body, - bool is_pub, SourceLoc loc); + bool is_pub, const char** type_params, size_t tp_count, + SourceLoc loc); AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc); AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc); AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot, diff --git a/src/parser/parser.c b/src/parser/parser.c index 0bd7872..aaed4be 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -907,9 +907,22 @@ static AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) { const Token* fn_tok = advance(p); // fn const Token* name = expect(p, TOK_IDENT, error, "fn 后应为函数名"); if (!name) return NULL; + // 泛型类型参数: + const char* type_params[8]; int tp_count = 0; + if (peek(p)->kind == TOK_LT) { + advance(p); // 跳过 '<' + while (peek(p)->kind != TOK_GT && !error->message) { + if (tp_count >= 8) { error->message = "类型参数过多 (最多8)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } + const Token* tp = expect(p, TOK_IDENT, error, "类型参数名"); + if (!tp) return NULL; + type_params[tp_count++] = arena_strdup_impl(p->arena, tp->start, tp->length); + if (peek(p)->kind == TOK_COMMA) advance(p); else break; + } + if (!expect(p, TOK_GT, error, "缺少 '>'")) return NULL; + } if (!expect(p, TOK_LPAREN, error, "缺少 '('")) return NULL; - // 参数列表 + // 参数列表(泛型参数可标注为类型参数名) AstNode* params[64]; int pcount = 0; while (peek(p)->kind != TOK_RPAREN && !error->message) { if (pcount >= 64) { error->message = "函数参数过多 (最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } @@ -941,9 +954,14 @@ static AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) { AstNode** parr = arena_alloc_impl(p->arena, pcount * sizeof(AstNode*)); memcpy(parr, params, pcount * sizeof(AstNode*)); + const char** tparr = NULL; + if (tp_count > 0) { + tparr = arena_alloc_impl(p->arena, tp_count * sizeof(const char*)); + memcpy(tparr, type_params, tp_count * sizeof(const char*)); + } return ast_make_function(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - parr, pcount, ret, ret_struct_name, body, is_pub, tok_loc(fn_tok)); + parr, pcount, ret, ret_struct_name, body, is_pub, tparr, tp_count, tok_loc(fn_tok)); } // === 模块文件加载辅助 === diff --git a/src/sema/sema.c b/src/sema/sema.c index 2e4e8e6..650eee2 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -159,31 +159,39 @@ static void analyze_binary_expr(AstNode* node, Scope* scope, ErrorList* errors, } // 参数类型匹配检查(CALL_EXPR 和 METHOD_CALL 共用) -static void check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname, - size_t idx, AstNode* call_node, Scope* scope, +static bool check_arg_type(AstNode* arg, TypeKind expected, const char* expected_sname, + size_t idx, AstNode* call_node, Symbol* fn_sym, ErrorList* errors, Arena* a) { - (void)scope; (void)a; + (void)a; TypeKind actual = arg->type.kind; - if (actual == TYPE_ERROR) return; - if (expected == TYPE_STRUCT) { + if (actual == TYPE_ERROR) return false; + if (expected == TYPE_STRUCT && expected_sname) { + // 检查是否是泛型类型参数(匹配则接受任意类型) + if (fn_sym && fn_sym->type_params) { + for (size_t t = 0; t < fn_sym->type_param_count; t++) { + if (strcmp(expected_sname, fn_sym->type_params[t]) == 0) + return true; // 泛型参数,接受任意类型 + } + } const char* actual_name = arg->type.struct_name; - if (actual != TYPE_STRUCT || !actual_name || !expected_sname || + if (actual != TYPE_STRUCT || !actual_name || strcmp(actual_name, expected_sname) != 0) { error_add(errors, "", call_node->loc.line, call_node->loc.col, "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", idx + 1, expected_sname ? expected_sname : "struct", actual_name ? actual_name : type_name(actual)); } - return; + return false; } - if (actual == expected) return; - if (expected == TYPE_I64 && actual == TYPE_ENUM) return; - if (can_implicit_convert(actual, expected)) return; + if (actual == expected) return false; + if (expected == TYPE_I64 && actual == TYPE_ENUM) return false; + if (can_implicit_convert(actual, expected)) return false; if (actual == TYPE_I64 && arg->kind == AST_LITERAL_EXPR - && (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return; + && (expected == TYPE_I32 || expected == TYPE_U64 || expected == TYPE_CHAR)) return false; error_add(errors, "", call_node->loc.line, call_node->loc.col, "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", idx + 1, type_name(expected), type_name(actual)); + return false; } // 命名参数重排序(CALL_EXPR 和 METHOD_CALL 共用) @@ -241,9 +249,20 @@ static void analyze_call_expr(AstNode* node, Scope* scope, ErrorList* errors, Ar } for (size_t i = 0; i < node->as.call.arg_count; i++) { analyze_expr(node->as.call.args[i], scope, errors, a); - check_arg_type(node->as.call.args[i], sym->param_types[i], + bool is_generic_param = check_arg_type(node->as.call.args[i], sym->param_types[i], sym->param_struct_names ? sym->param_struct_names[i] : NULL, - i, node, scope, errors, a); + i, node, sym, errors, a); + // 泛型: 若实参匹配类型参数,传播具体类型到返回值 + if (is_generic_param && sym->return_type == TYPE_STRUCT + && sym->return_struct_type_name && sym->type_params) { + for (size_t t = 0; t < sym->type_param_count; t++) { + if (strcmp(sym->return_struct_type_name, sym->type_params[t]) == 0) { + node->type.kind = node->as.call.args[i]->type.kind; + node->type.struct_name = node->as.call.args[i]->type.struct_name; + return; + } + } + } } node->type.kind = sym->return_type; if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) @@ -422,7 +441,7 @@ static void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, analyze_expr(node->as.method_call.args[i], scope, errors, a); check_arg_type(node->as.method_call.args[i], sym->param_types[i + 1], sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL, - i, node, scope, errors, a); + i, node, sym, errors, a); } node->type.kind = sym->return_type; if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) @@ -572,7 +591,9 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* scope_insert_function(scope, a, fn->as.function.name, ret_t, ret_sn, pts, pnames, pstruct_names, - fn->as.function.param_count); + fn->as.function.param_count, + fn->as.function.type_params, + fn->as.function.type_param_count); } // 第三遍:分析每个函数体 for (size_t i = 0; i < node->as.program.fn_count; i++) { @@ -884,13 +905,13 @@ void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) { // 注册内置函数 TypeKind params_i64[] = {TYPE_I64}; - scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, NULL, params_i64, NULL, NULL, 1); + scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, NULL, params_i64, NULL, NULL, 1, NULL, 0); TypeKind params_f64[] = {TYPE_F64}; - scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, NULL, params_f64, NULL, NULL, 1); + scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, NULL, params_f64, NULL, NULL, 1, NULL, 0); TypeKind params_bool[] = {TYPE_BOOL}; - scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, NULL, params_bool, NULL, NULL, 1); + scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, NULL, params_bool, NULL, NULL, 1, NULL, 0); TypeKind params_str[] = {TYPE_STR}; - scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1); + scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, NULL, 1, NULL, 0); analyze_node(ast, global_scope, errors, arena); } diff --git a/src/sema/symbol.c b/src/sema/symbol.c index 29a86ef..0c55ee2 100644 --- a/src/sema/symbol.c +++ b/src/sema/symbol.c @@ -31,6 +31,7 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name, sym->name = name; sym->kind = kind; sym->type = type; sym->is_mut = false; sym->return_type = TYPE_VOID; sym->param_types = NULL; sym->param_names = NULL; sym->param_count = 0; + sym->type_params = NULL; sym->type_param_count = 0; sym->struct_field_names = NULL; sym->struct_field_types = NULL; sym->struct_field_count = 0; @@ -49,7 +50,8 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name, Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, TypeKind ret, const char* ret_struct_name, TypeKind* pt, const char** pnames, - const char** pstruct_names, size_t pc) { + const char** pstruct_names, size_t pc, + const char** tparams, size_t tpc) { if (scope->head) { for (Symbol* sym = scope->head; sym; sym = sym->next) { if (strcmp(sym->name, name) == 0) return NULL; @@ -64,6 +66,8 @@ Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, sym->param_names = pnames; sym->param_struct_names = pstruct_names; sym->param_count = pc; + sym->type_params = tparams; + sym->type_param_count = tpc; sym->struct_field_names = NULL; sym->struct_field_types = NULL; sym->struct_field_count = 0; @@ -92,6 +96,7 @@ Symbol* scope_insert_struct(Scope* scope, void* alloc, const char* name, sym->name = name; sym->kind = SYM_STRUCT; sym->type = TYPE_STRUCT; sym->is_mut = false; sym->return_type = TYPE_VOID; sym->param_types = NULL; sym->param_names = NULL; sym->param_count = 0; + sym->type_params = NULL; sym->type_param_count = 0; sym->struct_field_names = fnames; sym->struct_field_types = ftypes; sym->struct_field_struct_names = fstruct_names; @@ -140,6 +145,7 @@ Symbol* scope_insert_enum(Scope* scope, void* alloc, const char* name, sym->name = name; sym->kind = SYM_ENUM; sym->type = TYPE_ENUM; sym->is_mut = false; sym->return_type = TYPE_VOID; sym->param_types = NULL; sym->param_names = NULL; sym->param_count = 0; + sym->type_params = NULL; sym->type_param_count = 0; sym->struct_field_names = vnames; sym->struct_field_types = NULL; sym->struct_field_struct_names = NULL; diff --git a/src/sema/symbol.h b/src/sema/symbol.h index 09f5c7c..1a58933 100644 --- a/src/sema/symbol.h +++ b/src/sema/symbol.h @@ -14,6 +14,8 @@ typedef struct Symbol { // 函数特有 TypeKind return_type; const char* return_struct_type_name; // 返回类型为 struct 时的类型名 + const char** type_params; // 泛型类型参数名 + size_t type_param_count; TypeKind* param_types; const char** param_names; // 参数名(用于命名参数匹配) const char** param_struct_names; // 参数为 struct 时的类型名 @@ -57,7 +59,8 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name, Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, TypeKind ret, const char* ret_struct_name, TypeKind* pt, const char** pnames, - const char** pstruct_names, size_t pc); + const char** pstruct_names, size_t pc, + const char** tparams, size_t tpc); // 插入结构体符号 Symbol* scope_insert_struct(Scope* scope, void* alloc, const char* name, diff --git a/test/programs/34_generic.l b/test/programs/34_generic.l new file mode 100644 index 0000000..bde0e3e --- /dev/null +++ b/test/programs/34_generic.l @@ -0,0 +1,9 @@ +fn id(x: T) -> T { + return x; +} + +fn main() -> i64 { + let a = id(42); + print_i64(a); // 42 + return 0; +} diff --git a/test/test_codegen.c b/test/test_codegen.c index d5c4215..77fb990 100644 --- a/test/test_codegen.c +++ b/test/test_codegen.c @@ -12,7 +12,7 @@ void test_codegen_simple_function() { AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 42, loc_at(1, 1)), loc_at(1, 1)); AstNode* stmts[] = { ret }; AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -46,7 +46,7 @@ void test_codegen_if_else() { ast_make_literal_bool(&a, true, loc_at(1, 1)), then_block, else_block, loc_at(1, 1)); AstNode* stmts[] = { if_stmt }; AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -77,7 +77,7 @@ void test_codegen_binary_ops() { AstNode* ret = ast_make_return(&a, expr, loc_at(1, 1)); AstNode* stmts[] = { ret }; AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -107,7 +107,7 @@ void test_codegen_while_loop() { AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, loc_at(1, 1)), loc_at(1, 1)); AstNode* stmts[] = { while_stmt, ret }; AstNode* fn_body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, fn_body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, fn_body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -160,7 +160,7 @@ void test_codegen_struct_decl() { AstNode* ret = ast_make_return(&a, field_x, loc_at(1, 1)); AstNode* stmts[] = { let_stmt, ret }; AstNode* body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -215,7 +215,7 @@ void test_codegen_struct_field_access() { AstNode* ret = ast_make_return(&a, field_y, loc_at(1, 1)); AstNode* stmts[] = { let_stmt, ret }; AstNode* body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -265,7 +265,7 @@ void test_codegen_enum() { AstNode* stmts[] = { let_stmt, print_call, ret }; AstNode* body = ast_make_block(&a, stmts, 3, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, enums, 1, NULL, 0, loc_at(1, 1)); @@ -330,7 +330,7 @@ void test_codegen_array() { AstNode* stmts[] = { let_stmt, arr_assign, print_call, ret }; AstNode* body = ast_make_block(&a, stmts, 4, loc_at(1, 1)); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -373,7 +373,7 @@ void test_codegen_method_call() { AstNode* ret_body = ast_make_return(&a, field_x, loc_at(1, 1)); AstNode* ret_stmts[] = { ret_body }; AstNode* body = ast_make_block(&a, ret_stmts, 1, loc_at(1, 1)); - AstNode* get_x_fn = ast_make_function(&a, "Point$get_x", params, 1, TYPE_I64, NULL, body, false, loc_at(1, 1)); + AstNode* get_x_fn = ast_make_function(&a, "Point$get_x", params, 1, TYPE_I64, NULL, body, false, NULL, 0, loc_at(1, 1)); /* fn main() -> i64 { let p = Point { x: 42, y: 0 }; @@ -399,7 +399,7 @@ void test_codegen_method_call() { AstNode* ret_main = ast_make_return(&a, method_call, loc_at(1, 1)); AstNode* main_stmts[] = { let_stmt, ret_main }; AstNode* main_body = ast_make_block(&a, main_stmts, 2, loc_at(1, 1)); - AstNode* main_fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, main_body, false, loc_at(1, 1)); + AstNode* main_fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, main_body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { get_x_fn, main_fn }; AstNode* prog = ast_make_program(&a, fns, 2, structs, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); @@ -486,7 +486,7 @@ void test_codegen_match() { AstNode* main_stmts[] = { let_stmt, outer_if }; AstNode* main_body = ast_make_block(&a, main_stmts, 2, loc_at(1, 1)); - AstNode* main_fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, main_body, false, loc_at(1, 1)); + AstNode* main_fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, main_body, false, NULL, 0, loc_at(1, 1)); AstNode* fns[] = { main_fn }; AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, enums, 1, NULL, 0, loc_at(1, 1));