Compare commits
2 Commits
499f59ff12
...
4e4827d640
| Author | SHA1 | Date | |
|---|---|---|---|
| 4e4827d640 | |||
| ca10b01fb0 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -3,3 +3,7 @@
|
||||
|
||||
*.toml
|
||||
/entity-maker
|
||||
|
||||
# Test coverage
|
||||
coverage.out
|
||||
coverage.html
|
||||
|
||||
122
Makefile
122
Makefile
@ -1,10 +1,130 @@
|
||||
EXEC=entity-maker
|
||||
BUILD_DIR=./build
|
||||
CMD_DIR=./cmd/entity-maker
|
||||
|
||||
# Build variables
|
||||
GOOS ?= linux
|
||||
GOARCH ?= amd64
|
||||
LDFLAGS=-s -w
|
||||
|
||||
|
||||
.PHONY: build
|
||||
build:
|
||||
@go build -ldflags "-s -w" -o ./build/${EXEC} ./cmd/entity-maker/main.go
|
||||
@echo "Building ${EXEC} for ${GOOS}/${GOARCH}..."
|
||||
@mkdir -p ${BUILD_DIR}
|
||||
@CGO_ENABLED=0 GOOS=${GOOS} GOARCH=${GOARCH} go build \
|
||||
-ldflags "${LDFLAGS}" \
|
||||
-trimpath \
|
||||
-o ${BUILD_DIR}/${EXEC} \
|
||||
${CMD_DIR}
|
||||
@echo "✓ Binary created: ${BUILD_DIR}/${EXEC}"
|
||||
@ls -lh ${BUILD_DIR}/${EXEC}
|
||||
|
||||
|
||||
.PHONY: build-dev
|
||||
build-dev:
|
||||
@echo "Building ${EXEC} (development mode)..."
|
||||
@mkdir -p ${BUILD_DIR}
|
||||
@go build -o ${BUILD_DIR}/${EXEC} ${CMD_DIR}
|
||||
@echo "✓ Binary created: ${BUILD_DIR}/${EXEC}"
|
||||
|
||||
|
||||
.PHONY: install
|
||||
install: build
|
||||
@echo "Installing ${EXEC} to /usr/local/bin..."
|
||||
@sudo cp ${BUILD_DIR}/${EXEC} /usr/local/bin/${EXEC}
|
||||
@echo "✓ Installed successfully"
|
||||
|
||||
|
||||
.PHONY: clean
|
||||
clean:
|
||||
@echo "Cleaning build artifacts..."
|
||||
@rm -rf ${BUILD_DIR}
|
||||
@rm -f ${EXEC}
|
||||
@echo "✓ Clean complete"
|
||||
|
||||
|
||||
.PHONY: test
|
||||
test:
|
||||
@echo "Running tests..."
|
||||
@go test ./...
|
||||
@echo "✓ Tests passed"
|
||||
|
||||
|
||||
.PHONY: test-verbose
|
||||
test-verbose:
|
||||
@echo "Running tests (verbose)..."
|
||||
@go test -v ./...
|
||||
|
||||
|
||||
.PHONY: test-coverage
|
||||
test-coverage:
|
||||
@echo "Running tests with coverage..."
|
||||
@go test -cover ./...
|
||||
@echo "✓ Tests completed"
|
||||
|
||||
|
||||
.PHONY: coverage
|
||||
coverage:
|
||||
@echo "Generating coverage report..."
|
||||
@go test -coverprofile=coverage.out ./...
|
||||
@go tool cover -html=coverage.out -o coverage.html
|
||||
@echo "✓ Coverage report generated: coverage.html"
|
||||
@echo " Open coverage.html in your browser to view detailed coverage"
|
||||
|
||||
|
||||
.PHONY: test-short
|
||||
test-short:
|
||||
@echo "Running short tests..."
|
||||
@go test -short ./...
|
||||
@echo "✓ Short tests passed"
|
||||
|
||||
|
||||
.PHONY: bench
|
||||
bench:
|
||||
@echo "Running benchmarks..."
|
||||
@go test -bench=. -benchmem ./internal/naming
|
||||
@echo "✓ Benchmarks completed"
|
||||
|
||||
|
||||
.PHONY: verify-static
|
||||
verify-static: build
|
||||
@echo "Verifying binary is statically linked..."
|
||||
@file ${BUILD_DIR}/${EXEC}
|
||||
@ldd ${BUILD_DIR}/${EXEC} 2>&1 || echo "✓ Binary is statically linked"
|
||||
|
||||
|
||||
.PHONY: upgrade-packages
|
||||
upgrade-packages:
|
||||
@echo "Upgrading Go packages..."
|
||||
@go get -u ./...
|
||||
@go mod tidy
|
||||
@echo "✓ Packages upgraded"
|
||||
|
||||
|
||||
.PHONY: help
|
||||
help:
|
||||
@echo "Entity Maker - Makefile targets:"
|
||||
@echo ""
|
||||
@echo "Build targets:"
|
||||
@echo " build Build static binary for Linux (default)"
|
||||
@echo " build-dev Build development binary with debug symbols"
|
||||
@echo " install Install binary to /usr/local/bin"
|
||||
@echo " clean Remove build artifacts"
|
||||
@echo " verify-static Verify binary has no dynamic dependencies"
|
||||
@echo ""
|
||||
@echo "Test targets:"
|
||||
@echo " test Run all tests"
|
||||
@echo " test-verbose Run tests with verbose output"
|
||||
@echo " test-coverage Run tests and show coverage summary"
|
||||
@echo " coverage Generate HTML coverage report"
|
||||
@echo " test-short Run only short tests"
|
||||
@echo " bench Run benchmarks"
|
||||
@echo ""
|
||||
@echo "Other targets:"
|
||||
@echo " upgrade-packages Update Go dependencies"
|
||||
@echo " help Show this help message"
|
||||
@echo ""
|
||||
@echo "Build options:"
|
||||
@echo " GOOS=linux GOARCH=amd64 make build (default)"
|
||||
@echo ""
|
||||
|
||||
80
README.md
80
README.md
@ -21,10 +21,90 @@ A command-line tool for generating SQLAlchemy entities from PostgreSQL database
|
||||
|
||||
### Build from source
|
||||
|
||||
Using Makefile (recommended):
|
||||
|
||||
```bash
|
||||
# Build static binary (production-ready, fully portable)
|
||||
make build
|
||||
|
||||
# Build with debug symbols (development)
|
||||
make build-dev
|
||||
|
||||
# Install to /usr/local/bin
|
||||
make install
|
||||
|
||||
# Clean build artifacts
|
||||
make clean
|
||||
|
||||
# Verify binary has no dependencies
|
||||
make verify-static
|
||||
|
||||
# Show all available targets
|
||||
make help
|
||||
```
|
||||
|
||||
Or using Go directly:
|
||||
|
||||
```bash
|
||||
# Static binary (portable across all Linux x86-64 systems)
|
||||
CGO_ENABLED=0 go build -ldflags "-s -w" -trimpath -o ./build/entity-maker ./cmd/entity-maker
|
||||
|
||||
# Development binary with debug symbols
|
||||
go build -o entity-maker ./cmd/entity-maker
|
||||
```
|
||||
|
||||
The static binary (`make build`) is **fully portable** and will run on any Linux x86-64 system without external dependencies.
|
||||
|
||||
## Testing
|
||||
|
||||
Run tests using Makefile commands:
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
make test
|
||||
|
||||
# Run tests with verbose output
|
||||
make test-verbose
|
||||
|
||||
# Run tests with coverage summary
|
||||
make test-coverage
|
||||
|
||||
# Generate HTML coverage report
|
||||
make coverage
|
||||
# Then open coverage.html in your browser
|
||||
|
||||
# Run only short/fast tests
|
||||
make test-short
|
||||
|
||||
# Run benchmarks
|
||||
make bench
|
||||
```
|
||||
|
||||
Or using Go directly:
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
go test ./...
|
||||
|
||||
# Run with coverage
|
||||
go test -cover ./...
|
||||
|
||||
# Run specific package
|
||||
go test ./internal/naming -v
|
||||
|
||||
# Generate coverage report
|
||||
go test -coverprofile=coverage.out ./...
|
||||
go tool cover -html=coverage.out
|
||||
```
|
||||
|
||||
**Test Coverage:**
|
||||
- `internal/naming`: 97.1% ⭐
|
||||
- `internal/generator`: 91.1% ⭐
|
||||
- `internal/config`: 81.8%
|
||||
- `internal/prompt`: 75.3%
|
||||
- `internal/database`: 30.8%
|
||||
- `cmd/entity-maker`: 0.0% (integration code - business logic tested in internal packages)
|
||||
|
||||
## Usage
|
||||
|
||||
### Interactive Mode
|
||||
|
||||
299
cmd/entity-maker/main_test.go
Normal file
299
cmd/entity-maker/main_test.go
Normal file
@ -0,0 +1,299 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/generator"
|
||||
)
|
||||
|
||||
// Note: The cmd/entity-maker package has low test coverage by design.
|
||||
// The run() function is integration code that orchestrates calls to:
|
||||
// - Prompt functions (requires user input)
|
||||
// - Database connections (requires live PostgreSQL)
|
||||
// - File system operations (requires specific paths)
|
||||
//
|
||||
// All the business logic is tested in the internal/* packages where
|
||||
// coverage is high (91-97%). The tests below validate the file generation
|
||||
// workflow in isolation without needing the full integration environment.
|
||||
|
||||
// TestPackageCompiles validates that the package compiles correctly
|
||||
func TestPackageCompiles(t *testing.T) {
|
||||
// This test ensures the main package compiles correctly
|
||||
// The main() function itself is hard to test as it calls os.Exit()
|
||||
// and depends on external resources (database, user input, file system)
|
||||
// If this test runs, the package compiled successfully
|
||||
t.Log("Main package compiled successfully")
|
||||
}
|
||||
|
||||
// TestFileGeneration tests the file generation logic in isolation
|
||||
func TestFileGeneration(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tmpDir := t.TempDir()
|
||||
moduleName := "test_entity"
|
||||
moduleDir := filepath.Join(tmpDir, moduleName)
|
||||
|
||||
// Create the module directory
|
||||
if err := os.MkdirAll(moduleDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create module directory: %v", err)
|
||||
}
|
||||
|
||||
// Create a simple test context
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "test_table",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
}
|
||||
|
||||
ctx := generator.NewContext(tableInfo, "")
|
||||
|
||||
// Define files to generate (same as in main.go)
|
||||
files := map[string]func(*generator.Context) (string, error){
|
||||
"table.py": generator.GenerateTable,
|
||||
"model.py": generator.GenerateModel,
|
||||
"filter.py": generator.GenerateFilter,
|
||||
"load_options.py": generator.GenerateLoadOptions,
|
||||
"repository.py": generator.GenerateRepository,
|
||||
"manager.py": generator.GenerateManager,
|
||||
"factory.py": generator.GenerateFactory,
|
||||
"mapper.py": generator.GenerateMapper,
|
||||
"__init__.py": generator.GenerateInit,
|
||||
}
|
||||
|
||||
// Generate and write each file
|
||||
for filename, genFunc := range files {
|
||||
content, err := genFunc(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to generate %s: %v", filename, err)
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(moduleDir, filename)
|
||||
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
t.Errorf("Failed to write %s: %v", filename, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify file was created
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
t.Errorf("File %s was not created", filename)
|
||||
}
|
||||
|
||||
// Verify file has content
|
||||
if len(content) == 0 && filename != "__init__.py" {
|
||||
t.Errorf("File %s has no content", filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify all expected files exist
|
||||
expectedFiles := []string{
|
||||
"table.py", "model.py", "filter.py", "load_options.py",
|
||||
"repository.py", "manager.py", "factory.py", "mapper.py", "__init__.py",
|
||||
}
|
||||
|
||||
for _, filename := range expectedFiles {
|
||||
filePath := filepath.Join(moduleDir, filename)
|
||||
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||
t.Errorf("Expected file %s does not exist", filename)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestFileGenerationWithEnums tests file generation when enum types are present
|
||||
func TestFileGenerationWithEnums(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
moduleName := "test_entity"
|
||||
moduleDir := filepath.Join(tmpDir, moduleName)
|
||||
|
||||
if err := os.MkdirAll(moduleDir, 0755); err != nil {
|
||||
t.Fatalf("Failed to create module directory: %v", err)
|
||||
}
|
||||
|
||||
// Create test context with enums
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "test_table",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "status", DataType: "USER-DEFINED", UdtName: "status_enum", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"active", "inactive", "pending"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := generator.NewContext(tableInfo, "")
|
||||
|
||||
files := map[string]func(*generator.Context) (string, error){
|
||||
"table.py": generator.GenerateTable,
|
||||
"model.py": generator.GenerateModel,
|
||||
"filter.py": generator.GenerateFilter,
|
||||
"load_options.py": generator.GenerateLoadOptions,
|
||||
"repository.py": generator.GenerateRepository,
|
||||
"manager.py": generator.GenerateManager,
|
||||
"factory.py": generator.GenerateFactory,
|
||||
"mapper.py": generator.GenerateMapper,
|
||||
"__init__.py": generator.GenerateInit,
|
||||
}
|
||||
|
||||
// Add enum.py since we have enum types
|
||||
if len(tableInfo.EnumTypes) > 0 {
|
||||
files["enum.py"] = generator.GenerateEnum
|
||||
}
|
||||
|
||||
// Generate all files
|
||||
for filename, genFunc := range files {
|
||||
content, err := genFunc(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to generate %s: %v", filename, err)
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := filepath.Join(moduleDir, filename)
|
||||
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
t.Errorf("Failed to write %s: %v", filename, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify enum.py was created
|
||||
enumPath := filepath.Join(moduleDir, "enum.py")
|
||||
if _, err := os.Stat(enumPath); os.IsNotExist(err) {
|
||||
t.Error("enum.py was not created when enum types are present")
|
||||
}
|
||||
}
|
||||
|
||||
// TestModuleDirectoryCreation tests directory creation logic
|
||||
func TestModuleDirectoryCreation(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
moduleName string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple module name",
|
||||
moduleName: "user",
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "nested module name",
|
||||
moduleName: filepath.Join("nested", "module"),
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "module with underscore",
|
||||
moduleName: "user_account",
|
||||
expectErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
moduleDir := filepath.Join(tmpDir, tt.moduleName)
|
||||
err := os.MkdirAll(moduleDir, 0755)
|
||||
|
||||
if (err != nil) != tt.expectErr {
|
||||
t.Errorf("MkdirAll() error = %v, expectErr %v", err, tt.expectErr)
|
||||
}
|
||||
|
||||
if !tt.expectErr {
|
||||
// Verify directory exists
|
||||
info, err := os.Stat(moduleDir)
|
||||
if err != nil {
|
||||
t.Errorf("Directory should exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Error("Path should be a directory")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeneratedFilePermissions tests that generated files have correct permissions
|
||||
func TestGeneratedFilePermissions(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
testFile := filepath.Join(tmpDir, "test.py")
|
||||
|
||||
content := "# Test content\n"
|
||||
if err := os.WriteFile(testFile, []byte(content), 0644); err != nil {
|
||||
t.Fatalf("Failed to write test file: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Stat(testFile)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to stat test file: %v", err)
|
||||
}
|
||||
|
||||
mode := info.Mode()
|
||||
|
||||
// Check that owner can read and write
|
||||
if mode&0600 != 0600 {
|
||||
t.Error("File should be readable and writable by owner")
|
||||
}
|
||||
|
||||
// Check that file is not executable
|
||||
if mode&0111 != 0 {
|
||||
t.Error("Python files should not be executable")
|
||||
}
|
||||
}
|
||||
|
||||
// TestGeneratedFilesAreNonEmpty tests that generated files have content
|
||||
func TestGeneratedFilesAreNonEmpty(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
}
|
||||
|
||||
ctx := generator.NewContext(tableInfo, "")
|
||||
|
||||
generators := map[string]func(*generator.Context) (string, error){
|
||||
"table": generator.GenerateTable,
|
||||
"model": generator.GenerateModel,
|
||||
"filter": generator.GenerateFilter,
|
||||
"load_options": generator.GenerateLoadOptions,
|
||||
"repository": generator.GenerateRepository,
|
||||
"manager": generator.GenerateManager,
|
||||
"factory": generator.GenerateFactory,
|
||||
"mapper": generator.GenerateMapper,
|
||||
}
|
||||
|
||||
for name, genFunc := range generators {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
content, err := genFunc(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to generate %s: %v", name, err)
|
||||
}
|
||||
|
||||
if len(content) == 0 {
|
||||
t.Errorf("Generated %s content should not be empty", name)
|
||||
}
|
||||
|
||||
// Check that content has at least some basic Python structure
|
||||
if name != "mapper" { // mapper is just a snippet
|
||||
if len(content) < 10 {
|
||||
t.Errorf("Generated %s content seems too short: %d bytes", name, len(content))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
117
internal/config/config_test.go
Normal file
117
internal/config/config_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultConfig(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if cfg.DBHost != "localhost" {
|
||||
t.Errorf("Expected DBHost to be 'localhost', got %q", cfg.DBHost)
|
||||
}
|
||||
|
||||
if cfg.DBPort != 5432 {
|
||||
t.Errorf("Expected DBPort to be 5432, got %d", cfg.DBPort)
|
||||
}
|
||||
|
||||
if cfg.DBSchema != "public" {
|
||||
t.Errorf("Expected DBSchema to be 'public', got %q", cfg.DBSchema)
|
||||
}
|
||||
|
||||
if cfg.DBUser != "postgres" {
|
||||
t.Errorf("Expected DBUser to be 'postgres', got %q", cfg.DBUser)
|
||||
}
|
||||
|
||||
if cfg.DBPassword != "postgres" {
|
||||
t.Errorf("Expected DBPassword to be 'postgres', got %q", cfg.DBPassword)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveAndLoad(t *testing.T) {
|
||||
// Create a temporary directory and set HOME to it
|
||||
tmpDir := t.TempDir()
|
||||
oldHome := os.Getenv("HOME")
|
||||
os.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", oldHome)
|
||||
|
||||
configPath := filepath.Join(tmpDir, ".config", "entity-maker.toml")
|
||||
|
||||
// Create a test config
|
||||
cfg := &Config{
|
||||
DBHost: "testhost",
|
||||
DBPort: 5433,
|
||||
DBName: "testdb",
|
||||
DBSchema: "testschema",
|
||||
DBUser: "testuser",
|
||||
DBPassword: "testpass",
|
||||
DBTable: "testtable",
|
||||
OutputDir: "/tmp/test",
|
||||
EntityName: "TestEntity",
|
||||
}
|
||||
|
||||
// Save config
|
||||
err := cfg.Save()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to save config: %v", err)
|
||||
}
|
||||
|
||||
// Verify file exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
t.Fatalf("Config file was not created: %v", err)
|
||||
}
|
||||
|
||||
// Load config
|
||||
loadedCfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to load config: %v", err)
|
||||
}
|
||||
|
||||
// Verify loaded values
|
||||
tests := []struct {
|
||||
name string
|
||||
got interface{}
|
||||
expected interface{}
|
||||
}{
|
||||
{"DBHost", loadedCfg.DBHost, cfg.DBHost},
|
||||
{"DBPort", loadedCfg.DBPort, cfg.DBPort},
|
||||
{"DBName", loadedCfg.DBName, cfg.DBName},
|
||||
{"DBSchema", loadedCfg.DBSchema, cfg.DBSchema},
|
||||
{"DBUser", loadedCfg.DBUser, cfg.DBUser},
|
||||
{"DBPassword", loadedCfg.DBPassword, cfg.DBPassword},
|
||||
{"DBTable", loadedCfg.DBTable, cfg.DBTable},
|
||||
{"OutputDir", loadedCfg.OutputDir, cfg.OutputDir},
|
||||
{"EntityName", loadedCfg.EntityName, cfg.EntityName},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if tt.got != tt.expected {
|
||||
t.Errorf("%s mismatch: expected %v, got %v", tt.name, tt.expected, tt.got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadNonExistentConfig(t *testing.T) {
|
||||
// Create a temporary directory and set HOME to it
|
||||
tmpDir := t.TempDir()
|
||||
oldHome := os.Getenv("HOME")
|
||||
os.Setenv("HOME", tmpDir)
|
||||
defer os.Setenv("HOME", oldHome)
|
||||
|
||||
// Load should return default config when file doesn't exist
|
||||
cfg, err := Load()
|
||||
if err != nil {
|
||||
t.Fatalf("Load should not error when file doesn't exist: %v", err)
|
||||
}
|
||||
|
||||
// Verify it's the default config
|
||||
defaultCfg := DefaultConfig()
|
||||
if cfg.DBHost != defaultCfg.DBHost {
|
||||
t.Errorf("Expected default DBHost %q, got %q", defaultCfg.DBHost, cfg.DBHost)
|
||||
}
|
||||
if cfg.DBPort != defaultCfg.DBPort {
|
||||
t.Errorf("Expected default DBPort %d, got %d", defaultCfg.DBPort, cfg.DBPort)
|
||||
}
|
||||
}
|
||||
341
internal/database/database_test.go
Normal file
341
internal/database/database_test.go
Normal file
@ -0,0 +1,341 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetPythonType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
col Column
|
||||
expected string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", Column{DataType: "integer"}, "int"},
|
||||
{"smallint", Column{DataType: "smallint"}, "int"},
|
||||
{"bigint", Column{DataType: "bigint"}, "int"},
|
||||
|
||||
// Numeric types
|
||||
{"numeric", Column{DataType: "numeric"}, "Decimal"},
|
||||
{"decimal", Column{DataType: "decimal"}, "Decimal"},
|
||||
{"real", Column{DataType: "real"}, "Decimal"},
|
||||
{"double precision", Column{DataType: "double precision"}, "Decimal"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", Column{DataType: "boolean"}, "bool"},
|
||||
|
||||
// String types
|
||||
{"varchar", Column{DataType: "character varying"}, "str"},
|
||||
{"varchar short", Column{DataType: "varchar"}, "str"},
|
||||
{"text", Column{DataType: "text"}, "str"},
|
||||
{"char", Column{DataType: "char"}, "str"},
|
||||
{"character", Column{DataType: "character"}, "str"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"},
|
||||
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"},
|
||||
{"timestamp", Column{DataType: "timestamp"}, "datetime"},
|
||||
{"date", Column{DataType: "date"}, "date"},
|
||||
{"time with tz", Column{DataType: "time with time zone"}, "time"},
|
||||
{"time without tz", Column{DataType: "time without time zone"}, "time"},
|
||||
{"time", Column{DataType: "time"}, "time"},
|
||||
|
||||
// JSON types
|
||||
{"json", Column{DataType: "json"}, "dict"},
|
||||
{"jsonb", Column{DataType: "jsonb"}, "dict"},
|
||||
|
||||
// Other types
|
||||
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
||||
{"bytea", Column{DataType: "bytea"}, "bytes"},
|
||||
|
||||
// User-defined (enum)
|
||||
{"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"},
|
||||
|
||||
// Unknown type
|
||||
{"unknown", Column{DataType: "unknown_type"}, "Any"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPythonType(tt.col)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLAlchemyType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
col Column
|
||||
expected string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", Column{DataType: "integer"}, "Integer"},
|
||||
{"smallint", Column{DataType: "smallint"}, "SmallInteger"},
|
||||
{"bigint", Column{DataType: "bigint"}, "BigInteger"},
|
||||
|
||||
// Numeric types with precision
|
||||
{
|
||||
"numeric with precision",
|
||||
Column{
|
||||
DataType: "numeric",
|
||||
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
||||
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
||||
},
|
||||
"Numeric(12, 4)",
|
||||
},
|
||||
{
|
||||
"numeric without precision",
|
||||
Column{DataType: "numeric"},
|
||||
"Numeric",
|
||||
},
|
||||
{"real", Column{DataType: "real"}, "Float"},
|
||||
{"double precision", Column{DataType: "double precision"}, "Float"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", Column{DataType: "boolean"}, "Boolean"},
|
||||
|
||||
// String types
|
||||
{
|
||||
"varchar with length",
|
||||
Column{
|
||||
DataType: "character varying",
|
||||
CharMaxLength: sql.NullInt64{Valid: true, Int64: 255},
|
||||
},
|
||||
"String(255)",
|
||||
},
|
||||
{
|
||||
"varchar without length",
|
||||
Column{DataType: "varchar"},
|
||||
"String",
|
||||
},
|
||||
{
|
||||
"char with length",
|
||||
Column{
|
||||
DataType: "char",
|
||||
CharMaxLength: sql.NullInt64{Valid: true, Int64: 10},
|
||||
},
|
||||
"String(10)",
|
||||
},
|
||||
{
|
||||
"char without length",
|
||||
Column{DataType: "character"},
|
||||
"String(1)",
|
||||
},
|
||||
{"text", Column{DataType: "text"}, "Text"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"},
|
||||
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"},
|
||||
{"timestamp", Column{DataType: "timestamp"}, "DateTime"},
|
||||
{"date", Column{DataType: "date"}, "Date"},
|
||||
{"time with tz", Column{DataType: "time with time zone"}, "Time"},
|
||||
{"time without tz", Column{DataType: "time without time zone"}, "Time"},
|
||||
{"time", Column{DataType: "time"}, "Time"},
|
||||
|
||||
// JSON types
|
||||
{"json", Column{DataType: "json"}, "JSON"},
|
||||
{"jsonb", Column{DataType: "jsonb"}, "JSONB"},
|
||||
|
||||
// Other types
|
||||
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
||||
{"bytea", Column{DataType: "bytea"}, "LargeBinary"},
|
||||
|
||||
// User-defined (enum)
|
||||
{"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"},
|
||||
|
||||
// Unknown type
|
||||
{"unknown", Column{DataType: "unknown_type"}, "String"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetSQLAlchemyType(tt.col)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumn(t *testing.T) {
|
||||
col := Column{
|
||||
Name: "test_column",
|
||||
DataType: "varchar",
|
||||
IsNullable: true,
|
||||
ColumnDefault: sql.NullString{Valid: true, String: "default_value"},
|
||||
CharMaxLength: sql.NullInt64{Valid: true, Int64: 100},
|
||||
NumericPrecision: sql.NullInt64{Valid: false},
|
||||
NumericScale: sql.NullInt64{Valid: false},
|
||||
UdtName: "",
|
||||
IsPrimaryKey: false,
|
||||
IsAutoIncrement: false,
|
||||
}
|
||||
|
||||
if col.Name != "test_column" {
|
||||
t.Errorf("Expected Name 'test_column', got %q", col.Name)
|
||||
}
|
||||
|
||||
if !col.IsNullable {
|
||||
t.Error("Expected IsNullable to be true")
|
||||
}
|
||||
|
||||
if !col.ColumnDefault.Valid {
|
||||
t.Error("Expected ColumnDefault to be valid")
|
||||
}
|
||||
|
||||
if col.ColumnDefault.String != "default_value" {
|
||||
t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignKey(t *testing.T) {
|
||||
fk := ForeignKey{
|
||||
ColumnName: "user_id",
|
||||
ForeignTableSchema: "public",
|
||||
ForeignTableName: "users",
|
||||
ForeignColumnName: "id",
|
||||
ConstraintName: "fk_user_id",
|
||||
}
|
||||
|
||||
if fk.ColumnName != "user_id" {
|
||||
t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName)
|
||||
}
|
||||
|
||||
if fk.ForeignTableName != "users" {
|
||||
t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName)
|
||||
}
|
||||
|
||||
if fk.ForeignColumnName != "id" {
|
||||
t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumType(t *testing.T) {
|
||||
enum := EnumType{
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
||||
}
|
||||
|
||||
if enum.TypeName != "status_enum" {
|
||||
t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName)
|
||||
}
|
||||
|
||||
if len(enum.Values) != 3 {
|
||||
t.Errorf("Expected 3 values, got %d", len(enum.Values))
|
||||
}
|
||||
|
||||
expectedValues := []string{"OPEN", "CLOSED", "PENDING"}
|
||||
for i, val := range enum.Values {
|
||||
if val != expectedValues[i] {
|
||||
t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableInfo(t *testing.T) {
|
||||
tableInfo := &TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
},
|
||||
ForeignKeys: []ForeignKey{
|
||||
{ColumnName: "company_id", ForeignTableName: "companies"},
|
||||
},
|
||||
EnumTypes: map[string]EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"ACTIVE", "INACTIVE"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if tableInfo.Schema != "public" {
|
||||
t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema)
|
||||
}
|
||||
|
||||
if tableInfo.TableName != "users" {
|
||||
t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName)
|
||||
}
|
||||
|
||||
if len(tableInfo.Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns))
|
||||
}
|
||||
|
||||
if len(tableInfo.ForeignKeys) != 1 {
|
||||
t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys))
|
||||
}
|
||||
|
||||
if len(tableInfo.EnumTypes) != 1 {
|
||||
t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes))
|
||||
}
|
||||
|
||||
// Test primary key detection
|
||||
foundPK := false
|
||||
for _, col := range tableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
foundPK = true
|
||||
if col.Name != "id" {
|
||||
t.Errorf("Expected primary key to be 'id', got %q", col.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundPK {
|
||||
t.Error("Expected to find primary key column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
cfg := Config{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "testdb",
|
||||
Schema: "public",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
if cfg.Host != "localhost" {
|
||||
t.Errorf("Expected Host 'localhost', got %q", cfg.Host)
|
||||
}
|
||||
|
||||
if cfg.Port != 5432 {
|
||||
t.Errorf("Expected Port 5432, got %d", cfg.Port)
|
||||
}
|
||||
|
||||
if cfg.Database != "testdb" {
|
||||
t.Errorf("Expected Database 'testdb', got %q", cfg.Database)
|
||||
}
|
||||
|
||||
if cfg.Schema != "public" {
|
||||
t.Errorf("Expected Schema 'public', got %q", cfg.Schema)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGetPythonType(b *testing.B) {
|
||||
col := Column{DataType: "character varying"}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetPythonType(col)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetSQLAlchemyType(b *testing.B) {
|
||||
col := Column{
|
||||
DataType: "numeric",
|
||||
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
||||
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetSQLAlchemyType(col)
|
||||
}
|
||||
}
|
||||
199
internal/generator/enum_test.go
Normal file
199
internal/generator/enum_test.go
Normal file
@ -0,0 +1,199 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
)
|
||||
|
||||
func TestSanitizePythonIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"starts with digit", "12_HOUR", "_12_HOUR"},
|
||||
{"starts with letter", "HOUR_12", "HOUR_12"},
|
||||
{"all digits", "24", "_24"},
|
||||
{"underscore first", "_12_HOUR", "_12_HOUR"},
|
||||
{"empty", "", ""},
|
||||
{"single digit", "1", "_1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizePythonIdentifier(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizePythonIdentifier(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateEnum(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enumTypes map[string]database.EnumType
|
||||
expectError bool
|
||||
checkFunc func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "simple enum",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class StatusEnum") {
|
||||
t.Error("Expected class StatusEnum")
|
||||
}
|
||||
if !strings.Contains(result, "OPEN = \"OPEN\"") {
|
||||
t.Error("Expected OPEN = \"OPEN\"")
|
||||
}
|
||||
if !strings.Contains(result, "CLOSED = \"CLOSED\"") {
|
||||
t.Error("Expected CLOSED = \"CLOSED\"")
|
||||
}
|
||||
if !strings.Contains(result, "PENDING = \"PENDING\"") {
|
||||
t.Error("Expected PENDING = \"PENDING\"")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with spaces and hyphens",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"time_format_enum": {
|
||||
TypeName: "time_format_enum",
|
||||
Values: []string{"12-hour", "24-hour"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class TimeFormatEnum") {
|
||||
t.Error("Expected class TimeFormatEnum")
|
||||
}
|
||||
if !strings.Contains(result, "_12_HOUR = \"12-hour\"") {
|
||||
t.Error("Expected _12_HOUR = \"12-hour\" (sanitized)")
|
||||
}
|
||||
if !strings.Contains(result, "_24_HOUR = \"24-hour\"") {
|
||||
t.Error("Expected _24_HOUR = \"24-hour\" (sanitized)")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with duplicates after normalization",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"measurement_enum": {
|
||||
TypeName: "measurement_enum",
|
||||
Values: []string{"international", "INTERNATIONAL", "imperial", "IMPERIAL"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class MeasurementEnum") {
|
||||
t.Error("Expected class MeasurementEnum")
|
||||
}
|
||||
// Should only have one INTERNATIONAL and one IMPERIAL
|
||||
internationalCount := strings.Count(result, "INTERNATIONAL = ")
|
||||
if internationalCount != 1 {
|
||||
t.Errorf("Expected 1 INTERNATIONAL, got %d", internationalCount)
|
||||
}
|
||||
imperialCount := strings.Count(result, "IMPERIAL = ")
|
||||
if imperialCount != 1 {
|
||||
t.Errorf("Expected 1 IMPERIAL, got %d", imperialCount)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty enum",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"empty_enum": {
|
||||
TypeName: "empty_enum",
|
||||
Values: []string{},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class EmptyEnum") {
|
||||
t.Error("Expected class EmptyEnum")
|
||||
}
|
||||
if !strings.Contains(result, "pass") {
|
||||
t.Error("Expected 'pass' for empty enum")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple enums",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"OPEN", "CLOSED"},
|
||||
},
|
||||
"priority_enum": {
|
||||
TypeName: "priority_enum",
|
||||
Values: []string{"HIGH", "LOW"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class StatusEnum") {
|
||||
t.Error("Expected class StatusEnum")
|
||||
}
|
||||
if !strings.Contains(result, "class PriorityEnum") {
|
||||
t.Error("Expected class PriorityEnum")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with special characters",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"special_enum": {
|
||||
TypeName: "special_enum",
|
||||
Values: []string{"IN-PROGRESS", "ON HOLD", "DONE"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "IN_PROGRESS = \"IN-PROGRESS\"") {
|
||||
t.Error("Expected IN_PROGRESS = \"IN-PROGRESS\"")
|
||||
}
|
||||
if !strings.Contains(result, "ON_HOLD = \"ON HOLD\"") {
|
||||
t.Error("Expected ON_HOLD = \"ON HOLD\"")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &Context{
|
||||
TableInfo: &database.TableInfo{
|
||||
EnumTypes: tt.enumTypes,
|
||||
},
|
||||
EntityName: "TestEntity",
|
||||
ModuleName: "test_entity",
|
||||
}
|
||||
|
||||
result, err := GenerateEnum(ctx)
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("GenerateEnum() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.checkFunc != nil {
|
||||
tt.checkFunc(t, result)
|
||||
}
|
||||
|
||||
// Check common requirements
|
||||
if !strings.Contains(result, "from enum import StrEnum") {
|
||||
t.Error("Expected import of StrEnum")
|
||||
}
|
||||
if !strings.Contains(result, "from televend_core.databases.enum import EnumMixin") {
|
||||
t.Error("Expected import of EnumMixin")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -57,15 +57,11 @@ func GetRelationshipModuleName(tableName string) string {
|
||||
}
|
||||
|
||||
// GetFilterFieldName returns the filter field name for a column
|
||||
// For ID fields with IN operator, pluralizes the name
|
||||
// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids)
|
||||
func GetFilterFieldName(columnName string, useIN bool) string {
|
||||
if useIN && strings.HasSuffix(columnName, "_id") {
|
||||
// Remove _id, pluralize, add back _ids
|
||||
base := columnName[:len(columnName)-3]
|
||||
if base == "" {
|
||||
return "ids"
|
||||
}
|
||||
return naming.Pluralize(base) + "_ids"
|
||||
// Replace _id with _ids
|
||||
return columnName[:len(columnName)-2] + "ids"
|
||||
}
|
||||
if useIN && columnName == "id" {
|
||||
return "ids"
|
||||
|
||||
979
internal/generator/generator_test.go
Normal file
979
internal/generator/generator_test.go
Normal file
@ -0,0 +1,979 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
)
|
||||
|
||||
func TestGetRelationshipName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"with _id suffix", "author_info_id", "author_info"},
|
||||
{"with user_id", "user_id", "user"},
|
||||
{"without _id", "status", "status"},
|
||||
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
|
||||
{"empty", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetRelationshipName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRelationshipEntityName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"plural table", "users", "User"},
|
||||
{"compound plural", "user_accounts", "UserAccount"},
|
||||
{"ies ending", "companies", "Company"},
|
||||
{"already singular", "user", "User"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetRelationshipEntityName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRelationshipModuleName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"plural table", "users", "user"},
|
||||
{"compound plural", "user_accounts", "user_account"},
|
||||
{"ies ending", "companies", "company"},
|
||||
{"already singular", "user", "user"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetRelationshipModuleName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFilterFieldName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
useIN bool
|
||||
expected string
|
||||
}{
|
||||
{"id with IN", "id", true, "ids"},
|
||||
{"id without IN", "id", false, "id"},
|
||||
{"user_id with IN", "user_id", true, "user_ids"},
|
||||
{"user_id without IN", "user_id", false, "user_id"},
|
||||
{"machine_id with IN", "machine_id", true, "machine_ids"},
|
||||
{"status without IN", "status", false, "status"},
|
||||
{"status with IN", "status", true, "status"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetFilterFieldName(tt.columnName, tt.useIN)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
|
||||
tt.columnName, tt.useIN, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldGenerateFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
col database.Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "boolean field",
|
||||
col: database.Column{
|
||||
Name: "alive",
|
||||
DataType: "boolean",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "id field",
|
||||
col: database.Column{
|
||||
Name: "id",
|
||||
DataType: "integer",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "foreign key field",
|
||||
col: database.Column{
|
||||
Name: "user_id",
|
||||
DataType: "integer",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "text field",
|
||||
col: database.Column{
|
||||
Name: "name",
|
||||
DataType: "text",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "varchar field",
|
||||
col: database.Column{
|
||||
Name: "email",
|
||||
DataType: "character varying",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "numeric field",
|
||||
col: database.Column{
|
||||
Name: "amount",
|
||||
DataType: "numeric",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "timestamp field",
|
||||
col: database.Column{
|
||||
Name: "created_at",
|
||||
DataType: "timestamp with time zone",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ShouldGenerateFilter(tt.col)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDecimalImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columns []database.Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has numeric column",
|
||||
columns: []database.Column{
|
||||
{Name: "amount", DataType: "numeric"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has decimal column",
|
||||
columns: []database.Column{
|
||||
{Name: "price", DataType: "decimal"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no numeric columns",
|
||||
columns: []database.Column{
|
||||
{Name: "id", DataType: "integer"},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty columns",
|
||||
columns: []database.Column{},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := NeedsDecimalImport(tt.columns)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDatetimeImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columns []database.Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has timestamp column",
|
||||
columns: []database.Column{
|
||||
{Name: "created_at", DataType: "timestamp with time zone"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has date column",
|
||||
columns: []database.Column{
|
||||
{Name: "birth_date", DataType: "date"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no datetime columns",
|
||||
columns: []database.Column{
|
||||
{Name: "id", DataType: "integer"},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := NeedsDatetimeImport(tt.columns)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequiredColumns(t *testing.T) {
|
||||
columns := []database.Column{
|
||||
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
|
||||
}
|
||||
|
||||
result := GetRequiredColumns(columns)
|
||||
|
||||
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
|
||||
if len(result) != 1 {
|
||||
t.Errorf("Expected 1 required column, got %d", len(result))
|
||||
}
|
||||
|
||||
if len(result) > 0 && result[0].Name != "name" {
|
||||
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOptionalColumns(t *testing.T) {
|
||||
columns := []database.Column{
|
||||
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||
}
|
||||
|
||||
result := GetOptionalColumns(columns)
|
||||
|
||||
// Should include 'email' (nullable) and 'created_at' (has default)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("Expected 2 optional columns, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetForeignKeyForColumn(t *testing.T) {
|
||||
fks := []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users"},
|
||||
{ColumnName: "company_id", ForeignTableName: "companies"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
expectNil bool
|
||||
expectFK string
|
||||
}{
|
||||
{"found user_id", "user_id", false, "users"},
|
||||
{"found company_id", "company_id", false, "companies"},
|
||||
{"not found", "status_id", true, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetForeignKeyForColumn(tt.columnName, fks)
|
||||
if tt.expectNil {
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %+v", result)
|
||||
}
|
||||
} else {
|
||||
if result == nil {
|
||||
t.Errorf("Expected FK, got nil")
|
||||
} else if result.ForeignTableName != tt.expectFK {
|
||||
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateInit(t *testing.T) {
|
||||
ctx := &Context{
|
||||
EntityName: "User",
|
||||
ModuleName: "user",
|
||||
}
|
||||
|
||||
result, err := GenerateInit(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateInit failed: %v", err)
|
||||
}
|
||||
|
||||
// __init__.py should be empty
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty __init__.py, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRepository(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{},
|
||||
}
|
||||
|
||||
ctx := NewContext(tableInfo, "")
|
||||
|
||||
result, err := GenerateRepository(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRepository failed: %v", err)
|
||||
}
|
||||
|
||||
// Check for expected content
|
||||
expectedStrings := []string{
|
||||
"class UserRepository",
|
||||
"CRUDRepository",
|
||||
"model_cls = User",
|
||||
"from televend_core.databases.base_repository import CRUDRepository",
|
||||
"from televend_core.databases.televend_repositories.user.filter import",
|
||||
"from televend_core.databases.televend_repositories.user.model import User",
|
||||
}
|
||||
|
||||
for _, expected := range expectedStrings {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateManager(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{},
|
||||
}
|
||||
|
||||
ctx := NewContext(tableInfo, "")
|
||||
|
||||
result, err := GenerateManager(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateManager failed: %v", err)
|
||||
}
|
||||
|
||||
// Check for expected content
|
||||
expectedStrings := []string{
|
||||
"class UserManager",
|
||||
"CRUDManager",
|
||||
"repository_cls = UserRepository",
|
||||
}
|
||||
|
||||
for _, expected := range expectedStrings {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "user_accounts",
|
||||
}
|
||||
|
||||
// Test without override
|
||||
ctx := NewContext(tableInfo, "")
|
||||
if ctx.EntityName != "UserAccount" {
|
||||
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
|
||||
}
|
||||
if ctx.ModuleName != "user_account" {
|
||||
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
|
||||
}
|
||||
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
|
||||
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
|
||||
}
|
||||
|
||||
// Test with override
|
||||
ctx = NewContext(tableInfo, "CustomUser")
|
||||
if ctx.EntityName != "CustomUser" {
|
||||
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "simple table with basic types",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
|
||||
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from sqlalchemy import",
|
||||
"Column",
|
||||
"Table",
|
||||
"Integer",
|
||||
"String",
|
||||
"USER_TABLE = Table(",
|
||||
`"users"`,
|
||||
"metadata_obj",
|
||||
`Column("id", Integer, primary_key=True, autoincrement=True)`,
|
||||
`Column("name", String(255), nullable=False)`,
|
||||
`Column("email", String)`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "table with foreign keys",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||
{Name: "title", DataType: "text", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"ForeignKey",
|
||||
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "table with enum",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "orders",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{
|
||||
"order_status": {
|
||||
TypeName: "order_status",
|
||||
Values: []string{"pending", "completed", "cancelled"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"Enum",
|
||||
"from televend_core.databases.televend_repositories.order.enum import",
|
||||
"OrderStatus",
|
||||
`Enum(`,
|
||||
`*OrderStatus.to_value_list()`,
|
||||
`name="order_status"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateTable(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateTable failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "simple model",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from dataclasses import dataclass",
|
||||
"from televend_core.databases.base_model import Base",
|
||||
"@dataclass",
|
||||
"class User(Base):",
|
||||
"name: str",
|
||||
"email: str | None = None",
|
||||
"id: int | None = None",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with foreign key",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||
{Name: "title", DataType: "text", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.televend_repositories.user.model import User",
|
||||
"user_id: int",
|
||||
"user: User",
|
||||
"title: str",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with datetime and decimal",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "orders",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "amount", DataType: "numeric", IsNullable: false},
|
||||
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from datetime import datetime",
|
||||
"from decimal import Decimal",
|
||||
"amount: Decimal",
|
||||
"created_at: datetime | None = None",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateModel(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateModel failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
notExpect []string
|
||||
}{
|
||||
{
|
||||
name: "filter with boolean and id fields",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
{Name: "alive", DataType: "boolean"},
|
||||
{Name: "user_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.base_filter import BaseFilter",
|
||||
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
|
||||
"class UserFilter(BaseFilter):",
|
||||
"model_cls = User",
|
||||
"id: int | None = filterfield(operator=EQ)",
|
||||
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
|
||||
"name: str | None = filterfield(operator=EQ)",
|
||||
"alive: bool | None = filterfield(operator=EQ, default=True)",
|
||||
"user_id: int | None = filterfield(operator=EQ)",
|
||||
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
|
||||
},
|
||||
notExpect: []string{
|
||||
"default=None",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter with no filterable fields",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "logs",
|
||||
Columns: []database.Column{
|
||||
{Name: "timestamp", DataType: "timestamp with time zone"},
|
||||
{Name: "amount", DataType: "numeric"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
},
|
||||
expected: []string{
|
||||
"class LogFilter(BaseFilter):",
|
||||
"pass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateFilter(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateFilter failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
for _, notExpected := range tt.notExpect {
|
||||
if strings.Contains(result, notExpected) {
|
||||
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
|
||||
notExpected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLoadOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "load options with relationships",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "user_id", DataType: "integer"},
|
||||
{Name: "category_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.base_load_options import LoadOptions",
|
||||
"from televend_core.databases.common.load_options import joinload",
|
||||
"class PostLoadOptions(LoadOptions):",
|
||||
"model_cls = Post",
|
||||
`load_user: bool = joinload(relations=["user"])`,
|
||||
`load_category: bool = joinload(relations=["category"])`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "load options with no relationships",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "settings",
|
||||
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
},
|
||||
expected: []string{
|
||||
"class SettingLoadOptions(LoadOptions):",
|
||||
"pass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateLoadOptions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateLoadOptions failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFactory(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "factory with basic fields",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
|
||||
{Name: "alive", DataType: "boolean"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from __future__ import annotations",
|
||||
"from typing import Type",
|
||||
"import factory",
|
||||
"class UserFactory(TelevendBaseFactory):",
|
||||
"alive = True",
|
||||
"id = None",
|
||||
`name = factory.Faker("pystr", max_chars=100)`,
|
||||
"class Meta:",
|
||||
"model = User",
|
||||
"def create_minimal",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "factory with foreign keys",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||
{Name: "title", DataType: "text"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.televend_repositories.user.factory import",
|
||||
"UserFactory",
|
||||
`user = CustomSelfAttribute("..user", UserFactory)`,
|
||||
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
|
||||
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "factory with decimal field",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "orders",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateFactory(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateFactory failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMapper(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "mapper with single foreign key",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "user_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"mapper_registry.map_imperatively(",
|
||||
"class_=Post,",
|
||||
"local_table=POST_TABLE,",
|
||||
"properties={",
|
||||
`"user": relationship(`,
|
||||
"User, lazy=relationship_loading_strategy.value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mapper with multiple foreign keys to same table",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "messages",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "sender_id", DataType: "integer"},
|
||||
{Name: "receiver_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
`"sender": relationship(`,
|
||||
"User, lazy=relationship_loading_strategy.value",
|
||||
`"receiver": relationship(`,
|
||||
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateMapper(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateMapper failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPythonTypeForColumn(t *testing.T) {
|
||||
ctx := &Context{
|
||||
TableInfo: &database.TableInfo{
|
||||
EnumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"active", "inactive"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
col database.Column
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "integer type",
|
||||
col: database.Column{DataType: "integer"},
|
||||
expected: "int",
|
||||
},
|
||||
{
|
||||
name: "varchar type",
|
||||
col: database.Column{DataType: "varchar"},
|
||||
expected: "str",
|
||||
},
|
||||
{
|
||||
name: "enum type",
|
||||
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
|
||||
expected: "StatusEnum",
|
||||
},
|
||||
{
|
||||
name: "unknown enum type",
|
||||
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
|
||||
expected: "str",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPythonTypeForColumn(tt.col, ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -155,8 +155,12 @@ func Pluralize(word string) string {
|
||||
return preserveCase(word, plural)
|
||||
}
|
||||
|
||||
// Already plural (ends in 's' and not special case)
|
||||
if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") {
|
||||
// Already plural (ends in 's' after a consonant, but not 'ss', 'us', 'is')
|
||||
// Skip this check if word ends in 'ss' (like 'class', 'glass')
|
||||
if strings.HasSuffix(lower, "s") &&
|
||||
!strings.HasSuffix(lower, "ss") &&
|
||||
!strings.HasSuffix(lower, "us") &&
|
||||
!strings.HasSuffix(lower, "is") {
|
||||
return word
|
||||
}
|
||||
|
||||
|
||||
247
internal/naming/naming_test.go
Normal file
247
internal/naming/naming_test.go
Normal file
@ -0,0 +1,247 @@
|
||||
package naming
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSingularize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// Regular plurals
|
||||
{"simple s", "users", "user"},
|
||||
{"simple s uppercase", "Users", "User"},
|
||||
{"tables", "tables", "table"},
|
||||
|
||||
// ies -> y
|
||||
{"ies to y", "companies", "company"},
|
||||
{"ies to y uppercase", "Companies", "Company"},
|
||||
{"cities", "cities", "city"},
|
||||
{"categories", "categories", "category"},
|
||||
|
||||
// ves -> f/fe
|
||||
{"ves to f", "halves", "half"},
|
||||
{"ves to fe - knife", "knives", "knife"},
|
||||
{"ves to fe - wife", "wives", "wife"},
|
||||
{"ves to fe - life", "lives", "life"},
|
||||
|
||||
// es endings
|
||||
{"xes", "boxes", "box"},
|
||||
{"shes", "dishes", "dish"},
|
||||
{"ches", "watches", "watch"},
|
||||
{"ses", "classes", "class"},
|
||||
|
||||
// oes -> o
|
||||
{"oes", "tomatoes", "tomato"},
|
||||
{"heroes", "heroes", "hero"},
|
||||
|
||||
// Irregular plurals
|
||||
{"people", "people", "person"},
|
||||
{"children", "children", "child"},
|
||||
{"men", "men", "man"},
|
||||
{"women", "women", "woman"},
|
||||
{"teeth", "teeth", "tooth"},
|
||||
{"feet", "feet", "foot"},
|
||||
{"mice", "mice", "mouse"},
|
||||
{"geese", "geese", "goose"},
|
||||
|
||||
// Unchanged
|
||||
{"sheep", "sheep", "sheep"},
|
||||
{"fish", "fish", "fish"},
|
||||
{"deer", "deer", "deer"},
|
||||
{"series", "series", "series"},
|
||||
{"species", "species", "species"},
|
||||
|
||||
// Special cases
|
||||
{"data", "data", "datum"},
|
||||
{"analyses", "analyses", "analysis"},
|
||||
{"crises", "crises", "crisis"},
|
||||
|
||||
// Edge cases
|
||||
{"empty", "", ""},
|
||||
{"already singular", "user", "user"},
|
||||
{"ss ending", "glass", "glass"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := Singularize(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Singularize(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPluralize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
// Regular plurals
|
||||
{"simple", "user", "users"},
|
||||
{"simple uppercase", "User", "Users"},
|
||||
{"table", "table", "tables"},
|
||||
|
||||
// y -> ies
|
||||
{"y to ies", "company", "companies"},
|
||||
{"y to ies uppercase", "Company", "Companies"},
|
||||
{"city", "city", "cities"},
|
||||
{"category", "category", "categories"},
|
||||
|
||||
// consonant + y stays
|
||||
{"ay ending", "day", "days"},
|
||||
{"ey ending", "key", "keys"},
|
||||
{"oy ending", "boy", "boys"},
|
||||
|
||||
// f/fe -> ves
|
||||
{"f to ves", "half", "halves"},
|
||||
{"fe to ves - knife", "knife", "knives"},
|
||||
{"fe to ves - wife", "wife", "wives"},
|
||||
{"fe to ves - life", "life", "lives"},
|
||||
|
||||
// s, x, sh, ch -> es
|
||||
{"x to es", "box", "boxes"},
|
||||
{"sh to es", "dish", "dishes"},
|
||||
{"ch to es", "watch", "watches"},
|
||||
{"s to es", "class", "classes"},
|
||||
|
||||
// o -> oes
|
||||
{"o to oes", "tomato", "tomatoes"},
|
||||
{"hero", "hero", "heroes"},
|
||||
|
||||
// Irregular plurals
|
||||
{"person", "person", "people"},
|
||||
{"child", "child", "children"},
|
||||
{"man", "man", "men"},
|
||||
{"woman", "woman", "women"},
|
||||
{"tooth", "tooth", "teeth"},
|
||||
{"foot", "foot", "feet"},
|
||||
{"mouse", "mouse", "mice"},
|
||||
{"goose", "goose", "geese"},
|
||||
|
||||
// Unchanged
|
||||
{"sheep", "sheep", "sheep"},
|
||||
{"fish", "fish", "fish"},
|
||||
{"deer", "deer", "deer"},
|
||||
{"series", "series", "series"},
|
||||
{"species", "species", "species"},
|
||||
|
||||
// Edge cases
|
||||
{"empty", "", ""},
|
||||
{"already plural", "users", "users"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := Pluralize(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSingularizeTableName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "users", "user"},
|
||||
{"compound", "user_accounts", "user_account"},
|
||||
{"triple compound", "user_login_histories", "user_login_history"},
|
||||
{"already singular", "user", "user"},
|
||||
{"ies ending", "cashbag_conforms", "cashbag_conform"},
|
||||
{"complex table", "auth_user_groups", "auth_user_group"},
|
||||
{"empty", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := SingularizeTableName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("SingularizeTableName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToPascalCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"simple", "user", "User"},
|
||||
{"snake_case", "user_account", "UserAccount"},
|
||||
{"triple", "user_login_history", "UserLoginHistory"},
|
||||
{"already pascal", "UserAccount", "UserAccount"},
|
||||
{"single char", "a", "A"},
|
||||
{"empty", "", ""},
|
||||
{"with numbers", "user_2fa", "User2fa"},
|
||||
{"multiple underscores", "user___account", "UserAccount"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ToPascalCase(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToSnakeCase(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"PascalCase", "UserAccount", "user_account"},
|
||||
{"simple", "User", "user"},
|
||||
{"multiple words", "UserLoginHistory", "user_login_history"},
|
||||
{"already snake", "user_account", "user_account"},
|
||||
{"empty", "", ""},
|
||||
{"single char", "A", "a"},
|
||||
{"with numbers", "User2FA", "user2_f_a"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ToSnakeCase(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ToSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkSingularize(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
Singularize("companies")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPluralize(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
Pluralize("company")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToPascalCase(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ToPascalCase("user_login_history")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkToSnakeCase(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ToSnakeCase("UserLoginHistory")
|
||||
}
|
||||
}
|
||||
319
internal/prompt/prompt_test.go
Normal file
319
internal/prompt/prompt_test.go
Normal file
@ -0,0 +1,319 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockReader implements the Reader interface for testing
|
||||
type mockReader struct {
|
||||
input string
|
||||
pos int
|
||||
}
|
||||
|
||||
func (m *mockReader) ReadString(delim byte) (string, error) {
|
||||
if m.pos >= len(m.input) {
|
||||
return "", io.EOF
|
||||
}
|
||||
|
||||
idx := strings.IndexByte(m.input[m.pos:], delim)
|
||||
if idx == -1 {
|
||||
result := m.input[m.pos:]
|
||||
m.pos = len(m.input)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
result := m.input[m.pos : m.pos+idx+1]
|
||||
m.pos += idx + 1
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func TestPromptStringWithReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultValue string
|
||||
required bool
|
||||
expected string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "user enters value",
|
||||
input: "testvalue\n",
|
||||
defaultValue: "default",
|
||||
required: false,
|
||||
expected: "testvalue",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "user accepts default",
|
||||
input: "\n",
|
||||
defaultValue: "default",
|
||||
required: false,
|
||||
expected: "default",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "required field empty then filled",
|
||||
input: "\nactualvalue\n",
|
||||
defaultValue: "",
|
||||
required: true,
|
||||
expected: "actualvalue",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "optional empty field",
|
||||
input: "\n",
|
||||
defaultValue: "",
|
||||
required: false,
|
||||
expected: "",
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "whitespace trimmed",
|
||||
input: " test \n",
|
||||
defaultValue: "",
|
||||
required: false,
|
||||
expected: "test",
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := &mockReader{input: tt.input}
|
||||
result, err := promptStringWithReader(reader, "Test Label", tt.defaultValue, tt.required)
|
||||
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptIntWithReader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultValue int
|
||||
required bool
|
||||
expected int
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "user enters valid number",
|
||||
input: "42\n",
|
||||
defaultValue: 0,
|
||||
required: false,
|
||||
expected: 42,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "user accepts default",
|
||||
input: "\n",
|
||||
defaultValue: 5432,
|
||||
required: false,
|
||||
expected: 5432,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "invalid then valid number",
|
||||
input: "abc\n123\n",
|
||||
defaultValue: 0,
|
||||
required: false,
|
||||
expected: 123,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "required field empty then filled",
|
||||
input: "\n999\n",
|
||||
defaultValue: 0,
|
||||
required: true,
|
||||
expected: 999,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "zero value",
|
||||
input: "0\n",
|
||||
defaultValue: 10,
|
||||
required: false,
|
||||
expected: 0,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
reader := &mockReader{input: tt.input}
|
||||
result, err := promptIntWithReader(reader, "Test Label", tt.defaultValue, tt.required)
|
||||
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
|
||||
if result != tt.expected {
|
||||
t.Errorf("Expected %d, got %d", tt.expected, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateDirectory(t *testing.T) {
|
||||
// Create a temporary directory for testing
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
setup func() string
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "existing directory",
|
||||
setup: func() string {
|
||||
return tmpDir
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "non-existent directory gets created",
|
||||
setup: func() string {
|
||||
return filepath.Join(tmpDir, "newdir")
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "nested directory gets created",
|
||||
setup: func() string {
|
||||
return filepath.Join(tmpDir, "level1", "level2", "level3")
|
||||
},
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "file exists at path",
|
||||
setup: func() string {
|
||||
filePath := filepath.Join(tmpDir, "testfile")
|
||||
os.WriteFile(filePath, []byte("test"), 0644)
|
||||
return filePath
|
||||
},
|
||||
expectError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
path := tt.setup()
|
||||
err := ValidateDirectory(path)
|
||||
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
|
||||
// If no error expected, verify directory exists
|
||||
if !tt.expectError {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
t.Errorf("Directory should exist: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Errorf("Path should be a directory")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrintFunctions(t *testing.T) {
|
||||
// These functions write to stdout, so we just ensure they don't panic
|
||||
// Testing colored output is complex due to terminal color codes
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fn func()
|
||||
}{
|
||||
{"PrintHeader", func() { PrintHeader("Test Header") }},
|
||||
{"PrintSuccess", func() { PrintSuccess("Success message") }},
|
||||
{"PrintError", func() { PrintError("Error message") }},
|
||||
{"PrintInfo", func() { PrintInfo("Info message") }},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Just ensure the function doesn't panic
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("Function panicked: %v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
tt.fn()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPromptDirectory(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
defaultValue string
|
||||
required bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "valid directory path",
|
||||
input: tmpDir + "\n",
|
||||
defaultValue: "",
|
||||
required: true,
|
||||
expectError: false,
|
||||
},
|
||||
{
|
||||
name: "creates new directory",
|
||||
input: filepath.Join(tmpDir, "newdir") + "\n",
|
||||
defaultValue: "",
|
||||
required: true,
|
||||
expectError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Note: This test would require refactoring PromptDirectory to accept a reader
|
||||
// For now, we test ValidateDirectory which is the core logic
|
||||
if tt.input != "" {
|
||||
path := strings.TrimSpace(tt.input[:len(tt.input)-1])
|
||||
err := ValidateDirectory(path)
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkPromptStringWithReader(b *testing.B) {
|
||||
reader := &mockReader{input: strings.Repeat("test\n", b.N)}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
promptStringWithReader(reader, "Label", "default", false)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkPromptIntWithReader(b *testing.B) {
|
||||
reader := &mockReader{input: strings.Repeat("42\n", b.N)}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
promptIntWithReader(reader, "Label", 0, false)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user