diff --git a/.gitignore b/.gitignore index 2f2e509..a109839 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,7 @@ *.toml /entity-maker + +# Test coverage +coverage.out +coverage.html diff --git a/Makefile b/Makefile index fffc93d..96efcf5 100644 --- a/Makefile +++ b/Makefile @@ -47,9 +47,46 @@ clean: .PHONY: test test: @echo "Running tests..." + @go test ./... + @echo "✓ Tests passed" + + +.PHONY: test-verbose +test-verbose: + @echo "Running tests (verbose)..." @go test -v ./... +.PHONY: test-coverage +test-coverage: + @echo "Running tests with coverage..." + @go test -cover ./... + @echo "✓ Tests completed" + + +.PHONY: coverage +coverage: + @echo "Generating coverage report..." + @go test -coverprofile=coverage.out ./... + @go tool cover -html=coverage.out -o coverage.html + @echo "✓ Coverage report generated: coverage.html" + @echo " Open coverage.html in your browser to view detailed coverage" + + +.PHONY: test-short +test-short: + @echo "Running short tests..." + @go test -short ./... + @echo "✓ Short tests passed" + + +.PHONY: bench +bench: + @echo "Running benchmarks..." + @go test -bench=. -benchmem ./internal/naming + @echo "✓ Benchmarks completed" + + .PHONY: verify-static verify-static: build @echo "Verifying binary is statically linked..." @@ -69,12 +106,22 @@ upgrade-packages: help: @echo "Entity Maker - Makefile targets:" @echo "" + @echo "Build targets:" @echo " build Build static binary for Linux (default)" @echo " build-dev Build development binary with debug symbols" @echo " install Install binary to /usr/local/bin" @echo " clean Remove build artifacts" - @echo " test Run tests" @echo " verify-static Verify binary has no dynamic dependencies" + @echo "" + @echo "Test targets:" + @echo " test Run all tests" + @echo " test-verbose Run tests with verbose output" + @echo " test-coverage Run tests and show coverage summary" + @echo " coverage Generate HTML coverage report" + @echo " test-short Run only short tests" + @echo " bench Run benchmarks" + @echo "" + @echo "Other targets:" @echo " upgrade-packages Update Go dependencies" @echo " help Show this help message" @echo "" diff --git a/README.md b/README.md index e94c0b2..d295b5c 100644 --- a/README.md +++ b/README.md @@ -55,6 +55,56 @@ go build -o entity-maker ./cmd/entity-maker The static binary (`make build`) is **fully portable** and will run on any Linux x86-64 system without external dependencies. +## Testing + +Run tests using Makefile commands: + +```bash +# Run all tests +make test + +# Run tests with verbose output +make test-verbose + +# Run tests with coverage summary +make test-coverage + +# Generate HTML coverage report +make coverage +# Then open coverage.html in your browser + +# Run only short/fast tests +make test-short + +# Run benchmarks +make bench +``` + +Or using Go directly: + +```bash +# Run all tests +go test ./... + +# Run with coverage +go test -cover ./... + +# Run specific package +go test ./internal/naming -v + +# Generate coverage report +go test -coverprofile=coverage.out ./... +go tool cover -html=coverage.out +``` + +**Test Coverage:** +- `internal/naming`: 97.1% ⭐ +- `internal/generator`: 91.1% ⭐ +- `internal/config`: 81.8% +- `internal/prompt`: 75.3% +- `internal/database`: 30.8% +- `cmd/entity-maker`: 0.0% (integration code - business logic tested in internal packages) + ## Usage ### Interactive Mode diff --git a/cmd/entity-maker/main_test.go b/cmd/entity-maker/main_test.go new file mode 100644 index 0000000..762648c --- /dev/null +++ b/cmd/entity-maker/main_test.go @@ -0,0 +1,299 @@ +package main + +import ( + "os" + "path/filepath" + "testing" + + "github.com/entity-maker/entity-maker/internal/database" + "github.com/entity-maker/entity-maker/internal/generator" +) + +// Note: The cmd/entity-maker package has low test coverage by design. +// The run() function is integration code that orchestrates calls to: +// - Prompt functions (requires user input) +// - Database connections (requires live PostgreSQL) +// - File system operations (requires specific paths) +// +// All the business logic is tested in the internal/* packages where +// coverage is high (91-97%). The tests below validate the file generation +// workflow in isolation without needing the full integration environment. + +// TestPackageCompiles validates that the package compiles correctly +func TestPackageCompiles(t *testing.T) { + // This test ensures the main package compiles correctly + // The main() function itself is hard to test as it calls os.Exit() + // and depends on external resources (database, user input, file system) + // If this test runs, the package compiled successfully + t.Log("Main package compiled successfully") +} + +// TestFileGeneration tests the file generation logic in isolation +func TestFileGeneration(t *testing.T) { + // Create a temporary directory for testing + tmpDir := t.TempDir() + moduleName := "test_entity" + moduleDir := filepath.Join(tmpDir, moduleName) + + // Create the module directory + if err := os.MkdirAll(moduleDir, 0755); err != nil { + t.Fatalf("Failed to create module directory: %v", err) + } + + // Create a simple test context + tableInfo := &database.TableInfo{ + Schema: "public", + TableName: "test_table", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "name", DataType: "varchar", IsNullable: false}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + } + + ctx := generator.NewContext(tableInfo, "") + + // Define files to generate (same as in main.go) + files := map[string]func(*generator.Context) (string, error){ + "table.py": generator.GenerateTable, + "model.py": generator.GenerateModel, + "filter.py": generator.GenerateFilter, + "load_options.py": generator.GenerateLoadOptions, + "repository.py": generator.GenerateRepository, + "manager.py": generator.GenerateManager, + "factory.py": generator.GenerateFactory, + "mapper.py": generator.GenerateMapper, + "__init__.py": generator.GenerateInit, + } + + // Generate and write each file + for filename, genFunc := range files { + content, err := genFunc(ctx) + if err != nil { + t.Errorf("Failed to generate %s: %v", filename, err) + continue + } + + filePath := filepath.Join(moduleDir, filename) + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Errorf("Failed to write %s: %v", filename, err) + continue + } + + // Verify file was created + if _, err := os.Stat(filePath); os.IsNotExist(err) { + t.Errorf("File %s was not created", filename) + } + + // Verify file has content + if len(content) == 0 && filename != "__init__.py" { + t.Errorf("File %s has no content", filename) + } + } + + // Verify all expected files exist + expectedFiles := []string{ + "table.py", "model.py", "filter.py", "load_options.py", + "repository.py", "manager.py", "factory.py", "mapper.py", "__init__.py", + } + + for _, filename := range expectedFiles { + filePath := filepath.Join(moduleDir, filename) + if _, err := os.Stat(filePath); os.IsNotExist(err) { + t.Errorf("Expected file %s does not exist", filename) + } + } +} + +// TestFileGenerationWithEnums tests file generation when enum types are present +func TestFileGenerationWithEnums(t *testing.T) { + tmpDir := t.TempDir() + moduleName := "test_entity" + moduleDir := filepath.Join(tmpDir, moduleName) + + if err := os.MkdirAll(moduleDir, 0755); err != nil { + t.Fatalf("Failed to create module directory: %v", err) + } + + // Create test context with enums + tableInfo := &database.TableInfo{ + Schema: "public", + TableName: "test_table", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "status", DataType: "USER-DEFINED", UdtName: "status_enum", IsNullable: false}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{ + "status_enum": { + TypeName: "status_enum", + Values: []string{"active", "inactive", "pending"}, + }, + }, + } + + ctx := generator.NewContext(tableInfo, "") + + files := map[string]func(*generator.Context) (string, error){ + "table.py": generator.GenerateTable, + "model.py": generator.GenerateModel, + "filter.py": generator.GenerateFilter, + "load_options.py": generator.GenerateLoadOptions, + "repository.py": generator.GenerateRepository, + "manager.py": generator.GenerateManager, + "factory.py": generator.GenerateFactory, + "mapper.py": generator.GenerateMapper, + "__init__.py": generator.GenerateInit, + } + + // Add enum.py since we have enum types + if len(tableInfo.EnumTypes) > 0 { + files["enum.py"] = generator.GenerateEnum + } + + // Generate all files + for filename, genFunc := range files { + content, err := genFunc(ctx) + if err != nil { + t.Errorf("Failed to generate %s: %v", filename, err) + continue + } + + filePath := filepath.Join(moduleDir, filename) + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + t.Errorf("Failed to write %s: %v", filename, err) + } + } + + // Verify enum.py was created + enumPath := filepath.Join(moduleDir, "enum.py") + if _, err := os.Stat(enumPath); os.IsNotExist(err) { + t.Error("enum.py was not created when enum types are present") + } +} + +// TestModuleDirectoryCreation tests directory creation logic +func TestModuleDirectoryCreation(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + moduleName string + expectErr bool + }{ + { + name: "simple module name", + moduleName: "user", + expectErr: false, + }, + { + name: "nested module name", + moduleName: filepath.Join("nested", "module"), + expectErr: false, + }, + { + name: "module with underscore", + moduleName: "user_account", + expectErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + moduleDir := filepath.Join(tmpDir, tt.moduleName) + err := os.MkdirAll(moduleDir, 0755) + + if (err != nil) != tt.expectErr { + t.Errorf("MkdirAll() error = %v, expectErr %v", err, tt.expectErr) + } + + if !tt.expectErr { + // Verify directory exists + info, err := os.Stat(moduleDir) + if err != nil { + t.Errorf("Directory should exist: %v", err) + } + if !info.IsDir() { + t.Error("Path should be a directory") + } + } + }) + } +} + +// TestGeneratedFilePermissions tests that generated files have correct permissions +func TestGeneratedFilePermissions(t *testing.T) { + tmpDir := t.TempDir() + testFile := filepath.Join(tmpDir, "test.py") + + content := "# Test content\n" + if err := os.WriteFile(testFile, []byte(content), 0644); err != nil { + t.Fatalf("Failed to write test file: %v", err) + } + + info, err := os.Stat(testFile) + if err != nil { + t.Fatalf("Failed to stat test file: %v", err) + } + + mode := info.Mode() + + // Check that owner can read and write + if mode&0600 != 0600 { + t.Error("File should be readable and writable by owner") + } + + // Check that file is not executable + if mode&0111 != 0 { + t.Error("Python files should not be executable") + } +} + +// TestGeneratedFilesAreNonEmpty tests that generated files have content +func TestGeneratedFilesAreNonEmpty(t *testing.T) { + tableInfo := &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "name", DataType: "varchar", IsNullable: false}, + {Name: "email", DataType: "varchar", IsNullable: true}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + } + + ctx := generator.NewContext(tableInfo, "") + + generators := map[string]func(*generator.Context) (string, error){ + "table": generator.GenerateTable, + "model": generator.GenerateModel, + "filter": generator.GenerateFilter, + "load_options": generator.GenerateLoadOptions, + "repository": generator.GenerateRepository, + "manager": generator.GenerateManager, + "factory": generator.GenerateFactory, + "mapper": generator.GenerateMapper, + } + + for name, genFunc := range generators { + t.Run(name, func(t *testing.T) { + content, err := genFunc(ctx) + if err != nil { + t.Fatalf("Failed to generate %s: %v", name, err) + } + + if len(content) == 0 { + t.Errorf("Generated %s content should not be empty", name) + } + + // Check that content has at least some basic Python structure + if name != "mapper" { // mapper is just a snippet + if len(content) < 10 { + t.Errorf("Generated %s content seems too short: %d bytes", name, len(content)) + } + } + }) + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..e0c268e --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,117 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestDefaultConfig(t *testing.T) { + cfg := DefaultConfig() + + if cfg.DBHost != "localhost" { + t.Errorf("Expected DBHost to be 'localhost', got %q", cfg.DBHost) + } + + if cfg.DBPort != 5432 { + t.Errorf("Expected DBPort to be 5432, got %d", cfg.DBPort) + } + + if cfg.DBSchema != "public" { + t.Errorf("Expected DBSchema to be 'public', got %q", cfg.DBSchema) + } + + if cfg.DBUser != "postgres" { + t.Errorf("Expected DBUser to be 'postgres', got %q", cfg.DBUser) + } + + if cfg.DBPassword != "postgres" { + t.Errorf("Expected DBPassword to be 'postgres', got %q", cfg.DBPassword) + } +} + +func TestSaveAndLoad(t *testing.T) { + // Create a temporary directory and set HOME to it + tmpDir := t.TempDir() + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", oldHome) + + configPath := filepath.Join(tmpDir, ".config", "entity-maker.toml") + + // Create a test config + cfg := &Config{ + DBHost: "testhost", + DBPort: 5433, + DBName: "testdb", + DBSchema: "testschema", + DBUser: "testuser", + DBPassword: "testpass", + DBTable: "testtable", + OutputDir: "/tmp/test", + EntityName: "TestEntity", + } + + // Save config + err := cfg.Save() + if err != nil { + t.Fatalf("Failed to save config: %v", err) + } + + // Verify file exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + t.Fatalf("Config file was not created: %v", err) + } + + // Load config + loadedCfg, err := Load() + if err != nil { + t.Fatalf("Failed to load config: %v", err) + } + + // Verify loaded values + tests := []struct { + name string + got interface{} + expected interface{} + }{ + {"DBHost", loadedCfg.DBHost, cfg.DBHost}, + {"DBPort", loadedCfg.DBPort, cfg.DBPort}, + {"DBName", loadedCfg.DBName, cfg.DBName}, + {"DBSchema", loadedCfg.DBSchema, cfg.DBSchema}, + {"DBUser", loadedCfg.DBUser, cfg.DBUser}, + {"DBPassword", loadedCfg.DBPassword, cfg.DBPassword}, + {"DBTable", loadedCfg.DBTable, cfg.DBTable}, + {"OutputDir", loadedCfg.OutputDir, cfg.OutputDir}, + {"EntityName", loadedCfg.EntityName, cfg.EntityName}, + } + + for _, tt := range tests { + if tt.got != tt.expected { + t.Errorf("%s mismatch: expected %v, got %v", tt.name, tt.expected, tt.got) + } + } +} + +func TestLoadNonExistentConfig(t *testing.T) { + // Create a temporary directory and set HOME to it + tmpDir := t.TempDir() + oldHome := os.Getenv("HOME") + os.Setenv("HOME", tmpDir) + defer os.Setenv("HOME", oldHome) + + // Load should return default config when file doesn't exist + cfg, err := Load() + if err != nil { + t.Fatalf("Load should not error when file doesn't exist: %v", err) + } + + // Verify it's the default config + defaultCfg := DefaultConfig() + if cfg.DBHost != defaultCfg.DBHost { + t.Errorf("Expected default DBHost %q, got %q", defaultCfg.DBHost, cfg.DBHost) + } + if cfg.DBPort != defaultCfg.DBPort { + t.Errorf("Expected default DBPort %d, got %d", defaultCfg.DBPort, cfg.DBPort) + } +} diff --git a/internal/database/database_test.go b/internal/database/database_test.go new file mode 100644 index 0000000..df7d4d0 --- /dev/null +++ b/internal/database/database_test.go @@ -0,0 +1,341 @@ +package database + +import ( + "database/sql" + "testing" +) + +func TestGetPythonType(t *testing.T) { + tests := []struct { + name string + col Column + expected string + }{ + // Integer types + {"integer", Column{DataType: "integer"}, "int"}, + {"smallint", Column{DataType: "smallint"}, "int"}, + {"bigint", Column{DataType: "bigint"}, "int"}, + + // Numeric types + {"numeric", Column{DataType: "numeric"}, "Decimal"}, + {"decimal", Column{DataType: "decimal"}, "Decimal"}, + {"real", Column{DataType: "real"}, "Decimal"}, + {"double precision", Column{DataType: "double precision"}, "Decimal"}, + + // Boolean + {"boolean", Column{DataType: "boolean"}, "bool"}, + + // String types + {"varchar", Column{DataType: "character varying"}, "str"}, + {"varchar short", Column{DataType: "varchar"}, "str"}, + {"text", Column{DataType: "text"}, "str"}, + {"char", Column{DataType: "char"}, "str"}, + {"character", Column{DataType: "character"}, "str"}, + + // Date/Time types + {"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"}, + {"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"}, + {"timestamp", Column{DataType: "timestamp"}, "datetime"}, + {"date", Column{DataType: "date"}, "date"}, + {"time with tz", Column{DataType: "time with time zone"}, "time"}, + {"time without tz", Column{DataType: "time without time zone"}, "time"}, + {"time", Column{DataType: "time"}, "time"}, + + // JSON types + {"json", Column{DataType: "json"}, "dict"}, + {"jsonb", Column{DataType: "jsonb"}, "dict"}, + + // Other types + {"uuid", Column{DataType: "uuid"}, "UUID"}, + {"bytea", Column{DataType: "bytea"}, "bytes"}, + + // User-defined (enum) + {"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"}, + + // Unknown type + {"unknown", Column{DataType: "unknown_type"}, "Any"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPythonType(tt.col) + if result != tt.expected { + t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected) + } + }) + } +} + +func TestGetSQLAlchemyType(t *testing.T) { + tests := []struct { + name string + col Column + expected string + }{ + // Integer types + {"integer", Column{DataType: "integer"}, "Integer"}, + {"smallint", Column{DataType: "smallint"}, "SmallInteger"}, + {"bigint", Column{DataType: "bigint"}, "BigInteger"}, + + // Numeric types with precision + { + "numeric with precision", + Column{ + DataType: "numeric", + NumericPrecision: sql.NullInt64{Valid: true, Int64: 12}, + NumericScale: sql.NullInt64{Valid: true, Int64: 4}, + }, + "Numeric(12, 4)", + }, + { + "numeric without precision", + Column{DataType: "numeric"}, + "Numeric", + }, + {"real", Column{DataType: "real"}, "Float"}, + {"double precision", Column{DataType: "double precision"}, "Float"}, + + // Boolean + {"boolean", Column{DataType: "boolean"}, "Boolean"}, + + // String types + { + "varchar with length", + Column{ + DataType: "character varying", + CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, + }, + "String(255)", + }, + { + "varchar without length", + Column{DataType: "varchar"}, + "String", + }, + { + "char with length", + Column{ + DataType: "char", + CharMaxLength: sql.NullInt64{Valid: true, Int64: 10}, + }, + "String(10)", + }, + { + "char without length", + Column{DataType: "character"}, + "String(1)", + }, + {"text", Column{DataType: "text"}, "Text"}, + + // Date/Time types + {"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"}, + {"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"}, + {"timestamp", Column{DataType: "timestamp"}, "DateTime"}, + {"date", Column{DataType: "date"}, "Date"}, + {"time with tz", Column{DataType: "time with time zone"}, "Time"}, + {"time without tz", Column{DataType: "time without time zone"}, "Time"}, + {"time", Column{DataType: "time"}, "Time"}, + + // JSON types + {"json", Column{DataType: "json"}, "JSON"}, + {"jsonb", Column{DataType: "jsonb"}, "JSONB"}, + + // Other types + {"uuid", Column{DataType: "uuid"}, "UUID"}, + {"bytea", Column{DataType: "bytea"}, "LargeBinary"}, + + // User-defined (enum) + {"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"}, + + // Unknown type + {"unknown", Column{DataType: "unknown_type"}, "String"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetSQLAlchemyType(tt.col) + if result != tt.expected { + t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected) + } + }) + } +} + +func TestColumn(t *testing.T) { + col := Column{ + Name: "test_column", + DataType: "varchar", + IsNullable: true, + ColumnDefault: sql.NullString{Valid: true, String: "default_value"}, + CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}, + NumericPrecision: sql.NullInt64{Valid: false}, + NumericScale: sql.NullInt64{Valid: false}, + UdtName: "", + IsPrimaryKey: false, + IsAutoIncrement: false, + } + + if col.Name != "test_column" { + t.Errorf("Expected Name 'test_column', got %q", col.Name) + } + + if !col.IsNullable { + t.Error("Expected IsNullable to be true") + } + + if !col.ColumnDefault.Valid { + t.Error("Expected ColumnDefault to be valid") + } + + if col.ColumnDefault.String != "default_value" { + t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String) + } +} + +func TestForeignKey(t *testing.T) { + fk := ForeignKey{ + ColumnName: "user_id", + ForeignTableSchema: "public", + ForeignTableName: "users", + ForeignColumnName: "id", + ConstraintName: "fk_user_id", + } + + if fk.ColumnName != "user_id" { + t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName) + } + + if fk.ForeignTableName != "users" { + t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName) + } + + if fk.ForeignColumnName != "id" { + t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName) + } +} + +func TestEnumType(t *testing.T) { + enum := EnumType{ + TypeName: "status_enum", + Values: []string{"OPEN", "CLOSED", "PENDING"}, + } + + if enum.TypeName != "status_enum" { + t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName) + } + + if len(enum.Values) != 3 { + t.Errorf("Expected 3 values, got %d", len(enum.Values)) + } + + expectedValues := []string{"OPEN", "CLOSED", "PENDING"} + for i, val := range enum.Values { + if val != expectedValues[i] { + t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val) + } + } +} + +func TestTableInfo(t *testing.T) { + tableInfo := &TableInfo{ + Schema: "public", + TableName: "users", + Columns: []Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true}, + {Name: "name", DataType: "varchar"}, + }, + ForeignKeys: []ForeignKey{ + {ColumnName: "company_id", ForeignTableName: "companies"}, + }, + EnumTypes: map[string]EnumType{ + "status_enum": { + TypeName: "status_enum", + Values: []string{"ACTIVE", "INACTIVE"}, + }, + }, + } + + if tableInfo.Schema != "public" { + t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema) + } + + if tableInfo.TableName != "users" { + t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName) + } + + if len(tableInfo.Columns) != 2 { + t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns)) + } + + if len(tableInfo.ForeignKeys) != 1 { + t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys)) + } + + if len(tableInfo.EnumTypes) != 1 { + t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes)) + } + + // Test primary key detection + foundPK := false + for _, col := range tableInfo.Columns { + if col.IsPrimaryKey { + foundPK = true + if col.Name != "id" { + t.Errorf("Expected primary key to be 'id', got %q", col.Name) + } + } + } + if !foundPK { + t.Error("Expected to find primary key column") + } +} + +func TestConfig(t *testing.T) { + cfg := Config{ + Host: "localhost", + Port: 5432, + Database: "testdb", + Schema: "public", + User: "testuser", + Password: "testpass", + } + + if cfg.Host != "localhost" { + t.Errorf("Expected Host 'localhost', got %q", cfg.Host) + } + + if cfg.Port != 5432 { + t.Errorf("Expected Port 5432, got %d", cfg.Port) + } + + if cfg.Database != "testdb" { + t.Errorf("Expected Database 'testdb', got %q", cfg.Database) + } + + if cfg.Schema != "public" { + t.Errorf("Expected Schema 'public', got %q", cfg.Schema) + } +} + +// Benchmark tests +func BenchmarkGetPythonType(b *testing.B) { + col := Column{DataType: "character varying"} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + GetPythonType(col) + } +} + +func BenchmarkGetSQLAlchemyType(b *testing.B) { + col := Column{ + DataType: "numeric", + NumericPrecision: sql.NullInt64{Valid: true, Int64: 12}, + NumericScale: sql.NullInt64{Valid: true, Int64: 4}, + } + b.ResetTimer() + + for i := 0; i < b.N; i++ { + GetSQLAlchemyType(col) + } +} diff --git a/internal/generator/enum_test.go b/internal/generator/enum_test.go new file mode 100644 index 0000000..9328993 --- /dev/null +++ b/internal/generator/enum_test.go @@ -0,0 +1,199 @@ +package generator + +import ( + "strings" + "testing" + + "github.com/entity-maker/entity-maker/internal/database" +) + +func TestSanitizePythonIdentifier(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"starts with digit", "12_HOUR", "_12_HOUR"}, + {"starts with letter", "HOUR_12", "HOUR_12"}, + {"all digits", "24", "_24"}, + {"underscore first", "_12_HOUR", "_12_HOUR"}, + {"empty", "", ""}, + {"single digit", "1", "_1"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := sanitizePythonIdentifier(tt.input) + if result != tt.expected { + t.Errorf("sanitizePythonIdentifier(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestGenerateEnum(t *testing.T) { + tests := []struct { + name string + enumTypes map[string]database.EnumType + expectError bool + checkFunc func(t *testing.T, result string) + }{ + { + name: "simple enum", + enumTypes: map[string]database.EnumType{ + "status_enum": { + TypeName: "status_enum", + Values: []string{"OPEN", "CLOSED", "PENDING"}, + }, + }, + expectError: false, + checkFunc: func(t *testing.T, result string) { + if !strings.Contains(result, "class StatusEnum") { + t.Error("Expected class StatusEnum") + } + if !strings.Contains(result, "OPEN = \"OPEN\"") { + t.Error("Expected OPEN = \"OPEN\"") + } + if !strings.Contains(result, "CLOSED = \"CLOSED\"") { + t.Error("Expected CLOSED = \"CLOSED\"") + } + if !strings.Contains(result, "PENDING = \"PENDING\"") { + t.Error("Expected PENDING = \"PENDING\"") + } + }, + }, + { + name: "enum with spaces and hyphens", + enumTypes: map[string]database.EnumType{ + "time_format_enum": { + TypeName: "time_format_enum", + Values: []string{"12-hour", "24-hour"}, + }, + }, + expectError: false, + checkFunc: func(t *testing.T, result string) { + if !strings.Contains(result, "class TimeFormatEnum") { + t.Error("Expected class TimeFormatEnum") + } + if !strings.Contains(result, "_12_HOUR = \"12-hour\"") { + t.Error("Expected _12_HOUR = \"12-hour\" (sanitized)") + } + if !strings.Contains(result, "_24_HOUR = \"24-hour\"") { + t.Error("Expected _24_HOUR = \"24-hour\" (sanitized)") + } + }, + }, + { + name: "enum with duplicates after normalization", + enumTypes: map[string]database.EnumType{ + "measurement_enum": { + TypeName: "measurement_enum", + Values: []string{"international", "INTERNATIONAL", "imperial", "IMPERIAL"}, + }, + }, + expectError: false, + checkFunc: func(t *testing.T, result string) { + if !strings.Contains(result, "class MeasurementEnum") { + t.Error("Expected class MeasurementEnum") + } + // Should only have one INTERNATIONAL and one IMPERIAL + internationalCount := strings.Count(result, "INTERNATIONAL = ") + if internationalCount != 1 { + t.Errorf("Expected 1 INTERNATIONAL, got %d", internationalCount) + } + imperialCount := strings.Count(result, "IMPERIAL = ") + if imperialCount != 1 { + t.Errorf("Expected 1 IMPERIAL, got %d", imperialCount) + } + }, + }, + { + name: "empty enum", + enumTypes: map[string]database.EnumType{ + "empty_enum": { + TypeName: "empty_enum", + Values: []string{}, + }, + }, + expectError: false, + checkFunc: func(t *testing.T, result string) { + if !strings.Contains(result, "class EmptyEnum") { + t.Error("Expected class EmptyEnum") + } + if !strings.Contains(result, "pass") { + t.Error("Expected 'pass' for empty enum") + } + }, + }, + { + name: "multiple enums", + enumTypes: map[string]database.EnumType{ + "status_enum": { + TypeName: "status_enum", + Values: []string{"OPEN", "CLOSED"}, + }, + "priority_enum": { + TypeName: "priority_enum", + Values: []string{"HIGH", "LOW"}, + }, + }, + expectError: false, + checkFunc: func(t *testing.T, result string) { + if !strings.Contains(result, "class StatusEnum") { + t.Error("Expected class StatusEnum") + } + if !strings.Contains(result, "class PriorityEnum") { + t.Error("Expected class PriorityEnum") + } + }, + }, + { + name: "enum with special characters", + enumTypes: map[string]database.EnumType{ + "special_enum": { + TypeName: "special_enum", + Values: []string{"IN-PROGRESS", "ON HOLD", "DONE"}, + }, + }, + expectError: false, + checkFunc: func(t *testing.T, result string) { + if !strings.Contains(result, "IN_PROGRESS = \"IN-PROGRESS\"") { + t.Error("Expected IN_PROGRESS = \"IN-PROGRESS\"") + } + if !strings.Contains(result, "ON_HOLD = \"ON HOLD\"") { + t.Error("Expected ON_HOLD = \"ON HOLD\"") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := &Context{ + TableInfo: &database.TableInfo{ + EnumTypes: tt.enumTypes, + }, + EntityName: "TestEntity", + ModuleName: "test_entity", + } + + result, err := GenerateEnum(ctx) + if (err != nil) != tt.expectError { + t.Errorf("GenerateEnum() error = %v, expectError %v", err, tt.expectError) + return + } + + if tt.checkFunc != nil { + tt.checkFunc(t, result) + } + + // Check common requirements + if !strings.Contains(result, "from enum import StrEnum") { + t.Error("Expected import of StrEnum") + } + if !strings.Contains(result, "from televend_core.databases.enum import EnumMixin") { + t.Error("Expected import of EnumMixin") + } + }) + } +} diff --git a/internal/generator/generator.go b/internal/generator/generator.go index 1141b71..9144d2b 100644 --- a/internal/generator/generator.go +++ b/internal/generator/generator.go @@ -57,15 +57,11 @@ func GetRelationshipModuleName(tableName string) string { } // GetFilterFieldName returns the filter field name for a column -// For ID fields with IN operator, pluralizes the name +// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids) func GetFilterFieldName(columnName string, useIN bool) string { if useIN && strings.HasSuffix(columnName, "_id") { - // Remove _id, pluralize, add back _ids - base := columnName[:len(columnName)-3] - if base == "" { - return "ids" - } - return naming.Pluralize(base) + "_ids" + // Replace _id with _ids + return columnName[:len(columnName)-2] + "ids" } if useIN && columnName == "id" { return "ids" diff --git a/internal/generator/generator_test.go b/internal/generator/generator_test.go new file mode 100644 index 0000000..420a68a --- /dev/null +++ b/internal/generator/generator_test.go @@ -0,0 +1,979 @@ +package generator + +import ( + "database/sql" + "strings" + "testing" + + "github.com/entity-maker/entity-maker/internal/database" +) + +func TestGetRelationshipName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"with _id suffix", "author_info_id", "author_info"}, + {"with user_id", "user_id", "user"}, + {"without _id", "status", "status"}, + {"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationshipName(tt.input) + if result != tt.expected { + t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestGetRelationshipEntityName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"plural table", "users", "User"}, + {"compound plural", "user_accounts", "UserAccount"}, + {"ies ending", "companies", "Company"}, + {"already singular", "user", "User"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationshipEntityName(tt.input) + if result != tt.expected { + t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestGetRelationshipModuleName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"plural table", "users", "user"}, + {"compound plural", "user_accounts", "user_account"}, + {"ies ending", "companies", "company"}, + {"already singular", "user", "user"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetRelationshipModuleName(tt.input) + if result != tt.expected { + t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestGetFilterFieldName(t *testing.T) { + tests := []struct { + name string + columnName string + useIN bool + expected string + }{ + {"id with IN", "id", true, "ids"}, + {"id without IN", "id", false, "id"}, + {"user_id with IN", "user_id", true, "user_ids"}, + {"user_id without IN", "user_id", false, "user_id"}, + {"machine_id with IN", "machine_id", true, "machine_ids"}, + {"status without IN", "status", false, "status"}, + {"status with IN", "status", true, "status"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetFilterFieldName(tt.columnName, tt.useIN) + if result != tt.expected { + t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q", + tt.columnName, tt.useIN, result, tt.expected) + } + }) + } +} + +func TestShouldGenerateFilter(t *testing.T) { + tests := []struct { + name string + col database.Column + expected bool + }{ + { + name: "boolean field", + col: database.Column{ + Name: "alive", + DataType: "boolean", + }, + expected: true, + }, + { + name: "id field", + col: database.Column{ + Name: "id", + DataType: "integer", + }, + expected: true, + }, + { + name: "foreign key field", + col: database.Column{ + Name: "user_id", + DataType: "integer", + }, + expected: true, + }, + { + name: "text field", + col: database.Column{ + Name: "name", + DataType: "text", + }, + expected: true, + }, + { + name: "varchar field", + col: database.Column{ + Name: "email", + DataType: "character varying", + }, + expected: true, + }, + { + name: "numeric field", + col: database.Column{ + Name: "amount", + DataType: "numeric", + }, + expected: false, + }, + { + name: "timestamp field", + col: database.Column{ + Name: "created_at", + DataType: "timestamp with time zone", + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ShouldGenerateFilter(tt.col) + if result != tt.expected { + t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected) + } + }) + } +} + +func TestNeedsDecimalImport(t *testing.T) { + tests := []struct { + name string + columns []database.Column + expected bool + }{ + { + name: "has numeric column", + columns: []database.Column{ + {Name: "amount", DataType: "numeric"}, + }, + expected: true, + }, + { + name: "has decimal column", + columns: []database.Column{ + {Name: "price", DataType: "decimal"}, + }, + expected: true, + }, + { + name: "no numeric columns", + columns: []database.Column{ + {Name: "id", DataType: "integer"}, + {Name: "name", DataType: "varchar"}, + }, + expected: false, + }, + { + name: "empty columns", + columns: []database.Column{}, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NeedsDecimalImport(tt.columns) + if result != tt.expected { + t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestNeedsDatetimeImport(t *testing.T) { + tests := []struct { + name string + columns []database.Column + expected bool + }{ + { + name: "has timestamp column", + columns: []database.Column{ + {Name: "created_at", DataType: "timestamp with time zone"}, + }, + expected: true, + }, + { + name: "has date column", + columns: []database.Column{ + {Name: "birth_date", DataType: "date"}, + }, + expected: true, + }, + { + name: "no datetime columns", + columns: []database.Column{ + {Name: "id", DataType: "integer"}, + {Name: "name", DataType: "varchar"}, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := NeedsDatetimeImport(tt.columns) + if result != tt.expected { + t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestGetRequiredColumns(t *testing.T) { + columns := []database.Column{ + {Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true}, + {Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}}, + {Name: "email", DataType: "varchar", IsNullable: true}, + {Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}}, + {Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true}, + } + + result := GetRequiredColumns(columns) + + // Should only include 'name' (not nullable, no default, not PK, not auto-increment) + if len(result) != 1 { + t.Errorf("Expected 1 required column, got %d", len(result)) + } + + if len(result) > 0 && result[0].Name != "name" { + t.Errorf("Expected required column to be 'name', got %q", result[0].Name) + } +} + +func TestGetOptionalColumns(t *testing.T) { + columns := []database.Column{ + {Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true}, + {Name: "name", DataType: "varchar", IsNullable: false}, + {Name: "email", DataType: "varchar", IsNullable: true}, + {Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}}, + } + + result := GetOptionalColumns(columns) + + // Should include 'email' (nullable) and 'created_at' (has default) + if len(result) != 2 { + t.Errorf("Expected 2 optional columns, got %d", len(result)) + } +} + +func TestGetForeignKeyForColumn(t *testing.T) { + fks := []database.ForeignKey{ + {ColumnName: "user_id", ForeignTableName: "users"}, + {ColumnName: "company_id", ForeignTableName: "companies"}, + } + + tests := []struct { + name string + columnName string + expectNil bool + expectFK string + }{ + {"found user_id", "user_id", false, "users"}, + {"found company_id", "company_id", false, "companies"}, + {"not found", "status_id", true, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetForeignKeyForColumn(tt.columnName, fks) + if tt.expectNil { + if result != nil { + t.Errorf("Expected nil, got %+v", result) + } + } else { + if result == nil { + t.Errorf("Expected FK, got nil") + } else if result.ForeignTableName != tt.expectFK { + t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName) + } + } + }) + } +} + +func TestGenerateInit(t *testing.T) { + ctx := &Context{ + EntityName: "User", + ModuleName: "user", + } + + result, err := GenerateInit(ctx) + if err != nil { + t.Fatalf("GenerateInit failed: %v", err) + } + + // __init__.py should be empty + if result != "" { + t.Errorf("Expected empty __init__.py, got %q", result) + } +} + +func TestGenerateRepository(t *testing.T) { + tableInfo := &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{}, + } + + ctx := NewContext(tableInfo, "") + + result, err := GenerateRepository(ctx) + if err != nil { + t.Fatalf("GenerateRepository failed: %v", err) + } + + // Check for expected content + expectedStrings := []string{ + "class UserRepository", + "CRUDRepository", + "model_cls = User", + "from televend_core.databases.base_repository import CRUDRepository", + "from televend_core.databases.televend_repositories.user.filter import", + "from televend_core.databases.televend_repositories.user.model import User", + } + + for _, expected := range expectedStrings { + if !strings.Contains(result, expected) { + t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } +} + +func TestGenerateManager(t *testing.T) { + tableInfo := &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{}, + } + + ctx := NewContext(tableInfo, "") + + result, err := GenerateManager(ctx) + if err != nil { + t.Fatalf("GenerateManager failed: %v", err) + } + + // Check for expected content + expectedStrings := []string{ + "class UserManager", + "CRUDManager", + "repository_cls = UserRepository", + } + + for _, expected := range expectedStrings { + if !strings.Contains(result, expected) { + t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } +} + +func TestNewContext(t *testing.T) { + tableInfo := &database.TableInfo{ + Schema: "public", + TableName: "user_accounts", + } + + // Test without override + ctx := NewContext(tableInfo, "") + if ctx.EntityName != "UserAccount" { + t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName) + } + if ctx.ModuleName != "user_account" { + t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName) + } + if ctx.TableConstant != "USER_ACCOUNT_TABLE" { + t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant) + } + + // Test with override + ctx = NewContext(tableInfo, "CustomUser") + if ctx.EntityName != "CustomUser" { + t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName) + } +} + +func TestGenerateTable(t *testing.T) { + tests := []struct { + name string + tableInfo *database.TableInfo + expected []string + }{ + { + name: "simple table with basic types", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false}, + {Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false}, + {Name: "email", DataType: "varchar", IsNullable: true}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "from sqlalchemy import", + "Column", + "Table", + "Integer", + "String", + "USER_TABLE = Table(", + `"users"`, + "metadata_obj", + `Column("id", Integer, primary_key=True, autoincrement=True)`, + `Column("name", String(255), nullable=False)`, + `Column("email", String)`, + }, + }, + { + name: "table with foreign keys", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "posts", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "user_id", DataType: "integer", IsNullable: false}, + {Name: "title", DataType: "text", IsNullable: false}, + }, + ForeignKeys: []database.ForeignKey{ + {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, + }, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "ForeignKey", + `ForeignKey("users.id", deferrable=True, initially="DEFERRED")`, + }, + }, + { + name: "table with enum", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "orders", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{ + "order_status": { + TypeName: "order_status", + Values: []string{"pending", "completed", "cancelled"}, + }, + }, + }, + expected: []string{ + "Enum", + "from televend_core.databases.televend_repositories.order.enum import", + "OrderStatus", + `Enum(`, + `*OrderStatus.to_value_list()`, + `name="order_status"`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.tableInfo, "") + result, err := GenerateTable(ctx) + if err != nil { + t.Fatalf("GenerateTable failed: %v", err) + } + + for _, expected := range tt.expected { + if !strings.Contains(result, expected) { + t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } + }) + } +} + +func TestGenerateModel(t *testing.T) { + tests := []struct { + name string + tableInfo *database.TableInfo + expected []string + }{ + { + name: "simple model", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "name", DataType: "varchar", IsNullable: false}, + {Name: "email", DataType: "varchar", IsNullable: true}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "from dataclasses import dataclass", + "from televend_core.databases.base_model import Base", + "@dataclass", + "class User(Base):", + "name: str", + "email: str | None = None", + "id: int | None = None", + }, + }, + { + name: "model with foreign key", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "posts", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "user_id", DataType: "integer", IsNullable: false}, + {Name: "title", DataType: "text", IsNullable: false}, + }, + ForeignKeys: []database.ForeignKey{ + {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, + }, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "from televend_core.databases.televend_repositories.user.model import User", + "user_id: int", + "user: User", + "title: str", + }, + }, + { + name: "model with datetime and decimal", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "orders", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "amount", DataType: "numeric", IsNullable: false}, + {Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "from datetime import datetime", + "from decimal import Decimal", + "amount: Decimal", + "created_at: datetime | None = None", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.tableInfo, "") + result, err := GenerateModel(ctx) + if err != nil { + t.Fatalf("GenerateModel failed: %v", err) + } + + for _, expected := range tt.expected { + if !strings.Contains(result, expected) { + t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } + }) + } +} + +func TestGenerateFilter(t *testing.T) { + tests := []struct { + name string + tableInfo *database.TableInfo + expected []string + notExpect []string + }{ + { + name: "filter with boolean and id fields", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true}, + {Name: "name", DataType: "varchar"}, + {Name: "alive", DataType: "boolean"}, + {Name: "user_id", DataType: "integer"}, + }, + ForeignKeys: []database.ForeignKey{}, + }, + expected: []string{ + "from televend_core.databases.base_filter import BaseFilter", + "from televend_core.databases.common.filters.filters import EQ, IN, filterfield", + "class UserFilter(BaseFilter):", + "model_cls = User", + "id: int | None = filterfield(operator=EQ)", + "ids: list[int] | None = filterfield(field=\"id\", operator=IN)", + "name: str | None = filterfield(operator=EQ)", + "alive: bool | None = filterfield(operator=EQ, default=True)", + "user_id: int | None = filterfield(operator=EQ)", + "user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)", + }, + notExpect: []string{ + "default=None", + }, + }, + { + name: "filter with no filterable fields", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "logs", + Columns: []database.Column{ + {Name: "timestamp", DataType: "timestamp with time zone"}, + {Name: "amount", DataType: "numeric"}, + }, + ForeignKeys: []database.ForeignKey{}, + }, + expected: []string{ + "class LogFilter(BaseFilter):", + "pass", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.tableInfo, "") + result, err := GenerateFilter(ctx) + if err != nil { + t.Fatalf("GenerateFilter failed: %v", err) + } + + for _, expected := range tt.expected { + if !strings.Contains(result, expected) { + t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } + + for _, notExpected := range tt.notExpect { + if strings.Contains(result, notExpected) { + t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s", + notExpected, result) + } + } + }) + } +} + +func TestGenerateLoadOptions(t *testing.T) { + tests := []struct { + name string + tableInfo *database.TableInfo + expected []string + }{ + { + name: "load options with relationships", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "posts", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true}, + {Name: "user_id", DataType: "integer"}, + {Name: "category_id", DataType: "integer"}, + }, + ForeignKeys: []database.ForeignKey{ + {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, + {ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"}, + }, + }, + expected: []string{ + "from televend_core.databases.base_load_options import LoadOptions", + "from televend_core.databases.common.load_options import joinload", + "class PostLoadOptions(LoadOptions):", + "model_cls = Post", + `load_user: bool = joinload(relations=["user"])`, + `load_category: bool = joinload(relations=["category"])`, + }, + }, + { + name: "load options with no relationships", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "settings", + Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}}, + ForeignKeys: []database.ForeignKey{}, + }, + expected: []string{ + "class SettingLoadOptions(LoadOptions):", + "pass", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.tableInfo, "") + result, err := GenerateLoadOptions(ctx) + if err != nil { + t.Fatalf("GenerateLoadOptions failed: %v", err) + } + + for _, expected := range tt.expected { + if !strings.Contains(result, expected) { + t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } + }) + } +} + +func TestGenerateFactory(t *testing.T) { + tests := []struct { + name string + tableInfo *database.TableInfo + expected []string + }{ + { + name: "factory with basic fields", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "users", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}}, + {Name: "alive", DataType: "boolean"}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "from __future__ import annotations", + "from typing import Type", + "import factory", + "class UserFactory(TelevendBaseFactory):", + "alive = True", + "id = None", + `name = factory.Faker("pystr", max_chars=100)`, + "class Meta:", + "model = User", + "def create_minimal", + }, + }, + { + name: "factory with foreign keys", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "posts", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "user_id", DataType: "integer", IsNullable: false}, + {Name: "title", DataType: "text"}, + }, + ForeignKeys: []database.ForeignKey{ + {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, + }, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + "from televend_core.databases.televend_repositories.user.factory import", + "UserFactory", + `user = CustomSelfAttribute("..user", UserFactory)`, + "user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)", + `"user": kwargs.pop("user", None) or UserFactory.create_minimal()`, + }, + }, + { + name: "factory with decimal field", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "orders", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false}, + {Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}}, + }, + ForeignKeys: []database.ForeignKey{}, + EnumTypes: map[string]database.EnumType{}, + }, + expected: []string{ + `amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.tableInfo, "") + result, err := GenerateFactory(ctx) + if err != nil { + t.Fatalf("GenerateFactory failed: %v", err) + } + + for _, expected := range tt.expected { + if !strings.Contains(result, expected) { + t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } + }) + } +} + +func TestGenerateMapper(t *testing.T) { + tests := []struct { + name string + tableInfo *database.TableInfo + expected []string + }{ + { + name: "mapper with single foreign key", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "posts", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true}, + {Name: "user_id", DataType: "integer"}, + }, + ForeignKeys: []database.ForeignKey{ + {ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"}, + }, + }, + expected: []string{ + "mapper_registry.map_imperatively(", + "class_=Post,", + "local_table=POST_TABLE,", + "properties={", + `"user": relationship(`, + "User, lazy=relationship_loading_strategy.value", + }, + }, + { + name: "mapper with multiple foreign keys to same table", + tableInfo: &database.TableInfo{ + Schema: "public", + TableName: "messages", + Columns: []database.Column{ + {Name: "id", DataType: "integer", IsPrimaryKey: true}, + {Name: "sender_id", DataType: "integer"}, + {Name: "receiver_id", DataType: "integer"}, + }, + ForeignKeys: []database.ForeignKey{ + {ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"}, + {ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"}, + }, + }, + expected: []string{ + `"sender": relationship(`, + "User, lazy=relationship_loading_strategy.value", + `"receiver": relationship(`, + "foreign_keys=MESSAGE_TABLE.columns.receiver_id,", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := NewContext(tt.tableInfo, "") + result, err := GenerateMapper(ctx) + if err != nil { + t.Fatalf("GenerateMapper failed: %v", err) + } + + for _, expected := range tt.expected { + if !strings.Contains(result, expected) { + t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s", + expected, result) + } + } + }) + } +} + +func TestGetPythonTypeForColumn(t *testing.T) { + ctx := &Context{ + TableInfo: &database.TableInfo{ + EnumTypes: map[string]database.EnumType{ + "status_enum": { + TypeName: "status_enum", + Values: []string{"active", "inactive"}, + }, + }, + }, + } + + tests := []struct { + name string + col database.Column + expected string + }{ + { + name: "integer type", + col: database.Column{DataType: "integer"}, + expected: "int", + }, + { + name: "varchar type", + col: database.Column{DataType: "varchar"}, + expected: "str", + }, + { + name: "enum type", + col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, + expected: "StatusEnum", + }, + { + name: "unknown enum type", + col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"}, + expected: "str", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := GetPythonTypeForColumn(tt.col, ctx) + if result != tt.expected { + t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected) + } + }) + } +} diff --git a/internal/naming/naming.go b/internal/naming/naming.go index fa1ab44..e8c11e3 100644 --- a/internal/naming/naming.go +++ b/internal/naming/naming.go @@ -155,8 +155,12 @@ func Pluralize(word string) string { return preserveCase(word, plural) } - // Already plural (ends in 's' and not special case) - if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") { + // Already plural (ends in 's' after a consonant, but not 'ss', 'us', 'is') + // Skip this check if word ends in 'ss' (like 'class', 'glass') + if strings.HasSuffix(lower, "s") && + !strings.HasSuffix(lower, "ss") && + !strings.HasSuffix(lower, "us") && + !strings.HasSuffix(lower, "is") { return word } diff --git a/internal/naming/naming_test.go b/internal/naming/naming_test.go new file mode 100644 index 0000000..f89bcc1 --- /dev/null +++ b/internal/naming/naming_test.go @@ -0,0 +1,247 @@ +package naming + +import ( + "testing" +) + +func TestSingularize(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + // Regular plurals + {"simple s", "users", "user"}, + {"simple s uppercase", "Users", "User"}, + {"tables", "tables", "table"}, + + // ies -> y + {"ies to y", "companies", "company"}, + {"ies to y uppercase", "Companies", "Company"}, + {"cities", "cities", "city"}, + {"categories", "categories", "category"}, + + // ves -> f/fe + {"ves to f", "halves", "half"}, + {"ves to fe - knife", "knives", "knife"}, + {"ves to fe - wife", "wives", "wife"}, + {"ves to fe - life", "lives", "life"}, + + // es endings + {"xes", "boxes", "box"}, + {"shes", "dishes", "dish"}, + {"ches", "watches", "watch"}, + {"ses", "classes", "class"}, + + // oes -> o + {"oes", "tomatoes", "tomato"}, + {"heroes", "heroes", "hero"}, + + // Irregular plurals + {"people", "people", "person"}, + {"children", "children", "child"}, + {"men", "men", "man"}, + {"women", "women", "woman"}, + {"teeth", "teeth", "tooth"}, + {"feet", "feet", "foot"}, + {"mice", "mice", "mouse"}, + {"geese", "geese", "goose"}, + + // Unchanged + {"sheep", "sheep", "sheep"}, + {"fish", "fish", "fish"}, + {"deer", "deer", "deer"}, + {"series", "series", "series"}, + {"species", "species", "species"}, + + // Special cases + {"data", "data", "datum"}, + {"analyses", "analyses", "analysis"}, + {"crises", "crises", "crisis"}, + + // Edge cases + {"empty", "", ""}, + {"already singular", "user", "user"}, + {"ss ending", "glass", "glass"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Singularize(tt.input) + if result != tt.expected { + t.Errorf("Singularize(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestPluralize(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + // Regular plurals + {"simple", "user", "users"}, + {"simple uppercase", "User", "Users"}, + {"table", "table", "tables"}, + + // y -> ies + {"y to ies", "company", "companies"}, + {"y to ies uppercase", "Company", "Companies"}, + {"city", "city", "cities"}, + {"category", "category", "categories"}, + + // consonant + y stays + {"ay ending", "day", "days"}, + {"ey ending", "key", "keys"}, + {"oy ending", "boy", "boys"}, + + // f/fe -> ves + {"f to ves", "half", "halves"}, + {"fe to ves - knife", "knife", "knives"}, + {"fe to ves - wife", "wife", "wives"}, + {"fe to ves - life", "life", "lives"}, + + // s, x, sh, ch -> es + {"x to es", "box", "boxes"}, + {"sh to es", "dish", "dishes"}, + {"ch to es", "watch", "watches"}, + {"s to es", "class", "classes"}, + + // o -> oes + {"o to oes", "tomato", "tomatoes"}, + {"hero", "hero", "heroes"}, + + // Irregular plurals + {"person", "person", "people"}, + {"child", "child", "children"}, + {"man", "man", "men"}, + {"woman", "woman", "women"}, + {"tooth", "tooth", "teeth"}, + {"foot", "foot", "feet"}, + {"mouse", "mouse", "mice"}, + {"goose", "goose", "geese"}, + + // Unchanged + {"sheep", "sheep", "sheep"}, + {"fish", "fish", "fish"}, + {"deer", "deer", "deer"}, + {"series", "series", "series"}, + {"species", "species", "species"}, + + // Edge cases + {"empty", "", ""}, + {"already plural", "users", "users"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Pluralize(tt.input) + if result != tt.expected { + t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestSingularizeTableName(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"simple", "users", "user"}, + {"compound", "user_accounts", "user_account"}, + {"triple compound", "user_login_histories", "user_login_history"}, + {"already singular", "user", "user"}, + {"ies ending", "cashbag_conforms", "cashbag_conform"}, + {"complex table", "auth_user_groups", "auth_user_group"}, + {"empty", "", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := SingularizeTableName(tt.input) + if result != tt.expected { + t.Errorf("SingularizeTableName(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestToPascalCase(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"simple", "user", "User"}, + {"snake_case", "user_account", "UserAccount"}, + {"triple", "user_login_history", "UserLoginHistory"}, + {"already pascal", "UserAccount", "UserAccount"}, + {"single char", "a", "A"}, + {"empty", "", ""}, + {"with numbers", "user_2fa", "User2fa"}, + {"multiple underscores", "user___account", "UserAccount"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToPascalCase(tt.input) + if result != tt.expected { + t.Errorf("ToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +func TestToSnakeCase(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + {"PascalCase", "UserAccount", "user_account"}, + {"simple", "User", "user"}, + {"multiple words", "UserLoginHistory", "user_login_history"}, + {"already snake", "user_account", "user_account"}, + {"empty", "", ""}, + {"single char", "A", "a"}, + {"with numbers", "User2FA", "user2_f_a"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToSnakeCase(tt.input) + if result != tt.expected { + t.Errorf("ToSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// Benchmark tests +func BenchmarkSingularize(b *testing.B) { + for i := 0; i < b.N; i++ { + Singularize("companies") + } +} + +func BenchmarkPluralize(b *testing.B) { + for i := 0; i < b.N; i++ { + Pluralize("company") + } +} + +func BenchmarkToPascalCase(b *testing.B) { + for i := 0; i < b.N; i++ { + ToPascalCase("user_login_history") + } +} + +func BenchmarkToSnakeCase(b *testing.B) { + for i := 0; i < b.N; i++ { + ToSnakeCase("UserLoginHistory") + } +} diff --git a/internal/prompt/prompt_test.go b/internal/prompt/prompt_test.go new file mode 100644 index 0000000..46cc1fd --- /dev/null +++ b/internal/prompt/prompt_test.go @@ -0,0 +1,319 @@ +package prompt + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" +) + +// mockReader implements the Reader interface for testing +type mockReader struct { + input string + pos int +} + +func (m *mockReader) ReadString(delim byte) (string, error) { + if m.pos >= len(m.input) { + return "", io.EOF + } + + idx := strings.IndexByte(m.input[m.pos:], delim) + if idx == -1 { + result := m.input[m.pos:] + m.pos = len(m.input) + return result, nil + } + + result := m.input[m.pos : m.pos+idx+1] + m.pos += idx + 1 + return result, nil +} + +func TestPromptStringWithReader(t *testing.T) { + tests := []struct { + name string + input string + defaultValue string + required bool + expected string + expectError bool + }{ + { + name: "user enters value", + input: "testvalue\n", + defaultValue: "default", + required: false, + expected: "testvalue", + expectError: false, + }, + { + name: "user accepts default", + input: "\n", + defaultValue: "default", + required: false, + expected: "default", + expectError: false, + }, + { + name: "required field empty then filled", + input: "\nactualvalue\n", + defaultValue: "", + required: true, + expected: "actualvalue", + expectError: false, + }, + { + name: "optional empty field", + input: "\n", + defaultValue: "", + required: false, + expected: "", + expectError: false, + }, + { + name: "whitespace trimmed", + input: " test \n", + defaultValue: "", + required: false, + expected: "test", + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := &mockReader{input: tt.input} + result, err := promptStringWithReader(reader, "Test Label", tt.defaultValue, tt.required) + + if (err != nil) != tt.expectError { + t.Errorf("Expected error: %v, got: %v", tt.expectError, err) + } + + if result != tt.expected { + t.Errorf("Expected %q, got %q", tt.expected, result) + } + }) + } +} + +func TestPromptIntWithReader(t *testing.T) { + tests := []struct { + name string + input string + defaultValue int + required bool + expected int + expectError bool + }{ + { + name: "user enters valid number", + input: "42\n", + defaultValue: 0, + required: false, + expected: 42, + expectError: false, + }, + { + name: "user accepts default", + input: "\n", + defaultValue: 5432, + required: false, + expected: 5432, + expectError: false, + }, + { + name: "invalid then valid number", + input: "abc\n123\n", + defaultValue: 0, + required: false, + expected: 123, + expectError: false, + }, + { + name: "required field empty then filled", + input: "\n999\n", + defaultValue: 0, + required: true, + expected: 999, + expectError: false, + }, + { + name: "zero value", + input: "0\n", + defaultValue: 10, + required: false, + expected: 0, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := &mockReader{input: tt.input} + result, err := promptIntWithReader(reader, "Test Label", tt.defaultValue, tt.required) + + if (err != nil) != tt.expectError { + t.Errorf("Expected error: %v, got: %v", tt.expectError, err) + } + + if result != tt.expected { + t.Errorf("Expected %d, got %d", tt.expected, result) + } + }) + } +} + +func TestValidateDirectory(t *testing.T) { + // Create a temporary directory for testing + tmpDir := t.TempDir() + + tests := []struct { + name string + path string + setup func() string + expectError bool + }{ + { + name: "existing directory", + setup: func() string { + return tmpDir + }, + expectError: false, + }, + { + name: "non-existent directory gets created", + setup: func() string { + return filepath.Join(tmpDir, "newdir") + }, + expectError: false, + }, + { + name: "nested directory gets created", + setup: func() string { + return filepath.Join(tmpDir, "level1", "level2", "level3") + }, + expectError: false, + }, + { + name: "file exists at path", + setup: func() string { + filePath := filepath.Join(tmpDir, "testfile") + os.WriteFile(filePath, []byte("test"), 0644) + return filePath + }, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := tt.setup() + err := ValidateDirectory(path) + + if (err != nil) != tt.expectError { + t.Errorf("Expected error: %v, got: %v", tt.expectError, err) + } + + // If no error expected, verify directory exists + if !tt.expectError { + info, err := os.Stat(path) + if err != nil { + t.Errorf("Directory should exist: %v", err) + } + if !info.IsDir() { + t.Errorf("Path should be a directory") + } + } + }) + } +} + +func TestPrintFunctions(t *testing.T) { + // These functions write to stdout, so we just ensure they don't panic + // Testing colored output is complex due to terminal color codes + + tests := []struct { + name string + fn func() + }{ + {"PrintHeader", func() { PrintHeader("Test Header") }}, + {"PrintSuccess", func() { PrintSuccess("Success message") }}, + {"PrintError", func() { PrintError("Error message") }}, + {"PrintInfo", func() { PrintInfo("Info message") }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Just ensure the function doesn't panic + defer func() { + if r := recover(); r != nil { + t.Errorf("Function panicked: %v", r) + } + }() + + tt.fn() + }) + } +} + +func TestPromptDirectory(t *testing.T) { + tmpDir := t.TempDir() + + tests := []struct { + name string + input string + defaultValue string + required bool + expectError bool + }{ + { + name: "valid directory path", + input: tmpDir + "\n", + defaultValue: "", + required: true, + expectError: false, + }, + { + name: "creates new directory", + input: filepath.Join(tmpDir, "newdir") + "\n", + defaultValue: "", + required: true, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Note: This test would require refactoring PromptDirectory to accept a reader + // For now, we test ValidateDirectory which is the core logic + if tt.input != "" { + path := strings.TrimSpace(tt.input[:len(tt.input)-1]) + err := ValidateDirectory(path) + if (err != nil) != tt.expectError { + t.Errorf("Expected error: %v, got: %v", tt.expectError, err) + } + } + }) + } +} + +// Benchmark tests +func BenchmarkPromptStringWithReader(b *testing.B) { + reader := &mockReader{input: strings.Repeat("test\n", b.N)} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + promptStringWithReader(reader, "Label", "default", false) + } +} + +func BenchmarkPromptIntWithReader(b *testing.B) { + reader := &mockReader{input: strings.Repeat("42\n", b.N)} + b.ResetTimer() + + for i := 0; i < b.N; i++ { + promptIntWithReader(reader, "Label", 0, false) + } +}