diff --git a/src/ast/ast.c b/src/ast/ast.c index ab79c1b..d7cdd89 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -48,10 +48,11 @@ AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size } AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, - const char* struct_type_name, SourceLoc loc) { + const char* struct_type_name, bool is_out, SourceLoc loc) { NEW(alloc, AST_PARAMETER); n->as.parameter.name = name; n->as.parameter.type = type; n->as.parameter.struct_type_name = struct_type_name; + n->as.parameter.is_out = is_out; return n; } diff --git a/src/ast/ast.h b/src/ast/ast.h index 435f5cc..543a1b7 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -74,7 +74,7 @@ struct AstNode { const char** type_params; size_t type_param_count; TypeKind* multi_ret_types; const char** multi_ret_snames; size_t multi_ret_count; } function; // AST_PARAMETER (也用作结构体字段: name + type) - struct { const char* name; TypeKind type; const char* struct_type_name; } parameter; + struct { const char* name; TypeKind type; const char* struct_type_name; bool is_out; } parameter; // AST_BLOCK struct { struct AstNode** stmts; size_t stmt_count; } block; // AST_LET_STMT @@ -150,7 +150,7 @@ AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size TypeKind ret, const char* ret_struct_name, AstNode* body, bool is_pub, const char** type_params, size_t tp_count, SourceLoc loc); -AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc); +AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, bool is_out, SourceLoc loc); AstNode* ast_make_block(void* alloc, AstNode** stmts, size_t count, SourceLoc loc); AstNode* ast_make_let(void* alloc, const char* name, TypeKind annot_type, bool has_type_annot, bool is_mut, AstNode* init, const char* struct_type_name, diff --git a/src/codegen/cg_expr.c b/src/codegen/cg_expr.c index e19c31f..83559cf 100644 --- a/src/codegen/cg_expr.c +++ b/src/codegen/cg_expr.c @@ -239,8 +239,34 @@ static LLVMValueRef cg_call_impl(CgCtx* ctx, AstNode* node) { if (!fn) return NULL; LLVMValueRef args[16]; if (node->as.call.arg_count > 16) { ctx->error = "函数参数过多(最多16)"; return NULL; } + FnEntry* fn_entry = find_fn_entry(ctx, node->as.call.name); for (size_t i = 0; i < node->as.call.arg_count; i++) { - args[i] = codegen_expr(ctx, node->as.call.args[i]); + bool is_out = fn_entry && fn_entry->out_params + && i < fn_entry->pc && fn_entry->out_params[i]; + if (is_out) { + // out 参数传 alloca 地址而非加载后的值 + AstNode* arg = node->as.call.args[i]; + if (arg->kind == AST_IDENT_EXPR) { + args[i] = find_var(ctx, arg->as.ident.name); + } else if (arg->kind == AST_INDEX_EXPR) { + // arr[i]: 生成 GEP 得到元素指针 + LLVMValueRef arr_ptr = find_var(ctx, arg->as.index_expr.array->as.ident.name); + LLVMValueRef idx_val = codegen_expr(ctx, arg->as.index_expr.index); + if (!arr_ptr || !idx_val) return NULL; + LLVMValueRef indices[] = { + LLVMConstInt(LLVMInt32TypeInContext(ctx->context), 0, false), + LLVMBuildIntCast2(ctx->builder, idx_val, + LLVMInt32TypeInContext(ctx->context), false, "idx32") + }; + args[i] = LLVMBuildGEP2(ctx->builder, + LLVMGetElementType(LLVMTypeOf(arr_ptr)), + arr_ptr, indices, 2, "out_gep"); + } else { + args[i] = codegen_expr(ctx, node->as.call.args[i]); + } + } else { + args[i] = codegen_expr(ctx, node->as.call.args[i]); + } if (!args[i]) return NULL; } LLVMTypeRef fn_ty = LLVMGlobalGetValueType(fn); diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 7db4f4e..15ff79d 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -23,17 +23,24 @@ LLVMValueRef find_fn(CgCtx* ctx, const char* name) { return NULL; } -void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn) { +void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc) { FnEntry* e = arena_alloc(ctx->arena, sizeof(*e)); if (!e) return; e->name = name; e->fn = fn; e->ret = TYPE_VOID; e->params = NULL; - e->pc = 0; + e->out_params = out_params; + e->pc = pc; e->next = ctx->fn_table; ctx->fn_table = e; } +FnEntry* find_fn_entry(CgCtx* ctx, const char* name) { + for (FnEntry* e = ctx->fn_table; e; e = e->next) + if (strcmp(e->name, name) == 0) return e; + return NULL; +} + // === 结构体类型表 === void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc) { StructTypeEntry* e = arena_alloc(ctx->arena, sizeof(*e)); @@ -379,14 +386,23 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, AstNode* fn = ast->as.program.functions[i]; LLVMTypeRef* ptypes = arena_alloc(ctx.arena, fn->as.function.param_count * sizeof(LLVMTypeRef)); + bool* out_params = NULL; + if (fn->as.function.param_count > 0) { + out_params = arena_alloc(ctx.arena, + fn->as.function.param_count * sizeof(bool)); + } for (size_t j = 0; j < fn->as.function.param_count; j++) { AstNode* param = fn->as.function.params[j]; + bool is_out = param->as.parameter.is_out; + if (out_params) out_params[j] = is_out; + LLVMTypeRef inner_ty; if (param->as.parameter.type == TYPE_STRUCT && param->as.parameter.struct_type_name) { - ptypes[j] = find_struct_type(&ctx, param->as.parameter.struct_type_name); + inner_ty = find_struct_type(&ctx, param->as.parameter.struct_type_name); } else { - ptypes[j] = to_llvm_type(&ctx, param->as.parameter.type); + inner_ty = to_llvm_type(&ctx, param->as.parameter.type); } + ptypes[j] = is_out ? LLVMPointerType(inner_ty, 0) : inner_ty; } LLVMTypeRef ret_ty; if (fn->as.function.return_type == TYPE_STRUCT && @@ -398,7 +414,8 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, LLVMTypeRef fty = LLVMFunctionType(ret_ty, ptypes, (unsigned)fn->as.function.param_count, false); LLVMValueRef lfn = LLVMAddFunction(ctx.module, fn->as.function.name, fty); - add_fn(&ctx, fn->as.function.name, lfn); + add_fn(&ctx, fn->as.function.name, lfn, out_params, + fn->as.function.param_count); } // 第二遍:生成函数体 @@ -422,10 +439,15 @@ LLVMModuleRef codegen_module(AstNode* ast, Arena* codegen_arena, } else { param_ty = to_llvm_type(&ctx, pnode->as.parameter.type); } - LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder, - param_ty, pnode->as.parameter.name); - LLVMBuildStore(ctx.builder, param, alloca); - add_var(&ctx, pnode->as.parameter.name, alloca, param_ty); + if (pnode->as.parameter.is_out) { + // out 参数: param 已是指向调用者变量的指针, 直接用作 alloca + add_var(&ctx, pnode->as.parameter.name, param, param_ty); + } else { + LLVMValueRef alloca = LLVMBuildAlloca(ctx.builder, + param_ty, pnode->as.parameter.name); + LLVMBuildStore(ctx.builder, param, alloca); + add_var(&ctx, pnode->as.parameter.name, alloca, param_ty); + } } ctx.defer_count = 0; diff --git a/src/codegen/codegen_internal.h b/src/codegen/codegen_internal.h index 0667119..98cbfee 100644 --- a/src/codegen/codegen_internal.h +++ b/src/codegen/codegen_internal.h @@ -28,6 +28,7 @@ typedef struct FnEntry { LLVMValueRef fn; TypeKind ret; TypeKind* params; + bool* out_params; // 哪些参数是 out (引用传递) size_t pc; struct FnEntry* next; } FnEntry; @@ -71,7 +72,8 @@ LLVMValueRef coerce_int(CgCtx* ctx, LLVMValueRef val, LLVMTypeRef from_ty, LLVMT LLVMValueRef find_var(CgCtx* ctx, const char* name); void add_var(CgCtx* ctx, const char* name, LLVMValueRef alloca, LLVMTypeRef alloca_type); LLVMValueRef find_fn(CgCtx* ctx, const char* name); -void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn); +FnEntry* find_fn_entry(CgCtx* ctx, const char* name); +void add_fn(CgCtx* ctx, const char* name, LLVMValueRef fn, bool* out_params, size_t pc); void add_struct_type(CgCtx* ctx, const char* name, LLVMTypeRef ty, size_t fc); LLVMTypeRef find_struct_type(CgCtx* ctx, const char* name); diff --git a/src/lexer/lexer.c b/src/lexer/lexer.c index 203b304..689a779 100644 --- a/src/lexer/lexer.c +++ b/src/lexer/lexer.c @@ -68,7 +68,7 @@ static TokenKind check_keyword(const Token* tok) { KW("struct", TOK_STRUCT); KW("type", TOK_TYPE); KW("enum", TOK_ENUM); KW("extend", TOK_EXTEND); KW("defer", TOK_DEFER); KW("match", TOK_MATCH); KW("pub", TOK_PUB); KW("mod", TOK_MOD); KW("use", TOK_USE); - KW("trait", TOK_TRAIT); KW("Self", TOK_SELF); + KW("trait", TOK_TRAIT); KW("Self", TOK_SELF); KW("out", TOK_OUT); KW("_", TOK_UNDERSCORE); KW("true", TOK_TRUE); KW("false", TOK_FALSE); #undef KW diff --git a/src/lexer/token.c b/src/lexer/token.c index 6d34c2e..c7421a2 100644 --- a/src/lexer/token.c +++ b/src/lexer/token.c @@ -7,7 +7,7 @@ static const char* NAMES[] = { [TOK_FN] = "fn", [TOK_LET] = "let", [TOK_VAR] = "var", [TOK_CONST] = "const", [TOK_IF] = "if", [TOK_GUARD] = "guard", [TOK_PUB] = "pub", [TOK_MOD] = "mod", [TOK_USE] = "use", - [TOK_TRAIT] = "trait", [TOK_SELF] = "Self", + [TOK_TRAIT] = "trait", [TOK_SELF] = "Self", [TOK_OUT] = "out", [TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return", [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_EXTEND] = "extend", [TOK_DEFER] = "defer", [TOK_MATCH] = "match", diff --git a/src/lexer/token.h b/src/lexer/token.h index 56a346b..6785900 100644 --- a/src/lexer/token.h +++ b/src/lexer/token.h @@ -8,7 +8,7 @@ typedef enum { // 关键字 TOK_FN, TOK_LET, TOK_VAR, TOK_CONST, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, TOK_GUARD, TOK_STRUCT, TOK_TYPE, TOK_ENUM, TOK_EXTEND, TOK_DEFER, TOK_MATCH, TOK_PUB, TOK_MOD, TOK_USE, - TOK_TRAIT, TOK_SELF, + TOK_TRAIT, TOK_SELF, TOK_OUT, // 类型关键字 TOK_I32, TOK_I64, TOK_U64, TOK_F64, TOK_BOOL, TOK_CHAR, TOK_STR, TOK_VOID, // 字面量 diff --git a/src/parser/parser.c b/src/parser/parser.c index c6c8128..17b89be 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -27,7 +27,7 @@ AstNode* parse_struct_decl(Parser* p, ErrorInfo* error) { } fields[fcount++] = ast_make_parameter(p->arena, arena_strdup_impl(p->arena, fname->start, fname->length), - fti.kind, fti.struct_name, tok_loc(fname)); + fti.kind, fti.struct_name, false, tok_loc(fname)); if (peek(p)->kind == TOK_COMMA) advance(p); else break; } @@ -348,6 +348,7 @@ AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) { AstNode* params[64]; int pcount = 0; while (peek(p)->kind != TOK_RPAREN && !error->message) { if (pcount >= 64) { error->message = "函数参数过多 (最多64)"; error->filename = p->filename; error->line = peek(p)->line; error->col = peek(p)->col; return NULL; } + bool is_out = match(p, TOK_OUT); const Token* pname = expect(p, TOK_IDENT, error, "参数名"); if (!pname) return NULL; if (!expect(p, TOK_COLON, error, "缺少 ':'")) return NULL; @@ -355,7 +356,7 @@ AstNode* parse_function(Parser* p, bool is_pub, ErrorInfo* error) { if (pti.kind == TYPE_ERROR) return NULL; params[pcount++] = ast_make_parameter(p->arena, arena_strdup_impl(p->arena, pname->start, pname->length), - pti.kind, pti.struct_name, tok_loc(pname)); + pti.kind, pti.struct_name, is_out, tok_loc(pname)); if (match(p, TOK_COMMA)) continue; else break; } @@ -699,7 +700,7 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, arena_strdup_impl(p.arena, fbuf, strlen(fbuf)), fn->as.function.multi_ret_types[i], fn->as.function.multi_ret_snames[i], - tok_loc(peek(&p))); + false, tok_loc(peek(&p))); } if (struct_count >= 64) { error->message = "结构体过多"; return NULL; } structs[struct_count++] = ast_make_struct_decl(p.arena, diff --git a/src/sema/sema.c b/src/sema/sema.c index f117ff4..0f19bfe 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -159,7 +159,8 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { Scope* fn_scope = scope_new(a, scope); for (size_t j = 0; j < mono_fn->as.function.param_count; j++) { AstNode* p = mono_fn->as.function.params[j]; - scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); + Symbol* ps = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, p->as.parameter.type); + if (ps && p->as.parameter.is_out) { ps->kind = SYM_VARIABLE; ps->is_mut = true; } } // 分析函数体 current_return_type = mono_fn->as.function.return_type; @@ -197,6 +198,7 @@ void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a) { } } Symbol* sym = scope_insert(fn_scope, a, p->as.parameter.name, SYM_PARAMETER, pt); + if (sym && p->as.parameter.is_out) { sym->kind = SYM_VARIABLE; sym->is_mut = true; } if (sym && pt == TYPE_STRUCT && psn) { sym->struct_type_name = psn; } diff --git a/test/programs/42_out_param.l b/test/programs/42_out_param.l new file mode 100644 index 0000000..271e608 --- /dev/null +++ b/test/programs/42_out_param.l @@ -0,0 +1,22 @@ +// out 参数测试 — 引用传递 +fn swap(out x: i64, out y: i64) -> void { + let t = x; + x = y; + y = t; +} + +fn increment(out x: i64) -> void { + x = x + 1; +} + +fn main() -> void { + let a = 10; + let b = 20; + print_i64(a); + print_i64(b); + swap(a, b); + print_i64(a); + print_i64(b); + increment(a); + print_i64(a); +} diff --git a/test/programs/43_out_param_struct.l b/test/programs/43_out_param_struct.l new file mode 100644 index 0000000..3b8f00b --- /dev/null +++ b/test/programs/43_out_param_struct.l @@ -0,0 +1,22 @@ +// out 参数 + 结构体测试 +struct Point { x: i64, y: i64 } + +fn init_point(out p: Point) -> void { + p = Point { x: 100, y: 200 }; +} + +fn offset_point(out p: Point) -> void { + p = Point { x: p.x + 50, y: p.y + 100 }; +} + +fn main() -> void { + let p = Point { x: 0, y: 0 }; + print_i64(p.x); + print_i64(p.y); + init_point(p); + print_i64(p.x); + print_i64(p.y); + offset_point(p); + print_i64(p.x); + print_i64(p.y); +} diff --git a/test/test_codegen.c b/test/test_codegen.c index 74e356c..3dfb557 100644 --- a/test/test_codegen.c +++ b/test/test_codegen.c @@ -129,8 +129,8 @@ void test_codegen_struct_decl() { /* 构造 AST: struct Point { x: i64, y: i64 } */ AstNode* fields[2]; - fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1)); - fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1)); + fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, false, loc_at(1, 1)); + fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, false, loc_at(1, 1)); AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1)); AstNode* structs[] = { struct_decl }; @@ -185,8 +185,8 @@ void test_codegen_struct_field_access() { /* 构造 AST: struct Point { x: i64, y: i64 } */ AstNode* fields[2]; - fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1)); - fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1)); + fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, false, loc_at(1, 1)); + fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, false, loc_at(1, 1)); AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1)); AstNode* structs[] = { struct_decl }; @@ -356,13 +356,13 @@ void test_codegen_method_call() { /* struct Point { x: i64, y: i64 } */ AstNode* fields[2]; - fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, loc_at(1, 1)); - fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, loc_at(1, 1)); + fields[0] = ast_make_parameter(&a, "x", TYPE_I64, NULL, false, loc_at(1, 1)); + fields[1] = ast_make_parameter(&a, "y", TYPE_I64, NULL, false, loc_at(1, 1)); AstNode* struct_decl = ast_make_struct_decl(&a, "Point", fields, 2, loc_at(1, 1)); AstNode* structs[] = { struct_decl }; /* fn Point$get_x(self: Point) -> i64 { return self.x; } */ - AstNode* self_param = ast_make_parameter(&a, "self", TYPE_STRUCT, "Point", loc_at(1, 1)); + AstNode* self_param = ast_make_parameter(&a, "self", TYPE_STRUCT, "Point", false, loc_at(1, 1)); AstNode* params[] = { self_param }; AstNode* self_ident = ast_make_ident(&a, "self", loc_at(1, 1)); self_ident->type.kind = TYPE_STRUCT; diff --git a/test/test_sema.c b/test/test_sema.c index d2f9ae4..0d0c2ed 100644 --- a/test/test_sema.c +++ b/test/test_sema.c @@ -418,6 +418,40 @@ void test_match_wildcard_only_sema_ok() { arena_destroy(&a); } +void test_out_param_assign_ok() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "fn swap(out x: i64, out y: i64) -> void { let t = x; x = y; y = t; return; }" + "fn main() -> void { let a = 10; let b = 20; swap(a, b); return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + ErrorList errors; error_init(&errors, &a); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count == 0); // out 参数赋值不应报错 + arena_destroy(&a); +} + +void test_in_param_assign_error() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "fn bad(x: i64) -> void { x = 42; return; }" + "fn main() -> void { bad(10); return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + ErrorList errors; error_init(&errors, &a); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count > 0); // 非 out 参数赋值应报错 + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_type_error); TEST_RUN(test_undefined_var); @@ -443,5 +477,7 @@ int main(void) { TEST_RUN(test_match_enum_sema_ok); TEST_RUN(test_match_int_sema_ok); TEST_RUN(test_match_wildcard_only_sema_ok); + TEST_RUN(test_out_param_assign_ok); + TEST_RUN(test_in_param_assign_error); return test_summary(); }