From da9a7065dd81596db6da5a447d095b65afb098a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E8=88=AA=E5=AE=87?= <3364451258@qq.com> Date: Fri, 5 Jun 2026 13:29:31 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20struct=E5=8F=82=E6=95=B0/=E8=BF=94?= =?UTF-8?q?=E5=9B=9E=E5=80=BC=20+=20SourceLoc=20+=20=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E8=A1=A5=E5=85=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - struct 可作函数参数和返回值 (fn make_point -> Point) - SourceLoc 抽象: 所有 ast_make_* 参数 + AstNode 迁移完毕 - sema: +4 struct 类型检查测试 (字段类型/未定义/数量/嵌套) - codegen: +2 struct IR 生成测试 (decl + field_access) - 新增集成测试 14_struct_fn.l 测试: 104 单元 + 14 集成 = 全部通过 --- include/l_lang.h | 11 +++ src/ast/ast.c | 45 +++++----- src/ast/ast.h | 48 +++++----- src/codegen/codegen.c | 40 +++++++-- src/lexer/token.h | 5 ++ src/parser/parser.c | 92 ++++++++++--------- src/sema/sema.c | 132 +++++++++++++++++---------- src/sema/symbol.c | 9 +- src/sema/symbol.h | 5 +- test/programs/14_struct_fn.l | 19 ++++ test/test_codegen.c | 169 +++++++++++++++++++++++++++++------ test/test_sema.c | 74 +++++++++++++++ 12 files changed, 481 insertions(+), 168 deletions(-) create mode 100644 test/programs/14_struct_fn.l diff --git a/include/l_lang.h b/include/l_lang.h index 628f3b3..728eb67 100644 --- a/include/l_lang.h +++ b/include/l_lang.h @@ -29,6 +29,17 @@ static inline const char* type_name(TypeKind kind) { } } +// === 源码位置(替代分散的 line/col 参数)=== +typedef struct { + int line; + int col; +} SourceLoc; + +// 手动创建 SourceLoc +static inline SourceLoc loc_at(int line, int col) { + return (SourceLoc){ .line = line, .col = col }; +} + // === 向前声明 === typedef struct Token Token; typedef struct AstNode AstNode; diff --git a/src/ast/ast.c b/src/ast/ast.c index b985695..a9e552a 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -6,10 +6,10 @@ AstNode* n = (AstNode*)arena_alloc_impl(alloc, sizeof(AstNode)); \ if (!n) return NULL; \ n->kind = (k); n->type.kind = TYPE_UNKNOWN; n->type.struct_name = NULL; \ - n->line = line; n->col = col + n->loc = loc AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, - AstNode** structs, size_t struct_count, int line, int col) { + AstNode** structs, size_t struct_count, SourceLoc loc) { NEW(alloc, AST_PROGRAM); n->as.program.functions = fns; n->as.program.fn_count = fn_count; @@ -19,30 +19,31 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, } AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, - TypeKind ret, AstNode* body, int line, int col) { + TypeKind ret, const char* ret_struct_name, AstNode* body, SourceLoc loc) { NEW(alloc, AST_FUNCTION); n->as.function.name = name; n->as.function.params = params; n->as.function.param_count = pcount; n->as.function.return_type = ret; + n->as.function.return_struct_type_name = ret_struct_name; n->as.function.body = body; return n; } AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, - const char* struct_type_name, int line, int col) { + const char* struct_type_name, SourceLoc loc) { NEW(alloc, AST_PARAMETER); n->as.parameter.name = name; n->as.parameter.type = type; n->as.parameter.struct_type_name = struct_type_name; return n; } -AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, int line, int col) { +AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc) { NEW(alloc, AST_BLOCK); n->as.block.stmts = stmts; n->as.block.stmt_count = count; return n; } AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot, - bool is_mut, AstNode* init, const char* struct_type_name, int line, int col) { + bool is_mut, AstNode* init, const char* struct_type_name, SourceLoc loc) { NEW(alloc, AST_LET_STMT); n->as.let_stmt.name = name; n->as.let_stmt.annot_type = annot_type; n->as.let_stmt.has_type_annot = has_type_annot; n->as.let_stmt.is_mut = is_mut; @@ -51,84 +52,84 @@ AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool h return n; } -AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, int line, int col) { +AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, SourceLoc loc) { NEW(alloc, AST_ASSIGN_STMT); n->as.assign_stmt.name = name; n->as.assign_stmt.value = value; return n; } -AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, int line, int col) { +AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, SourceLoc loc) { NEW(alloc, AST_IF_STMT); n->as.if_stmt.cond = cond; n->as.if_stmt.then_block = then_b; n->as.if_stmt.else_block = else_b; return n; } -AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, int line, int col) { +AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, SourceLoc loc) { NEW(alloc, AST_WHILE_STMT); n->as.while_stmt.cond = cond; n->as.while_stmt.body = body; return n; } -AstNode* ast_make_return(void* alloc, AstNode* expr, int line, int col) { +AstNode* ast_make_return(void* alloc, AstNode* expr, SourceLoc loc) { NEW(alloc, AST_RETURN_STMT); n->as.return_stmt.expr = expr; return n; } -AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, int line, int col) { +AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, SourceLoc loc) { NEW(alloc, AST_EXPR_STMT); n->as.expr_stmt.expr = expr; return n; } -AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, int line, int col) { +AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, SourceLoc loc) { NEW(alloc, AST_BINARY_EXPR); n->as.binary.op = op; n->as.binary.left = left; n->as.binary.right = right; return n; } -AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, int line, int col) { +AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, SourceLoc loc) { NEW(alloc, AST_UNARY_EXPR); n->as.unary.op = op; n->as.unary.operand = operand; return n; } -AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, int line, int col) { +AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, SourceLoc loc) { NEW(alloc, AST_CALL_EXPR); n->as.call.name = name; n->as.call.args = args; n->as.call.arg_count = count; return n; } -AstNode* ast_make_literal_i64(void* alloc, int64_t val, int line, int col) { +AstNode* ast_make_literal_i64(void* alloc, int64_t val, SourceLoc loc) { NEW(alloc, AST_LITERAL_EXPR); n->as.literal.lit_type = TYPE_I64; n->as.literal.i64_val = val; n->type.kind = TYPE_I64; return n; } -AstNode* ast_make_literal_f64(void* alloc, double val, int line, int col) { +AstNode* ast_make_literal_f64(void* alloc, double val, SourceLoc loc) { NEW(alloc, AST_LITERAL_EXPR); n->as.literal.lit_type = TYPE_F64; n->as.literal.f64_val = val; n->type.kind = TYPE_F64; return n; } -AstNode* ast_make_literal_bool(void* alloc, bool val, int line, int col) { +AstNode* ast_make_literal_bool(void* alloc, bool val, SourceLoc loc) { NEW(alloc, AST_LITERAL_EXPR); n->as.literal.lit_type = TYPE_BOOL; n->as.literal.bool_val = val; n->type.kind = TYPE_BOOL; return n; } -AstNode* ast_make_literal_str(void* alloc, const char* val, int line, int col) { +AstNode* ast_make_literal_str(void* alloc, const char* val, SourceLoc loc) { NEW(alloc, AST_LITERAL_EXPR); n->as.literal.lit_type = TYPE_STR; n->as.literal.str_val = val; n->type.kind = TYPE_STR; return n; } -AstNode* ast_make_ident(void* alloc, const char* name, int line, int col) { +AstNode* ast_make_ident(void* alloc, const char* name, SourceLoc loc) { NEW(alloc, AST_IDENT_EXPR); n->as.ident.name = name; return n; @@ -137,7 +138,7 @@ AstNode* ast_make_ident(void* alloc, const char* name, int line, int col) { // === 结构体相关工厂函数 === AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, - size_t count, int line, int col) { + size_t count, SourceLoc loc) { NEW(alloc, AST_STRUCT_DECL); n->as.struct_decl.name = name; n->as.struct_decl.fields = fields; @@ -147,7 +148,7 @@ AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, AstNode* ast_make_struct_init(void* alloc, const char* type_name, const char** fnames, AstNode** fvals, - size_t count, int line, int col) { + size_t count, SourceLoc loc) { NEW(alloc, AST_STRUCT_INIT); n->as.struct_init.type_name = type_name; n->as.struct_init.field_names = fnames; @@ -157,7 +158,7 @@ AstNode* ast_make_struct_init(void* alloc, const char* type_name, } AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, - int line, int col) { + SourceLoc loc) { NEW(alloc, AST_FIELD_ACCESS); n->as.field_access.object = object; n->as.field_access.field = field; diff --git a/src/ast/ast.h b/src/ast/ast.h index 285bc1c..fd147ad 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -42,8 +42,7 @@ typedef struct { struct AstNode { AstKind kind; TypeInfo type; // 语义分析后填充,默认为 TYPE_UNKNOWN - int line; // 源文件行号 - int col; // 源文件列号 + SourceLoc loc; // 源码位置 // 节点特有数据(按 kind 解释) union { @@ -52,7 +51,8 @@ struct AstNode { struct AstNode** structs; size_t struct_count; } program; // AST_FUNCTION struct { const char* name; struct AstNode** params; size_t param_count; - TypeKind return_type; struct AstNode* body; } function; + TypeKind return_type; const char* return_struct_type_name; + struct AstNode* body; } function; // AST_PARAMETER (也用作结构体字段: name + type) struct { const char* name; TypeKind type; const char* struct_type_name; } parameter; // AST_BLOCK @@ -92,28 +92,28 @@ struct AstNode { // 创建节点的辅助函数(内存来自 arena,通过 void* 传递避免循环依赖) AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, - AstNode** structs, size_t struct_count, int line, int col); + AstNode** structs, size_t struct_count, SourceLoc loc); AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, - TypeKind ret, AstNode* body, int line, int col); -AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, int line, int col); -AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, int line, int col); + TypeKind ret, const char* ret_struct_name, AstNode* body, SourceLoc loc); +AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc); +AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc); AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot, - bool is_mut, AstNode* init, const char* struct_type_name, int line, int col); -AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, int line, int col); -AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, int line, int col); -AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, int line, int col); -AstNode* ast_make_return(void* alloc, AstNode* expr, int line, int col); -AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, int line, int col); -AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, int line, int col); -AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, int line, int col); -AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, int line, int col); -AstNode* ast_make_literal_i64(void* alloc, int64_t val, int line, int col); -AstNode* ast_make_literal_f64(void* alloc, double val, int line, int col); -AstNode* ast_make_literal_bool(void* alloc, bool val, int line, int col); -AstNode* ast_make_literal_str(void* alloc, const char* val, int line, int col); -AstNode* ast_make_ident(void* alloc, const char* name, int line, int col); -AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, size_t count, int line, int col); -AstNode* ast_make_struct_init(void* alloc, const char* type_name, const char** fnames, AstNode** fvals, size_t count, int line, int col); -AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, int line, int col); + bool is_mut, AstNode* init, const char* struct_type_name, SourceLoc loc); +AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, SourceLoc loc); +AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, SourceLoc loc); +AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, SourceLoc loc); +AstNode* ast_make_return(void* alloc, AstNode* expr, SourceLoc loc); +AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, SourceLoc loc); +AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, SourceLoc loc); +AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, SourceLoc loc); +AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, SourceLoc loc); +AstNode* ast_make_literal_i64(void* alloc, int64_t val, SourceLoc loc); +AstNode* ast_make_literal_f64(void* alloc, double val, SourceLoc loc); +AstNode* ast_make_literal_bool(void* alloc, bool val, SourceLoc loc); +AstNode* ast_make_literal_str(void* alloc, const char* val, SourceLoc loc); +AstNode* ast_make_ident(void* alloc, const char* name, SourceLoc loc); +AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, size_t count, SourceLoc loc); +AstNode* ast_make_struct_init(void* alloc, const char* type_name, const char** fnames, AstNode** fvals, size_t count, SourceLoc loc); +AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, SourceLoc loc); #endif diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 4ff32c6..58ac2e0 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -608,10 +608,23 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, AstNode* fn = ast->as.program.functions[i]; LLVMTypeRef* ptypes = arena_alloc(ctx.arena, fn->as.function.param_count * sizeof(LLVMTypeRef)); - for (size_t j = 0; j < fn->as.function.param_count; j++) - ptypes[j] = to_llvm_type(&ctx, fn->as.function.params[j]->as.parameter.type); - LLVMTypeRef fty = LLVMFunctionType( - to_llvm_type(&ctx, fn->as.function.return_type), + for (size_t j = 0; j < fn->as.function.param_count; j++) { + AstNode* param = fn->as.function.params[j]; + if (param->as.parameter.type == TYPE_STRUCT && + param->as.parameter.struct_type_name) { + ptypes[j] = find_struct_type(&ctx, param->as.parameter.struct_type_name); + } else { + ptypes[j] = to_llvm_type(&ctx, param->as.parameter.type); + } + } + LLVMTypeRef ret_ty; + if (fn->as.function.return_type == TYPE_STRUCT && + fn->as.function.return_struct_type_name) { + ret_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name); + } else { + ret_ty = to_llvm_type(&ctx, fn->as.function.return_type); + } + LLVMTypeRef fty = LLVMFunctionType(ret_ty, ptypes, (unsigned)fn->as.function.param_count, false); LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty); add_fn(&ctx, fn->as.function.name, lfn); @@ -630,11 +643,18 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, // 将参数注册为变量 for (size_t j = 0; j < fn->as.function.param_count; j++) { LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j); + AstNode* pnode = fn->as.function.params[j]; + LLVMTypeRef param_ty; + if (pnode->as.parameter.type == TYPE_STRUCT && + pnode->as.parameter.struct_type_name) { + param_ty = find_struct_type(&ctx, pnode->as.parameter.struct_type_name); + } else { + param_ty = to_llvm_type(&ctx, pnode->as.parameter.type); + } LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder, - to_llvm_type(&ctx, fn->as.function.params[j]->as.parameter.type), - fn->as.function.params[j]->as.parameter.name); + param_ty, pnode->as.parameter.name); LLVMBuildStore(ctx.builder, param, alloca); - add_var(&ctx, fn->as.function.params[j]->as.parameter.name, alloca); + add_var(&ctx, pnode->as.parameter.name, alloca); } codegen_stmt(&ctx, fn->as.function.body); @@ -645,6 +665,12 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, cleanup_emit(&ctx, 0); if (fn->as.function.return_type == TYPE_VOID) LLVMBuildRetVoid(ctx.builder); + else if (fn->as.function.return_type == TYPE_STRUCT && + fn->as.function.return_struct_type_name) { + LLVMTypeRef st_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name); + LLVMBuildRet(ctx.builder, st_ty ? LLVMConstNull(st_ty) : + LLVMConstInt(to_llvm_type(&ctx, TYPE_I64), 0, false)); + } else LLVMBuildRet(ctx.builder, (fn->as.function.return_type == TYPE_F64 diff --git a/src/lexer/token.h b/src/lexer/token.h index be14be4..6b15264 100644 --- a/src/lexer/token.h +++ b/src/lexer/token.h @@ -41,6 +41,11 @@ struct Token { const char* tok_name(TokenKind kind); bool tok_is_type(TokenKind kind); +// 从 Token 创建 SourceLoc +static inline SourceLoc tok_loc(const Token* tok) { + return (SourceLoc){ .line = tok->line, .col = tok->col }; +} + // 从 Token 提取值 int64_t tok_int_value(const Token* tok); double tok_float_value(const Token* tok); diff --git a/src/parser/parser.c b/src/parser/parser.c index 559bf4a..ce1bb3e 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -78,7 +78,7 @@ static AstNode* parse_unary(Parser* p, ErrorInfo* error) { AstNode* operand = parse_expr_prec(p, PREC_UNARY, error); if (!operand) return NULL; BinaryOp uop = (op->kind == TOK_MINUS) ? OP_NEG : OP_NOT; - return ast_make_unary(p->arena, uop, operand, op->line, op->col); + return ast_make_unary(p->arena, uop, operand, tok_loc(op)); } static AstNode* parse_group(Parser* p, ErrorInfo* error) { @@ -92,15 +92,15 @@ static AstNode* parse_group(Parser* p, ErrorInfo* error) { static AstNode* parse_literal(Parser* p) { const Token* t = advance(p); switch (t->kind) { - case TOK_INT_LIT: return ast_make_literal_i64(p->arena, tok_int_value(t), t->line, t->col); - case TOK_FLOAT_LIT: return ast_make_literal_f64(p->arena, tok_float_value(t), t->line, t->col); - case TOK_TRUE: return ast_make_literal_bool(p->arena, true, t->line, t->col); - case TOK_FALSE: return ast_make_literal_bool(p->arena, false, t->line, t->col); + case TOK_INT_LIT: return ast_make_literal_i64(p->arena, tok_int_value(t), tok_loc(t)); + case TOK_FLOAT_LIT: return ast_make_literal_f64(p->arena, tok_float_value(t), tok_loc(t)); + case TOK_TRUE: return ast_make_literal_bool(p->arena, true, tok_loc(t)); + case TOK_FALSE: return ast_make_literal_bool(p->arena, false, tok_loc(t)); case TOK_STR_LIT: { char* str = arena_alloc_impl(p->arena, t->length + 1); memcpy(str, t->start, t->length); str[t->length] = '\0'; - return ast_make_literal_str(p->arena, str, t->line, t->col); + return ast_make_literal_str(p->arena, str, tok_loc(t)); } default: return NULL; } @@ -137,7 +137,7 @@ static AstNode* parse_struct_init(Parser* p, const Token* name, ErrorInfo* error return ast_make_struct_init(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - n_arr, v_arr, fcount, name->line, name->col); + n_arr, v_arr, fcount, tok_loc(name)); } // === 标识符 / 函数调用 / 结构体初始化 === @@ -175,11 +175,11 @@ static AstNode* parse_ident_or_call(Parser* p, ErrorInfo* error) { AstNode** arg_arr = arena_alloc_impl(p->arena, arg_count * sizeof(AstNode*)); memcpy(arg_arr, args, arg_count * sizeof(AstNode*)); return ast_make_call(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - arg_arr, arg_count, name->line, name->col); + arg_arr, arg_count, tok_loc(name)); } return ast_make_ident(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - name->line, name->col); + tok_loc(name)); } // === Pratt 主循环 === @@ -216,7 +216,7 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error if (!field) return NULL; left = ast_make_field_access(p->arena, left, arena_strdup_impl(p->arena, field->start, field->length), - field->line, field->col); + tok_loc(field)); continue; } @@ -227,7 +227,7 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error const Token* op = advance(p); AstNode* right = parse_expr_prec(p, prec, error); if (!right) return NULL; - left = ast_make_binary(p->arena, tok_to_binop(kind), left, right, op->line, op->col); + left = ast_make_binary(p->arena, tok_to_binop(kind), left, right, tok_loc(op)); } return left; @@ -271,7 +271,7 @@ static AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) { } fields[fcount++] = ast_make_parameter(p->arena, arena_strdup_impl(p->arena, fname->start, fname->length), - field_kind, field_struct_name, fname->line, fname->col); + field_kind, field_struct_name, tok_loc(fname)); if (peek(p)->kind == TOK_COMMA) advance(p); else break; } @@ -281,7 +281,7 @@ static AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) { memcpy(farr, fields, fcount * sizeof(AstNode*)); return ast_make_struct_decl(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - farr, fcount, s_tok->line, s_tok->col); + farr, fcount, tok_loc(s_tok)); } // === 语句解析 === @@ -305,7 +305,7 @@ static AstNode* parse_block(Parser* p, ErrorInfo* error) { AstNode** arr = arena_alloc_impl(p->arena, count * sizeof(AstNode*)); memcpy(arr, stmts, count * sizeof(AstNode*)); parse_depth--; - return ast_make_block(p->arena, arr, count, open->line, open->col); + return ast_make_block(p->arena, arr, count, tok_loc(open)); } static AstNode* parse_statement(Parser* p, ErrorInfo* error) { @@ -342,7 +342,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; return ast_make_let(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - annot_type, has_type_annot, is_mut, init, struct_type_name, t->line, t->col); + annot_type, has_type_annot, is_mut, init, struct_type_name, tok_loc(t)); } if (t->kind == TOK_IF) { @@ -360,7 +360,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { } if (!else_block) return NULL; } - return ast_make_if(p->arena, cond, then_block, else_block, t->line, t->col); + return ast_make_if(p->arena, cond, then_block, else_block, tok_loc(t)); } if (t->kind == TOK_WHILE) { @@ -369,7 +369,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { if (!cond) return NULL; AstNode* body = parse_block(p, error); if (!body) return NULL; - return ast_make_while(p->arena, cond, body, t->line, t->col); + return ast_make_while(p->arena, cond, body, tok_loc(t)); } if (t->kind == TOK_FOR) { @@ -403,20 +403,20 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { const char* vname = arena_strdup_impl(p->arena, var_name->start, var_name->length); // 构建: let mut i = start; - AstNode* let_stmt = ast_make_let(p->arena, vname, TYPE_UNKNOWN, false, true, start_expr, NULL, var_name->line, var_name->col); + AstNode* let_stmt = ast_make_let(p->arena, vname, TYPE_UNKNOWN, false, true, start_expr, NULL, tok_loc(var_name)); // 构建: i < end (while 条件) AstNode* cond = ast_make_binary(p->arena, OP_LT, - ast_make_ident(p->arena, vname, var_name->line, var_name->col), - end_expr, var_name->line, var_name->col); + ast_make_ident(p->arena, vname, tok_loc(var_name)), + end_expr, tok_loc(var_name)); // 构建: i = i + 1 (循环增量) AstNode* incr = ast_make_assign(p->arena, vname, ast_make_binary(p->arena, OP_ADD, - ast_make_ident(p->arena, vname, var_name->line, var_name->col), - ast_make_literal_i64(p->arena, 1, var_name->line, var_name->col), - var_name->line, var_name->col), - var_name->line, var_name->col); + ast_make_ident(p->arena, vname, tok_loc(var_name)), + ast_make_literal_i64(p->arena, 1, tok_loc(var_name)), + tok_loc(var_name)), + tok_loc(var_name)); // 将增量追加到循环体末尾 AstNode** new_stmts = arena_alloc_impl(p->arena, @@ -424,27 +424,27 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { memcpy(new_stmts, body->as.block.stmts, body->as.block.stmt_count * sizeof(AstNode*)); new_stmts[body->as.block.stmt_count] = incr; AstNode* new_body = ast_make_block(p->arena, new_stmts, - body->as.block.stmt_count + 1, body->line, body->col); + body->as.block.stmt_count + 1, body->loc); // 构建: while i < end { ... body ... ; i = i + 1; } - AstNode* while_loop = ast_make_while(p->arena, cond, new_body, t->line, t->col); + AstNode* while_loop = ast_make_while(p->arena, cond, new_body, tok_loc(t)); // 包装: { let mut i = start; while i < end { ... } } AstNode* stmts_arr[2] = { let_stmt, while_loop }; AstNode** stmts = arena_alloc_impl(p->arena, 2 * sizeof(AstNode*)); memcpy(stmts, stmts_arr, 2 * sizeof(AstNode*)); - return ast_make_block(p->arena, stmts, 2, t->line, t->col); + return ast_make_block(p->arena, stmts, 2, tok_loc(t)); } if (t->kind == TOK_RETURN) { advance(p); if (match(p, TOK_SEMICOLON)) { - return ast_make_return(p->arena, NULL, t->line, t->col); + return ast_make_return(p->arena, NULL, tok_loc(t)); } AstNode* expr = parse_expr(p, error); if (!expr) return NULL; if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; - return ast_make_return(p->arena, expr, t->line, t->col); + return ast_make_return(p->arena, expr, tok_loc(t)); } // 赋值语句: ident = expr ; @@ -456,7 +456,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; return ast_make_assign(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - value, name->line, name->col); + value, tok_loc(name)); } // 复合赋值: ident += expr → ident = ident + expr @@ -481,12 +481,12 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { AstNode* lhs_ident = ast_make_ident(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - name->line, name->col); + tok_loc(name)); AstNode* bin_expr = ast_make_binary(p->arena, binop, lhs_ident, rhs, - name->line, name->col); + tok_loc(name)); return ast_make_assign(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - bin_expr, name->line, name->col); + bin_expr, tok_loc(name)); } } @@ -494,7 +494,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) { AstNode* expr = parse_expr(p, error); if (!expr) return NULL; if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; - return ast_make_expr_stmt(p->arena, expr, t->line, t->col); + return ast_make_expr_stmt(p->arena, expr, tok_loc(t)); } // === 函数解析 === @@ -512,13 +512,20 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) { if (!pname) return NULL; if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; const Token* ptype = advance(p); - if (!tok_is_type(ptype->kind)) { + TypeKind param_kind; + const char* param_struct_name = NULL; + if (tok_is_type(ptype->kind)) { + param_kind = token_to_type(ptype->kind); + } else if (ptype->kind == TOK_IDENT) { + param_kind = TYPE_STRUCT; + param_struct_name = arena_strdup_impl(p->arena, ptype->start, ptype->length); + } else { error->message = "无效的参数类型"; error->filename = p->filename; error->line = ptype->line; error->col = ptype->col; return NULL; } params[pcount++] = ast_make_parameter(p->arena, arena_strdup_impl(p->arena, pname->start, pname->length), - token_to_type(ptype->kind), NULL, pname->line, pname->col); + param_kind, param_struct_name, tok_loc(pname)); if (match(p, TOK_COMMA)) continue; else break; } @@ -526,13 +533,18 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) { // 返回类型 TypeKind ret = TYPE_VOID; + const char* ret_struct_name = NULL; if (match(p, TOK_ARROW)) { const Token* rt = advance(p); - if (!tok_is_type(rt->kind)) { + if (tok_is_type(rt->kind)) { + ret = token_to_type(rt->kind); + } else if (rt->kind == TOK_IDENT) { + ret = TYPE_STRUCT; + ret_struct_name = arena_strdup_impl(p->arena, rt->start, rt->length); + } else { error->message = "无效的返回类型"; error->filename = p->filename; error->line = rt->line; error->col = rt->col; return NULL; } - ret = token_to_type(rt->kind); } AstNode* body = parse_block(p, error); @@ -542,7 +554,7 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) { memcpy(parr, params, pcount * sizeof(AstNode*)); return ast_make_function(p->arena, arena_strdup_impl(p->arena, name->start, name->length), - parr, pcount, ret, body, fn_tok->line, fn_tok->col); + parr, pcount, ret, ret_struct_name, body, tok_loc(fn_tok)); } // === 程序入口 === @@ -572,5 +584,5 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, memcpy(fn_arr, functions, fn_count * sizeof(AstNode*)); AstNode** st_arr = arena_alloc_impl(a, struct_count * sizeof(AstNode*)); memcpy(st_arr, structs, struct_count * sizeof(AstNode*)); - return ast_make_program(a, fn_arr, fn_count, st_arr, struct_count, 0, 0); + return ast_make_program(a, fn_arr, fn_count, st_arr, struct_count, loc_at(0, 0)); } diff --git a/src/sema/sema.c b/src/sema/sema.c index 3ad1548..a0d2a59 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -3,6 +3,7 @@ // === 类型关系 === static TypeKind current_return_type = TYPE_VOID; +static const char* current_return_struct_name = NULL; static TypeKind promote(TypeKind a, TypeKind b) { if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64; @@ -26,11 +27,11 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* case AST_IDENT_EXPR: { Symbol* sym = scope_lookup(scope, node->as.ident.name); if (!sym) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "未定义的变量 '%s'", node->as.ident.name); node->type.kind = TYPE_ERROR; } else if (sym->kind == SYM_FUNCTION) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "'%s' 是函数,不能作为表达式使用", node->as.ident.name); node->type.kind = TYPE_ERROR; } else { @@ -47,7 +48,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* TypeKind inner = node->as.unary.operand->type.kind; if (node->as.unary.op == OP_NEG) { if (!is_numeric(inner)) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "一元 '-' 只能用于数值类型"); node->type.kind = TYPE_ERROR; } else { @@ -55,7 +56,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* } } else { // OP_NOT if (inner != TYPE_BOOL) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "'!' 只能用于布尔类型,得到 '%s'", type_name(inner)); node->type.kind = TYPE_ERROR; } else { @@ -77,7 +78,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* if (l == TYPE_STR || r == TYPE_STR) { // 字符串拼接:两边都必须是 str 类型 if (l != TYPE_STR || r != TYPE_STR) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'", type_name(l), type_name(r)); node->type.kind = TYPE_ERROR; @@ -85,7 +86,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* node->type.kind = TYPE_STR; } } else if (!is_numeric(l) || !is_numeric(r)) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); node->type.kind = TYPE_ERROR; } else { @@ -94,7 +95,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD: if (!is_numeric(l) || !is_numeric(r)) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "算术运算需要数值类型"); node->type.kind = TYPE_ERROR; } else { @@ -103,7 +104,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE: if (!is_comparable(l, r)) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r)); node->type.kind = TYPE_ERROR; } else { @@ -112,7 +113,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; case OP_AND: case OP_OR: if (l != TYPE_BOOL || r != TYPE_BOOL) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "逻辑运算需要布尔类型"); node->type.kind = TYPE_ERROR; } else { @@ -127,7 +128,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* case AST_CALL_EXPR: { Symbol* sym = scope_lookup(scope, node->as.call.name); if (!sym || sym->kind != SYM_FUNCTION) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "未定义的函数 '%s'", node->as.call.name); node->type.kind = TYPE_ERROR; // 即使函数未定义,也要分析参数表达式(它们可能有更多错误) @@ -137,7 +138,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; } if (node->as.call.arg_count != sym->param_count) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "函数 '%s' 需要 %zu 个参数,但提供了 %zu 个", node->as.call.name, sym->param_count, node->as.call.arg_count); node->type.kind = TYPE_ERROR; @@ -151,13 +152,30 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* analyze_expr(node->as.call.args[i], scope, errors, a); TypeKind actual = node->as.call.args[i]->type.kind; TypeKind expected = sym->param_types[i]; - if (actual != TYPE_ERROR && actual != expected) { - error_add(errors, "", node->line, node->col, - "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", - i + 1, type_name(expected), type_name(actual)); + if (actual != TYPE_ERROR) { + if (expected == TYPE_STRUCT) { + // 结构体参数:比较具体类型名 + const char* actual_name = node->as.call.args[i]->type.struct_name; + const char* expected_name = sym->param_struct_names ? sym->param_struct_names[i] : NULL; + if (actual != TYPE_STRUCT || !actual_name || !expected_name || + strcmp(actual_name, expected_name) != 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", + i + 1, + expected_name ? expected_name : "struct", + actual_name ? actual_name : type_name(actual)); + } + } else if (actual != expected) { + error_add(errors, "", node->loc.line, node->loc.col, + "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", + i + 1, type_name(expected), type_name(actual)); + } } } node->type.kind = sym->return_type; + if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) { + node->type.struct_name = sym->return_struct_type_name; + } break; } @@ -169,7 +187,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; } if (obj->type.kind != TYPE_STRUCT) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "类型 '%s' 不是结构体,不能访问字段 '%s'", type_name(obj->type.kind), node->as.field_access.field); node->type.kind = TYPE_ERROR; @@ -178,21 +196,21 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* // 查找结构体定义 const char* struct_name = obj->type.struct_name; if (!struct_name) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "无法确定结构体类型"); node->type.kind = TYPE_ERROR; break; } Symbol* struct_sym = scope_lookup_struct(scope, struct_name); if (!struct_sym) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "未定义的结构体 '%s'", struct_name); node->type.kind = TYPE_ERROR; break; } int fi = scope_struct_field_index(struct_sym, node->as.field_access.field); if (fi < 0) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field); node->type.kind = TYPE_ERROR; break; @@ -211,13 +229,13 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* case AST_STRUCT_INIT: { Symbol* struct_sym = scope_lookup_struct(scope, node->as.struct_init.type_name); if (!struct_sym) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "未定义的结构体类型 '%s'", node->as.struct_init.type_name); node->type.kind = TYPE_ERROR; break; } if (node->as.struct_init.field_count != struct_sym->struct_field_count) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 有 %zu 个字段,但提供了 %zu 个", node->as.struct_init.type_name, struct_sym->struct_field_count, @@ -233,7 +251,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* int fi = scope_struct_field_index(struct_sym, fname); if (fi < 0) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "结构体 '%s' 没有字段 '%s'", node->as.struct_init.type_name, fname); node->type.kind = TYPE_ERROR; @@ -242,7 +260,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* TypeKind expected = struct_sym->struct_field_types[fi]; TypeKind actual = fval->type.kind; if (actual != TYPE_ERROR && actual != expected) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'", fname, type_name(expected), type_name(actual)); } @@ -285,11 +303,15 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* for (size_t i = 0; i < node->as.program.fn_count; i++) { AstNode* fn = node->as.program.functions[i]; TypeKind* pts = (TypeKind*)arena_alloc_impl(a, fn->as.function.param_count * sizeof(TypeKind)); + const char** pstruct_names = (const char**)arena_alloc_impl(a, fn->as.function.param_count * sizeof(const char*)); for (size_t j = 0; j < fn->as.function.param_count; j++) { pts[j] = fn->as.function.params[j]->as.parameter.type; + pstruct_names[j] = fn->as.function.params[j]->as.parameter.struct_type_name; } scope_insert_function(scope, a, fn->as.function.name, - fn->as.function.return_type, pts, + fn->as.function.return_type, + fn->as.function.return_struct_type_name, + pts, pstruct_names, fn->as.function.param_count); } // 第三遍:分析每个函数体 @@ -303,12 +325,18 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* // 注册参数 for (size_t i = 0; i < node->as.function.param_count; i++) { AstNode* p = node->as.function.params[i]; - scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); + Symbol* sym = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); + if (sym && p->as.parameter.type == TYPE_STRUCT && p->as.parameter.struct_type_name) { + sym->struct_type_name = p->as.parameter.struct_type_name; + } } TypeKind saved = current_return_type; + const char* saved_name = current_return_struct_name; current_return_type = node->as.function.return_type; + current_return_struct_name = node->as.function.return_struct_type_name; analyze_node(node->as.function.body, fn_scope, errors, a); current_return_type = saved; + current_return_struct_name = saved_name; break; } @@ -330,7 +358,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* // struct 类型标注 Symbol* st_sym = scope_lookup_struct(scope, annot_struct); if (!st_sym) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "未定义的结构体类型 '%s'", annot_struct); break; } @@ -340,7 +368,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* var_type = node->as.let_stmt.annot_type; } if (inferred != TYPE_ERROR && inferred != var_type) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "变量 '%s' 类型标注为 '%s',但初始化表达式类型为 '%s'", node->as.let_stmt.name, annot_struct ? annot_struct : type_name(var_type), @@ -349,7 +377,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* } else { // 类型推断 if (inferred == TYPE_ERROR || inferred == TYPE_VOID) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "无法从表达式推断变量 '%s' 的类型", node->as.let_stmt.name); break; } @@ -363,7 +391,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* node->type.struct_name = var_struct_name; Symbol* sym = scope_insert(scope, a, node->as.let_stmt.name, SYM_VARIABLE, var_type); if (!sym) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "变量 '%s' 重复定义", node->as.let_stmt.name); } else { sym->is_mut = node->as.let_stmt.is_mut; @@ -378,19 +406,19 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* case AST_ASSIGN_STMT: { Symbol* sym = scope_lookup(scope, node->as.assign_stmt.name); if (!sym) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "未定义的变量 '%s'", node->as.assign_stmt.name); node->type.kind = TYPE_ERROR; break; } if (sym->kind != SYM_VARIABLE) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "'%s' 不是变量,不能赋值", node->as.assign_stmt.name); node->type.kind = TYPE_ERROR; break; } if (!sym->is_mut) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "不能对不可变变量 '%s' 赋值(需用 let mut 声明)", node->as.assign_stmt.name); node->type.kind = TYPE_ERROR; @@ -399,7 +427,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* analyze_expr(node->as.assign_stmt.value, scope, errors, a); TypeKind value_ty = node->as.assign_stmt.value->type.kind; if (value_ty != TYPE_ERROR && value_ty != sym->type) { - error_add(errors, "", node->line, node->col, + error_add(errors, "", node->loc.line, node->loc.col, "赋值类型不匹配: 变量 '%s' 类型为 '%s',但表达式类型为 '%s'", node->as.assign_stmt.name, type_name(sym->type), type_name(value_ty)); } @@ -411,7 +439,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* analyze_expr(node->as.if_stmt.cond, scope, errors, a); if (node->as.if_stmt.cond->type.kind != TYPE_BOOL && node->as.if_stmt.cond->type.kind != TYPE_ERROR) { - error_add(errors, "", node->line, node->col, "if 条件必须是布尔类型"); + error_add(errors, "", node->loc.line, node->loc.col, "if 条件必须是布尔类型"); } analyze_node(node->as.if_stmt.then_block, scope, errors, a); if (node->as.if_stmt.else_block) { @@ -423,7 +451,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* analyze_expr(node->as.while_stmt.cond, scope, errors, a); if (node->as.while_stmt.cond->type.kind != TYPE_BOOL && node->as.while_stmt.cond->type.kind != TYPE_ERROR) { - error_add(errors, "", node->line, node->col, "while 条件必须是布尔类型"); + error_add(errors, "", node->loc.line, node->loc.col, "while 条件必须是布尔类型"); } analyze_node(node->as.while_stmt.body, scope, errors, a); break; @@ -434,14 +462,28 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* node->type.kind = node->as.return_stmt.expr->type.kind; TypeKind actual = node->as.return_stmt.expr->type.kind; TypeKind expected = current_return_type; - if (actual != TYPE_ERROR && expected != TYPE_VOID && actual != expected) { - error_add(errors, "", node->line, node->col, - "返回类型不匹配: 期望 '%s',得到 '%s'", - type_name(expected), type_name(actual)); + if (actual != TYPE_ERROR && expected != TYPE_VOID) { + if (expected == TYPE_STRUCT) { + // 结构体返回类型:比较具体类型名 + const char* actual_name = node->as.return_stmt.expr->type.struct_name; + const char* expected_name = current_return_struct_name; + if (actual != TYPE_STRUCT || !actual_name || !expected_name || + strcmp(actual_name, expected_name) != 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "返回类型不匹配: 期望 '%s',得到 '%s'", + expected_name ? expected_name : "struct", + actual_name ? actual_name : type_name(actual)); + } + } else if (actual != expected) { + error_add(errors, "", node->loc.line, node->loc.col, + "返回类型不匹配: 期望 '%s',得到 '%s'", + type_name(expected), type_name(actual)); + } } } else if (current_return_type != TYPE_VOID) { - error_add(errors, "", node->line, node->col, - "函数应返回值类型 '%s'", type_name(current_return_type)); + error_add(errors, "", node->loc.line, node->loc.col, + "函数应返回值类型 '%s'", + current_return_struct_name ? current_return_struct_name : type_name(current_return_type)); } break; @@ -460,13 +502,13 @@ void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) { // 注册内置函数 TypeKind params_i64[] = {TYPE_I64}; - scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, params_i64, 1); + scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, NULL, params_i64, NULL, 1); TypeKind params_f64[] = {TYPE_F64}; - scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, params_f64, 1); + scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, NULL, params_f64, NULL, 1); TypeKind params_bool[] = {TYPE_BOOL}; - scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, params_bool, 1); + scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, NULL, params_bool, NULL, 1); TypeKind params_str[] = {TYPE_STR}; - scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, params_str, 1); + scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, 1); analyze_node(ast, global_scope, errors, arena); } diff --git a/src/sema/symbol.c b/src/sema/symbol.c index 7bccd92..4219512 100644 --- a/src/sema/symbol.c +++ b/src/sema/symbol.c @@ -41,7 +41,8 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name, } Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, - TypeKind ret, TypeKind* pt, size_t pc) { + TypeKind ret, const char* ret_struct_name, + TypeKind* pt, const char** pstruct_names, size_t pc) { if (scope->head) { for (Symbol* sym = scope->head; sym; sym = sym->next) { if (strcmp(sym->name, name) == 0) return NULL; @@ -50,7 +51,11 @@ Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, Symbol* sym = (Symbol*)arena_alloc_impl(alloc, sizeof(Symbol)); if (!sym) return NULL; sym->name = name; sym->kind = SYM_FUNCTION; sym->type = TYPE_VOID; - sym->return_type = ret; sym->param_types = pt; sym->param_count = pc; + sym->return_type = ret; + sym->return_struct_type_name = ret_struct_name; + sym->param_types = pt; + sym->param_struct_names = pstruct_names; + sym->param_count = pc; sym->struct_field_names = NULL; sym->struct_field_types = NULL; sym->struct_field_count = 0; diff --git a/src/sema/symbol.h b/src/sema/symbol.h index 7454c15..d02a4b4 100644 --- a/src/sema/symbol.h +++ b/src/sema/symbol.h @@ -13,7 +13,9 @@ typedef struct Symbol { bool is_mut; // 变量是否可变(可被赋值) // 函数特有 TypeKind return_type; + const char* return_struct_type_name; // 返回类型为 struct 时的类型名 TypeKind* param_types; + const char** param_struct_names; // 参数为 struct 时的类型名 size_t param_count; // 结构体特有(SYM_STRUCT) const char** struct_field_names; @@ -43,7 +45,8 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name, // 插入函数符号 Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, - TypeKind ret, TypeKind* pt, size_t pc); + TypeKind ret, const char* ret_struct_name, + TypeKind* pt, const char** pstruct_names, size_t pc); // 插入结构体符号 Symbol* scope_insert_struct(Scope* scope, void* alloc, const char* name, diff --git a/test/programs/14_struct_fn.l b/test/programs/14_struct_fn.l new file mode 100644 index 0000000..c8422d3 --- /dev/null +++ b/test/programs/14_struct_fn.l @@ -0,0 +1,19 @@ +struct Point { + x: i64, + y: i64, +} + +fn make_point(x: i64, y: i64) -> Point { + return Point { x: x, y: y }; +} + +fn print_point(p: Point) -> void { + print_i64(p.x); + print_i64(p.y); +} + +fn main() -> i64 { + let p: Point = make_point(3, 4); + print_point(p); + return 0; +} diff --git a/test/test_codegen.c b/test/test_codegen.c index 0a386ea..f20789d 100644 --- a/test/test_codegen.c +++ b/test/test_codegen.c @@ -8,12 +8,12 @@ void test_codegen_simple_function() { Arena a = arena_create(1); // 构造 AST: fn main() -> i64 { return 42; } - AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 42, 1, 1), 1, 1); + AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 42, loc_at(1, 1)), loc_at(1, 1)); AstNode* stmts[] = { ret }; - AstNode* body = ast_make_block(&a, stmts, 1, 1, 1); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, body, 1, 1); + AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -35,19 +35,19 @@ void test_codegen_if_else() { Arena a = arena_create(1); // fn main() -> i64 { if true { return 1; } else { return 0; } } - AstNode* then_ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, 1, 1), 1, 1); + AstNode* then_ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, loc_at(1, 1)), loc_at(1, 1)); AstNode* then_stmts[] = { then_ret }; - AstNode* then_block = ast_make_block(&a, then_stmts, 1, 1, 1); - AstNode* else_ret = ast_make_return(&a, ast_make_literal_i64(&a, 0, 1, 1), 1, 1); + AstNode* then_block = ast_make_block(&a, then_stmts, 1, loc_at(1, 1)); + AstNode* else_ret = ast_make_return(&a, ast_make_literal_i64(&a, 0, loc_at(1, 1)), loc_at(1, 1)); AstNode* else_stmts[] = { else_ret }; - AstNode* else_block = ast_make_block(&a, else_stmts, 1, 1, 1); + AstNode* else_block = ast_make_block(&a, else_stmts, 1, loc_at(1, 1)); AstNode* if_stmt = ast_make_if(&a, - ast_make_literal_bool(&a, true, 1, 1), then_block, else_block, 1, 1); + ast_make_literal_bool(&a, true, loc_at(1, 1)), then_block, else_block, loc_at(1, 1)); AstNode* stmts[] = { if_stmt }; - AstNode* body = ast_make_block(&a, stmts, 1, 1, 1); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, body, 1, 1); + AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx2 = NULL; @@ -68,17 +68,17 @@ void test_codegen_binary_ops() { // fn main() -> i64 { return 1 + 2 * 3; } AstNode* expr = ast_make_binary(&a, OP_ADD, - ast_make_literal_i64(&a, 1, 1, 1), + ast_make_literal_i64(&a, 1, loc_at(1, 1)), ast_make_binary(&a, OP_MUL, - ast_make_literal_i64(&a, 2, 1, 1), - ast_make_literal_i64(&a, 3, 1, 1), 1, 1), - 1, 1); - AstNode* ret = ast_make_return(&a, expr, 1, 1); + ast_make_literal_i64(&a, 2, loc_at(1, 1)), + ast_make_literal_i64(&a, 3, loc_at(1, 1)), loc_at(1, 1)), + loc_at(1, 1)); + AstNode* ret = ast_make_return(&a, expr, loc_at(1, 1)); AstNode* stmts[] = { ret }; - AstNode* body = ast_make_block(&a, stmts, 1, 1, 1); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, body, 1, 1); + AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx3 = NULL; @@ -98,17 +98,17 @@ void test_codegen_while_loop() { Arena a = arena_create(1); // fn main() -> i64 { while true { return 0; } return 1; } AstNode* while_body_stmts[] = { - ast_make_return(&a, ast_make_literal_i64(&a, 0, 1, 1), 1, 1) + ast_make_return(&a, ast_make_literal_i64(&a, 0, loc_at(1, 1)), loc_at(1, 1)) }; - AstNode* while_body = ast_make_block(&a, while_body_stmts, 1, 1, 1); + AstNode* while_body = ast_make_block(&a, while_body_stmts, 1, loc_at(1, 1)); AstNode* while_stmt = ast_make_while(&a, - ast_make_literal_bool(&a, true, 1, 1), while_body, 1, 1); - AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, 1, 1), 1, 1); + ast_make_literal_bool(&a, true, loc_at(1, 1)), while_body, loc_at(1, 1)); + AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, loc_at(1, 1)), loc_at(1, 1)); AstNode* stmts[] = { while_stmt, ret }; - AstNode* fn_body = ast_make_block(&a, stmts, 2, 1, 1); - AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, fn_body, 1, 1); + AstNode* fn_body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, fn_body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx4 = NULL; @@ -121,10 +121,125 @@ void test_codegen_while_loop() { arena_destroy(&a); } +/* === struct IR 生成测试 === */ + +void test_codegen_struct_decl() { + Arena a = arena_create(1); + + /* 构造 AST: struct Point { x: i64, y: i64 } */ + AstNode* fields[2]; + fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1)); + fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1)); + AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1)); + AstNode* structs[] = { struct_decl }; + + /* 构造 fn main() -> i64 { let p = Point { x: 1, y: 2 }; return p.x; } */ + const char* fnames[] = {"x", "y"}; + AstNode* fvals[] = { + ast_make_literal_i64(&a, 1, loc_at(1, 1)), + ast_make_literal_i64(&a, 2, loc_at(1, 1)) + }; + AstNode* init = ast_make_struct_init(&a, "Point", fnames, fvals, 2, loc_at(1, 1)); + /* 手动设置类型(绕过 sema,codegen 需要读取 type 字段) */ + init->type.kind = TYPE_STRUCT; + init->type.struct_name = "Point"; + + AstNode* let_stmt = ast_make_let(&a, "p", TYPE_UNKNOWN, false, false, + init, NULL, loc_at(1, 1)); + + /* return p.x; */ + AstNode* p_ident = ast_make_ident(&a, "p", loc_at(1, 1)); + p_ident->type.kind = TYPE_STRUCT; + p_ident->type.struct_name = "Point"; + + AstNode* field_x = ast_make_field_access(&a, p_ident, "x", loc_at(1, 1)); + field_x->as.field_access.field_index = 0; /* x 是第 0 个字段 */ + field_x->type.kind = TYPE_I64; + + AstNode* ret = ast_make_return(&a, field_x, loc_at(1, 1)); + AstNode* stmts[] = { let_stmt, ret }; + AstNode* body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); + AstNode* fns[] = { fn }; + + AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, loc_at(1, 1)); + + const char* err = NULL; + LLVMContextRef ctx = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_struct_decl", &err, &ctx); + ASSERT(mod != NULL); + ASSERT(err == NULL); + + char* verify_err = NULL; + int failed = LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err); + ASSERT(!failed); + + LLVMDisposeModule(mod); + LLVMContextDispose(ctx); + arena_destroy(&a); +} + +void test_codegen_struct_field_access() { + Arena a = arena_create(1); + + /* 构造 AST: struct Point { x: i64, y: i64 } */ + AstNode* fields[2]; + fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1)); + fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1)); + AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1)); + AstNode* structs[] = { struct_decl }; + + /* 构造 fn main() -> i64 { let p = Point { x: 5, y: 10 }; return p.y; } */ + const char* fnames[] = {"x", "y"}; + AstNode* fvals[] = { + ast_make_literal_i64(&a, 5, loc_at(1, 1)), + ast_make_literal_i64(&a, 10, loc_at(1, 1)) + }; + AstNode* init = ast_make_struct_init(&a, "Point", fnames, fvals, 2, loc_at(1, 1)); + init->type.kind = TYPE_STRUCT; + init->type.struct_name = "Point"; + + AstNode* let_stmt = ast_make_let(&a, "p", TYPE_UNKNOWN, false, false, + init, NULL, loc_at(1, 1)); + + /* return p.y; */ + AstNode* p_ident = ast_make_ident(&a, "p", loc_at(1, 1)); + p_ident->type.kind = TYPE_STRUCT; + p_ident->type.struct_name = "Point"; + + AstNode* field_y = ast_make_field_access(&a, p_ident, "y", loc_at(1, 1)); + field_y->as.field_access.field_index = 1; /* y 是第 1 个字段 */ + field_y->type.kind = TYPE_I64; + + AstNode* ret = ast_make_return(&a, field_y, loc_at(1, 1)); + AstNode* stmts[] = { let_stmt, ret }; + AstNode* body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); + AstNode* fns[] = { fn }; + + AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, loc_at(1, 1)); + + const char* err = NULL; + LLVMContextRef ctx = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_struct_field", &err, &ctx); + ASSERT(mod != NULL); + ASSERT(err == NULL); + + char* verify_err = NULL; + int failed = LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err); + ASSERT(!failed); + + LLVMDisposeModule(mod); + LLVMContextDispose(ctx); + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_codegen_simple_function); TEST_RUN(test_codegen_if_else); TEST_RUN(test_codegen_binary_ops); TEST_RUN(test_codegen_while_loop); + TEST_RUN(test_codegen_struct_decl); + TEST_RUN(test_codegen_struct_field_access); return test_summary(); } diff --git a/test/test_sema.c b/test/test_sema.c index 0a4cf78..cba8d08 100644 --- a/test/test_sema.c +++ b/test/test_sema.c @@ -115,6 +115,76 @@ void test_str_concat_type_ok() { arena_destroy(&a); } +/* === struct 类型检查测试 === */ + +void test_struct_field_type_mismatch() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "struct Point { x: i64, y: i64 } fn main() { let p: Point = Point { x: 10, y: true }; return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count > 0); // y 字段类型不匹配: true 是 bool, 不是 i64 + arena_destroy(&a); +} + +void test_struct_undefined() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "fn main() { let p: Unknown = Unknown { x: 1 }; return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count > 0); // Unknown 未定义 + arena_destroy(&a); +} + +void test_struct_field_count_mismatch() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "struct Point { x: i64, y: i64 } fn main() { let p: Point = Point { x: 10 }; return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count > 0); // 缺少字段 'y' + arena_destroy(&a); +} + +void test_struct_nested_type_ok() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "struct Point { x: i64, y: i64 } struct Rect { tl: Point, br: Point } fn main() { let r: Rect = Rect { tl: Point { x: 0, y: 0 }, br: Point { x: 1, y: 1 } }; print_i64(r.tl.x); return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count == 0); // 嵌套结构体类型检查应通过 + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_type_error); TEST_RUN(test_undefined_var); @@ -123,5 +193,9 @@ int main(void) { TEST_RUN(test_assign_immutable_error); TEST_RUN(test_str_type_ok); TEST_RUN(test_str_concat_type_ok); + TEST_RUN(test_struct_field_type_mismatch); + TEST_RUN(test_struct_undefined); + TEST_RUN(test_struct_field_count_mismatch); + TEST_RUN(test_struct_nested_type_ok); return test_summary(); }