From f9f67b6c930897dc55f891f795c39032f9cb79fe Mon Sep 17 00:00:00 2001 From: Eden Kirin Date: Fri, 31 Oct 2025 14:36:41 +0100 Subject: [PATCH] First working version --- .gitignore | 4 + Makefile | 10 + README.md | 136 ++++++++++++ cmd/entity-maker/main.go | 240 ++++++++++++++++++++ go.mod | 15 ++ go.sum | 15 ++ internal/config/config.go | 82 +++++++ internal/database/client.go | 67 ++++++ internal/database/introspector.go | 341 +++++++++++++++++++++++++++++ internal/generator/enum.go | 40 ++++ internal/generator/factory.go | 203 +++++++++++++++++ internal/generator/filter.go | 77 +++++++ internal/generator/generator.go | 195 +++++++++++++++++ internal/generator/init.go | 7 + internal/generator/load_options.go | 38 ++++ internal/generator/manager.go | 34 +++ internal/generator/mapper.go | 48 ++++ internal/generator/model.go | 124 +++++++++++ internal/generator/repository.go | 28 +++ internal/generator/table.go | 160 ++++++++++++++ internal/naming/naming.go | 227 +++++++++++++++++++ internal/prompt/prompt.go | 186 ++++++++++++++++ 22 files changed, 2277 insertions(+) create mode 100644 Makefile create mode 100644 README.md create mode 100644 cmd/entity-maker/main.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/config/config.go create mode 100644 internal/database/client.go create mode 100644 internal/database/introspector.go create mode 100644 internal/generator/enum.go create mode 100644 internal/generator/factory.go create mode 100644 internal/generator/filter.go create mode 100644 internal/generator/generator.go create mode 100644 internal/generator/init.go create mode 100644 internal/generator/load_options.go create mode 100644 internal/generator/manager.go create mode 100644 internal/generator/mapper.go create mode 100644 internal/generator/model.go create mode 100644 internal/generator/repository.go create mode 100644 internal/generator/table.go create mode 100644 internal/naming/naming.go create mode 100644 internal/prompt/prompt.go diff --git a/.gitignore b/.gitignore index 1f57b97..2f2e509 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,5 @@ +/build /output + +*.toml +/entity-maker diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2fd4a6b --- /dev/null +++ b/Makefile @@ -0,0 +1,10 @@ +EXEC=entity-maker + + +.PHONY: build +build: + @go build -ldflags "-s -w" -o ./build/${EXEC} ./cmd/entity-maker/main.go + + +upgrade-packages: + @go get -u ./... diff --git a/README.md b/README.md new file mode 100644 index 0000000..7c14ce8 --- /dev/null +++ b/README.md @@ -0,0 +1,136 @@ +# Entity Maker + +A command-line tool for generating SQLAlchemy entities from PostgreSQL database tables. + +## Features + +- Connects to PostgreSQL database and introspects table structure +- Generates Python 3.13 code using SQLAlchemy +- Creates complete entity scaffolding including: + - SQLAlchemy table definition + - Dataclass model + - Filter class + - Load options + - Repository + - Manager + - Factory (for testing) + - Mapper + - Enum types (if table uses PostgreSQL enums) + +## Installation + +### Build from source + +```bash +go build -o entity-maker ./cmd/entity-maker +``` + +## Usage + +### Interactive Mode + +Run the application without arguments to use interactive prompts: + +```bash +./entity-maker +``` + +The application will prompt you for: +- Database host (default: localhost) +- Database port (default: 5432) +- Database name (required) +- Database schema (default: public) +- Database user (default: postgres) +- Database password (default: postgres) +- Table name (required) +- Output directory (required) +- Entity name override (optional) + +### Command-Line Arguments + +You can also provide parameters via command-line flags: + +```bash +./entity-maker \ + -host localhost \ + -port 5432 \ + -db mydb \ + -schema public \ + -user postgres \ + -password secret \ + -table users \ + -output ./output \ + -entity User +``` + +### Available Flags + +- `-host` - Database host +- `-port` - Database port +- `-db` - Database name +- `-schema` - Database schema +- `-user` - Database user +- `-password` - Database password +- `-table` - Table name to generate entities from +- `-output` - Output directory path +- `-entity` - Entity name override (optional, defaults to singularized table name) + +## Configuration + +The application saves your settings to `~/.config/entity-maker.toml` for future use. When you run the application again, it will use these saved values as defaults. + +## Generated Files + +For a table named `users`, the following files will be generated in `output/user/`: + +- `table.py` - SQLAlchemy Table definition +- `model.py` - Dataclass model with type hints +- `filter.py` - Filter class for queries +- `load_options.py` - Relationship loading options +- `repository.py` - CRUD repository +- `manager.py` - Manager class +- `factory.py` - Factory for testing (uses factory_boy) +- `mapper.py` - SQLAlchemy mapper configuration +- `enum.py` - Enum types (only if table uses PostgreSQL enums) +- `__init__.py` - Package initialization + +## Example + +Generate entities for the `cashbag_conforms` table: + +```bash +./entity-maker -db mydb -table cashbag_conforms -output ./output +``` + +This will create files in `./output/cashbag_conform/`: +- Entity name: `CashbagConform` +- Module name: `cashbag_conform` +- All supporting files + +## Requirements + +- PostgreSQL database +- Go 1.21+ (for building) +- Python 3.13+ (for generated code) +- SQLAlchemy (for generated code) + +## Project Structure + +``` +entity-maker/ +├── cmd/ +│ └── entity-maker/ # Main application +├── internal/ +│ ├── config/ # Configuration management +│ ├── database/ # Database connection and introspection +│ ├── generator/ # Code generators +│ ├── naming/ # Naming utilities (pluralization, case conversion) +│ └── prompt/ # Interactive CLI prompts +├── example/ # Example output files +├── CLAUDE.md # Detailed specification +└── README.md # This file +``` + +## License + +See LICENSE file for details. diff --git a/cmd/entity-maker/main.go b/cmd/entity-maker/main.go new file mode 100644 index 0000000..c2823f1 --- /dev/null +++ b/cmd/entity-maker/main.go @@ -0,0 +1,240 @@ +package main + +import ( + "flag" + "fmt" + "os" + "path/filepath" + + "github.com/entity-maker/entity-maker/internal/config" + "github.com/entity-maker/entity-maker/internal/database" + "github.com/entity-maker/entity-maker/internal/generator" + "github.com/entity-maker/entity-maker/internal/naming" + "github.com/entity-maker/entity-maker/internal/prompt" + "github.com/fatih/color" +) + +func main() { + if err := run(); err != nil { + prompt.PrintError(err.Error()) + os.Exit(1) + } +} + +func run() error { + // Parse command line flags + var ( + dbHost = flag.String("host", "", "Database host") + dbPort = flag.Int("port", 0, "Database port") + dbName = flag.String("db", "", "Database name") + dbSchema = flag.String("schema", "", "Database schema") + dbUser = flag.String("user", "", "Database user") + dbPassword = flag.String("password", "", "Database password") + dbTable = flag.String("table", "", "Database table name") + outputDir = flag.String("output", "", "Output directory") + entityName = flag.String("entity", "", "Entity name override") + ) + flag.Parse() + + // Print header + header := color.New(color.FgCyan, color.Bold) + header.Println("\n╔════════════════════════════════════════╗") + header.Println("║ Entity Maker for TelevendCore ║") + header.Println("╚════════════════════════════════════════╝") + + // Load configuration + cfg, err := config.Load() + if err != nil { + return fmt.Errorf("failed to load config: %w", err) + } + + // Prompt for parameters + prompt.PrintHeader("Database Connection") + + // Override config with command line flags if provided + if *dbHost != "" { + cfg.DBHost = *dbHost + } + if *dbPort != 0 { + cfg.DBPort = *dbPort + } + if *dbName != "" { + cfg.DBName = *dbName + } + if *dbSchema != "" { + cfg.DBSchema = *dbSchema + } + if *dbUser != "" { + cfg.DBUser = *dbUser + } + if *dbPassword != "" { + cfg.DBPassword = *dbPassword + } + if *dbTable != "" { + cfg.DBTable = *dbTable + } + if *outputDir != "" { + cfg.OutputDir = *outputDir + } + if *entityName != "" { + cfg.EntityName = *entityName + } + + // Prompt for missing parameters + cfg.DBHost, err = prompt.PromptString("Database host", cfg.DBHost, false) + if err != nil { + return err + } + + cfg.DBPort, err = prompt.PromptInt("Database port", cfg.DBPort, false) + if err != nil { + return err + } + + cfg.DBName, err = prompt.PromptString("Database name", cfg.DBName, true) + if err != nil { + return err + } + + cfg.DBSchema, err = prompt.PromptString("Database schema", cfg.DBSchema, false) + if err != nil { + return err + } + + cfg.DBUser, err = prompt.PromptString("Database user", cfg.DBUser, false) + if err != nil { + return err + } + + cfg.DBPassword, err = prompt.PromptString("Database password", cfg.DBPassword, false) + if err != nil { + return err + } + + cfg.DBTable, err = prompt.PromptString("Table name", cfg.DBTable, true) + if err != nil { + return err + } + + cfg.OutputDir, err = prompt.PromptDirectory("Output directory", cfg.OutputDir, true) + if err != nil { + return err + } + + cfg.EntityName, err = prompt.PromptString("Entity name (optional)", cfg.EntityName, false) + if err != nil { + return err + } + + // Save configuration + if err := cfg.Save(); err != nil { + prompt.PrintError(fmt.Sprintf("Failed to save config: %v", err)) + // Don't fail, just warn + } else { + prompt.PrintSuccess("Configuration saved") + } + + // Connect to database + prompt.PrintHeader("Connecting to Database") + + dbClient, err := database.NewClient(database.Config{ + Host: cfg.DBHost, + Port: cfg.DBPort, + Database: cfg.DBName, + Schema: cfg.DBSchema, + User: cfg.DBUser, + Password: cfg.DBPassword, + }) + if err != nil { + return fmt.Errorf("database connection failed: %w", err) + } + defer dbClient.Close() + + prompt.PrintSuccess(fmt.Sprintf("Connected to %s@%s:%d/%s", cfg.DBUser, cfg.DBHost, cfg.DBPort, cfg.DBName)) + + // Check if table exists + exists, err := dbClient.TableExists(cfg.DBTable) + if err != nil { + return fmt.Errorf("failed to check table existence: %w", err) + } + if !exists { + return fmt.Errorf("table '%s' does not exist in schema '%s'", cfg.DBTable, cfg.DBSchema) + } + + prompt.PrintSuccess(fmt.Sprintf("Found table '%s'", cfg.DBTable)) + + // Introspect table + prompt.PrintHeader("Introspecting Table") + + tableInfo, err := dbClient.IntrospectTable(cfg.DBTable) + if err != nil { + return fmt.Errorf("failed to introspect table: %w", err) + } + + prompt.PrintSuccess(fmt.Sprintf("Found %d columns", len(tableInfo.Columns))) + prompt.PrintSuccess(fmt.Sprintf("Found %d foreign keys", len(tableInfo.ForeignKeys))) + prompt.PrintSuccess(fmt.Sprintf("Found %d enum types", len(tableInfo.EnumTypes))) + + // Create generation context + ctx := generator.NewContext(tableInfo, cfg.EntityName) + + // Create output directory + moduleName := naming.SingularizeTableName(cfg.DBTable) + moduleDir := filepath.Join(cfg.OutputDir, moduleName) + + if err := os.MkdirAll(moduleDir, 0755); err != nil { + return fmt.Errorf("failed to create output directory: %w", err) + } + + prompt.PrintSuccess(fmt.Sprintf("Created directory: %s", moduleDir)) + + // Generate files + prompt.PrintHeader("Generating Files") + + files := map[string]func(*generator.Context) (string, error){ + "table.py": generator.GenerateTable, + "model.py": generator.GenerateModel, + "filter.py": generator.GenerateFilter, + "load_options.py": generator.GenerateLoadOptions, + "repository.py": generator.GenerateRepository, + "manager.py": generator.GenerateManager, + "factory.py": generator.GenerateFactory, + "mapper.py": generator.GenerateMapper, + "__init__.py": generator.GenerateInit, + } + + // Add enum.py if there are enum types + if len(tableInfo.EnumTypes) > 0 { + files["enum.py"] = generator.GenerateEnum + } + + // Generate and write each file + for filename, genFunc := range files { + content, err := genFunc(ctx) + if err != nil { + return fmt.Errorf("failed to generate %s: %w", filename, err) + } + + filePath := filepath.Join(moduleDir, filename) + if err := os.WriteFile(filePath, []byte(content), 0644); err != nil { + return fmt.Errorf("failed to write %s: %w", filename, err) + } + + prompt.PrintSuccess(fmt.Sprintf("Generated %s", filename)) + } + + // Print summary + prompt.PrintHeader("Summary") + fmt.Printf("\n") + fmt.Printf(" Entity name: %s\n", color.GreenString(ctx.EntityName)) + fmt.Printf(" Module name: %s\n", color.GreenString(ctx.ModuleName)) + fmt.Printf(" Output dir: %s\n", color.GreenString(moduleDir)) + fmt.Printf(" Files: %s\n", color.GreenString("%d files generated", len(files))) + fmt.Printf("\n") + + success := color.New(color.FgGreen, color.Bold) + success.Println("✓ Code generation completed successfully!") + fmt.Println() + + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..098089f --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/entity-maker/entity-maker + +go 1.25.3 + +require ( + github.com/BurntSushi/toml v1.5.0 + github.com/fatih/color v1.18.0 + github.com/lib/pq v1.10.9 +) + +require ( + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + golang.org/x/sys v0.25.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f25dfcb --- /dev/null +++ b/go.sum @@ -0,0 +1,15 @@ +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34= +golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..29e5481 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,82 @@ +package config + +import ( + "os" + "path/filepath" + + "github.com/BurntSushi/toml" +) + +// Config represents the application configuration +type Config struct { + DBHost string `toml:"db_host"` + DBPort int `toml:"db_port"` + DBName string `toml:"db_name"` + DBSchema string `toml:"db_schema"` + DBUser string `toml:"db_user"` + DBPassword string `toml:"db_password"` + DBTable string `toml:"db_table"` + OutputDir string `toml:"output_dir"` + EntityName string `toml:"entity_name"` +} + +// DefaultConfig returns a new config with default values +func DefaultConfig() *Config { + return &Config{ + DBHost: "localhost", + DBPort: 5432, + DBSchema: "public", + DBUser: "postgres", + DBPassword: "postgres", + } +} + +// Load reads the configuration file from ~/.config/entity-maker.toml +func Load() (*Config, error) { + configPath := getConfigPath() + + config := DefaultConfig() + + // Check if config file exists + if _, err := os.Stat(configPath); os.IsNotExist(err) { + return config, nil + } + + // Read and decode the config file + if _, err := toml.DecodeFile(configPath, config); err != nil { + return nil, err + } + + return config, nil +} + +// Save writes the configuration to ~/.config/entity-maker.toml +func (c *Config) Save() error { + configPath := getConfigPath() + + // Create config directory if it doesn't exist + configDir := filepath.Dir(configPath) + if err := os.MkdirAll(configDir, 0755); err != nil { + return err + } + + // Create or truncate the config file + f, err := os.Create(configPath) + if err != nil { + return err + } + defer f.Close() + + // Encode and write the config + encoder := toml.NewEncoder(f) + return encoder.Encode(c) +} + +// getConfigPath returns the full path to the config file +func getConfigPath() string { + homeDir, err := os.UserHomeDir() + if err != nil { + homeDir = "." + } + return filepath.Join(homeDir, ".config", "entity-maker.toml") +} diff --git a/internal/database/client.go b/internal/database/client.go new file mode 100644 index 0000000..26258fc --- /dev/null +++ b/internal/database/client.go @@ -0,0 +1,67 @@ +package database + +import ( + "database/sql" + "fmt" + + _ "github.com/lib/pq" +) + +// Client represents a database connection +type Client struct { + db *sql.DB + schema string +} + +// Config holds database connection parameters +type Config struct { + Host string + Port int + Database string + Schema string + User string + Password string +} + +// NewClient creates a new database client +func NewClient(cfg Config) (*Client, error) { + connStr := fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", + cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Database, + ) + + db, err := sql.Open("postgres", connStr) + if err != nil { + return nil, fmt.Errorf("failed to open database: %w", err) + } + + if err := db.Ping(); err != nil { + db.Close() + return nil, fmt.Errorf("failed to connect to database: %w", err) + } + + return &Client{ + db: db, + schema: cfg.Schema, + }, nil +} + +// Close closes the database connection +func (c *Client) Close() error { + return c.db.Close() +} + +// TableExists checks if a table exists in the schema +func (c *Client) TableExists(tableName string) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = $1 AND table_name = $2 + ) + ` + + var exists bool + err := c.db.QueryRow(query, c.schema, tableName).Scan(&exists) + return exists, err +} diff --git a/internal/database/introspector.go b/internal/database/introspector.go new file mode 100644 index 0000000..81e2281 --- /dev/null +++ b/internal/database/introspector.go @@ -0,0 +1,341 @@ +package database + +import ( + "database/sql" + "fmt" + "strings" +) + +// Column represents a database column +type Column struct { + Name string + DataType string + IsNullable bool + ColumnDefault sql.NullString + CharMaxLength sql.NullInt64 + NumericPrecision sql.NullInt64 + NumericScale sql.NullInt64 + UdtName string // User-defined type name (for enums) + IsPrimaryKey bool + IsAutoIncrement bool +} + +// ForeignKey represents a foreign key constraint +type ForeignKey struct { + ColumnName string + ForeignTableSchema string + ForeignTableName string + ForeignColumnName string + ConstraintName string +} + +// EnumType represents a PostgreSQL enum type +type EnumType struct { + TypeName string + Values []string +} + +// TableInfo contains all information about a table +type TableInfo struct { + Schema string + TableName string + Columns []Column + ForeignKeys []ForeignKey + EnumTypes map[string]EnumType +} + +// IntrospectTable retrieves comprehensive information about a table +func (c *Client) IntrospectTable(tableName string) (*TableInfo, error) { + info := &TableInfo{ + Schema: c.schema, + TableName: tableName, + EnumTypes: make(map[string]EnumType), + } + + // Get columns + columns, err := c.getColumns(tableName) + if err != nil { + return nil, err + } + info.Columns = columns + + // Get primary keys and mark columns + primaryKeys, err := c.getPrimaryKeys(tableName) + if err != nil { + return nil, err + } + for i := range info.Columns { + for _, pk := range primaryKeys { + if info.Columns[i].Name == pk { + info.Columns[i].IsPrimaryKey = true + } + } + } + + // Get foreign keys + foreignKeys, err := c.getForeignKeys(tableName) + if err != nil { + return nil, err + } + info.ForeignKeys = foreignKeys + + // Get enum types for columns that use them + for _, col := range info.Columns { + if col.DataType == "USER-DEFINED" && col.UdtName != "" { + if _, exists := info.EnumTypes[col.UdtName]; !exists { + enumType, err := c.getEnumType(col.UdtName) + if err != nil { + return nil, err + } + info.EnumTypes[col.UdtName] = enumType + } + } + } + + return info, nil +} + +// getColumns retrieves column information for a table +func (c *Client) getColumns(tableName string) ([]Column, error) { + query := ` + SELECT + column_name, + data_type, + is_nullable = 'YES' as is_nullable, + column_default, + character_maximum_length, + numeric_precision, + numeric_scale, + udt_name + FROM information_schema.columns + WHERE table_schema = $1 AND table_name = $2 + ORDER BY ordinal_position + ` + + rows, err := c.db.Query(query, c.schema, tableName) + if err != nil { + return nil, fmt.Errorf("failed to query columns: %w", err) + } + defer rows.Close() + + var columns []Column + for rows.Next() { + var col Column + err := rows.Scan( + &col.Name, + &col.DataType, + &col.IsNullable, + &col.ColumnDefault, + &col.CharMaxLength, + &col.NumericPrecision, + &col.NumericScale, + &col.UdtName, + ) + if err != nil { + return nil, err + } + + // Check if column is auto-increment + if col.ColumnDefault.Valid { + defaultVal := strings.ToLower(col.ColumnDefault.String) + col.IsAutoIncrement = strings.Contains(defaultVal, "nextval") + } + + columns = append(columns, col) + } + + return columns, rows.Err() +} + +// getPrimaryKeys retrieves primary key column names for a table +func (c *Client) getPrimaryKeys(tableName string) ([]string, error) { + query := ` + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = ($1 || '.' || $2)::regclass + AND i.indisprimary + ` + + rows, err := c.db.Query(query, c.schema, tableName) + if err != nil { + return nil, fmt.Errorf("failed to query primary keys: %w", err) + } + defer rows.Close() + + var keys []string + for rows.Next() { + var key string + if err := rows.Scan(&key); err != nil { + return nil, err + } + keys = append(keys, key) + } + + return keys, rows.Err() +} + +// getForeignKeys retrieves foreign key information for a table +func (c *Client) getForeignKeys(tableName string) ([]ForeignKey, error) { + query := ` + SELECT + kcu.column_name, + ccu.table_schema AS foreign_table_schema, + ccu.table_name AS foreign_table_name, + ccu.column_name AS foreign_column_name, + tc.constraint_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu + ON tc.constraint_name = kcu.constraint_name + AND tc.table_schema = kcu.table_schema + JOIN information_schema.constraint_column_usage AS ccu + ON ccu.constraint_name = tc.constraint_name + AND ccu.table_schema = tc.table_schema + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_schema = $1 + AND tc.table_name = $2 + ORDER BY kcu.ordinal_position + ` + + rows, err := c.db.Query(query, c.schema, tableName) + if err != nil { + return nil, fmt.Errorf("failed to query foreign keys: %w", err) + } + defer rows.Close() + + var fks []ForeignKey + for rows.Next() { + var fk ForeignKey + err := rows.Scan( + &fk.ColumnName, + &fk.ForeignTableSchema, + &fk.ForeignTableName, + &fk.ForeignColumnName, + &fk.ConstraintName, + ) + if err != nil { + return nil, err + } + fks = append(fks, fk) + } + + return fks, rows.Err() +} + +// getEnumType retrieves enum values for a PostgreSQL enum type +func (c *Client) getEnumType(typeName string) (EnumType, error) { + query := ` + SELECT e.enumlabel + FROM pg_type t + JOIN pg_enum e ON t.oid = e.enumtypid + WHERE t.typname = $1 + ORDER BY e.enumsortorder + ` + + rows, err := c.db.Query(query, typeName) + if err != nil { + return EnumType{}, fmt.Errorf("failed to query enum type: %w", err) + } + defer rows.Close() + + var values []string + for rows.Next() { + var value string + if err := rows.Scan(&value); err != nil { + return EnumType{}, err + } + values = append(values, value) + } + + return EnumType{ + TypeName: typeName, + Values: values, + }, rows.Err() +} + +// GetPythonType converts PostgreSQL data type to Python type +func GetPythonType(col Column) string { + switch col.DataType { + case "integer", "smallint", "bigint": + return "int" + case "numeric", "decimal", "real", "double precision": + return "Decimal" + case "boolean": + return "bool" + case "character varying", "varchar", "text", "char", "character": + return "str" + case "timestamp with time zone", "timestamp without time zone", "timestamp": + return "datetime" + case "date": + return "date" + case "time with time zone", "time without time zone", "time": + return "time" + case "json", "jsonb": + return "dict" + case "uuid": + return "UUID" + case "bytea": + return "bytes" + case "USER-DEFINED": + // This is likely an enum + return "str" // Will be replaced with specific enum type in generator + default: + return "Any" + } +} + +// GetSQLAlchemyType converts PostgreSQL data type to SQLAlchemy type +func GetSQLAlchemyType(col Column) string { + switch col.DataType { + case "integer": + return "Integer" + case "smallint": + return "SmallInteger" + case "bigint": + return "BigInteger" + case "numeric", "decimal": + if col.NumericPrecision.Valid && col.NumericScale.Valid { + return fmt.Sprintf("Numeric(%d, %d)", col.NumericPrecision.Int64, col.NumericScale.Int64) + } + return "Numeric" + case "real": + return "Float" + case "double precision": + return "Float" + case "boolean": + return "Boolean" + case "character varying", "varchar": + if col.CharMaxLength.Valid { + return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64) + } + return "String" + case "char", "character": + if col.CharMaxLength.Valid { + return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64) + } + return "String(1)" + case "text": + return "Text" + case "timestamp with time zone": + return "DateTime(timezone=True)" + case "timestamp without time zone", "timestamp": + return "DateTime" + case "date": + return "Date" + case "time with time zone", "time without time zone", "time": + return "Time" + case "json": + return "JSON" + case "jsonb": + return "JSONB" + case "uuid": + return "UUID" + case "bytea": + return "LargeBinary" + case "USER-DEFINED": + // Will be handled separately for enums + return "Enum" + default: + return "String" + } +} diff --git a/internal/generator/enum.go b/internal/generator/enum.go new file mode 100644 index 0000000..2f61421 --- /dev/null +++ b/internal/generator/enum.go @@ -0,0 +1,40 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/entity-maker/entity-maker/internal/naming" +) + +// GenerateEnum generates the enum types file +func GenerateEnum(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from enum import StrEnum\n\n") + b.WriteString("from televend_core.databases.enum import EnumMixin\n\n\n") + + // Generate each enum type + for _, enumType := range ctx.TableInfo.EnumTypes { + enumName := naming.ToPascalCase(enumType.TypeName) + + b.WriteString(fmt.Sprintf("class %s(EnumMixin, StrEnum):\n", enumName)) + + if len(enumType.Values) == 0 { + b.WriteString(" pass\n") + } else { + for _, value := range enumType.Values { + // Convert value to valid Python identifier + // Usually enum values are already uppercase like "OPEN", "IN_PROGRESS" + identifier := strings.ToUpper(strings.ReplaceAll(value, " ", "_")) + identifier = strings.ReplaceAll(identifier, "-", "_") + + b.WriteString(fmt.Sprintf(" %s = \"%s\"\n", identifier, value)) + } + } + b.WriteString("\n") + } + + return b.String(), nil +} diff --git a/internal/generator/factory.go b/internal/generator/factory.go new file mode 100644 index 0000000..06742a7 --- /dev/null +++ b/internal/generator/factory.go @@ -0,0 +1,203 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/entity-maker/entity-maker/internal/database" + "github.com/entity-maker/entity-maker/internal/naming" +) + +// GenerateFactory generates the factory class +func GenerateFactory(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from __future__ import annotations\n\n") + b.WriteString("from typing import Type\n\n") + b.WriteString("import factory\n\n") + + // Import factories for related models + fkImports := make(map[string]string) // module_name -> entity_name + for _, fk := range ctx.TableInfo.ForeignKeys { + moduleName := GetRelationshipModuleName(fk.ForeignTableName) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + fkImports[moduleName] = entityName + } + + for moduleName, entityName := range fkImports { + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n", + moduleName)) + b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName)) + b.WriteString(")\n") + } + + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n", + ctx.ModuleName)) + b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName)) + b.WriteString(")\n") + b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n") + b.WriteString(" CustomSelfAttribute,\n") + b.WriteString(" TelevendBaseFactory,\n") + b.WriteString(")\n\n\n") + + // Class definition + b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName)) + + // Add boolean fields with defaults + for _, col := range ctx.TableInfo.Columns { + if col.DataType == "boolean" { + defaultValue := "True" + if col.Name == "alive" { + defaultValue = "True" + } else { + defaultValue = "False" + } + b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue)) + } + } + + // Add id field + for _, col := range ctx.TableInfo.Columns { + if col.IsPrimaryKey { + b.WriteString(fmt.Sprintf(" %s = None\n", col.Name)) + } + } + + b.WriteString("\n") + + // Generate faker fields for each column + for _, col := range ctx.TableInfo.Columns { + if col.IsPrimaryKey || col.DataType == "boolean" { + continue + } + + // Skip foreign keys, we'll handle them separately + if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil { + continue + } + + fakerDef := generateFakerField(col, ctx) + if fakerDef != "" { + b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef)) + } + } + + // Generate foreign key relationships + if len(ctx.TableInfo.ForeignKeys) > 0 { + b.WriteString("\n") + for _, fk := range ctx.TableInfo.ForeignKeys { + relationName := GetRelationshipName(fk.ColumnName) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + + b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n", + relationName, relationName, entityName)) + b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n", + fk.ColumnName, relationName, relationName)) + b.WriteString("\n") + } + } + + // Meta class + b.WriteString(" class Meta:\n") + b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName)) + b.WriteString("\n") + + // create_minimal method + b.WriteString(" @classmethod\n") + b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n", + ctx.EntityName, ctx.EntityName)) + b.WriteString(" minimal_params = {\n") + + // Add foreign key params + for _, fk := range ctx.TableInfo.ForeignKeys { + relationName := GetRelationshipName(fk.ColumnName) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + + // Check if this FK is required + var col *database.Column + for i := range ctx.TableInfo.Columns { + if ctx.TableInfo.Columns[i].Name == fk.ColumnName { + col = &ctx.TableInfo.Columns[i] + break + } + } + + if col != nil && !col.IsNullable { + // Required FK + b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n", + relationName, relationName, entityName)) + } else { + // Optional FK + b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName)) + } + } + + b.WriteString(" }\n") + b.WriteString(" minimal_params.update(kwargs)\n") + b.WriteString(" return cls.create(**minimal_params)\n") + + return b.String(), nil +} + +func generateFakerField(col database.Column, ctx *Context) string { + pythonType := database.GetPythonType(col) + + switch pythonType { + case "Decimal": + precision := 7 + scale := 4 + if col.NumericPrecision.Valid { + precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64) + } + if col.NumericScale.Valid { + scale = int(col.NumericScale.Int64) + } + return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)", + precision, scale) + + case "int": + if col.DataType == "bigint" { + return "factory.Faker(\"pyint\")" + } + return "factory.Faker(\"pyint\")" + + case "str": + // Check if it's an enum + if col.DataType == "USER-DEFINED" && col.UdtName != "" { + if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists { + if len(enumType.Values) > 0 { + enumName := naming.ToPascalCase(enumType.TypeName) + return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName) + } + } + } + + maxLen := 255 + if col.CharMaxLength.Valid { + maxLen = int(col.CharMaxLength.Int64) + } + if col.DataType == "text" { + return "factory.Faker(\"text\", max_nb_chars=500)" + } + return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen) + + case "datetime": + return "factory.Faker(\"date_time\")" + + case "date": + return "factory.Faker(\"date\")" + + case "time": + return "factory.Faker(\"time\")" + + case "bool": + return "factory.Faker(\"boolean\")" + + case "dict": + return "factory.Faker(\"pydict\")" + + default: + return "None" + } +} diff --git a/internal/generator/filter.go b/internal/generator/filter.go new file mode 100644 index 0000000..4c2b4a8 --- /dev/null +++ b/internal/generator/filter.go @@ -0,0 +1,77 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/entity-maker/entity-maker/internal/database" +) + +// GenerateFilter generates the filter class +func GenerateFilter(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from televend_core.databases.base_filter import BaseFilter\n") + b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n") + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n", + ctx.ModuleName, ctx.EntityName)) + b.WriteString("\n\n") + + // Class definition + b.WriteString(fmt.Sprintf("class %sFilter(BaseFilter):\n", ctx.EntityName)) + b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName)) + + // Generate filters based on rules: + // - Boolean fields: EQ operator + // - ID fields: both EQ and IN operators + // - Text fields: EQ operator + + hasFilters := false + + for _, col := range ctx.TableInfo.Columns { + if !ShouldGenerateFilter(col) { + continue + } + + hasFilters = true + + // Boolean fields + if col.DataType == "boolean" { + defaultVal := "None" + if col.Name == "alive" { + defaultVal = "True" + } + b.WriteString(fmt.Sprintf(" %s: bool | None = filterfield(operator=EQ, default=%s)\n", + col.Name, defaultVal)) + } + + // ID fields (both EQ and IN) + if strings.HasSuffix(col.Name, "_id") || col.Name == "id" { + // Single ID with EQ + pythonType := database.GetPythonType(col) + b.WriteString(fmt.Sprintf(" %s: %s | None = filterfield(operator=EQ)\n", + col.Name, pythonType)) + + // Multiple IDs with IN (plural field name) + pluralFieldName := GetFilterFieldName(col.Name, true) + b.WriteString(fmt.Sprintf(" %s: list[%s] | None = filterfield(field=\"%s\", operator=IN)\n", + pluralFieldName, pythonType, col.Name)) + } + + // Text fields + if (col.DataType == "character varying" || col.DataType == "varchar" || + col.DataType == "text" || col.DataType == "char" || col.DataType == "character") && + !strings.HasSuffix(col.Name, "_id") && col.Name != "id" { + b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n", + col.Name)) + } + } + + // If no filters were generated, add a pass statement + if !hasFilters { + b.WriteString(" pass\n") + } + + return b.String(), nil +} diff --git a/internal/generator/generator.go b/internal/generator/generator.go new file mode 100644 index 0000000..1141b71 --- /dev/null +++ b/internal/generator/generator.go @@ -0,0 +1,195 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/entity-maker/entity-maker/internal/database" + "github.com/entity-maker/entity-maker/internal/naming" +) + +// Context contains all information needed for code generation +type Context struct { + TableInfo *database.TableInfo + EntityName string // Singular, PascalCase (e.g., "CashbagConform") + ModuleName string // Singular, snake_case (e.g., "cashbag_conform") + TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE") +} + +// NewContext creates a new generation context +func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context { + moduleName := naming.SingularizeTableName(tableInfo.TableName) + + entityName := entityNameOverride + if entityName == "" { + entityName = naming.ToPascalCase(moduleName) + } + + tableConstant := strings.ToUpper(moduleName) + "_TABLE" + + return &Context{ + TableInfo: tableInfo, + EntityName: entityName, + ModuleName: moduleName, + TableConstant: tableConstant, + } +} + +// GetRelationshipName returns the relationship name for a foreign key +// Strips _id suffix and converts to snake_case +func GetRelationshipName(fkColumnName string) string { + name := fkColumnName + if strings.HasSuffix(name, "_id") { + name = name[:len(name)-3] + } + return name +} + +// GetRelationshipEntityName returns the entity name for a foreign key's target table +func GetRelationshipEntityName(tableName string) string { + singular := naming.SingularizeTableName(tableName) + return naming.ToPascalCase(singular) +} + +// GetRelationshipModuleName returns the module name for a foreign key's target table +func GetRelationshipModuleName(tableName string) string { + return naming.SingularizeTableName(tableName) +} + +// GetFilterFieldName returns the filter field name for a column +// For ID fields with IN operator, pluralizes the name +func GetFilterFieldName(columnName string, useIN bool) string { + if useIN && strings.HasSuffix(columnName, "_id") { + // Remove _id, pluralize, add back _ids + base := columnName[:len(columnName)-3] + if base == "" { + return "ids" + } + return naming.Pluralize(base) + "_ids" + } + if useIN && columnName == "id" { + return "ids" + } + return columnName +} + +// ShouldGenerateFilter determines if a column should have a filter +func ShouldGenerateFilter(col database.Column) bool { + // Generate filters for boolean, ID, and text fields + if col.DataType == "boolean" { + return true + } + if strings.HasSuffix(col.Name, "_id") || col.Name == "id" { + return true + } + if col.DataType == "character varying" || col.DataType == "varchar" || + col.DataType == "text" || col.DataType == "char" || col.DataType == "character" { + return true + } + return false +} + +// GetPythonTypeForColumn returns the Python type annotation for a column +func GetPythonTypeForColumn(col database.Column, ctx *Context) string { + baseType := database.GetPythonType(col) + + // Handle enum types + if col.DataType == "USER-DEFINED" && col.UdtName != "" { + if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists { + baseType = naming.ToPascalCase(enumType.TypeName) + } + } + + // Handle Decimal type + if baseType == "Decimal" { + baseType = "Decimal" + } + + return baseType +} + +// NeedsDecimalImport checks if any column uses Decimal type +func NeedsDecimalImport(columns []database.Column) bool { + for _, col := range columns { + if database.GetPythonType(col) == "Decimal" { + return true + } + } + return false +} + +// NeedsDatetimeImport checks if any column uses datetime type +func NeedsDatetimeImport(columns []database.Column) bool { + for _, col := range columns { + pyType := database.GetPythonType(col) + if pyType == "datetime" || pyType == "date" || pyType == "time" { + return true + } + } + return false +} + +// GetRequiredColumns returns columns that are not nullable and don't have defaults +func GetRequiredColumns(columns []database.Column) []database.Column { + var required []database.Column + for _, col := range columns { + if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey { + required = append(required, col) + } + } + return required +} + +// GetOptionalColumns returns columns that are nullable or have defaults +func GetOptionalColumns(columns []database.Column) []database.Column { + var optional []database.Column + for _, col := range columns { + if col.IsNullable || col.ColumnDefault.Valid { + optional = append(optional, col) + } + } + return optional +} + +// GetForeignKeyForColumn returns the foreign key info for a column, if it exists +func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey { + for _, fk := range foreignKeys { + if fk.ColumnName == columnName { + return &fk + } + } + return nil +} + +// GenerateFiles generates all Python files for the entity +func GenerateFiles(ctx *Context, outputDir string) error { + generators := map[string]func(*Context) (string, error){ + "table.py": GenerateTable, + "model.py": GenerateModel, + "filter.py": GenerateFilter, + "load_options.py": GenerateLoadOptions, + "repository.py": GenerateRepository, + "manager.py": GenerateManager, + "factory.py": GenerateFactory, + "mapper.py": GenerateMapper, + "__init__.py": GenerateInit, + } + + // Generate enum.py if there are enum types + if len(ctx.TableInfo.EnumTypes) > 0 { + generators["enum.py"] = GenerateEnum + } + + for filename, generator := range generators { + content, err := generator(ctx) + if err != nil { + return fmt.Errorf("failed to generate %s: %w", filename, err) + } + + // Write to file (this will be handled by the main function) + // For now, just return the content + _ = content + } + + return nil +} diff --git a/internal/generator/init.go b/internal/generator/init.go new file mode 100644 index 0000000..43ad7e7 --- /dev/null +++ b/internal/generator/init.go @@ -0,0 +1,7 @@ +package generator + +// GenerateInit generates an empty __init__.py file +func GenerateInit(ctx *Context) (string, error) { + // Empty __init__.py file + return "", nil +} diff --git a/internal/generator/load_options.go b/internal/generator/load_options.go new file mode 100644 index 0000000..e30058c --- /dev/null +++ b/internal/generator/load_options.go @@ -0,0 +1,38 @@ +package generator + +import ( + "fmt" + "strings" +) + +// GenerateLoadOptions generates the load options class +func GenerateLoadOptions(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n") + b.WriteString("from televend_core.databases.common.load_options import joinload\n") + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n", + ctx.ModuleName, ctx.EntityName)) + b.WriteString("\n\n") + + // Class definition + b.WriteString(fmt.Sprintf("class %sLoadOptions(LoadOptions):\n", ctx.EntityName)) + b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName)) + + // Generate load options for all foreign key relationships + hasRelationships := false + for _, fk := range ctx.TableInfo.ForeignKeys { + hasRelationships = true + relationName := GetRelationshipName(fk.ColumnName) + b.WriteString(fmt.Sprintf(" load_%s: bool = joinload(relations=[\"%s\"])\n", + relationName, relationName)) + } + + // If no relationships, add pass + if !hasRelationships { + b.WriteString(" pass\n") + } + + return b.String(), nil +} diff --git a/internal/generator/manager.go b/internal/generator/manager.go new file mode 100644 index 0000000..b731a41 --- /dev/null +++ b/internal/generator/manager.go @@ -0,0 +1,34 @@ +package generator + +import ( + "fmt" + "strings" +) + +// GenerateManager generates the manager class +func GenerateManager(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from televend_core.databases.base_manager import CRUDManager\n") + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n", + ctx.ModuleName)) + b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName)) + b.WriteString(")\n") + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n", + ctx.ModuleName, ctx.EntityName)) + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n", + ctx.ModuleName)) + b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName)) + b.WriteString(")\n") + b.WriteString("\n\n") + + // Class definition + b.WriteString(fmt.Sprintf("class %sManager(\n", ctx.EntityName)) + b.WriteString(fmt.Sprintf(" CRUDManager[%s, %sFilter, %sRepository]\n", + ctx.EntityName, ctx.EntityName, ctx.EntityName)) + b.WriteString("):\n") + b.WriteString(fmt.Sprintf(" repository_cls = %sRepository\n", ctx.EntityName)) + + return b.String(), nil +} diff --git a/internal/generator/mapper.go b/internal/generator/mapper.go new file mode 100644 index 0000000..148d357 --- /dev/null +++ b/internal/generator/mapper.go @@ -0,0 +1,48 @@ +package generator + +import ( + "fmt" + "strings" +) + +// GenerateMapper generates the mapper snippet (without imports) +func GenerateMapper(ctx *Context) (string, error) { + var b strings.Builder + + // Mapper registration (snippet only, no imports) + b.WriteString(" mapper_registry.map_imperatively(\n") + b.WriteString(fmt.Sprintf(" class_=%s,\n", ctx.EntityName)) + b.WriteString(fmt.Sprintf(" local_table=%s,\n", ctx.TableConstant)) + b.WriteString(" properties={\n") + + // Generate relationships for all foreign keys + fkRelationships := make(map[string][]string) // entity -> []column_names + + for _, fk := range ctx.TableInfo.ForeignKeys { + relationName := GetRelationshipName(fk.ColumnName) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + + // Group by entity name to handle multiple FKs to same table + fkRelationships[entityName] = append(fkRelationships[entityName], fk.ColumnName) + + if len(fkRelationships[entityName]) == 1 { + // First FK to this table + b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName)) + b.WriteString(fmt.Sprintf(" %s, lazy=relationship_loading_strategy.value\n", entityName)) + } else { + // Multiple FKs to same table, need to specify foreign_keys + b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName)) + b.WriteString(fmt.Sprintf(" %s,\n", entityName)) + b.WriteString(" lazy=relationship_loading_strategy.value,\n") + b.WriteString(fmt.Sprintf(" foreign_keys=%s.columns.%s,\n", + ctx.TableConstant, fk.ColumnName)) + } + + b.WriteString(" ),\n") + } + + b.WriteString(" },\n") + b.WriteString(" )\n") + + return b.String(), nil +} diff --git a/internal/generator/model.go b/internal/generator/model.go new file mode 100644 index 0000000..b5832cc --- /dev/null +++ b/internal/generator/model.go @@ -0,0 +1,124 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/entity-maker/entity-maker/internal/naming" +) + +// GenerateModel generates the dataclass model +func GenerateModel(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from dataclasses import dataclass\n") + + // Check what we need to import + needsDatetime := NeedsDatetimeImport(ctx.TableInfo.Columns) + needsDecimal := NeedsDecimalImport(ctx.TableInfo.Columns) + + if needsDatetime { + b.WriteString("from datetime import datetime\n") + } + if needsDecimal { + b.WriteString("from decimal import Decimal\n") + } + + b.WriteString("\n") + b.WriteString("from televend_core.databases.base_model import Base\n") + + // Import related models for foreign keys + fkImports := make(map[string]string) // module_name -> entity_name + for _, fk := range ctx.TableInfo.ForeignKeys { + moduleName := GetRelationshipModuleName(fk.ForeignTableName) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + fkImports[moduleName] = entityName + } + + // Import enum types + if len(ctx.TableInfo.EnumTypes) > 0 { + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n", + ctx.ModuleName)) + for _, enumType := range ctx.TableInfo.EnumTypes { + enumName := naming.ToPascalCase(enumType.TypeName) + b.WriteString(fmt.Sprintf(" %s,\n", enumName)) + } + b.WriteString(")\n") + } + + // Write foreign key imports + for moduleName, entityName := range fkImports { + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n", + moduleName, entityName)) + } + + b.WriteString("\n\n") + + // Class definition + b.WriteString("@dataclass\n") + b.WriteString(fmt.Sprintf("class %s(Base):\n", ctx.EntityName)) + + // Get required and optional columns + requiredCols := GetRequiredColumns(ctx.TableInfo.Columns) + optionalCols := GetOptionalColumns(ctx.TableInfo.Columns) + + // Required fields (non-nullable, no default, not auto-increment, not PK) + for _, col := range requiredCols { + fieldName := col.Name + pythonType := GetPythonTypeForColumn(col, ctx) + + // Regular field + b.WriteString(fmt.Sprintf(" %s: %s\n", fieldName, pythonType)) + + // Add relationship field if this is a foreign key and column ends with _id + // (to avoid name clashes with FK columns that don't follow _id convention) + if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil { + if strings.HasSuffix(col.Name, "_id") { + relationName := GetRelationshipName(col.Name) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + b.WriteString(fmt.Sprintf(" %s: %s\n", relationName, entityName)) + } + } + } + + // Empty line between required and optional + if len(requiredCols) > 0 && len(optionalCols) > 0 { + b.WriteString("\n") + } + + // Optional fields + for _, col := range optionalCols { + // Skip primary key, we'll add it at the end + if col.IsPrimaryKey { + continue + } + + fieldName := col.Name + pythonType := GetPythonTypeForColumn(col, ctx) + + b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", fieldName, pythonType)) + + // Add relationship field if this is a foreign key and column ends with _id + // (to avoid name clashes with FK columns that don't follow _id convention) + if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil { + if strings.HasSuffix(col.Name, "_id") { + relationName := GetRelationshipName(col.Name) + entityName := GetRelationshipEntityName(fk.ForeignTableName) + b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", relationName, entityName)) + } + } + } + + // Add primary key at the end + for _, col := range ctx.TableInfo.Columns { + if col.IsPrimaryKey { + b.WriteString("\n") + pythonType := GetPythonTypeForColumn(col, ctx) + b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", col.Name, pythonType)) + break + } + } + + return b.String(), nil +} diff --git a/internal/generator/repository.go b/internal/generator/repository.go new file mode 100644 index 0000000..8221569 --- /dev/null +++ b/internal/generator/repository.go @@ -0,0 +1,28 @@ +package generator + +import ( + "fmt" + "strings" +) + +// GenerateRepository generates the repository class +func GenerateRepository(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n") + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n", + ctx.ModuleName)) + b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName)) + b.WriteString(")\n") + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n", + ctx.ModuleName, ctx.EntityName)) + b.WriteString("\n\n") + + // Class definition + b.WriteString(fmt.Sprintf("class %sRepository(CRUDRepository[%s, %sFilter]):\n", + ctx.EntityName, ctx.EntityName, ctx.EntityName)) + b.WriteString(fmt.Sprintf(" model_cls = %s\n", ctx.EntityName)) + + return b.String(), nil +} diff --git a/internal/generator/table.go b/internal/generator/table.go new file mode 100644 index 0000000..b1da4c1 --- /dev/null +++ b/internal/generator/table.go @@ -0,0 +1,160 @@ +package generator + +import ( + "fmt" + "strings" + + "github.com/entity-maker/entity-maker/internal/database" + "github.com/entity-maker/entity-maker/internal/naming" +) + +// GenerateTable generates the SQLAlchemy table definition +func GenerateTable(ctx *Context) (string, error) { + var b strings.Builder + + // Imports + b.WriteString("from sqlalchemy import (\n") + + // Collect unique imports + imports := make(map[string]bool) + imports["Column"] = true + imports["Table"] = true + + for _, col := range ctx.TableInfo.Columns { + sqlType := database.GetSQLAlchemyType(col) + + // Extract base type name (before parentheses) + typeName := strings.Split(sqlType, "(")[0] + imports[typeName] = true + + if col.IsPrimaryKey { + imports["Integer"] = true + } + + // Check for foreign keys + if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil { + imports["ForeignKey"] = true + } + + // Check for enums + if col.DataType == "USER-DEFINED" { + imports["Enum"] = true + } + } + + // Sort and write imports + importList := []string{} + for imp := range imports { + importList = append(importList, imp) + } + + // Write imports in a reasonable order + orderedImports := []string{ + "BigInteger", "Boolean", "Column", "Date", "DateTime", "Enum", + "Float", "ForeignKey", "Integer", "JSON", "JSONB", "LargeBinary", + "Numeric", "SmallInteger", "String", "Table", "Text", "Time", "UUID", + } + + first := true + for _, imp := range orderedImports { + if imports[imp] { + if !first { + b.WriteString(",\n") + } else { + first = false + } + b.WriteString(" " + imp) + } + } + b.WriteString(",\n)\n\n") + + // Import enum types if they exist + if len(ctx.TableInfo.EnumTypes) > 0 { + for _, enumType := range ctx.TableInfo.EnumTypes { + enumName := naming.ToPascalCase(enumType.TypeName) + b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n", + ctx.ModuleName)) + b.WriteString(fmt.Sprintf(" %s,\n", enumName)) + b.WriteString(")\n") + } + } + + b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n") + + // Table definition + b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant)) + b.WriteString(fmt.Sprintf(" \"%s\",\n", ctx.TableInfo.TableName)) + b.WriteString(" metadata_obj,\n") + + // Columns + for _, col := range ctx.TableInfo.Columns { + b.WriteString(generateColumnDefinition(col, ctx)) + } + + b.WriteString(")\n") + + return b.String(), nil +} + +func generateColumnDefinition(col database.Column, ctx *Context) string { + var parts []string + + // Column name + parts = append(parts, fmt.Sprintf("\"%s\"", col.Name)) + + // Column type + sqlType := database.GetSQLAlchemyType(col) + + // Handle enum types + if col.DataType == "USER-DEFINED" && col.UdtName != "" { + if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists { + enumName := naming.ToPascalCase(enumType.TypeName) + sqlType = fmt.Sprintf("Enum(\n *%s.to_value_list(),\n name=\"%s\",\n )", + enumName, enumType.TypeName) + } + } + + parts = append(parts, sqlType) + + // Foreign key + if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil { + fkDef := fmt.Sprintf("ForeignKey(\"%s.%s\", deferrable=True, initially=\"DEFERRED\")", + fk.ForeignTableName, fk.ForeignColumnName) + parts = append(parts, fkDef) + } + + // Primary key + if col.IsPrimaryKey { + parts = append(parts, "primary_key=True") + if col.IsAutoIncrement { + parts = append(parts, "autoincrement=True") + } + } + + // Nullable + if !col.IsNullable && !col.IsPrimaryKey { + parts = append(parts, "nullable=False") + } + + // Unique + // Note: We don't have unique constraint info in our introspection yet + // This would need to be added if needed + + // Format the column definition + result := " Column(" + + // Check if we need multiline formatting (for complex types like Enum) + if strings.Contains(sqlType, "\n") { + // Multiline format + result += parts[0] + ",\n " + parts[1] + for _, part := range parts[2:] { + result += ",\n " + part + } + result += ",\n ),\n" + } else { + // Single line format + result += strings.Join(parts, ", ") + "),\n" + } + + return result +} diff --git a/internal/naming/naming.go b/internal/naming/naming.go new file mode 100644 index 0000000..fa1ab44 --- /dev/null +++ b/internal/naming/naming.go @@ -0,0 +1,227 @@ +package naming + +import ( + "strings" + "unicode" +) + +// ToSnakeCase converts a string to snake_case +func ToSnakeCase(s string) string { + var result []rune + for i, r := range s { + if unicode.IsUpper(r) { + if i > 0 { + result = append(result, '_') + } + result = append(result, unicode.ToLower(r)) + } else { + result = append(result, r) + } + } + return string(result) +} + +// ToPascalCase converts snake_case to PascalCase +func ToPascalCase(s string) string { + parts := strings.Split(s, "_") + for i, part := range parts { + if len(part) > 0 { + parts[i] = strings.ToUpper(part[:1]) + part[1:] + } + } + return strings.Join(parts, "") +} + +// Singularize converts a plural word to singular +// Uses common English pluralization rules +func Singularize(word string) string { + // Handle empty string + if word == "" { + return word + } + + // Common irregular plurals + irregulars := map[string]string{ + "people": "person", + "men": "man", + "women": "woman", + "children": "child", + "teeth": "tooth", + "feet": "foot", + "mice": "mouse", + "geese": "goose", + "oxen": "ox", + "sheep": "sheep", + "fish": "fish", + "deer": "deer", + "series": "series", + "species": "species", + "quizzes": "quiz", + "analyses": "analysis", + "diagnoses": "diagnosis", + "oases": "oasis", + "theses": "thesis", + "crises": "crisis", + "phenomena": "phenomenon", + "criteria": "criterion", + "data": "datum", + } + + lower := strings.ToLower(word) + if singular, ok := irregulars[lower]; ok { + return preserveCase(word, singular) + } + + // Handle words ending in 'ies' -> 'y' + if strings.HasSuffix(lower, "ies") && len(word) > 3 { + return word[:len(word)-3] + "y" + } + + // Handle words ending in 'ves' -> 'fe' or 'f' + if strings.HasSuffix(lower, "ves") && len(word) > 3 { + base := word[:len(word)-3] + // Common words that end in 'fe' + if strings.HasSuffix(strings.ToLower(base), "li") || + strings.HasSuffix(strings.ToLower(base), "wi") || + strings.HasSuffix(strings.ToLower(base), "kni") { + return base + "fe" + } + return base + "f" + } + + // Handle words ending in 'xes', 'ses', 'shes', 'ches' -> remove 'es' + if strings.HasSuffix(lower, "xes") || + strings.HasSuffix(lower, "ses") || + strings.HasSuffix(lower, "shes") || + strings.HasSuffix(lower, "ches") { + if len(word) > 2 { + return word[:len(word)-2] + } + } + + // Handle words ending in 'oes' -> 'o' + if strings.HasSuffix(lower, "oes") && len(word) > 3 { + return word[:len(word)-2] + } + + // Handle simple 's' suffix + if strings.HasSuffix(lower, "s") && len(word) > 1 { + // Don't remove 's' from words that naturally end in 's' + if !strings.HasSuffix(lower, "ss") && + !strings.HasSuffix(lower, "us") && + !strings.HasSuffix(lower, "is") { + return word[:len(word)-1] + } + } + + return word +} + +// Pluralize converts a singular word to plural +func Pluralize(word string) string { + if word == "" { + return word + } + + // Common irregular plurals + irregulars := map[string]string{ + "person": "people", + "man": "men", + "woman": "women", + "child": "children", + "tooth": "teeth", + "foot": "feet", + "mouse": "mice", + "goose": "geese", + "ox": "oxen", + "sheep": "sheep", + "fish": "fish", + "deer": "deer", + "series": "series", + "species": "species", + "quiz": "quizzes", + "analysis": "analyses", + "diagnosis": "diagnoses", + "oasis": "oases", + "thesis": "theses", + "crisis": "crises", + "phenomenon": "phenomena", + "criterion": "criteria", + "datum": "data", + } + + lower := strings.ToLower(word) + if plural, ok := irregulars[lower]; ok { + return preserveCase(word, plural) + } + + // Already plural (ends in 's' and not special case) + if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") { + return word + } + + // Handle words ending in 'y' -> 'ies' + if strings.HasSuffix(lower, "y") && len(word) > 1 { + prevChar := rune(lower[len(lower)-2]) + if !isVowel(prevChar) { + return word[:len(word)-1] + "ies" + } + } + + // Handle words ending in 'f' or 'fe' -> 'ves' + if strings.HasSuffix(lower, "f") { + return word[:len(word)-1] + "ves" + } + if strings.HasSuffix(lower, "fe") { + return word[:len(word)-2] + "ves" + } + + // Handle words ending in 'o' -> 'oes' + if strings.HasSuffix(lower, "o") && len(word) > 1 { + prevChar := rune(lower[len(lower)-2]) + if !isVowel(prevChar) { + return word + "es" + } + } + + // Handle words ending in 'x', 's', 'sh', 'ch' -> add 'es' + if strings.HasSuffix(lower, "x") || + strings.HasSuffix(lower, "s") || + strings.HasSuffix(lower, "sh") || + strings.HasSuffix(lower, "ch") { + return word + "es" + } + + // Default: just add 's' + return word + "s" +} + +// preserveCase preserves the case pattern of the original word +func preserveCase(original, replacement string) string { + if len(original) == 0 { + return replacement + } + + if unicode.IsUpper(rune(original[0])) { + return strings.ToUpper(replacement[:1]) + replacement[1:] + } + return replacement +} + +// isVowel checks if a rune is a vowel +func isVowel(r rune) bool { + vowels := "aeiouAEIOU" + return strings.ContainsRune(vowels, r) +} + +// SingularizeTableName converts a table name to its singular form +// Handles snake_case table names +func SingularizeTableName(tableName string) string { + parts := strings.Split(tableName, "_") + if len(parts) > 0 { + // Only singularize the last part + lastIdx := len(parts) - 1 + parts[lastIdx] = Singularize(parts[lastIdx]) + } + return strings.Join(parts, "_") +} diff --git a/internal/prompt/prompt.go b/internal/prompt/prompt.go new file mode 100644 index 0000000..a71db41 --- /dev/null +++ b/internal/prompt/prompt.go @@ -0,0 +1,186 @@ +package prompt + +import ( + "bufio" + "fmt" + "os" + "strconv" + "strings" + + "github.com/fatih/color" +) + +var ( + cyan = color.New(color.FgCyan).SprintFunc() + green = color.New(color.FgGreen).SprintFunc() + yellow = color.New(color.FgYellow).SprintFunc() + red = color.New(color.FgRed).SprintFunc() +) + +// Reader interface for testing +type Reader interface { + ReadString(delim byte) (string, error) +} + +// PromptString prompts for a string value +func PromptString(label string, defaultValue string, required bool) (string, error) { + reader := bufio.NewReader(os.Stdin) + return promptStringWithReader(reader, label, defaultValue, required) +} + +func promptStringWithReader(reader Reader, label string, defaultValue string, required bool) (string, error) { + for { + prompt := fmt.Sprintf("%s", cyan(label)) + if defaultValue != "" { + prompt += fmt.Sprintf(" [%s]", green(defaultValue)) + } + prompt += ": " + + fmt.Print(prompt) + + input, err := reader.ReadString('\n') + if err != nil { + return "", err + } + + input = strings.TrimSpace(input) + + // If input is empty, use default value + if input == "" { + if defaultValue != "" { + // Echo the selected default value + fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultValue), defaultValue) + fmt.Printf("\033[1A\r%s\n", fullPrompt) + return defaultValue, nil + } + if required { + fmt.Println(red("✗ This field is required")) + continue + } + return "", nil + } + + return input, nil + } +} + +// PromptInt prompts for an integer value +func PromptInt(label string, defaultValue int, required bool) (int, error) { + reader := bufio.NewReader(os.Stdin) + return promptIntWithReader(reader, label, defaultValue, required) +} + +func promptIntWithReader(reader Reader, label string, defaultValue int, required bool) (int, error) { + for { + defaultStr := "" + if defaultValue != 0 { + defaultStr = strconv.Itoa(defaultValue) + } + + prompt := fmt.Sprintf("%s", cyan(label)) + if defaultStr != "" { + prompt += fmt.Sprintf(" [%s]", green(defaultStr)) + } + prompt += ": " + + fmt.Print(prompt) + + input, err := reader.ReadString('\n') + if err != nil { + return 0, err + } + + input = strings.TrimSpace(input) + + // If input is empty, use default value + if input == "" { + if defaultValue != 0 { + // Echo the selected default value + fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultStr), defaultStr) + fmt.Printf("\033[1A\r%s\n", fullPrompt) + return defaultValue, nil + } + if required { + fmt.Println(red("✗ This field is required")) + continue + } + return 0, nil + } + + // Parse integer + value, err := strconv.Atoi(input) + if err != nil { + fmt.Println(red("✗ Please enter a valid number")) + continue + } + + return value, nil + } +} + +// ValidateDirectory checks if a directory exists or can be created +func ValidateDirectory(path string) error { + // Check if path exists + info, err := os.Stat(path) + if err == nil { + // Path exists, check if it's a directory + if !info.IsDir() { + return fmt.Errorf("path exists but is not a directory") + } + return nil + } + + // Path doesn't exist, try to create it + if os.IsNotExist(err) { + if err := os.MkdirAll(path, 0755); err != nil { + return fmt.Errorf("cannot create directory: %w", err) + } + return nil + } + + return err +} + +// PromptDirectory prompts for a directory path and validates it +func PromptDirectory(label string, defaultValue string, required bool) (string, error) { + for { + path, err := PromptString(label, defaultValue, required) + if err != nil { + return "", err + } + + if path == "" && !required { + return "", nil + } + + // Validate directory + if err := ValidateDirectory(path); err != nil { + fmt.Printf("%s %s\n", red("✗"), err.Error()) + continue + } + + return path, nil + } +} + +// PrintHeader prints a colored header +func PrintHeader(text string) { + header := color.New(color.FgCyan, color.Bold) + header.Println("\n" + text) + header.Println(strings.Repeat("=", len(text))) +} + +// PrintSuccess prints a success message +func PrintSuccess(text string) { + fmt.Printf("%s %s\n", green("✓"), text) +} + +// PrintError prints an error message +func PrintError(text string) { + fmt.Printf("%s %s\n", red("✗"), text) +} + +// PrintInfo prints an info message +func PrintInfo(text string) { + fmt.Printf("%s %s\n", yellow("ℹ"), text) +}