package generator import ( "database/sql" "strings" "testing" "github.com/entity-maker/entity-maker/internal/database" ) func TestGetRelationshipName(t *testing.T) { tests := []struct { name string input string expected string }{ {"with _id suffix", "author_info_id", "author_info"}, {"with user_id", "user_id", "user"}, {"without _id", "status", "status"}, {"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is {"empty", "", ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetRelationshipName(tt.input) if result != tt.expected { t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected) } }) } } func TestGetRelationshipEntityName(t *testing.T) { tests := []struct { name string input string expected string }{ {"plural table", "users", "User"}, {"compound plural", "user_accounts", "UserAccount"}, {"ies ending", "companies", "Company"}, {"already singular", "user", "User"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetRelationshipEntityName(tt.input) if result != tt.expected { t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected) } }) } } func TestGetRelationshipModuleName(t *testing.T) { tests := []struct { name string input string expected string }{ {"plural table", "users", "user"}, {"compound plural", "user_accounts", "user_account"}, {"ies ending", "companies", "company"}, {"already singular", "user", "user"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetRelationshipModuleName(tt.input) if result != tt.expected { t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected) } }) } } func TestGetFilterFieldName(t *testing.T) { tests := []struct { name string columnName string useIN bool expected string }{ {"id with IN", "id", true, "ids"}, {"id without IN", "id", false, "id"}, {"user_id with IN", "user_id", true, "user_ids"}, {"user_id without IN", "user_id", false, "user_id"}, {"machine_id with IN", "machine_id", true, "machine_ids"}, {"status without IN", "status", false, "status"}, {"status with IN", "status", true, "status"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetFilterFieldName(tt.columnName, tt.useIN) if result != tt.expected { t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q", tt.columnName, tt.useIN, result, tt.expected) } }) } } func TestShouldGenerateFilter(t *testing.T) { tests := []struct { name string col database.Column expected bool }{ { name: "boolean field", col: database.Column{ Name: "alive", DataType: "boolean", }, expected: true, }, { name: "id field", col: database.Column{ Name: "id", DataType: "integer", }, expected: true, }, { name: "foreign key field", col: database.Column{ Name: "user_id", DataType: "integer", }, expected: true, }, { name: "text field", col: database.Column{ Name: "name", DataType: "text", }, expected: true, }, { name: "varchar field", col: database.Column{ Name: "email", DataType: "character varying", }, expected: true, }, { name: "numeric field", col: database.Column{ Name: "amount", DataType: "numeric", }, expected: false, }, { name: "timestamp field", col: database.Column{ Name: "created_at", DataType: "timestamp with time zone", }, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := ShouldGenerateFilter(tt.col) if result != tt.expected { t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected) } }) } } func TestNeedsDecimalImport(t *testing.T) { tests := []struct { name string columns []database.Column expected bool }{ { name: "has numeric column", columns: []database.Column{ {Name: "amount", DataType: "numeric"}, }, expected: true, }, { name: "has decimal column", columns: []database.Column{ {Name: "price", DataType: "decimal"}, }, expected: true, }, { name: "no numeric columns", columns: []database.Column{ {Name: "id", DataType: "integer"}, {Name: "name", DataType: "varchar"}, }, expected: false, }, { name: "empty columns", columns: []database.Column{}, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := NeedsDecimalImport(tt.columns) if result != tt.expected { t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected) } }) } } func TestNeedsDatetimeImport(t *testing.T) { tests := []struct { name string columns []database.Column expected bool }{ { name: "has timestamp column", columns: []database.Column{ {Name: "created_at", DataType: "timestamp with time zone"}, }, expected: true, }, { name: "has date column", columns: []database.Column{ {Name: "birth_date", DataType: "date"}, }, expected: true, }, { name: "no datetime columns", columns: []database.Column{ {Name: "id", DataType: "integer"}, {Name: "name", DataType: "varchar"}, }, expected: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := NeedsDatetimeImport(tt.columns) if result != tt.expected { t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected) } }) } } func TestGetRequiredColumns(t *testing.T) { columns := []database.Column{ {Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true}, {Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}}, {Name: "email", DataType: "varchar", IsNullable: true}, {Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}}, {Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true}, } result := GetRequiredColumns(columns) // Should only include 'name' (not nullable, no default, not PK, not auto-increment) if len(result) != 1 { t.Errorf("Expected 1 required column, got %d", len(result)) } if len(result) > 0 && result[0].Name != "name" { t.Errorf("Expected required column to be 'name', got %q", result[0].Name) } } func TestGetOptionalColumns(t *testing.T) { columns := []database.Column{ {Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true}, {Name: "name", DataType: "varchar", IsNullable: false}, {Name: "email", DataType: "varchar", IsNullable: true}, {Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}}, } result := GetOptionalColumns(columns) // Should include 'email' (nullable) and 'created_at' (has default) if len(result) != 2 { t.Errorf("Expected 2 optional columns, got %d", len(result)) } } func TestGetForeignKeyForColumn(t *testing.T) { fks := []database.ForeignKey{ {ColumnName: "user_id", ForeignTableName: "users"}, {ColumnName: "company_id", ForeignTableName: "companies"}, } tests := []struct { name string columnName string expectNil bool expectFK string }{ {"found user_id", "user_id", false, "users"}, {"found company_id", "company_id", false, "companies"}, {"not found", "status_id", true, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetForeignKeyForColumn(tt.columnName, fks) if tt.expectNil { if result != nil { t.Errorf("Expected nil, got %+v", result) } } else { if result == nil { t.Errorf("Expected FK, got nil") } else if result.ForeignTableName != tt.expectFK { t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName) } } }) } } func TestGenerateInit(t *testing.T) { ctx := &Context{ EntityName: "User", ModuleName: "user", } result, err := GenerateInit(ctx) if err != nil { t.Fatalf("GenerateInit failed: %v", err) } // __init__.py should be empty if result != "" { t.Errorf("Expected empty __init__.py, got %q", result) } } func TestGenerateRepository(t *testing.T) { tableInfo := &database.TableInfo{ Schema: "public", TableName: "users", Columns: []database.Column{}, } ctx := NewContext(tableInfo, "") result, err := GenerateRepository(ctx) if err != nil { t.Fatalf("GenerateRepository failed: %v", err) } // Check for expected content expectedStrings := []string{ "class UserRepository", "CRUDRepository", "model_cls = User", "from televend_core.databases.base_repository import CRUDRepository", "from televend_core.databases.cloud_repositories.user.filter import", "from televend_core.databases.cloud_repositories.user.model import User", } for _, expected := range expectedStrings { if !strings.Contains(result, expected) { t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } } func TestGenerateManager(t *testing.T) { tableInfo := &database.TableInfo{ Schema: "public", TableName: "users", Columns: []database.Column{}, } ctx := NewContext(tableInfo, "") result, err := GenerateManager(ctx) if err != nil { t.Fatalf("GenerateManager failed: %v", err) } // Check for expected content expectedStrings := []string{ "class UserManager", "CRUDManager", "repository_cls = UserRepository", } for _, expected := range expectedStrings { if !strings.Contains(result, expected) { t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } } func TestNewContext(t *testing.T) { tableInfo := &database.TableInfo{ Schema: "public", TableName: "user_accounts", } // Test without override ctx := NewContext(tableInfo, "") if ctx.EntityName != "UserAccount" { t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName) } if ctx.ModuleName != "user_account" { t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName) } if ctx.TableConstant != "USER_ACCOUNT_TABLE" { t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant) } // Test with override ctx = NewContext(tableInfo, "CustomUser") if ctx.EntityName != "CustomUser" { t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName) } } func TestGenerateTable(t *testing.T) { tests := []struct { name string tableInfo *database.TableInfo expected []string }{ { name: "simple table with basic types", tableInfo: &database.TableInfo{ Schema: "public", TableName: "users", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false}, {Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false}, {Name: "email", DataType: "varchar", IsNullable: true}, }, ForeignKeys: []database.ForeignKey{}, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "from sqlalchemy import", "Column", "Table", "Integer", "String", "USER_TABLE = Table(", `"users"`, "metadata_obj", `Column("id", Integer, primary_key=True, autoincrement=True)`, `Column("name", String(255), nullable=False)`, `Column("email", String)`, }, }, { name: "table with foreign keys", tableInfo: &database.TableInfo{ Schema: "public", TableName: "posts", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "user_id", DataType: "integer", IsNullable: false}, {Name: "title", DataType: "text", IsNullable: false}, }, ForeignKeys: []database.ForeignKey{ {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, }, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "ForeignKey", `ForeignKey("users.id", deferrable=True, initially="DEFERRED")`, }, }, { name: "table with enum", tableInfo: &database.TableInfo{ Schema: "public", TableName: "orders", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false}, }, ForeignKeys: []database.ForeignKey{}, EnumTypes: map[string]database.EnumType{ "order_status": { TypeName: "order_status", Values: []string{"pending", "completed", "cancelled"}, }, }, }, expected: []string{ "Enum", "from televend_core.databases.cloud_repositories.order.enum import", "OrderStatus", `Enum(`, `*OrderStatus.to_value_list()`, `name="order_status"`, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContext(tt.tableInfo, "") result, err := GenerateTable(ctx) if err != nil { t.Fatalf("GenerateTable failed: %v", err) } for _, expected := range tt.expected { if !strings.Contains(result, expected) { t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } }) } } func TestGenerateModel(t *testing.T) { tests := []struct { name string tableInfo *database.TableInfo expected []string }{ { name: "simple model", tableInfo: &database.TableInfo{ Schema: "public", TableName: "users", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "name", DataType: "varchar", IsNullable: false}, {Name: "email", DataType: "varchar", IsNullable: true}, }, ForeignKeys: []database.ForeignKey{}, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "from dataclasses import dataclass", "from televend_core.databases.base_model import Base", "@dataclass", "class User(Base):", "name: str", "email: str | None = None", "id: int | None = None", }, }, { name: "model with foreign key", tableInfo: &database.TableInfo{ Schema: "public", TableName: "posts", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "user_id", DataType: "integer", IsNullable: false}, {Name: "title", DataType: "text", IsNullable: false}, }, ForeignKeys: []database.ForeignKey{ {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, }, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "from televend_core.databases.cloud_repositories.user.model import User", "user_id: int", "user: User", "title: str", }, }, { name: "model with datetime and decimal", tableInfo: &database.TableInfo{ Schema: "public", TableName: "orders", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "amount", DataType: "numeric", IsNullable: false}, {Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}}, }, ForeignKeys: []database.ForeignKey{}, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "from datetime import datetime", "from decimal import Decimal", "amount: Decimal", "created_at: datetime | None = None", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContext(tt.tableInfo, "") result, err := GenerateModel(ctx) if err != nil { t.Fatalf("GenerateModel failed: %v", err) } for _, expected := range tt.expected { if !strings.Contains(result, expected) { t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } }) } } func TestGenerateFilter(t *testing.T) { tests := []struct { name string tableInfo *database.TableInfo expected []string notExpect []string }{ { name: "filter with boolean and id fields", tableInfo: &database.TableInfo{ Schema: "public", TableName: "users", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true}, {Name: "name", DataType: "varchar"}, {Name: "alive", DataType: "boolean"}, {Name: "user_id", DataType: "integer"}, }, ForeignKeys: []database.ForeignKey{}, }, expected: []string{ "from televend_core.databases.base_filter import BaseFilter", "from televend_core.databases.common.filters.filters import EQ, IN, filterfield", "class UserFilter(BaseFilter):", "model_cls = User", "id: int | None = filterfield(operator=EQ)", "ids: list[int] | None = filterfield(field=\"id\", operator=IN)", "name: str | None = filterfield(operator=EQ)", "alive: bool | None = filterfield(operator=EQ, default=True)", "user_id: int | None = filterfield(operator=EQ)", "user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)", }, notExpect: []string{ "default=None", }, }, { name: "filter with no filterable fields", tableInfo: &database.TableInfo{ Schema: "public", TableName: "logs", Columns: []database.Column{ {Name: "timestamp", DataType: "timestamp with time zone"}, {Name: "amount", DataType: "numeric"}, }, ForeignKeys: []database.ForeignKey{}, }, expected: []string{ "class LogFilter(BaseFilter):", "pass", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContext(tt.tableInfo, "") result, err := GenerateFilter(ctx) if err != nil { t.Fatalf("GenerateFilter failed: %v", err) } for _, expected := range tt.expected { if !strings.Contains(result, expected) { t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } for _, notExpected := range tt.notExpect { if strings.Contains(result, notExpected) { t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s", notExpected, result) } } }) } } func TestGenerateLoadOptions(t *testing.T) { tests := []struct { name string tableInfo *database.TableInfo expected []string }{ { name: "load options with relationships", tableInfo: &database.TableInfo{ Schema: "public", TableName: "posts", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true}, {Name: "user_id", DataType: "integer"}, {Name: "category_id", DataType: "integer"}, }, ForeignKeys: []database.ForeignKey{ {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, {ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"}, }, }, expected: []string{ "from televend_core.databases.base_load_options import LoadOptions", "from televend_core.databases.common.load_options import joinload", "class PostLoadOptions(LoadOptions):", "model_cls = Post", `load_user: bool = joinload(relations=["user"])`, `load_category: bool = joinload(relations=["category"])`, }, }, { name: "load options with no relationships", tableInfo: &database.TableInfo{ Schema: "public", TableName: "settings", Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}}, ForeignKeys: []database.ForeignKey{}, }, expected: []string{ "class SettingLoadOptions(LoadOptions):", "pass", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContext(tt.tableInfo, "") result, err := GenerateLoadOptions(ctx) if err != nil { t.Fatalf("GenerateLoadOptions failed: %v", err) } for _, expected := range tt.expected { if !strings.Contains(result, expected) { t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } }) } } func TestGenerateFactory(t *testing.T) { tests := []struct { name string tableInfo *database.TableInfo expected []string }{ { name: "factory with basic fields", tableInfo: &database.TableInfo{ Schema: "public", TableName: "users", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}}, {Name: "alive", DataType: "boolean"}, }, ForeignKeys: []database.ForeignKey{}, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "from __future__ import annotations", "from typing import Type", "import factory", "class UserFactory(CloudBaseFactory):", "alive = True", "id = None", `name = factory.Faker("pystr", max_chars=100)`, "class Meta:", "model = User", "def create_minimal", }, }, { name: "factory with foreign keys", tableInfo: &database.TableInfo{ Schema: "public", TableName: "posts", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "user_id", DataType: "integer", IsNullable: false}, {Name: "title", DataType: "text"}, }, ForeignKeys: []database.ForeignKey{ {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, }, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ "from televend_core.databases.cloud_repositories.user.factory import", "UserFactory", `user = CustomSelfAttribute("..user", UserFactory)`, "user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)", `"user": kwargs.pop("user", None) or UserFactory.create_minimal()`, }, }, { name: "factory with decimal field", tableInfo: &database.TableInfo{ Schema: "public", TableName: "orders", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, {Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}}, }, ForeignKeys: []database.ForeignKey{}, EnumTypes: map[string]database.EnumType{}, }, expected: []string{ `amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`, }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContext(tt.tableInfo, "") result, err := GenerateFactory(ctx) if err != nil { t.Fatalf("GenerateFactory failed: %v", err) } for _, expected := range tt.expected { if !strings.Contains(result, expected) { t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } }) } } func TestGenerateMapper(t *testing.T) { tests := []struct { name string tableInfo *database.TableInfo expected []string }{ { name: "mapper with single foreign key", tableInfo: &database.TableInfo{ Schema: "public", TableName: "posts", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true}, {Name: "user_id", DataType: "integer"}, }, ForeignKeys: []database.ForeignKey{ {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, }, }, expected: []string{ "mapper_registry.map_imperatively(", "class_=Post,", "local_table=POST_TABLE,", "properties={", `"user": relationship(`, "User, lazy=relationship_loading_strategy.value", }, }, { name: "mapper with multiple foreign keys to same table", tableInfo: &database.TableInfo{ Schema: "public", TableName: "messages", Columns: []database.Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true}, {Name: "sender_id", DataType: "integer"}, {Name: "receiver_id", DataType: "integer"}, }, ForeignKeys: []database.ForeignKey{ {ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"}, {ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"}, }, }, expected: []string{ `"sender": relationship(`, "User, lazy=relationship_loading_strategy.value", `"receiver": relationship(`, "foreign_keys=MESSAGE_TABLE.columns.receiver_id,", }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := NewContext(tt.tableInfo, "") result, err := GenerateMapper(ctx) if err != nil { t.Fatalf("GenerateMapper failed: %v", err) } for _, expected := range tt.expected { if !strings.Contains(result, expected) { t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s", expected, result) } } }) } } func TestGetPythonTypeForColumn(t *testing.T) { ctx := &Context{ TableInfo: &database.TableInfo{ EnumTypes: map[string]database.EnumType{ "status_enum": { TypeName: "status_enum", Values: []string{"active", "inactive"}, }, }, }, } tests := []struct { name string col database.Column expected string }{ { name: "integer type", col: database.Column{DataType: "integer"}, expected: "int", }, { name: "varchar type", col: database.Column{DataType: "varchar"}, expected: "str", }, { name: "enum type", col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, expected: "StatusEnum", }, { name: "unknown enum type", col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"}, expected: "str", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetPythonTypeForColumn(tt.col, ctx) if result != tt.expected { t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected) } }) } }