Compare commits
4 Commits
499f59ff12
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 5f5f1ad114 | |||
| 0a4030c389 | |||
| 4e4827d640 | |||
| ca10b01fb0 |
4
.gitignore
vendored
4
.gitignore
vendored
@ -3,3 +3,7 @@
|
|||||||
|
|
||||||
*.toml
|
*.toml
|
||||||
/entity-maker
|
/entity-maker
|
||||||
|
|
||||||
|
# Test coverage
|
||||||
|
coverage.out
|
||||||
|
coverage.html
|
||||||
|
|||||||
122
Makefile
122
Makefile
@ -1,10 +1,130 @@
|
|||||||
EXEC=entity-maker
|
EXEC=entity-maker
|
||||||
|
BUILD_DIR=./build
|
||||||
|
CMD_DIR=./cmd/entity-maker
|
||||||
|
|
||||||
|
# Build variables
|
||||||
|
GOOS ?= linux
|
||||||
|
GOARCH ?= amd64
|
||||||
|
LDFLAGS=-s -w
|
||||||
|
|
||||||
|
|
||||||
.PHONY: build
|
.PHONY: build
|
||||||
build:
|
build:
|
||||||
@go build -ldflags "-s -w" -o ./build/${EXEC} ./cmd/entity-maker/main.go
|
@echo "Building ${EXEC} for ${GOOS}/${GOARCH}..."
|
||||||
|
@mkdir -p ${BUILD_DIR}
|
||||||
|
@CGO_ENABLED=0 GOOS=${GOOS} GOARCH=${GOARCH} go build \
|
||||||
|
-ldflags "${LDFLAGS}" \
|
||||||
|
-trimpath \
|
||||||
|
-o ${BUILD_DIR}/${EXEC} \
|
||||||
|
${CMD_DIR}
|
||||||
|
@echo "✓ Binary created: ${BUILD_DIR}/${EXEC}"
|
||||||
|
@ls -lh ${BUILD_DIR}/${EXEC}
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: build-dev
|
||||||
|
build-dev:
|
||||||
|
@echo "Building ${EXEC} (development mode)..."
|
||||||
|
@mkdir -p ${BUILD_DIR}
|
||||||
|
@go build -o ${BUILD_DIR}/${EXEC} ${CMD_DIR}
|
||||||
|
@echo "✓ Binary created: ${BUILD_DIR}/${EXEC}"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: install
|
||||||
|
install: build
|
||||||
|
@echo "Installing ${EXEC} to /usr/local/bin..."
|
||||||
|
@sudo cp ${BUILD_DIR}/${EXEC} /usr/local/bin/${EXEC}
|
||||||
|
@echo "✓ Installed successfully"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: clean
|
||||||
|
clean:
|
||||||
|
@echo "Cleaning build artifacts..."
|
||||||
|
@rm -rf ${BUILD_DIR}
|
||||||
|
@rm -f ${EXEC}
|
||||||
|
@echo "✓ Clean complete"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test
|
||||||
|
test:
|
||||||
|
@echo "Running tests..."
|
||||||
|
@go test ./...
|
||||||
|
@echo "✓ Tests passed"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test-verbose
|
||||||
|
test-verbose:
|
||||||
|
@echo "Running tests (verbose)..."
|
||||||
|
@go test -v ./...
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test-coverage
|
||||||
|
test-coverage:
|
||||||
|
@echo "Running tests with coverage..."
|
||||||
|
@go test -cover ./...
|
||||||
|
@echo "✓ Tests completed"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: coverage
|
||||||
|
coverage:
|
||||||
|
@echo "Generating coverage report..."
|
||||||
|
@go test -coverprofile=coverage.out ./...
|
||||||
|
@go tool cover -html=coverage.out -o coverage.html
|
||||||
|
@echo "✓ Coverage report generated: coverage.html"
|
||||||
|
@echo " Open coverage.html in your browser to view detailed coverage"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: test-short
|
||||||
|
test-short:
|
||||||
|
@echo "Running short tests..."
|
||||||
|
@go test -short ./...
|
||||||
|
@echo "✓ Short tests passed"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: bench
|
||||||
|
bench:
|
||||||
|
@echo "Running benchmarks..."
|
||||||
|
@go test -bench=. -benchmem ./internal/naming
|
||||||
|
@echo "✓ Benchmarks completed"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: verify-static
|
||||||
|
verify-static: build
|
||||||
|
@echo "Verifying binary is statically linked..."
|
||||||
|
@file ${BUILD_DIR}/${EXEC}
|
||||||
|
@ldd ${BUILD_DIR}/${EXEC} 2>&1 || echo "✓ Binary is statically linked"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: upgrade-packages
|
||||||
upgrade-packages:
|
upgrade-packages:
|
||||||
|
@echo "Upgrading Go packages..."
|
||||||
@go get -u ./...
|
@go get -u ./...
|
||||||
|
@go mod tidy
|
||||||
|
@echo "✓ Packages upgraded"
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: help
|
||||||
|
help:
|
||||||
|
@echo "Entity Maker - Makefile targets:"
|
||||||
|
@echo ""
|
||||||
|
@echo "Build targets:"
|
||||||
|
@echo " build Build static binary for Linux (default)"
|
||||||
|
@echo " build-dev Build development binary with debug symbols"
|
||||||
|
@echo " install Install binary to /usr/local/bin"
|
||||||
|
@echo " clean Remove build artifacts"
|
||||||
|
@echo " verify-static Verify binary has no dynamic dependencies"
|
||||||
|
@echo ""
|
||||||
|
@echo "Test targets:"
|
||||||
|
@echo " test Run all tests"
|
||||||
|
@echo " test-verbose Run tests with verbose output"
|
||||||
|
@echo " test-coverage Run tests and show coverage summary"
|
||||||
|
@echo " coverage Generate HTML coverage report"
|
||||||
|
@echo " test-short Run only short tests"
|
||||||
|
@echo " bench Run benchmarks"
|
||||||
|
@echo ""
|
||||||
|
@echo "Other targets:"
|
||||||
|
@echo " upgrade-packages Update Go dependencies"
|
||||||
|
@echo " help Show this help message"
|
||||||
|
@echo ""
|
||||||
|
@echo "Build options:"
|
||||||
|
@echo " GOOS=linux GOARCH=amd64 make build (default)"
|
||||||
|
@echo ""
|
||||||
|
|||||||
80
README.md
80
README.md
@ -21,10 +21,90 @@ A command-line tool for generating SQLAlchemy entities from PostgreSQL database
|
|||||||
|
|
||||||
### Build from source
|
### Build from source
|
||||||
|
|
||||||
|
Using Makefile (recommended):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
# Build static binary (production-ready, fully portable)
|
||||||
|
make build
|
||||||
|
|
||||||
|
# Build with debug symbols (development)
|
||||||
|
make build-dev
|
||||||
|
|
||||||
|
# Install to /usr/local/bin
|
||||||
|
make install
|
||||||
|
|
||||||
|
# Clean build artifacts
|
||||||
|
make clean
|
||||||
|
|
||||||
|
# Verify binary has no dependencies
|
||||||
|
make verify-static
|
||||||
|
|
||||||
|
# Show all available targets
|
||||||
|
make help
|
||||||
|
```
|
||||||
|
|
||||||
|
Or using Go directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Static binary (portable across all Linux x86-64 systems)
|
||||||
|
CGO_ENABLED=0 go build -ldflags "-s -w" -trimpath -o ./build/entity-maker ./cmd/entity-maker
|
||||||
|
|
||||||
|
# Development binary with debug symbols
|
||||||
go build -o entity-maker ./cmd/entity-maker
|
go build -o entity-maker ./cmd/entity-maker
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The static binary (`make build`) is **fully portable** and will run on any Linux x86-64 system without external dependencies.
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
Run tests using Makefile commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
make test
|
||||||
|
|
||||||
|
# Run tests with verbose output
|
||||||
|
make test-verbose
|
||||||
|
|
||||||
|
# Run tests with coverage summary
|
||||||
|
make test-coverage
|
||||||
|
|
||||||
|
# Generate HTML coverage report
|
||||||
|
make coverage
|
||||||
|
# Then open coverage.html in your browser
|
||||||
|
|
||||||
|
# Run only short/fast tests
|
||||||
|
make test-short
|
||||||
|
|
||||||
|
# Run benchmarks
|
||||||
|
make bench
|
||||||
|
```
|
||||||
|
|
||||||
|
Or using Go directly:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
go test ./...
|
||||||
|
|
||||||
|
# Run with coverage
|
||||||
|
go test -cover ./...
|
||||||
|
|
||||||
|
# Run specific package
|
||||||
|
go test ./internal/naming -v
|
||||||
|
|
||||||
|
# Generate coverage report
|
||||||
|
go test -coverprofile=coverage.out ./...
|
||||||
|
go tool cover -html=coverage.out
|
||||||
|
```
|
||||||
|
|
||||||
|
**Test Coverage:**
|
||||||
|
- `internal/naming`: 97.1% ⭐
|
||||||
|
- `internal/generator`: 91.1% ⭐
|
||||||
|
- `internal/config`: 81.8%
|
||||||
|
- `internal/prompt`: 75.3%
|
||||||
|
- `internal/database`: 30.8%
|
||||||
|
- `cmd/entity-maker`: 0.0% (integration code - business logic tested in internal packages)
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
### Interactive Mode
|
### Interactive Mode
|
||||||
|
|||||||
@ -165,7 +165,6 @@ func run() error {
|
|||||||
|
|
||||||
// Introspect table
|
// Introspect table
|
||||||
prompt.PrintHeader("Introspecting Table")
|
prompt.PrintHeader("Introspecting Table")
|
||||||
|
|
||||||
tableInfo, err := dbClient.IntrospectTable(cfg.DBTable)
|
tableInfo, err := dbClient.IntrospectTable(cfg.DBTable)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to introspect table: %w", err)
|
return fmt.Errorf("failed to introspect table: %w", err)
|
||||||
@ -225,7 +224,6 @@ func run() error {
|
|||||||
|
|
||||||
// Print summary
|
// Print summary
|
||||||
prompt.PrintHeader("Summary")
|
prompt.PrintHeader("Summary")
|
||||||
fmt.Printf("\n")
|
|
||||||
fmt.Printf(" Entity name: %s\n", color.GreenString(ctx.EntityName))
|
fmt.Printf(" Entity name: %s\n", color.GreenString(ctx.EntityName))
|
||||||
fmt.Printf(" Module name: %s\n", color.GreenString(ctx.ModuleName))
|
fmt.Printf(" Module name: %s\n", color.GreenString(ctx.ModuleName))
|
||||||
fmt.Printf(" Output dir: %s\n", color.GreenString(moduleDir))
|
fmt.Printf(" Output dir: %s\n", color.GreenString(moduleDir))
|
||||||
|
|||||||
299
cmd/entity-maker/main_test.go
Normal file
299
cmd/entity-maker/main_test.go
Normal file
@ -0,0 +1,299 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/generator"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Note: The cmd/entity-maker package has low test coverage by design.
|
||||||
|
// The run() function is integration code that orchestrates calls to:
|
||||||
|
// - Prompt functions (requires user input)
|
||||||
|
// - Database connections (requires live PostgreSQL)
|
||||||
|
// - File system operations (requires specific paths)
|
||||||
|
//
|
||||||
|
// All the business logic is tested in the internal/* packages where
|
||||||
|
// coverage is high (91-97%). The tests below validate the file generation
|
||||||
|
// workflow in isolation without needing the full integration environment.
|
||||||
|
|
||||||
|
// TestPackageCompiles validates that the package compiles correctly
|
||||||
|
func TestPackageCompiles(t *testing.T) {
|
||||||
|
// This test ensures the main package compiles correctly
|
||||||
|
// The main() function itself is hard to test as it calls os.Exit()
|
||||||
|
// and depends on external resources (database, user input, file system)
|
||||||
|
// If this test runs, the package compiled successfully
|
||||||
|
t.Log("Main package compiled successfully")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFileGeneration tests the file generation logic in isolation
|
||||||
|
func TestFileGeneration(t *testing.T) {
|
||||||
|
// Create a temporary directory for testing
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
moduleName := "test_entity"
|
||||||
|
moduleDir := filepath.Join(tmpDir, moduleName)
|
||||||
|
|
||||||
|
// Create the module directory
|
||||||
|
if err := os.MkdirAll(moduleDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create module directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a simple test context
|
||||||
|
tableInfo := &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "test_table",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := generator.NewContext(tableInfo, "")
|
||||||
|
|
||||||
|
// Define files to generate (same as in main.go)
|
||||||
|
files := map[string]func(*generator.Context) (string, error){
|
||||||
|
"table.py": generator.GenerateTable,
|
||||||
|
"model.py": generator.GenerateModel,
|
||||||
|
"filter.py": generator.GenerateFilter,
|
||||||
|
"load_options.py": generator.GenerateLoadOptions,
|
||||||
|
"repository.py": generator.GenerateRepository,
|
||||||
|
"manager.py": generator.GenerateManager,
|
||||||
|
"factory.py": generator.GenerateFactory,
|
||||||
|
"mapper.py": generator.GenerateMapper,
|
||||||
|
"__init__.py": generator.GenerateInit,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate and write each file
|
||||||
|
for filename, genFunc := range files {
|
||||||
|
content, err := genFunc(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to generate %s: %v", filename, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(moduleDir, filename)
|
||||||
|
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||||
|
t.Errorf("Failed to write %s: %v", filename, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file was created
|
||||||
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("File %s was not created", filename)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file has content
|
||||||
|
if len(content) == 0 && filename != "__init__.py" {
|
||||||
|
t.Errorf("File %s has no content", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all expected files exist
|
||||||
|
expectedFiles := []string{
|
||||||
|
"table.py", "model.py", "filter.py", "load_options.py",
|
||||||
|
"repository.py", "manager.py", "factory.py", "mapper.py", "__init__.py",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, filename := range expectedFiles {
|
||||||
|
filePath := filepath.Join(moduleDir, filename)
|
||||||
|
if _, err := os.Stat(filePath); os.IsNotExist(err) {
|
||||||
|
t.Errorf("Expected file %s does not exist", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestFileGenerationWithEnums tests file generation when enum types are present
|
||||||
|
func TestFileGenerationWithEnums(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
moduleName := "test_entity"
|
||||||
|
moduleDir := filepath.Join(tmpDir, moduleName)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(moduleDir, 0755); err != nil {
|
||||||
|
t.Fatalf("Failed to create module directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test context with enums
|
||||||
|
tableInfo := &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "test_table",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "status", DataType: "USER-DEFINED", UdtName: "status_enum", IsNullable: false},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{
|
||||||
|
"status_enum": {
|
||||||
|
TypeName: "status_enum",
|
||||||
|
Values: []string{"active", "inactive", "pending"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := generator.NewContext(tableInfo, "")
|
||||||
|
|
||||||
|
files := map[string]func(*generator.Context) (string, error){
|
||||||
|
"table.py": generator.GenerateTable,
|
||||||
|
"model.py": generator.GenerateModel,
|
||||||
|
"filter.py": generator.GenerateFilter,
|
||||||
|
"load_options.py": generator.GenerateLoadOptions,
|
||||||
|
"repository.py": generator.GenerateRepository,
|
||||||
|
"manager.py": generator.GenerateManager,
|
||||||
|
"factory.py": generator.GenerateFactory,
|
||||||
|
"mapper.py": generator.GenerateMapper,
|
||||||
|
"__init__.py": generator.GenerateInit,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add enum.py since we have enum types
|
||||||
|
if len(tableInfo.EnumTypes) > 0 {
|
||||||
|
files["enum.py"] = generator.GenerateEnum
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate all files
|
||||||
|
for filename, genFunc := range files {
|
||||||
|
content, err := genFunc(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Failed to generate %s: %v", filename, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(moduleDir, filename)
|
||||||
|
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||||
|
t.Errorf("Failed to write %s: %v", filename, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify enum.py was created
|
||||||
|
enumPath := filepath.Join(moduleDir, "enum.py")
|
||||||
|
if _, err := os.Stat(enumPath); os.IsNotExist(err) {
|
||||||
|
t.Error("enum.py was not created when enum types are present")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestModuleDirectoryCreation tests directory creation logic
|
||||||
|
func TestModuleDirectoryCreation(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
moduleName string
|
||||||
|
expectErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple module name",
|
||||||
|
moduleName: "user",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested module name",
|
||||||
|
moduleName: filepath.Join("nested", "module"),
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "module with underscore",
|
||||||
|
moduleName: "user_account",
|
||||||
|
expectErr: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
moduleDir := filepath.Join(tmpDir, tt.moduleName)
|
||||||
|
err := os.MkdirAll(moduleDir, 0755)
|
||||||
|
|
||||||
|
if (err != nil) != tt.expectErr {
|
||||||
|
t.Errorf("MkdirAll() error = %v, expectErr %v", err, tt.expectErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tt.expectErr {
|
||||||
|
// Verify directory exists
|
||||||
|
info, err := os.Stat(moduleDir)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Directory should exist: %v", err)
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
t.Error("Path should be a directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGeneratedFilePermissions tests that generated files have correct permissions
|
||||||
|
func TestGeneratedFilePermissions(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
testFile := filepath.Join(tmpDir, "test.py")
|
||||||
|
|
||||||
|
content := "# Test content\n"
|
||||||
|
if err := os.WriteFile(testFile, []byte(content), 0644); err != nil {
|
||||||
|
t.Fatalf("Failed to write test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
info, err := os.Stat(testFile)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to stat test file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
mode := info.Mode()
|
||||||
|
|
||||||
|
// Check that owner can read and write
|
||||||
|
if mode&0600 != 0600 {
|
||||||
|
t.Error("File should be readable and writable by owner")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that file is not executable
|
||||||
|
if mode&0111 != 0 {
|
||||||
|
t.Error("Python files should not be executable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGeneratedFilesAreNonEmpty tests that generated files have content
|
||||||
|
func TestGeneratedFilesAreNonEmpty(t *testing.T) {
|
||||||
|
tableInfo := &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||||
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := generator.NewContext(tableInfo, "")
|
||||||
|
|
||||||
|
generators := map[string]func(*generator.Context) (string, error){
|
||||||
|
"table": generator.GenerateTable,
|
||||||
|
"model": generator.GenerateModel,
|
||||||
|
"filter": generator.GenerateFilter,
|
||||||
|
"load_options": generator.GenerateLoadOptions,
|
||||||
|
"repository": generator.GenerateRepository,
|
||||||
|
"manager": generator.GenerateManager,
|
||||||
|
"factory": generator.GenerateFactory,
|
||||||
|
"mapper": generator.GenerateMapper,
|
||||||
|
}
|
||||||
|
|
||||||
|
for name, genFunc := range generators {
|
||||||
|
t.Run(name, func(t *testing.T) {
|
||||||
|
content, err := genFunc(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to generate %s: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(content) == 0 {
|
||||||
|
t.Errorf("Generated %s content should not be empty", name)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that content has at least some basic Python structure
|
||||||
|
if name != "mapper" { // mapper is just a snippet
|
||||||
|
if len(content) < 10 {
|
||||||
|
t.Errorf("Generated %s content seems too short: %d bytes", name, len(content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
117
internal/config/config_test.go
Normal file
117
internal/config/config_test.go
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDefaultConfig(t *testing.T) {
|
||||||
|
cfg := DefaultConfig()
|
||||||
|
|
||||||
|
if cfg.DBHost != "localhost" {
|
||||||
|
t.Errorf("Expected DBHost to be 'localhost', got %q", cfg.DBHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.DBPort != 5432 {
|
||||||
|
t.Errorf("Expected DBPort to be 5432, got %d", cfg.DBPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.DBSchema != "public" {
|
||||||
|
t.Errorf("Expected DBSchema to be 'public', got %q", cfg.DBSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.DBUser != "postgres" {
|
||||||
|
t.Errorf("Expected DBUser to be 'postgres', got %q", cfg.DBUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.DBPassword != "postgres" {
|
||||||
|
t.Errorf("Expected DBPassword to be 'postgres', got %q", cfg.DBPassword)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAndLoad(t *testing.T) {
|
||||||
|
// Create a temporary directory and set HOME to it
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
oldHome := os.Getenv("HOME")
|
||||||
|
os.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", oldHome)
|
||||||
|
|
||||||
|
configPath := filepath.Join(tmpDir, ".config", "entity-maker.toml")
|
||||||
|
|
||||||
|
// Create a test config
|
||||||
|
cfg := &Config{
|
||||||
|
DBHost: "testhost",
|
||||||
|
DBPort: 5433,
|
||||||
|
DBName: "testdb",
|
||||||
|
DBSchema: "testschema",
|
||||||
|
DBUser: "testuser",
|
||||||
|
DBPassword: "testpass",
|
||||||
|
DBTable: "testtable",
|
||||||
|
OutputDir: "/tmp/test",
|
||||||
|
EntityName: "TestEntity",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save config
|
||||||
|
err := cfg.Save()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to save config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify file exists
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("Config file was not created: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
loadedCfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify loaded values
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
got interface{}
|
||||||
|
expected interface{}
|
||||||
|
}{
|
||||||
|
{"DBHost", loadedCfg.DBHost, cfg.DBHost},
|
||||||
|
{"DBPort", loadedCfg.DBPort, cfg.DBPort},
|
||||||
|
{"DBName", loadedCfg.DBName, cfg.DBName},
|
||||||
|
{"DBSchema", loadedCfg.DBSchema, cfg.DBSchema},
|
||||||
|
{"DBUser", loadedCfg.DBUser, cfg.DBUser},
|
||||||
|
{"DBPassword", loadedCfg.DBPassword, cfg.DBPassword},
|
||||||
|
{"DBTable", loadedCfg.DBTable, cfg.DBTable},
|
||||||
|
{"OutputDir", loadedCfg.OutputDir, cfg.OutputDir},
|
||||||
|
{"EntityName", loadedCfg.EntityName, cfg.EntityName},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
if tt.got != tt.expected {
|
||||||
|
t.Errorf("%s mismatch: expected %v, got %v", tt.name, tt.expected, tt.got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadNonExistentConfig(t *testing.T) {
|
||||||
|
// Create a temporary directory and set HOME to it
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
oldHome := os.Getenv("HOME")
|
||||||
|
os.Setenv("HOME", tmpDir)
|
||||||
|
defer os.Setenv("HOME", oldHome)
|
||||||
|
|
||||||
|
// Load should return default config when file doesn't exist
|
||||||
|
cfg, err := Load()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Load should not error when file doesn't exist: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's the default config
|
||||||
|
defaultCfg := DefaultConfig()
|
||||||
|
if cfg.DBHost != defaultCfg.DBHost {
|
||||||
|
t.Errorf("Expected default DBHost %q, got %q", defaultCfg.DBHost, cfg.DBHost)
|
||||||
|
}
|
||||||
|
if cfg.DBPort != defaultCfg.DBPort {
|
||||||
|
t.Errorf("Expected default DBPort %d, got %d", defaultCfg.DBPort, cfg.DBPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
341
internal/database/database_test.go
Normal file
341
internal/database/database_test.go
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetPythonType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
col Column
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer", Column{DataType: "integer"}, "int"},
|
||||||
|
{"smallint", Column{DataType: "smallint"}, "int"},
|
||||||
|
{"bigint", Column{DataType: "bigint"}, "int"},
|
||||||
|
|
||||||
|
// Numeric types
|
||||||
|
{"numeric", Column{DataType: "numeric"}, "Decimal"},
|
||||||
|
{"decimal", Column{DataType: "decimal"}, "Decimal"},
|
||||||
|
{"real", Column{DataType: "real"}, "Decimal"},
|
||||||
|
{"double precision", Column{DataType: "double precision"}, "Decimal"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean", Column{DataType: "boolean"}, "bool"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{"varchar", Column{DataType: "character varying"}, "str"},
|
||||||
|
{"varchar short", Column{DataType: "varchar"}, "str"},
|
||||||
|
{"text", Column{DataType: "text"}, "str"},
|
||||||
|
{"char", Column{DataType: "char"}, "str"},
|
||||||
|
{"character", Column{DataType: "character"}, "str"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"},
|
||||||
|
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"},
|
||||||
|
{"timestamp", Column{DataType: "timestamp"}, "datetime"},
|
||||||
|
{"date", Column{DataType: "date"}, "date"},
|
||||||
|
{"time with tz", Column{DataType: "time with time zone"}, "time"},
|
||||||
|
{"time without tz", Column{DataType: "time without time zone"}, "time"},
|
||||||
|
{"time", Column{DataType: "time"}, "time"},
|
||||||
|
|
||||||
|
// JSON types
|
||||||
|
{"json", Column{DataType: "json"}, "dict"},
|
||||||
|
{"jsonb", Column{DataType: "jsonb"}, "dict"},
|
||||||
|
|
||||||
|
// Other types
|
||||||
|
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
||||||
|
{"bytea", Column{DataType: "bytea"}, "bytes"},
|
||||||
|
|
||||||
|
// User-defined (enum)
|
||||||
|
{"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"},
|
||||||
|
|
||||||
|
// Unknown type
|
||||||
|
{"unknown", Column{DataType: "unknown_type"}, "Any"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPythonType(tt.col)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetSQLAlchemyType(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
col Column
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// Integer types
|
||||||
|
{"integer", Column{DataType: "integer"}, "Integer"},
|
||||||
|
{"smallint", Column{DataType: "smallint"}, "SmallInteger"},
|
||||||
|
{"bigint", Column{DataType: "bigint"}, "BigInteger"},
|
||||||
|
|
||||||
|
// Numeric types with precision
|
||||||
|
{
|
||||||
|
"numeric with precision",
|
||||||
|
Column{
|
||||||
|
DataType: "numeric",
|
||||||
|
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
||||||
|
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
||||||
|
},
|
||||||
|
"Numeric(12, 4)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"numeric without precision",
|
||||||
|
Column{DataType: "numeric"},
|
||||||
|
"Numeric",
|
||||||
|
},
|
||||||
|
{"real", Column{DataType: "real"}, "Float"},
|
||||||
|
{"double precision", Column{DataType: "double precision"}, "Float"},
|
||||||
|
|
||||||
|
// Boolean
|
||||||
|
{"boolean", Column{DataType: "boolean"}, "Boolean"},
|
||||||
|
|
||||||
|
// String types
|
||||||
|
{
|
||||||
|
"varchar with length",
|
||||||
|
Column{
|
||||||
|
DataType: "character varying",
|
||||||
|
CharMaxLength: sql.NullInt64{Valid: true, Int64: 255},
|
||||||
|
},
|
||||||
|
"String(255)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"varchar without length",
|
||||||
|
Column{DataType: "varchar"},
|
||||||
|
"String",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"char with length",
|
||||||
|
Column{
|
||||||
|
DataType: "char",
|
||||||
|
CharMaxLength: sql.NullInt64{Valid: true, Int64: 10},
|
||||||
|
},
|
||||||
|
"String(10)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"char without length",
|
||||||
|
Column{DataType: "character"},
|
||||||
|
"String(1)",
|
||||||
|
},
|
||||||
|
{"text", Column{DataType: "text"}, "Text"},
|
||||||
|
|
||||||
|
// Date/Time types
|
||||||
|
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"},
|
||||||
|
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"},
|
||||||
|
{"timestamp", Column{DataType: "timestamp"}, "DateTime"},
|
||||||
|
{"date", Column{DataType: "date"}, "Date"},
|
||||||
|
{"time with tz", Column{DataType: "time with time zone"}, "Time"},
|
||||||
|
{"time without tz", Column{DataType: "time without time zone"}, "Time"},
|
||||||
|
{"time", Column{DataType: "time"}, "Time"},
|
||||||
|
|
||||||
|
// JSON types
|
||||||
|
{"json", Column{DataType: "json"}, "JSON"},
|
||||||
|
{"jsonb", Column{DataType: "jsonb"}, "JSONB"},
|
||||||
|
|
||||||
|
// Other types
|
||||||
|
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
||||||
|
{"bytea", Column{DataType: "bytea"}, "LargeBinary"},
|
||||||
|
|
||||||
|
// User-defined (enum)
|
||||||
|
{"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"},
|
||||||
|
|
||||||
|
// Unknown type
|
||||||
|
{"unknown", Column{DataType: "unknown_type"}, "String"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetSQLAlchemyType(tt.col)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestColumn(t *testing.T) {
|
||||||
|
col := Column{
|
||||||
|
Name: "test_column",
|
||||||
|
DataType: "varchar",
|
||||||
|
IsNullable: true,
|
||||||
|
ColumnDefault: sql.NullString{Valid: true, String: "default_value"},
|
||||||
|
CharMaxLength: sql.NullInt64{Valid: true, Int64: 100},
|
||||||
|
NumericPrecision: sql.NullInt64{Valid: false},
|
||||||
|
NumericScale: sql.NullInt64{Valid: false},
|
||||||
|
UdtName: "",
|
||||||
|
IsPrimaryKey: false,
|
||||||
|
IsAutoIncrement: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.Name != "test_column" {
|
||||||
|
t.Errorf("Expected Name 'test_column', got %q", col.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !col.IsNullable {
|
||||||
|
t.Error("Expected IsNullable to be true")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !col.ColumnDefault.Valid {
|
||||||
|
t.Error("Expected ColumnDefault to be valid")
|
||||||
|
}
|
||||||
|
|
||||||
|
if col.ColumnDefault.String != "default_value" {
|
||||||
|
t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForeignKey(t *testing.T) {
|
||||||
|
fk := ForeignKey{
|
||||||
|
ColumnName: "user_id",
|
||||||
|
ForeignTableSchema: "public",
|
||||||
|
ForeignTableName: "users",
|
||||||
|
ForeignColumnName: "id",
|
||||||
|
ConstraintName: "fk_user_id",
|
||||||
|
}
|
||||||
|
|
||||||
|
if fk.ColumnName != "user_id" {
|
||||||
|
t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fk.ForeignTableName != "users" {
|
||||||
|
t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fk.ForeignColumnName != "id" {
|
||||||
|
t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnumType(t *testing.T) {
|
||||||
|
enum := EnumType{
|
||||||
|
TypeName: "status_enum",
|
||||||
|
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if enum.TypeName != "status_enum" {
|
||||||
|
t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(enum.Values) != 3 {
|
||||||
|
t.Errorf("Expected 3 values, got %d", len(enum.Values))
|
||||||
|
}
|
||||||
|
|
||||||
|
expectedValues := []string{"OPEN", "CLOSED", "PENDING"}
|
||||||
|
for i, val := range enum.Values {
|
||||||
|
if val != expectedValues[i] {
|
||||||
|
t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTableInfo(t *testing.T) {
|
||||||
|
tableInfo := &TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||||
|
{Name: "name", DataType: "varchar"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []ForeignKey{
|
||||||
|
{ColumnName: "company_id", ForeignTableName: "companies"},
|
||||||
|
},
|
||||||
|
EnumTypes: map[string]EnumType{
|
||||||
|
"status_enum": {
|
||||||
|
TypeName: "status_enum",
|
||||||
|
Values: []string{"ACTIVE", "INACTIVE"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableInfo.Schema != "public" {
|
||||||
|
t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tableInfo.TableName != "users" {
|
||||||
|
t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tableInfo.Columns) != 2 {
|
||||||
|
t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tableInfo.ForeignKeys) != 1 {
|
||||||
|
t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tableInfo.EnumTypes) != 1 {
|
||||||
|
t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test primary key detection
|
||||||
|
foundPK := false
|
||||||
|
for _, col := range tableInfo.Columns {
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
foundPK = true
|
||||||
|
if col.Name != "id" {
|
||||||
|
t.Errorf("Expected primary key to be 'id', got %q", col.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundPK {
|
||||||
|
t.Error("Expected to find primary key column")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConfig(t *testing.T) {
|
||||||
|
cfg := Config{
|
||||||
|
Host: "localhost",
|
||||||
|
Port: 5432,
|
||||||
|
Database: "testdb",
|
||||||
|
Schema: "public",
|
||||||
|
User: "testuser",
|
||||||
|
Password: "testpass",
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Host != "localhost" {
|
||||||
|
t.Errorf("Expected Host 'localhost', got %q", cfg.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Port != 5432 {
|
||||||
|
t.Errorf("Expected Port 5432, got %d", cfg.Port)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Database != "testdb" {
|
||||||
|
t.Errorf("Expected Database 'testdb', got %q", cfg.Database)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Schema != "public" {
|
||||||
|
t.Errorf("Expected Schema 'public', got %q", cfg.Schema)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkGetPythonType(b *testing.B) {
|
||||||
|
col := Column{DataType: "character varying"}
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
GetPythonType(col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkGetSQLAlchemyType(b *testing.B) {
|
||||||
|
col := Column{
|
||||||
|
DataType: "numeric",
|
||||||
|
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
||||||
|
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
||||||
|
}
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
GetSQLAlchemyType(col)
|
||||||
|
}
|
||||||
|
}
|
||||||
199
internal/generator/enum_test.go
Normal file
199
internal/generator/enum_test.go
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSanitizePythonIdentifier(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"starts with digit", "12_HOUR", "_12_HOUR"},
|
||||||
|
{"starts with letter", "HOUR_12", "HOUR_12"},
|
||||||
|
{"all digits", "24", "_24"},
|
||||||
|
{"underscore first", "_12_HOUR", "_12_HOUR"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
{"single digit", "1", "_1"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := sanitizePythonIdentifier(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("sanitizePythonIdentifier(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateEnum(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
enumTypes map[string]database.EnumType
|
||||||
|
expectError bool
|
||||||
|
checkFunc func(t *testing.T, result string)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple enum",
|
||||||
|
enumTypes: map[string]database.EnumType{
|
||||||
|
"status_enum": {
|
||||||
|
TypeName: "status_enum",
|
||||||
|
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(result, "class StatusEnum") {
|
||||||
|
t.Error("Expected class StatusEnum")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "OPEN = \"OPEN\"") {
|
||||||
|
t.Error("Expected OPEN = \"OPEN\"")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "CLOSED = \"CLOSED\"") {
|
||||||
|
t.Error("Expected CLOSED = \"CLOSED\"")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "PENDING = \"PENDING\"") {
|
||||||
|
t.Error("Expected PENDING = \"PENDING\"")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enum with spaces and hyphens",
|
||||||
|
enumTypes: map[string]database.EnumType{
|
||||||
|
"time_format_enum": {
|
||||||
|
TypeName: "time_format_enum",
|
||||||
|
Values: []string{"12-hour", "24-hour"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(result, "class TimeFormatEnum") {
|
||||||
|
t.Error("Expected class TimeFormatEnum")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "_12_HOUR = \"12-hour\"") {
|
||||||
|
t.Error("Expected _12_HOUR = \"12-hour\" (sanitized)")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "_24_HOUR = \"24-hour\"") {
|
||||||
|
t.Error("Expected _24_HOUR = \"24-hour\" (sanitized)")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enum with duplicates after normalization",
|
||||||
|
enumTypes: map[string]database.EnumType{
|
||||||
|
"measurement_enum": {
|
||||||
|
TypeName: "measurement_enum",
|
||||||
|
Values: []string{"international", "INTERNATIONAL", "imperial", "IMPERIAL"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(result, "class MeasurementEnum") {
|
||||||
|
t.Error("Expected class MeasurementEnum")
|
||||||
|
}
|
||||||
|
// Should only have one INTERNATIONAL and one IMPERIAL
|
||||||
|
internationalCount := strings.Count(result, "INTERNATIONAL = ")
|
||||||
|
if internationalCount != 1 {
|
||||||
|
t.Errorf("Expected 1 INTERNATIONAL, got %d", internationalCount)
|
||||||
|
}
|
||||||
|
imperialCount := strings.Count(result, "IMPERIAL = ")
|
||||||
|
if imperialCount != 1 {
|
||||||
|
t.Errorf("Expected 1 IMPERIAL, got %d", imperialCount)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty enum",
|
||||||
|
enumTypes: map[string]database.EnumType{
|
||||||
|
"empty_enum": {
|
||||||
|
TypeName: "empty_enum",
|
||||||
|
Values: []string{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(result, "class EmptyEnum") {
|
||||||
|
t.Error("Expected class EmptyEnum")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "pass") {
|
||||||
|
t.Error("Expected 'pass' for empty enum")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple enums",
|
||||||
|
enumTypes: map[string]database.EnumType{
|
||||||
|
"status_enum": {
|
||||||
|
TypeName: "status_enum",
|
||||||
|
Values: []string{"OPEN", "CLOSED"},
|
||||||
|
},
|
||||||
|
"priority_enum": {
|
||||||
|
TypeName: "priority_enum",
|
||||||
|
Values: []string{"HIGH", "LOW"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(result, "class StatusEnum") {
|
||||||
|
t.Error("Expected class StatusEnum")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "class PriorityEnum") {
|
||||||
|
t.Error("Expected class PriorityEnum")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enum with special characters",
|
||||||
|
enumTypes: map[string]database.EnumType{
|
||||||
|
"special_enum": {
|
||||||
|
TypeName: "special_enum",
|
||||||
|
Values: []string{"IN-PROGRESS", "ON HOLD", "DONE"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
checkFunc: func(t *testing.T, result string) {
|
||||||
|
if !strings.Contains(result, "IN_PROGRESS = \"IN-PROGRESS\"") {
|
||||||
|
t.Error("Expected IN_PROGRESS = \"IN-PROGRESS\"")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "ON_HOLD = \"ON HOLD\"") {
|
||||||
|
t.Error("Expected ON_HOLD = \"ON HOLD\"")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := &Context{
|
||||||
|
TableInfo: &database.TableInfo{
|
||||||
|
EnumTypes: tt.enumTypes,
|
||||||
|
},
|
||||||
|
EntityName: "TestEntity",
|
||||||
|
ModuleName: "test_entity",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := GenerateEnum(ctx)
|
||||||
|
if (err != nil) != tt.expectError {
|
||||||
|
t.Errorf("GenerateEnum() error = %v, expectError %v", err, tt.expectError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.checkFunc != nil {
|
||||||
|
tt.checkFunc(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check common requirements
|
||||||
|
if !strings.Contains(result, "from enum import StrEnum") {
|
||||||
|
t.Error("Expected import of StrEnum")
|
||||||
|
}
|
||||||
|
if !strings.Contains(result, "from televend_core.databases.enum import EnumMixin") {
|
||||||
|
t.Error("Expected import of EnumMixin")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -26,23 +26,23 @@ func GenerateFactory(ctx *Context) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for moduleName, entityName := range fkImports {
|
for moduleName, entityName := range fkImports {
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.factory import (\n",
|
||||||
moduleName))
|
moduleName))
|
||||||
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
|
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
|
||||||
b.WriteString(")\n")
|
b.WriteString(")\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import (\n",
|
||||||
ctx.ModuleName))
|
ctx.ModuleName))
|
||||||
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
|
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
|
||||||
b.WriteString(")\n")
|
b.WriteString(")\n")
|
||||||
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
|
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
|
||||||
b.WriteString(" CustomSelfAttribute,\n")
|
b.WriteString(" CustomSelfAttribute,\n")
|
||||||
b.WriteString(" TelevendBaseFactory,\n")
|
b.WriteString(" CloudBaseFactory,\n")
|
||||||
b.WriteString(")\n\n\n")
|
b.WriteString(")\n\n\n")
|
||||||
|
|
||||||
// Class definition
|
// Class definition
|
||||||
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
|
b.WriteString(fmt.Sprintf("class %sFactory(CloudBaseFactory):\n", ctx.EntityName))
|
||||||
|
|
||||||
// Add boolean fields with defaults
|
// Add boolean fields with defaults
|
||||||
for _, col := range ctx.TableInfo.Columns {
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
|||||||
@ -14,7 +14,7 @@ func GenerateFilter(ctx *Context) (string, error) {
|
|||||||
// Imports
|
// Imports
|
||||||
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
|
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
|
||||||
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
|
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
|
||||||
ctx.ModuleName, ctx.EntityName))
|
ctx.ModuleName, ctx.EntityName))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
@ -64,8 +64,8 @@ func GenerateFilter(ctx *Context) (string, error) {
|
|||||||
|
|
||||||
// Text fields
|
// Text fields
|
||||||
if (col.DataType == "character varying" || col.DataType == "varchar" ||
|
if (col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character") &&
|
col.DataType == "text" || col.DataType == "char" || col.DataType == "character") &&
|
||||||
!strings.HasSuffix(col.Name, "_id") && col.Name != "id" {
|
!strings.HasSuffix(col.Name, "_id") && col.Name != "id" {
|
||||||
b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n",
|
b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n",
|
||||||
col.Name))
|
col.Name))
|
||||||
}
|
}
|
||||||
|
|||||||
@ -57,15 +57,11 @@ func GetRelationshipModuleName(tableName string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetFilterFieldName returns the filter field name for a column
|
// GetFilterFieldName returns the filter field name for a column
|
||||||
// For ID fields with IN operator, pluralizes the name
|
// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids)
|
||||||
func GetFilterFieldName(columnName string, useIN bool) string {
|
func GetFilterFieldName(columnName string, useIN bool) string {
|
||||||
if useIN && strings.HasSuffix(columnName, "_id") {
|
if useIN && strings.HasSuffix(columnName, "_id") {
|
||||||
// Remove _id, pluralize, add back _ids
|
// Replace _id with _ids
|
||||||
base := columnName[:len(columnName)-3]
|
return columnName[:len(columnName)-2] + "ids"
|
||||||
if base == "" {
|
|
||||||
return "ids"
|
|
||||||
}
|
|
||||||
return naming.Pluralize(base) + "_ids"
|
|
||||||
}
|
}
|
||||||
if useIN && columnName == "id" {
|
if useIN && columnName == "id" {
|
||||||
return "ids"
|
return "ids"
|
||||||
|
|||||||
979
internal/generator/generator_test.go
Normal file
979
internal/generator/generator_test.go
Normal file
@ -0,0 +1,979 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetRelationshipName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"with _id suffix", "author_info_id", "author_info"},
|
||||||
|
{"with user_id", "user_id", "user"},
|
||||||
|
{"without _id", "status", "status"},
|
||||||
|
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
|
||||||
|
{"empty", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetRelationshipName(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRelationshipEntityName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"plural table", "users", "User"},
|
||||||
|
{"compound plural", "user_accounts", "UserAccount"},
|
||||||
|
{"ies ending", "companies", "Company"},
|
||||||
|
{"already singular", "user", "User"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetRelationshipEntityName(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRelationshipModuleName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"plural table", "users", "user"},
|
||||||
|
{"compound plural", "user_accounts", "user_account"},
|
||||||
|
{"ies ending", "companies", "company"},
|
||||||
|
{"already singular", "user", "user"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetRelationshipModuleName(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetFilterFieldName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columnName string
|
||||||
|
useIN bool
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"id with IN", "id", true, "ids"},
|
||||||
|
{"id without IN", "id", false, "id"},
|
||||||
|
{"user_id with IN", "user_id", true, "user_ids"},
|
||||||
|
{"user_id without IN", "user_id", false, "user_id"},
|
||||||
|
{"machine_id with IN", "machine_id", true, "machine_ids"},
|
||||||
|
{"status without IN", "status", false, "status"},
|
||||||
|
{"status with IN", "status", true, "status"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetFilterFieldName(tt.columnName, tt.useIN)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
|
||||||
|
tt.columnName, tt.useIN, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldGenerateFilter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
col database.Column
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "boolean field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "alive",
|
||||||
|
DataType: "boolean",
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "id field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "id",
|
||||||
|
DataType: "integer",
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "foreign key field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "user_id",
|
||||||
|
DataType: "integer",
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "text field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "name",
|
||||||
|
DataType: "text",
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "varchar field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "email",
|
||||||
|
DataType: "character varying",
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "numeric field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "amount",
|
||||||
|
DataType: "numeric",
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "timestamp field",
|
||||||
|
col: database.Column{
|
||||||
|
Name: "created_at",
|
||||||
|
DataType: "timestamp with time zone",
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ShouldGenerateFilter(tt.col)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedsDecimalImport(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columns []database.Column
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "has numeric column",
|
||||||
|
columns: []database.Column{
|
||||||
|
{Name: "amount", DataType: "numeric"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "has decimal column",
|
||||||
|
columns: []database.Column{
|
||||||
|
{Name: "price", DataType: "decimal"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no numeric columns",
|
||||||
|
columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer"},
|
||||||
|
{Name: "name", DataType: "varchar"},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty columns",
|
||||||
|
columns: []database.Column{},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := NeedsDecimalImport(tt.columns)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNeedsDatetimeImport(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columns []database.Column
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "has timestamp column",
|
||||||
|
columns: []database.Column{
|
||||||
|
{Name: "created_at", DataType: "timestamp with time zone"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "has date column",
|
||||||
|
columns: []database.Column{
|
||||||
|
{Name: "birth_date", DataType: "date"},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no datetime columns",
|
||||||
|
columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer"},
|
||||||
|
{Name: "name", DataType: "varchar"},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := NeedsDatetimeImport(tt.columns)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRequiredColumns(t *testing.T) {
|
||||||
|
columns := []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
||||||
|
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
|
||||||
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||||
|
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||||
|
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := GetRequiredColumns(columns)
|
||||||
|
|
||||||
|
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
|
||||||
|
if len(result) != 1 {
|
||||||
|
t.Errorf("Expected 1 required column, got %d", len(result))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) > 0 && result[0].Name != "name" {
|
||||||
|
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOptionalColumns(t *testing.T) {
|
||||||
|
columns := []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
||||||
|
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||||
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||||
|
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
result := GetOptionalColumns(columns)
|
||||||
|
|
||||||
|
// Should include 'email' (nullable) and 'created_at' (has default)
|
||||||
|
if len(result) != 2 {
|
||||||
|
t.Errorf("Expected 2 optional columns, got %d", len(result))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetForeignKeyForColumn(t *testing.T) {
|
||||||
|
fks := []database.ForeignKey{
|
||||||
|
{ColumnName: "user_id", ForeignTableName: "users"},
|
||||||
|
{ColumnName: "company_id", ForeignTableName: "companies"},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
columnName string
|
||||||
|
expectNil bool
|
||||||
|
expectFK string
|
||||||
|
}{
|
||||||
|
{"found user_id", "user_id", false, "users"},
|
||||||
|
{"found company_id", "company_id", false, "companies"},
|
||||||
|
{"not found", "status_id", true, ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetForeignKeyForColumn(tt.columnName, fks)
|
||||||
|
if tt.expectNil {
|
||||||
|
if result != nil {
|
||||||
|
t.Errorf("Expected nil, got %+v", result)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if result == nil {
|
||||||
|
t.Errorf("Expected FK, got nil")
|
||||||
|
} else if result.ForeignTableName != tt.expectFK {
|
||||||
|
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateInit(t *testing.T) {
|
||||||
|
ctx := &Context{
|
||||||
|
EntityName: "User",
|
||||||
|
ModuleName: "user",
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := GenerateInit(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateInit failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// __init__.py should be empty
|
||||||
|
if result != "" {
|
||||||
|
t.Errorf("Expected empty __init__.py, got %q", result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateRepository(t *testing.T) {
|
||||||
|
tableInfo := &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := NewContext(tableInfo, "")
|
||||||
|
|
||||||
|
result, err := GenerateRepository(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateRepository failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for expected content
|
||||||
|
expectedStrings := []string{
|
||||||
|
"class UserRepository",
|
||||||
|
"CRUDRepository",
|
||||||
|
"model_cls = User",
|
||||||
|
"from televend_core.databases.base_repository import CRUDRepository",
|
||||||
|
"from televend_core.databases.cloud_repositories.user.filter import",
|
||||||
|
"from televend_core.databases.cloud_repositories.user.model import User",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedStrings {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateManager(t *testing.T) {
|
||||||
|
tableInfo := &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := NewContext(tableInfo, "")
|
||||||
|
|
||||||
|
result, err := GenerateManager(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateManager failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for expected content
|
||||||
|
expectedStrings := []string{
|
||||||
|
"class UserManager",
|
||||||
|
"CRUDManager",
|
||||||
|
"repository_cls = UserRepository",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range expectedStrings {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewContext(t *testing.T) {
|
||||||
|
tableInfo := &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "user_accounts",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test without override
|
||||||
|
ctx := NewContext(tableInfo, "")
|
||||||
|
if ctx.EntityName != "UserAccount" {
|
||||||
|
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
|
||||||
|
}
|
||||||
|
if ctx.ModuleName != "user_account" {
|
||||||
|
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
|
||||||
|
}
|
||||||
|
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
|
||||||
|
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with override
|
||||||
|
ctx = NewContext(tableInfo, "CustomUser")
|
||||||
|
if ctx.EntityName != "CustomUser" {
|
||||||
|
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateTable(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tableInfo *database.TableInfo
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple table with basic types",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
|
||||||
|
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
|
||||||
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from sqlalchemy import",
|
||||||
|
"Column",
|
||||||
|
"Table",
|
||||||
|
"Integer",
|
||||||
|
"String",
|
||||||
|
"USER_TABLE = Table(",
|
||||||
|
`"users"`,
|
||||||
|
"metadata_obj",
|
||||||
|
`Column("id", Integer, primary_key=True, autoincrement=True)`,
|
||||||
|
`Column("name", String(255), nullable=False)`,
|
||||||
|
`Column("email", String)`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "table with foreign keys",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "posts",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||||
|
{Name: "title", DataType: "text", IsNullable: false},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{
|
||||||
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"ForeignKey",
|
||||||
|
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "table with enum",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "orders",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{
|
||||||
|
"order_status": {
|
||||||
|
TypeName: "order_status",
|
||||||
|
Values: []string{"pending", "completed", "cancelled"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"Enum",
|
||||||
|
"from televend_core.databases.cloud_repositories.order.enum import",
|
||||||
|
"OrderStatus",
|
||||||
|
`Enum(`,
|
||||||
|
`*OrderStatus.to_value_list()`,
|
||||||
|
`name="order_status"`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContext(tt.tableInfo, "")
|
||||||
|
result, err := GenerateTable(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateTable failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateModel(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tableInfo *database.TableInfo
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple model",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||||
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from dataclasses import dataclass",
|
||||||
|
"from televend_core.databases.base_model import Base",
|
||||||
|
"@dataclass",
|
||||||
|
"class User(Base):",
|
||||||
|
"name: str",
|
||||||
|
"email: str | None = None",
|
||||||
|
"id: int | None = None",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with foreign key",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "posts",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||||
|
{Name: "title", DataType: "text", IsNullable: false},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{
|
||||||
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from televend_core.databases.cloud_repositories.user.model import User",
|
||||||
|
"user_id: int",
|
||||||
|
"user: User",
|
||||||
|
"title: str",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model with datetime and decimal",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "orders",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "amount", DataType: "numeric", IsNullable: false},
|
||||||
|
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from datetime import datetime",
|
||||||
|
"from decimal import Decimal",
|
||||||
|
"amount: Decimal",
|
||||||
|
"created_at: datetime | None = None",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContext(tt.tableInfo, "")
|
||||||
|
result, err := GenerateModel(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateModel failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateFilter(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tableInfo *database.TableInfo
|
||||||
|
expected []string
|
||||||
|
notExpect []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "filter with boolean and id fields",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||||
|
{Name: "name", DataType: "varchar"},
|
||||||
|
{Name: "alive", DataType: "boolean"},
|
||||||
|
{Name: "user_id", DataType: "integer"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from televend_core.databases.base_filter import BaseFilter",
|
||||||
|
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
|
||||||
|
"class UserFilter(BaseFilter):",
|
||||||
|
"model_cls = User",
|
||||||
|
"id: int | None = filterfield(operator=EQ)",
|
||||||
|
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
|
||||||
|
"name: str | None = filterfield(operator=EQ)",
|
||||||
|
"alive: bool | None = filterfield(operator=EQ, default=True)",
|
||||||
|
"user_id: int | None = filterfield(operator=EQ)",
|
||||||
|
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
|
||||||
|
},
|
||||||
|
notExpect: []string{
|
||||||
|
"default=None",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "filter with no filterable fields",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "logs",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "timestamp", DataType: "timestamp with time zone"},
|
||||||
|
{Name: "amount", DataType: "numeric"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"class LogFilter(BaseFilter):",
|
||||||
|
"pass",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContext(tt.tableInfo, "")
|
||||||
|
result, err := GenerateFilter(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateFilter failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, notExpected := range tt.notExpect {
|
||||||
|
if strings.Contains(result, notExpected) {
|
||||||
|
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
|
||||||
|
notExpected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateLoadOptions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tableInfo *database.TableInfo
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "load options with relationships",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "posts",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||||
|
{Name: "user_id", DataType: "integer"},
|
||||||
|
{Name: "category_id", DataType: "integer"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{
|
||||||
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from televend_core.databases.base_load_options import LoadOptions",
|
||||||
|
"from televend_core.databases.common.load_options import joinload",
|
||||||
|
"class PostLoadOptions(LoadOptions):",
|
||||||
|
"model_cls = Post",
|
||||||
|
`load_user: bool = joinload(relations=["user"])`,
|
||||||
|
`load_category: bool = joinload(relations=["category"])`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "load options with no relationships",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "settings",
|
||||||
|
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"class SettingLoadOptions(LoadOptions):",
|
||||||
|
"pass",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContext(tt.tableInfo, "")
|
||||||
|
result, err := GenerateLoadOptions(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateLoadOptions failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateFactory(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tableInfo *database.TableInfo
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "factory with basic fields",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "users",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
|
||||||
|
{Name: "alive", DataType: "boolean"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from __future__ import annotations",
|
||||||
|
"from typing import Type",
|
||||||
|
"import factory",
|
||||||
|
"class UserFactory(CloudBaseFactory):",
|
||||||
|
"alive = True",
|
||||||
|
"id = None",
|
||||||
|
`name = factory.Faker("pystr", max_chars=100)`,
|
||||||
|
"class Meta:",
|
||||||
|
"model = User",
|
||||||
|
"def create_minimal",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "factory with foreign keys",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "posts",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||||
|
{Name: "title", DataType: "text"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{
|
||||||
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"from televend_core.databases.cloud_repositories.user.factory import",
|
||||||
|
"UserFactory",
|
||||||
|
`user = CustomSelfAttribute("..user", UserFactory)`,
|
||||||
|
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
|
||||||
|
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "factory with decimal field",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "orders",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||||
|
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{},
|
||||||
|
EnumTypes: map[string]database.EnumType{},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContext(tt.tableInfo, "")
|
||||||
|
result, err := GenerateFactory(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateFactory failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateMapper(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
tableInfo *database.TableInfo
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "mapper with single foreign key",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "posts",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||||
|
{Name: "user_id", DataType: "integer"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{
|
||||||
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
"mapper_registry.map_imperatively(",
|
||||||
|
"class_=Post,",
|
||||||
|
"local_table=POST_TABLE,",
|
||||||
|
"properties={",
|
||||||
|
`"user": relationship(`,
|
||||||
|
"User, lazy=relationship_loading_strategy.value",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mapper with multiple foreign keys to same table",
|
||||||
|
tableInfo: &database.TableInfo{
|
||||||
|
Schema: "public",
|
||||||
|
TableName: "messages",
|
||||||
|
Columns: []database.Column{
|
||||||
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||||
|
{Name: "sender_id", DataType: "integer"},
|
||||||
|
{Name: "receiver_id", DataType: "integer"},
|
||||||
|
},
|
||||||
|
ForeignKeys: []database.ForeignKey{
|
||||||
|
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: []string{
|
||||||
|
`"sender": relationship(`,
|
||||||
|
"User, lazy=relationship_loading_strategy.value",
|
||||||
|
`"receiver": relationship(`,
|
||||||
|
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
ctx := NewContext(tt.tableInfo, "")
|
||||||
|
result, err := GenerateMapper(ctx)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GenerateMapper failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, expected := range tt.expected {
|
||||||
|
if !strings.Contains(result, expected) {
|
||||||
|
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||||
|
expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPythonTypeForColumn(t *testing.T) {
|
||||||
|
ctx := &Context{
|
||||||
|
TableInfo: &database.TableInfo{
|
||||||
|
EnumTypes: map[string]database.EnumType{
|
||||||
|
"status_enum": {
|
||||||
|
TypeName: "status_enum",
|
||||||
|
Values: []string{"active", "inactive"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
col database.Column
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "integer type",
|
||||||
|
col: database.Column{DataType: "integer"},
|
||||||
|
expected: "int",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "varchar type",
|
||||||
|
col: database.Column{DataType: "varchar"},
|
||||||
|
expected: "str",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "enum type",
|
||||||
|
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
|
||||||
|
expected: "StatusEnum",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown enum type",
|
||||||
|
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
|
||||||
|
expected: "str",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := GetPythonTypeForColumn(tt.col, ctx)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@ -12,7 +12,7 @@ func GenerateLoadOptions(ctx *Context) (string, error) {
|
|||||||
// Imports
|
// Imports
|
||||||
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
|
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
|
||||||
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
|
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
|
||||||
ctx.ModuleName, ctx.EntityName))
|
ctx.ModuleName, ctx.EntityName))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
|||||||
@ -11,13 +11,13 @@ func GenerateManager(ctx *Context) (string, error) {
|
|||||||
|
|
||||||
// Imports
|
// Imports
|
||||||
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
|
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.filter import (\n",
|
||||||
ctx.ModuleName))
|
ctx.ModuleName))
|
||||||
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||||
b.WriteString(")\n")
|
b.WriteString(")\n")
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
|
||||||
ctx.ModuleName, ctx.EntityName))
|
ctx.ModuleName, ctx.EntityName))
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.repository import (\n",
|
||||||
ctx.ModuleName))
|
ctx.ModuleName))
|
||||||
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
|
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
|
||||||
b.WriteString(")\n")
|
b.WriteString(")\n")
|
||||||
|
|||||||
@ -38,7 +38,7 @@ func GenerateModel(ctx *Context) (string, error) {
|
|||||||
|
|
||||||
// Import enum types
|
// Import enum types
|
||||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.enum import (\n",
|
||||||
ctx.ModuleName))
|
ctx.ModuleName))
|
||||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
@ -49,7 +49,7 @@ func GenerateModel(ctx *Context) (string, error) {
|
|||||||
|
|
||||||
// Write foreign key imports
|
// Write foreign key imports
|
||||||
for moduleName, entityName := range fkImports {
|
for moduleName, entityName := range fkImports {
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
|
||||||
moduleName, entityName))
|
moduleName, entityName))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -11,11 +11,11 @@ func GenerateRepository(ctx *Context) (string, error) {
|
|||||||
|
|
||||||
// Imports
|
// Imports
|
||||||
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
|
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.filter import (\n",
|
||||||
ctx.ModuleName))
|
ctx.ModuleName))
|
||||||
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||||
b.WriteString(")\n")
|
b.WriteString(")\n")
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.model import %s\n",
|
||||||
ctx.ModuleName, ctx.EntityName))
|
ctx.ModuleName, ctx.EntityName))
|
||||||
b.WriteString("\n\n")
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
|||||||
@ -72,14 +72,14 @@ func GenerateTable(ctx *Context) (string, error) {
|
|||||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.enum import (\n",
|
||||||
ctx.ModuleName))
|
ctx.ModuleName))
|
||||||
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||||
b.WriteString(")\n")
|
b.WriteString(")\n")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n")
|
b.WriteString("from televend_core.databases.cloud_repositories.table_meta import metadata_obj\n\n")
|
||||||
|
|
||||||
// Table definition
|
// Table definition
|
||||||
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
|
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
|
||||||
|
|||||||
@ -155,8 +155,12 @@ func Pluralize(word string) string {
|
|||||||
return preserveCase(word, plural)
|
return preserveCase(word, plural)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Already plural (ends in 's' and not special case)
|
// Already plural (ends in 's' after a consonant, but not 'ss', 'us', 'is')
|
||||||
if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") {
|
// Skip this check if word ends in 'ss' (like 'class', 'glass')
|
||||||
|
if strings.HasSuffix(lower, "s") &&
|
||||||
|
!strings.HasSuffix(lower, "ss") &&
|
||||||
|
!strings.HasSuffix(lower, "us") &&
|
||||||
|
!strings.HasSuffix(lower, "is") {
|
||||||
return word
|
return word
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
247
internal/naming/naming_test.go
Normal file
247
internal/naming/naming_test.go
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
package naming
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSingularize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// Regular plurals
|
||||||
|
{"simple s", "users", "user"},
|
||||||
|
{"simple s uppercase", "Users", "User"},
|
||||||
|
{"tables", "tables", "table"},
|
||||||
|
|
||||||
|
// ies -> y
|
||||||
|
{"ies to y", "companies", "company"},
|
||||||
|
{"ies to y uppercase", "Companies", "Company"},
|
||||||
|
{"cities", "cities", "city"},
|
||||||
|
{"categories", "categories", "category"},
|
||||||
|
|
||||||
|
// ves -> f/fe
|
||||||
|
{"ves to f", "halves", "half"},
|
||||||
|
{"ves to fe - knife", "knives", "knife"},
|
||||||
|
{"ves to fe - wife", "wives", "wife"},
|
||||||
|
{"ves to fe - life", "lives", "life"},
|
||||||
|
|
||||||
|
// es endings
|
||||||
|
{"xes", "boxes", "box"},
|
||||||
|
{"shes", "dishes", "dish"},
|
||||||
|
{"ches", "watches", "watch"},
|
||||||
|
{"ses", "classes", "class"},
|
||||||
|
|
||||||
|
// oes -> o
|
||||||
|
{"oes", "tomatoes", "tomato"},
|
||||||
|
{"heroes", "heroes", "hero"},
|
||||||
|
|
||||||
|
// Irregular plurals
|
||||||
|
{"people", "people", "person"},
|
||||||
|
{"children", "children", "child"},
|
||||||
|
{"men", "men", "man"},
|
||||||
|
{"women", "women", "woman"},
|
||||||
|
{"teeth", "teeth", "tooth"},
|
||||||
|
{"feet", "feet", "foot"},
|
||||||
|
{"mice", "mice", "mouse"},
|
||||||
|
{"geese", "geese", "goose"},
|
||||||
|
|
||||||
|
// Unchanged
|
||||||
|
{"sheep", "sheep", "sheep"},
|
||||||
|
{"fish", "fish", "fish"},
|
||||||
|
{"deer", "deer", "deer"},
|
||||||
|
{"series", "series", "series"},
|
||||||
|
{"species", "species", "species"},
|
||||||
|
|
||||||
|
// Special cases
|
||||||
|
{"data", "data", "datum"},
|
||||||
|
{"analyses", "analyses", "analysis"},
|
||||||
|
{"crises", "crises", "crisis"},
|
||||||
|
|
||||||
|
// Edge cases
|
||||||
|
{"empty", "", ""},
|
||||||
|
{"already singular", "user", "user"},
|
||||||
|
{"ss ending", "glass", "glass"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := Singularize(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Singularize(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPluralize(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
// Regular plurals
|
||||||
|
{"simple", "user", "users"},
|
||||||
|
{"simple uppercase", "User", "Users"},
|
||||||
|
{"table", "table", "tables"},
|
||||||
|
|
||||||
|
// y -> ies
|
||||||
|
{"y to ies", "company", "companies"},
|
||||||
|
{"y to ies uppercase", "Company", "Companies"},
|
||||||
|
{"city", "city", "cities"},
|
||||||
|
{"category", "category", "categories"},
|
||||||
|
|
||||||
|
// consonant + y stays
|
||||||
|
{"ay ending", "day", "days"},
|
||||||
|
{"ey ending", "key", "keys"},
|
||||||
|
{"oy ending", "boy", "boys"},
|
||||||
|
|
||||||
|
// f/fe -> ves
|
||||||
|
{"f to ves", "half", "halves"},
|
||||||
|
{"fe to ves - knife", "knife", "knives"},
|
||||||
|
{"fe to ves - wife", "wife", "wives"},
|
||||||
|
{"fe to ves - life", "life", "lives"},
|
||||||
|
|
||||||
|
// s, x, sh, ch -> es
|
||||||
|
{"x to es", "box", "boxes"},
|
||||||
|
{"sh to es", "dish", "dishes"},
|
||||||
|
{"ch to es", "watch", "watches"},
|
||||||
|
{"s to es", "class", "classes"},
|
||||||
|
|
||||||
|
// o -> oes
|
||||||
|
{"o to oes", "tomato", "tomatoes"},
|
||||||
|
{"hero", "hero", "heroes"},
|
||||||
|
|
||||||
|
// Irregular plurals
|
||||||
|
{"person", "person", "people"},
|
||||||
|
{"child", "child", "children"},
|
||||||
|
{"man", "man", "men"},
|
||||||
|
{"woman", "woman", "women"},
|
||||||
|
{"tooth", "tooth", "teeth"},
|
||||||
|
{"foot", "foot", "feet"},
|
||||||
|
{"mouse", "mouse", "mice"},
|
||||||
|
{"goose", "goose", "geese"},
|
||||||
|
|
||||||
|
// Unchanged
|
||||||
|
{"sheep", "sheep", "sheep"},
|
||||||
|
{"fish", "fish", "fish"},
|
||||||
|
{"deer", "deer", "deer"},
|
||||||
|
{"series", "series", "series"},
|
||||||
|
{"species", "species", "species"},
|
||||||
|
|
||||||
|
// Edge cases
|
||||||
|
{"empty", "", ""},
|
||||||
|
{"already plural", "users", "users"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := Pluralize(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Pluralize(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSingularizeTableName(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"simple", "users", "user"},
|
||||||
|
{"compound", "user_accounts", "user_account"},
|
||||||
|
{"triple compound", "user_login_histories", "user_login_history"},
|
||||||
|
{"already singular", "user", "user"},
|
||||||
|
{"ies ending", "cashbag_conforms", "cashbag_conform"},
|
||||||
|
{"complex table", "auth_user_groups", "auth_user_group"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := SingularizeTableName(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("SingularizeTableName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToPascalCase(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"simple", "user", "User"},
|
||||||
|
{"snake_case", "user_account", "UserAccount"},
|
||||||
|
{"triple", "user_login_history", "UserLoginHistory"},
|
||||||
|
{"already pascal", "UserAccount", "UserAccount"},
|
||||||
|
{"single char", "a", "A"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
{"with numbers", "user_2fa", "User2fa"},
|
||||||
|
{"multiple underscores", "user___account", "UserAccount"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ToPascalCase(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ToPascalCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToSnakeCase(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{"PascalCase", "UserAccount", "user_account"},
|
||||||
|
{"simple", "User", "user"},
|
||||||
|
{"multiple words", "UserLoginHistory", "user_login_history"},
|
||||||
|
{"already snake", "user_account", "user_account"},
|
||||||
|
{"empty", "", ""},
|
||||||
|
{"single char", "A", "a"},
|
||||||
|
{"with numbers", "User2FA", "user2_f_a"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := ToSnakeCase(tt.input)
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("ToSnakeCase(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkSingularize(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
Singularize("companies")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPluralize(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
Pluralize("company")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkToPascalCase(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ToPascalCase("user_login_history")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkToSnakeCase(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
ToSnakeCase("UserLoginHistory")
|
||||||
|
}
|
||||||
|
}
|
||||||
319
internal/prompt/prompt_test.go
Normal file
319
internal/prompt/prompt_test.go
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
package prompt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// mockReader implements the Reader interface for testing
|
||||||
|
type mockReader struct {
|
||||||
|
input string
|
||||||
|
pos int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockReader) ReadString(delim byte) (string, error) {
|
||||||
|
if m.pos >= len(m.input) {
|
||||||
|
return "", io.EOF
|
||||||
|
}
|
||||||
|
|
||||||
|
idx := strings.IndexByte(m.input[m.pos:], delim)
|
||||||
|
if idx == -1 {
|
||||||
|
result := m.input[m.pos:]
|
||||||
|
m.pos = len(m.input)
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := m.input[m.pos : m.pos+idx+1]
|
||||||
|
m.pos += idx + 1
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptStringWithReader(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
defaultValue string
|
||||||
|
required bool
|
||||||
|
expected string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user enters value",
|
||||||
|
input: "testvalue\n",
|
||||||
|
defaultValue: "default",
|
||||||
|
required: false,
|
||||||
|
expected: "testvalue",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user accepts default",
|
||||||
|
input: "\n",
|
||||||
|
defaultValue: "default",
|
||||||
|
required: false,
|
||||||
|
expected: "default",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required field empty then filled",
|
||||||
|
input: "\nactualvalue\n",
|
||||||
|
defaultValue: "",
|
||||||
|
required: true,
|
||||||
|
expected: "actualvalue",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "optional empty field",
|
||||||
|
input: "\n",
|
||||||
|
defaultValue: "",
|
||||||
|
required: false,
|
||||||
|
expected: "",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace trimmed",
|
||||||
|
input: " test \n",
|
||||||
|
defaultValue: "",
|
||||||
|
required: false,
|
||||||
|
expected: "test",
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
reader := &mockReader{input: tt.input}
|
||||||
|
result, err := promptStringWithReader(reader, "Test Label", tt.defaultValue, tt.required)
|
||||||
|
|
||||||
|
if (err != nil) != tt.expectError {
|
||||||
|
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %q, got %q", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptIntWithReader(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
defaultValue int
|
||||||
|
required bool
|
||||||
|
expected int
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "user enters valid number",
|
||||||
|
input: "42\n",
|
||||||
|
defaultValue: 0,
|
||||||
|
required: false,
|
||||||
|
expected: 42,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "user accepts default",
|
||||||
|
input: "\n",
|
||||||
|
defaultValue: 5432,
|
||||||
|
required: false,
|
||||||
|
expected: 5432,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid then valid number",
|
||||||
|
input: "abc\n123\n",
|
||||||
|
defaultValue: 0,
|
||||||
|
required: false,
|
||||||
|
expected: 123,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "required field empty then filled",
|
||||||
|
input: "\n999\n",
|
||||||
|
defaultValue: 0,
|
||||||
|
required: true,
|
||||||
|
expected: 999,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero value",
|
||||||
|
input: "0\n",
|
||||||
|
defaultValue: 10,
|
||||||
|
required: false,
|
||||||
|
expected: 0,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
reader := &mockReader{input: tt.input}
|
||||||
|
result, err := promptIntWithReader(reader, "Test Label", tt.defaultValue, tt.required)
|
||||||
|
|
||||||
|
if (err != nil) != tt.expectError {
|
||||||
|
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != tt.expected {
|
||||||
|
t.Errorf("Expected %d, got %d", tt.expected, result)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateDirectory(t *testing.T) {
|
||||||
|
// Create a temporary directory for testing
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
setup func() string
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "existing directory",
|
||||||
|
setup: func() string {
|
||||||
|
return tmpDir
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non-existent directory gets created",
|
||||||
|
setup: func() string {
|
||||||
|
return filepath.Join(tmpDir, "newdir")
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nested directory gets created",
|
||||||
|
setup: func() string {
|
||||||
|
return filepath.Join(tmpDir, "level1", "level2", "level3")
|
||||||
|
},
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "file exists at path",
|
||||||
|
setup: func() string {
|
||||||
|
filePath := filepath.Join(tmpDir, "testfile")
|
||||||
|
os.WriteFile(filePath, []byte("test"), 0644)
|
||||||
|
return filePath
|
||||||
|
},
|
||||||
|
expectError: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
path := tt.setup()
|
||||||
|
err := ValidateDirectory(path)
|
||||||
|
|
||||||
|
if (err != nil) != tt.expectError {
|
||||||
|
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no error expected, verify directory exists
|
||||||
|
if !tt.expectError {
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Directory should exist: %v", err)
|
||||||
|
}
|
||||||
|
if !info.IsDir() {
|
||||||
|
t.Errorf("Path should be a directory")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrintFunctions(t *testing.T) {
|
||||||
|
// These functions write to stdout, so we just ensure they don't panic
|
||||||
|
// Testing colored output is complex due to terminal color codes
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
fn func()
|
||||||
|
}{
|
||||||
|
{"PrintHeader", func() { PrintHeader("Test Header") }},
|
||||||
|
{"PrintSuccess", func() { PrintSuccess("Success message") }},
|
||||||
|
{"PrintError", func() { PrintError("Error message") }},
|
||||||
|
{"PrintInfo", func() { PrintInfo("Info message") }},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Just ensure the function doesn't panic
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Errorf("Function panicked: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
tt.fn()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPromptDirectory(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
defaultValue string
|
||||||
|
required bool
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid directory path",
|
||||||
|
input: tmpDir + "\n",
|
||||||
|
defaultValue: "",
|
||||||
|
required: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "creates new directory",
|
||||||
|
input: filepath.Join(tmpDir, "newdir") + "\n",
|
||||||
|
defaultValue: "",
|
||||||
|
required: true,
|
||||||
|
expectError: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
// Note: This test would require refactoring PromptDirectory to accept a reader
|
||||||
|
// For now, we test ValidateDirectory which is the core logic
|
||||||
|
if tt.input != "" {
|
||||||
|
path := strings.TrimSpace(tt.input[:len(tt.input)-1])
|
||||||
|
err := ValidateDirectory(path)
|
||||||
|
if (err != nil) != tt.expectError {
|
||||||
|
t.Errorf("Expected error: %v, got: %v", tt.expectError, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark tests
|
||||||
|
func BenchmarkPromptStringWithReader(b *testing.B) {
|
||||||
|
reader := &mockReader{input: strings.Repeat("test\n", b.N)}
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
promptStringWithReader(reader, "Label", "default", false)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkPromptIntWithReader(b *testing.B) {
|
||||||
|
reader := &mockReader{input: strings.Repeat("42\n", b.N)}
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
promptIntWithReader(reader, "Label", 0, false)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user