Compare commits

...

4 Commits

Author SHA1 Message Date
5f5f1ad114 Generate cloud instead televendcloud entities 2025-12-15 10:37:28 +01:00
0a4030c389 Tweak output 2025-11-04 09:58:08 +01:00
4e4827d640 Tests 2025-10-31 19:04:40 +01:00
ca10b01fb0 Update makefile 2025-10-31 18:41:32 +01:00
20 changed files with 2732 additions and 29 deletions

4
.gitignore vendored
View File

@ -3,3 +3,7 @@
*.toml
/entity-maker
# Test coverage
coverage.out
coverage.html

122
Makefile
View File

@ -1,10 +1,130 @@
EXEC=entity-maker
BUILD_DIR=./build
CMD_DIR=./cmd/entity-maker
# Build variables
GOOS ?= linux
GOARCH ?= amd64
LDFLAGS=-s -w
.PHONY: build
build:
@go build -ldflags "-s -w" -o ./build/${EXEC} ./cmd/entity-maker/main.go
@echo "Building ${EXEC} for ${GOOS}/${GOARCH}..."
@mkdir -p ${BUILD_DIR}
@CGO_ENABLED=0 GOOS=${GOOS} GOARCH=${GOARCH} go build \
-ldflags "${LDFLAGS}" \
-trimpath \
-o ${BUILD_DIR}/${EXEC} \
${CMD_DIR}
@echo "✓ Binary created: ${BUILD_DIR}/${EXEC}"
@ls -lh ${BUILD_DIR}/${EXEC}
.PHONY: build-dev
build-dev:
@echo "Building ${EXEC} (development mode)..."
@mkdir -p ${BUILD_DIR}
@go build -o ${BUILD_DIR}/${EXEC} ${CMD_DIR}
@echo "✓ Binary created: ${BUILD_DIR}/${EXEC}"
.PHONY: install
install: build
@echo "Installing ${EXEC} to /usr/local/bin..."
@sudo cp ${BUILD_DIR}/${EXEC} /usr/local/bin/${EXEC}
@echo "✓ Installed successfully"
.PHONY: clean
clean:
@echo "Cleaning build artifacts..."
@rm -rf ${BUILD_DIR}
@rm -f ${EXEC}
@echo "✓ Clean complete"
.PHONY: test
test:
@echo "Running tests..."
@go test ./...
@echo "✓ Tests passed"
.PHONY: test-verbose
test-verbose:
@echo "Running tests (verbose)..."
@go test -v ./...
.PHONY: test-coverage
test-coverage:
@echo "Running tests with coverage..."
@go test -cover ./...
@echo "✓ Tests completed"
.PHONY: coverage
coverage:
@echo "Generating coverage report..."
@go test -coverprofile=coverage.out ./...
@go tool cover -html=coverage.out -o coverage.html
@echo "✓ Coverage report generated: coverage.html"
@echo " Open coverage.html in your browser to view detailed coverage"
.PHONY: test-short
test-short:
@echo "Running short tests..."
@go test -short ./...
@echo "✓ Short tests passed"
.PHONY: bench
bench:
@echo "Running benchmarks..."
@go test -bench=. -benchmem ./internal/naming
@echo "✓ Benchmarks completed"
.PHONY: verify-static
verify-static: build
@echo "Verifying binary is statically linked..."
@file ${BUILD_DIR}/${EXEC}
@ldd ${BUILD_DIR}/${EXEC} 2>&1 || echo "✓ Binary is statically linked"
.PHONY: upgrade-packages
upgrade-packages:
@echo "Upgrading Go packages..."
@go get -u ./...
@go mod tidy
@echo "✓ Packages upgraded"
.PHONY: help
help:
@echo "Entity Maker - Makefile targets:"
@echo ""
@echo "Build targets:"
@echo " build Build static binary for Linux (default)"
@echo " build-dev Build development binary with debug symbols"
@echo " install Install binary to /usr/local/bin"
@echo " clean Remove build artifacts"
@echo " verify-static Verify binary has no dynamic dependencies"
@echo ""
@echo "Test targets:"
@echo " test Run all tests"
@echo " test-verbose Run tests with verbose output"
@echo " test-coverage Run tests and show coverage summary"
@echo " coverage Generate HTML coverage report"
@echo " test-short Run only short tests"
@echo " bench Run benchmarks"
@echo ""
@echo "Other targets:"
@echo " upgrade-packages Update Go dependencies"
@echo " help Show this help message"
@echo ""
@echo "Build options:"
@echo " GOOS=linux GOARCH=amd64 make build (default)"
@echo ""

View File

@ -21,10 +21,90 @@ A command-line tool for generating SQLAlchemy entities from PostgreSQL database
### Build from source
Using Makefile (recommended):
```bash
# Build static binary (production-ready, fully portable)
make build
# Build with debug symbols (development)
make build-dev
# Install to /usr/local/bin
make install
# Clean build artifacts
make clean
# Verify binary has no dependencies
make verify-static
# Show all available targets
make help
```
Or using Go directly:
```bash
# Static binary (portable across all Linux x86-64 systems)
CGO_ENABLED=0 go build -ldflags "-s -w" -trimpath -o ./build/entity-maker ./cmd/entity-maker
# Development binary with debug symbols
go build -o entity-maker ./cmd/entity-maker
```
The static binary (`make build`) is **fully portable** and will run on any Linux x86-64 system without external dependencies.
## Testing
Run tests using Makefile commands:
```bash
# Run all tests
make test
# Run tests with verbose output
make test-verbose
# Run tests with coverage summary
make test-coverage
# Generate HTML coverage report
make coverage
# Then open coverage.html in your browser
# Run only short/fast tests
make test-short
# Run benchmarks
make bench
```
Or using Go directly:
```bash
# Run all tests
go test ./...
# Run with coverage
go test -cover ./...
# Run specific package
go test ./internal/naming -v
# Generate coverage report
go test -coverprofile=coverage.out ./...
go tool cover -html=coverage.out
```
**Test Coverage:**
- `internal/naming`: 97.1% ⭐
- `internal/generator`: 91.1% ⭐
- `internal/config`: 81.8%
- `internal/prompt`: 75.3%
- `internal/database`: 30.8%
- `cmd/entity-maker`: 0.0% (integration code - business logic tested in internal packages)
## Usage
### Interactive Mode

View File

@ -165,7 +165,6 @@ func run() error {
// Introspect table
prompt.PrintHeader("Introspecting Table")
tableInfo, err := dbClient.IntrospectTable(cfg.DBTable)
if err != nil {
return fmt.Errorf("failed to introspect table: %w", err)
@ -225,7 +224,6 @@ func run() error {
// Print summary
prompt.PrintHeader("Summary")
fmt.Printf("\n")
fmt.Printf(" Entity name: %s\n", color.GreenString(ctx.EntityName))
fmt.Printf(" Module name: %s\n", color.GreenString(ctx.ModuleName))
fmt.Printf(" Output dir: %s\n", color.GreenString(moduleDir))

View File

@ -0,0 +1,299 @@
package main
import (
"os"
"path/filepath"
"testing"
"github.com/entity-maker/entity-maker/internal/database"
"github.com/entity-maker/entity-maker/internal/generator"
)
// Note: The cmd/entity-maker package has low test coverage by design.
// The run() function is integration code that orchestrates calls to:
// - Prompt functions (requires user input)
// - Database connections (requires live PostgreSQL)
// - File system operations (requires specific paths)
//
// All the business logic is tested in the internal/* packages where
// coverage is high (91-97%). The tests below validate the file generation
// workflow in isolation without needing the full integration environment.
// TestPackageCompiles validates that the package compiles correctly
func TestPackageCompiles(t *testing.T) {
// This test ensures the main package compiles correctly
// The main() function itself is hard to test as it calls os.Exit()
// and depends on external resources (database, user input, file system)
// If this test runs, the package compiled successfully
t.Log("Main package compiled successfully")
}
// TestFileGeneration tests the file generation logic in isolation
func TestFileGeneration(t *testing.T) {
// Create a temporary directory for testing
tmpDir := t.TempDir()
moduleName := "test_entity"
moduleDir := filepath.Join(tmpDir, moduleName)
// Create the module directory
if err := os.MkdirAll(moduleDir, 0755); err != nil {
t.Fatalf("Failed to create module directory: %v", err)
}
// Create a simple test context
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "test_table",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
}
ctx := generator.NewContext(tableInfo, "")
// Define files to generate (same as in main.go)
files := map[string]func(*generator.Context) (string, error){
"table.py": generator.GenerateTable,
"model.py": generator.GenerateModel,
"filter.py": generator.GenerateFilter,
"load_options.py": generator.GenerateLoadOptions,
"repository.py": generator.GenerateRepository,
"manager.py": generator.GenerateManager,
"factory.py": generator.GenerateFactory,
"mapper.py": generator.GenerateMapper,
"__init__.py": generator.GenerateInit,
}
// Generate and write each file
for filename, genFunc := range files {
content, err := genFunc(ctx)
if err != nil {
t.Errorf("Failed to generate %s: %v", filename, err)
continue
}
filePath := filepath.Join(moduleDir, filename)
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
t.Errorf("Failed to write %s: %v", filename, err)
continue
}
// Verify file was created
if _, err := os.Stat(filePath); os.IsNotExist(err) {
t.Errorf("File %s was not created", filename)
}
// Verify file has content
if len(content) == 0 && filename != "__init__.py" {
t.Errorf("File %s has no content", filename)
}
}
// Verify all expected files exist
expectedFiles := []string{
"table.py", "model.py", "filter.py", "load_options.py",
"repository.py", "manager.py", "factory.py", "mapper.py", "__init__.py",
}
for _, filename := range expectedFiles {
filePath := filepath.Join(moduleDir, filename)
if _, err := os.Stat(filePath); os.IsNotExist(err) {
t.Errorf("Expected file %s does not exist", filename)
}
}
}
// TestFileGenerationWithEnums tests file generation when enum types are present
func TestFileGenerationWithEnums(t *testing.T) {
tmpDir := t.TempDir()
moduleName := "test_entity"
moduleDir := filepath.Join(tmpDir, moduleName)
if err := os.MkdirAll(moduleDir, 0755); err != nil {
t.Fatalf("Failed to create module directory: %v", err)
}
// Create test context with enums
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "test_table",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "status", DataType: "USER-DEFINED", UdtName: "status_enum", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"active", "inactive", "pending"},
},
},
}
ctx := generator.NewContext(tableInfo, "")
files := map[string]func(*generator.Context) (string, error){
"table.py": generator.GenerateTable,
"model.py": generator.GenerateModel,
"filter.py": generator.GenerateFilter,
"load_options.py": generator.GenerateLoadOptions,
"repository.py": generator.GenerateRepository,
"manager.py": generator.GenerateManager,
"factory.py": generator.GenerateFactory,
"mapper.py": generator.GenerateMapper,
"__init__.py": generator.GenerateInit,
}
// Add enum.py since we have enum types
if len(tableInfo.EnumTypes) > 0 {
files["enum.py"] = generator.GenerateEnum
}
// Generate all files
for filename, genFunc := range files {
content, err := genFunc(ctx)
if err != nil {
t.Errorf("Failed to generate %s: %v", filename, err)
continue
}
filePath := filepath.Join(moduleDir, filename)
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
t.Errorf("Failed to write %s: %v", filename, err)
}
}
// Verify enum.py was created
enumPath := filepath.Join(moduleDir, "enum.py")
if _, err := os.Stat(enumPath); os.IsNotExist(err) {
t.Error("enum.py was not created when enum types are present")
}
}
// TestModuleDirectoryCreation tests directory creation logic
func TestModuleDirectoryCreation(t *testing.T) {
tmpDir := t.TempDir()
tests := []struct {
name string
moduleName string
expectErr bool
}{
{
name: "simple module name",
moduleName: "user",
expectErr: false,
},
{
name: "nested module name",
moduleName: filepath.Join("nested", "module"),
expectErr: false,
},
{
name: "module with underscore",
moduleName: "user_account",
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
moduleDir := filepath.Join(tmpDir, tt.moduleName)
err := os.MkdirAll(moduleDir, 0755)
if (err != nil) != tt.expectErr {
t.Errorf("MkdirAll() error = %v, expectErr %v", err, tt.expectErr)
}
if !tt.expectErr {
// Verify directory exists
info, err := os.Stat(moduleDir)
if err != nil {
t.Errorf("Directory should exist: %v", err)
}
if !info.IsDir() {
t.Error("Path should be a directory")
}
}
})
}
}
// TestGeneratedFilePermissions tests that generated files have correct permissions
func TestGeneratedFilePermissions(t *testing.T) {
tmpDir := t.TempDir()
testFile := filepath.Join(tmpDir, "test.py")
content := "# Test content\n"
if err := os.WriteFile(testFile, []byte(content), 0644); err != nil {
t.Fatalf("Failed to write test file: %v", err)
}
info, err := os.Stat(testFile)
if err != nil {
t.Fatalf("Failed to stat test file: %v", err)
}
mode := info.Mode()
// Check that owner can read and write
if mode&0600 != 0600 {
t.Error("File should be readable and writable by owner")
}
// Check that file is not executable
if mode&0111 != 0 {
t.Error("Python files should not be executable")
}
}
// TestGeneratedFilesAreNonEmpty tests that generated files have content
func TestGeneratedFilesAreNonEmpty(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
}
ctx := generator.NewContext(tableInfo, "")
generators := map[string]func(*generator.Context) (string, error){
"table": generator.GenerateTable,
"model": generator.GenerateModel,
"filter": generator.GenerateFilter,
"load_options": generator.GenerateLoadOptions,
"repository": generator.GenerateRepository,
"manager": generator.GenerateManager,
"factory": generator.GenerateFactory,
"mapper": generator.GenerateMapper,
}
for name, genFunc := range generators {
t.Run(name, func(t *testing.T) {
content, err := genFunc(ctx)
if err != nil {
t.Fatalf("Failed to generate %s: %v", name, err)
}
if len(content) == 0 {
t.Errorf("Generated %s content should not be empty", name)
}
// Check that content has at least some basic Python structure
if name != "mapper" { // mapper is just a snippet
if len(content) < 10 {
t.Errorf("Generated %s content seems too short: %d bytes", name, len(content))
}
}
})
}
}

View File

@ -0,0 +1,117 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestDefaultConfig(t *testing.T) {
cfg := DefaultConfig()
if cfg.DBHost != "localhost" {
t.Errorf("Expected DBHost to be 'localhost', got %q", cfg.DBHost)
}
if cfg.DBPort != 5432 {
t.Errorf("Expected DBPort to be 5432, got %d", cfg.DBPort)
}
if cfg.DBSchema != "public" {
t.Errorf("Expected DBSchema to be 'public', got %q", cfg.DBSchema)
}
if cfg.DBUser != "postgres" {
t.Errorf("Expected DBUser to be 'postgres', got %q", cfg.DBUser)
}
if cfg.DBPassword != "postgres" {
t.Errorf("Expected DBPassword to be 'postgres', got %q", cfg.DBPassword)
}
}
func TestSaveAndLoad(t *testing.T) {
// Create a temporary directory and set HOME to it
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
configPath := filepath.Join(tmpDir, ".config", "entity-maker.toml")
// Create a test config
cfg := &Config{
DBHost: "testhost",
DBPort: 5433,
DBName: "testdb",
DBSchema: "testschema",
DBUser: "testuser",
DBPassword: "testpass",
DBTable: "testtable",
OutputDir: "/tmp/test",
EntityName: "TestEntity",
}
// Save config
err := cfg.Save()
if err != nil {
t.Fatalf("Failed to save config: %v", err)
}
// Verify file exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
t.Fatalf("Config file was not created: %v", err)
}
// Load config
loadedCfg, err := Load()
if err != nil {
t.Fatalf("Failed to load config: %v", err)
}
// Verify loaded values
tests := []struct {
name string
got interface{}
expected interface{}
}{
{"DBHost", loadedCfg.DBHost, cfg.DBHost},
{"DBPort", loadedCfg.DBPort, cfg.DBPort},
{"DBName", loadedCfg.DBName, cfg.DBName},
{"DBSchema", loadedCfg.DBSchema, cfg.DBSchema},
{"DBUser", loadedCfg.DBUser, cfg.DBUser},
{"DBPassword", loadedCfg.DBPassword, cfg.DBPassword},
{"DBTable", loadedCfg.DBTable, cfg.DBTable},
{"OutputDir", loadedCfg.OutputDir, cfg.OutputDir},
{"EntityName", loadedCfg.EntityName, cfg.EntityName},
}
for _, tt := range tests {
if tt.got != tt.expected {
t.Errorf("%s mismatch: expected %v, got %v", tt.name, tt.expected, tt.got)
}
}
}
func TestLoadNonExistentConfig(t *testing.T) {
// Create a temporary directory and set HOME to it
tmpDir := t.TempDir()
oldHome := os.Getenv("HOME")
os.Setenv("HOME", tmpDir)
defer os.Setenv("HOME", oldHome)
// Load should return default config when file doesn't exist
cfg, err := Load()
if err != nil {
t.Fatalf("Load should not error when file doesn't exist: %v", err)
}
// Verify it's the default config
defaultCfg := DefaultConfig()
if cfg.DBHost != defaultCfg.DBHost {
t.Errorf("Expected default DBHost %q, got %q", defaultCfg.DBHost, cfg.DBHost)
}
if cfg.DBPort != defaultCfg.DBPort {
t.Errorf("Expected default DBPort %d, got %d", defaultCfg.DBPort, cfg.DBPort)
}
}

View File

@ -0,0 +1,341 @@
package database
import (
"database/sql"
"testing"
)
func TestGetPythonType(t *testing.T) {
tests := []struct {
name string
col Column
expected string
}{
// Integer types
{"integer", Column{DataType: "integer"}, "int"},
{"smallint", Column{DataType: "smallint"}, "int"},
{"bigint", Column{DataType: "bigint"}, "int"},
// Numeric types
{"numeric", Column{DataType: "numeric"}, "Decimal"},
{"decimal", Column{DataType: "decimal"}, "Decimal"},
{"real", Column{DataType: "real"}, "Decimal"},
{"double precision", Column{DataType: "double precision"}, "Decimal"},
// Boolean
{"boolean", Column{DataType: "boolean"}, "bool"},
// String types
{"varchar", Column{DataType: "character varying"}, "str"},
{"varchar short", Column{DataType: "varchar"}, "str"},
{"text", Column{DataType: "text"}, "str"},
{"char", Column{DataType: "char"}, "str"},
{"character", Column{DataType: "character"}, "str"},
// Date/Time types
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"},
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"},
{"timestamp", Column{DataType: "timestamp"}, "datetime"},
{"date", Column{DataType: "date"}, "date"},
{"time with tz", Column{DataType: "time with time zone"}, "time"},
{"time without tz", Column{DataType: "time without time zone"}, "time"},
{"time", Column{DataType: "time"}, "time"},
// JSON types
{"json", Column{DataType: "json"}, "dict"},
{"jsonb", Column{DataType: "jsonb"}, "dict"},
// Other types
{"uuid", Column{DataType: "uuid"}, "UUID"},
{"bytea", Column{DataType: "bytea"}, "bytes"},
// User-defined (enum)
{"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"},
// Unknown type
{"unknown", Column{DataType: "unknown_type"}, "Any"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPythonType(tt.col)
if result != tt.expected {
t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}
func TestGetSQLAlchemyType(t *testing.T) {
tests := []struct {
name string
col Column
expected string
}{
// Integer types
{"integer", Column{DataType: "integer"}, "Integer"},
{"smallint", Column{DataType: "smallint"}, "SmallInteger"},
{"bigint", Column{DataType: "bigint"}, "BigInteger"},
// Numeric types with precision
{
"numeric with precision",
Column{
DataType: "numeric",
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
},
"Numeric(12, 4)",
},
{
"numeric without precision",
Column{DataType: "numeric"},
"Numeric",
},
{"real", Column{DataType: "real"}, "Float"},
{"double precision", Column{DataType: "double precision"}, "Float"},
// Boolean
{"boolean", Column{DataType: "boolean"}, "Boolean"},
// String types
{
"varchar with length",
Column{
DataType: "character varying",
CharMaxLength: sql.NullInt64{Valid: true, Int64: 255},
},
"String(255)",
},
{
"varchar without length",
Column{DataType: "varchar"},
"String",
},
{
"char with length",
Column{
DataType: "char",
CharMaxLength: sql.NullInt64{Valid: true, Int64: 10},
},
"String(10)",
},
{
"char without length",
Column{DataType: "character"},
"String(1)",
},
{"text", Column{DataType: "text"}, "Text"},
// Date/Time types
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"},
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"},
{"timestamp", Column{DataType: "timestamp"}, "DateTime"},
{"date", Column{DataType: "date"}, "Date"},
{"time with tz", Column{DataType: "time with time zone"}, "Time"},
{"time without tz", Column{DataType: "time without time zone"}, "Time"},
{"time", Column{DataType: "time"}, "Time"},
// JSON types
{"json", Column{DataType: "json"}, "JSON"},
{"jsonb", Column{DataType: "jsonb"}, "JSONB"},
// Other types
{"uuid", Column{DataType: "uuid"}, "UUID"},
{"bytea", Column{DataType: "bytea"}, "LargeBinary"},
// User-defined (enum)
{"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"},
// Unknown type
{"unknown", Column{DataType: "unknown_type"}, "String"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSQLAlchemyType(tt.col)
if result != tt.expected {
t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}
func TestColumn(t *testing.T) {
col := Column{
Name: "test_column",
DataType: "varchar",
IsNullable: true,
ColumnDefault: sql.NullString{Valid: true, String: "default_value"},
CharMaxLength: sql.NullInt64{Valid: true, Int64: 100},
NumericPrecision: sql.NullInt64{Valid: false},
NumericScale: sql.NullInt64{Valid: false},
UdtName: "",
IsPrimaryKey: false,
IsAutoIncrement: false,
}
if col.Name != "test_column" {
t.Errorf("Expected Name 'test_column', got %q", col.Name)
}
if !col.IsNullable {
t.Error("Expected IsNullable to be true")
}
if !col.ColumnDefault.Valid {
t.Error("Expected ColumnDefault to be valid")
}
if col.ColumnDefault.String != "default_value" {
t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String)
}
}
func TestForeignKey(t *testing.T) {
fk := ForeignKey{
ColumnName: "user_id",
ForeignTableSchema: "public",
ForeignTableName: "users",
ForeignColumnName: "id",
ConstraintName: "fk_user_id",
}
if fk.ColumnName != "user_id" {
t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName)
}
if fk.ForeignTableName != "users" {
t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName)
}
if fk.ForeignColumnName != "id" {
t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName)
}
}
func TestEnumType(t *testing.T) {
enum := EnumType{
TypeName: "status_enum",
Values: []string{"OPEN", "CLOSED", "PENDING"},
}
if enum.TypeName != "status_enum" {
t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName)
}
if len(enum.Values) != 3 {
t.Errorf("Expected 3 values, got %d", len(enum.Values))
}
expectedValues := []string{"OPEN", "CLOSED", "PENDING"}
for i, val := range enum.Values {
if val != expectedValues[i] {
t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val)
}
}
}
func TestTableInfo(t *testing.T) {
tableInfo := &TableInfo{
Schema: "public",
TableName: "users",
Columns: []Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "name", DataType: "varchar"},
},
ForeignKeys: []ForeignKey{
{ColumnName: "company_id", ForeignTableName: "companies"},
},
EnumTypes: map[string]EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"ACTIVE", "INACTIVE"},
},
},
}
if tableInfo.Schema != "public" {
t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema)
}
if tableInfo.TableName != "users" {
t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName)
}
if len(tableInfo.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns))
}
if len(tableInfo.ForeignKeys) != 1 {
t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys))
}
if len(tableInfo.EnumTypes) != 1 {
t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes))
}
// Test primary key detection
foundPK := false
for _, col := range tableInfo.Columns {
if col.IsPrimaryKey {
foundPK = true
if col.Name != "id" {
t.Errorf("Expected primary key to be 'id', got %q", col.Name)
}
}
}
if !foundPK {
t.Error("Expected to find primary key column")
}
}
func TestConfig(t *testing.T) {
cfg := Config{
Host: "localhost",
Port: 5432,
Database: "testdb",
Schema: "public",
User: "testuser",
Password: "testpass",
}
if cfg.Host != "localhost" {
t.Errorf("Expected Host 'localhost', got %q", cfg.Host)
}
if cfg.Port != 5432 {
t.Errorf("Expected Port 5432, got %d", cfg.Port)
}
if cfg.Database != "testdb" {
t.Errorf("Expected Database 'testdb', got %q", cfg.Database)
}
if cfg.Schema != "public" {
t.Errorf("Expected Schema 'public', got %q", cfg.Schema)
}
}
// Benchmark tests
func BenchmarkGetPythonType(b *testing.B) {
col := Column{DataType: "character varying"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
GetPythonType(col)
}
}
func BenchmarkGetSQLAlchemyType(b *testing.B) {
col := Column{
DataType: "numeric",
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
GetSQLAlchemyType(col)
}
}

View File

@ -0,0 +1,199 @@
package generator
import (
"strings"
"testing"
"github.com/entity-maker/entity-maker/internal/database"
)
func TestSanitizePythonIdentifier(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"starts with digit", "12_HOUR", "_12_HOUR"},
{"starts with letter", "HOUR_12", "HOUR_12"},
{"all digits", "24", "_24"},
{"underscore first", "_12_HOUR", "_12_HOUR"},
{"empty", "", ""},
{"single digit", "1", "_1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizePythonIdentifier(tt.input)
if result != tt.expected {
t.Errorf("sanitizePythonIdentifier(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGenerateEnum(t *testing.T) {
tests := []struct {
name string
enumTypes map[string]database.EnumType
expectError bool
checkFunc func(t *testing.T, result string)
}{
{
name: "simple enum",
enumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"OPEN", "CLOSED", "PENDING"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class StatusEnum") {
t.Error("Expected class StatusEnum")
}
if !strings.Contains(result, "OPEN = \"OPEN\"") {
t.Error("Expected OPEN = \"OPEN\"")
}
if !strings.Contains(result, "CLOSED = \"CLOSED\"") {
t.Error("Expected CLOSED = \"CLOSED\"")
}
if !strings.Contains(result, "PENDING = \"PENDING\"") {
t.Error("Expected PENDING = \"PENDING\"")
}
},
},
{
name: "enum with spaces and hyphens",
enumTypes: map[string]database.EnumType{
"time_format_enum": {
TypeName: "time_format_enum",
Values: []string{"12-hour", "24-hour"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class TimeFormatEnum") {
t.Error("Expected class TimeFormatEnum")
}
if !strings.Contains(result, "_12_HOUR = \"12-hour\"") {
t.Error("Expected _12_HOUR = \"12-hour\" (sanitized)")
}
if !strings.Contains(result, "_24_HOUR = \"24-hour\"") {
t.Error("Expected _24_HOUR = \"24-hour\" (sanitized)")
}
},
},
{
name: "enum with duplicates after normalization",
enumTypes: map[string]database.EnumType{
"measurement_enum": {
TypeName: "measurement_enum",
Values: []string{"international", "INTERNATIONAL", "imperial", "IMPERIAL"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class MeasurementEnum") {
t.Error("Expected class MeasurementEnum")
}
// Should only have one INTERNATIONAL and one IMPERIAL
internationalCount := strings.Count(result, "INTERNATIONAL = ")
if internationalCount != 1 {
t.Errorf("Expected 1 INTERNATIONAL, got %d", internationalCount)
}
imperialCount := strings.Count(result, "IMPERIAL = ")
if imperialCount != 1 {
t.Errorf("Expected 1 IMPERIAL, got %d", imperialCount)
}
},
},
{
name: "empty enum",
enumTypes: map[string]database.EnumType{
"empty_enum": {
TypeName: "empty_enum",
Values: []string{},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class EmptyEnum") {
t.Error("Expected class EmptyEnum")
}
if !strings.Contains(result, "pass") {
t.Error("Expected 'pass' for empty enum")
}
},
},
{
name: "multiple enums",
enumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"OPEN", "CLOSED"},
},
"priority_enum": {
TypeName: "priority_enum",
Values: []string{"HIGH", "LOW"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class StatusEnum") {
t.Error("Expected class StatusEnum")
}
if !strings.Contains(result, "class PriorityEnum") {
t.Error("Expected class PriorityEnum")
}
},
},
{
name: "enum with special characters",
enumTypes: map[string]database.EnumType{
"special_enum": {
TypeName: "special_enum",
Values: []string{"IN-PROGRESS", "ON HOLD", "DONE"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "IN_PROGRESS = \"IN-PROGRESS\"") {
t.Error("Expected IN_PROGRESS = \"IN-PROGRESS\"")
}
if !strings.Contains(result, "ON_HOLD = \"ON HOLD\"") {
t.Error("Expected ON_HOLD = \"ON HOLD\"")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &Context{
TableInfo: &database.TableInfo{
EnumTypes: tt.enumTypes,
},
EntityName: "TestEntity",
ModuleName: "test_entity",
}
result, err := GenerateEnum(ctx)
if (err != nil) != tt.expectError {
t.Errorf("GenerateEnum() error = %v, expectError %v", err, tt.expectError)
return
}
if tt.checkFunc != nil {
tt.checkFunc(t, result)
}
// Check common requirements
if !strings.Contains(result, "from enum import StrEnum") {
t.Error("Expected import of StrEnum")
}
if !strings.Contains(result, "from televend_core.databases.enum import EnumMixin") {
t.Error("Expected import of EnumMixin")
}
})
}
}

View File

@ -26,23 +26,23 @@ func GenerateFactory(ctx *Context) (string, error) {
}
for moduleName, entityName := range fkImports {
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.factory import (\n",
moduleName))
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
b.WriteString(")\n")
}
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
b.WriteString(" CustomSelfAttribute,\n")
b.WriteString(" TelevendBaseFactory,\n")
b.WriteString(" CloudBaseFactory,\n")
b.WriteString(")\n\n\n")
// Class definition
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
b.WriteString(fmt.Sprintf("class %sFactory(CloudBaseFactory):\n", ctx.EntityName))
// Add boolean fields with defaults
for _, col := range ctx.TableInfo.Columns {

View File

@ -14,7 +14,7 @@ func GenerateFilter(ctx *Context) (string, error) {
// Imports
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString("\n\n")

View File

@ -57,15 +57,11 @@ func GetRelationshipModuleName(tableName string) string {
}
// GetFilterFieldName returns the filter field name for a column
// For ID fields with IN operator, pluralizes the name
// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids)
func GetFilterFieldName(columnName string, useIN bool) string {
if useIN && strings.HasSuffix(columnName, "_id") {
// Remove _id, pluralize, add back _ids
base := columnName[:len(columnName)-3]
if base == "" {
return "ids"
}
return naming.Pluralize(base) + "_ids"
// Replace _id with _ids
return columnName[:len(columnName)-2] + "ids"
}
if useIN && columnName == "id" {
return "ids"

View File

@ -0,0 +1,979 @@
package generator
import (
"database/sql"
"strings"
"testing"
"github.com/entity-maker/entity-maker/internal/database"
)
func TestGetRelationshipName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"with _id suffix", "author_info_id", "author_info"},
{"with user_id", "user_id", "user"},
{"without _id", "status", "status"},
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetRelationshipEntityName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"plural table", "users", "User"},
{"compound plural", "user_accounts", "UserAccount"},
{"ies ending", "companies", "Company"},
{"already singular", "user", "User"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipEntityName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetRelationshipModuleName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"plural table", "users", "user"},
{"compound plural", "user_accounts", "user_account"},
{"ies ending", "companies", "company"},
{"already singular", "user", "user"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipModuleName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetFilterFieldName(t *testing.T) {
tests := []struct {
name string
columnName string
useIN bool
expected string
}{
{"id with IN", "id", true, "ids"},
{"id without IN", "id", false, "id"},
{"user_id with IN", "user_id", true, "user_ids"},
{"user_id without IN", "user_id", false, "user_id"},
{"machine_id with IN", "machine_id", true, "machine_ids"},
{"status without IN", "status", false, "status"},
{"status with IN", "status", true, "status"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetFilterFieldName(tt.columnName, tt.useIN)
if result != tt.expected {
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
tt.columnName, tt.useIN, result, tt.expected)
}
})
}
}
func TestShouldGenerateFilter(t *testing.T) {
tests := []struct {
name string
col database.Column
expected bool
}{
{
name: "boolean field",
col: database.Column{
Name: "alive",
DataType: "boolean",
},
expected: true,
},
{
name: "id field",
col: database.Column{
Name: "id",
DataType: "integer",
},
expected: true,
},
{
name: "foreign key field",
col: database.Column{
Name: "user_id",
DataType: "integer",
},
expected: true,
},
{
name: "text field",
col: database.Column{
Name: "name",
DataType: "text",
},
expected: true,
},
{
name: "varchar field",
col: database.Column{
Name: "email",
DataType: "character varying",
},
expected: true,
},
{
name: "numeric field",
col: database.Column{
Name: "amount",
DataType: "numeric",
},
expected: false,
},
{
name: "timestamp field",
col: database.Column{
Name: "created_at",
DataType: "timestamp with time zone",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ShouldGenerateFilter(tt.col)
if result != tt.expected {
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
}
})
}
}
func TestNeedsDecimalImport(t *testing.T) {
tests := []struct {
name string
columns []database.Column
expected bool
}{
{
name: "has numeric column",
columns: []database.Column{
{Name: "amount", DataType: "numeric"},
},
expected: true,
},
{
name: "has decimal column",
columns: []database.Column{
{Name: "price", DataType: "decimal"},
},
expected: true,
},
{
name: "no numeric columns",
columns: []database.Column{
{Name: "id", DataType: "integer"},
{Name: "name", DataType: "varchar"},
},
expected: false,
},
{
name: "empty columns",
columns: []database.Column{},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NeedsDecimalImport(tt.columns)
if result != tt.expected {
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
}
})
}
}
func TestNeedsDatetimeImport(t *testing.T) {
tests := []struct {
name string
columns []database.Column
expected bool
}{
{
name: "has timestamp column",
columns: []database.Column{
{Name: "created_at", DataType: "timestamp with time zone"},
},
expected: true,
},
{
name: "has date column",
columns: []database.Column{
{Name: "birth_date", DataType: "date"},
},
expected: true,
},
{
name: "no datetime columns",
columns: []database.Column{
{Name: "id", DataType: "integer"},
{Name: "name", DataType: "varchar"},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NeedsDatetimeImport(tt.columns)
if result != tt.expected {
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetRequiredColumns(t *testing.T) {
columns := []database.Column{
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
{Name: "email", DataType: "varchar", IsNullable: true},
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
}
result := GetRequiredColumns(columns)
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
if len(result) != 1 {
t.Errorf("Expected 1 required column, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "name" {
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
}
}
func TestGetOptionalColumns(t *testing.T) {
columns := []database.Column{
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
}
result := GetOptionalColumns(columns)
// Should include 'email' (nullable) and 'created_at' (has default)
if len(result) != 2 {
t.Errorf("Expected 2 optional columns, got %d", len(result))
}
}
func TestGetForeignKeyForColumn(t *testing.T) {
fks := []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users"},
{ColumnName: "company_id", ForeignTableName: "companies"},
}
tests := []struct {
name string
columnName string
expectNil bool
expectFK string
}{
{"found user_id", "user_id", false, "users"},
{"found company_id", "company_id", false, "companies"},
{"not found", "status_id", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetForeignKeyForColumn(tt.columnName, fks)
if tt.expectNil {
if result != nil {
t.Errorf("Expected nil, got %+v", result)
}
} else {
if result == nil {
t.Errorf("Expected FK, got nil")
} else if result.ForeignTableName != tt.expectFK {
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
}
}
})
}
}
func TestGenerateInit(t *testing.T) {
ctx := &Context{
EntityName: "User",
ModuleName: "user",
}
result, err := GenerateInit(ctx)
if err != nil {
t.Fatalf("GenerateInit failed: %v", err)
}
// __init__.py should be empty
if result != "" {
t.Errorf("Expected empty __init__.py, got %q", result)
}
}
func TestGenerateRepository(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{},
}
ctx := NewContext(tableInfo, "")
result, err := GenerateRepository(ctx)
if err != nil {
t.Fatalf("GenerateRepository failed: %v", err)
}
// Check for expected content
expectedStrings := []string{
"class UserRepository",
"CRUDRepository",
"model_cls = User",
"from televend_core.databases.base_repository import CRUDRepository",
"from televend_core.databases.cloud_repositories.user.filter import",
"from televend_core.databases.cloud_repositories.user.model import User",
}
for _, expected := range expectedStrings {
if !strings.Contains(result, expected) {
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
}
func TestGenerateManager(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{},
}
ctx := NewContext(tableInfo, "")
result, err := GenerateManager(ctx)
if err != nil {
t.Fatalf("GenerateManager failed: %v", err)
}
// Check for expected content
expectedStrings := []string{
"class UserManager",
"CRUDManager",
"repository_cls = UserRepository",
}
for _, expected := range expectedStrings {
if !strings.Contains(result, expected) {
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
}
func TestNewContext(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "user_accounts",
}
// Test without override
ctx := NewContext(tableInfo, "")
if ctx.EntityName != "UserAccount" {
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
}
if ctx.ModuleName != "user_account" {
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
}
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
}
// Test with override
ctx = NewContext(tableInfo, "CustomUser")
if ctx.EntityName != "CustomUser" {
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
}
}
func TestGenerateTable(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "simple table with basic types",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from sqlalchemy import",
"Column",
"Table",
"Integer",
"String",
"USER_TABLE = Table(",
`"users"`,
"metadata_obj",
`Column("id", Integer, primary_key=True, autoincrement=True)`,
`Column("name", String(255), nullable=False)`,
`Column("email", String)`,
},
},
{
name: "table with foreign keys",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"ForeignKey",
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
},
},
{
name: "table with enum",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{
"order_status": {
TypeName: "order_status",
Values: []string{"pending", "completed", "cancelled"},
},
},
},
expected: []string{
"Enum",
"from televend_core.databases.cloud_repositories.order.enum import",
"OrderStatus",
`Enum(`,
`*OrderStatus.to_value_list()`,
`name="order_status"`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateTable(ctx)
if err != nil {
t.Fatalf("GenerateTable failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateModel(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "simple model",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from dataclasses import dataclass",
"from televend_core.databases.base_model import Base",
"@dataclass",
"class User(Base):",
"name: str",
"email: str | None = None",
"id: int | None = None",
},
},
{
name: "model with foreign key",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from televend_core.databases.cloud_repositories.user.model import User",
"user_id: int",
"user: User",
"title: str",
},
},
{
name: "model with datetime and decimal",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "amount", DataType: "numeric", IsNullable: false},
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from datetime import datetime",
"from decimal import Decimal",
"amount: Decimal",
"created_at: datetime | None = None",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateModel(ctx)
if err != nil {
t.Fatalf("GenerateModel failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateFilter(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
notExpect []string
}{
{
name: "filter with boolean and id fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "name", DataType: "varchar"},
{Name: "alive", DataType: "boolean"},
{Name: "user_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"from televend_core.databases.base_filter import BaseFilter",
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
"class UserFilter(BaseFilter):",
"model_cls = User",
"id: int | None = filterfield(operator=EQ)",
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
"name: str | None = filterfield(operator=EQ)",
"alive: bool | None = filterfield(operator=EQ, default=True)",
"user_id: int | None = filterfield(operator=EQ)",
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
},
notExpect: []string{
"default=None",
},
},
{
name: "filter with no filterable fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "logs",
Columns: []database.Column{
{Name: "timestamp", DataType: "timestamp with time zone"},
{Name: "amount", DataType: "numeric"},
},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"class LogFilter(BaseFilter):",
"pass",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateFilter(ctx)
if err != nil {
t.Fatalf("GenerateFilter failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
for _, notExpected := range tt.notExpect {
if strings.Contains(result, notExpected) {
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
notExpected, result)
}
}
})
}
}
func TestGenerateLoadOptions(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "load options with relationships",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "user_id", DataType: "integer"},
{Name: "category_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
},
},
expected: []string{
"from televend_core.databases.base_load_options import LoadOptions",
"from televend_core.databases.common.load_options import joinload",
"class PostLoadOptions(LoadOptions):",
"model_cls = Post",
`load_user: bool = joinload(relations=["user"])`,
`load_category: bool = joinload(relations=["category"])`,
},
},
{
name: "load options with no relationships",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "settings",
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"class SettingLoadOptions(LoadOptions):",
"pass",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateLoadOptions(ctx)
if err != nil {
t.Fatalf("GenerateLoadOptions failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateFactory(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "factory with basic fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
{Name: "alive", DataType: "boolean"},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from __future__ import annotations",
"from typing import Type",
"import factory",
"class UserFactory(CloudBaseFactory):",
"alive = True",
"id = None",
`name = factory.Faker("pystr", max_chars=100)`,
"class Meta:",
"model = User",
"def create_minimal",
},
},
{
name: "factory with foreign keys",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from televend_core.databases.cloud_repositories.user.factory import",
"UserFactory",
`user = CustomSelfAttribute("..user", UserFactory)`,
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
},
},
{
name: "factory with decimal field",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateFactory(ctx)
if err != nil {
t.Fatalf("GenerateFactory failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateMapper(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "mapper with single foreign key",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "user_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
},
expected: []string{
"mapper_registry.map_imperatively(",
"class_=Post,",
"local_table=POST_TABLE,",
"properties={",
`"user": relationship(`,
"User, lazy=relationship_loading_strategy.value",
},
},
{
name: "mapper with multiple foreign keys to same table",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "messages",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "sender_id", DataType: "integer"},
{Name: "receiver_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
},
expected: []string{
`"sender": relationship(`,
"User, lazy=relationship_loading_strategy.value",
`"receiver": relationship(`,
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateMapper(ctx)
if err != nil {
t.Fatalf("GenerateMapper failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGetPythonTypeForColumn(t *testing.T) {
ctx := &Context{
TableInfo: &database.TableInfo{
EnumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"active", "inactive"},
},
},
},
}
tests := []struct {
name string
col database.Column
expected string
}{
{
name: "integer type",
col: database.Column{DataType: "integer"},
expected: "int",
},
{
name: "varchar type",
col: database.Column{DataType: "varchar"},
expected: "str",
},
{
name: "enum type",
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
expected: "StatusEnum",
},
{
name: "unknown enum type",
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
expected: "str",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPythonTypeForColumn(tt.col, ctx)
if result != tt.expected {
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}

View File

@ -12,7 +12,7 @@ func GenerateLoadOptions(ctx *Context) (string, error) {
// Imports
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString("\n\n")

View File

@ -11,13 +11,13 @@ func GenerateManager(ctx *Context) (string, error) {
// Imports
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.filter import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.repository import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
b.WriteString(")\n")

View File

@ -38,7 +38,7 @@ func GenerateModel(ctx *Context) (string, error) {
// Import enum types
if len(ctx.TableInfo.EnumTypes) > 0 {
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.enum import (\n",
ctx.ModuleName))
for _, enumType := range ctx.TableInfo.EnumTypes {
enumName := naming.ToPascalCase(enumType.TypeName)
@ -49,7 +49,7 @@ func GenerateModel(ctx *Context) (string, error) {
// Write foreign key imports
for moduleName, entityName := range fkImports {
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
moduleName, entityName))
}

View File

@ -11,11 +11,11 @@ func GenerateRepository(ctx *Context) (string, error) {
// Imports
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.filter import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString("\n\n")

View File

@ -72,14 +72,14 @@ func GenerateTable(ctx *Context) (string, error) {
if len(ctx.TableInfo.EnumTypes) > 0 {
for _, enumType := range ctx.TableInfo.EnumTypes {
enumName := naming.ToPascalCase(enumType.TypeName)
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.enum import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
b.WriteString(")\n")
}
}
b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n")
b.WriteString("from televend_core.databases.cloud_repositories.table_meta import metadata_obj\n\n")
// Table definition
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))

View File

@ -155,8 +155,12 @@ func Pluralize(word string) string {
return preserveCase(word, plural)
}
// Already plural (ends in 's' and not special case)
if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") {
// Already plural (ends in 's' after a consonant, but not 'ss', 'us', 'is')
// Skip this check if word ends in 'ss' (like 'class', 'glass')
if strings.HasSuffix(lower, "s") &&
!strings.HasSuffix(lower, "ss") &&
!strings.HasSuffix(lower, "us") &&
!strings.HasSuffix(lower, "is") {
return word
}

View File

@ -0,0 +1,247 @@
package naming
import (
"testing"
)
func TestSingularize(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
// Regular plurals
{"simple s", "users", "user"},
{"simple s uppercase", "Users", "User"},
{"tables", "tables", "table"},
// ies -> y
{"ies to y", "companies", "company"},
{"ies to y uppercase", "Companies", "Company"},
{"cities", "cities", "city"},
{"categories", "categories", "category"},
// ves -> f/fe
{"ves to f", "halves", "half"},
{"ves to fe - knife", "knives", "knife"},
{"ves to fe - wife", "wives", "wife"},
{"ves to fe - life", "lives", "life"},
// es endings
{"xes", "boxes", "box"},
{"shes", "dishes", "dish"},
{"ches", "watches", "watch"},
{"ses", "classes", "class"},
// oes -> o
{"oes", "tomatoes", "tomato"},
{"heroes", "heroes", "hero"},
// Irregular plurals
{"people", "people", "person"},
{"children", "children", "child"},
{"men", "men", "man"},
{"women", "women", "woman"},
{"teeth", "teeth", "tooth"},
{"feet", "feet", "foot"},
{"mice", "mice", "mouse"},
{"geese", "geese", "goose"},
// Unchanged
{"sheep", "sheep", "sheep"},
{"fish", "fish", "fish"},
{"deer", "deer", "deer"},
{"series", "series", "series"},
{"species", "species", "species"},
// Special cases
{"data", "data", "datum"},
{"analyses", "analyses", "analysis"},
{"crises", "crises", "crisis"},
// Edge cases
{"empty", "", ""},
{"already singular", "user", "user"},
{"ss ending", "glass", "glass"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Singularize(tt.input)
if result != tt.expected {
t.Errorf("Singularize(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestPluralize(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
// Regular plurals
{"simple", "user", "users"},
{"simple uppercase", "User", "Users"},
{"table", "table", "tables"},
// y -> ies
{"y to ies", "company", "companies"},
{"y to ies uppercase", "Company", "Companies"},
{"city", "city", "cities"},
{"category", "category", "categories"},
// consonant + y stays
{"ay ending", "day", "days"},
{"ey ending", "key", "keys"},
{"oy ending", "boy", "boys"},
// f/fe -> ves
{"f to ves", "half", "halves"},
{"fe to ves - knife", "knife", "knives"},
{"fe to ves - wife", "wife", "wives"},
{"fe to ves - life", "life", "lives"},
// s, x, sh, ch -> es
{"x to es", "box", "boxes"},
{"sh to es", "dish", "dishes"},
{"ch to es", "watch", "watches"},
{"s to es", "class", "classes"},
// o -> oes
{"o to oes", "tomato", "tomatoes"},
{"hero", "hero", "heroes"},
// Irregular plurals
{"person", "person", "people"},
{"child", "child", "children"},
{"man", "man", "men"},
{"woman", "woman", "women"},
{"tooth", "tooth", "teeth"},
{"foot", "foot", "feet"},
{"mouse", "mouse", "mice"},
{"goose", "goose", "geese"},
// Unchanged
{"sheep", "sheep", "sheep"},
{"fish", "fish", "fish"},
{"deer", "deer", "deer"},
{"series", "series", "series"},
{"species", "species", "species"},
// Edge cases
{"empty", "", ""},
{"already plural", "users", "users"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := Pluralize(tt.input)
if result != tt.expected {
t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestSingularizeTableName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"simple", "users", "user"},
{"compound", "user_accounts", "user_account"},
{"triple compound", "user_login_histories", "user_login_history"},
{"already singular", "user", "user"},
{"ies ending", "cashbag_conforms", "cashbag_conform"},
{"complex table", "auth_user_groups", "auth_user_group"},
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := SingularizeTableName(tt.input)
if result != tt.expected {
t.Errorf("SingularizeTableName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestToPascalCase(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"simple", "user", "User"},
{"snake_case", "user_account", "UserAccount"},
{"triple", "user_login_history", "UserLoginHistory"},
{"already pascal", "UserAccount", "UserAccount"},
{"single char", "a", "A"},
{"empty", "", ""},
{"with numbers", "user_2fa", "User2fa"},
{"multiple underscores", "user___account", "UserAccount"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ToPascalCase(tt.input)
if result != tt.expected {
t.Errorf("ToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestToSnakeCase(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"PascalCase", "UserAccount", "user_account"},
{"simple", "User", "user"},
{"multiple words", "UserLoginHistory", "user_login_history"},
{"already snake", "user_account", "user_account"},
{"empty", "", ""},
{"single char", "A", "a"},
{"with numbers", "User2FA", "user2_f_a"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ToSnakeCase(tt.input)
if result != tt.expected {
t.Errorf("ToSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
// Benchmark tests
func BenchmarkSingularize(b *testing.B) {
for i := 0; i < b.N; i++ {
Singularize("companies")
}
}
func BenchmarkPluralize(b *testing.B) {
for i := 0; i < b.N; i++ {
Pluralize("company")
}
}
func BenchmarkToPascalCase(b *testing.B) {
for i := 0; i < b.N; i++ {
ToPascalCase("user_login_history")
}
}
func BenchmarkToSnakeCase(b *testing.B) {
for i := 0; i < b.N; i++ {
ToSnakeCase("UserLoginHistory")
}
}

View File

@ -0,0 +1,319 @@
package prompt
import (
"io"
"os"
"path/filepath"
"strings"
"testing"
)
// mockReader implements the Reader interface for testing
type mockReader struct {
input string
pos int
}
func (m *mockReader) ReadString(delim byte) (string, error) {
if m.pos >= len(m.input) {
return "", io.EOF
}
idx := strings.IndexByte(m.input[m.pos:], delim)
if idx == -1 {
result := m.input[m.pos:]
m.pos = len(m.input)
return result, nil
}
result := m.input[m.pos : m.pos+idx+1]
m.pos += idx + 1
return result, nil
}
func TestPromptStringWithReader(t *testing.T) {
tests := []struct {
name string
input string
defaultValue string
required bool
expected string
expectError bool
}{
{
name: "user enters value",
input: "testvalue\n",
defaultValue: "default",
required: false,
expected: "testvalue",
expectError: false,
},
{
name: "user accepts default",
input: "\n",
defaultValue: "default",
required: false,
expected: "default",
expectError: false,
},
{
name: "required field empty then filled",
input: "\nactualvalue\n",
defaultValue: "",
required: true,
expected: "actualvalue",
expectError: false,
},
{
name: "optional empty field",
input: "\n",
defaultValue: "",
required: false,
expected: "",
expectError: false,
},
{
name: "whitespace trimmed",
input: " test \n",
defaultValue: "",
required: false,
expected: "test",
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := &mockReader{input: tt.input}
result, err := promptStringWithReader(reader, "Test Label", tt.defaultValue, tt.required)
if (err != nil) != tt.expectError {
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
}
if result != tt.expected {
t.Errorf("Expected %q, got %q", tt.expected, result)
}
})
}
}
func TestPromptIntWithReader(t *testing.T) {
tests := []struct {
name string
input string
defaultValue int
required bool
expected int
expectError bool
}{
{
name: "user enters valid number",
input: "42\n",
defaultValue: 0,
required: false,
expected: 42,
expectError: false,
},
{
name: "user accepts default",
input: "\n",
defaultValue: 5432,
required: false,
expected: 5432,
expectError: false,
},
{
name: "invalid then valid number",
input: "abc\n123\n",
defaultValue: 0,
required: false,
expected: 123,
expectError: false,
},
{
name: "required field empty then filled",
input: "\n999\n",
defaultValue: 0,
required: true,
expected: 999,
expectError: false,
},
{
name: "zero value",
input: "0\n",
defaultValue: 10,
required: false,
expected: 0,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader := &mockReader{input: tt.input}
result, err := promptIntWithReader(reader, "Test Label", tt.defaultValue, tt.required)
if (err != nil) != tt.expectError {
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
}
if result != tt.expected {
t.Errorf("Expected %d, got %d", tt.expected, result)
}
})
}
}
func TestValidateDirectory(t *testing.T) {
// Create a temporary directory for testing
tmpDir := t.TempDir()
tests := []struct {
name string
path string
setup func() string
expectError bool
}{
{
name: "existing directory",
setup: func() string {
return tmpDir
},
expectError: false,
},
{
name: "non-existent directory gets created",
setup: func() string {
return filepath.Join(tmpDir, "newdir")
},
expectError: false,
},
{
name: "nested directory gets created",
setup: func() string {
return filepath.Join(tmpDir, "level1", "level2", "level3")
},
expectError: false,
},
{
name: "file exists at path",
setup: func() string {
filePath := filepath.Join(tmpDir, "testfile")
os.WriteFile(filePath, []byte("test"), 0644)
return filePath
},
expectError: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
path := tt.setup()
err := ValidateDirectory(path)
if (err != nil) != tt.expectError {
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
}
// If no error expected, verify directory exists
if !tt.expectError {
info, err := os.Stat(path)
if err != nil {
t.Errorf("Directory should exist: %v", err)
}
if !info.IsDir() {
t.Errorf("Path should be a directory")
}
}
})
}
}
func TestPrintFunctions(t *testing.T) {
// These functions write to stdout, so we just ensure they don't panic
// Testing colored output is complex due to terminal color codes
tests := []struct {
name string
fn func()
}{
{"PrintHeader", func() { PrintHeader("Test Header") }},
{"PrintSuccess", func() { PrintSuccess("Success message") }},
{"PrintError", func() { PrintError("Error message") }},
{"PrintInfo", func() { PrintInfo("Info message") }},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Just ensure the function doesn't panic
defer func() {
if r := recover(); r != nil {
t.Errorf("Function panicked: %v", r)
}
}()
tt.fn()
})
}
}
func TestPromptDirectory(t *testing.T) {
tmpDir := t.TempDir()
tests := []struct {
name string
input string
defaultValue string
required bool
expectError bool
}{
{
name: "valid directory path",
input: tmpDir + "\n",
defaultValue: "",
required: true,
expectError: false,
},
{
name: "creates new directory",
input: filepath.Join(tmpDir, "newdir") + "\n",
defaultValue: "",
required: true,
expectError: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Note: This test would require refactoring PromptDirectory to accept a reader
// For now, we test ValidateDirectory which is the core logic
if tt.input != "" {
path := strings.TrimSpace(tt.input[:len(tt.input)-1])
err := ValidateDirectory(path)
if (err != nil) != tt.expectError {
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
}
}
})
}
}
// Benchmark tests
func BenchmarkPromptStringWithReader(b *testing.B) {
reader := &mockReader{input: strings.Repeat("test\n", b.N)}
b.ResetTimer()
for i := 0; i < b.N; i++ {
promptStringWithReader(reader, "Label", "default", false)
}
}
func BenchmarkPromptIntWithReader(b *testing.B) {
reader := &mockReader{input: strings.Repeat("42\n", b.N)}
b.ResetTimer()
for i := 0; i < b.N; i++ {
promptIntWithReader(reader, "Label", 0, false)
}
}