feat: struct参数/返回值 + SourceLoc + 测试补全

- struct 可作函数参数和返回值 (fn make_point -> Point)
- SourceLoc 抽象: 所有 ast_make_* 参数 + AstNode 迁移完毕
- sema: +4 struct 类型检查测试 (字段类型/未定义/数量/嵌套)
- codegen: +2 struct IR 生成测试 (decl + field_access)
- 新增集成测试 14_struct_fn.l

测试: 104 单元 + 14 集成 = 全部通过
This commit is contained in:
2026-06-05 13:29:31 +08:00
parent 4046ab1875
commit da9a7065dd
12 changed files with 481 additions and 168 deletions
+11
View File
@@ -29,6 +29,17 @@ static inline const char* type_name(TypeKind kind) {
} }
} }
// === 源码位置(替代分散的 line/col 参数)===
typedef struct {
int line;
int col;
} SourceLoc;
// 手动创建 SourceLoc
static inline SourceLoc loc_at(int line, int col) {
return (SourceLoc){ .line = line, .col = col };
}
// === 向前声明 === // === 向前声明 ===
typedef struct Token Token; typedef struct Token Token;
typedef struct AstNode AstNode; typedef struct AstNode AstNode;
+23 -22
View File
@@ -6,10 +6,10 @@
AstNode* n = (AstNode*)arena_alloc_impl(alloc, sizeof(AstNode)); \ AstNode* n = (AstNode*)arena_alloc_impl(alloc, sizeof(AstNode)); \
if (!n) return NULL; \ if (!n) return NULL; \
n->kind = (k); n->type.kind = TYPE_UNKNOWN; n->type.struct_name = NULL; \ n->kind = (k); n->type.kind = TYPE_UNKNOWN; n->type.struct_name = NULL; \
n->line = line; n->col = col n->loc = loc
AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count,
AstNode** structs, size_t struct_count, int line, int col) { AstNode** structs, size_t struct_count, SourceLoc loc) {
NEW(alloc, AST_PROGRAM); NEW(alloc, AST_PROGRAM);
n->as.program.functions = fns; n->as.program.functions = fns;
n->as.program.fn_count = fn_count; n->as.program.fn_count = fn_count;
@@ -19,30 +19,31 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count,
} }
AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount,
TypeKind ret, AstNode* body, int line, int col) { TypeKind ret, const char* ret_struct_name, AstNode* body, SourceLoc loc) {
NEW(alloc, AST_FUNCTION); NEW(alloc, AST_FUNCTION);
n->as.function.name = name; n->as.function.params = params; n->as.function.name = name; n->as.function.params = params;
n->as.function.param_count = pcount; n->as.function.return_type = ret; n->as.function.param_count = pcount; n->as.function.return_type = ret;
n->as.function.return_struct_type_name = ret_struct_name;
n->as.function.body = body; n->as.function.body = body;
return n; return n;
} }
AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type,
const char* struct_type_name, int line, int col) { const char* struct_type_name, SourceLoc loc) {
NEW(alloc, AST_PARAMETER); NEW(alloc, AST_PARAMETER);
n->as.parameter.name = name; n->as.parameter.type = type; n->as.parameter.name = name; n->as.parameter.type = type;
n->as.parameter.struct_type_name = struct_type_name; n->as.parameter.struct_type_name = struct_type_name;
return n; return n;
} }
AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, int line, int col) { AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc) {
NEW(alloc, AST_BLOCK); NEW(alloc, AST_BLOCK);
n->as.block.stmts = stmts; n->as.block.stmt_count = count; n->as.block.stmts = stmts; n->as.block.stmt_count = count;
return n; return n;
} }
AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot, AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot,
bool is_mut, AstNode* init, const char* struct_type_name, int line, int col) { bool is_mut, AstNode* init, const char* struct_type_name, SourceLoc loc) {
NEW(alloc, AST_LET_STMT); NEW(alloc, AST_LET_STMT);
n->as.let_stmt.name = name; n->as.let_stmt.annot_type = annot_type; n->as.let_stmt.name = name; n->as.let_stmt.annot_type = annot_type;
n->as.let_stmt.has_type_annot = has_type_annot; n->as.let_stmt.is_mut = is_mut; n->as.let_stmt.has_type_annot = has_type_annot; n->as.let_stmt.is_mut = is_mut;
@@ -51,84 +52,84 @@ AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool h
return n; return n;
} }
AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, int line, int col) { AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, SourceLoc loc) {
NEW(alloc, AST_ASSIGN_STMT); NEW(alloc, AST_ASSIGN_STMT);
n->as.assign_stmt.name = name; n->as.assign_stmt.value = value; n->as.assign_stmt.name = name; n->as.assign_stmt.value = value;
return n; return n;
} }
AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, int line, int col) { AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, SourceLoc loc) {
NEW(alloc, AST_IF_STMT); NEW(alloc, AST_IF_STMT);
n->as.if_stmt.cond = cond; n->as.if_stmt.then_block = then_b; n->as.if_stmt.cond = cond; n->as.if_stmt.then_block = then_b;
n->as.if_stmt.else_block = else_b; n->as.if_stmt.else_block = else_b;
return n; return n;
} }
AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, int line, int col) { AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, SourceLoc loc) {
NEW(alloc, AST_WHILE_STMT); NEW(alloc, AST_WHILE_STMT);
n->as.while_stmt.cond = cond; n->as.while_stmt.body = body; n->as.while_stmt.cond = cond; n->as.while_stmt.body = body;
return n; return n;
} }
AstNode* ast_make_return(void* alloc, AstNode* expr, int line, int col) { AstNode* ast_make_return(void* alloc, AstNode* expr, SourceLoc loc) {
NEW(alloc, AST_RETURN_STMT); NEW(alloc, AST_RETURN_STMT);
n->as.return_stmt.expr = expr; n->as.return_stmt.expr = expr;
return n; return n;
} }
AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, int line, int col) { AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, SourceLoc loc) {
NEW(alloc, AST_EXPR_STMT); NEW(alloc, AST_EXPR_STMT);
n->as.expr_stmt.expr = expr; n->as.expr_stmt.expr = expr;
return n; return n;
} }
AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, int line, int col) { AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, SourceLoc loc) {
NEW(alloc, AST_BINARY_EXPR); NEW(alloc, AST_BINARY_EXPR);
n->as.binary.op = op; n->as.binary.left = left; n->as.binary.right = right; n->as.binary.op = op; n->as.binary.left = left; n->as.binary.right = right;
return n; return n;
} }
AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, int line, int col) { AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, SourceLoc loc) {
NEW(alloc, AST_UNARY_EXPR); NEW(alloc, AST_UNARY_EXPR);
n->as.unary.op = op; n->as.unary.operand = operand; n->as.unary.op = op; n->as.unary.operand = operand;
return n; return n;
} }
AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, int line, int col) { AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, SourceLoc loc) {
NEW(alloc, AST_CALL_EXPR); NEW(alloc, AST_CALL_EXPR);
n->as.call.name = name; n->as.call.args = args; n->as.call.arg_count = count; n->as.call.name = name; n->as.call.args = args; n->as.call.arg_count = count;
return n; return n;
} }
AstNode* ast_make_literal_i64(void* alloc, int64_t val, int line, int col) { AstNode* ast_make_literal_i64(void* alloc, int64_t val, SourceLoc loc) {
NEW(alloc, AST_LITERAL_EXPR); NEW(alloc, AST_LITERAL_EXPR);
n->as.literal.lit_type = TYPE_I64; n->as.literal.i64_val = val; n->as.literal.lit_type = TYPE_I64; n->as.literal.i64_val = val;
n->type.kind = TYPE_I64; n->type.kind = TYPE_I64;
return n; return n;
} }
AstNode* ast_make_literal_f64(void* alloc, double val, int line, int col) { AstNode* ast_make_literal_f64(void* alloc, double val, SourceLoc loc) {
NEW(alloc, AST_LITERAL_EXPR); NEW(alloc, AST_LITERAL_EXPR);
n->as.literal.lit_type = TYPE_F64; n->as.literal.f64_val = val; n->as.literal.lit_type = TYPE_F64; n->as.literal.f64_val = val;
n->type.kind = TYPE_F64; n->type.kind = TYPE_F64;
return n; return n;
} }
AstNode* ast_make_literal_bool(void* alloc, bool val, int line, int col) { AstNode* ast_make_literal_bool(void* alloc, bool val, SourceLoc loc) {
NEW(alloc, AST_LITERAL_EXPR); NEW(alloc, AST_LITERAL_EXPR);
n->as.literal.lit_type = TYPE_BOOL; n->as.literal.bool_val = val; n->as.literal.lit_type = TYPE_BOOL; n->as.literal.bool_val = val;
n->type.kind = TYPE_BOOL; n->type.kind = TYPE_BOOL;
return n; return n;
} }
AstNode* ast_make_literal_str(void* alloc, const char* val, int line, int col) { AstNode* ast_make_literal_str(void* alloc, const char* val, SourceLoc loc) {
NEW(alloc, AST_LITERAL_EXPR); NEW(alloc, AST_LITERAL_EXPR);
n->as.literal.lit_type = TYPE_STR; n->as.literal.str_val = val; n->as.literal.lit_type = TYPE_STR; n->as.literal.str_val = val;
n->type.kind = TYPE_STR; n->type.kind = TYPE_STR;
return n; return n;
} }
AstNode* ast_make_ident(void* alloc, const char* name, int line, int col) { AstNode* ast_make_ident(void* alloc, const char* name, SourceLoc loc) {
NEW(alloc, AST_IDENT_EXPR); NEW(alloc, AST_IDENT_EXPR);
n->as.ident.name = name; n->as.ident.name = name;
return n; return n;
@@ -137,7 +138,7 @@ AstNode* ast_make_ident(void* alloc, const char* name, int line, int col) {
// === 结构体相关工厂函数 === // === 结构体相关工厂函数 ===
AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields,
size_t count, int line, int col) { size_t count, SourceLoc loc) {
NEW(alloc, AST_STRUCT_DECL); NEW(alloc, AST_STRUCT_DECL);
n->as.struct_decl.name = name; n->as.struct_decl.name = name;
n->as.struct_decl.fields = fields; n->as.struct_decl.fields = fields;
@@ -147,7 +148,7 @@ AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields,
AstNode* ast_make_struct_init(void* alloc, const char* type_name, AstNode* ast_make_struct_init(void* alloc, const char* type_name,
const char** fnames, AstNode** fvals, const char** fnames, AstNode** fvals,
size_t count, int line, int col) { size_t count, SourceLoc loc) {
NEW(alloc, AST_STRUCT_INIT); NEW(alloc, AST_STRUCT_INIT);
n->as.struct_init.type_name = type_name; n->as.struct_init.type_name = type_name;
n->as.struct_init.field_names = fnames; n->as.struct_init.field_names = fnames;
@@ -157,7 +158,7 @@ AstNode* ast_make_struct_init(void* alloc, const char* type_name,
} }
AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field,
int line, int col) { SourceLoc loc) {
NEW(alloc, AST_FIELD_ACCESS); NEW(alloc, AST_FIELD_ACCESS);
n->as.field_access.object = object; n->as.field_access.object = object;
n->as.field_access.field = field; n->as.field_access.field = field;
+24 -24
View File
@@ -42,8 +42,7 @@ typedef struct {
struct AstNode { struct AstNode {
AstKind kind; AstKind kind;
TypeInfo type; // 语义分析后填充,默认为 TYPE_UNKNOWN TypeInfo type; // 语义分析后填充,默认为 TYPE_UNKNOWN
int line; // 源文件行号 SourceLoc loc; // 源码位置
int col; // 源文件列号
// 节点特有数据(按 kind 解释) // 节点特有数据(按 kind 解释)
union { union {
@@ -52,7 +51,8 @@ struct AstNode {
struct AstNode** structs; size_t struct_count; } program; struct AstNode** structs; size_t struct_count; } program;
// AST_FUNCTION // AST_FUNCTION
struct { const char* name; struct AstNode** params; size_t param_count; struct { const char* name; struct AstNode** params; size_t param_count;
TypeKind return_type; struct AstNode* body; } function; TypeKind return_type; const char* return_struct_type_name;
struct AstNode* body; } function;
// AST_PARAMETER (也用作结构体字段: name + type) // AST_PARAMETER (也用作结构体字段: name + type)
struct { const char* name; TypeKind type; const char* struct_type_name; } parameter; struct { const char* name; TypeKind type; const char* struct_type_name; } parameter;
// AST_BLOCK // AST_BLOCK
@@ -92,28 +92,28 @@ struct AstNode {
// 创建节点的辅助函数(内存来自 arena,通过 void* 传递避免循环依赖) // 创建节点的辅助函数(内存来自 arena,通过 void* 传递避免循环依赖)
AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count,
AstNode** structs, size_t struct_count, int line, int col); AstNode** structs, size_t struct_count, SourceLoc loc);
AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount,
TypeKind ret, AstNode* body, int line, int col); TypeKind ret, const char* ret_struct_name, AstNode* body, SourceLoc loc);
AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, int line, int col); AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc);
AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, int line, int col); AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc);
AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot, AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot,
bool is_mut, AstNode* init, const char* struct_type_name, int line, int col); bool is_mut, AstNode* init, const char* struct_type_name, SourceLoc loc);
AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, int line, int col); AstNode* ast_make_assign(void* alloc, const char* name, AstNode* value, SourceLoc loc);
AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, int line, int col); AstNode* ast_make_if(void* alloc, AstNode* cond, AstNode* then_b, AstNode* else_b, SourceLoc loc);
AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, int line, int col); AstNode* ast_make_while(void* alloc, AstNode* cond, AstNode* body, SourceLoc loc);
AstNode* ast_make_return(void* alloc, AstNode* expr, int line, int col); AstNode* ast_make_return(void* alloc, AstNode* expr, SourceLoc loc);
AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, int line, int col); AstNode* ast_make_expr_stmt(void* alloc, AstNode* expr, SourceLoc loc);
AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, int line, int col); AstNode* ast_make_binary(void* alloc, BinaryOp op, AstNode* left, AstNode* right, SourceLoc loc);
AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, int line, int col); AstNode* ast_make_unary(void* alloc, BinaryOp op, AstNode* operand, SourceLoc loc);
AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, int line, int col); AstNode* ast_make_call(void* alloc, const char* name, AstNode** args, size_t count, SourceLoc loc);
AstNode* ast_make_literal_i64(void* alloc, int64_t val, int line, int col); AstNode* ast_make_literal_i64(void* alloc, int64_t val, SourceLoc loc);
AstNode* ast_make_literal_f64(void* alloc, double val, int line, int col); AstNode* ast_make_literal_f64(void* alloc, double val, SourceLoc loc);
AstNode* ast_make_literal_bool(void* alloc, bool val, int line, int col); AstNode* ast_make_literal_bool(void* alloc, bool val, SourceLoc loc);
AstNode* ast_make_literal_str(void* alloc, const char* val, int line, int col); AstNode* ast_make_literal_str(void* alloc, const char* val, SourceLoc loc);
AstNode* ast_make_ident(void* alloc, const char* name, int line, int col); AstNode* ast_make_ident(void* alloc, const char* name, SourceLoc loc);
AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, size_t count, int line, int col); AstNode* ast_make_struct_decl(void* alloc, const char* name, AstNode** fields, size_t count, SourceLoc loc);
AstNode* ast_make_struct_init(void* alloc, const char* type_name, const char** fnames, AstNode** fvals, size_t count, int line, int col); AstNode* ast_make_struct_init(void* alloc, const char* type_name, const char** fnames, AstNode** fvals, size_t count, SourceLoc loc);
AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, int line, int col); AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, SourceLoc loc);
#endif #endif
+33 -7
View File
@@ -608,10 +608,23 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
AstNode* fn = ast->as.program.functions[i]; AstNode* fn = ast->as.program.functions[i];
LLVMTypeRef* ptypes = arena_alloc(ctx.arena, LLVMTypeRef* ptypes = arena_alloc(ctx.arena,
fn->as.function.param_count * sizeof(LLVMTypeRef)); fn->as.function.param_count * sizeof(LLVMTypeRef));
for (size_t j = 0; j < fn->as.function.param_count; j++) for (size_t j = 0; j < fn->as.function.param_count; j++) {
ptypes[j] = to_llvm_type(&ctx, fn->as.function.params[j]->as.parameter.type); AstNode* param = fn->as.function.params[j];
LLVMTypeRef fty = LLVMFunctionType( if (param->as.parameter.type == TYPE_STRUCT &&
to_llvm_type(&ctx, fn->as.function.return_type), param->as.parameter.struct_type_name) {
ptypes[j] = find_struct_type(&ctx, param->as.parameter.struct_type_name);
} else {
ptypes[j] = to_llvm_type(&ctx, param->as.parameter.type);
}
}
LLVMTypeRef ret_ty;
if (fn->as.function.return_type == TYPE_STRUCT &&
fn->as.function.return_struct_type_name) {
ret_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name);
} else {
ret_ty = to_llvm_type(&ctx, fn->as.function.return_type);
}
LLVMTypeRef fty = LLVMFunctionType(ret_ty,
ptypes, (unsigned)fn->as.function.param_count, false); ptypes, (unsigned)fn->as.function.param_count, false);
LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty); LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty);
add_fn(&ctx, fn->as.function.name, lfn); add_fn(&ctx, fn->as.function.name, lfn);
@@ -630,11 +643,18 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
// 将参数注册为变量 // 将参数注册为变量
for (size_t j = 0; j < fn->as.function.param_count; j++) { for (size_t j = 0; j < fn->as.function.param_count; j++) {
LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j); LLVMValueRef param = LLVMGetParam(lfn, (unsigned)j);
AstNode* pnode = fn->as.function.params[j];
LLVMTypeRef param_ty;
if (pnode->as.parameter.type == TYPE_STRUCT &&
pnode->as.parameter.struct_type_name) {
param_ty = find_struct_type(&ctx, pnode->as.parameter.struct_type_name);
} else {
param_ty = to_llvm_type(&ctx, pnode->as.parameter.type);
}
LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder, LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder,
to_llvm_type(&ctx, fn->as.function.params[j]->as.parameter.type), param_ty, pnode->as.parameter.name);
fn->as.function.params[j]->as.parameter.name);
LLVMBuildStore(ctx.builder, param, alloca); LLVMBuildStore(ctx.builder, param, alloca);
add_var(&ctx, fn->as.function.params[j]->as.parameter.name, alloca); add_var(&ctx, pnode->as.parameter.name, alloca);
} }
codegen_stmt(&ctx, fn->as.function.body); codegen_stmt(&ctx, fn->as.function.body);
@@ -645,6 +665,12 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena,
cleanup_emit(&ctx, 0); cleanup_emit(&ctx, 0);
if (fn->as.function.return_type == TYPE_VOID) if (fn->as.function.return_type == TYPE_VOID)
LLVMBuildRetVoid(ctx.builder); LLVMBuildRetVoid(ctx.builder);
else if (fn->as.function.return_type == TYPE_STRUCT &&
fn->as.function.return_struct_type_name) {
LLVMTypeRef st_ty = find_struct_type(&ctx, fn->as.function.return_struct_type_name);
LLVMBuildRet(ctx.builder, st_ty ? LLVMConstNull(st_ty) :
LLVMConstInt(to_llvm_type(&ctx, TYPE_I64), 0, false));
}
else else
LLVMBuildRet(ctx.builder, LLVMBuildRet(ctx.builder,
(fn->as.function.return_type == TYPE_F64 (fn->as.function.return_type == TYPE_F64
+5
View File
@@ -41,6 +41,11 @@ struct Token {
const char* tok_name(TokenKind kind); const char* tok_name(TokenKind kind);
bool tok_is_type(TokenKind kind); bool tok_is_type(TokenKind kind);
// 从 Token 创建 SourceLoc
static inline SourceLoc tok_loc(const Token* tok) {
return (SourceLoc){ .line = tok->line, .col = tok->col };
}
// 从 Token 提取值 // 从 Token 提取值
int64_t tok_int_value(const Token* tok); int64_t tok_int_value(const Token* tok);
double tok_float_value(const Token* tok); double tok_float_value(const Token* tok);
+52 -40
View File
@@ -78,7 +78,7 @@ static AstNode* parse_unary(Parser* p, ErrorInfo* error) {
AstNode* operand = parse_expr_prec(p, PREC_UNARY, error); AstNode* operand = parse_expr_prec(p, PREC_UNARY, error);
if (!operand) return NULL; if (!operand) return NULL;
BinaryOp uop = (op->kind == TOK_MINUS) ? OP_NEG : OP_NOT; BinaryOp uop = (op->kind == TOK_MINUS) ? OP_NEG : OP_NOT;
return ast_make_unary(p->arena, uop, operand, op->line, op->col); return ast_make_unary(p->arena, uop, operand, tok_loc(op));
} }
static AstNode* parse_group(Parser* p, ErrorInfo* error) { static AstNode* parse_group(Parser* p, ErrorInfo* error) {
@@ -92,15 +92,15 @@ static AstNode* parse_group(Parser* p, ErrorInfo* error) {
static AstNode* parse_literal(Parser* p) { static AstNode* parse_literal(Parser* p) {
const Token* t = advance(p); const Token* t = advance(p);
switch (t->kind) { switch (t->kind) {
case TOK_INT_LIT: return ast_make_literal_i64(p->arena, tok_int_value(t), t->line, t->col); case TOK_INT_LIT: return ast_make_literal_i64(p->arena, tok_int_value(t), tok_loc(t));
case TOK_FLOAT_LIT: return ast_make_literal_f64(p->arena, tok_float_value(t), t->line, t->col); case TOK_FLOAT_LIT: return ast_make_literal_f64(p->arena, tok_float_value(t), tok_loc(t));
case TOK_TRUE: return ast_make_literal_bool(p->arena, true, t->line, t->col); case TOK_TRUE: return ast_make_literal_bool(p->arena, true, tok_loc(t));
case TOK_FALSE: return ast_make_literal_bool(p->arena, false, t->line, t->col); case TOK_FALSE: return ast_make_literal_bool(p->arena, false, tok_loc(t));
case TOK_STR_LIT: { case TOK_STR_LIT: {
char* str = arena_alloc_impl(p->arena, t->length + 1); char* str = arena_alloc_impl(p->arena, t->length + 1);
memcpy(str, t->start, t->length); memcpy(str, t->start, t->length);
str[t->length] = '\0'; str[t->length] = '\0';
return ast_make_literal_str(p->arena, str, t->line, t->col); return ast_make_literal_str(p->arena, str, tok_loc(t));
} }
default: return NULL; default: return NULL;
} }
@@ -137,7 +137,7 @@ static AstNode* parse_struct_init(Parser* p, const Token* name, ErrorInfo* error
return ast_make_struct_init(p->arena, return ast_make_struct_init(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
n_arr, v_arr, fcount, name->line, name->col); n_arr, v_arr, fcount, tok_loc(name));
} }
// === 标识符 / 函数调用 / 结构体初始化 === // === 标识符 / 函数调用 / 结构体初始化 ===
@@ -175,11 +175,11 @@ static AstNode* parse_ident_or_call(Parser* p, ErrorInfo* error) {
AstNode** arg_arr = arena_alloc_impl(p->arena, arg_count * sizeof(AstNode*)); AstNode** arg_arr = arena_alloc_impl(p->arena, arg_count * sizeof(AstNode*));
memcpy(arg_arr, args, arg_count * sizeof(AstNode*)); memcpy(arg_arr, args, arg_count * sizeof(AstNode*));
return ast_make_call(p->arena, arena_strdup_impl(p->arena, name->start, name->length), return ast_make_call(p->arena, arena_strdup_impl(p->arena, name->start, name->length),
arg_arr, arg_count, name->line, name->col); arg_arr, arg_count, tok_loc(name));
} }
return ast_make_ident(p->arena, return ast_make_ident(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
name->line, name->col); tok_loc(name));
} }
// === Pratt 主循环 === // === Pratt 主循环 ===
@@ -216,7 +216,7 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error
if (!field) return NULL; if (!field) return NULL;
left = ast_make_field_access(p->arena, left, left = ast_make_field_access(p->arena, left,
arena_strdup_impl(p->arena, field->start, field->length), arena_strdup_impl(p->arena, field->start, field->length),
field->line, field->col); tok_loc(field));
continue; continue;
} }
@@ -227,7 +227,7 @@ static AstNode* parse_expr_prec(Parser* p, Precedence min_prec, ErrorInfo* error
const Token* op = advance(p); const Token* op = advance(p);
AstNode* right = parse_expr_prec(p, prec, error); AstNode* right = parse_expr_prec(p, prec, error);
if (!right) return NULL; if (!right) return NULL;
left = ast_make_binary(p->arena, tok_to_binop(kind), left, right, op->line, op->col); left = ast_make_binary(p->arena, tok_to_binop(kind), left, right, tok_loc(op));
} }
return left; return left;
@@ -271,7 +271,7 @@ static AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) {
} }
fields[fcount++] = ast_make_parameter(p->arena, fields[fcount++] = ast_make_parameter(p->arena,
arena_strdup_impl(p->arena, fname->start, fname->length), arena_strdup_impl(p->arena, fname->start, fname->length),
field_kind, field_struct_name, fname->line, fname->col); field_kind, field_struct_name, tok_loc(fname));
if (peek(p)->kind == TOK_COMMA) advance(p); if (peek(p)->kind == TOK_COMMA) advance(p);
else break; else break;
} }
@@ -281,7 +281,7 @@ static AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) {
memcpy(farr, fields, fcount * sizeof(AstNode*)); memcpy(farr, fields, fcount * sizeof(AstNode*));
return ast_make_struct_decl(p->arena, return ast_make_struct_decl(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
farr, fcount, s_tok->line, s_tok->col); farr, fcount, tok_loc(s_tok));
} }
// === 语句解析 === // === 语句解析 ===
@@ -305,7 +305,7 @@ static AstNode* parse_block(Parser* p, ErrorInfo* error) {
AstNode** arr = arena_alloc_impl(p->arena, count * sizeof(AstNode*)); AstNode** arr = arena_alloc_impl(p->arena, count * sizeof(AstNode*));
memcpy(arr, stmts, count * sizeof(AstNode*)); memcpy(arr, stmts, count * sizeof(AstNode*));
parse_depth--; parse_depth--;
return ast_make_block(p->arena, arr, count, open->line, open->col); return ast_make_block(p->arena, arr, count, tok_loc(open));
} }
static AstNode* parse_statement(Parser* p, ErrorInfo* error) { static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
@@ -342,7 +342,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL;
return ast_make_let(p->arena, return ast_make_let(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
annot_type, has_type_annot, is_mut, init, struct_type_name, t->line, t->col); annot_type, has_type_annot, is_mut, init, struct_type_name, tok_loc(t));
} }
if (t->kind == TOK_IF) { if (t->kind == TOK_IF) {
@@ -360,7 +360,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
} }
if (!else_block) return NULL; if (!else_block) return NULL;
} }
return ast_make_if(p->arena, cond, then_block, else_block, t->line, t->col); return ast_make_if(p->arena, cond, then_block, else_block, tok_loc(t));
} }
if (t->kind == TOK_WHILE) { if (t->kind == TOK_WHILE) {
@@ -369,7 +369,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
if (!cond) return NULL; if (!cond) return NULL;
AstNode* body = parse_block(p, error); AstNode* body = parse_block(p, error);
if (!body) return NULL; if (!body) return NULL;
return ast_make_while(p->arena, cond, body, t->line, t->col); return ast_make_while(p->arena, cond, body, tok_loc(t));
} }
if (t->kind == TOK_FOR) { if (t->kind == TOK_FOR) {
@@ -403,20 +403,20 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
const char* vname = arena_strdup_impl(p->arena, var_name->start, var_name->length); const char* vname = arena_strdup_impl(p->arena, var_name->start, var_name->length);
// 构建: let mut i = start; // 构建: let mut i = start;
AstNode* let_stmt = ast_make_let(p->arena, vname, TYPE_UNKNOWN, false, true, start_expr, NULL, var_name->line, var_name->col); AstNode* let_stmt = ast_make_let(p->arena, vname, TYPE_UNKNOWN, false, true, start_expr, NULL, tok_loc(var_name));
// 构建: i < end (while 条件) // 构建: i < end (while 条件)
AstNode* cond = ast_make_binary(p->arena, OP_LT, AstNode* cond = ast_make_binary(p->arena, OP_LT,
ast_make_ident(p->arena, vname, var_name->line, var_name->col), ast_make_ident(p->arena, vname, tok_loc(var_name)),
end_expr, var_name->line, var_name->col); end_expr, tok_loc(var_name));
// 构建: i = i + 1 (循环增量) // 构建: i = i + 1 (循环增量)
AstNode* incr = ast_make_assign(p->arena, vname, AstNode* incr = ast_make_assign(p->arena, vname,
ast_make_binary(p->arena, OP_ADD, ast_make_binary(p->arena, OP_ADD,
ast_make_ident(p->arena, vname, var_name->line, var_name->col), ast_make_ident(p->arena, vname, tok_loc(var_name)),
ast_make_literal_i64(p->arena, 1, var_name->line, var_name->col), ast_make_literal_i64(p->arena, 1, tok_loc(var_name)),
var_name->line, var_name->col), tok_loc(var_name)),
var_name->line, var_name->col); tok_loc(var_name));
// 将增量追加到循环体末尾 // 将增量追加到循环体末尾
AstNode** new_stmts = arena_alloc_impl(p->arena, AstNode** new_stmts = arena_alloc_impl(p->arena,
@@ -424,27 +424,27 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
memcpy(new_stmts, body->as.block.stmts, body->as.block.stmt_count * sizeof(AstNode*)); memcpy(new_stmts, body->as.block.stmts, body->as.block.stmt_count * sizeof(AstNode*));
new_stmts[body->as.block.stmt_count] = incr; new_stmts[body->as.block.stmt_count] = incr;
AstNode* new_body = ast_make_block(p->arena, new_stmts, AstNode* new_body = ast_make_block(p->arena, new_stmts,
body->as.block.stmt_count + 1, body->line, body->col); body->as.block.stmt_count + 1, body->loc);
// 构建: while i < end { ... body ... ; i = i + 1; } // 构建: while i < end { ... body ... ; i = i + 1; }
AstNode* while_loop = ast_make_while(p->arena, cond, new_body, t->line, t->col); AstNode* while_loop = ast_make_while(p->arena, cond, new_body, tok_loc(t));
// 包装: { let mut i = start; while i < end { ... } } // 包装: { let mut i = start; while i < end { ... } }
AstNode* stmts_arr[2] = { let_stmt, while_loop }; AstNode* stmts_arr[2] = { let_stmt, while_loop };
AstNode** stmts = arena_alloc_impl(p->arena, 2 * sizeof(AstNode*)); AstNode** stmts = arena_alloc_impl(p->arena, 2 * sizeof(AstNode*));
memcpy(stmts, stmts_arr, 2 * sizeof(AstNode*)); memcpy(stmts, stmts_arr, 2 * sizeof(AstNode*));
return ast_make_block(p->arena, stmts, 2, t->line, t->col); return ast_make_block(p->arena, stmts, 2, tok_loc(t));
} }
if (t->kind == TOK_RETURN) { if (t->kind == TOK_RETURN) {
advance(p); advance(p);
if (match(p, TOK_SEMICOLON)) { if (match(p, TOK_SEMICOLON)) {
return ast_make_return(p->arena, NULL, t->line, t->col); return ast_make_return(p->arena, NULL, tok_loc(t));
} }
AstNode* expr = parse_expr(p, error); AstNode* expr = parse_expr(p, error);
if (!expr) return NULL; if (!expr) return NULL;
if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL;
return ast_make_return(p->arena, expr, t->line, t->col); return ast_make_return(p->arena, expr, tok_loc(t));
} }
// 赋值语句: ident = expr ; // 赋值语句: ident = expr ;
@@ -456,7 +456,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL;
return ast_make_assign(p->arena, return ast_make_assign(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
value, name->line, name->col); value, tok_loc(name));
} }
// 复合赋值: ident += expr → ident = ident + expr // 复合赋值: ident += expr → ident = ident + expr
@@ -481,12 +481,12 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
AstNode* lhs_ident = ast_make_ident(p->arena, AstNode* lhs_ident = ast_make_ident(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
name->line, name->col); tok_loc(name));
AstNode* bin_expr = ast_make_binary(p->arena, binop, lhs_ident, rhs, AstNode* bin_expr = ast_make_binary(p->arena, binop, lhs_ident, rhs,
name->line, name->col); tok_loc(name));
return ast_make_assign(p->arena, return ast_make_assign(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
bin_expr, name->line, name->col); bin_expr, tok_loc(name));
} }
} }
@@ -494,7 +494,7 @@ static AstNode* parse_statement(Parser* p, ErrorInfo* error) {
AstNode* expr = parse_expr(p, error); AstNode* expr = parse_expr(p, error);
if (!expr) return NULL; if (!expr) return NULL;
if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL; if (!expect(p, TOK_SEMICOLON, error, "缺少 ';'")) return NULL;
return ast_make_expr_stmt(p->arena, expr, t->line, t->col); return ast_make_expr_stmt(p->arena, expr, tok_loc(t));
} }
// === 函数解析 === // === 函数解析 ===
@@ -512,13 +512,20 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) {
if (!pname) return NULL; if (!pname) return NULL;
if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL;
const Token* ptype = advance(p); const Token* ptype = advance(p);
if (!tok_is_type(ptype->kind)) { TypeKind param_kind;
const char* param_struct_name = NULL;
if (tok_is_type(ptype->kind)) {
param_kind = token_to_type(ptype->kind);
} else if (ptype->kind == TOK_IDENT) {
param_kind = TYPE_STRUCT;
param_struct_name = arena_strdup_impl(p->arena, ptype->start, ptype->length);
} else {
error->message = "无效的参数类型"; error->filename = p->filename; error->message = "无效的参数类型"; error->filename = p->filename;
error->line = ptype->line; error->col = ptype->col; return NULL; error->line = ptype->line; error->col = ptype->col; return NULL;
} }
params[pcount++] = ast_make_parameter(p->arena, params[pcount++] = ast_make_parameter(p->arena,
arena_strdup_impl(p->arena, pname->start, pname->length), arena_strdup_impl(p->arena, pname->start, pname->length),
token_to_type(ptype->kind), NULL, pname->line, pname->col); param_kind, param_struct_name, tok_loc(pname));
if (match(p, TOK_COMMA)) continue; if (match(p, TOK_COMMA)) continue;
else break; else break;
} }
@@ -526,13 +533,18 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) {
// 返回类型 // 返回类型
TypeKind ret = TYPE_VOID; TypeKind ret = TYPE_VOID;
const char* ret_struct_name = NULL;
if (match(p, TOK_ARROW)) { if (match(p, TOK_ARROW)) {
const Token* rt = advance(p); const Token* rt = advance(p);
if (!tok_is_type(rt->kind)) { if (tok_is_type(rt->kind)) {
ret = token_to_type(rt->kind);
} else if (rt->kind == TOK_IDENT) {
ret = TYPE_STRUCT;
ret_struct_name = arena_strdup_impl(p->arena, rt->start, rt->length);
} else {
error->message = "无效的返回类型"; error->filename = p->filename; error->message = "无效的返回类型"; error->filename = p->filename;
error->line = rt->line; error->col = rt->col; return NULL; error->line = rt->line; error->col = rt->col; return NULL;
} }
ret = token_to_type(rt->kind);
} }
AstNode* body = parse_block(p, error); AstNode* body = parse_block(p, error);
@@ -542,7 +554,7 @@ static AstNode* parse_function(Parser* p, ErrorInfo* error) {
memcpy(parr, params, pcount * sizeof(AstNode*)); memcpy(parr, params, pcount * sizeof(AstNode*));
return ast_make_function(p->arena, return ast_make_function(p->arena,
arena_strdup_impl(p->arena, name->start, name->length), arena_strdup_impl(p->arena, name->start, name->length),
parr, pcount, ret, body, fn_tok->line, fn_tok->col); parr, pcount, ret, ret_struct_name, body, tok_loc(fn_tok));
} }
// === 程序入口 === // === 程序入口 ===
@@ -572,5 +584,5 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count,
memcpy(fn_arr, functions, fn_count * sizeof(AstNode*)); memcpy(fn_arr, functions, fn_count * sizeof(AstNode*));
AstNode** st_arr = arena_alloc_impl(a, struct_count * sizeof(AstNode*)); AstNode** st_arr = arena_alloc_impl(a, struct_count * sizeof(AstNode*));
memcpy(st_arr, structs, struct_count * sizeof(AstNode*)); memcpy(st_arr, structs, struct_count * sizeof(AstNode*));
return ast_make_program(a, fn_arr, fn_count, st_arr, struct_count, 0, 0); return ast_make_program(a, fn_arr, fn_count, st_arr, struct_count, loc_at(0, 0));
} }
+83 -41
View File
@@ -3,6 +3,7 @@
// === 类型关系 === // === 类型关系 ===
static TypeKind current_return_type = TYPE_VOID; static TypeKind current_return_type = TYPE_VOID;
static const char* current_return_struct_name = NULL;
static TypeKind promote(TypeKind a, TypeKind b) { static TypeKind promote(TypeKind a, TypeKind b) {
if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64; if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64;
@@ -26,11 +27,11 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
case AST_IDENT_EXPR: { case AST_IDENT_EXPR: {
Symbol* sym = scope_lookup(scope, node->as.ident.name); Symbol* sym = scope_lookup(scope, node->as.ident.name);
if (!sym) { if (!sym) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的变量 '%s'", node->as.ident.name); "未定义的变量 '%s'", node->as.ident.name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else if (sym->kind == SYM_FUNCTION) { } else if (sym->kind == SYM_FUNCTION) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 是函数,不能作为表达式使用", node->as.ident.name); "'%s' 是函数,不能作为表达式使用", node->as.ident.name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -47,7 +48,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
TypeKind inner = node->as.unary.operand->type.kind; TypeKind inner = node->as.unary.operand->type.kind;
if (node->as.unary.op == OP_NEG) { if (node->as.unary.op == OP_NEG) {
if (!is_numeric(inner)) { if (!is_numeric(inner)) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"一元 '-' 只能用于数值类型"); "一元 '-' 只能用于数值类型");
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -55,7 +56,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
} }
} else { // OP_NOT } else { // OP_NOT
if (inner != TYPE_BOOL) { if (inner != TYPE_BOOL) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'!' 只能用于布尔类型,得到 '%s'", type_name(inner)); "'!' 只能用于布尔类型,得到 '%s'", type_name(inner));
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -77,7 +78,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
if (l == TYPE_STR || r == TYPE_STR) { if (l == TYPE_STR || r == TYPE_STR) {
// 字符串拼接:两边都必须是 str 类型 // 字符串拼接:两边都必须是 str 类型
if (l != TYPE_STR || r != TYPE_STR) { if (l != TYPE_STR || r != TYPE_STR) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'", "字符串拼接需要两边都是 str 类型,得到 '%s' + '%s'",
type_name(l), type_name(r)); type_name(l), type_name(r));
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
@@ -85,7 +86,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
node->type.kind = TYPE_STR; node->type.kind = TYPE_STR;
} }
} else if (!is_numeric(l) || !is_numeric(r)) { } else if (!is_numeric(l) || !is_numeric(r)) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"算术运算需要数值类型"); "算术运算需要数值类型");
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -94,7 +95,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
break; break;
case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD: case OP_SUB: case OP_MUL: case OP_DIV: case OP_MOD:
if (!is_numeric(l) || !is_numeric(r)) { if (!is_numeric(l) || !is_numeric(r)) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"算术运算需要数值类型"); "算术运算需要数值类型");
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -103,7 +104,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
break; break;
case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE: case OP_EQ: case OP_NE: case OP_LT: case OP_GT: case OP_LE: case OP_GE:
if (!is_comparable(l, r)) { if (!is_comparable(l, r)) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r)); "类型 '%s' 和 '%s' 无法比较", type_name(l), type_name(r));
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -112,7 +113,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
break; break;
case OP_AND: case OP_OR: case OP_AND: case OP_OR:
if (l != TYPE_BOOL || r != TYPE_BOOL) { if (l != TYPE_BOOL || r != TYPE_BOOL) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"逻辑运算需要布尔类型"); "逻辑运算需要布尔类型");
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
} else { } else {
@@ -127,7 +128,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
case AST_CALL_EXPR: { case AST_CALL_EXPR: {
Symbol* sym = scope_lookup(scope, node->as.call.name); Symbol* sym = scope_lookup(scope, node->as.call.name);
if (!sym || sym->kind != SYM_FUNCTION) { if (!sym || sym->kind != SYM_FUNCTION) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的函数 '%s'", node->as.call.name); "未定义的函数 '%s'", node->as.call.name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
// 即使函数未定义,也要分析参数表达式(它们可能有更多错误) // 即使函数未定义,也要分析参数表达式(它们可能有更多错误)
@@ -137,7 +138,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
break; break;
} }
if (node->as.call.arg_count != sym->param_count) { if (node->as.call.arg_count != sym->param_count) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"函数 '%s' 需要 %zu 个参数,但提供了 %zu 个", "函数 '%s' 需要 %zu 个参数,但提供了 %zu 个",
node->as.call.name, sym->param_count, node->as.call.arg_count); node->as.call.name, sym->param_count, node->as.call.arg_count);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
@@ -151,13 +152,30 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
analyze_expr(node->as.call.args[i], scope, errors, a); analyze_expr(node->as.call.args[i], scope, errors, a);
TypeKind actual = node->as.call.args[i]->type.kind; TypeKind actual = node->as.call.args[i]->type.kind;
TypeKind expected = sym->param_types[i]; TypeKind expected = sym->param_types[i];
if (actual != TYPE_ERROR && actual != expected) { if (actual != TYPE_ERROR) {
error_add(errors, "<sema>", node->line, node->col, if (expected == TYPE_STRUCT) {
// 结构体参数:比较具体类型名
const char* actual_name = node->as.call.args[i]->type.struct_name;
const char* expected_name = sym->param_struct_names ? sym->param_struct_names[i] : NULL;
if (actual != TYPE_STRUCT || !actual_name || !expected_name ||
strcmp(actual_name, expected_name) != 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
i + 1,
expected_name ? expected_name : "struct",
actual_name ? actual_name : type_name(actual));
}
} else if (actual != expected) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'",
i + 1, type_name(expected), type_name(actual)); i + 1, type_name(expected), type_name(actual));
} }
} }
}
node->type.kind = sym->return_type; node->type.kind = sym->return_type;
if (sym->return_type == TYPE_STRUCT && sym->return_struct_type_name) {
node->type.struct_name = sym->return_struct_type_name;
}
break; break;
} }
@@ -169,7 +187,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
break; break;
} }
if (obj->type.kind != TYPE_STRUCT) { if (obj->type.kind != TYPE_STRUCT) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"类型 '%s' 不是结构体,不能访问字段 '%s'", "类型 '%s' 不是结构体,不能访问字段 '%s'",
type_name(obj->type.kind), node->as.field_access.field); type_name(obj->type.kind), node->as.field_access.field);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
@@ -178,21 +196,21 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
// 查找结构体定义 // 查找结构体定义
const char* struct_name = obj->type.struct_name; const char* struct_name = obj->type.struct_name;
if (!struct_name) { if (!struct_name) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"无法确定结构体类型"); "无法确定结构体类型");
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
break; break;
} }
Symbol* struct_sym = scope_lookup_struct(scope, struct_name); Symbol* struct_sym = scope_lookup_struct(scope, struct_name);
if (!struct_sym) { if (!struct_sym) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的结构体 '%s'", struct_name); "未定义的结构体 '%s'", struct_name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
break; break;
} }
int fi = scope_struct_field_index(struct_sym, node->as.field_access.field); int fi = scope_struct_field_index(struct_sym, node->as.field_access.field);
if (fi < 0) { if (fi < 0) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field); "结构体 '%s' 没有字段 '%s'", struct_name, node->as.field_access.field);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
break; break;
@@ -211,13 +229,13 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
case AST_STRUCT_INIT: { case AST_STRUCT_INIT: {
Symbol* struct_sym = scope_lookup_struct(scope, node->as.struct_init.type_name); Symbol* struct_sym = scope_lookup_struct(scope, node->as.struct_init.type_name);
if (!struct_sym) { if (!struct_sym) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的结构体类型 '%s'", node->as.struct_init.type_name); "未定义的结构体类型 '%s'", node->as.struct_init.type_name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
break; break;
} }
if (node->as.struct_init.field_count != struct_sym->struct_field_count) { if (node->as.struct_init.field_count != struct_sym->struct_field_count) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 有 %zu 个字段,但提供了 %zu 个", "结构体 '%s' 有 %zu 个字段,但提供了 %zu 个",
node->as.struct_init.type_name, node->as.struct_init.type_name,
struct_sym->struct_field_count, struct_sym->struct_field_count,
@@ -233,7 +251,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
int fi = scope_struct_field_index(struct_sym, fname); int fi = scope_struct_field_index(struct_sym, fname);
if (fi < 0) { if (fi < 0) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"结构体 '%s' 没有字段 '%s'", "结构体 '%s' 没有字段 '%s'",
node->as.struct_init.type_name, fname); node->as.struct_init.type_name, fname);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
@@ -242,7 +260,7 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena*
TypeKind expected = struct_sym->struct_field_types[fi]; TypeKind expected = struct_sym->struct_field_types[fi];
TypeKind actual = fval->type.kind; TypeKind actual = fval->type.kind;
if (actual != TYPE_ERROR && actual != expected) { if (actual != TYPE_ERROR && actual != expected) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'", "字段 '%s' 类型不匹配: 期望 '%s',得到 '%s'",
fname, type_name(expected), type_name(actual)); fname, type_name(expected), type_name(actual));
} }
@@ -285,11 +303,15 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
for (size_t i = 0; i < node->as.program.fn_count; i++) { for (size_t i = 0; i < node->as.program.fn_count; i++) {
AstNode* fn = node->as.program.functions[i]; AstNode* fn = node->as.program.functions[i];
TypeKind* pts = (TypeKind*)arena_alloc_impl(a, fn->as.function.param_count * sizeof(TypeKind)); TypeKind* pts = (TypeKind*)arena_alloc_impl(a, fn->as.function.param_count * sizeof(TypeKind));
const char** pstruct_names = (const char**)arena_alloc_impl(a, fn->as.function.param_count * sizeof(const char*));
for (size_t j = 0; j < fn->as.function.param_count; j++) { for (size_t j = 0; j < fn->as.function.param_count; j++) {
pts[j] = fn->as.function.params[j]->as.parameter.type; pts[j] = fn->as.function.params[j]->as.parameter.type;
pstruct_names[j] = fn->as.function.params[j]->as.parameter.struct_type_name;
} }
scope_insert_function(scope, a, fn->as.function.name, scope_insert_function(scope, a, fn->as.function.name,
fn->as.function.return_type, pts, fn->as.function.return_type,
fn->as.function.return_struct_type_name,
pts, pstruct_names,
fn->as.function.param_count); fn->as.function.param_count);
} }
// 第三遍:分析每个函数体 // 第三遍:分析每个函数体
@@ -303,12 +325,18 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
// 注册参数 // 注册参数
for (size_t i = 0; i < node->as.function.param_count; i++) { for (size_t i = 0; i < node->as.function.param_count; i++) {
AstNode* p = node->as.function.params[i]; AstNode* p = node->as.function.params[i];
scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); Symbol* sym = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type);
if (sym && p->as.parameter.type == TYPE_STRUCT && p->as.parameter.struct_type_name) {
sym->struct_type_name = p->as.parameter.struct_type_name;
}
} }
TypeKind saved = current_return_type; TypeKind saved = current_return_type;
const char* saved_name = current_return_struct_name;
current_return_type = node->as.function.return_type; current_return_type = node->as.function.return_type;
current_return_struct_name = node->as.function.return_struct_type_name;
analyze_node(node->as.function.body, fn_scope, errors, a); analyze_node(node->as.function.body, fn_scope, errors, a);
current_return_type = saved; current_return_type = saved;
current_return_struct_name = saved_name;
break; break;
} }
@@ -330,7 +358,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
// struct 类型标注 // struct 类型标注
Symbol* st_sym = scope_lookup_struct(scope, annot_struct); Symbol* st_sym = scope_lookup_struct(scope, annot_struct);
if (!st_sym) { if (!st_sym) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的结构体类型 '%s'", annot_struct); "未定义的结构体类型 '%s'", annot_struct);
break; break;
} }
@@ -340,7 +368,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
var_type = node->as.let_stmt.annot_type; var_type = node->as.let_stmt.annot_type;
} }
if (inferred != TYPE_ERROR && inferred != var_type) { if (inferred != TYPE_ERROR && inferred != var_type) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"变量 '%s' 类型标注为 '%s',但初始化表达式类型为 '%s'", "变量 '%s' 类型标注为 '%s',但初始化表达式类型为 '%s'",
node->as.let_stmt.name, node->as.let_stmt.name,
annot_struct ? annot_struct : type_name(var_type), annot_struct ? annot_struct : type_name(var_type),
@@ -349,7 +377,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
} else { } else {
// 类型推断 // 类型推断
if (inferred == TYPE_ERROR || inferred == TYPE_VOID) { if (inferred == TYPE_ERROR || inferred == TYPE_VOID) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"无法从表达式推断变量 '%s' 的类型", node->as.let_stmt.name); "无法从表达式推断变量 '%s' 的类型", node->as.let_stmt.name);
break; break;
} }
@@ -363,7 +391,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
node->type.struct_name = var_struct_name; node->type.struct_name = var_struct_name;
Symbol* sym = scope_insert(scope, a, node->as.let_stmt.name, SYM_VARIABLE, var_type); Symbol* sym = scope_insert(scope, a, node->as.let_stmt.name, SYM_VARIABLE, var_type);
if (!sym) { if (!sym) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"变量 '%s' 重复定义", node->as.let_stmt.name); "变量 '%s' 重复定义", node->as.let_stmt.name);
} else { } else {
sym->is_mut = node->as.let_stmt.is_mut; sym->is_mut = node->as.let_stmt.is_mut;
@@ -378,19 +406,19 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
case AST_ASSIGN_STMT: { case AST_ASSIGN_STMT: {
Symbol* sym = scope_lookup(scope, node->as.assign_stmt.name); Symbol* sym = scope_lookup(scope, node->as.assign_stmt.name);
if (!sym) { if (!sym) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"未定义的变量 '%s'", node->as.assign_stmt.name); "未定义的变量 '%s'", node->as.assign_stmt.name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
break; break;
} }
if (sym->kind != SYM_VARIABLE) { if (sym->kind != SYM_VARIABLE) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"'%s' 不是变量,不能赋值", node->as.assign_stmt.name); "'%s' 不是变量,不能赋值", node->as.assign_stmt.name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
break; break;
} }
if (!sym->is_mut) { if (!sym->is_mut) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"不能对不可变变量 '%s' 赋值(需用 let mut 声明)", "不能对不可变变量 '%s' 赋值(需用 let mut 声明)",
node->as.assign_stmt.name); node->as.assign_stmt.name);
node->type.kind = TYPE_ERROR; node->type.kind = TYPE_ERROR;
@@ -399,7 +427,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
analyze_expr(node->as.assign_stmt.value, scope, errors, a); analyze_expr(node->as.assign_stmt.value, scope, errors, a);
TypeKind value_ty = node->as.assign_stmt.value->type.kind; TypeKind value_ty = node->as.assign_stmt.value->type.kind;
if (value_ty != TYPE_ERROR && value_ty != sym->type) { if (value_ty != TYPE_ERROR && value_ty != sym->type) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"赋值类型不匹配: 变量 '%s' 类型为 '%s',但表达式类型为 '%s'", "赋值类型不匹配: 变量 '%s' 类型为 '%s',但表达式类型为 '%s'",
node->as.assign_stmt.name, type_name(sym->type), type_name(value_ty)); node->as.assign_stmt.name, type_name(sym->type), type_name(value_ty));
} }
@@ -411,7 +439,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
analyze_expr(node->as.if_stmt.cond, scope, errors, a); analyze_expr(node->as.if_stmt.cond, scope, errors, a);
if (node->as.if_stmt.cond->type.kind != TYPE_BOOL && if (node->as.if_stmt.cond->type.kind != TYPE_BOOL &&
node->as.if_stmt.cond->type.kind != TYPE_ERROR) { node->as.if_stmt.cond->type.kind != TYPE_ERROR) {
error_add(errors, "<sema>", node->line, node->col, "if 条件必须是布尔类型"); error_add(errors, "<sema>", node->loc.line, node->loc.col, "if 条件必须是布尔类型");
} }
analyze_node(node->as.if_stmt.then_block, scope, errors, a); analyze_node(node->as.if_stmt.then_block, scope, errors, a);
if (node->as.if_stmt.else_block) { if (node->as.if_stmt.else_block) {
@@ -423,7 +451,7 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
analyze_expr(node->as.while_stmt.cond, scope, errors, a); analyze_expr(node->as.while_stmt.cond, scope, errors, a);
if (node->as.while_stmt.cond->type.kind != TYPE_BOOL && if (node->as.while_stmt.cond->type.kind != TYPE_BOOL &&
node->as.while_stmt.cond->type.kind != TYPE_ERROR) { node->as.while_stmt.cond->type.kind != TYPE_ERROR) {
error_add(errors, "<sema>", node->line, node->col, "while 条件必须是布尔类型"); error_add(errors, "<sema>", node->loc.line, node->loc.col, "while 条件必须是布尔类型");
} }
analyze_node(node->as.while_stmt.body, scope, errors, a); analyze_node(node->as.while_stmt.body, scope, errors, a);
break; break;
@@ -434,14 +462,28 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena*
node->type.kind = node->as.return_stmt.expr->type.kind; node->type.kind = node->as.return_stmt.expr->type.kind;
TypeKind actual = node->as.return_stmt.expr->type.kind; TypeKind actual = node->as.return_stmt.expr->type.kind;
TypeKind expected = current_return_type; TypeKind expected = current_return_type;
if (actual != TYPE_ERROR && expected != TYPE_VOID && actual != expected) { if (actual != TYPE_ERROR && expected != TYPE_VOID) {
error_add(errors, "<sema>", node->line, node->col, if (expected == TYPE_STRUCT) {
// 结构体返回类型:比较具体类型名
const char* actual_name = node->as.return_stmt.expr->type.struct_name;
const char* expected_name = current_return_struct_name;
if (actual != TYPE_STRUCT || !actual_name || !expected_name ||
strcmp(actual_name, expected_name) != 0) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"返回类型不匹配: 期望 '%s',得到 '%s'",
expected_name ? expected_name : "struct",
actual_name ? actual_name : type_name(actual));
}
} else if (actual != expected) {
error_add(errors, "<sema>", node->loc.line, node->loc.col,
"返回类型不匹配: 期望 '%s',得到 '%s'", "返回类型不匹配: 期望 '%s',得到 '%s'",
type_name(expected), type_name(actual)); type_name(expected), type_name(actual));
} }
}
} else if (current_return_type != TYPE_VOID) { } else if (current_return_type != TYPE_VOID) {
error_add(errors, "<sema>", node->line, node->col, error_add(errors, "<sema>", node->loc.line, node->loc.col,
"函数应返回值类型 '%s'", type_name(current_return_type)); "函数应返回值类型 '%s'",
current_return_struct_name ? current_return_struct_name : type_name(current_return_type));
} }
break; break;
@@ -460,13 +502,13 @@ void sema_analyze(AstNode* ast, ErrorList* errors, Arena* arena) {
// 注册内置函数 // 注册内置函数
TypeKind params_i64[] = {TYPE_I64}; TypeKind params_i64[] = {TYPE_I64};
scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, params_i64, 1); scope_insert_function(global_scope, arena, "print_i64", TYPE_VOID, NULL, params_i64, NULL, 1);
TypeKind params_f64[] = {TYPE_F64}; TypeKind params_f64[] = {TYPE_F64};
scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, params_f64, 1); scope_insert_function(global_scope, arena, "print_f64", TYPE_VOID, NULL, params_f64, NULL, 1);
TypeKind params_bool[] = {TYPE_BOOL}; TypeKind params_bool[] = {TYPE_BOOL};
scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, params_bool, 1); scope_insert_function(global_scope, arena, "print_bool", TYPE_VOID, NULL, params_bool, NULL, 1);
TypeKind params_str[] = {TYPE_STR}; TypeKind params_str[] = {TYPE_STR};
scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, params_str, 1); scope_insert_function(global_scope, arena, "print_str", TYPE_VOID, NULL, params_str, NULL, 1);
analyze_node(ast, global_scope, errors, arena); analyze_node(ast, global_scope, errors, arena);
} }
+7 -2
View File
@@ -41,7 +41,8 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name,
} }
Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name,
TypeKind ret, TypeKind* pt, size_t pc) { TypeKind ret, const char* ret_struct_name,
TypeKind* pt, const char** pstruct_names, size_t pc) {
if (scope->head) { if (scope->head) {
for (Symbol* sym = scope->head; sym; sym = sym->next) { for (Symbol* sym = scope->head; sym; sym = sym->next) {
if (strcmp(sym->name, name) == 0) return NULL; if (strcmp(sym->name, name) == 0) return NULL;
@@ -50,7 +51,11 @@ Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name,
Symbol* sym = (Symbol*)arena_alloc_impl(alloc, sizeof(Symbol)); Symbol* sym = (Symbol*)arena_alloc_impl(alloc, sizeof(Symbol));
if (!sym) return NULL; if (!sym) return NULL;
sym->name = name; sym->kind = SYM_FUNCTION; sym->type = TYPE_VOID; sym->name = name; sym->kind = SYM_FUNCTION; sym->type = TYPE_VOID;
sym->return_type = ret; sym->param_types = pt; sym->param_count = pc; sym->return_type = ret;
sym->return_struct_type_name = ret_struct_name;
sym->param_types = pt;
sym->param_struct_names = pstruct_names;
sym->param_count = pc;
sym->struct_field_names = NULL; sym->struct_field_names = NULL;
sym->struct_field_types = NULL; sym->struct_field_types = NULL;
sym->struct_field_count = 0; sym->struct_field_count = 0;
+4 -1
View File
@@ -13,7 +13,9 @@ typedef struct Symbol {
bool is_mut; // 变量是否可变(可被赋值) bool is_mut; // 变量是否可变(可被赋值)
// 函数特有 // 函数特有
TypeKind return_type; TypeKind return_type;
const char* return_struct_type_name; // 返回类型为 struct 时的类型名
TypeKind* param_types; TypeKind* param_types;
const char** param_struct_names; // 参数为 struct 时的类型名
size_t param_count; size_t param_count;
// 结构体特有(SYM_STRUCT // 结构体特有(SYM_STRUCT
const char** struct_field_names; const char** struct_field_names;
@@ -43,7 +45,8 @@ Symbol* scope_insert(Scope* scope, void* alloc, const char* name,
// 插入函数符号 // 插入函数符号
Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name, Symbol* scope_insert_function(Scope* scope, void* alloc, const char* name,
TypeKind ret, TypeKind* pt, size_t pc); TypeKind ret, const char* ret_struct_name,
TypeKind* pt, const char** pstruct_names, size_t pc);
// 插入结构体符号 // 插入结构体符号
Symbol* scope_insert_struct(Scope* scope, void* alloc, const char* name, Symbol* scope_insert_struct(Scope* scope, void* alloc, const char* name,
+19
View File
@@ -0,0 +1,19 @@
struct Point {
x: i64,
y: i64,
}
fn make_point(x: i64, y: i64) -> Point {
return Point { x: x, y: y };
}
fn print_point(p: Point) -> void {
print_i64(p.x);
print_i64(p.y);
}
fn main() -> i64 {
let p: Point = make_point(3, 4);
print_point(p);
return 0;
}
+142 -27
View File
@@ -8,12 +8,12 @@ void test_codegen_simple_function() {
Arena a = arena_create(1); Arena a = arena_create(1);
// 构造 AST: fn main() -> i64 { return 42; } // 构造 AST: fn main() -> i64 { return 42; }
AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 42, 1, 1), 1, 1); AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 42, loc_at(1, 1)), loc_at(1, 1));
AstNode* stmts[] = { ret }; AstNode* stmts[] = { ret };
AstNode* body = ast_make_block(&a, stmts, 1, 1, 1); AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1));
AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, body, 1, 1); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1));
AstNode* fns[] = { fn }; AstNode* fns[] = { fn };
AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1));
const char* err = NULL; const char* err = NULL;
LLVMContextRef ctx = NULL; LLVMContextRef ctx = NULL;
@@ -35,19 +35,19 @@ void test_codegen_if_else() {
Arena a = arena_create(1); Arena a = arena_create(1);
// fn main() -> i64 { if true { return 1; } else { return 0; } } // fn main() -> i64 { if true { return 1; } else { return 0; } }
AstNode* then_ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, 1, 1), 1, 1); AstNode* then_ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, loc_at(1, 1)), loc_at(1, 1));
AstNode* then_stmts[] = { then_ret }; AstNode* then_stmts[] = { then_ret };
AstNode* then_block = ast_make_block(&a, then_stmts, 1, 1, 1); AstNode* then_block = ast_make_block(&a, then_stmts, 1, loc_at(1, 1));
AstNode* else_ret = ast_make_return(&a, ast_make_literal_i64(&a, 0, 1, 1), 1, 1); AstNode* else_ret = ast_make_return(&a, ast_make_literal_i64(&a, 0, loc_at(1, 1)), loc_at(1, 1));
AstNode* else_stmts[] = { else_ret }; AstNode* else_stmts[] = { else_ret };
AstNode* else_block = ast_make_block(&a, else_stmts, 1, 1, 1); AstNode* else_block = ast_make_block(&a, else_stmts, 1, loc_at(1, 1));
AstNode* if_stmt = ast_make_if(&a, AstNode* if_stmt = ast_make_if(&a,
ast_make_literal_bool(&a, true, 1, 1), then_block, else_block, 1, 1); ast_make_literal_bool(&a, true, loc_at(1, 1)), then_block, else_block, loc_at(1, 1));
AstNode* stmts[] = { if_stmt }; AstNode* stmts[] = { if_stmt };
AstNode* body = ast_make_block(&a, stmts, 1, 1, 1); AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1));
AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, body, 1, 1); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1));
AstNode* fns[] = { fn }; AstNode* fns[] = { fn };
AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1));
const char* err = NULL; const char* err = NULL;
LLVMContextRef ctx2 = NULL; LLVMContextRef ctx2 = NULL;
@@ -68,17 +68,17 @@ void test_codegen_binary_ops() {
// fn main() -> i64 { return 1 + 2 * 3; } // fn main() -> i64 { return 1 + 2 * 3; }
AstNode* expr = ast_make_binary(&a, OP_ADD, AstNode* expr = ast_make_binary(&a, OP_ADD,
ast_make_literal_i64(&a, 1, 1, 1), ast_make_literal_i64(&a, 1, loc_at(1, 1)),
ast_make_binary(&a, OP_MUL, ast_make_binary(&a, OP_MUL,
ast_make_literal_i64(&a, 2, 1, 1), ast_make_literal_i64(&a, 2, loc_at(1, 1)),
ast_make_literal_i64(&a, 3, 1, 1), 1, 1), ast_make_literal_i64(&a, 3, loc_at(1, 1)), loc_at(1, 1)),
1, 1); loc_at(1, 1));
AstNode* ret = ast_make_return(&a, expr, 1, 1); AstNode* ret = ast_make_return(&a, expr, loc_at(1, 1));
AstNode* stmts[] = { ret }; AstNode* stmts[] = { ret };
AstNode* body = ast_make_block(&a, stmts, 1, 1, 1); AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1));
AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, body, 1, 1); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1));
AstNode* fns[] = { fn }; AstNode* fns[] = { fn };
AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1));
const char* err = NULL; const char* err = NULL;
LLVMContextRef ctx3 = NULL; LLVMContextRef ctx3 = NULL;
@@ -98,17 +98,17 @@ void test_codegen_while_loop() {
Arena a = arena_create(1); Arena a = arena_create(1);
// fn main() -> i64 { while true { return 0; } return 1; } // fn main() -> i64 { while true { return 0; } return 1; }
AstNode* while_body_stmts[] = { AstNode* while_body_stmts[] = {
ast_make_return(&a, ast_make_literal_i64(&a, 0, 1, 1), 1, 1) ast_make_return(&a, ast_make_literal_i64(&a, 0, loc_at(1, 1)), loc_at(1, 1))
}; };
AstNode* while_body = ast_make_block(&a, while_body_stmts, 1, 1, 1); AstNode* while_body = ast_make_block(&a, while_body_stmts, 1, loc_at(1, 1));
AstNode* while_stmt = ast_make_while(&a, AstNode* while_stmt = ast_make_while(&a,
ast_make_literal_bool(&a, true, 1, 1), while_body, 1, 1); ast_make_literal_bool(&a, true, loc_at(1, 1)), while_body, loc_at(1, 1));
AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, 1, 1), 1, 1); AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 1, loc_at(1, 1)), loc_at(1, 1));
AstNode* stmts[] = { while_stmt, ret }; AstNode* stmts[] = { while_stmt, ret };
AstNode* fn_body = ast_make_block(&a, stmts, 2, 1, 1); AstNode* fn_body = ast_make_block(&a, stmts, 2, loc_at(1, 1));
AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, fn_body, 1, 1); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, fn_body, loc_at(1, 1));
AstNode* fns[] = { fn }; AstNode* fns[] = { fn };
AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, 1, 1); AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, loc_at(1, 1));
const char* err = NULL; const char* err = NULL;
LLVMContextRef ctx4 = NULL; LLVMContextRef ctx4 = NULL;
@@ -121,10 +121,125 @@ void test_codegen_while_loop() {
arena_destroy(&a); arena_destroy(&a);
} }
/* === struct IR 生成测试 === */
void test_codegen_struct_decl() {
Arena a = arena_create(1);
/* 构造 AST: struct Point { x: i64, y: i64 } */
AstNode* fields[2];
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1));
AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1));
AstNode* structs[] = { struct_decl };
/* 构造 fn main() -> i64 { let p = Point { x: 1, y: 2 }; return p.x; } */
const char* fnames[] = {"x", "y"};
AstNode* fvals[] = {
ast_make_literal_i64(&a, 1, loc_at(1, 1)),
ast_make_literal_i64(&a, 2, loc_at(1, 1))
};
AstNode* init = ast_make_struct_init(&a, "Point", fnames, fvals, 2, loc_at(1, 1));
/* 手动设置类型(绕过 semacodegen 需要读取 type 字段) */
init->type.kind = TYPE_STRUCT;
init->type.struct_name = "Point";
AstNode* let_stmt = ast_make_let(&a, "p", TYPE_UNKNOWN, false, false,
init, NULL, loc_at(1, 1));
/* return p.x; */
AstNode* p_ident = ast_make_ident(&a, "p", loc_at(1, 1));
p_ident->type.kind = TYPE_STRUCT;
p_ident->type.struct_name = "Point";
AstNode* field_x = ast_make_field_access(&a, p_ident, "x", loc_at(1, 1));
field_x->as.field_access.field_index = 0; /* x 是第 0 个字段 */
field_x->type.kind = TYPE_I64;
AstNode* ret = ast_make_return(&a, field_x, loc_at(1, 1));
AstNode* stmts[] = { let_stmt, ret };
AstNode* body = ast_make_block(&a, stmts, 2, loc_at(1, 1));
AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1));
AstNode* fns[] = { fn };
AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, loc_at(1, 1));
const char* err = NULL;
LLVMContextRef ctx = NULL;
LLVMModuleRef mod = codegen_module(prog, &a, "test_struct_decl", &err, &ctx);
ASSERT(mod != NULL);
ASSERT(err == NULL);
char* verify_err = NULL;
int failed = LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err);
ASSERT(!failed);
LLVMDisposeModule(mod);
LLVMContextDispose(ctx);
arena_destroy(&a);
}
void test_codegen_struct_field_access() {
Arena a = arena_create(1);
/* 构造 AST: struct Point { x: i64, y: i64 } */
AstNode* fields[2];
fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1));
fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1));
AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1));
AstNode* structs[] = { struct_decl };
/* 构造 fn main() -> i64 { let p = Point { x: 5, y: 10 }; return p.y; } */
const char* fnames[] = {"x", "y"};
AstNode* fvals[] = {
ast_make_literal_i64(&a, 5, loc_at(1, 1)),
ast_make_literal_i64(&a, 10, loc_at(1, 1))
};
AstNode* init = ast_make_struct_init(&a, "Point", fnames, fvals, 2, loc_at(1, 1));
init->type.kind = TYPE_STRUCT;
init->type.struct_name = "Point";
AstNode* let_stmt = ast_make_let(&a, "p", TYPE_UNKNOWN, false, false,
init, NULL, loc_at(1, 1));
/* return p.y; */
AstNode* p_ident = ast_make_ident(&a, "p", loc_at(1, 1));
p_ident->type.kind = TYPE_STRUCT;
p_ident->type.struct_name = "Point";
AstNode* field_y = ast_make_field_access(&a, p_ident, "y", loc_at(1, 1));
field_y->as.field_access.field_index = 1; /* y 是第 1 个字段 */
field_y->type.kind = TYPE_I64;
AstNode* ret = ast_make_return(&a, field_y, loc_at(1, 1));
AstNode* stmts[] = { let_stmt, ret };
AstNode* body = ast_make_block(&a, stmts, 2, loc_at(1, 1));
AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1));
AstNode* fns[] = { fn };
AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, loc_at(1, 1));
const char* err = NULL;
LLVMContextRef ctx = NULL;
LLVMModuleRef mod = codegen_module(prog, &a, "test_struct_field", &err, &ctx);
ASSERT(mod != NULL);
ASSERT(err == NULL);
char* verify_err = NULL;
int failed = LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err);
ASSERT(!failed);
LLVMDisposeModule(mod);
LLVMContextDispose(ctx);
arena_destroy(&a);
}
int main(void) { int main(void) {
TEST_RUN(test_codegen_simple_function); TEST_RUN(test_codegen_simple_function);
TEST_RUN(test_codegen_if_else); TEST_RUN(test_codegen_if_else);
TEST_RUN(test_codegen_binary_ops); TEST_RUN(test_codegen_binary_ops);
TEST_RUN(test_codegen_while_loop); TEST_RUN(test_codegen_while_loop);
TEST_RUN(test_codegen_struct_decl);
TEST_RUN(test_codegen_struct_field_access);
return test_summary(); return test_summary();
} }
+74
View File
@@ -115,6 +115,76 @@ void test_str_concat_type_ok() {
arena_destroy(&a); arena_destroy(&a);
} }
/* === struct 类型检查测试 === */
void test_struct_field_type_mismatch() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"struct Point { x: i64, y: i64 } fn main() { let p: Point = Point { x: 10, y: true }; return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count > 0); // y 字段类型不匹配: true 是 bool, 不是 i64
arena_destroy(&a);
}
void test_struct_undefined() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"fn main() { let p: Unknown = Unknown { x: 1 }; return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count > 0); // Unknown 未定义
arena_destroy(&a);
}
void test_struct_field_count_mismatch() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"struct Point { x: i64, y: i64 } fn main() { let p: Point = Point { x: 10 }; return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count > 0); // 缺少字段 'y'
arena_destroy(&a);
}
void test_struct_nested_type_ok() {
Arena a = arena_create(1);
size_t tc; ErrorInfo lex_err = {0};
Token* toks = lex(&a,
"struct Point { x: i64, y: i64 } struct Rect { tl: Point, br: Point } fn main() { let r: Rect = Rect { tl: Point { x: 0, y: 0 }, br: Point { x: 1, y: 1 } }; print_i64(r.tl.x); return; }",
"test", &tc, &lex_err);
ASSERT(toks != NULL);
ErrorInfo parse_err = {0};
AstNode* ast = parse(&a, toks, tc, "test", &parse_err);
ASSERT(ast != NULL);
ErrorList errors; error_init(&errors);
sema_analyze(ast, &errors, &a);
ASSERT(errors.count == 0); // 嵌套结构体类型检查应通过
arena_destroy(&a);
}
int main(void) { int main(void) {
TEST_RUN(test_type_error); TEST_RUN(test_type_error);
TEST_RUN(test_undefined_var); TEST_RUN(test_undefined_var);
@@ -123,5 +193,9 @@ int main(void) {
TEST_RUN(test_assign_immutable_error); TEST_RUN(test_assign_immutable_error);
TEST_RUN(test_str_type_ok); TEST_RUN(test_str_type_ok);
TEST_RUN(test_str_concat_type_ok); TEST_RUN(test_str_concat_type_ok);
TEST_RUN(test_struct_field_type_mismatch);
TEST_RUN(test_struct_undefined);
TEST_RUN(test_struct_field_count_mismatch);
TEST_RUN(test_struct_nested_type_ok);
return test_summary(); return test_summary();
} }