diff --git a/include/l_lang.h b/include/l_lang.h index 728eb67..04494c6 100644 --- a/include/l_lang.h +++ b/include/l_lang.h @@ -13,6 +13,7 @@ typedef enum { TYPE_STR, TYPE_VOID, TYPE_STRUCT, // 结构体类型 + TYPE_ENUM, // 枚举类型 TYPE_UNKNOWN, // 尚未推断 TYPE_ERROR, // 类型错误 } TypeKind; @@ -25,6 +26,7 @@ static inline const char* type_name(TypeKind kind) { case TYPE_STR: return "str"; case TYPE_VOID: return "void"; case TYPE_STRUCT: return "struct"; + case TYPE_ENUM: return "enum"; default: return ""; } } diff --git a/src/ast/ast.c b/src/ast/ast.c index c7eae06..c189dd1 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -10,7 +10,8 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode** structs, size_t struct_count, - AstNode** aliases, size_t alias_count, SourceLoc loc) { + AstNode** aliases, size_t alias_count, + AstNode** enums, size_t enum_count, SourceLoc loc) { NEW(alloc, AST_PROGRAM); n->as.program.functions = fns; n->as.program.fn_count = fn_count; @@ -18,6 +19,8 @@ AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, n->as.program.struct_count = struct_count; n->as.program.type_aliases = aliases; n->as.program.alias_count = alias_count; + n->as.program.enums = enums; + n->as.program.enum_count = enum_count; return n; } @@ -177,3 +180,23 @@ AstNode* ast_make_type_alias(void* alloc, const char* name, TypeKind aliased, n->as.type_alias.aliased_struct_name = aliased_struct; return n; } + +// === 枚举相关工厂函数 === + +AstNode* ast_make_enum_decl(void* alloc, const char* name, const char** variants, + size_t count, SourceLoc loc) { + NEW(alloc, AST_ENUM_DECL); + n->as.enum_decl.name = name; + n->as.enum_decl.variants = variants; + n->as.enum_decl.variant_count = count; + return n; +} + +AstNode* ast_make_enum_variant(void* alloc, const char* enum_name, + const char* variant_name, SourceLoc loc) { + NEW(alloc, AST_ENUM_VARIANT); + n->as.enum_variant.enum_name = enum_name; + n->as.enum_variant.variant_name = variant_name; + n->as.enum_variant.variant_index = -1; + return n; +} diff --git a/src/ast/ast.h b/src/ast/ast.h index db4ec80..40fada6 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -24,6 +24,8 @@ typedef enum { AST_STRUCT_INIT, // Point { x: 10, y: 20 } AST_FIELD_ACCESS, // p.x AST_TYPE_ALIAS, // type Meters = i64 + AST_ENUM_DECL, // enum Color { Red, Green, Blue } + AST_ENUM_VARIANT, // Color::Red } AstKind; typedef enum { @@ -50,7 +52,8 @@ struct AstNode { // AST_PROGRAM struct { struct AstNode** functions; size_t fn_count; struct AstNode** structs; size_t struct_count; - struct AstNode** type_aliases; size_t alias_count; } program; + struct AstNode** type_aliases; size_t alias_count; + struct AstNode** enums; size_t enum_count; } program; // AST_FUNCTION struct { const char* name; struct AstNode** params; size_t param_count; TypeKind return_type; const char* return_struct_type_name; @@ -91,13 +94,18 @@ struct AstNode { struct { struct AstNode* object; const char* field; int field_index; } field_access; // AST_TYPE_ALIAS struct { const char* name; TypeKind aliased_type; const char* aliased_struct_name; } type_alias; + // AST_ENUM_DECL + struct { const char* name; const char** variants; size_t variant_count; } enum_decl; + // AST_ENUM_VARIANT + struct { const char* enum_name; const char* variant_name; int variant_index; } enum_variant; } as; }; // 创建节点的辅助函数(内存来自 arena,通过 void* 传递避免循环依赖) AstNode* ast_make_program(void* alloc, AstNode** fns, size_t fn_count, AstNode** structs, size_t struct_count, - AstNode** aliases, size_t alias_count, SourceLoc loc); + AstNode** aliases, size_t alias_count, + AstNode** enums, size_t enum_count, SourceLoc loc); AstNode* ast_make_function(void* alloc, const char* name, AstNode** params, size_t pcount, TypeKind ret, const char* ret_struct_name, AstNode* body, SourceLoc loc); AstNode* ast_make_parameter(void* alloc, const char* name, TypeKind type, const char* struct_type_name, SourceLoc loc); @@ -122,5 +130,7 @@ AstNode* ast_make_struct_init(void* alloc, const char* type_name, const char** f AstNode* ast_make_field_access(void* alloc, AstNode* object, const char* field, SourceLoc loc); AstNode* ast_make_type_alias(void* alloc, const char* name, TypeKind aliased, const char* aliased_struct, SourceLoc loc); +AstNode* ast_make_enum_decl(void* alloc, const char* name, const char** variants, size_t count, SourceLoc loc); +AstNode* ast_make_enum_variant(void* alloc, const char* enum_name, const char* variant_name, SourceLoc loc); #endif diff --git a/src/codegen/codegen.c b/src/codegen/codegen.c index 58ac2e0..fe9f2bb 100644 --- a/src/codegen/codegen.c +++ b/src/codegen/codegen.c @@ -63,6 +63,7 @@ static LLVMTypeRef to_llvm_type(CgCtx* ctx, TypeKind kind) { case TYPE_BOOL: return LLVMInt1TypeInContext(ctx->context); case TYPE_STR: return LLVMPointerType(LLVMInt8TypeInContext(ctx->context), 0); case TYPE_STRUCT: + case TYPE_ENUM: return LLVMInt64TypeInContext(ctx->context); case TYPE_UNKNOWN: case TYPE_ERROR: default: return LLVMVoidTypeInContext(ctx->context); @@ -347,6 +348,10 @@ static LLVMValueRef codegen_expr(CgCtx* ctx, AstNode* node) { return LLVMBuildLoad2(ctx->builder, struct_ty, alloca, "struct_val"); } + case AST_ENUM_VARIANT: + return LLVMConstInt(LLVMInt64TypeInContext(ctx->context), + (unsigned long long)node->as.enum_variant.variant_index, true); + default: return NULL; } diff --git a/src/lexer/lexer.c b/src/lexer/lexer.c index ec9571a..4d493df 100644 --- a/src/lexer/lexer.c +++ b/src/lexer/lexer.c @@ -63,6 +63,7 @@ static TokenKind check_keyword(const Token* tok) { KW("bool", TOK_BOOL); KW("str", TOK_STR); KW("void", TOK_VOID); KW("struct", TOK_STRUCT); KW("type", TOK_TYPE); + KW("enum", TOK_ENUM); KW("true", TOK_TRUE); KW("false", TOK_FALSE); #undef KW return TOK_IDENT; @@ -141,6 +142,7 @@ Token* lex(Arena* a, const char* source, const char* filename, else if (c == '{') { tokens[idx++] = make_token(&l, TOK_LBRACE, l.pos, 1); advance(&l); } else if (c == '}') { tokens[idx++] = make_token(&l, TOK_RBRACE, l.pos, 1); advance(&l); } else if (c == ',') { tokens[idx++] = make_token(&l, TOK_COMMA, l.pos, 1); advance(&l); } + else if (c == ':' && peek_next(&l) == ':') { tokens[idx++] = make_token(&l, TOK_COLON_COLON, l.pos, 2); advance(&l); advance(&l); } else if (c == ':') { tokens[idx++] = make_token(&l, TOK_COLON, l.pos, 1); advance(&l); } else if (c == ';') { tokens[idx++] = make_token(&l, TOK_SEMICOLON, l.pos, 1); advance(&l); } else { diff --git a/src/lexer/token.c b/src/lexer/token.c index 7c1ff70..aac1bb7 100644 --- a/src/lexer/token.c +++ b/src/lexer/token.c @@ -7,7 +7,7 @@ static const char* NAMES[] = { [TOK_FN] = "fn", [TOK_LET] = "let", [TOK_MUT] = "mut", [TOK_IF] = "if", [TOK_ELSE] = "else", [TOK_WHILE] = "while", [TOK_FOR] = "for", [TOK_IN] = "in", [TOK_RETURN] = "return", - [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", + [TOK_STRUCT] = "struct", [TOK_TYPE] = "type", [TOK_ENUM] = "enum", [TOK_I64] = "i64", [TOK_F64] = "f64", [TOK_BOOL] = "bool", [TOK_STR] = "str", [TOK_VOID] = "void", [TOK_INT_LIT] = "整数", [TOK_FLOAT_LIT] = "浮点数", [TOK_STR_LIT] = "字符串", [TOK_TRUE] = "true", [TOK_FALSE] = "false", @@ -23,7 +23,7 @@ static const char* NAMES[] = { [TOK_LBRACE] = "{", [TOK_RBRACE] = "}", [TOK_COMMA] = ",", [TOK_COLON] = ":", [TOK_SEMICOLON] = ";", [TOK_ASSIGN] = "=", - [TOK_DOT] = ".", + [TOK_DOT] = ".", [TOK_COLON_COLON] = "::", [TOK_EOF] = "EOF", [TOK_ERROR] = "错误", }; diff --git a/src/lexer/token.h b/src/lexer/token.h index cbf30ee..85dd797 100644 --- a/src/lexer/token.h +++ b/src/lexer/token.h @@ -7,7 +7,7 @@ typedef enum { // 关键字 TOK_FN, TOK_LET, TOK_MUT, TOK_IF, TOK_ELSE, TOK_WHILE, TOK_FOR, TOK_IN, TOK_RETURN, - TOK_STRUCT, TOK_TYPE, + TOK_STRUCT, TOK_TYPE, TOK_ENUM, // 类型关键字 TOK_I64, TOK_F64, TOK_BOOL, TOK_STR, TOK_VOID, // 字面量 @@ -24,7 +24,7 @@ typedef enum { TOK_LPAREN, TOK_RPAREN, TOK_LBRACE, TOK_RBRACE, TOK_COMMA, TOK_COLON, TOK_SEMICOLON, TOK_ASSIGN, // 特殊 - TOK_DOT, + TOK_DOT, TOK_COLON_COLON, TOK_EOF, TOK_ERROR, } TokenKind; diff --git a/src/parser/parser.c b/src/parser/parser.c index e963d4e..7e9e6c1 100644 --- a/src/parser/parser.c +++ b/src/parser/parser.c @@ -144,6 +144,17 @@ static AstNode* parse_struct_init(Parser* p, const Token* name, ErrorInfo* error static AstNode* parse_ident_or_call(Parser* p, ErrorInfo* error) { const Token* name = advance(p); + // 枚举变体引用: Name::Variant + if (peek(p)->kind == TOK_COLON_COLON) { + advance(p); // 跳过 :: + const Token* variant = expect(p, TOK_IDENT, error, "枚举变体名"); + if (!variant) return NULL; + return ast_make_enum_variant(p->arena, + arena_strdup_impl(p->arena, name->start, name->length), + arena_strdup_impl(p->arena, variant->start, variant->length), + tok_loc(name)); + } + // 结构体初始化: Name { field: val, ... } // 用提前看来区别 struct init 和 block: // struct init → { IDENT COLON ... ;block → { 可能是 let/if/while/... @@ -553,6 +564,7 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, AstNode* functions[256]; int fn_count = 0; AstNode* structs[64]; int struct_count = 0; AstNode* aliases[64]; int alias_count = 0; + AstNode* enums[64]; int enum_count = 0; while (peek(&p)->kind != TOK_EOF && !error->message) { if (peek(&p)->kind == TOK_STRUCT) { if (struct_count >= 64) { error->message = "结构体过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } @@ -569,11 +581,30 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, aliases[alias_count++] = ast_make_type_alias(a, arena_strdup_impl(a, alias_name->start, alias_name->length), rti.kind, rti.struct_name, tok_loc(type_tok)); + } else if (peek(&p)->kind == TOK_ENUM) { + advance(&p); + const Token* name = expect(&p, TOK_IDENT, error, "enum 后应为枚举名"); + if (!name) return NULL; + if (!expect(&p, TOK_LBRACE, error, "缺少 '{'")) return NULL; + const char* variants[64]; int vcount = 0; + while (peek(&p)->kind != TOK_RBRACE && !error->message) { + if (vcount >= 64) { error->message = "枚举变体过多(最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + const Token* vname = expect(&p, TOK_IDENT, error, "变体名"); + if (!vname) return NULL; + variants[vcount++] = arena_strdup_impl(p.arena, vname->start, vname->length); + if (peek(&p)->kind == TOK_COMMA) advance(&p); else break; + } + if (!expect(&p, TOK_RBRACE, error, "缺少 '}'")) return NULL; + const char** v_arr = arena_alloc_impl(p.arena, vcount * sizeof(const char*)); + memcpy(v_arr, variants, vcount * sizeof(const char*)); + AstNode* enum_decl = ast_make_enum_decl(p.arena, arena_strdup_impl(p.arena, name->start, name->length), v_arr, vcount, tok_loc(name)); + if (enum_count >= 64) { error->message = "枚举过多 (最多64)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } + enums[enum_count++] = enum_decl; } else if (peek(&p)->kind == TOK_FN) { if (fn_count >= 256) { error->message = "函数过多 (最多256)"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; return NULL; } functions[fn_count++] = parse_function(&p, error); } else { - error->message = "顶层只允许 fn、struct 或 type"; + error->message = "顶层只允许 fn、struct、type 或 enum"; error->filename = p.filename; error->line = peek(&p)->line; error->col = peek(&p)->col; @@ -587,6 +618,8 @@ AstNode* parse(Arena* a, const Token* tokens, size_t count, memcpy(st_arr, structs, struct_count * sizeof(AstNode*)); AstNode** al_arr = arena_alloc_impl(a, alias_count * sizeof(AstNode*)); memcpy(al_arr, aliases, alias_count * sizeof(AstNode*)); + AstNode** en_arr = arena_alloc_impl(a, enum_count * sizeof(AstNode*)); + memcpy(en_arr, enums, enum_count * sizeof(AstNode*)); return ast_make_program(a, fn_arr, fn_count, st_arr, struct_count, - al_arr, alias_count, loc_at(0, 0)); + al_arr, alias_count, en_arr, enum_count, loc_at(0, 0)); } diff --git a/src/sema/sema.c b/src/sema/sema.c index f459c1f..3f8d631 100644 --- a/src/sema/sema.c +++ b/src/sema/sema.c @@ -6,14 +6,22 @@ static TypeKind current_return_type = TYPE_VOID; static const char* current_return_struct_name = NULL; static TypeKind promote(TypeKind a, TypeKind b) { + // 枚举在算术运算中视为 i64 + if (a == TYPE_ENUM) a = TYPE_I64; + if (b == TYPE_ENUM) b = TYPE_I64; if (a == TYPE_F64 || b == TYPE_F64) return TYPE_F64; if (a == TYPE_I64 || b == TYPE_I64) return TYPE_I64; if (a == TYPE_BOOL || b == TYPE_BOOL) return TYPE_BOOL; return TYPE_ERROR; } -static bool is_numeric(TypeKind t) { return t == TYPE_I64 || t == TYPE_F64; } -static bool is_comparable(TypeKind a, TypeKind b) { return a == b; } +static bool is_numeric(TypeKind t) { return t == TYPE_I64 || t == TYPE_F64 || t == TYPE_ENUM; } +static bool is_comparable(TypeKind a, TypeKind b) { + if (a == b) return true; + // 枚举可以参与整数比较 + if ((a == TYPE_I64 && b == TYPE_ENUM) || (a == TYPE_ENUM && b == TYPE_I64)) return true; + return false; +} // === 向前声明 === static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* a); @@ -169,7 +177,8 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* expected_name ? expected_name : "struct", actual_name ? actual_name : type_name(actual)); } - } else if (actual != expected) { + } else if (actual != expected && + !(expected == TYPE_I64 && actual == TYPE_ENUM)) { error_add(errors, "", node->loc.line, node->loc.col, "参数 %zu 类型不匹配: 期望 '%s',得到 '%s'", i + 1, type_name(expected), type_name(actual)); @@ -287,6 +296,26 @@ static void analyze_expr(AstNode* node, Scope* scope, ErrorList* errors, Arena* break; } + case AST_ENUM_VARIANT: { + Symbol* enum_sym = scope_lookup_struct(scope, node->as.enum_variant.enum_name); + if (!enum_sym || enum_sym->kind != SYM_ENUM) { + error_add(errors, "", node->loc.line, node->loc.col, + "未定义的枚举 '%s'", node->as.enum_variant.enum_name); + node->type.kind = TYPE_ERROR; break; + } + int vi = scope_enum_variant_index(enum_sym, node->as.enum_variant.variant_name); + if (vi < 0) { + error_add(errors, "", node->loc.line, node->loc.col, + "枚举 '%s' 没有变体 '%s'", + node->as.enum_variant.enum_name, + node->as.enum_variant.variant_name); + node->type.kind = TYPE_ERROR; break; + } + node->as.enum_variant.variant_index = vi; + node->type.kind = TYPE_ENUM; + break; + } + default: break; } } @@ -308,6 +337,13 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* } } } + // 注册枚举 + for (size_t i = 0; i < node->as.program.enum_count; i++) { + AstNode* ed = node->as.program.enums[i]; + scope_insert_enum(scope, a, ed->as.enum_decl.name, + ed->as.enum_decl.variants, + ed->as.enum_decl.variant_count); + } // 第一遍:收集所有结构体定义 for (size_t i = 0; i < node->as.program.struct_count; i++) { AstNode* sd = node->as.program.structs[i]; @@ -548,7 +584,8 @@ static void analyze_node(AstNode* node, Scope* scope, ErrorList* errors, Arena* expected_name ? expected_name : "struct", actual_name ? actual_name : type_name(actual)); } - } else if (actual != expected) { + } else if (actual != expected && + !(expected == TYPE_I64 && actual == TYPE_ENUM)) { error_add(errors, "", node->loc.line, node->loc.col, "返回类型不匹配: 期望 '%s',得到 '%s'", type_name(expected), type_name(actual)); diff --git a/src/sema/symbol.c b/src/sema/symbol.c index 0caebf5..90469c3 100644 --- a/src/sema/symbol.c +++ b/src/sema/symbol.c @@ -94,7 +94,7 @@ Symbol* scope_insert_struct(Scope* scope, void* alloc, const char* name, Symbol* scope_lookup_struct(const Scope* scope, const char* name) { for (const Scope* s = scope; s; s = s->parent) { for (Symbol* sym = s->head; sym; sym = sym->next) { - if (sym->kind == SYM_STRUCT && strcmp(sym->name, name) == 0) + if ((sym->kind == SYM_STRUCT || sym->kind == SYM_ENUM) && strcmp(sym->name, name) == 0) return sym; } } @@ -109,3 +109,35 @@ int scope_struct_field_index(const Symbol* sym, const char* field_name) { } return -1; } + +Symbol* scope_insert_enum(Scope* scope, void* alloc, const char* name, + const char** vnames, size_t vc) { + if (scope->head) { + for (Symbol* sym = scope->head; sym; sym = sym->next) { + if (strcmp(sym->name, name) == 0) return NULL; + } + } + Symbol* sym = (Symbol*)arena_alloc_impl(alloc, sizeof(Symbol)); + if (!sym) return NULL; + sym->name = name; sym->kind = SYM_ENUM; sym->type = TYPE_ENUM; + sym->is_mut = false; sym->return_type = TYPE_VOID; + sym->param_types = NULL; sym->param_count = 0; + sym->struct_field_names = vnames; + sym->struct_field_types = NULL; + sym->struct_field_struct_names = NULL; + sym->struct_field_count = vc; + sym->struct_type_name = NULL; + sym->is_type_alias = false; + sym->next = scope->head; + scope->head = sym; + return sym; +} + +int scope_enum_variant_index(const Symbol* sym, const char* variant_name) { + if (sym->kind != SYM_ENUM) return -1; + for (size_t i = 0; i < sym->struct_field_count; i++) { + if (strcmp(sym->struct_field_names[i], variant_name) == 0) + return (int)i; + } + return -1; +} diff --git a/src/sema/symbol.h b/src/sema/symbol.h index 85b38ff..11ae22c 100644 --- a/src/sema/symbol.h +++ b/src/sema/symbol.h @@ -4,7 +4,7 @@ #include "l_lang.h" #include "ast.h" -typedef enum { SYM_VARIABLE, SYM_PARAMETER, SYM_FUNCTION, SYM_STRUCT } SymbolKind; +typedef enum { SYM_VARIABLE, SYM_PARAMETER, SYM_FUNCTION, SYM_STRUCT, SYM_ENUM } SymbolKind; typedef struct Symbol { const char* name; @@ -61,4 +61,11 @@ Symbol* scope_lookup_struct(const Scope* scope, const char* name); // 在结构体符号中查找字段索引(返回 -1 表示未找到) int scope_struct_field_index(const Symbol* sym, const char* field_name); +// 插入枚举符号 +Symbol* scope_insert_enum(Scope* scope, void* alloc, const char* name, + const char** vnames, size_t vc); + +// 在枚举符号中查找变体索引(返回 -1 表示未找到) +int scope_enum_variant_index(const Symbol* sym, const char* variant_name); + #endif diff --git a/test/programs/17_enum.l b/test/programs/17_enum.l new file mode 100644 index 0000000..ab99bde --- /dev/null +++ b/test/programs/17_enum.l @@ -0,0 +1,6 @@ +enum Color { Red, Green, Blue } +fn main() -> i64 { + let c = Color::Green; + print_i64(c); + return 0; +} diff --git a/test/test_codegen.c b/test/test_codegen.c index 41fc6bd..354af46 100644 --- a/test/test_codegen.c +++ b/test/test_codegen.c @@ -13,7 +13,7 @@ void test_codegen_simple_function() { AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -47,7 +47,7 @@ void test_codegen_if_else() { AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx2 = NULL; @@ -78,7 +78,7 @@ void test_codegen_binary_ops() { AstNode* body = ast_make_block(&a, stmts, 1, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx3 = NULL; @@ -108,7 +108,7 @@ void test_codegen_while_loop() { AstNode* fn_body = ast_make_block(&a, stmts, 2, loc_at(1, 1)); AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, fn_body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx4 = NULL; @@ -162,7 +162,7 @@ void test_codegen_struct_decl() { AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -217,7 +217,7 @@ void test_codegen_struct_field_access() { AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); AstNode* fns[] = { fn }; - AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, loc_at(1, 1)); + AstNode* prog = ast_make_program(&a, fns, 1, structs, 1, NULL, 0, NULL, 0, loc_at(1, 1)); const char* err = NULL; LLVMContextRef ctx = NULL; @@ -234,6 +234,56 @@ void test_codegen_struct_field_access() { arena_destroy(&a); } +/* === enum 代码生成测试 === */ + +void test_codegen_enum() { + Arena a = arena_create(1); + + /* 构造 AST: enum Color { Red, Green, Blue } + fn main() -> i64 { let c = Color::Green; print_i64(c); return 0; } */ + const char* cvariants[] = {"Red", "Green", "Blue"}; + AstNode* enum_decl = ast_make_enum_decl(&a, "Color", cvariants, 3, loc_at(1, 1)); + AstNode* enums[] = { enum_decl }; + + /* let c = Color::Green; */ + AstNode* cv = ast_make_enum_variant(&a, "Color", "Green", loc_at(1, 1)); + cv->as.enum_variant.variant_index = 1; /* Green = index 1 */ + cv->type.kind = TYPE_ENUM; + + AstNode* let_stmt = ast_make_let(&a, "c", TYPE_UNKNOWN, false, false, + cv, NULL, loc_at(1, 1)); + + /* print_i64(c); */ + AstNode* c_ident = ast_make_ident(&a, "c", loc_at(1, 1)); + c_ident->type.kind = TYPE_ENUM; + AstNode* args[] = { c_ident }; + AstNode* print_call = ast_make_call(&a, "print_i64", args, 1, loc_at(1, 1)); + + /* return 0; */ + AstNode* ret = ast_make_return(&a, ast_make_literal_i64(&a, 0, loc_at(1, 1)), loc_at(1, 1)); + + AstNode* stmts[] = { let_stmt, print_call, ret }; + AstNode* body = ast_make_block(&a, stmts, 3, loc_at(1, 1)); + AstNode* fn = ast_make_function(&a, "main", NULL, 0, TYPE_I64, NULL, body, loc_at(1, 1)); + AstNode* fns[] = { fn }; + + AstNode* prog = ast_make_program(&a, fns, 1, NULL, 0, NULL, 0, enums, 1, loc_at(1, 1)); + + const char* err = NULL; + LLVMContextRef ctx = NULL; + LLVMModuleRef mod = codegen_module(prog, &a, "test_enum", &err, &ctx); + ASSERT(mod != NULL); + ASSERT(err == NULL); + + char* verify_err = NULL; + int failed = LLVMVerifyModule(mod, LLVMReturnStatusAction, &verify_err); + ASSERT(!failed); + + LLVMDisposeModule(mod); + LLVMContextDispose(ctx); + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_codegen_simple_function); TEST_RUN(test_codegen_if_else); @@ -241,5 +291,6 @@ int main(void) { TEST_RUN(test_codegen_while_loop); TEST_RUN(test_codegen_struct_decl); TEST_RUN(test_codegen_struct_field_access); + TEST_RUN(test_codegen_enum); return test_summary(); } diff --git a/test/test_sema.c b/test/test_sema.c index e3df622..6b90a52 100644 --- a/test/test_sema.c +++ b/test/test_sema.c @@ -223,6 +223,42 @@ void test_struct_nested_type_ok() { arena_destroy(&a); } +/* === enum 语义分析测试 === */ + +void test_enum_ok() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "enum Color { Red, Green, Blue } fn main() { let c = Color::Green; print_i64(c); return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count == 0); + arena_destroy(&a); +} + +void test_enum_bad_variant() { + Arena a = arena_create(1); + size_t tc; ErrorInfo lex_err = {0}; + Token* toks = lex(&a, + "enum Color { Red, Green, Blue } fn main() { let c = Color::Yellow; return; }", + "test", &tc, &lex_err); + ASSERT(toks != NULL); + ErrorInfo parse_err = {0}; + AstNode* ast = parse(&a, toks, tc, "test", &parse_err); + ASSERT(ast != NULL); + + ErrorList errors; error_init(&errors); + sema_analyze(ast, &errors, &a); + ASSERT(errors.count > 0); // Yellow 不存在 + arena_destroy(&a); +} + int main(void) { TEST_RUN(test_type_error); TEST_RUN(test_undefined_var); @@ -237,5 +273,7 @@ int main(void) { TEST_RUN(test_struct_nested_type_ok); TEST_RUN(test_type_alias_ok); TEST_RUN(test_type_alias_struct); + TEST_RUN(test_enum_ok); + TEST_RUN(test_enum_bad_variant); return test_summary(); }