diff --git a/src/ast/ast.c b/src/ast/ast.c index c918e2d..3730d54 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -274,3 +274,11 @@ AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, Sour n->as.use_decl.item = item; return n; } + +AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc) { + NEW(alloc, AST_TRAIT_DECL); + n->as.trait_decl.name = name; + n->as.trait_decl.methods = methods; + n->as.trait_decl.method_count = count; + return n; +} diff --git a/src/ast/ast.h b/src/ast/ast.h index 39d1b48..74907eb 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -32,6 +32,7 @@ typedef enum { AST_METHOD_CALL, // receiver.method(args) AST_MOD_DECL, // mod foo; AST_USE_DECL, // use foo::bar; + AST_TRAIT_DECL, // trait Name { fn ... } } AstKind; typedef enum { @@ -127,6 +128,8 @@ struct AstNode { struct { const char* name; struct AstNode* ast; } mod_decl; // AST_USE_DECL struct { const char* path; const char* item; } use_decl; + // AST_TRAIT_DECL + struct { const char* name; struct AstNode** methods; size_t method_count; } trait_decl; } as; }; @@ -175,5 +178,6 @@ AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** met AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc); AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc); AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc); +AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc); #endif diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index b3f7f3d..187a61d 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -431,8 +431,12 @@ static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { case AST_METHOD_CALL: { const char* struct_name = node->as.method_call.receiver->type.struct_name; char mangled[256]; - snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, - node->as.method_call.method_name); + // 若 method_name 已含 $(trait 方法,sema 已设置全限定名),直接用 + if (strchr(node->as.method_call.method_name, '$')) + snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name); + else + snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, + node->as.method_call.method_name); LLVMValueRef fn = find_fn(ctx, mangled); if (!fn) return NULL; // 参数列表: [receiver, 用户参数...] diff --git a/src/lexer/lexer.c b/src/lexer/lexer.c index 5e5ca65..28a3722 100644 --- a/src/lexer/lexer.c +++ b/src/lexer/lexer.c @@ -67,6 +67,7 @@ static TokenKind check_keyword(const Token* tok) { KW("struct", TOK_STRUCT); KW("type", TOK_TYPE); KW("enum", TOK_ENUM); KW("extend", TOK_EXTEND); KW("match", TOK_MATCH); KW("pub", TOK_PUB); KW("mod", TOK_MOD); KW("use", TOK_USE); + KW("trait", TOK_TRAIT); KW("Self", TOK_SELF); KW("_", TOK_UNDERSCORE); KW("true", TOK_TRUE); KW("false", TOK_FALSE); #undef KW diff --git a/src/lexer/token.c b/src/lexer/token.c index fcc3a56..7779244 100644 --- a/src/lexer/token.c +++ b/src/lexer/token.c @@ -7,6 +7,7 @@ static const char* NAMES[] = { [TOK_FN] = "fn", [TOK_LET] = "let", [TOK_VAR] = "var", [TOK_IF] = "if", [TOK_GUARD] = "guard", [TOK_PUB] = "pub", [TOK_MOD] = "mod", [TOK_USE] = "use", + [TOK_TRAIT] = "trait", [TOK_SELF] = "Self", [TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return", [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_EXTEND] = "extend", [TOK_MATCH] = "match", diff --git a/src/lexer/token.h b/src/lexer/token.h index 2f5eddb..5b3c81f 100644 --- a/src/lexer/token.h +++ b/src/lexer/token.h @@ -8,6 +8,7 @@ typedef enum { // 关键字 TOK_FN, TOK_LET, TOK_VAR, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, TOK_GUARD, TOK_STRUCT, TOK_TYPE, TOK_ENUM, TOK_EXTEND, TOK_MATCH, TOK_PUB, TOK_MOD, TOK_USE, + TOK_TRAIT, TOK_SELF, // 类型关键字 TOK_I32, TOK_I64, TOK_U64, TOK_F64, TOK_BOOL, TOK_CHAR, TOK_STR, TOK_VOID, // 字面量 diff --git a/src/parser/parser.c b/src/parser/parser.c index aaed4be..642e575 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -499,6 +499,14 @@ static TypeInfo parse_type_expr(Parser* p, ErrorInfo* error) { const Token* t = peek(p); TypeInfo ti = {0}; + // Self 类型(trait 中引用实现者自身类型) + if (t->kind == TOK_SELF) { + advance(p); + ti.kind = TYPE_STRUCT; + ti.struct_name = "Self"; + return ti; + } + // 解析基础类型 if (tok_is_type(t->kind)) { advance(p); @@ -949,8 +957,14 @@ static AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) { ret_struct_name = rti.struct_name; } - AstNode* body = parse_block(p, error); - if (!body) return NULL; + // trait 方法签名或普通函数体 + AstNode* body = NULL; + if (match(p, TOK_SEMICOLON)) { + body = NULL; // trait 方法签名,无实现 + } else { + body = parse_block(p, error); + if (!body) return NULL; + } AstNode** parr = arena_alloc_impl(p->arena, pcount * sizeof(AstNode*)); memcpy(parr, params, pcount * sizeof(AstNode*)); @@ -1026,7 +1040,31 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, if (peek(&p)->kind == TOK_PUB) { is_pub = true; advance(&p); } - if (peek(&p)->kind == TOK_STRUCT) { + if (peek(&p)->kind == TOK_TRAIT) { + const Token* tt = advance(&p); + const Token* tname = expect(&p, TOK_IDENT, error, "trait 后应为接口名"); + if (!tname) return NULL; + if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL; + AstNode* methods[64]; int mcount = 0; + while (peek(&p)->kind != TOK_RBRACE && !error->message) { + if (mcount >= 64) { error->message = "trait 方法过多(最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + if (peek(&p)->kind != TOK_FN) { error->message = "trait 内只允许 fn"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + // trait 方法只解析签名(body 为空) + AstNode* m = parse_function(&p, false, error); + if (!m) return NULL; + m->as.function.body = NULL; // trait 方法无实现 + methods[mcount++] = m; + if (peek(&p)->kind == TOK_COMMA) advance(&p); + } + if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL; + AstNode** marr = arena_alloc_impl(p.arena, mcount * sizeof(AstNode*)); + memcpy(marr, methods, mcount * sizeof(AstNode*)); + // 复用 impl_count 存储 trait(共用计数) + if (impl_count >= 64) { error->message = "trait 过多(最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + impls[impl_count++] = ast_make_trait_decl(p.arena, + arena_strdup_impl(p.arena, tname->start, tname->length), + marr, mcount, tok_loc(tt)); + } else if (peek(&p)->kind == TOK_STRUCT) { if (struct_count >= 64) { error->message = "结构体过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } structs[struct_count++] = parse_struct_decl(&p, error); } else if (peek(&p)->kind == TOK_TYPE) { @@ -1083,22 +1121,38 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, enums[enum_count++] = enum_decl; } else if (peek(&p)->kind == TOK_EXTEND) { const Token* i_tok = advance(&p); - const Token* st_name = expect(&p, TOK_IDENT, error, "extend 后应为结构体名"); - if (!st_name) return NULL; + const Token* first = expect(&p, TOK_IDENT, error, "extend 后应为结构体名"); + if (!first) return NULL; + const char* trait_name = NULL; + const char* struct_name; + // extend Trait Struct { ... }(trait 实现:两个标识符) + if (peek(&p)->kind == TOK_IDENT) { + trait_name = arena_strdup_impl(p.arena, first->start, first->length); + struct_name = arena_strdup_impl(p.arena, peek(&p)->start, peek(&p)->length); + advance(&p); + } else { + struct_name = arena_strdup_impl(p.arena, first->start, first->length); + } if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL; AstNode* methods[64]; int mcount = 0; while (peek(&p)->kind != TOK_RBRACE && !error->message) { if (mcount >= 64) { error->message = "方法过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } if (peek(&p)->kind != TOK_FN) { error->message = "extend 块内只允许 fn"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } - methods[mcount++] = parse_function(&p, false, error); + AstNode* m = parse_function(&p, false, error); + if (!m) return NULL; + // trait 实现: 方法名 mangled 为 TraitName$methodName + if (trait_name) { + char* mn = arena_alloc_impl(p.arena, strlen(trait_name) + strlen(m->as.function.name) + 4); + sprintf(mn, "%s$%s", trait_name, m->as.function.name); + m->as.function.name = mn; + } + methods[mcount++] = m; } if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL; AstNode** m_arr = arena_alloc_impl(p.arena, mcount * sizeof(AstNode*)); memcpy(m_arr, methods, mcount * sizeof(AstNode*)); if (impl_count >= 64) { error->message = "extend 块过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } - impls[impl_count++] = ast_make_impl_block(p.arena, - arena_strdup_impl(p.arena, st_name->start, st_name->length), - m_arr, mcount, tok_loc(i_tok)); + impls[impl_count++] = ast_make_impl_block(p.arena, struct_name, m_arr, mcount, tok_loc(i_tok)); } else if (peek(&p)->kind == TOK_MOD) { advance(&p); const Token* mn = expect(&p, TOK_IDENT, error, "mod 后应为模块名"); diff --git a/src/sema/sema.c b/src/sema/sema.c index 0c4eb0a..75d9652 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -558,6 +558,26 @@ static void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors, snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct, node->as.method_call.method_name); Symbol* sym = scope_lookup(scope, mangled); + // trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾的符号 + if (!sym || sym->kind != SYM_FUNCTION) { + char suffix[256]; + snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name); + size_t suf_len = strlen(suffix); + for (const Scope* sc = scope; sc; sc = sc->parent) { + for (Symbol* s = sc->head; s; s = s->next) { + if (s->kind == SYM_FUNCTION) { + size_t name_len = strlen(s->name); + if (name_len > suf_len && strcmp(s->name + name_len - suf_len, suffix) == 0) { + sym = s; + // 更新 method_name 为找到的完整函数名(codegen 需要) + node->as.method_call.method_name = s->name; + break; + } + } + } + if (sym) break; + } + } if (!sym || sym->kind != SYM_FUNCTION) { error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 没有方法 '%s'", recv_struct, @@ -657,6 +677,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* size_t extra_fn = 0; for (size_t i = 0; i < node->as.program.impl_count; i++) { AstNode* impl = node->as.program.impls[i]; + if (impl->kind == AST_TRAIT_DECL) continue; extra_fn += impl->as.impl_block.method_count; } if (extra_fn > 0) { @@ -668,6 +689,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* for (size_t i = 0; i < node->as.program.impl_count; i++) { AstNode* impl = node->as.program.impls[i]; + if (impl->kind == AST_TRAIT_DECL) continue; // 跳过 trait 声明 const char* st_name = impl->as.impl_block.struct_name; // 验证目标结构体存在 Symbol* st_sym = scope_lookup_struct(scope, st_name); @@ -802,7 +824,8 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* const char* saved_name = current_return_struct_name; current_return_type = node->as.function.return_type; current_return_struct_name = node->as.function.return_struct_type_name; - analyze_node(node->as.function.body, fn_scope, errors, a); + if (node->as.function.body) + analyze_node(node->as.function.body, fn_scope, errors, a); current_return_type = saved; current_return_struct_name = saved_name; break; diff --git a/test/programs/35_trait.l b/test/programs/35_trait.l new file mode 100644 index 0000000..3899f67 --- /dev/null +++ b/test/programs/35_trait.l @@ -0,0 +1,18 @@ +trait Show { + fn show(self: Self) -> void; +} + +struct Point { x: i64, y: i64 } + +extend Show Point { + fn show(self: Point) -> void { + print_i64(self.x); + print_i64(self.y); + } +} + +fn main() -> i64 { + let p = Point { x: 10, y: 20 }; + p.show(); + return 0; +}