feat: trait 接口系统 — trait Show { fn method } + extend Trait Struct { }

This commit is contained in:
2026-06-06 16:41:21 +08:00
parent 9169796b77
commit b3b3d285f9
9 changed files with 126 additions and 12 deletions
+8
View File
@@ -274,3 +274,11 @@ AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, Sour
n->as.use_decl.item = item; n->as.use_decl.item = item;
return n; return n;
} }
AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc) {
NEW(alloc, AST_TRAIT_DECL);
n->as.trait_decl.name = name;
n->as.trait_decl.methods = methods;
n->as.trait_decl.method_count = count;
return n;
}
+4
View File
@@ -32,6 +32,7 @@ typedef enum {
AST_METHOD_CALL, // receiver.method(args) AST_METHOD_CALL, // receiver.method(args)
AST_MOD_DECL, // mod foo; AST_MOD_DECL, // mod foo;
AST_USE_DECL, // use foo::bar; AST_USE_DECL, // use foo::bar;
AST_TRAIT_DECL, // trait Name { fn ... }
} AstKind; } AstKind;
typedef enum { typedef enum {
@@ -127,6 +128,8 @@ struct AstNode {
struct { const char* name; struct AstNode* ast; } mod_decl; struct { const char* name; struct AstNode* ast; } mod_decl;
// AST_USE_DECL // AST_USE_DECL
struct { const char* path; const char* item; } use_decl; struct { const char* path; const char* item; } use_decl;
// AST_TRAIT_DECL
struct { const char* name; struct AstNode** methods; size_t method_count; } trait_decl;
} as; } as;
}; };
@@ -175,5 +178,6 @@ AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** met
AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc); AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, const char** arg_names, size_t count, SourceLoc loc);
AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc); AstNode* ast_make_mod_decl(void* alloc, const char* name, AstNode* sub_ast, SourceLoc loc);
AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc); AstNode* ast_make_use_decl(void* alloc, const char* path, const char* item, SourceLoc loc);
AstNode* ast_make_trait_decl(void* alloc, const char* name, AstNode** methods, size_t count, SourceLoc loc);
#endif #endif
+6 -2
View File
@@ -431,8 +431,12 @@ static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) {
case AST_METHOD_CALL: { case AST_METHOD_CALL: {
const char* struct_name = node->as.method_call.receiver->type.struct_name; const char* struct_name = node->as.method_call.receiver->type.struct_name;
char mangled[256]; char mangled[256];
snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, // 若 method_name 已含 $trait 方法,sema 已设置全限定名),直接用
node->as.method_call.method_name); if (strchr(node->as.method_call.method_name, '$'))
snprintf(mangled, sizeof(mangled), "%s", node->as.method_call.method_name);
else
snprintf(mangled, sizeof(mangled), "%s$%s", struct_name,
node->as.method_call.method_name);
LLVMValueRef fn = find_fn(ctx, mangled); LLVMValueRef fn = find_fn(ctx, mangled);
if (!fn) return NULL; if (!fn) return NULL;
// 参数列表: [receiver, 用户参数...] // 参数列表: [receiver, 用户参数...]
+1
View File
@@ -67,6 +67,7 @@ static TokenKind check_keyword(const Token* tok) {
KW("struct", TOK_STRUCT); KW("type", TOK_TYPE); KW("struct", TOK_STRUCT); KW("type", TOK_TYPE);
KW("enum", TOK_ENUM); KW("extend", TOK_EXTEND); KW("match", TOK_MATCH); KW("enum", TOK_ENUM); KW("extend", TOK_EXTEND); KW("match", TOK_MATCH);
KW("pub", TOK_PUB); KW("mod", TOK_MOD); KW("use", TOK_USE); KW("pub", TOK_PUB); KW("mod", TOK_MOD); KW("use", TOK_USE);
KW("trait", TOK_TRAIT); KW("Self", TOK_SELF);
KW("_", TOK_UNDERSCORE); KW("_", TOK_UNDERSCORE);
KW("true", TOK_TRUE); KW("false", TOK_FALSE); KW("true", TOK_TRUE); KW("false", TOK_FALSE);
#undef KW #undef KW
+1
View File
@@ -7,6 +7,7 @@
static const char* NAMES[] = { static const char* NAMES[] = {
[TOK_FN] = "fn", [TOK_LET] = "let", [TOK_VAR] = "var", [TOK_IF] = "if", [TOK_GUARD] = "guard", [TOK_FN] = "fn", [TOK_LET] = "let", [TOK_VAR] = "var", [TOK_IF] = "if", [TOK_GUARD] = "guard",
[TOK_PUB] = "pub", [TOK_MOD] = "mod", [TOK_USE] = "use", [TOK_PUB] = "pub", [TOK_MOD] = "mod", [TOK_USE] = "use",
[TOK_TRAIT] = "trait", [TOK_SELF] = "Self",
[TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return", [TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return",
[TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_EXTEND] = "extend", [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_EXTEND] = "extend",
[TOK_MATCH] = "match", [TOK_MATCH] = "match",
+1
View File
@@ -8,6 +8,7 @@ typedef enum {
// 关键字 // 关键字
TOK_FN, TOK_LET, TOK_VAR, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, TOK_GUARD, TOK_FN, TOK_LET, TOK_VAR, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, TOK_GUARD,
TOK_STRUCT, TOK_TYPE, TOK_ENUM, TOK_EXTEND, TOK_MATCH, TOK_PUB, TOK_MOD, TOK_USE, TOK_STRUCT, TOK_TYPE, TOK_ENUM, TOK_EXTEND, TOK_MATCH, TOK_PUB, TOK_MOD, TOK_USE,
TOK_TRAIT, TOK_SELF,
// 类型关键字 // 类型关键字
TOK_I32, TOK_I64, TOK_U64, TOK_F64, TOK_BOOL, TOK_CHAR, TOK_STR, TOK_VOID, TOK_I32, TOK_I64, TOK_U64, TOK_F64, TOK_BOOL, TOK_CHAR, TOK_STR, TOK_VOID,
// 字面量 // 字面量
+63 -9
View File
@@ -499,6 +499,14 @@ static TypeInfo parse_type_expr(Parser* p, ErrorInfo* error) {
const Token* t = peek(p); const Token* t = peek(p);
TypeInfo ti = {0}; TypeInfo ti = {0};
// Self 类型(trait 中引用实现者自身类型)
if (t->kind == TOK_SELF) {
advance(p);
ti.kind = TYPE_STRUCT;
ti.struct_name = "Self";
return ti;
}
// 解析基础类型 // 解析基础类型
if (tok_is_type(t->kind)) { if (tok_is_type(t->kind)) {
advance(p); advance(p);
@@ -949,8 +957,14 @@ static AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) {
ret_struct_name = rti.struct_name; ret_struct_name = rti.struct_name;
} }
AstNode* body = parse_block(p, error); // trait 方法签名或普通函数体
if (!body) return NULL; AstNode* body = NULL;
if (match(p, TOK_SEMICOLON)) {
body = NULL; // trait 方法签名,无实现
} else {
body = parse_block(p, error);
if (!body) return NULL;
}
AstNode** parr = arena_alloc_impl(p->arena, pcount * sizeof(AstNode*)); AstNode** parr = arena_alloc_impl(p->arena, pcount * sizeof(AstNode*));
memcpy(parr, params, pcount * sizeof(AstNode*)); memcpy(parr, params, pcount * sizeof(AstNode*));
@@ -1026,7 +1040,31 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count,
if (peek(&p)->kind == TOK_PUB) { if (peek(&p)->kind == TOK_PUB) {
is_pub = true; advance(&p); is_pub = true; advance(&p);
} }
if (peek(&p)->kind == TOK_STRUCT) { if (peek(&p)->kind == TOK_TRAIT) {
const Token* tt = advance(&p);
const Token* tname = expect(&p, TOK_IDENT, error, "trait 后应为接口名");
if (!tname) return NULL;
if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL;
AstNode* methods[64]; int mcount = 0;
while (peek(&p)->kind != TOK_RBRACE && !error->message) {
if (mcount >= 64) { error->message = "trait 方法过多(最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
if (peek(&p)->kind != TOK_FN) { error->message = "trait 内只允许 fn"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
// trait 方法只解析签名(body 为空)
AstNode* m = parse_function(&p, false, error);
if (!m) return NULL;
m->as.function.body = NULL; // trait 方法无实现
methods[mcount++] = m;
if (peek(&p)->kind == TOK_COMMA) advance(&p);
}
if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL;
AstNode** marr = arena_alloc_impl(p.arena, mcount * sizeof(AstNode*));
memcpy(marr, methods, mcount * sizeof(AstNode*));
// 复用 impl_count 存储 trait(共用计数)
if (impl_count >= 64) { error->message = "trait 过多(最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
impls[impl_count++] = ast_make_trait_decl(p.arena,
arena_strdup_impl(p.arena, tname->start, tname->length),
marr, mcount, tok_loc(tt));
} else if (peek(&p)->kind == TOK_STRUCT) {
if (struct_count >= 64) { error->message = "结构体过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } if (struct_count >= 64) { error->message = "结构体过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
structs[struct_count++] = parse_struct_decl(&p, error); structs[struct_count++] = parse_struct_decl(&p, error);
} else if (peek(&p)->kind == TOK_TYPE) { } else if (peek(&p)->kind == TOK_TYPE) {
@@ -1083,22 +1121,38 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count,
enums[enum_count++] = enum_decl; enums[enum_count++] = enum_decl;
} else if (peek(&p)->kind == TOK_EXTEND) { } else if (peek(&p)->kind == TOK_EXTEND) {
const Token* i_tok = advance(&p); const Token* i_tok = advance(&p);
const Token* st_name = expect(&p, TOK_IDENT, error, "extend 后应为结构体名"); const Token* first = expect(&p, TOK_IDENT, error, "extend 后应为结构体名");
if (!st_name) return NULL; if (!first) return NULL;
const char* trait_name = NULL;
const char* struct_name;
// extend Trait Struct { ... }trait 实现:两个标识符)
if (peek(&p)->kind == TOK_IDENT) {
trait_name = arena_strdup_impl(p.arena, first->start, first->length);
struct_name = arena_strdup_impl(p.arena, peek(&p)->start, peek(&p)->length);
advance(&p);
} else {
struct_name = arena_strdup_impl(p.arena, first->start, first->length);
}
if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL; if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL;
AstNode* methods[64]; int mcount = 0; AstNode* methods[64]; int mcount = 0;
while (peek(&p)->kind != TOK_RBRACE && !error->message) { while (peek(&p)->kind != TOK_RBRACE && !error->message) {
if (mcount >= 64) { error->message = "方法过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } if (mcount >= 64) { error->message = "方法过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
if (peek(&p)->kind != TOK_FN) { error->message = "extend 块内只允许 fn"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } if (peek(&p)->kind != TOK_FN) { error->message = "extend 块内只允许 fn"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
methods[mcount++] = parse_function(&p, false, error); AstNode* m = parse_function(&p, false, error);
if (!m) return NULL;
// trait 实现: 方法名 mangled 为 TraitName$methodName
if (trait_name) {
char* mn = arena_alloc_impl(p.arena, strlen(trait_name) + strlen(m->as.function.name) + 4);
sprintf(mn, "%s$%s", trait_name, m->as.function.name);
m->as.function.name = mn;
}
methods[mcount++] = m;
} }
if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL; if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL;
AstNode** m_arr = arena_alloc_impl(p.arena, mcount * sizeof(AstNode*)); AstNode** m_arr = arena_alloc_impl(p.arena, mcount * sizeof(AstNode*));
memcpy(m_arr, methods, mcount * sizeof(AstNode*)); memcpy(m_arr, methods, mcount * sizeof(AstNode*));
if (impl_count >= 64) { error->message = "extend 块过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } if (impl_count >= 64) { error->message = "extend 块过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; }
impls[impl_count++] = ast_make_impl_block(p.arena, impls[impl_count++] = ast_make_impl_block(p.arena, struct_name, m_arr, mcount, tok_loc(i_tok));
arena_strdup_impl(p.arena, st_name->start, st_name->length),
m_arr, mcount, tok_loc(i_tok));
} else if (peek(&p)->kind == TOK_MOD) { } else if (peek(&p)->kind == TOK_MOD) {
advance(&p); advance(&p);
const Token* mn = expect(&p, TOK_IDENT, error, "mod 后应为模块名"); const Token* mn = expect(&p, TOK_IDENT, error, "mod 后应为模块名");
+24 -1
View File
@@ -558,6 +558,26 @@ static void analyze_method_call(AstNode* node, Scope* scope, ErrorList* errors,
snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct, snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct,
node->as.method_call.method_name); node->as.method_call.method_name);
Symbol* sym = scope_lookup(scope, mangled); Symbol* sym = scope_lookup(scope, mangled);
// trait 方法 fallback: 搜索所有作用域中以 $method_name 结尾的符号
if (!sym || sym->kind != SYM_FUNCTION) {
char suffix[256];
snprintf(suffix, sizeof(suffix), "$%s", node->as.method_call.method_name);
size_t suf_len = strlen(suffix);
for (const Scope* sc = scope; sc; sc = sc->parent) {
for (Symbol* s = sc->head; s; s = s->next) {
if (s->kind == SYM_FUNCTION) {
size_t name_len = strlen(s->name);
if (name_len > suf_len && strcmp(s->name + name_len - suf_len, suffix) == 0) {
sym = s;
// 更新 method_name 为找到的完整函数名(codegen 需要)
node->as.method_call.method_name = s->name;
break;
}
}
}
if (sym) break;
}
}
if (!sym || sym->kind != SYM_FUNCTION) { if (!sym || sym->kind != SYM_FUNCTION) {
error_add(errors, "<sema>", node->loc.line, node->loc.col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 没有方法 '%s'", recv_struct, "结构体 '%s' 没有方法 '%s'", recv_struct,
@@ -657,6 +677,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
size_t extra_fn = 0; size_t extra_fn = 0;
for (size_t i = 0; i < node->as.program.impl_count; i++) { for (size_t i = 0; i < node->as.program.impl_count; i++) {
AstNode* impl = node->as.program.impls[i]; AstNode* impl = node->as.program.impls[i];
if (impl->kind == AST_TRAIT_DECL) continue;
extra_fn += impl->as.impl_block.method_count; extra_fn += impl->as.impl_block.method_count;
} }
if (extra_fn > 0) { if (extra_fn > 0) {
@@ -668,6 +689,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
for (size_t i = 0; i < node->as.program.impl_count; i++) { for (size_t i = 0; i < node->as.program.impl_count; i++) {
AstNode* impl = node->as.program.impls[i]; AstNode* impl = node->as.program.impls[i];
if (impl->kind == AST_TRAIT_DECL) continue; // 跳过 trait 声明
const char* st_name = impl->as.impl_block.struct_name; const char* st_name = impl->as.impl_block.struct_name;
// 验证目标结构体存在 // 验证目标结构体存在
Symbol* st_sym = scope_lookup_struct(scope, st_name); Symbol* st_sym = scope_lookup_struct(scope, st_name);
@@ -802,7 +824,8 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
const char* saved_name = current_return_struct_name; const char* saved_name = current_return_struct_name;
current_return_type = node->as.function.return_type; current_return_type = node->as.function.return_type;
current_return_struct_name = node->as.function.return_struct_type_name; current_return_struct_name = node->as.function.return_struct_type_name;
analyze_node(node->as.function.body, fn_scope, errors, a); if (node->as.function.body)
analyze_node(node->as.function.body, fn_scope, errors, a);
current_return_type = saved; current_return_type = saved;
current_return_struct_name = saved_name; current_return_struct_name = saved_name;
break; break;
+18
View File
@@ -0,0 +1,18 @@
trait Show {
fn show(self: Self) -> void;
}
struct Point { x: i64, y: i64 }
extend Show Point {
fn show(self: Point) -> void {
print_i64(self.x);
print_i64(self.y);
}
}
fn main() -> i64 {
let p = Point { x: 10, y: 20 };
p.show();
return 0;
}