First working version
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@ -1 +1,5 @@
|
|||||||
|
/build
|
||||||
/output
|
/output
|
||||||
|
|
||||||
|
*.toml
|
||||||
|
/entity-maker
|
||||||
|
|||||||
10
Makefile
Normal file
10
Makefile
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
EXEC=entity-maker
|
||||||
|
|
||||||
|
|
||||||
|
.PHONY: build
|
||||||
|
build:
|
||||||
|
@go build -ldflags "-s -w" -o ./build/${EXEC} ./cmd/entity-maker/main.go
|
||||||
|
|
||||||
|
|
||||||
|
upgrade-packages:
|
||||||
|
@go get -u ./...
|
||||||
136
README.md
Normal file
136
README.md
Normal file
@ -0,0 +1,136 @@
|
|||||||
|
# Entity Maker
|
||||||
|
|
||||||
|
A command-line tool for generating SQLAlchemy entities from PostgreSQL database tables.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Connects to PostgreSQL database and introspects table structure
|
||||||
|
- Generates Python 3.13 code using SQLAlchemy
|
||||||
|
- Creates complete entity scaffolding including:
|
||||||
|
- SQLAlchemy table definition
|
||||||
|
- Dataclass model
|
||||||
|
- Filter class
|
||||||
|
- Load options
|
||||||
|
- Repository
|
||||||
|
- Manager
|
||||||
|
- Factory (for testing)
|
||||||
|
- Mapper
|
||||||
|
- Enum types (if table uses PostgreSQL enums)
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
### Build from source
|
||||||
|
|
||||||
|
```bash
|
||||||
|
go build -o entity-maker ./cmd/entity-maker
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
### Interactive Mode
|
||||||
|
|
||||||
|
Run the application without arguments to use interactive prompts:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./entity-maker
|
||||||
|
```
|
||||||
|
|
||||||
|
The application will prompt you for:
|
||||||
|
- Database host (default: localhost)
|
||||||
|
- Database port (default: 5432)
|
||||||
|
- Database name (required)
|
||||||
|
- Database schema (default: public)
|
||||||
|
- Database user (default: postgres)
|
||||||
|
- Database password (default: postgres)
|
||||||
|
- Table name (required)
|
||||||
|
- Output directory (required)
|
||||||
|
- Entity name override (optional)
|
||||||
|
|
||||||
|
### Command-Line Arguments
|
||||||
|
|
||||||
|
You can also provide parameters via command-line flags:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./entity-maker \
|
||||||
|
-host localhost \
|
||||||
|
-port 5432 \
|
||||||
|
-db mydb \
|
||||||
|
-schema public \
|
||||||
|
-user postgres \
|
||||||
|
-password secret \
|
||||||
|
-table users \
|
||||||
|
-output ./output \
|
||||||
|
-entity User
|
||||||
|
```
|
||||||
|
|
||||||
|
### Available Flags
|
||||||
|
|
||||||
|
- `-host` - Database host
|
||||||
|
- `-port` - Database port
|
||||||
|
- `-db` - Database name
|
||||||
|
- `-schema` - Database schema
|
||||||
|
- `-user` - Database user
|
||||||
|
- `-password` - Database password
|
||||||
|
- `-table` - Table name to generate entities from
|
||||||
|
- `-output` - Output directory path
|
||||||
|
- `-entity` - Entity name override (optional, defaults to singularized table name)
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
The application saves your settings to `~/.config/entity-maker.toml` for future use. When you run the application again, it will use these saved values as defaults.
|
||||||
|
|
||||||
|
## Generated Files
|
||||||
|
|
||||||
|
For a table named `users`, the following files will be generated in `output/user/`:
|
||||||
|
|
||||||
|
- `table.py` - SQLAlchemy Table definition
|
||||||
|
- `model.py` - Dataclass model with type hints
|
||||||
|
- `filter.py` - Filter class for queries
|
||||||
|
- `load_options.py` - Relationship loading options
|
||||||
|
- `repository.py` - CRUD repository
|
||||||
|
- `manager.py` - Manager class
|
||||||
|
- `factory.py` - Factory for testing (uses factory_boy)
|
||||||
|
- `mapper.py` - SQLAlchemy mapper configuration
|
||||||
|
- `enum.py` - Enum types (only if table uses PostgreSQL enums)
|
||||||
|
- `__init__.py` - Package initialization
|
||||||
|
|
||||||
|
## Example
|
||||||
|
|
||||||
|
Generate entities for the `cashbag_conforms` table:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
./entity-maker -db mydb -table cashbag_conforms -output ./output
|
||||||
|
```
|
||||||
|
|
||||||
|
This will create files in `./output/cashbag_conform/`:
|
||||||
|
- Entity name: `CashbagConform`
|
||||||
|
- Module name: `cashbag_conform`
|
||||||
|
- All supporting files
|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
- PostgreSQL database
|
||||||
|
- Go 1.21+ (for building)
|
||||||
|
- Python 3.13+ (for generated code)
|
||||||
|
- SQLAlchemy (for generated code)
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
entity-maker/
|
||||||
|
├── cmd/
|
||||||
|
│ └── entity-maker/ # Main application
|
||||||
|
├── internal/
|
||||||
|
│ ├── config/ # Configuration management
|
||||||
|
│ ├── database/ # Database connection and introspection
|
||||||
|
│ ├── generator/ # Code generators
|
||||||
|
│ ├── naming/ # Naming utilities (pluralization, case conversion)
|
||||||
|
│ └── prompt/ # Interactive CLI prompts
|
||||||
|
├── example/ # Example output files
|
||||||
|
├── CLAUDE.md # Detailed specification
|
||||||
|
└── README.md # This file
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
See LICENSE file for details.
|
||||||
240
cmd/entity-maker/main.go
Normal file
240
cmd/entity-maker/main.go
Normal file
@ -0,0 +1,240 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/config"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/generator"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/naming"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/prompt"
|
||||||
|
"github.com/fatih/color"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
if err := run(); err != nil {
|
||||||
|
prompt.PrintError(err.Error())
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func run() error {
|
||||||
|
// Parse command line flags
|
||||||
|
var (
|
||||||
|
dbHost = flag.String("host", "", "Database host")
|
||||||
|
dbPort = flag.Int("port", 0, "Database port")
|
||||||
|
dbName = flag.String("db", "", "Database name")
|
||||||
|
dbSchema = flag.String("schema", "", "Database schema")
|
||||||
|
dbUser = flag.String("user", "", "Database user")
|
||||||
|
dbPassword = flag.String("password", "", "Database password")
|
||||||
|
dbTable = flag.String("table", "", "Database table name")
|
||||||
|
outputDir = flag.String("output", "", "Output directory")
|
||||||
|
entityName = flag.String("entity", "", "Entity name override")
|
||||||
|
)
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
// Print header
|
||||||
|
header := color.New(color.FgCyan, color.Bold)
|
||||||
|
header.Println("\n╔════════════════════════════════════════╗")
|
||||||
|
header.Println("║ Entity Maker for TelevendCore ║")
|
||||||
|
header.Println("╚════════════════════════════════════════╝")
|
||||||
|
|
||||||
|
// Load configuration
|
||||||
|
cfg, err := config.Load()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prompt for parameters
|
||||||
|
prompt.PrintHeader("Database Connection")
|
||||||
|
|
||||||
|
// Override config with command line flags if provided
|
||||||
|
if *dbHost != "" {
|
||||||
|
cfg.DBHost = *dbHost
|
||||||
|
}
|
||||||
|
if *dbPort != 0 {
|
||||||
|
cfg.DBPort = *dbPort
|
||||||
|
}
|
||||||
|
if *dbName != "" {
|
||||||
|
cfg.DBName = *dbName
|
||||||
|
}
|
||||||
|
if *dbSchema != "" {
|
||||||
|
cfg.DBSchema = *dbSchema
|
||||||
|
}
|
||||||
|
if *dbUser != "" {
|
||||||
|
cfg.DBUser = *dbUser
|
||||||
|
}
|
||||||
|
if *dbPassword != "" {
|
||||||
|
cfg.DBPassword = *dbPassword
|
||||||
|
}
|
||||||
|
if *dbTable != "" {
|
||||||
|
cfg.DBTable = *dbTable
|
||||||
|
}
|
||||||
|
if *outputDir != "" {
|
||||||
|
cfg.OutputDir = *outputDir
|
||||||
|
}
|
||||||
|
if *entityName != "" {
|
||||||
|
cfg.EntityName = *entityName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prompt for missing parameters
|
||||||
|
cfg.DBHost, err = prompt.PromptString("Database host", cfg.DBHost, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DBPort, err = prompt.PromptInt("Database port", cfg.DBPort, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DBName, err = prompt.PromptString("Database name", cfg.DBName, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DBSchema, err = prompt.PromptString("Database schema", cfg.DBSchema, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DBUser, err = prompt.PromptString("Database user", cfg.DBUser, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DBPassword, err = prompt.PromptString("Database password", cfg.DBPassword, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DBTable, err = prompt.PromptString("Table name", cfg.DBTable, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.OutputDir, err = prompt.PromptDirectory("Output directory", cfg.OutputDir, true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.EntityName, err = prompt.PromptString("Entity name (optional)", cfg.EntityName, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save configuration
|
||||||
|
if err := cfg.Save(); err != nil {
|
||||||
|
prompt.PrintError(fmt.Sprintf("Failed to save config: %v", err))
|
||||||
|
// Don't fail, just warn
|
||||||
|
} else {
|
||||||
|
prompt.PrintSuccess("Configuration saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to database
|
||||||
|
prompt.PrintHeader("Connecting to Database")
|
||||||
|
|
||||||
|
dbClient, err := database.NewClient(database.Config{
|
||||||
|
Host: cfg.DBHost,
|
||||||
|
Port: cfg.DBPort,
|
||||||
|
Database: cfg.DBName,
|
||||||
|
Schema: cfg.DBSchema,
|
||||||
|
User: cfg.DBUser,
|
||||||
|
Password: cfg.DBPassword,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("database connection failed: %w", err)
|
||||||
|
}
|
||||||
|
defer dbClient.Close()
|
||||||
|
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Connected to %s@%s:%d/%s", cfg.DBUser, cfg.DBHost, cfg.DBPort, cfg.DBName))
|
||||||
|
|
||||||
|
// Check if table exists
|
||||||
|
exists, err := dbClient.TableExists(cfg.DBTable)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to check table existence: %w", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
return fmt.Errorf("table '%s' does not exist in schema '%s'", cfg.DBTable, cfg.DBSchema)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Found table '%s'", cfg.DBTable))
|
||||||
|
|
||||||
|
// Introspect table
|
||||||
|
prompt.PrintHeader("Introspecting Table")
|
||||||
|
|
||||||
|
tableInfo, err := dbClient.IntrospectTable(cfg.DBTable)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to introspect table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Found %d columns", len(tableInfo.Columns)))
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Found %d foreign keys", len(tableInfo.ForeignKeys)))
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Found %d enum types", len(tableInfo.EnumTypes)))
|
||||||
|
|
||||||
|
// Create generation context
|
||||||
|
ctx := generator.NewContext(tableInfo, cfg.EntityName)
|
||||||
|
|
||||||
|
// Create output directory
|
||||||
|
moduleName := naming.SingularizeTableName(cfg.DBTable)
|
||||||
|
moduleDir := filepath.Join(cfg.OutputDir, moduleName)
|
||||||
|
|
||||||
|
if err := os.MkdirAll(moduleDir, 0755); err != nil {
|
||||||
|
return fmt.Errorf("failed to create output directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Created directory: %s", moduleDir))
|
||||||
|
|
||||||
|
// Generate files
|
||||||
|
prompt.PrintHeader("Generating Files")
|
||||||
|
|
||||||
|
files := map[string]func(*generator.Context) (string, error){
|
||||||
|
"table.py": generator.GenerateTable,
|
||||||
|
"model.py": generator.GenerateModel,
|
||||||
|
"filter.py": generator.GenerateFilter,
|
||||||
|
"load_options.py": generator.GenerateLoadOptions,
|
||||||
|
"repository.py": generator.GenerateRepository,
|
||||||
|
"manager.py": generator.GenerateManager,
|
||||||
|
"factory.py": generator.GenerateFactory,
|
||||||
|
"mapper.py": generator.GenerateMapper,
|
||||||
|
"__init__.py": generator.GenerateInit,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add enum.py if there are enum types
|
||||||
|
if len(tableInfo.EnumTypes) > 0 {
|
||||||
|
files["enum.py"] = generator.GenerateEnum
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate and write each file
|
||||||
|
for filename, genFunc := range files {
|
||||||
|
content, err := genFunc(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := filepath.Join(moduleDir, filename)
|
||||||
|
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.PrintSuccess(fmt.Sprintf("Generated %s", filename))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print summary
|
||||||
|
prompt.PrintHeader("Summary")
|
||||||
|
fmt.Printf("\n")
|
||||||
|
fmt.Printf(" Entity name: %s\n", color.GreenString(ctx.EntityName))
|
||||||
|
fmt.Printf(" Module name: %s\n", color.GreenString(ctx.ModuleName))
|
||||||
|
fmt.Printf(" Output dir: %s\n", color.GreenString(moduleDir))
|
||||||
|
fmt.Printf(" Files: %s\n", color.GreenString("%d files generated", len(files)))
|
||||||
|
fmt.Printf("\n")
|
||||||
|
|
||||||
|
success := color.New(color.FgGreen, color.Bold)
|
||||||
|
success.Println("✓ Code generation completed successfully!")
|
||||||
|
fmt.Println()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
15
go.mod
Normal file
15
go.mod
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
module github.com/entity-maker/entity-maker
|
||||||
|
|
||||||
|
go 1.25.3
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/BurntSushi/toml v1.5.0
|
||||||
|
github.com/fatih/color v1.18.0
|
||||||
|
github.com/lib/pq v1.10.9
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/mattn/go-colorable v0.1.13 // indirect
|
||||||
|
github.com/mattn/go-isatty v0.0.20 // indirect
|
||||||
|
golang.org/x/sys v0.25.0 // indirect
|
||||||
|
)
|
||||||
15
go.sum
Normal file
15
go.sum
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
|
||||||
|
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
|
||||||
|
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
|
||||||
|
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
|
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
|
||||||
|
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
|
||||||
|
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
|
||||||
|
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
|
||||||
|
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
|
||||||
|
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||||
|
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
|
||||||
|
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
82
internal/config/config.go
Normal file
82
internal/config/config.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/BurntSushi/toml"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config represents the application configuration
|
||||||
|
type Config struct {
|
||||||
|
DBHost string `toml:"db_host"`
|
||||||
|
DBPort int `toml:"db_port"`
|
||||||
|
DBName string `toml:"db_name"`
|
||||||
|
DBSchema string `toml:"db_schema"`
|
||||||
|
DBUser string `toml:"db_user"`
|
||||||
|
DBPassword string `toml:"db_password"`
|
||||||
|
DBTable string `toml:"db_table"`
|
||||||
|
OutputDir string `toml:"output_dir"`
|
||||||
|
EntityName string `toml:"entity_name"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultConfig returns a new config with default values
|
||||||
|
func DefaultConfig() *Config {
|
||||||
|
return &Config{
|
||||||
|
DBHost: "localhost",
|
||||||
|
DBPort: 5432,
|
||||||
|
DBSchema: "public",
|
||||||
|
DBUser: "postgres",
|
||||||
|
DBPassword: "postgres",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load reads the configuration file from ~/.config/entity-maker.toml
|
||||||
|
func Load() (*Config, error) {
|
||||||
|
configPath := getConfigPath()
|
||||||
|
|
||||||
|
config := DefaultConfig()
|
||||||
|
|
||||||
|
// Check if config file exists
|
||||||
|
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and decode the config file
|
||||||
|
if _, err := toml.DecodeFile(configPath, config); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save writes the configuration to ~/.config/entity-maker.toml
|
||||||
|
func (c *Config) Save() error {
|
||||||
|
configPath := getConfigPath()
|
||||||
|
|
||||||
|
// Create config directory if it doesn't exist
|
||||||
|
configDir := filepath.Dir(configPath)
|
||||||
|
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create or truncate the config file
|
||||||
|
f, err := os.Create(configPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
// Encode and write the config
|
||||||
|
encoder := toml.NewEncoder(f)
|
||||||
|
return encoder.Encode(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getConfigPath returns the full path to the config file
|
||||||
|
func getConfigPath() string {
|
||||||
|
homeDir, err := os.UserHomeDir()
|
||||||
|
if err != nil {
|
||||||
|
homeDir = "."
|
||||||
|
}
|
||||||
|
return filepath.Join(homeDir, ".config", "entity-maker.toml")
|
||||||
|
}
|
||||||
67
internal/database/client.go
Normal file
67
internal/database/client.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Client represents a database connection
|
||||||
|
type Client struct {
|
||||||
|
db *sql.DB
|
||||||
|
schema string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config holds database connection parameters
|
||||||
|
type Config struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Database string
|
||||||
|
Schema string
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClient creates a new database client
|
||||||
|
func NewClient(cfg Config) (*Client, error) {
|
||||||
|
connStr := fmt.Sprintf(
|
||||||
|
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||||
|
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Database,
|
||||||
|
)
|
||||||
|
|
||||||
|
db, err := sql.Open("postgres", connStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := db.Ping(); err != nil {
|
||||||
|
db.Close()
|
||||||
|
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
db: db,
|
||||||
|
schema: cfg.Schema,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes the database connection
|
||||||
|
func (c *Client) Close() error {
|
||||||
|
return c.db.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableExists checks if a table exists in the schema
|
||||||
|
func (c *Client) TableExists(tableName string) (bool, error) {
|
||||||
|
query := `
|
||||||
|
SELECT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM information_schema.tables
|
||||||
|
WHERE table_schema = $1 AND table_name = $2
|
||||||
|
)
|
||||||
|
`
|
||||||
|
|
||||||
|
var exists bool
|
||||||
|
err := c.db.QueryRow(query, c.schema, tableName).Scan(&exists)
|
||||||
|
return exists, err
|
||||||
|
}
|
||||||
341
internal/database/introspector.go
Normal file
341
internal/database/introspector.go
Normal file
@ -0,0 +1,341 @@
|
|||||||
|
package database
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Column represents a database column
|
||||||
|
type Column struct {
|
||||||
|
Name string
|
||||||
|
DataType string
|
||||||
|
IsNullable bool
|
||||||
|
ColumnDefault sql.NullString
|
||||||
|
CharMaxLength sql.NullInt64
|
||||||
|
NumericPrecision sql.NullInt64
|
||||||
|
NumericScale sql.NullInt64
|
||||||
|
UdtName string // User-defined type name (for enums)
|
||||||
|
IsPrimaryKey bool
|
||||||
|
IsAutoIncrement bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ForeignKey represents a foreign key constraint
|
||||||
|
type ForeignKey struct {
|
||||||
|
ColumnName string
|
||||||
|
ForeignTableSchema string
|
||||||
|
ForeignTableName string
|
||||||
|
ForeignColumnName string
|
||||||
|
ConstraintName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// EnumType represents a PostgreSQL enum type
|
||||||
|
type EnumType struct {
|
||||||
|
TypeName string
|
||||||
|
Values []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// TableInfo contains all information about a table
|
||||||
|
type TableInfo struct {
|
||||||
|
Schema string
|
||||||
|
TableName string
|
||||||
|
Columns []Column
|
||||||
|
ForeignKeys []ForeignKey
|
||||||
|
EnumTypes map[string]EnumType
|
||||||
|
}
|
||||||
|
|
||||||
|
// IntrospectTable retrieves comprehensive information about a table
|
||||||
|
func (c *Client) IntrospectTable(tableName string) (*TableInfo, error) {
|
||||||
|
info := &TableInfo{
|
||||||
|
Schema: c.schema,
|
||||||
|
TableName: tableName,
|
||||||
|
EnumTypes: make(map[string]EnumType),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get columns
|
||||||
|
columns, err := c.getColumns(tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
info.Columns = columns
|
||||||
|
|
||||||
|
// Get primary keys and mark columns
|
||||||
|
primaryKeys, err := c.getPrimaryKeys(tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for i := range info.Columns {
|
||||||
|
for _, pk := range primaryKeys {
|
||||||
|
if info.Columns[i].Name == pk {
|
||||||
|
info.Columns[i].IsPrimaryKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get foreign keys
|
||||||
|
foreignKeys, err := c.getForeignKeys(tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
info.ForeignKeys = foreignKeys
|
||||||
|
|
||||||
|
// Get enum types for columns that use them
|
||||||
|
for _, col := range info.Columns {
|
||||||
|
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||||
|
if _, exists := info.EnumTypes[col.UdtName]; !exists {
|
||||||
|
enumType, err := c.getEnumType(col.UdtName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
info.EnumTypes[col.UdtName] = enumType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getColumns retrieves column information for a table
|
||||||
|
func (c *Client) getColumns(tableName string) ([]Column, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
column_name,
|
||||||
|
data_type,
|
||||||
|
is_nullable = 'YES' as is_nullable,
|
||||||
|
column_default,
|
||||||
|
character_maximum_length,
|
||||||
|
numeric_precision,
|
||||||
|
numeric_scale,
|
||||||
|
udt_name
|
||||||
|
FROM information_schema.columns
|
||||||
|
WHERE table_schema = $1 AND table_name = $2
|
||||||
|
ORDER BY ordinal_position
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := c.db.Query(query, c.schema, tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var columns []Column
|
||||||
|
for rows.Next() {
|
||||||
|
var col Column
|
||||||
|
err := rows.Scan(
|
||||||
|
&col.Name,
|
||||||
|
&col.DataType,
|
||||||
|
&col.IsNullable,
|
||||||
|
&col.ColumnDefault,
|
||||||
|
&col.CharMaxLength,
|
||||||
|
&col.NumericPrecision,
|
||||||
|
&col.NumericScale,
|
||||||
|
&col.UdtName,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if column is auto-increment
|
||||||
|
if col.ColumnDefault.Valid {
|
||||||
|
defaultVal := strings.ToLower(col.ColumnDefault.String)
|
||||||
|
col.IsAutoIncrement = strings.Contains(defaultVal, "nextval")
|
||||||
|
}
|
||||||
|
|
||||||
|
columns = append(columns, col)
|
||||||
|
}
|
||||||
|
|
||||||
|
return columns, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPrimaryKeys retrieves primary key column names for a table
|
||||||
|
func (c *Client) getPrimaryKeys(tableName string) ([]string, error) {
|
||||||
|
query := `
|
||||||
|
SELECT a.attname
|
||||||
|
FROM pg_index i
|
||||||
|
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||||
|
WHERE i.indrelid = ($1 || '.' || $2)::regclass
|
||||||
|
AND i.indisprimary
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := c.db.Query(query, c.schema, tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query primary keys: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
for rows.Next() {
|
||||||
|
var key string
|
||||||
|
if err := rows.Scan(&key); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getForeignKeys retrieves foreign key information for a table
|
||||||
|
func (c *Client) getForeignKeys(tableName string) ([]ForeignKey, error) {
|
||||||
|
query := `
|
||||||
|
SELECT
|
||||||
|
kcu.column_name,
|
||||||
|
ccu.table_schema AS foreign_table_schema,
|
||||||
|
ccu.table_name AS foreign_table_name,
|
||||||
|
ccu.column_name AS foreign_column_name,
|
||||||
|
tc.constraint_name
|
||||||
|
FROM information_schema.table_constraints AS tc
|
||||||
|
JOIN information_schema.key_column_usage AS kcu
|
||||||
|
ON tc.constraint_name = kcu.constraint_name
|
||||||
|
AND tc.table_schema = kcu.table_schema
|
||||||
|
JOIN information_schema.constraint_column_usage AS ccu
|
||||||
|
ON ccu.constraint_name = tc.constraint_name
|
||||||
|
AND ccu.table_schema = tc.table_schema
|
||||||
|
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||||
|
AND tc.table_schema = $1
|
||||||
|
AND tc.table_name = $2
|
||||||
|
ORDER BY kcu.ordinal_position
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := c.db.Query(query, c.schema, tableName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to query foreign keys: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var fks []ForeignKey
|
||||||
|
for rows.Next() {
|
||||||
|
var fk ForeignKey
|
||||||
|
err := rows.Scan(
|
||||||
|
&fk.ColumnName,
|
||||||
|
&fk.ForeignTableSchema,
|
||||||
|
&fk.ForeignTableName,
|
||||||
|
&fk.ForeignColumnName,
|
||||||
|
&fk.ConstraintName,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
fks = append(fks, fk)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fks, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEnumType retrieves enum values for a PostgreSQL enum type
|
||||||
|
func (c *Client) getEnumType(typeName string) (EnumType, error) {
|
||||||
|
query := `
|
||||||
|
SELECT e.enumlabel
|
||||||
|
FROM pg_type t
|
||||||
|
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||||
|
WHERE t.typname = $1
|
||||||
|
ORDER BY e.enumsortorder
|
||||||
|
`
|
||||||
|
|
||||||
|
rows, err := c.db.Query(query, typeName)
|
||||||
|
if err != nil {
|
||||||
|
return EnumType{}, fmt.Errorf("failed to query enum type: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var values []string
|
||||||
|
for rows.Next() {
|
||||||
|
var value string
|
||||||
|
if err := rows.Scan(&value); err != nil {
|
||||||
|
return EnumType{}, err
|
||||||
|
}
|
||||||
|
values = append(values, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
return EnumType{
|
||||||
|
TypeName: typeName,
|
||||||
|
Values: values,
|
||||||
|
}, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPythonType converts PostgreSQL data type to Python type
|
||||||
|
func GetPythonType(col Column) string {
|
||||||
|
switch col.DataType {
|
||||||
|
case "integer", "smallint", "bigint":
|
||||||
|
return "int"
|
||||||
|
case "numeric", "decimal", "real", "double precision":
|
||||||
|
return "Decimal"
|
||||||
|
case "boolean":
|
||||||
|
return "bool"
|
||||||
|
case "character varying", "varchar", "text", "char", "character":
|
||||||
|
return "str"
|
||||||
|
case "timestamp with time zone", "timestamp without time zone", "timestamp":
|
||||||
|
return "datetime"
|
||||||
|
case "date":
|
||||||
|
return "date"
|
||||||
|
case "time with time zone", "time without time zone", "time":
|
||||||
|
return "time"
|
||||||
|
case "json", "jsonb":
|
||||||
|
return "dict"
|
||||||
|
case "uuid":
|
||||||
|
return "UUID"
|
||||||
|
case "bytea":
|
||||||
|
return "bytes"
|
||||||
|
case "USER-DEFINED":
|
||||||
|
// This is likely an enum
|
||||||
|
return "str" // Will be replaced with specific enum type in generator
|
||||||
|
default:
|
||||||
|
return "Any"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetSQLAlchemyType converts PostgreSQL data type to SQLAlchemy type
|
||||||
|
func GetSQLAlchemyType(col Column) string {
|
||||||
|
switch col.DataType {
|
||||||
|
case "integer":
|
||||||
|
return "Integer"
|
||||||
|
case "smallint":
|
||||||
|
return "SmallInteger"
|
||||||
|
case "bigint":
|
||||||
|
return "BigInteger"
|
||||||
|
case "numeric", "decimal":
|
||||||
|
if col.NumericPrecision.Valid && col.NumericScale.Valid {
|
||||||
|
return fmt.Sprintf("Numeric(%d, %d)", col.NumericPrecision.Int64, col.NumericScale.Int64)
|
||||||
|
}
|
||||||
|
return "Numeric"
|
||||||
|
case "real":
|
||||||
|
return "Float"
|
||||||
|
case "double precision":
|
||||||
|
return "Float"
|
||||||
|
case "boolean":
|
||||||
|
return "Boolean"
|
||||||
|
case "character varying", "varchar":
|
||||||
|
if col.CharMaxLength.Valid {
|
||||||
|
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
|
||||||
|
}
|
||||||
|
return "String"
|
||||||
|
case "char", "character":
|
||||||
|
if col.CharMaxLength.Valid {
|
||||||
|
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
|
||||||
|
}
|
||||||
|
return "String(1)"
|
||||||
|
case "text":
|
||||||
|
return "Text"
|
||||||
|
case "timestamp with time zone":
|
||||||
|
return "DateTime(timezone=True)"
|
||||||
|
case "timestamp without time zone", "timestamp":
|
||||||
|
return "DateTime"
|
||||||
|
case "date":
|
||||||
|
return "Date"
|
||||||
|
case "time with time zone", "time without time zone", "time":
|
||||||
|
return "Time"
|
||||||
|
case "json":
|
||||||
|
return "JSON"
|
||||||
|
case "jsonb":
|
||||||
|
return "JSONB"
|
||||||
|
case "uuid":
|
||||||
|
return "UUID"
|
||||||
|
case "bytea":
|
||||||
|
return "LargeBinary"
|
||||||
|
case "USER-DEFINED":
|
||||||
|
// Will be handled separately for enums
|
||||||
|
return "Enum"
|
||||||
|
default:
|
||||||
|
return "String"
|
||||||
|
}
|
||||||
|
}
|
||||||
40
internal/generator/enum.go
Normal file
40
internal/generator/enum.go
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/naming"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateEnum generates the enum types file
|
||||||
|
func GenerateEnum(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from enum import StrEnum\n\n")
|
||||||
|
b.WriteString("from televend_core.databases.enum import EnumMixin\n\n\n")
|
||||||
|
|
||||||
|
// Generate each enum type
|
||||||
|
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||||
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf("class %s(EnumMixin, StrEnum):\n", enumName))
|
||||||
|
|
||||||
|
if len(enumType.Values) == 0 {
|
||||||
|
b.WriteString(" pass\n")
|
||||||
|
} else {
|
||||||
|
for _, value := range enumType.Values {
|
||||||
|
// Convert value to valid Python identifier
|
||||||
|
// Usually enum values are already uppercase like "OPEN", "IN_PROGRESS"
|
||||||
|
identifier := strings.ToUpper(strings.ReplaceAll(value, " ", "_"))
|
||||||
|
identifier = strings.ReplaceAll(identifier, "-", "_")
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf(" %s = \"%s\"\n", identifier, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
203
internal/generator/factory.go
Normal file
203
internal/generator/factory.go
Normal file
@ -0,0 +1,203 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/naming"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateFactory generates the factory class
|
||||||
|
func GenerateFactory(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from __future__ import annotations\n\n")
|
||||||
|
b.WriteString("from typing import Type\n\n")
|
||||||
|
b.WriteString("import factory\n\n")
|
||||||
|
|
||||||
|
// Import factories for related models
|
||||||
|
fkImports := make(map[string]string) // module_name -> entity_name
|
||||||
|
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||||
|
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
fkImports[moduleName] = entityName
|
||||||
|
}
|
||||||
|
|
||||||
|
for moduleName, entityName := range fkImports {
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
|
||||||
|
moduleName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
|
||||||
|
b.WriteString(")\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
|
||||||
|
ctx.ModuleName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
|
||||||
|
b.WriteString(")\n")
|
||||||
|
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
|
||||||
|
b.WriteString(" CustomSelfAttribute,\n")
|
||||||
|
b.WriteString(" TelevendBaseFactory,\n")
|
||||||
|
b.WriteString(")\n\n\n")
|
||||||
|
|
||||||
|
// Class definition
|
||||||
|
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
|
||||||
|
|
||||||
|
// Add boolean fields with defaults
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
if col.DataType == "boolean" {
|
||||||
|
defaultValue := "True"
|
||||||
|
if col.Name == "alive" {
|
||||||
|
defaultValue = "True"
|
||||||
|
} else {
|
||||||
|
defaultValue = "False"
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add id field
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
b.WriteString(fmt.Sprintf(" %s = None\n", col.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// Generate faker fields for each column
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
if col.IsPrimaryKey || col.DataType == "boolean" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip foreign keys, we'll handle them separately
|
||||||
|
if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fakerDef := generateFakerField(col, ctx)
|
||||||
|
if fakerDef != "" {
|
||||||
|
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate foreign key relationships
|
||||||
|
if len(ctx.TableInfo.ForeignKeys) > 0 {
|
||||||
|
b.WriteString("\n")
|
||||||
|
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||||
|
relationName := GetRelationshipName(fk.ColumnName)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n",
|
||||||
|
relationName, relationName, entityName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n",
|
||||||
|
fk.ColumnName, relationName, relationName))
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Meta class
|
||||||
|
b.WriteString(" class Meta:\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName))
|
||||||
|
b.WriteString("\n")
|
||||||
|
|
||||||
|
// create_minimal method
|
||||||
|
b.WriteString(" @classmethod\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n",
|
||||||
|
ctx.EntityName, ctx.EntityName))
|
||||||
|
b.WriteString(" minimal_params = {\n")
|
||||||
|
|
||||||
|
// Add foreign key params
|
||||||
|
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||||
|
relationName := GetRelationshipName(fk.ColumnName)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
|
||||||
|
// Check if this FK is required
|
||||||
|
var col *database.Column
|
||||||
|
for i := range ctx.TableInfo.Columns {
|
||||||
|
if ctx.TableInfo.Columns[i].Name == fk.ColumnName {
|
||||||
|
col = &ctx.TableInfo.Columns[i]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if col != nil && !col.IsNullable {
|
||||||
|
// Required FK
|
||||||
|
b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n",
|
||||||
|
relationName, relationName, entityName))
|
||||||
|
} else {
|
||||||
|
// Optional FK
|
||||||
|
b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(" }\n")
|
||||||
|
b.WriteString(" minimal_params.update(kwargs)\n")
|
||||||
|
b.WriteString(" return cls.create(**minimal_params)\n")
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateFakerField(col database.Column, ctx *Context) string {
|
||||||
|
pythonType := database.GetPythonType(col)
|
||||||
|
|
||||||
|
switch pythonType {
|
||||||
|
case "Decimal":
|
||||||
|
precision := 7
|
||||||
|
scale := 4
|
||||||
|
if col.NumericPrecision.Valid {
|
||||||
|
precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64)
|
||||||
|
}
|
||||||
|
if col.NumericScale.Valid {
|
||||||
|
scale = int(col.NumericScale.Int64)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)",
|
||||||
|
precision, scale)
|
||||||
|
|
||||||
|
case "int":
|
||||||
|
if col.DataType == "bigint" {
|
||||||
|
return "factory.Faker(\"pyint\")"
|
||||||
|
}
|
||||||
|
return "factory.Faker(\"pyint\")"
|
||||||
|
|
||||||
|
case "str":
|
||||||
|
// Check if it's an enum
|
||||||
|
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||||
|
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||||
|
if len(enumType.Values) > 0 {
|
||||||
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
|
return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
maxLen := 255
|
||||||
|
if col.CharMaxLength.Valid {
|
||||||
|
maxLen = int(col.CharMaxLength.Int64)
|
||||||
|
}
|
||||||
|
if col.DataType == "text" {
|
||||||
|
return "factory.Faker(\"text\", max_nb_chars=500)"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen)
|
||||||
|
|
||||||
|
case "datetime":
|
||||||
|
return "factory.Faker(\"date_time\")"
|
||||||
|
|
||||||
|
case "date":
|
||||||
|
return "factory.Faker(\"date\")"
|
||||||
|
|
||||||
|
case "time":
|
||||||
|
return "factory.Faker(\"time\")"
|
||||||
|
|
||||||
|
case "bool":
|
||||||
|
return "factory.Faker(\"boolean\")"
|
||||||
|
|
||||||
|
case "dict":
|
||||||
|
return "factory.Faker(\"pydict\")"
|
||||||
|
|
||||||
|
default:
|
||||||
|
return "None"
|
||||||
|
}
|
||||||
|
}
|
||||||
77
internal/generator/filter.go
Normal file
77
internal/generator/filter.go
Normal file
@ -0,0 +1,77 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateFilter generates the filter class
|
||||||
|
func GenerateFilter(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
|
||||||
|
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||||
|
ctx.ModuleName, ctx.EntityName))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Class definition
|
||||||
|
b.WriteString(fmt.Sprintf("class %sFilter(BaseFilter):\n", ctx.EntityName))
|
||||||
|
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
|
||||||
|
|
||||||
|
// Generate filters based on rules:
|
||||||
|
// - Boolean fields: EQ operator
|
||||||
|
// - ID fields: both EQ and IN operators
|
||||||
|
// - Text fields: EQ operator
|
||||||
|
|
||||||
|
hasFilters := false
|
||||||
|
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
if !ShouldGenerateFilter(col) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
hasFilters = true
|
||||||
|
|
||||||
|
// Boolean fields
|
||||||
|
if col.DataType == "boolean" {
|
||||||
|
defaultVal := "None"
|
||||||
|
if col.Name == "alive" {
|
||||||
|
defaultVal = "True"
|
||||||
|
}
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: bool | None = filterfield(operator=EQ, default=%s)\n",
|
||||||
|
col.Name, defaultVal))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID fields (both EQ and IN)
|
||||||
|
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
|
||||||
|
// Single ID with EQ
|
||||||
|
pythonType := database.GetPythonType(col)
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: %s | None = filterfield(operator=EQ)\n",
|
||||||
|
col.Name, pythonType))
|
||||||
|
|
||||||
|
// Multiple IDs with IN (plural field name)
|
||||||
|
pluralFieldName := GetFilterFieldName(col.Name, true)
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: list[%s] | None = filterfield(field=\"%s\", operator=IN)\n",
|
||||||
|
pluralFieldName, pythonType, col.Name))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Text fields
|
||||||
|
if (col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||||
|
col.DataType == "text" || col.DataType == "char" || col.DataType == "character") &&
|
||||||
|
!strings.HasSuffix(col.Name, "_id") && col.Name != "id" {
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n",
|
||||||
|
col.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no filters were generated, add a pass statement
|
||||||
|
if !hasFilters {
|
||||||
|
b.WriteString(" pass\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
195
internal/generator/generator.go
Normal file
195
internal/generator/generator.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/naming"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Context contains all information needed for code generation
|
||||||
|
type Context struct {
|
||||||
|
TableInfo *database.TableInfo
|
||||||
|
EntityName string // Singular, PascalCase (e.g., "CashbagConform")
|
||||||
|
ModuleName string // Singular, snake_case (e.g., "cashbag_conform")
|
||||||
|
TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE")
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContext creates a new generation context
|
||||||
|
func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context {
|
||||||
|
moduleName := naming.SingularizeTableName(tableInfo.TableName)
|
||||||
|
|
||||||
|
entityName := entityNameOverride
|
||||||
|
if entityName == "" {
|
||||||
|
entityName = naming.ToPascalCase(moduleName)
|
||||||
|
}
|
||||||
|
|
||||||
|
tableConstant := strings.ToUpper(moduleName) + "_TABLE"
|
||||||
|
|
||||||
|
return &Context{
|
||||||
|
TableInfo: tableInfo,
|
||||||
|
EntityName: entityName,
|
||||||
|
ModuleName: moduleName,
|
||||||
|
TableConstant: tableConstant,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationshipName returns the relationship name for a foreign key
|
||||||
|
// Strips _id suffix and converts to snake_case
|
||||||
|
func GetRelationshipName(fkColumnName string) string {
|
||||||
|
name := fkColumnName
|
||||||
|
if strings.HasSuffix(name, "_id") {
|
||||||
|
name = name[:len(name)-3]
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationshipEntityName returns the entity name for a foreign key's target table
|
||||||
|
func GetRelationshipEntityName(tableName string) string {
|
||||||
|
singular := naming.SingularizeTableName(tableName)
|
||||||
|
return naming.ToPascalCase(singular)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRelationshipModuleName returns the module name for a foreign key's target table
|
||||||
|
func GetRelationshipModuleName(tableName string) string {
|
||||||
|
return naming.SingularizeTableName(tableName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFilterFieldName returns the filter field name for a column
|
||||||
|
// For ID fields with IN operator, pluralizes the name
|
||||||
|
func GetFilterFieldName(columnName string, useIN bool) string {
|
||||||
|
if useIN && strings.HasSuffix(columnName, "_id") {
|
||||||
|
// Remove _id, pluralize, add back _ids
|
||||||
|
base := columnName[:len(columnName)-3]
|
||||||
|
if base == "" {
|
||||||
|
return "ids"
|
||||||
|
}
|
||||||
|
return naming.Pluralize(base) + "_ids"
|
||||||
|
}
|
||||||
|
if useIN && columnName == "id" {
|
||||||
|
return "ids"
|
||||||
|
}
|
||||||
|
return columnName
|
||||||
|
}
|
||||||
|
|
||||||
|
// ShouldGenerateFilter determines if a column should have a filter
|
||||||
|
func ShouldGenerateFilter(col database.Column) bool {
|
||||||
|
// Generate filters for boolean, ID, and text fields
|
||||||
|
if col.DataType == "boolean" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||||
|
col.DataType == "text" || col.DataType == "char" || col.DataType == "character" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPythonTypeForColumn returns the Python type annotation for a column
|
||||||
|
func GetPythonTypeForColumn(col database.Column, ctx *Context) string {
|
||||||
|
baseType := database.GetPythonType(col)
|
||||||
|
|
||||||
|
// Handle enum types
|
||||||
|
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||||
|
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||||
|
baseType = naming.ToPascalCase(enumType.TypeName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Decimal type
|
||||||
|
if baseType == "Decimal" {
|
||||||
|
baseType = "Decimal"
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseType
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsDecimalImport checks if any column uses Decimal type
|
||||||
|
func NeedsDecimalImport(columns []database.Column) bool {
|
||||||
|
for _, col := range columns {
|
||||||
|
if database.GetPythonType(col) == "Decimal" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// NeedsDatetimeImport checks if any column uses datetime type
|
||||||
|
func NeedsDatetimeImport(columns []database.Column) bool {
|
||||||
|
for _, col := range columns {
|
||||||
|
pyType := database.GetPythonType(col)
|
||||||
|
if pyType == "datetime" || pyType == "date" || pyType == "time" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRequiredColumns returns columns that are not nullable and don't have defaults
|
||||||
|
func GetRequiredColumns(columns []database.Column) []database.Column {
|
||||||
|
var required []database.Column
|
||||||
|
for _, col := range columns {
|
||||||
|
if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey {
|
||||||
|
required = append(required, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return required
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOptionalColumns returns columns that are nullable or have defaults
|
||||||
|
func GetOptionalColumns(columns []database.Column) []database.Column {
|
||||||
|
var optional []database.Column
|
||||||
|
for _, col := range columns {
|
||||||
|
if col.IsNullable || col.ColumnDefault.Valid {
|
||||||
|
optional = append(optional, col)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return optional
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetForeignKeyForColumn returns the foreign key info for a column, if it exists
|
||||||
|
func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey {
|
||||||
|
for _, fk := range foreignKeys {
|
||||||
|
if fk.ColumnName == columnName {
|
||||||
|
return &fk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateFiles generates all Python files for the entity
|
||||||
|
func GenerateFiles(ctx *Context, outputDir string) error {
|
||||||
|
generators := map[string]func(*Context) (string, error){
|
||||||
|
"table.py": GenerateTable,
|
||||||
|
"model.py": GenerateModel,
|
||||||
|
"filter.py": GenerateFilter,
|
||||||
|
"load_options.py": GenerateLoadOptions,
|
||||||
|
"repository.py": GenerateRepository,
|
||||||
|
"manager.py": GenerateManager,
|
||||||
|
"factory.py": GenerateFactory,
|
||||||
|
"mapper.py": GenerateMapper,
|
||||||
|
"__init__.py": GenerateInit,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate enum.py if there are enum types
|
||||||
|
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||||
|
generators["enum.py"] = GenerateEnum
|
||||||
|
}
|
||||||
|
|
||||||
|
for filename, generator := range generators {
|
||||||
|
content, err := generator(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to generate %s: %w", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write to file (this will be handled by the main function)
|
||||||
|
// For now, just return the content
|
||||||
|
_ = content
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
7
internal/generator/init.go
Normal file
7
internal/generator/init.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
// GenerateInit generates an empty __init__.py file
|
||||||
|
func GenerateInit(ctx *Context) (string, error) {
|
||||||
|
// Empty __init__.py file
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
38
internal/generator/load_options.go
Normal file
38
internal/generator/load_options.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateLoadOptions generates the load options class
|
||||||
|
func GenerateLoadOptions(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
|
||||||
|
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||||
|
ctx.ModuleName, ctx.EntityName))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Class definition
|
||||||
|
b.WriteString(fmt.Sprintf("class %sLoadOptions(LoadOptions):\n", ctx.EntityName))
|
||||||
|
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
|
||||||
|
|
||||||
|
// Generate load options for all foreign key relationships
|
||||||
|
hasRelationships := false
|
||||||
|
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||||
|
hasRelationships = true
|
||||||
|
relationName := GetRelationshipName(fk.ColumnName)
|
||||||
|
b.WriteString(fmt.Sprintf(" load_%s: bool = joinload(relations=[\"%s\"])\n",
|
||||||
|
relationName, relationName))
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no relationships, add pass
|
||||||
|
if !hasRelationships {
|
||||||
|
b.WriteString(" pass\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
34
internal/generator/manager.go
Normal file
34
internal/generator/manager.go
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateManager generates the manager class
|
||||||
|
func GenerateManager(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
||||||
|
ctx.ModuleName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||||
|
b.WriteString(")\n")
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||||
|
ctx.ModuleName, ctx.EntityName))
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n",
|
||||||
|
ctx.ModuleName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
|
||||||
|
b.WriteString(")\n")
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Class definition
|
||||||
|
b.WriteString(fmt.Sprintf("class %sManager(\n", ctx.EntityName))
|
||||||
|
b.WriteString(fmt.Sprintf(" CRUDManager[%s, %sFilter, %sRepository]\n",
|
||||||
|
ctx.EntityName, ctx.EntityName, ctx.EntityName))
|
||||||
|
b.WriteString("):\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" repository_cls = %sRepository\n", ctx.EntityName))
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
48
internal/generator/mapper.go
Normal file
48
internal/generator/mapper.go
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateMapper generates the mapper snippet (without imports)
|
||||||
|
func GenerateMapper(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Mapper registration (snippet only, no imports)
|
||||||
|
b.WriteString(" mapper_registry.map_imperatively(\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" class_=%s,\n", ctx.EntityName))
|
||||||
|
b.WriteString(fmt.Sprintf(" local_table=%s,\n", ctx.TableConstant))
|
||||||
|
b.WriteString(" properties={\n")
|
||||||
|
|
||||||
|
// Generate relationships for all foreign keys
|
||||||
|
fkRelationships := make(map[string][]string) // entity -> []column_names
|
||||||
|
|
||||||
|
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||||
|
relationName := GetRelationshipName(fk.ColumnName)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
|
||||||
|
// Group by entity name to handle multiple FKs to same table
|
||||||
|
fkRelationships[entityName] = append(fkRelationships[entityName], fk.ColumnName)
|
||||||
|
|
||||||
|
if len(fkRelationships[entityName]) == 1 {
|
||||||
|
// First FK to this table
|
||||||
|
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %s, lazy=relationship_loading_strategy.value\n", entityName))
|
||||||
|
} else {
|
||||||
|
// Multiple FKs to same table, need to specify foreign_keys
|
||||||
|
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %s,\n", entityName))
|
||||||
|
b.WriteString(" lazy=relationship_loading_strategy.value,\n")
|
||||||
|
b.WriteString(fmt.Sprintf(" foreign_keys=%s.columns.%s,\n",
|
||||||
|
ctx.TableConstant, fk.ColumnName))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(" ),\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(" },\n")
|
||||||
|
b.WriteString(" )\n")
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
124
internal/generator/model.go
Normal file
124
internal/generator/model.go
Normal file
@ -0,0 +1,124 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/naming"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateModel generates the dataclass model
|
||||||
|
func GenerateModel(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from dataclasses import dataclass\n")
|
||||||
|
|
||||||
|
// Check what we need to import
|
||||||
|
needsDatetime := NeedsDatetimeImport(ctx.TableInfo.Columns)
|
||||||
|
needsDecimal := NeedsDecimalImport(ctx.TableInfo.Columns)
|
||||||
|
|
||||||
|
if needsDatetime {
|
||||||
|
b.WriteString("from datetime import datetime\n")
|
||||||
|
}
|
||||||
|
if needsDecimal {
|
||||||
|
b.WriteString("from decimal import Decimal\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n")
|
||||||
|
b.WriteString("from televend_core.databases.base_model import Base\n")
|
||||||
|
|
||||||
|
// Import related models for foreign keys
|
||||||
|
fkImports := make(map[string]string) // module_name -> entity_name
|
||||||
|
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||||
|
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
fkImports[moduleName] = entityName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Import enum types
|
||||||
|
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
||||||
|
ctx.ModuleName))
|
||||||
|
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||||
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
|
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||||
|
}
|
||||||
|
b.WriteString(")\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write foreign key imports
|
||||||
|
for moduleName, entityName := range fkImports {
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||||
|
moduleName, entityName))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Class definition
|
||||||
|
b.WriteString("@dataclass\n")
|
||||||
|
b.WriteString(fmt.Sprintf("class %s(Base):\n", ctx.EntityName))
|
||||||
|
|
||||||
|
// Get required and optional columns
|
||||||
|
requiredCols := GetRequiredColumns(ctx.TableInfo.Columns)
|
||||||
|
optionalCols := GetOptionalColumns(ctx.TableInfo.Columns)
|
||||||
|
|
||||||
|
// Required fields (non-nullable, no default, not auto-increment, not PK)
|
||||||
|
for _, col := range requiredCols {
|
||||||
|
fieldName := col.Name
|
||||||
|
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||||
|
|
||||||
|
// Regular field
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: %s\n", fieldName, pythonType))
|
||||||
|
|
||||||
|
// Add relationship field if this is a foreign key and column ends with _id
|
||||||
|
// (to avoid name clashes with FK columns that don't follow _id convention)
|
||||||
|
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||||
|
if strings.HasSuffix(col.Name, "_id") {
|
||||||
|
relationName := GetRelationshipName(col.Name)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: %s\n", relationName, entityName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Empty line between required and optional
|
||||||
|
if len(requiredCols) > 0 && len(optionalCols) > 0 {
|
||||||
|
b.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optional fields
|
||||||
|
for _, col := range optionalCols {
|
||||||
|
// Skip primary key, we'll add it at the end
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fieldName := col.Name
|
||||||
|
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||||
|
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", fieldName, pythonType))
|
||||||
|
|
||||||
|
// Add relationship field if this is a foreign key and column ends with _id
|
||||||
|
// (to avoid name clashes with FK columns that don't follow _id convention)
|
||||||
|
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||||
|
if strings.HasSuffix(col.Name, "_id") {
|
||||||
|
relationName := GetRelationshipName(col.Name)
|
||||||
|
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", relationName, entityName))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add primary key at the end
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
b.WriteString("\n")
|
||||||
|
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||||
|
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", col.Name, pythonType))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
28
internal/generator/repository.go
Normal file
28
internal/generator/repository.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateRepository generates the repository class
|
||||||
|
func GenerateRepository(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
||||||
|
ctx.ModuleName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||||
|
b.WriteString(")\n")
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||||
|
ctx.ModuleName, ctx.EntityName))
|
||||||
|
b.WriteString("\n\n")
|
||||||
|
|
||||||
|
// Class definition
|
||||||
|
b.WriteString(fmt.Sprintf("class %sRepository(CRUDRepository[%s, %sFilter]):\n",
|
||||||
|
ctx.EntityName, ctx.EntityName, ctx.EntityName))
|
||||||
|
b.WriteString(fmt.Sprintf(" model_cls = %s\n", ctx.EntityName))
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
160
internal/generator/table.go
Normal file
160
internal/generator/table.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package generator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/entity-maker/entity-maker/internal/database"
|
||||||
|
"github.com/entity-maker/entity-maker/internal/naming"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GenerateTable generates the SQLAlchemy table definition
|
||||||
|
func GenerateTable(ctx *Context) (string, error) {
|
||||||
|
var b strings.Builder
|
||||||
|
|
||||||
|
// Imports
|
||||||
|
b.WriteString("from sqlalchemy import (\n")
|
||||||
|
|
||||||
|
// Collect unique imports
|
||||||
|
imports := make(map[string]bool)
|
||||||
|
imports["Column"] = true
|
||||||
|
imports["Table"] = true
|
||||||
|
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
sqlType := database.GetSQLAlchemyType(col)
|
||||||
|
|
||||||
|
// Extract base type name (before parentheses)
|
||||||
|
typeName := strings.Split(sqlType, "(")[0]
|
||||||
|
imports[typeName] = true
|
||||||
|
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
imports["Integer"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for foreign keys
|
||||||
|
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||||
|
imports["ForeignKey"] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for enums
|
||||||
|
if col.DataType == "USER-DEFINED" {
|
||||||
|
imports["Enum"] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort and write imports
|
||||||
|
importList := []string{}
|
||||||
|
for imp := range imports {
|
||||||
|
importList = append(importList, imp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write imports in a reasonable order
|
||||||
|
orderedImports := []string{
|
||||||
|
"BigInteger", "Boolean", "Column", "Date", "DateTime", "Enum",
|
||||||
|
"Float", "ForeignKey", "Integer", "JSON", "JSONB", "LargeBinary",
|
||||||
|
"Numeric", "SmallInteger", "String", "Table", "Text", "Time", "UUID",
|
||||||
|
}
|
||||||
|
|
||||||
|
first := true
|
||||||
|
for _, imp := range orderedImports {
|
||||||
|
if imports[imp] {
|
||||||
|
if !first {
|
||||||
|
b.WriteString(",\n")
|
||||||
|
} else {
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
b.WriteString(" " + imp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
b.WriteString(",\n)\n\n")
|
||||||
|
|
||||||
|
// Import enum types if they exist
|
||||||
|
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||||
|
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||||
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
|
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
||||||
|
ctx.ModuleName))
|
||||||
|
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||||
|
b.WriteString(")\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n")
|
||||||
|
|
||||||
|
// Table definition
|
||||||
|
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
|
||||||
|
b.WriteString(fmt.Sprintf(" \"%s\",\n", ctx.TableInfo.TableName))
|
||||||
|
b.WriteString(" metadata_obj,\n")
|
||||||
|
|
||||||
|
// Columns
|
||||||
|
for _, col := range ctx.TableInfo.Columns {
|
||||||
|
b.WriteString(generateColumnDefinition(col, ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(")\n")
|
||||||
|
|
||||||
|
return b.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateColumnDefinition(col database.Column, ctx *Context) string {
|
||||||
|
var parts []string
|
||||||
|
|
||||||
|
// Column name
|
||||||
|
parts = append(parts, fmt.Sprintf("\"%s\"", col.Name))
|
||||||
|
|
||||||
|
// Column type
|
||||||
|
sqlType := database.GetSQLAlchemyType(col)
|
||||||
|
|
||||||
|
// Handle enum types
|
||||||
|
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||||
|
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||||
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||||
|
sqlType = fmt.Sprintf("Enum(\n *%s.to_value_list(),\n name=\"%s\",\n )",
|
||||||
|
enumName, enumType.TypeName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
parts = append(parts, sqlType)
|
||||||
|
|
||||||
|
// Foreign key
|
||||||
|
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||||
|
fkDef := fmt.Sprintf("ForeignKey(\"%s.%s\", deferrable=True, initially=\"DEFERRED\")",
|
||||||
|
fk.ForeignTableName, fk.ForeignColumnName)
|
||||||
|
parts = append(parts, fkDef)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Primary key
|
||||||
|
if col.IsPrimaryKey {
|
||||||
|
parts = append(parts, "primary_key=True")
|
||||||
|
if col.IsAutoIncrement {
|
||||||
|
parts = append(parts, "autoincrement=True")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Nullable
|
||||||
|
if !col.IsNullable && !col.IsPrimaryKey {
|
||||||
|
parts = append(parts, "nullable=False")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unique
|
||||||
|
// Note: We don't have unique constraint info in our introspection yet
|
||||||
|
// This would need to be added if needed
|
||||||
|
|
||||||
|
// Format the column definition
|
||||||
|
result := " Column("
|
||||||
|
|
||||||
|
// Check if we need multiline formatting (for complex types like Enum)
|
||||||
|
if strings.Contains(sqlType, "\n") {
|
||||||
|
// Multiline format
|
||||||
|
result += parts[0] + ",\n " + parts[1]
|
||||||
|
for _, part := range parts[2:] {
|
||||||
|
result += ",\n " + part
|
||||||
|
}
|
||||||
|
result += ",\n ),\n"
|
||||||
|
} else {
|
||||||
|
// Single line format
|
||||||
|
result += strings.Join(parts, ", ") + "),\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
227
internal/naming/naming.go
Normal file
227
internal/naming/naming.go
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
package naming
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ToSnakeCase converts a string to snake_case
|
||||||
|
func ToSnakeCase(s string) string {
|
||||||
|
var result []rune
|
||||||
|
for i, r := range s {
|
||||||
|
if unicode.IsUpper(r) {
|
||||||
|
if i > 0 {
|
||||||
|
result = append(result, '_')
|
||||||
|
}
|
||||||
|
result = append(result, unicode.ToLower(r))
|
||||||
|
} else {
|
||||||
|
result = append(result, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToPascalCase converts snake_case to PascalCase
|
||||||
|
func ToPascalCase(s string) string {
|
||||||
|
parts := strings.Split(s, "_")
|
||||||
|
for i, part := range parts {
|
||||||
|
if len(part) > 0 {
|
||||||
|
parts[i] = strings.ToUpper(part[:1]) + part[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Singularize converts a plural word to singular
|
||||||
|
// Uses common English pluralization rules
|
||||||
|
func Singularize(word string) string {
|
||||||
|
// Handle empty string
|
||||||
|
if word == "" {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common irregular plurals
|
||||||
|
irregulars := map[string]string{
|
||||||
|
"people": "person",
|
||||||
|
"men": "man",
|
||||||
|
"women": "woman",
|
||||||
|
"children": "child",
|
||||||
|
"teeth": "tooth",
|
||||||
|
"feet": "foot",
|
||||||
|
"mice": "mouse",
|
||||||
|
"geese": "goose",
|
||||||
|
"oxen": "ox",
|
||||||
|
"sheep": "sheep",
|
||||||
|
"fish": "fish",
|
||||||
|
"deer": "deer",
|
||||||
|
"series": "series",
|
||||||
|
"species": "species",
|
||||||
|
"quizzes": "quiz",
|
||||||
|
"analyses": "analysis",
|
||||||
|
"diagnoses": "diagnosis",
|
||||||
|
"oases": "oasis",
|
||||||
|
"theses": "thesis",
|
||||||
|
"crises": "crisis",
|
||||||
|
"phenomena": "phenomenon",
|
||||||
|
"criteria": "criterion",
|
||||||
|
"data": "datum",
|
||||||
|
}
|
||||||
|
|
||||||
|
lower := strings.ToLower(word)
|
||||||
|
if singular, ok := irregulars[lower]; ok {
|
||||||
|
return preserveCase(word, singular)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'ies' -> 'y'
|
||||||
|
if strings.HasSuffix(lower, "ies") && len(word) > 3 {
|
||||||
|
return word[:len(word)-3] + "y"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'ves' -> 'fe' or 'f'
|
||||||
|
if strings.HasSuffix(lower, "ves") && len(word) > 3 {
|
||||||
|
base := word[:len(word)-3]
|
||||||
|
// Common words that end in 'fe'
|
||||||
|
if strings.HasSuffix(strings.ToLower(base), "li") ||
|
||||||
|
strings.HasSuffix(strings.ToLower(base), "wi") ||
|
||||||
|
strings.HasSuffix(strings.ToLower(base), "kni") {
|
||||||
|
return base + "fe"
|
||||||
|
}
|
||||||
|
return base + "f"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'xes', 'ses', 'shes', 'ches' -> remove 'es'
|
||||||
|
if strings.HasSuffix(lower, "xes") ||
|
||||||
|
strings.HasSuffix(lower, "ses") ||
|
||||||
|
strings.HasSuffix(lower, "shes") ||
|
||||||
|
strings.HasSuffix(lower, "ches") {
|
||||||
|
if len(word) > 2 {
|
||||||
|
return word[:len(word)-2]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'oes' -> 'o'
|
||||||
|
if strings.HasSuffix(lower, "oes") && len(word) > 3 {
|
||||||
|
return word[:len(word)-2]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle simple 's' suffix
|
||||||
|
if strings.HasSuffix(lower, "s") && len(word) > 1 {
|
||||||
|
// Don't remove 's' from words that naturally end in 's'
|
||||||
|
if !strings.HasSuffix(lower, "ss") &&
|
||||||
|
!strings.HasSuffix(lower, "us") &&
|
||||||
|
!strings.HasSuffix(lower, "is") {
|
||||||
|
return word[:len(word)-1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pluralize converts a singular word to plural
|
||||||
|
func Pluralize(word string) string {
|
||||||
|
if word == "" {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common irregular plurals
|
||||||
|
irregulars := map[string]string{
|
||||||
|
"person": "people",
|
||||||
|
"man": "men",
|
||||||
|
"woman": "women",
|
||||||
|
"child": "children",
|
||||||
|
"tooth": "teeth",
|
||||||
|
"foot": "feet",
|
||||||
|
"mouse": "mice",
|
||||||
|
"goose": "geese",
|
||||||
|
"ox": "oxen",
|
||||||
|
"sheep": "sheep",
|
||||||
|
"fish": "fish",
|
||||||
|
"deer": "deer",
|
||||||
|
"series": "series",
|
||||||
|
"species": "species",
|
||||||
|
"quiz": "quizzes",
|
||||||
|
"analysis": "analyses",
|
||||||
|
"diagnosis": "diagnoses",
|
||||||
|
"oasis": "oases",
|
||||||
|
"thesis": "theses",
|
||||||
|
"crisis": "crises",
|
||||||
|
"phenomenon": "phenomena",
|
||||||
|
"criterion": "criteria",
|
||||||
|
"datum": "data",
|
||||||
|
}
|
||||||
|
|
||||||
|
lower := strings.ToLower(word)
|
||||||
|
if plural, ok := irregulars[lower]; ok {
|
||||||
|
return preserveCase(word, plural)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Already plural (ends in 's' and not special case)
|
||||||
|
if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") {
|
||||||
|
return word
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'y' -> 'ies'
|
||||||
|
if strings.HasSuffix(lower, "y") && len(word) > 1 {
|
||||||
|
prevChar := rune(lower[len(lower)-2])
|
||||||
|
if !isVowel(prevChar) {
|
||||||
|
return word[:len(word)-1] + "ies"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'f' or 'fe' -> 'ves'
|
||||||
|
if strings.HasSuffix(lower, "f") {
|
||||||
|
return word[:len(word)-1] + "ves"
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(lower, "fe") {
|
||||||
|
return word[:len(word)-2] + "ves"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'o' -> 'oes'
|
||||||
|
if strings.HasSuffix(lower, "o") && len(word) > 1 {
|
||||||
|
prevChar := rune(lower[len(lower)-2])
|
||||||
|
if !isVowel(prevChar) {
|
||||||
|
return word + "es"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle words ending in 'x', 's', 'sh', 'ch' -> add 'es'
|
||||||
|
if strings.HasSuffix(lower, "x") ||
|
||||||
|
strings.HasSuffix(lower, "s") ||
|
||||||
|
strings.HasSuffix(lower, "sh") ||
|
||||||
|
strings.HasSuffix(lower, "ch") {
|
||||||
|
return word + "es"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default: just add 's'
|
||||||
|
return word + "s"
|
||||||
|
}
|
||||||
|
|
||||||
|
// preserveCase preserves the case pattern of the original word
|
||||||
|
func preserveCase(original, replacement string) string {
|
||||||
|
if len(original) == 0 {
|
||||||
|
return replacement
|
||||||
|
}
|
||||||
|
|
||||||
|
if unicode.IsUpper(rune(original[0])) {
|
||||||
|
return strings.ToUpper(replacement[:1]) + replacement[1:]
|
||||||
|
}
|
||||||
|
return replacement
|
||||||
|
}
|
||||||
|
|
||||||
|
// isVowel checks if a rune is a vowel
|
||||||
|
func isVowel(r rune) bool {
|
||||||
|
vowels := "aeiouAEIOU"
|
||||||
|
return strings.ContainsRune(vowels, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SingularizeTableName converts a table name to its singular form
|
||||||
|
// Handles snake_case table names
|
||||||
|
func SingularizeTableName(tableName string) string {
|
||||||
|
parts := strings.Split(tableName, "_")
|
||||||
|
if len(parts) > 0 {
|
||||||
|
// Only singularize the last part
|
||||||
|
lastIdx := len(parts) - 1
|
||||||
|
parts[lastIdx] = Singularize(parts[lastIdx])
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "_")
|
||||||
|
}
|
||||||
186
internal/prompt/prompt.go
Normal file
186
internal/prompt/prompt.go
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
package prompt
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/fatih/color"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
cyan = color.New(color.FgCyan).SprintFunc()
|
||||||
|
green = color.New(color.FgGreen).SprintFunc()
|
||||||
|
yellow = color.New(color.FgYellow).SprintFunc()
|
||||||
|
red = color.New(color.FgRed).SprintFunc()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Reader interface for testing
|
||||||
|
type Reader interface {
|
||||||
|
ReadString(delim byte) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptString prompts for a string value
|
||||||
|
func PromptString(label string, defaultValue string, required bool) (string, error) {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
return promptStringWithReader(reader, label, defaultValue, required)
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptStringWithReader(reader Reader, label string, defaultValue string, required bool) (string, error) {
|
||||||
|
for {
|
||||||
|
prompt := fmt.Sprintf("%s", cyan(label))
|
||||||
|
if defaultValue != "" {
|
||||||
|
prompt += fmt.Sprintf(" [%s]", green(defaultValue))
|
||||||
|
}
|
||||||
|
prompt += ": "
|
||||||
|
|
||||||
|
fmt.Print(prompt)
|
||||||
|
|
||||||
|
input, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
input = strings.TrimSpace(input)
|
||||||
|
|
||||||
|
// If input is empty, use default value
|
||||||
|
if input == "" {
|
||||||
|
if defaultValue != "" {
|
||||||
|
// Echo the selected default value
|
||||||
|
fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultValue), defaultValue)
|
||||||
|
fmt.Printf("\033[1A\r%s\n", fullPrompt)
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
if required {
|
||||||
|
fmt.Println(red("✗ This field is required"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return input, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptInt prompts for an integer value
|
||||||
|
func PromptInt(label string, defaultValue int, required bool) (int, error) {
|
||||||
|
reader := bufio.NewReader(os.Stdin)
|
||||||
|
return promptIntWithReader(reader, label, defaultValue, required)
|
||||||
|
}
|
||||||
|
|
||||||
|
func promptIntWithReader(reader Reader, label string, defaultValue int, required bool) (int, error) {
|
||||||
|
for {
|
||||||
|
defaultStr := ""
|
||||||
|
if defaultValue != 0 {
|
||||||
|
defaultStr = strconv.Itoa(defaultValue)
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt := fmt.Sprintf("%s", cyan(label))
|
||||||
|
if defaultStr != "" {
|
||||||
|
prompt += fmt.Sprintf(" [%s]", green(defaultStr))
|
||||||
|
}
|
||||||
|
prompt += ": "
|
||||||
|
|
||||||
|
fmt.Print(prompt)
|
||||||
|
|
||||||
|
input, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
input = strings.TrimSpace(input)
|
||||||
|
|
||||||
|
// If input is empty, use default value
|
||||||
|
if input == "" {
|
||||||
|
if defaultValue != 0 {
|
||||||
|
// Echo the selected default value
|
||||||
|
fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultStr), defaultStr)
|
||||||
|
fmt.Printf("\033[1A\r%s\n", fullPrompt)
|
||||||
|
return defaultValue, nil
|
||||||
|
}
|
||||||
|
if required {
|
||||||
|
fmt.Println(red("✗ This field is required"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse integer
|
||||||
|
value, err := strconv.Atoi(input)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Println(red("✗ Please enter a valid number"))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ValidateDirectory checks if a directory exists or can be created
|
||||||
|
func ValidateDirectory(path string) error {
|
||||||
|
// Check if path exists
|
||||||
|
info, err := os.Stat(path)
|
||||||
|
if err == nil {
|
||||||
|
// Path exists, check if it's a directory
|
||||||
|
if !info.IsDir() {
|
||||||
|
return fmt.Errorf("path exists but is not a directory")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path doesn't exist, try to create it
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
if err := os.MkdirAll(path, 0755); err != nil {
|
||||||
|
return fmt.Errorf("cannot create directory: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// PromptDirectory prompts for a directory path and validates it
|
||||||
|
func PromptDirectory(label string, defaultValue string, required bool) (string, error) {
|
||||||
|
for {
|
||||||
|
path, err := PromptString(label, defaultValue, required)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if path == "" && !required {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate directory
|
||||||
|
if err := ValidateDirectory(path); err != nil {
|
||||||
|
fmt.Printf("%s %s\n", red("✗"), err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return path, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintHeader prints a colored header
|
||||||
|
func PrintHeader(text string) {
|
||||||
|
header := color.New(color.FgCyan, color.Bold)
|
||||||
|
header.Println("\n" + text)
|
||||||
|
header.Println(strings.Repeat("=", len(text)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintSuccess prints a success message
|
||||||
|
func PrintSuccess(text string) {
|
||||||
|
fmt.Printf("%s %s\n", green("✓"), text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintError prints an error message
|
||||||
|
func PrintError(text string) {
|
||||||
|
fmt.Printf("%s %s\n", red("✗"), text)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrintInfo prints an info message
|
||||||
|
func PrintInfo(text string) {
|
||||||
|
fmt.Printf("%s %s\n", yellow("ℹ"), text)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user