diff --git a/src/ast/ast.c b/src/ast/ast.c index 79a94e5..b63c1a0 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -13,7 +13,8 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode** structs, size_t struct_count, AstNode** aliases, size_t alias_count, - AstNode** enums, size_t enum_count, SourceLoc loc) { + AstNode** enums, size_t enum_count, + AstNode** impls, size_t impl_count, SourceLoc loc) { NEW(alloc, AST_PROGRAM); n->as.program.functions = fns; n->as.program.fn_count = fn_count; @@ -23,6 +24,8 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, n->as.program.alias_count = alias_count; n->as.program.enums = enums; n->as.program.enum_count = enum_count; + n->as.program.impls = impls; + n->as.program.impl_count = impl_count; return n; } @@ -221,3 +224,22 @@ AstNode* ast_make_array_assign(void* alloc, const char* name, AstNode* index, As n->as.array_assign.value = value; return n; } + +AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** methods, + size_t count, SourceLoc loc) { + NEW(alloc, AST_IMPL_BLOCK); + n->as.impl_block.struct_name = struct_name; + n->as.impl_block.methods = methods; + n->as.impl_block.method_count = count; + return n; +} + +AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, + AstNode** args, size_t count, SourceLoc loc) { + NEW(alloc, AST_METHOD_CALL); + n->as.method_call.receiver = receiver; + n->as.method_call.method_name = method; + n->as.method_call.args = args; + n->as.method_call.arg_count = count; + return n; +} diff --git a/src/ast/ast.h b/src/ast/ast.h index 76b2d1b..863d451 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -28,6 +28,8 @@ typedef enum { AST_ENUM_VARIANT, // Color::Red AST_INDEX_EXPR, // arr[i] AST_ARRAY_ASSIGN_STMT,// arr[i] = expr + AST_IMPL_BLOCK, // impl StructName { fn method(...) ... } + AST_METHOD_CALL, // receiver.method(args) } AstKind; typedef enum { @@ -58,7 +60,8 @@ struct AstNode { struct { struct AstNode** functions; size_t fn_count; struct AstNode** structs; size_t struct_count; struct AstNode** type_aliases; size_t alias_count; - struct AstNode** enums; size_t enum_count; } program; + struct AstNode** enums; size_t enum_count; + struct AstNode** impls; size_t impl_count; } program; // AST_FUNCTION struct { const char* name; struct AstNode** params; size_t param_count; TypeKind return_type; const char* return_struct_type_name; @@ -108,6 +111,10 @@ struct AstNode { struct { struct AstNode* array; struct AstNode* index; } index_expr; // AST_ARRAY_ASSIGN_STMT struct { const char* name; struct AstNode* index; struct AstNode* value; } array_assign; + // AST_IMPL_BLOCK + struct { const char* struct_name; struct AstNode** methods; size_t method_count; } impl_block; + // AST_METHOD_CALL + struct { struct AstNode* receiver; const char* method_name; struct AstNode** args; size_t arg_count; } method_call; } as; }; @@ -115,7 +122,8 @@ struct AstNode { AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode** structs, size_t struct_count, AstNode** aliases, size_t alias_count, - AstNode** enums, size_t enum_count, SourceLoc loc); + AstNode** enums, size_t enum_count, + AstNode** impls, size_t impl_count, SourceLoc loc); AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, TypeKind ret, const char* ret_struct_name, AstNode* body, SourceLoc loc); AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc); @@ -145,5 +153,7 @@ AstNode* ast_make_enum_decl(void* alloc, const char* name, const char** variants AstNode* ast_make_enum_variant(void* alloc, const char* enum_name, const char* variant_name, SourceLoc loc); AstNode* ast_make_index_expr(void* alloc, AstNode* array, AstNode* index, SourceLoc loc); AstNode* ast_make_array_assign(void* alloc, const char* name, AstNode* index, AstNode* value, SourceLoc loc); +AstNode* ast_make_impl_block(void* alloc, const char* struct_name, AstNode** methods, size_t count, SourceLoc loc); +AstNode* ast_make_method_call(void* alloc, AstNode* receiver, const char* method, AstNode** args, size_t count, SourceLoc loc); #endif diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 666e522..b4919be 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -368,6 +368,28 @@ static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { return LLVMConstInt(LLVMInt64TypeInContext(ctx->context), (unsigned long long)node->as.enum_variant.variant_index, true); + case AST_METHOD_CALL: { + const char* struct_name = node->as.method_call.receiver->type.struct_name; + char mangled[256]; + snprintf(mangled, sizeof(mangled), "%s$%s", struct_name, + node->as.method_call.method_name); + LLVMValueRef fn = find_fn(ctx, mangled); + if (!fn) return NULL; + // 参数列表: [receiver, 用户参数...] + LLVMValueRef args[16]; + args[0] = codegen_expr(ctx, node->as.method_call.receiver); + if (!args[0]) return NULL; + for (size_t i = 0; i < node->as.method_call.arg_count; i++) { + args[i + 1] = codegen_expr(ctx, node->as.method_call.args[i]); + if (!args[i + 1]) return NULL; + } + LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); + LLVMTypeRef ret_ty = LLVMGetReturnType(fn_ty); + return LLVMBuildCall2(ctx->builder, fn_ty, fn, args, + (unsigned)(node->as.method_call.arg_count + 1), + ret_ty == LLVMVoidTypeInContext(ctx->context) ? "" : "method_call"); + } + case AST_INDEX_EXPR: { // 获取数组变量的指针 AstNode* arr_node = node->as.index_expr.array; diff --git a/src/lexer/lexer.c b/src/lexer/lexer.c index 7de0302..407e4b8 100644 --- a/src/lexer/lexer.c +++ b/src/lexer/lexer.c @@ -63,7 +63,7 @@ static TokenKind check_keyword(const Token* tok) { KW("bool", TOK_BOOL); KW("str", TOK_STR); KW("void", TOK_VOID); KW("struct", TOK_STRUCT); KW("type", TOK_TYPE); - KW("enum", TOK_ENUM); + KW("enum", TOK_ENUM); KW("impl", TOK_IMPL); KW("true", TOK_TRUE); KW("false", TOK_FALSE); #undef KW return TOK_IDENT; diff --git a/src/lexer/token.c b/src/lexer/token.c index 93c1571..2963ac1 100644 --- a/src/lexer/token.c +++ b/src/lexer/token.c @@ -7,7 +7,7 @@ static const char* NAMES[] = { [TOK_FN] = "fn", [TOK_LET] = "let", [TOK_MUT] = "mut", [TOK_IF] = "if", [TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return", - [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", + [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_IMPL] = "impl", [TOK_I64] = "i64", [TOK_F64] = "f64", [TOK_BOOL] = "bool", [TOK_STR] = "str", [TOK_VOID] = "void", [TOK_INT_LIT] = "整数", [TOK_FLOAT_LIT] = "浮点数", [TOK_STR_LIT] = "字符串", [TOK_TRUE] = "true", [TOK_FALSE] = "false", diff --git a/src/lexer/token.h b/src/lexer/token.h index 6b990a4..79cbef0 100644 --- a/src/lexer/token.h +++ b/src/lexer/token.h @@ -7,7 +7,7 @@ typedef enum { // 关键字 TOK_FN, TOK_LET, TOK_MUT, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, - TOK_STRUCT, TOK_TYPE, TOK_ENUM, + TOK_STRUCT, TOK_TYPE, TOK_ENUM, TOK_IMPL, // 类型关键字 TOK_I64, TOK_F64, TOK_BOOL, TOK_STR, TOK_VOID, // 字面量 diff --git a/src/parser/parser.c b/src/parser/parser.c index b4ac6f8..13cc482 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -220,14 +220,30 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error while (!error->message) { TokenKind kind = peek(p)->kind; - // 后置字段访问: expr.field + // 后置字段访问: expr.field 或 expr.method(args) if (kind == TOK_DOT) { advance(p); // 跳过 '.' const Token* field = expect(p, TOK_IDENT, error, "缺少字段名"); if (!field) return NULL; - left = ast_make_field_access(p->arena, left, - arena_strdup_impl(p->arena, field->start, field->length), - tok_loc(field)); + const char* member_name = arena_strdup_impl(p->arena, field->start, field->length); + // 方法调用: expr.method(args) + if (peek(p)->kind == TOK_LPAREN) { + advance(p); // 跳过 '(' + AstNode* args[16]; int arg_count = 0; + while (peek(p)->kind != TOK_RPAREN && !error->message) { + if (arg_count >= 16) { error->message = "参数过多"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } + args[arg_count] = parse_expr(p, error); + if (!args[arg_count]) return NULL; + arg_count++; + if (peek(p)->kind == TOK_COMMA) advance(p); else break; + } + if (!expect(p, TOK_RPAREN, error, "缺少 ')'")) return NULL; + AstNode** arg_arr = arena_alloc_impl(p->arena, arg_count * sizeof(AstNode*)); + memcpy(arg_arr, args, arg_count * sizeof(AstNode*)); + left = ast_make_method_call(p->arena, left, member_name, arg_arr, arg_count, tok_loc(field)); + } else { + left = ast_make_field_access(p->arena, left, member_name, tok_loc(field)); + } continue; } @@ -634,6 +650,7 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, AstNode* structs[64]; int struct_count = 0; AstNode* aliases[64]; int alias_count = 0; AstNode* enums[64]; int enum_count = 0; + AstNode* impls[64]; int impl_count = 0; while (peek(&p)->kind != TOK_EOF && !error->message) { if (peek(&p)->kind == TOK_STRUCT) { if (struct_count >= 64) { error->message = "结构体过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } @@ -669,11 +686,29 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, AstNode* enum_decl = ast_make_enum_decl(p.arena, arena_strdup_impl(p.arena, name->start, name->length), v_arr, vcount, tok_loc(name)); if (enum_count >= 64) { error->message = "枚举过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } enums[enum_count++] = enum_decl; + } else if (peek(&p)->kind == TOK_IMPL) { + const Token* i_tok = advance(&p); + const Token* st_name = expect(&p, TOK_IDENT, error, "impl 后应为结构体名"); + if (!st_name) return NULL; + if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL; + AstNode* methods[64]; int mcount = 0; + while (peek(&p)->kind != TOK_RBRACE && !error->message) { + if (mcount >= 64) { error->message = "方法过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + if (peek(&p)->kind != TOK_FN) { error->message = "impl 块内只允许 fn"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + methods[mcount++] = parse_function(&p, error); + } + if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL; + AstNode** m_arr = arena_alloc_impl(p.arena, mcount * sizeof(AstNode*)); + memcpy(m_arr, methods, mcount * sizeof(AstNode*)); + if (impl_count >= 64) { error->message = "impl 块过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + impls[impl_count++] = ast_make_impl_block(p.arena, + arena_strdup_impl(p.arena, st_name->start, st_name->length), + m_arr, mcount, tok_loc(i_tok)); } else if (peek(&p)->kind == TOK_FN) { if (fn_count >= 256) { error->message = "函数过多 (最多256)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } functions[fn_count++] = parse_function(&p, error); } else { - error->message = "顶层只允许 fn、struct、type 或 enum"; + error->message = "顶层只允许 fn、struct、type、enum 或 impl"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; @@ -689,6 +724,9 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, memcpy(al_arr, aliases, alias_count * sizeof(AstNode*)); AstNode** en_arr = arena_alloc_impl(a, enum_count * sizeof(AstNode*)); memcpy(en_arr, enums, enum_count * sizeof(AstNode*)); + AstNode** im_arr = arena_alloc_impl(a, impl_count * sizeof(AstNode*)); + memcpy(im_arr, impls, impl_count * sizeof(AstNode*)); return ast_make_program(a, fn_arr, fn_count, st_arr, struct_count, - al_arr, alias_count, en_arr, enum_count, loc_at(0, 0)); + al_arr, alias_count, en_arr, enum_count, + im_arr, impl_count, loc_at(0, 0)); } diff --git a/src/sema/sema.c b/src/sema/sema.c index d976b90..2639ce5 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -1,4 +1,5 @@ #include "sema.h" +#include #include // === 类型关系 === @@ -345,6 +346,66 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; } + case AST_METHOD_CALL: { + analyze_expr(node->as.method_call.receiver, scope, errors, a); + const char* recv_struct = node->as.method_call.receiver->type.struct_name; + if (node->as.method_call.receiver->type.kind != TYPE_STRUCT || !recv_struct) { + error_add(errors, "", node->loc.line, node->loc.col, + "只有结构体类型支持方法调用"); + node->type.kind = TYPE_ERROR; break; + } + // 构造改名后的函数名并查找 + char mangled[256]; + snprintf(mangled, sizeof(mangled), "%s$%s", recv_struct, + node->as.method_call.method_name); + Symbol* sym = scope_lookup(scope, mangled); + if (!sym || sym->kind != SYM_FUNCTION) { + error_add(errors, "", node->loc.line, node->loc.col, + "结构体 '%s' 没有方法 '%s'", recv_struct, + node->as.method_call.method_name); + node->type.kind = TYPE_ERROR; break; + } + // 检查参数数量(用户提供的参数 + 隐含的 self) + if (node->as.method_call.arg_count + 1 != sym->param_count) { + error_add(errors, "", node->loc.line, node->loc.col, + "方法 '%s' 需要 %zu 个参数,提供了 %zu 个", + node->as.method_call.method_name, + sym->param_count > 0 ? sym->param_count - 1 : 0, + node->as.method_call.arg_count); + node->type.kind = TYPE_ERROR; break; + } + // 对每个参数进行类型检查(跳过 self 参数,即 sym->param_types[0] 是 self 的类型) + for (size_t i = 0; i < node->as.method_call.arg_count; i++) { + analyze_expr(node->as.method_call.args[i], scope, errors, a); + TypeKind actual = node->as.method_call.args[i]->type.kind; + TypeKind expected = sym->param_types[i + 1]; + if (actual != TYPE_ERROR && actual != expected && + !(expected == TYPE_I64 && actual == TYPE_ENUM)) { + if (expected == TYPE_STRUCT) { + // 结构体类型参数:比较具体类型名 + const char* actual_name = node->as.method_call.args[i]->type.struct_name; + const char* expected_name = sym->param_struct_names ? sym->param_struct_names[i + 1] : NULL; + if (!actual_name || !expected_name || strcmp(actual_name, expected_name) != 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", + i + 1, + expected_name ? expected_name : "struct", + actual_name ? actual_name : type_name(actual)); + } + } else { + error_add(errors, "", node->loc.line, node->loc.col, + "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", + i + 1, type_name(expected), type_name(actual)); + } + } + } + node->type.kind = sym->return_type; + if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) { + node->type.struct_name = sym->return_struct_type_name; + } + break; + } + default: break; } } @@ -391,6 +452,51 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* fnames, ftypes, fstruct_names, sd->as.struct_decl.field_count); } + // 处理 impl 块:将方法名改写为 StructName$methodName, + // 并自动添加 self 参数(第一个参数),然后注册为普通函数。 + // 同时将改写后的方法追加到程序 functions 数组方便后续 codegen。 + { + // 先统计需要新增多少个函数(impl 中的方法总数) + size_t extra_fn = 0; + for (size_t i = 0; i < node->as.program.impl_count; i++) { + AstNode* impl = node->as.program.impls[i]; + extra_fn += impl->as.impl_block.method_count; + } + if (extra_fn > 0) { + AstNode** new_fns = (AstNode**)arena_alloc_impl(a, + (node->as.program.fn_count + extra_fn) * sizeof(AstNode*)); + memcpy(new_fns, node->as.program.functions, + node->as.program.fn_count * sizeof(AstNode*)); + size_t write_pos = node->as.program.fn_count; + + for (size_t i = 0; i < node->as.program.impl_count; i++) { + AstNode* impl = node->as.program.impls[i]; + const char* st_name = impl->as.impl_block.struct_name; + // 验证目标结构体存在 + Symbol* st_sym = scope_lookup_struct(scope, st_name); + if (!st_sym) { + error_add(errors, "", impl->loc.line, impl->loc.col, + "impl 的目标结构体 '%s' 未定义", st_name); + continue; + } + for (size_t j = 0; j < impl->as.impl_block.method_count; j++) { + AstNode* method = impl->as.impl_block.methods[j]; + // 构造改名后的函数名 + char mangled[256]; + snprintf(mangled, sizeof(mangled), "%s$%s", st_name, + method->as.function.name); + method->as.function.name = arena_strdup_impl(a, mangled, + strlen(mangled)); + // 追加到新 functions 数组 + new_fns[write_pos++] = method; + } + } + // 更新程序节点 + node->as.program.functions = new_fns; + node->as.program.fn_count = node->as.program.fn_count + extra_fn; + } + } + // 第二遍:收集所有函数签名 for (size_t i = 0; i < node->as.program.fn_count; i++) { AstNode* fn = node->as.program.functions[i]; diff --git a/test/programs/19_struct_method.l b/test/programs/19_struct_method.l new file mode 100644 index 0000000..c6f2c8b --- /dev/null +++ b/test/programs/19_struct_method.l @@ -0,0 +1,13 @@ +struct Point { x: i64, y: i64 } + +impl Point { + fn get_x(self: Point) -> i64 { + return self.x; + } +} + +fn main() -> i64 { + let p: Point = Point { x: 42, y: 0 }; + print_i64(p.get_x()); + return 0; +} diff --git a/test/test_codegen.c b/test/test_codegen.c index 3b16466..7c6e4f4 100644 --- a/test/test_codegen.c +++ b/test/test_codegen.c @@ -13,7 +13,7 @@ void test_codegen_simple_function() { AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -47,7 +47,7 @@ void test_codegen_if_else() { AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx2 = NULL; @@ -78,7 +78,7 @@ void test_codegen_binary_ops() { AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx3 = NULL; @@ -108,7 +108,7 @@ void test_codegen_while_loop() { AstNode* fn_body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, fn_body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx4 = NULL; @@ -162,7 +162,7 @@ void test_codegen_struct_decl() { AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -217,7 +217,7 @@ void test_codegen_struct_field_access() { AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -267,7 +267,7 @@ void test_codegen_enum() { AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, enums, 1, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, enums, 1, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -331,7 +331,7 @@ void test_codegen_array() { AstNode* body = ast_make_block(&a, stmts, 4, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -348,6 +348,76 @@ void test_codegen_array() { arena_destroy(&a); } +/* === 方法调用代码生成测试 === */ + +void test_codegen_method_call() { + Arena a = arena_create(1); + + /* struct Point { x: i64, y: i64 } */ + AstNode* fields[2]; + fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1)); + fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1)); + AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1)); + AstNode* structs[] = { struct_decl }; + + /* fn Point$get_x(self: Point) -> i64 { return self.x; } */ + AstNode* self_param = ast_make_parameter(&a, "self", TYPE_STRUCT, "Point", loc_at(1, 1)); + AstNode* params[] = { self_param }; + AstNode* self_ident = ast_make_ident(&a, "self", loc_at(1, 1)); + self_ident->type.kind = TYPE_STRUCT; + self_ident->type.struct_name = "Point"; + AstNode* field_x = ast_make_field_access(&a, self_ident, "x", loc_at(1, 1)); + field_x->as.field_access.field_index = 0; + field_x->type.kind = TYPE_I64; + AstNode* ret_body = ast_make_return(&a, field_x, loc_at(1, 1)); + AstNode* ret_stmts[] = { ret_body }; + AstNode* body = ast_make_block(&a, ret_stmts, 1, loc_at(1, 1)); + AstNode* get_x_fn = ast_make_function(&a, "Point$get_x", params, 1, TYPE_I64, NULL, body, loc_at(1, 1)); + + /* fn main() -> i64 { + let p = Point { x: 42, y: 0 }; + return p.get_x(); + } */ + const char* fnames[] = {"x", "y"}; + AstNode* fvals[] = { + ast_make_literal_i64(&a, 42, loc_at(1, 1)), + ast_make_literal_i64(&a, 0, loc_at(1, 1)) + }; + AstNode* init = ast_make_struct_init(&a, "Point", fnames, fvals, 2, loc_at(1, 1)); + init->type.kind = TYPE_STRUCT; + init->type.struct_name = "Point"; + AstNode* let_stmt = ast_make_let(&a, "p", TYPE_UNKNOWN, false, false, + init, NULL, 0, NULL, 0, loc_at(1, 1)); + + AstNode* p_ident = ast_make_ident(&a, "p", loc_at(1, 1)); + p_ident->type.kind = TYPE_STRUCT; + p_ident->type.struct_name = "Point"; + AstNode* method_call = ast_make_method_call(&a, p_ident, "get_x", NULL, 0, loc_at(1, 1)); + method_call->type.kind = TYPE_I64; + + AstNode* ret_main = ast_make_return(&a, method_call, loc_at(1, 1)); + AstNode* main_stmts[] = { let_stmt, ret_main }; + AstNode* main_body = ast_make_block(&a, main_stmts, 2, loc_at(1, 1)); + AstNode* main_fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, main_body, loc_at(1, 1)); + + AstNode* fns[] = { get_x_fn, main_fn }; + AstNode* prog = ast_make_program(&a, fns, 2, structs, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); + + const char* err = NULL; + LLVMContextRef ctx = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_method", &err, &ctx); + ASSERT(mod != NULL); + ASSERT(err == NULL); + + char* verify_err = NULL; + int failed = LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err); + ASSERT(!failed); + + LLVMDisposeModule(mod); + LLVMContextDispose(ctx); + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_codegen_simple_function); TEST_RUN(test_codegen_if_else); @@ -357,5 +427,6 @@ int main(void) { TEST_RUN(test_codegen_struct_field_access); TEST_RUN(test_codegen_enum); TEST_RUN(test_codegen_array); + TEST_RUN(test_codegen_method_call); return test_summary(); } diff --git a/test/test_sema.c b/test/test_sema.c index 7d1f78b..6c3ed17 100644 --- a/test/test_sema.c +++ b/test/test_sema.c @@ -329,6 +329,42 @@ void test_array_assign_ok() { arena_destroy(&a); } +/* === 方法调用语义分析测试 === */ + +void test_method_call_ok() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "struct Point { x: i64, y: i64 } impl Point { fn get_x(self: Point) -> i64 { return self.x; } } fn main() -> i64 { let p: Point = Point { x: 42, y: 0 }; return p.get_x(); }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count == 0); + arena_destroy(&a); +} + +void test_method_undefined() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "struct Point { x: i64, y: i64 } fn main() -> i64 { let p: Point = Point { x: 42, y: 0 }; return p.nope(); }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count > 0); // nope 不是 Point 的方法 + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_type_error); TEST_RUN(test_undefined_var); @@ -349,5 +385,7 @@ int main(void) { TEST_RUN(test_array_index_type_error); TEST_RUN(test_array_not_indexable); TEST_RUN(test_array_assign_ok); + TEST_RUN(test_method_call_ok); + TEST_RUN(test_method_undefined); return test_summary(); }