First working version

This commit is contained in:
Eden Kirin
2025-10-31 14:36:41 +01:00
parent f8edfc0fc1
commit f9f67b6c93
22 changed files with 2277 additions and 0 deletions

4
.gitignore vendored
View File

@ -1 +1,5 @@
/build
/output /output
*.toml
/entity-maker

10
Makefile Normal file
View File

@ -0,0 +1,10 @@
EXEC=entity-maker
.PHONY: build
build:
@go build -ldflags "-s -w" -o ./build/${EXEC} ./cmd/entity-maker/main.go
upgrade-packages:
@go get -u ./...

136
README.md Normal file
View File

@ -0,0 +1,136 @@
# Entity Maker
A command-line tool for generating SQLAlchemy entities from PostgreSQL database tables.
## Features
- Connects to PostgreSQL database and introspects table structure
- Generates Python 3.13 code using SQLAlchemy
- Creates complete entity scaffolding including:
- SQLAlchemy table definition
- Dataclass model
- Filter class
- Load options
- Repository
- Manager
- Factory (for testing)
- Mapper
- Enum types (if table uses PostgreSQL enums)
## Installation
### Build from source
```bash
go build -o entity-maker ./cmd/entity-maker
```
## Usage
### Interactive Mode
Run the application without arguments to use interactive prompts:
```bash
./entity-maker
```
The application will prompt you for:
- Database host (default: localhost)
- Database port (default: 5432)
- Database name (required)
- Database schema (default: public)
- Database user (default: postgres)
- Database password (default: postgres)
- Table name (required)
- Output directory (required)
- Entity name override (optional)
### Command-Line Arguments
You can also provide parameters via command-line flags:
```bash
./entity-maker \
-host localhost \
-port 5432 \
-db mydb \
-schema public \
-user postgres \
-password secret \
-table users \
-output ./output \
-entity User
```
### Available Flags
- `-host` - Database host
- `-port` - Database port
- `-db` - Database name
- `-schema` - Database schema
- `-user` - Database user
- `-password` - Database password
- `-table` - Table name to generate entities from
- `-output` - Output directory path
- `-entity` - Entity name override (optional, defaults to singularized table name)
## Configuration
The application saves your settings to `~/.config/entity-maker.toml` for future use. When you run the application again, it will use these saved values as defaults.
## Generated Files
For a table named `users`, the following files will be generated in `output/user/`:
- `table.py` - SQLAlchemy Table definition
- `model.py` - Dataclass model with type hints
- `filter.py` - Filter class for queries
- `load_options.py` - Relationship loading options
- `repository.py` - CRUD repository
- `manager.py` - Manager class
- `factory.py` - Factory for testing (uses factory_boy)
- `mapper.py` - SQLAlchemy mapper configuration
- `enum.py` - Enum types (only if table uses PostgreSQL enums)
- `__init__.py` - Package initialization
## Example
Generate entities for the `cashbag_conforms` table:
```bash
./entity-maker -db mydb -table cashbag_conforms -output ./output
```
This will create files in `./output/cashbag_conform/`:
- Entity name: `CashbagConform`
- Module name: `cashbag_conform`
- All supporting files
## Requirements
- PostgreSQL database
- Go 1.21+ (for building)
- Python 3.13+ (for generated code)
- SQLAlchemy (for generated code)
## Project Structure
```
entity-maker/
├── cmd/
│ └── entity-maker/ # Main application
├── internal/
│ ├── config/ # Configuration management
│ ├── database/ # Database connection and introspection
│ ├── generator/ # Code generators
│ ├── naming/ # Naming utilities (pluralization, case conversion)
│ └── prompt/ # Interactive CLI prompts
├── example/ # Example output files
├── CLAUDE.md # Detailed specification
└── README.md # This file
```
## License
See LICENSE file for details.

240
cmd/entity-maker/main.go Normal file
View File

@ -0,0 +1,240 @@
package main
import (
"flag"
"fmt"
"os"
"path/filepath"
"github.com/entity-maker/entity-maker/internal/config"
"github.com/entity-maker/entity-maker/internal/database"
"github.com/entity-maker/entity-maker/internal/generator"
"github.com/entity-maker/entity-maker/internal/naming"
"github.com/entity-maker/entity-maker/internal/prompt"
"github.com/fatih/color"
)
func main() {
if err := run(); err != nil {
prompt.PrintError(err.Error())
os.Exit(1)
}
}
func run() error {
// Parse command line flags
var (
dbHost = flag.String("host", "", "Database host")
dbPort = flag.Int("port", 0, "Database port")
dbName = flag.String("db", "", "Database name")
dbSchema = flag.String("schema", "", "Database schema")
dbUser = flag.String("user", "", "Database user")
dbPassword = flag.String("password", "", "Database password")
dbTable = flag.String("table", "", "Database table name")
outputDir = flag.String("output", "", "Output directory")
entityName = flag.String("entity", "", "Entity name override")
)
flag.Parse()
// Print header
header := color.New(color.FgCyan, color.Bold)
header.Println("\n╔════════════════════════════════════════╗")
header.Println("║ Entity Maker for TelevendCore ║")
header.Println("╚════════════════════════════════════════╝")
// Load configuration
cfg, err := config.Load()
if err != nil {
return fmt.Errorf("failed to load config: %w", err)
}
// Prompt for parameters
prompt.PrintHeader("Database Connection")
// Override config with command line flags if provided
if *dbHost != "" {
cfg.DBHost = *dbHost
}
if *dbPort != 0 {
cfg.DBPort = *dbPort
}
if *dbName != "" {
cfg.DBName = *dbName
}
if *dbSchema != "" {
cfg.DBSchema = *dbSchema
}
if *dbUser != "" {
cfg.DBUser = *dbUser
}
if *dbPassword != "" {
cfg.DBPassword = *dbPassword
}
if *dbTable != "" {
cfg.DBTable = *dbTable
}
if *outputDir != "" {
cfg.OutputDir = *outputDir
}
if *entityName != "" {
cfg.EntityName = *entityName
}
// Prompt for missing parameters
cfg.DBHost, err = prompt.PromptString("Database host", cfg.DBHost, false)
if err != nil {
return err
}
cfg.DBPort, err = prompt.PromptInt("Database port", cfg.DBPort, false)
if err != nil {
return err
}
cfg.DBName, err = prompt.PromptString("Database name", cfg.DBName, true)
if err != nil {
return err
}
cfg.DBSchema, err = prompt.PromptString("Database schema", cfg.DBSchema, false)
if err != nil {
return err
}
cfg.DBUser, err = prompt.PromptString("Database user", cfg.DBUser, false)
if err != nil {
return err
}
cfg.DBPassword, err = prompt.PromptString("Database password", cfg.DBPassword, false)
if err != nil {
return err
}
cfg.DBTable, err = prompt.PromptString("Table name", cfg.DBTable, true)
if err != nil {
return err
}
cfg.OutputDir, err = prompt.PromptDirectory("Output directory", cfg.OutputDir, true)
if err != nil {
return err
}
cfg.EntityName, err = prompt.PromptString("Entity name (optional)", cfg.EntityName, false)
if err != nil {
return err
}
// Save configuration
if err := cfg.Save(); err != nil {
prompt.PrintError(fmt.Sprintf("Failed to save config: %v", err))
// Don't fail, just warn
} else {
prompt.PrintSuccess("Configuration saved")
}
// Connect to database
prompt.PrintHeader("Connecting to Database")
dbClient, err := database.NewClient(database.Config{
Host: cfg.DBHost,
Port: cfg.DBPort,
Database: cfg.DBName,
Schema: cfg.DBSchema,
User: cfg.DBUser,
Password: cfg.DBPassword,
})
if err != nil {
return fmt.Errorf("database connection failed: %w", err)
}
defer dbClient.Close()
prompt.PrintSuccess(fmt.Sprintf("Connected to %s@%s:%d/%s", cfg.DBUser, cfg.DBHost, cfg.DBPort, cfg.DBName))
// Check if table exists
exists, err := dbClient.TableExists(cfg.DBTable)
if err != nil {
return fmt.Errorf("failed to check table existence: %w", err)
}
if !exists {
return fmt.Errorf("table '%s' does not exist in schema '%s'", cfg.DBTable, cfg.DBSchema)
}
prompt.PrintSuccess(fmt.Sprintf("Found table '%s'", cfg.DBTable))
// Introspect table
prompt.PrintHeader("Introspecting Table")
tableInfo, err := dbClient.IntrospectTable(cfg.DBTable)
if err != nil {
return fmt.Errorf("failed to introspect table: %w", err)
}
prompt.PrintSuccess(fmt.Sprintf("Found %d columns", len(tableInfo.Columns)))
prompt.PrintSuccess(fmt.Sprintf("Found %d foreign keys", len(tableInfo.ForeignKeys)))
prompt.PrintSuccess(fmt.Sprintf("Found %d enum types", len(tableInfo.EnumTypes)))
// Create generation context
ctx := generator.NewContext(tableInfo, cfg.EntityName)
// Create output directory
moduleName := naming.SingularizeTableName(cfg.DBTable)
moduleDir := filepath.Join(cfg.OutputDir, moduleName)
if err := os.MkdirAll(moduleDir, 0755); err != nil {
return fmt.Errorf("failed to create output directory: %w", err)
}
prompt.PrintSuccess(fmt.Sprintf("Created directory: %s", moduleDir))
// Generate files
prompt.PrintHeader("Generating Files")
files := map[string]func(*generator.Context) (string, error){
"table.py": generator.GenerateTable,
"model.py": generator.GenerateModel,
"filter.py": generator.GenerateFilter,
"load_options.py": generator.GenerateLoadOptions,
"repository.py": generator.GenerateRepository,
"manager.py": generator.GenerateManager,
"factory.py": generator.GenerateFactory,
"mapper.py": generator.GenerateMapper,
"__init__.py": generator.GenerateInit,
}
// Add enum.py if there are enum types
if len(tableInfo.EnumTypes) > 0 {
files["enum.py"] = generator.GenerateEnum
}
// Generate and write each file
for filename, genFunc := range files {
content, err := genFunc(ctx)
if err != nil {
return fmt.Errorf("failed to generate %s: %w", filename, err)
}
filePath := filepath.Join(moduleDir, filename)
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
return fmt.Errorf("failed to write %s: %w", filename, err)
}
prompt.PrintSuccess(fmt.Sprintf("Generated %s", filename))
}
// Print summary
prompt.PrintHeader("Summary")
fmt.Printf("\n")
fmt.Printf(" Entity name: %s\n", color.GreenString(ctx.EntityName))
fmt.Printf(" Module name: %s\n", color.GreenString(ctx.ModuleName))
fmt.Printf(" Output dir: %s\n", color.GreenString(moduleDir))
fmt.Printf(" Files: %s\n", color.GreenString("%d files generated", len(files)))
fmt.Printf("\n")
success := color.New(color.FgGreen, color.Bold)
success.Println("✓ Code generation completed successfully!")
fmt.Println()
return nil
}

15
go.mod Normal file
View File

@ -0,0 +1,15 @@
module github.com/entity-maker/entity-maker
go 1.25.3
require (
github.com/BurntSushi/toml v1.5.0
github.com/fatih/color v1.18.0
github.com/lib/pq v1.10.9
)
require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
golang.org/x/sys v0.25.0 // indirect
)

15
go.sum Normal file
View File

@ -0,0 +1,15 @@
github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg=
github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.25.0 h1:r+8e+loiHxRqhXVl6ML1nO3l1+oFoWbnlu2Ehimmi34=
golang.org/x/sys v0.25.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=

82
internal/config/config.go Normal file
View File

@ -0,0 +1,82 @@
package config
import (
"os"
"path/filepath"
"github.com/BurntSushi/toml"
)
// Config represents the application configuration
type Config struct {
DBHost string `toml:"db_host"`
DBPort int `toml:"db_port"`
DBName string `toml:"db_name"`
DBSchema string `toml:"db_schema"`
DBUser string `toml:"db_user"`
DBPassword string `toml:"db_password"`
DBTable string `toml:"db_table"`
OutputDir string `toml:"output_dir"`
EntityName string `toml:"entity_name"`
}
// DefaultConfig returns a new config with default values
func DefaultConfig() *Config {
return &Config{
DBHost: "localhost",
DBPort: 5432,
DBSchema: "public",
DBUser: "postgres",
DBPassword: "postgres",
}
}
// Load reads the configuration file from ~/.config/entity-maker.toml
func Load() (*Config, error) {
configPath := getConfigPath()
config := DefaultConfig()
// Check if config file exists
if _, err := os.Stat(configPath); os.IsNotExist(err) {
return config, nil
}
// Read and decode the config file
if _, err := toml.DecodeFile(configPath, config); err != nil {
return nil, err
}
return config, nil
}
// Save writes the configuration to ~/.config/entity-maker.toml
func (c *Config) Save() error {
configPath := getConfigPath()
// Create config directory if it doesn't exist
configDir := filepath.Dir(configPath)
if err := os.MkdirAll(configDir, 0755); err != nil {
return err
}
// Create or truncate the config file
f, err := os.Create(configPath)
if err != nil {
return err
}
defer f.Close()
// Encode and write the config
encoder := toml.NewEncoder(f)
return encoder.Encode(c)
}
// getConfigPath returns the full path to the config file
func getConfigPath() string {
homeDir, err := os.UserHomeDir()
if err != nil {
homeDir = "."
}
return filepath.Join(homeDir, ".config", "entity-maker.toml")
}

View File

@ -0,0 +1,67 @@
package database
import (
"database/sql"
"fmt"
_ "github.com/lib/pq"
)
// Client represents a database connection
type Client struct {
db *sql.DB
schema string
}
// Config holds database connection parameters
type Config struct {
Host string
Port int
Database string
Schema string
User string
Password string
}
// NewClient creates a new database client
func NewClient(cfg Config) (*Client, error) {
connStr := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Database,
)
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
return &Client{
db: db,
schema: cfg.Schema,
}, nil
}
// Close closes the database connection
func (c *Client) Close() error {
return c.db.Close()
}
// TableExists checks if a table exists in the schema
func (c *Client) TableExists(tableName string) (bool, error) {
query := `
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = $1 AND table_name = $2
)
`
var exists bool
err := c.db.QueryRow(query, c.schema, tableName).Scan(&exists)
return exists, err
}

View File

@ -0,0 +1,341 @@
package database
import (
"database/sql"
"fmt"
"strings"
)
// Column represents a database column
type Column struct {
Name string
DataType string
IsNullable bool
ColumnDefault sql.NullString
CharMaxLength sql.NullInt64
NumericPrecision sql.NullInt64
NumericScale sql.NullInt64
UdtName string // User-defined type name (for enums)
IsPrimaryKey bool
IsAutoIncrement bool
}
// ForeignKey represents a foreign key constraint
type ForeignKey struct {
ColumnName string
ForeignTableSchema string
ForeignTableName string
ForeignColumnName string
ConstraintName string
}
// EnumType represents a PostgreSQL enum type
type EnumType struct {
TypeName string
Values []string
}
// TableInfo contains all information about a table
type TableInfo struct {
Schema string
TableName string
Columns []Column
ForeignKeys []ForeignKey
EnumTypes map[string]EnumType
}
// IntrospectTable retrieves comprehensive information about a table
func (c *Client) IntrospectTable(tableName string) (*TableInfo, error) {
info := &TableInfo{
Schema: c.schema,
TableName: tableName,
EnumTypes: make(map[string]EnumType),
}
// Get columns
columns, err := c.getColumns(tableName)
if err != nil {
return nil, err
}
info.Columns = columns
// Get primary keys and mark columns
primaryKeys, err := c.getPrimaryKeys(tableName)
if err != nil {
return nil, err
}
for i := range info.Columns {
for _, pk := range primaryKeys {
if info.Columns[i].Name == pk {
info.Columns[i].IsPrimaryKey = true
}
}
}
// Get foreign keys
foreignKeys, err := c.getForeignKeys(tableName)
if err != nil {
return nil, err
}
info.ForeignKeys = foreignKeys
// Get enum types for columns that use them
for _, col := range info.Columns {
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
if _, exists := info.EnumTypes[col.UdtName]; !exists {
enumType, err := c.getEnumType(col.UdtName)
if err != nil {
return nil, err
}
info.EnumTypes[col.UdtName] = enumType
}
}
}
return info, nil
}
// getColumns retrieves column information for a table
func (c *Client) getColumns(tableName string) ([]Column, error) {
query := `
SELECT
column_name,
data_type,
is_nullable = 'YES' as is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale,
udt_name
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
`
rows, err := c.db.Query(query, c.schema, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query columns: %w", err)
}
defer rows.Close()
var columns []Column
for rows.Next() {
var col Column
err := rows.Scan(
&col.Name,
&col.DataType,
&col.IsNullable,
&col.ColumnDefault,
&col.CharMaxLength,
&col.NumericPrecision,
&col.NumericScale,
&col.UdtName,
)
if err != nil {
return nil, err
}
// Check if column is auto-increment
if col.ColumnDefault.Valid {
defaultVal := strings.ToLower(col.ColumnDefault.String)
col.IsAutoIncrement = strings.Contains(defaultVal, "nextval")
}
columns = append(columns, col)
}
return columns, rows.Err()
}
// getPrimaryKeys retrieves primary key column names for a table
func (c *Client) getPrimaryKeys(tableName string) ([]string, error) {
query := `
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = ($1 || '.' || $2)::regclass
AND i.indisprimary
`
rows, err := c.db.Query(query, c.schema, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query primary keys: %w", err)
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
if err := rows.Scan(&key); err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, rows.Err()
}
// getForeignKeys retrieves foreign key information for a table
func (c *Client) getForeignKeys(tableName string) ([]ForeignKey, error) {
query := `
SELECT
kcu.column_name,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name,
tc.constraint_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = $1
AND tc.table_name = $2
ORDER BY kcu.ordinal_position
`
rows, err := c.db.Query(query, c.schema, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query foreign keys: %w", err)
}
defer rows.Close()
var fks []ForeignKey
for rows.Next() {
var fk ForeignKey
err := rows.Scan(
&fk.ColumnName,
&fk.ForeignTableSchema,
&fk.ForeignTableName,
&fk.ForeignColumnName,
&fk.ConstraintName,
)
if err != nil {
return nil, err
}
fks = append(fks, fk)
}
return fks, rows.Err()
}
// getEnumType retrieves enum values for a PostgreSQL enum type
func (c *Client) getEnumType(typeName string) (EnumType, error) {
query := `
SELECT e.enumlabel
FROM pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
WHERE t.typname = $1
ORDER BY e.enumsortorder
`
rows, err := c.db.Query(query, typeName)
if err != nil {
return EnumType{}, fmt.Errorf("failed to query enum type: %w", err)
}
defer rows.Close()
var values []string
for rows.Next() {
var value string
if err := rows.Scan(&value); err != nil {
return EnumType{}, err
}
values = append(values, value)
}
return EnumType{
TypeName: typeName,
Values: values,
}, rows.Err()
}
// GetPythonType converts PostgreSQL data type to Python type
func GetPythonType(col Column) string {
switch col.DataType {
case "integer", "smallint", "bigint":
return "int"
case "numeric", "decimal", "real", "double precision":
return "Decimal"
case "boolean":
return "bool"
case "character varying", "varchar", "text", "char", "character":
return "str"
case "timestamp with time zone", "timestamp without time zone", "timestamp":
return "datetime"
case "date":
return "date"
case "time with time zone", "time without time zone", "time":
return "time"
case "json", "jsonb":
return "dict"
case "uuid":
return "UUID"
case "bytea":
return "bytes"
case "USER-DEFINED":
// This is likely an enum
return "str" // Will be replaced with specific enum type in generator
default:
return "Any"
}
}
// GetSQLAlchemyType converts PostgreSQL data type to SQLAlchemy type
func GetSQLAlchemyType(col Column) string {
switch col.DataType {
case "integer":
return "Integer"
case "smallint":
return "SmallInteger"
case "bigint":
return "BigInteger"
case "numeric", "decimal":
if col.NumericPrecision.Valid && col.NumericScale.Valid {
return fmt.Sprintf("Numeric(%d, %d)", col.NumericPrecision.Int64, col.NumericScale.Int64)
}
return "Numeric"
case "real":
return "Float"
case "double precision":
return "Float"
case "boolean":
return "Boolean"
case "character varying", "varchar":
if col.CharMaxLength.Valid {
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
}
return "String"
case "char", "character":
if col.CharMaxLength.Valid {
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
}
return "String(1)"
case "text":
return "Text"
case "timestamp with time zone":
return "DateTime(timezone=True)"
case "timestamp without time zone", "timestamp":
return "DateTime"
case "date":
return "Date"
case "time with time zone", "time without time zone", "time":
return "Time"
case "json":
return "JSON"
case "jsonb":
return "JSONB"
case "uuid":
return "UUID"
case "bytea":
return "LargeBinary"
case "USER-DEFINED":
// Will be handled separately for enums
return "Enum"
default:
return "String"
}
}

View File

@ -0,0 +1,40 @@
package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/naming"
)
// GenerateEnum generates the enum types file
func GenerateEnum(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from enum import StrEnum\n\n")
b.WriteString("from televend_core.databases.enum import EnumMixin\n\n\n")
// Generate each enum type
for _, enumType := range ctx.TableInfo.EnumTypes {
enumName := naming.ToPascalCase(enumType.TypeName)
b.WriteString(fmt.Sprintf("class %s(EnumMixin, StrEnum):\n", enumName))
if len(enumType.Values) == 0 {
b.WriteString(" pass\n")
} else {
for _, value := range enumType.Values {
// Convert value to valid Python identifier
// Usually enum values are already uppercase like "OPEN", "IN_PROGRESS"
identifier := strings.ToUpper(strings.ReplaceAll(value, " ", "_"))
identifier = strings.ReplaceAll(identifier, "-", "_")
b.WriteString(fmt.Sprintf(" %s = \"%s\"\n", identifier, value))
}
}
b.WriteString("\n")
}
return b.String(), nil
}

View File

@ -0,0 +1,203 @@
package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/database"
"github.com/entity-maker/entity-maker/internal/naming"
)
// GenerateFactory generates the factory class
func GenerateFactory(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from __future__ import annotations\n\n")
b.WriteString("from typing import Type\n\n")
b.WriteString("import factory\n\n")
// Import factories for related models
fkImports := make(map[string]string) // module_name -> entity_name
for _, fk := range ctx.TableInfo.ForeignKeys {
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
fkImports[moduleName] = entityName
}
for moduleName, entityName := range fkImports {
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
moduleName))
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
b.WriteString(")\n")
}
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
b.WriteString(" CustomSelfAttribute,\n")
b.WriteString(" TelevendBaseFactory,\n")
b.WriteString(")\n\n\n")
// Class definition
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
// Add boolean fields with defaults
for _, col := range ctx.TableInfo.Columns {
if col.DataType == "boolean" {
defaultValue := "True"
if col.Name == "alive" {
defaultValue = "True"
} else {
defaultValue = "False"
}
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue))
}
}
// Add id field
for _, col := range ctx.TableInfo.Columns {
if col.IsPrimaryKey {
b.WriteString(fmt.Sprintf(" %s = None\n", col.Name))
}
}
b.WriteString("\n")
// Generate faker fields for each column
for _, col := range ctx.TableInfo.Columns {
if col.IsPrimaryKey || col.DataType == "boolean" {
continue
}
// Skip foreign keys, we'll handle them separately
if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil {
continue
}
fakerDef := generateFakerField(col, ctx)
if fakerDef != "" {
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef))
}
}
// Generate foreign key relationships
if len(ctx.TableInfo.ForeignKeys) > 0 {
b.WriteString("\n")
for _, fk := range ctx.TableInfo.ForeignKeys {
relationName := GetRelationshipName(fk.ColumnName)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n",
relationName, relationName, entityName))
b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n",
fk.ColumnName, relationName, relationName))
b.WriteString("\n")
}
}
// Meta class
b.WriteString(" class Meta:\n")
b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName))
b.WriteString("\n")
// create_minimal method
b.WriteString(" @classmethod\n")
b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n",
ctx.EntityName, ctx.EntityName))
b.WriteString(" minimal_params = {\n")
// Add foreign key params
for _, fk := range ctx.TableInfo.ForeignKeys {
relationName := GetRelationshipName(fk.ColumnName)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
// Check if this FK is required
var col *database.Column
for i := range ctx.TableInfo.Columns {
if ctx.TableInfo.Columns[i].Name == fk.ColumnName {
col = &ctx.TableInfo.Columns[i]
break
}
}
if col != nil && !col.IsNullable {
// Required FK
b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n",
relationName, relationName, entityName))
} else {
// Optional FK
b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName))
}
}
b.WriteString(" }\n")
b.WriteString(" minimal_params.update(kwargs)\n")
b.WriteString(" return cls.create(**minimal_params)\n")
return b.String(), nil
}
func generateFakerField(col database.Column, ctx *Context) string {
pythonType := database.GetPythonType(col)
switch pythonType {
case "Decimal":
precision := 7
scale := 4
if col.NumericPrecision.Valid {
precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64)
}
if col.NumericScale.Valid {
scale = int(col.NumericScale.Int64)
}
return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)",
precision, scale)
case "int":
if col.DataType == "bigint" {
return "factory.Faker(\"pyint\")"
}
return "factory.Faker(\"pyint\")"
case "str":
// Check if it's an enum
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
if len(enumType.Values) > 0 {
enumName := naming.ToPascalCase(enumType.TypeName)
return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName)
}
}
}
maxLen := 255
if col.CharMaxLength.Valid {
maxLen = int(col.CharMaxLength.Int64)
}
if col.DataType == "text" {
return "factory.Faker(\"text\", max_nb_chars=500)"
}
return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen)
case "datetime":
return "factory.Faker(\"date_time\")"
case "date":
return "factory.Faker(\"date\")"
case "time":
return "factory.Faker(\"time\")"
case "bool":
return "factory.Faker(\"boolean\")"
case "dict":
return "factory.Faker(\"pydict\")"
default:
return "None"
}
}

View File

@ -0,0 +1,77 @@
package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/database"
)
// GenerateFilter generates the filter class
func GenerateFilter(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString("\n\n")
// Class definition
b.WriteString(fmt.Sprintf("class %sFilter(BaseFilter):\n", ctx.EntityName))
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
// Generate filters based on rules:
// - Boolean fields: EQ operator
// - ID fields: both EQ and IN operators
// - Text fields: EQ operator
hasFilters := false
for _, col := range ctx.TableInfo.Columns {
if !ShouldGenerateFilter(col) {
continue
}
hasFilters = true
// Boolean fields
if col.DataType == "boolean" {
defaultVal := "None"
if col.Name == "alive" {
defaultVal = "True"
}
b.WriteString(fmt.Sprintf(" %s: bool | None = filterfield(operator=EQ, default=%s)\n",
col.Name, defaultVal))
}
// ID fields (both EQ and IN)
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
// Single ID with EQ
pythonType := database.GetPythonType(col)
b.WriteString(fmt.Sprintf(" %s: %s | None = filterfield(operator=EQ)\n",
col.Name, pythonType))
// Multiple IDs with IN (plural field name)
pluralFieldName := GetFilterFieldName(col.Name, true)
b.WriteString(fmt.Sprintf(" %s: list[%s] | None = filterfield(field=\"%s\", operator=IN)\n",
pluralFieldName, pythonType, col.Name))
}
// Text fields
if (col.DataType == "character varying" || col.DataType == "varchar" ||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character") &&
!strings.HasSuffix(col.Name, "_id") && col.Name != "id" {
b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n",
col.Name))
}
}
// If no filters were generated, add a pass statement
if !hasFilters {
b.WriteString(" pass\n")
}
return b.String(), nil
}

View File

@ -0,0 +1,195 @@
package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/database"
"github.com/entity-maker/entity-maker/internal/naming"
)
// Context contains all information needed for code generation
type Context struct {
TableInfo *database.TableInfo
EntityName string // Singular, PascalCase (e.g., "CashbagConform")
ModuleName string // Singular, snake_case (e.g., "cashbag_conform")
TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE")
}
// NewContext creates a new generation context
func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context {
moduleName := naming.SingularizeTableName(tableInfo.TableName)
entityName := entityNameOverride
if entityName == "" {
entityName = naming.ToPascalCase(moduleName)
}
tableConstant := strings.ToUpper(moduleName) + "_TABLE"
return &Context{
TableInfo: tableInfo,
EntityName: entityName,
ModuleName: moduleName,
TableConstant: tableConstant,
}
}
// GetRelationshipName returns the relationship name for a foreign key
// Strips _id suffix and converts to snake_case
func GetRelationshipName(fkColumnName string) string {
name := fkColumnName
if strings.HasSuffix(name, "_id") {
name = name[:len(name)-3]
}
return name
}
// GetRelationshipEntityName returns the entity name for a foreign key's target table
func GetRelationshipEntityName(tableName string) string {
singular := naming.SingularizeTableName(tableName)
return naming.ToPascalCase(singular)
}
// GetRelationshipModuleName returns the module name for a foreign key's target table
func GetRelationshipModuleName(tableName string) string {
return naming.SingularizeTableName(tableName)
}
// GetFilterFieldName returns the filter field name for a column
// For ID fields with IN operator, pluralizes the name
func GetFilterFieldName(columnName string, useIN bool) string {
if useIN && strings.HasSuffix(columnName, "_id") {
// Remove _id, pluralize, add back _ids
base := columnName[:len(columnName)-3]
if base == "" {
return "ids"
}
return naming.Pluralize(base) + "_ids"
}
if useIN && columnName == "id" {
return "ids"
}
return columnName
}
// ShouldGenerateFilter determines if a column should have a filter
func ShouldGenerateFilter(col database.Column) bool {
// Generate filters for boolean, ID, and text fields
if col.DataType == "boolean" {
return true
}
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
return true
}
if col.DataType == "character varying" || col.DataType == "varchar" ||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character" {
return true
}
return false
}
// GetPythonTypeForColumn returns the Python type annotation for a column
func GetPythonTypeForColumn(col database.Column, ctx *Context) string {
baseType := database.GetPythonType(col)
// Handle enum types
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
baseType = naming.ToPascalCase(enumType.TypeName)
}
}
// Handle Decimal type
if baseType == "Decimal" {
baseType = "Decimal"
}
return baseType
}
// NeedsDecimalImport checks if any column uses Decimal type
func NeedsDecimalImport(columns []database.Column) bool {
for _, col := range columns {
if database.GetPythonType(col) == "Decimal" {
return true
}
}
return false
}
// NeedsDatetimeImport checks if any column uses datetime type
func NeedsDatetimeImport(columns []database.Column) bool {
for _, col := range columns {
pyType := database.GetPythonType(col)
if pyType == "datetime" || pyType == "date" || pyType == "time" {
return true
}
}
return false
}
// GetRequiredColumns returns columns that are not nullable and don't have defaults
func GetRequiredColumns(columns []database.Column) []database.Column {
var required []database.Column
for _, col := range columns {
if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey {
required = append(required, col)
}
}
return required
}
// GetOptionalColumns returns columns that are nullable or have defaults
func GetOptionalColumns(columns []database.Column) []database.Column {
var optional []database.Column
for _, col := range columns {
if col.IsNullable || col.ColumnDefault.Valid {
optional = append(optional, col)
}
}
return optional
}
// GetForeignKeyForColumn returns the foreign key info for a column, if it exists
func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey {
for _, fk := range foreignKeys {
if fk.ColumnName == columnName {
return &fk
}
}
return nil
}
// GenerateFiles generates all Python files for the entity
func GenerateFiles(ctx *Context, outputDir string) error {
generators := map[string]func(*Context) (string, error){
"table.py": GenerateTable,
"model.py": GenerateModel,
"filter.py": GenerateFilter,
"load_options.py": GenerateLoadOptions,
"repository.py": GenerateRepository,
"manager.py": GenerateManager,
"factory.py": GenerateFactory,
"mapper.py": GenerateMapper,
"__init__.py": GenerateInit,
}
// Generate enum.py if there are enum types
if len(ctx.TableInfo.EnumTypes) > 0 {
generators["enum.py"] = GenerateEnum
}
for filename, generator := range generators {
content, err := generator(ctx)
if err != nil {
return fmt.Errorf("failed to generate %s: %w", filename, err)
}
// Write to file (this will be handled by the main function)
// For now, just return the content
_ = content
}
return nil
}

View File

@ -0,0 +1,7 @@
package generator
// GenerateInit generates an empty __init__.py file
func GenerateInit(ctx *Context) (string, error) {
// Empty __init__.py file
return "", nil
}

View File

@ -0,0 +1,38 @@
package generator
import (
"fmt"
"strings"
)
// GenerateLoadOptions generates the load options class
func GenerateLoadOptions(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString("\n\n")
// Class definition
b.WriteString(fmt.Sprintf("class %sLoadOptions(LoadOptions):\n", ctx.EntityName))
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
// Generate load options for all foreign key relationships
hasRelationships := false
for _, fk := range ctx.TableInfo.ForeignKeys {
hasRelationships = true
relationName := GetRelationshipName(fk.ColumnName)
b.WriteString(fmt.Sprintf(" load_%s: bool = joinload(relations=[\"%s\"])\n",
relationName, relationName))
}
// If no relationships, add pass
if !hasRelationships {
b.WriteString(" pass\n")
}
return b.String(), nil
}

View File

@ -0,0 +1,34 @@
package generator
import (
"fmt"
"strings"
)
// GenerateManager generates the manager class
func GenerateManager(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString("\n\n")
// Class definition
b.WriteString(fmt.Sprintf("class %sManager(\n", ctx.EntityName))
b.WriteString(fmt.Sprintf(" CRUDManager[%s, %sFilter, %sRepository]\n",
ctx.EntityName, ctx.EntityName, ctx.EntityName))
b.WriteString("):\n")
b.WriteString(fmt.Sprintf(" repository_cls = %sRepository\n", ctx.EntityName))
return b.String(), nil
}

View File

@ -0,0 +1,48 @@
package generator
import (
"fmt"
"strings"
)
// GenerateMapper generates the mapper snippet (without imports)
func GenerateMapper(ctx *Context) (string, error) {
var b strings.Builder
// Mapper registration (snippet only, no imports)
b.WriteString(" mapper_registry.map_imperatively(\n")
b.WriteString(fmt.Sprintf(" class_=%s,\n", ctx.EntityName))
b.WriteString(fmt.Sprintf(" local_table=%s,\n", ctx.TableConstant))
b.WriteString(" properties={\n")
// Generate relationships for all foreign keys
fkRelationships := make(map[string][]string) // entity -> []column_names
for _, fk := range ctx.TableInfo.ForeignKeys {
relationName := GetRelationshipName(fk.ColumnName)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
// Group by entity name to handle multiple FKs to same table
fkRelationships[entityName] = append(fkRelationships[entityName], fk.ColumnName)
if len(fkRelationships[entityName]) == 1 {
// First FK to this table
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
b.WriteString(fmt.Sprintf(" %s, lazy=relationship_loading_strategy.value\n", entityName))
} else {
// Multiple FKs to same table, need to specify foreign_keys
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
b.WriteString(fmt.Sprintf(" %s,\n", entityName))
b.WriteString(" lazy=relationship_loading_strategy.value,\n")
b.WriteString(fmt.Sprintf(" foreign_keys=%s.columns.%s,\n",
ctx.TableConstant, fk.ColumnName))
}
b.WriteString(" ),\n")
}
b.WriteString(" },\n")
b.WriteString(" )\n")
return b.String(), nil
}

124
internal/generator/model.go Normal file
View File

@ -0,0 +1,124 @@
package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/naming"
)
// GenerateModel generates the dataclass model
func GenerateModel(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from dataclasses import dataclass\n")
// Check what we need to import
needsDatetime := NeedsDatetimeImport(ctx.TableInfo.Columns)
needsDecimal := NeedsDecimalImport(ctx.TableInfo.Columns)
if needsDatetime {
b.WriteString("from datetime import datetime\n")
}
if needsDecimal {
b.WriteString("from decimal import Decimal\n")
}
b.WriteString("\n")
b.WriteString("from televend_core.databases.base_model import Base\n")
// Import related models for foreign keys
fkImports := make(map[string]string) // module_name -> entity_name
for _, fk := range ctx.TableInfo.ForeignKeys {
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
fkImports[moduleName] = entityName
}
// Import enum types
if len(ctx.TableInfo.EnumTypes) > 0 {
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
ctx.ModuleName))
for _, enumType := range ctx.TableInfo.EnumTypes {
enumName := naming.ToPascalCase(enumType.TypeName)
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
}
b.WriteString(")\n")
}
// Write foreign key imports
for moduleName, entityName := range fkImports {
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
moduleName, entityName))
}
b.WriteString("\n\n")
// Class definition
b.WriteString("@dataclass\n")
b.WriteString(fmt.Sprintf("class %s(Base):\n", ctx.EntityName))
// Get required and optional columns
requiredCols := GetRequiredColumns(ctx.TableInfo.Columns)
optionalCols := GetOptionalColumns(ctx.TableInfo.Columns)
// Required fields (non-nullable, no default, not auto-increment, not PK)
for _, col := range requiredCols {
fieldName := col.Name
pythonType := GetPythonTypeForColumn(col, ctx)
// Regular field
b.WriteString(fmt.Sprintf(" %s: %s\n", fieldName, pythonType))
// Add relationship field if this is a foreign key and column ends with _id
// (to avoid name clashes with FK columns that don't follow _id convention)
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
if strings.HasSuffix(col.Name, "_id") {
relationName := GetRelationshipName(col.Name)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
b.WriteString(fmt.Sprintf(" %s: %s\n", relationName, entityName))
}
}
}
// Empty line between required and optional
if len(requiredCols) > 0 && len(optionalCols) > 0 {
b.WriteString("\n")
}
// Optional fields
for _, col := range optionalCols {
// Skip primary key, we'll add it at the end
if col.IsPrimaryKey {
continue
}
fieldName := col.Name
pythonType := GetPythonTypeForColumn(col, ctx)
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", fieldName, pythonType))
// Add relationship field if this is a foreign key and column ends with _id
// (to avoid name clashes with FK columns that don't follow _id convention)
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
if strings.HasSuffix(col.Name, "_id") {
relationName := GetRelationshipName(col.Name)
entityName := GetRelationshipEntityName(fk.ForeignTableName)
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", relationName, entityName))
}
}
}
// Add primary key at the end
for _, col := range ctx.TableInfo.Columns {
if col.IsPrimaryKey {
b.WriteString("\n")
pythonType := GetPythonTypeForColumn(col, ctx)
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", col.Name, pythonType))
break
}
}
return b.String(), nil
}

View File

@ -0,0 +1,28 @@
package generator
import (
"fmt"
"strings"
)
// GenerateRepository generates the repository class
func GenerateRepository(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
b.WriteString(")\n")
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
ctx.ModuleName, ctx.EntityName))
b.WriteString("\n\n")
// Class definition
b.WriteString(fmt.Sprintf("class %sRepository(CRUDRepository[%s, %sFilter]):\n",
ctx.EntityName, ctx.EntityName, ctx.EntityName))
b.WriteString(fmt.Sprintf(" model_cls = %s\n", ctx.EntityName))
return b.String(), nil
}

160
internal/generator/table.go Normal file
View File

@ -0,0 +1,160 @@
package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/database"
"github.com/entity-maker/entity-maker/internal/naming"
)
// GenerateTable generates the SQLAlchemy table definition
func GenerateTable(ctx *Context) (string, error) {
var b strings.Builder
// Imports
b.WriteString("from sqlalchemy import (\n")
// Collect unique imports
imports := make(map[string]bool)
imports["Column"] = true
imports["Table"] = true
for _, col := range ctx.TableInfo.Columns {
sqlType := database.GetSQLAlchemyType(col)
// Extract base type name (before parentheses)
typeName := strings.Split(sqlType, "(")[0]
imports[typeName] = true
if col.IsPrimaryKey {
imports["Integer"] = true
}
// Check for foreign keys
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
imports["ForeignKey"] = true
}
// Check for enums
if col.DataType == "USER-DEFINED" {
imports["Enum"] = true
}
}
// Sort and write imports
importList := []string{}
for imp := range imports {
importList = append(importList, imp)
}
// Write imports in a reasonable order
orderedImports := []string{
"BigInteger", "Boolean", "Column", "Date", "DateTime", "Enum",
"Float", "ForeignKey", "Integer", "JSON", "JSONB", "LargeBinary",
"Numeric", "SmallInteger", "String", "Table", "Text", "Time", "UUID",
}
first := true
for _, imp := range orderedImports {
if imports[imp] {
if !first {
b.WriteString(",\n")
} else {
first = false
}
b.WriteString(" " + imp)
}
}
b.WriteString(",\n)\n\n")
// Import enum types if they exist
if len(ctx.TableInfo.EnumTypes) > 0 {
for _, enumType := range ctx.TableInfo.EnumTypes {
enumName := naming.ToPascalCase(enumType.TypeName)
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
ctx.ModuleName))
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
b.WriteString(")\n")
}
}
b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n")
// Table definition
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
b.WriteString(fmt.Sprintf(" \"%s\",\n", ctx.TableInfo.TableName))
b.WriteString(" metadata_obj,\n")
// Columns
for _, col := range ctx.TableInfo.Columns {
b.WriteString(generateColumnDefinition(col, ctx))
}
b.WriteString(")\n")
return b.String(), nil
}
func generateColumnDefinition(col database.Column, ctx *Context) string {
var parts []string
// Column name
parts = append(parts, fmt.Sprintf("\"%s\"", col.Name))
// Column type
sqlType := database.GetSQLAlchemyType(col)
// Handle enum types
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
enumName := naming.ToPascalCase(enumType.TypeName)
sqlType = fmt.Sprintf("Enum(\n *%s.to_value_list(),\n name=\"%s\",\n )",
enumName, enumType.TypeName)
}
}
parts = append(parts, sqlType)
// Foreign key
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
fkDef := fmt.Sprintf("ForeignKey(\"%s.%s\", deferrable=True, initially=\"DEFERRED\")",
fk.ForeignTableName, fk.ForeignColumnName)
parts = append(parts, fkDef)
}
// Primary key
if col.IsPrimaryKey {
parts = append(parts, "primary_key=True")
if col.IsAutoIncrement {
parts = append(parts, "autoincrement=True")
}
}
// Nullable
if !col.IsNullable && !col.IsPrimaryKey {
parts = append(parts, "nullable=False")
}
// Unique
// Note: We don't have unique constraint info in our introspection yet
// This would need to be added if needed
// Format the column definition
result := " Column("
// Check if we need multiline formatting (for complex types like Enum)
if strings.Contains(sqlType, "\n") {
// Multiline format
result += parts[0] + ",\n " + parts[1]
for _, part := range parts[2:] {
result += ",\n " + part
}
result += ",\n ),\n"
} else {
// Single line format
result += strings.Join(parts, ", ") + "),\n"
}
return result
}

227
internal/naming/naming.go Normal file
View File

@ -0,0 +1,227 @@
package naming
import (
"strings"
"unicode"
)
// ToSnakeCase converts a string to snake_case
func ToSnakeCase(s string) string {
var result []rune
for i, r := range s {
if unicode.IsUpper(r) {
if i > 0 {
result = append(result, '_')
}
result = append(result, unicode.ToLower(r))
} else {
result = append(result, r)
}
}
return string(result)
}
// ToPascalCase converts snake_case to PascalCase
func ToPascalCase(s string) string {
parts := strings.Split(s, "_")
for i, part := range parts {
if len(part) > 0 {
parts[i] = strings.ToUpper(part[:1]) + part[1:]
}
}
return strings.Join(parts, "")
}
// Singularize converts a plural word to singular
// Uses common English pluralization rules
func Singularize(word string) string {
// Handle empty string
if word == "" {
return word
}
// Common irregular plurals
irregulars := map[string]string{
"people": "person",
"men": "man",
"women": "woman",
"children": "child",
"teeth": "tooth",
"feet": "foot",
"mice": "mouse",
"geese": "goose",
"oxen": "ox",
"sheep": "sheep",
"fish": "fish",
"deer": "deer",
"series": "series",
"species": "species",
"quizzes": "quiz",
"analyses": "analysis",
"diagnoses": "diagnosis",
"oases": "oasis",
"theses": "thesis",
"crises": "crisis",
"phenomena": "phenomenon",
"criteria": "criterion",
"data": "datum",
}
lower := strings.ToLower(word)
if singular, ok := irregulars[lower]; ok {
return preserveCase(word, singular)
}
// Handle words ending in 'ies' -> 'y'
if strings.HasSuffix(lower, "ies") && len(word) > 3 {
return word[:len(word)-3] + "y"
}
// Handle words ending in 'ves' -> 'fe' or 'f'
if strings.HasSuffix(lower, "ves") && len(word) > 3 {
base := word[:len(word)-3]
// Common words that end in 'fe'
if strings.HasSuffix(strings.ToLower(base), "li") ||
strings.HasSuffix(strings.ToLower(base), "wi") ||
strings.HasSuffix(strings.ToLower(base), "kni") {
return base + "fe"
}
return base + "f"
}
// Handle words ending in 'xes', 'ses', 'shes', 'ches' -> remove 'es'
if strings.HasSuffix(lower, "xes") ||
strings.HasSuffix(lower, "ses") ||
strings.HasSuffix(lower, "shes") ||
strings.HasSuffix(lower, "ches") {
if len(word) > 2 {
return word[:len(word)-2]
}
}
// Handle words ending in 'oes' -> 'o'
if strings.HasSuffix(lower, "oes") && len(word) > 3 {
return word[:len(word)-2]
}
// Handle simple 's' suffix
if strings.HasSuffix(lower, "s") && len(word) > 1 {
// Don't remove 's' from words that naturally end in 's'
if !strings.HasSuffix(lower, "ss") &&
!strings.HasSuffix(lower, "us") &&
!strings.HasSuffix(lower, "is") {
return word[:len(word)-1]
}
}
return word
}
// Pluralize converts a singular word to plural
func Pluralize(word string) string {
if word == "" {
return word
}
// Common irregular plurals
irregulars := map[string]string{
"person": "people",
"man": "men",
"woman": "women",
"child": "children",
"tooth": "teeth",
"foot": "feet",
"mouse": "mice",
"goose": "geese",
"ox": "oxen",
"sheep": "sheep",
"fish": "fish",
"deer": "deer",
"series": "series",
"species": "species",
"quiz": "quizzes",
"analysis": "analyses",
"diagnosis": "diagnoses",
"oasis": "oases",
"thesis": "theses",
"crisis": "crises",
"phenomenon": "phenomena",
"criterion": "criteria",
"datum": "data",
}
lower := strings.ToLower(word)
if plural, ok := irregulars[lower]; ok {
return preserveCase(word, plural)
}
// Already plural (ends in 's' and not special case)
if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") {
return word
}
// Handle words ending in 'y' -> 'ies'
if strings.HasSuffix(lower, "y") && len(word) > 1 {
prevChar := rune(lower[len(lower)-2])
if !isVowel(prevChar) {
return word[:len(word)-1] + "ies"
}
}
// Handle words ending in 'f' or 'fe' -> 'ves'
if strings.HasSuffix(lower, "f") {
return word[:len(word)-1] + "ves"
}
if strings.HasSuffix(lower, "fe") {
return word[:len(word)-2] + "ves"
}
// Handle words ending in 'o' -> 'oes'
if strings.HasSuffix(lower, "o") && len(word) > 1 {
prevChar := rune(lower[len(lower)-2])
if !isVowel(prevChar) {
return word + "es"
}
}
// Handle words ending in 'x', 's', 'sh', 'ch' -> add 'es'
if strings.HasSuffix(lower, "x") ||
strings.HasSuffix(lower, "s") ||
strings.HasSuffix(lower, "sh") ||
strings.HasSuffix(lower, "ch") {
return word + "es"
}
// Default: just add 's'
return word + "s"
}
// preserveCase preserves the case pattern of the original word
func preserveCase(original, replacement string) string {
if len(original) == 0 {
return replacement
}
if unicode.IsUpper(rune(original[0])) {
return strings.ToUpper(replacement[:1]) + replacement[1:]
}
return replacement
}
// isVowel checks if a rune is a vowel
func isVowel(r rune) bool {
vowels := "aeiouAEIOU"
return strings.ContainsRune(vowels, r)
}
// SingularizeTableName converts a table name to its singular form
// Handles snake_case table names
func SingularizeTableName(tableName string) string {
parts := strings.Split(tableName, "_")
if len(parts) > 0 {
// Only singularize the last part
lastIdx := len(parts) - 1
parts[lastIdx] = Singularize(parts[lastIdx])
}
return strings.Join(parts, "_")
}

186
internal/prompt/prompt.go Normal file
View File

@ -0,0 +1,186 @@
package prompt
import (
"bufio"
"fmt"
"os"
"strconv"
"strings"
"github.com/fatih/color"
)
var (
cyan = color.New(color.FgCyan).SprintFunc()
green = color.New(color.FgGreen).SprintFunc()
yellow = color.New(color.FgYellow).SprintFunc()
red = color.New(color.FgRed).SprintFunc()
)
// Reader interface for testing
type Reader interface {
ReadString(delim byte) (string, error)
}
// PromptString prompts for a string value
func PromptString(label string, defaultValue string, required bool) (string, error) {
reader := bufio.NewReader(os.Stdin)
return promptStringWithReader(reader, label, defaultValue, required)
}
func promptStringWithReader(reader Reader, label string, defaultValue string, required bool) (string, error) {
for {
prompt := fmt.Sprintf("%s", cyan(label))
if defaultValue != "" {
prompt += fmt.Sprintf(" [%s]", green(defaultValue))
}
prompt += ": "
fmt.Print(prompt)
input, err := reader.ReadString('\n')
if err != nil {
return "", err
}
input = strings.TrimSpace(input)
// If input is empty, use default value
if input == "" {
if defaultValue != "" {
// Echo the selected default value
fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultValue), defaultValue)
fmt.Printf("\033[1A\r%s\n", fullPrompt)
return defaultValue, nil
}
if required {
fmt.Println(red("✗ This field is required"))
continue
}
return "", nil
}
return input, nil
}
}
// PromptInt prompts for an integer value
func PromptInt(label string, defaultValue int, required bool) (int, error) {
reader := bufio.NewReader(os.Stdin)
return promptIntWithReader(reader, label, defaultValue, required)
}
func promptIntWithReader(reader Reader, label string, defaultValue int, required bool) (int, error) {
for {
defaultStr := ""
if defaultValue != 0 {
defaultStr = strconv.Itoa(defaultValue)
}
prompt := fmt.Sprintf("%s", cyan(label))
if defaultStr != "" {
prompt += fmt.Sprintf(" [%s]", green(defaultStr))
}
prompt += ": "
fmt.Print(prompt)
input, err := reader.ReadString('\n')
if err != nil {
return 0, err
}
input = strings.TrimSpace(input)
// If input is empty, use default value
if input == "" {
if defaultValue != 0 {
// Echo the selected default value
fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultStr), defaultStr)
fmt.Printf("\033[1A\r%s\n", fullPrompt)
return defaultValue, nil
}
if required {
fmt.Println(red("✗ This field is required"))
continue
}
return 0, nil
}
// Parse integer
value, err := strconv.Atoi(input)
if err != nil {
fmt.Println(red("✗ Please enter a valid number"))
continue
}
return value, nil
}
}
// ValidateDirectory checks if a directory exists or can be created
func ValidateDirectory(path string) error {
// Check if path exists
info, err := os.Stat(path)
if err == nil {
// Path exists, check if it's a directory
if !info.IsDir() {
return fmt.Errorf("path exists but is not a directory")
}
return nil
}
// Path doesn't exist, try to create it
if os.IsNotExist(err) {
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("cannot create directory: %w", err)
}
return nil
}
return err
}
// PromptDirectory prompts for a directory path and validates it
func PromptDirectory(label string, defaultValue string, required bool) (string, error) {
for {
path, err := PromptString(label, defaultValue, required)
if err != nil {
return "", err
}
if path == "" && !required {
return "", nil
}
// Validate directory
if err := ValidateDirectory(path); err != nil {
fmt.Printf("%s %s\n", red("✗"), err.Error())
continue
}
return path, nil
}
}
// PrintHeader prints a colored header
func PrintHeader(text string) {
header := color.New(color.FgCyan, color.Bold)
header.Println("\n" + text)
header.Println(strings.Repeat("=", len(text)))
}
// PrintSuccess prints a success message
func PrintSuccess(text string) {
fmt.Printf("%s %s\n", green("✓"), text)
}
// PrintError prints an error message
func PrintError(text string) {
fmt.Printf("%s %s\n", red("✗"), text)
}
// PrintInfo prints an info message
func PrintInfo(text string) {
fmt.Printf("%s %s\n", yellow(""), text)
}