First working version
This commit is contained in:
82
internal/config/config.go
Normal file
82
internal/config/config.go
Normal file
@ -0,0 +1,82 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
)
|
||||
|
||||
// Config represents the application configuration
|
||||
type Config struct {
|
||||
DBHost string `toml:"db_host"`
|
||||
DBPort int `toml:"db_port"`
|
||||
DBName string `toml:"db_name"`
|
||||
DBSchema string `toml:"db_schema"`
|
||||
DBUser string `toml:"db_user"`
|
||||
DBPassword string `toml:"db_password"`
|
||||
DBTable string `toml:"db_table"`
|
||||
OutputDir string `toml:"output_dir"`
|
||||
EntityName string `toml:"entity_name"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a new config with default values
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
DBHost: "localhost",
|
||||
DBPort: 5432,
|
||||
DBSchema: "public",
|
||||
DBUser: "postgres",
|
||||
DBPassword: "postgres",
|
||||
}
|
||||
}
|
||||
|
||||
// Load reads the configuration file from ~/.config/entity-maker.toml
|
||||
func Load() (*Config, error) {
|
||||
configPath := getConfigPath()
|
||||
|
||||
config := DefaultConfig()
|
||||
|
||||
// Check if config file exists
|
||||
if _, err := os.Stat(configPath); os.IsNotExist(err) {
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Read and decode the config file
|
||||
if _, err := toml.DecodeFile(configPath, config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
// Save writes the configuration to ~/.config/entity-maker.toml
|
||||
func (c *Config) Save() error {
|
||||
configPath := getConfigPath()
|
||||
|
||||
// Create config directory if it doesn't exist
|
||||
configDir := filepath.Dir(configPath)
|
||||
if err := os.MkdirAll(configDir, 0755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create or truncate the config file
|
||||
f, err := os.Create(configPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
// Encode and write the config
|
||||
encoder := toml.NewEncoder(f)
|
||||
return encoder.Encode(c)
|
||||
}
|
||||
|
||||
// getConfigPath returns the full path to the config file
|
||||
func getConfigPath() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
homeDir = "."
|
||||
}
|
||||
return filepath.Join(homeDir, ".config", "entity-maker.toml")
|
||||
}
|
||||
67
internal/database/client.go
Normal file
67
internal/database/client.go
Normal file
@ -0,0 +1,67 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// Client represents a database connection
|
||||
type Client struct {
|
||||
db *sql.DB
|
||||
schema string
|
||||
}
|
||||
|
||||
// Config holds database connection parameters
|
||||
type Config struct {
|
||||
Host string
|
||||
Port int
|
||||
Database string
|
||||
Schema string
|
||||
User string
|
||||
Password string
|
||||
}
|
||||
|
||||
// NewClient creates a new database client
|
||||
func NewClient(cfg Config) (*Client, error) {
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
|
||||
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Database,
|
||||
)
|
||||
|
||||
db, err := sql.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
return nil, fmt.Errorf("failed to connect to database: %w", err)
|
||||
}
|
||||
|
||||
return &Client{
|
||||
db: db,
|
||||
schema: cfg.Schema,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the database connection
|
||||
func (c *Client) Close() error {
|
||||
return c.db.Close()
|
||||
}
|
||||
|
||||
// TableExists checks if a table exists in the schema
|
||||
func (c *Client) TableExists(tableName string) (bool, error) {
|
||||
query := `
|
||||
SELECT EXISTS (
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
)
|
||||
`
|
||||
|
||||
var exists bool
|
||||
err := c.db.QueryRow(query, c.schema, tableName).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
341
internal/database/introspector.go
Normal file
341
internal/database/introspector.go
Normal file
@ -0,0 +1,341 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Column represents a database column
|
||||
type Column struct {
|
||||
Name string
|
||||
DataType string
|
||||
IsNullable bool
|
||||
ColumnDefault sql.NullString
|
||||
CharMaxLength sql.NullInt64
|
||||
NumericPrecision sql.NullInt64
|
||||
NumericScale sql.NullInt64
|
||||
UdtName string // User-defined type name (for enums)
|
||||
IsPrimaryKey bool
|
||||
IsAutoIncrement bool
|
||||
}
|
||||
|
||||
// ForeignKey represents a foreign key constraint
|
||||
type ForeignKey struct {
|
||||
ColumnName string
|
||||
ForeignTableSchema string
|
||||
ForeignTableName string
|
||||
ForeignColumnName string
|
||||
ConstraintName string
|
||||
}
|
||||
|
||||
// EnumType represents a PostgreSQL enum type
|
||||
type EnumType struct {
|
||||
TypeName string
|
||||
Values []string
|
||||
}
|
||||
|
||||
// TableInfo contains all information about a table
|
||||
type TableInfo struct {
|
||||
Schema string
|
||||
TableName string
|
||||
Columns []Column
|
||||
ForeignKeys []ForeignKey
|
||||
EnumTypes map[string]EnumType
|
||||
}
|
||||
|
||||
// IntrospectTable retrieves comprehensive information about a table
|
||||
func (c *Client) IntrospectTable(tableName string) (*TableInfo, error) {
|
||||
info := &TableInfo{
|
||||
Schema: c.schema,
|
||||
TableName: tableName,
|
||||
EnumTypes: make(map[string]EnumType),
|
||||
}
|
||||
|
||||
// Get columns
|
||||
columns, err := c.getColumns(tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.Columns = columns
|
||||
|
||||
// Get primary keys and mark columns
|
||||
primaryKeys, err := c.getPrimaryKeys(tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range info.Columns {
|
||||
for _, pk := range primaryKeys {
|
||||
if info.Columns[i].Name == pk {
|
||||
info.Columns[i].IsPrimaryKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get foreign keys
|
||||
foreignKeys, err := c.getForeignKeys(tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.ForeignKeys = foreignKeys
|
||||
|
||||
// Get enum types for columns that use them
|
||||
for _, col := range info.Columns {
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if _, exists := info.EnumTypes[col.UdtName]; !exists {
|
||||
enumType, err := c.getEnumType(col.UdtName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.EnumTypes[col.UdtName] = enumType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// getColumns retrieves column information for a table
|
||||
func (c *Client) getColumns(tableName string) ([]Column, error) {
|
||||
query := `
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable = 'YES' as is_nullable,
|
||||
column_default,
|
||||
character_maximum_length,
|
||||
numeric_precision,
|
||||
numeric_scale,
|
||||
udt_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
ORDER BY ordinal_position
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, c.schema, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columns []Column
|
||||
for rows.Next() {
|
||||
var col Column
|
||||
err := rows.Scan(
|
||||
&col.Name,
|
||||
&col.DataType,
|
||||
&col.IsNullable,
|
||||
&col.ColumnDefault,
|
||||
&col.CharMaxLength,
|
||||
&col.NumericPrecision,
|
||||
&col.NumericScale,
|
||||
&col.UdtName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if column is auto-increment
|
||||
if col.ColumnDefault.Valid {
|
||||
defaultVal := strings.ToLower(col.ColumnDefault.String)
|
||||
col.IsAutoIncrement = strings.Contains(defaultVal, "nextval")
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
return columns, rows.Err()
|
||||
}
|
||||
|
||||
// getPrimaryKeys retrieves primary key column names for a table
|
||||
func (c *Client) getPrimaryKeys(tableName string) ([]string, error) {
|
||||
query := `
|
||||
SELECT a.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||
WHERE i.indrelid = ($1 || '.' || $2)::regclass
|
||||
AND i.indisprimary
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, c.schema, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query primary keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
if err := rows.Scan(&key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys, rows.Err()
|
||||
}
|
||||
|
||||
// getForeignKeys retrieves foreign key information for a table
|
||||
func (c *Client) getForeignKeys(tableName string) ([]ForeignKey, error) {
|
||||
query := `
|
||||
SELECT
|
||||
kcu.column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name,
|
||||
tc.constraint_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_schema = $1
|
||||
AND tc.table_name = $2
|
||||
ORDER BY kcu.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, c.schema, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query foreign keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var fks []ForeignKey
|
||||
for rows.Next() {
|
||||
var fk ForeignKey
|
||||
err := rows.Scan(
|
||||
&fk.ColumnName,
|
||||
&fk.ForeignTableSchema,
|
||||
&fk.ForeignTableName,
|
||||
&fk.ForeignColumnName,
|
||||
&fk.ConstraintName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
|
||||
return fks, rows.Err()
|
||||
}
|
||||
|
||||
// getEnumType retrieves enum values for a PostgreSQL enum type
|
||||
func (c *Client) getEnumType(typeName string) (EnumType, error) {
|
||||
query := `
|
||||
SELECT e.enumlabel
|
||||
FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = $1
|
||||
ORDER BY e.enumsortorder
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, typeName)
|
||||
if err != nil {
|
||||
return EnumType{}, fmt.Errorf("failed to query enum type: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var values []string
|
||||
for rows.Next() {
|
||||
var value string
|
||||
if err := rows.Scan(&value); err != nil {
|
||||
return EnumType{}, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return EnumType{
|
||||
TypeName: typeName,
|
||||
Values: values,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// GetPythonType converts PostgreSQL data type to Python type
|
||||
func GetPythonType(col Column) string {
|
||||
switch col.DataType {
|
||||
case "integer", "smallint", "bigint":
|
||||
return "int"
|
||||
case "numeric", "decimal", "real", "double precision":
|
||||
return "Decimal"
|
||||
case "boolean":
|
||||
return "bool"
|
||||
case "character varying", "varchar", "text", "char", "character":
|
||||
return "str"
|
||||
case "timestamp with time zone", "timestamp without time zone", "timestamp":
|
||||
return "datetime"
|
||||
case "date":
|
||||
return "date"
|
||||
case "time with time zone", "time without time zone", "time":
|
||||
return "time"
|
||||
case "json", "jsonb":
|
||||
return "dict"
|
||||
case "uuid":
|
||||
return "UUID"
|
||||
case "bytea":
|
||||
return "bytes"
|
||||
case "USER-DEFINED":
|
||||
// This is likely an enum
|
||||
return "str" // Will be replaced with specific enum type in generator
|
||||
default:
|
||||
return "Any"
|
||||
}
|
||||
}
|
||||
|
||||
// GetSQLAlchemyType converts PostgreSQL data type to SQLAlchemy type
|
||||
func GetSQLAlchemyType(col Column) string {
|
||||
switch col.DataType {
|
||||
case "integer":
|
||||
return "Integer"
|
||||
case "smallint":
|
||||
return "SmallInteger"
|
||||
case "bigint":
|
||||
return "BigInteger"
|
||||
case "numeric", "decimal":
|
||||
if col.NumericPrecision.Valid && col.NumericScale.Valid {
|
||||
return fmt.Sprintf("Numeric(%d, %d)", col.NumericPrecision.Int64, col.NumericScale.Int64)
|
||||
}
|
||||
return "Numeric"
|
||||
case "real":
|
||||
return "Float"
|
||||
case "double precision":
|
||||
return "Float"
|
||||
case "boolean":
|
||||
return "Boolean"
|
||||
case "character varying", "varchar":
|
||||
if col.CharMaxLength.Valid {
|
||||
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
|
||||
}
|
||||
return "String"
|
||||
case "char", "character":
|
||||
if col.CharMaxLength.Valid {
|
||||
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
|
||||
}
|
||||
return "String(1)"
|
||||
case "text":
|
||||
return "Text"
|
||||
case "timestamp with time zone":
|
||||
return "DateTime(timezone=True)"
|
||||
case "timestamp without time zone", "timestamp":
|
||||
return "DateTime"
|
||||
case "date":
|
||||
return "Date"
|
||||
case "time with time zone", "time without time zone", "time":
|
||||
return "Time"
|
||||
case "json":
|
||||
return "JSON"
|
||||
case "jsonb":
|
||||
return "JSONB"
|
||||
case "uuid":
|
||||
return "UUID"
|
||||
case "bytea":
|
||||
return "LargeBinary"
|
||||
case "USER-DEFINED":
|
||||
// Will be handled separately for enums
|
||||
return "Enum"
|
||||
default:
|
||||
return "String"
|
||||
}
|
||||
}
|
||||
40
internal/generator/enum.go
Normal file
40
internal/generator/enum.go
Normal file
@ -0,0 +1,40 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateEnum generates the enum types file
|
||||
func GenerateEnum(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from enum import StrEnum\n\n")
|
||||
b.WriteString("from televend_core.databases.enum import EnumMixin\n\n\n")
|
||||
|
||||
// Generate each enum type
|
||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
|
||||
b.WriteString(fmt.Sprintf("class %s(EnumMixin, StrEnum):\n", enumName))
|
||||
|
||||
if len(enumType.Values) == 0 {
|
||||
b.WriteString(" pass\n")
|
||||
} else {
|
||||
for _, value := range enumType.Values {
|
||||
// Convert value to valid Python identifier
|
||||
// Usually enum values are already uppercase like "OPEN", "IN_PROGRESS"
|
||||
identifier := strings.ToUpper(strings.ReplaceAll(value, " ", "_"))
|
||||
identifier = strings.ReplaceAll(identifier, "-", "_")
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s = \"%s\"\n", identifier, value))
|
||||
}
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
203
internal/generator/factory.go
Normal file
203
internal/generator/factory.go
Normal file
@ -0,0 +1,203 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateFactory generates the factory class
|
||||
func GenerateFactory(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from __future__ import annotations\n\n")
|
||||
b.WriteString("from typing import Type\n\n")
|
||||
b.WriteString("import factory\n\n")
|
||||
|
||||
// Import factories for related models
|
||||
fkImports := make(map[string]string) // module_name -> entity_name
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
fkImports[moduleName] = entityName
|
||||
}
|
||||
|
||||
for moduleName, entityName := range fkImports {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
|
||||
moduleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
|
||||
b.WriteString(" CustomSelfAttribute,\n")
|
||||
b.WriteString(" TelevendBaseFactory,\n")
|
||||
b.WriteString(")\n\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
|
||||
|
||||
// Add boolean fields with defaults
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.DataType == "boolean" {
|
||||
defaultValue := "True"
|
||||
if col.Name == "alive" {
|
||||
defaultValue = "True"
|
||||
} else {
|
||||
defaultValue = "False"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue))
|
||||
}
|
||||
}
|
||||
|
||||
// Add id field
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
b.WriteString(fmt.Sprintf(" %s = None\n", col.Name))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
|
||||
// Generate faker fields for each column
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey || col.DataType == "boolean" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip foreign keys, we'll handle them separately
|
||||
if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fakerDef := generateFakerField(col, ctx)
|
||||
if fakerDef != "" {
|
||||
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef))
|
||||
}
|
||||
}
|
||||
|
||||
// Generate foreign key relationships
|
||||
if len(ctx.TableInfo.ForeignKeys) > 0 {
|
||||
b.WriteString("\n")
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n",
|
||||
relationName, relationName, entityName))
|
||||
b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n",
|
||||
fk.ColumnName, relationName, relationName))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Meta class
|
||||
b.WriteString(" class Meta:\n")
|
||||
b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName))
|
||||
b.WriteString("\n")
|
||||
|
||||
// create_minimal method
|
||||
b.WriteString(" @classmethod\n")
|
||||
b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n",
|
||||
ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString(" minimal_params = {\n")
|
||||
|
||||
// Add foreign key params
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
// Check if this FK is required
|
||||
var col *database.Column
|
||||
for i := range ctx.TableInfo.Columns {
|
||||
if ctx.TableInfo.Columns[i].Name == fk.ColumnName {
|
||||
col = &ctx.TableInfo.Columns[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if col != nil && !col.IsNullable {
|
||||
// Required FK
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n",
|
||||
relationName, relationName, entityName))
|
||||
} else {
|
||||
// Optional FK
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString(" minimal_params.update(kwargs)\n")
|
||||
b.WriteString(" return cls.create(**minimal_params)\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func generateFakerField(col database.Column, ctx *Context) string {
|
||||
pythonType := database.GetPythonType(col)
|
||||
|
||||
switch pythonType {
|
||||
case "Decimal":
|
||||
precision := 7
|
||||
scale := 4
|
||||
if col.NumericPrecision.Valid {
|
||||
precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64)
|
||||
}
|
||||
if col.NumericScale.Valid {
|
||||
scale = int(col.NumericScale.Int64)
|
||||
}
|
||||
return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)",
|
||||
precision, scale)
|
||||
|
||||
case "int":
|
||||
if col.DataType == "bigint" {
|
||||
return "factory.Faker(\"pyint\")"
|
||||
}
|
||||
return "factory.Faker(\"pyint\")"
|
||||
|
||||
case "str":
|
||||
// Check if it's an enum
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
if len(enumType.Values) > 0 {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxLen := 255
|
||||
if col.CharMaxLength.Valid {
|
||||
maxLen = int(col.CharMaxLength.Int64)
|
||||
}
|
||||
if col.DataType == "text" {
|
||||
return "factory.Faker(\"text\", max_nb_chars=500)"
|
||||
}
|
||||
return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen)
|
||||
|
||||
case "datetime":
|
||||
return "factory.Faker(\"date_time\")"
|
||||
|
||||
case "date":
|
||||
return "factory.Faker(\"date\")"
|
||||
|
||||
case "time":
|
||||
return "factory.Faker(\"time\")"
|
||||
|
||||
case "bool":
|
||||
return "factory.Faker(\"boolean\")"
|
||||
|
||||
case "dict":
|
||||
return "factory.Faker(\"pydict\")"
|
||||
|
||||
default:
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
77
internal/generator/filter.go
Normal file
77
internal/generator/filter.go
Normal file
@ -0,0 +1,77 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
)
|
||||
|
||||
// GenerateFilter generates the filter class
|
||||
func GenerateFilter(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
|
||||
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sFilter(BaseFilter):\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
|
||||
|
||||
// Generate filters based on rules:
|
||||
// - Boolean fields: EQ operator
|
||||
// - ID fields: both EQ and IN operators
|
||||
// - Text fields: EQ operator
|
||||
|
||||
hasFilters := false
|
||||
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if !ShouldGenerateFilter(col) {
|
||||
continue
|
||||
}
|
||||
|
||||
hasFilters = true
|
||||
|
||||
// Boolean fields
|
||||
if col.DataType == "boolean" {
|
||||
defaultVal := "None"
|
||||
if col.Name == "alive" {
|
||||
defaultVal = "True"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %s: bool | None = filterfield(operator=EQ, default=%s)\n",
|
||||
col.Name, defaultVal))
|
||||
}
|
||||
|
||||
// ID fields (both EQ and IN)
|
||||
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
|
||||
// Single ID with EQ
|
||||
pythonType := database.GetPythonType(col)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = filterfield(operator=EQ)\n",
|
||||
col.Name, pythonType))
|
||||
|
||||
// Multiple IDs with IN (plural field name)
|
||||
pluralFieldName := GetFilterFieldName(col.Name, true)
|
||||
b.WriteString(fmt.Sprintf(" %s: list[%s] | None = filterfield(field=\"%s\", operator=IN)\n",
|
||||
pluralFieldName, pythonType, col.Name))
|
||||
}
|
||||
|
||||
// Text fields
|
||||
if (col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character") &&
|
||||
!strings.HasSuffix(col.Name, "_id") && col.Name != "id" {
|
||||
b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n",
|
||||
col.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// If no filters were generated, add a pass statement
|
||||
if !hasFilters {
|
||||
b.WriteString(" pass\n")
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
195
internal/generator/generator.go
Normal file
195
internal/generator/generator.go
Normal file
@ -0,0 +1,195 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// Context contains all information needed for code generation
|
||||
type Context struct {
|
||||
TableInfo *database.TableInfo
|
||||
EntityName string // Singular, PascalCase (e.g., "CashbagConform")
|
||||
ModuleName string // Singular, snake_case (e.g., "cashbag_conform")
|
||||
TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE")
|
||||
}
|
||||
|
||||
// NewContext creates a new generation context
|
||||
func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context {
|
||||
moduleName := naming.SingularizeTableName(tableInfo.TableName)
|
||||
|
||||
entityName := entityNameOverride
|
||||
if entityName == "" {
|
||||
entityName = naming.ToPascalCase(moduleName)
|
||||
}
|
||||
|
||||
tableConstant := strings.ToUpper(moduleName) + "_TABLE"
|
||||
|
||||
return &Context{
|
||||
TableInfo: tableInfo,
|
||||
EntityName: entityName,
|
||||
ModuleName: moduleName,
|
||||
TableConstant: tableConstant,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRelationshipName returns the relationship name for a foreign key
|
||||
// Strips _id suffix and converts to snake_case
|
||||
func GetRelationshipName(fkColumnName string) string {
|
||||
name := fkColumnName
|
||||
if strings.HasSuffix(name, "_id") {
|
||||
name = name[:len(name)-3]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// GetRelationshipEntityName returns the entity name for a foreign key's target table
|
||||
func GetRelationshipEntityName(tableName string) string {
|
||||
singular := naming.SingularizeTableName(tableName)
|
||||
return naming.ToPascalCase(singular)
|
||||
}
|
||||
|
||||
// GetRelationshipModuleName returns the module name for a foreign key's target table
|
||||
func GetRelationshipModuleName(tableName string) string {
|
||||
return naming.SingularizeTableName(tableName)
|
||||
}
|
||||
|
||||
// GetFilterFieldName returns the filter field name for a column
|
||||
// For ID fields with IN operator, pluralizes the name
|
||||
func GetFilterFieldName(columnName string, useIN bool) string {
|
||||
if useIN && strings.HasSuffix(columnName, "_id") {
|
||||
// Remove _id, pluralize, add back _ids
|
||||
base := columnName[:len(columnName)-3]
|
||||
if base == "" {
|
||||
return "ids"
|
||||
}
|
||||
return naming.Pluralize(base) + "_ids"
|
||||
}
|
||||
if useIN && columnName == "id" {
|
||||
return "ids"
|
||||
}
|
||||
return columnName
|
||||
}
|
||||
|
||||
// ShouldGenerateFilter determines if a column should have a filter
|
||||
func ShouldGenerateFilter(col database.Column) bool {
|
||||
// Generate filters for boolean, ID, and text fields
|
||||
if col.DataType == "boolean" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
|
||||
return true
|
||||
}
|
||||
if col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPythonTypeForColumn returns the Python type annotation for a column
|
||||
func GetPythonTypeForColumn(col database.Column, ctx *Context) string {
|
||||
baseType := database.GetPythonType(col)
|
||||
|
||||
// Handle enum types
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
baseType = naming.ToPascalCase(enumType.TypeName)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Decimal type
|
||||
if baseType == "Decimal" {
|
||||
baseType = "Decimal"
|
||||
}
|
||||
|
||||
return baseType
|
||||
}
|
||||
|
||||
// NeedsDecimalImport checks if any column uses Decimal type
|
||||
func NeedsDecimalImport(columns []database.Column) bool {
|
||||
for _, col := range columns {
|
||||
if database.GetPythonType(col) == "Decimal" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NeedsDatetimeImport checks if any column uses datetime type
|
||||
func NeedsDatetimeImport(columns []database.Column) bool {
|
||||
for _, col := range columns {
|
||||
pyType := database.GetPythonType(col)
|
||||
if pyType == "datetime" || pyType == "date" || pyType == "time" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRequiredColumns returns columns that are not nullable and don't have defaults
|
||||
func GetRequiredColumns(columns []database.Column) []database.Column {
|
||||
var required []database.Column
|
||||
for _, col := range columns {
|
||||
if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey {
|
||||
required = append(required, col)
|
||||
}
|
||||
}
|
||||
return required
|
||||
}
|
||||
|
||||
// GetOptionalColumns returns columns that are nullable or have defaults
|
||||
func GetOptionalColumns(columns []database.Column) []database.Column {
|
||||
var optional []database.Column
|
||||
for _, col := range columns {
|
||||
if col.IsNullable || col.ColumnDefault.Valid {
|
||||
optional = append(optional, col)
|
||||
}
|
||||
}
|
||||
return optional
|
||||
}
|
||||
|
||||
// GetForeignKeyForColumn returns the foreign key info for a column, if it exists
|
||||
func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey {
|
||||
for _, fk := range foreignKeys {
|
||||
if fk.ColumnName == columnName {
|
||||
return &fk
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateFiles generates all Python files for the entity
|
||||
func GenerateFiles(ctx *Context, outputDir string) error {
|
||||
generators := map[string]func(*Context) (string, error){
|
||||
"table.py": GenerateTable,
|
||||
"model.py": GenerateModel,
|
||||
"filter.py": GenerateFilter,
|
||||
"load_options.py": GenerateLoadOptions,
|
||||
"repository.py": GenerateRepository,
|
||||
"manager.py": GenerateManager,
|
||||
"factory.py": GenerateFactory,
|
||||
"mapper.py": GenerateMapper,
|
||||
"__init__.py": GenerateInit,
|
||||
}
|
||||
|
||||
// Generate enum.py if there are enum types
|
||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||
generators["enum.py"] = GenerateEnum
|
||||
}
|
||||
|
||||
for filename, generator := range generators {
|
||||
content, err := generator(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Write to file (this will be handled by the main function)
|
||||
// For now, just return the content
|
||||
_ = content
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
internal/generator/init.go
Normal file
7
internal/generator/init.go
Normal file
@ -0,0 +1,7 @@
|
||||
package generator
|
||||
|
||||
// GenerateInit generates an empty __init__.py file
|
||||
func GenerateInit(ctx *Context) (string, error) {
|
||||
// Empty __init__.py file
|
||||
return "", nil
|
||||
}
|
||||
38
internal/generator/load_options.go
Normal file
38
internal/generator/load_options.go
Normal file
@ -0,0 +1,38 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateLoadOptions generates the load options class
|
||||
func GenerateLoadOptions(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
|
||||
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sLoadOptions(LoadOptions):\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
|
||||
|
||||
// Generate load options for all foreign key relationships
|
||||
hasRelationships := false
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
hasRelationships = true
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
b.WriteString(fmt.Sprintf(" load_%s: bool = joinload(relations=[\"%s\"])\n",
|
||||
relationName, relationName))
|
||||
}
|
||||
|
||||
// If no relationships, add pass
|
||||
if !hasRelationships {
|
||||
b.WriteString(" pass\n")
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
34
internal/generator/manager.go
Normal file
34
internal/generator/manager.go
Normal file
@ -0,0 +1,34 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateManager generates the manager class
|
||||
func GenerateManager(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sManager(\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" CRUDManager[%s, %sFilter, %sRepository]\n",
|
||||
ctx.EntityName, ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString("):\n")
|
||||
b.WriteString(fmt.Sprintf(" repository_cls = %sRepository\n", ctx.EntityName))
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
48
internal/generator/mapper.go
Normal file
48
internal/generator/mapper.go
Normal file
@ -0,0 +1,48 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateMapper generates the mapper snippet (without imports)
|
||||
func GenerateMapper(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Mapper registration (snippet only, no imports)
|
||||
b.WriteString(" mapper_registry.map_imperatively(\n")
|
||||
b.WriteString(fmt.Sprintf(" class_=%s,\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" local_table=%s,\n", ctx.TableConstant))
|
||||
b.WriteString(" properties={\n")
|
||||
|
||||
// Generate relationships for all foreign keys
|
||||
fkRelationships := make(map[string][]string) // entity -> []column_names
|
||||
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
// Group by entity name to handle multiple FKs to same table
|
||||
fkRelationships[entityName] = append(fkRelationships[entityName], fk.ColumnName)
|
||||
|
||||
if len(fkRelationships[entityName]) == 1 {
|
||||
// First FK to this table
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
|
||||
b.WriteString(fmt.Sprintf(" %s, lazy=relationship_loading_strategy.value\n", entityName))
|
||||
} else {
|
||||
// Multiple FKs to same table, need to specify foreign_keys
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", entityName))
|
||||
b.WriteString(" lazy=relationship_loading_strategy.value,\n")
|
||||
b.WriteString(fmt.Sprintf(" foreign_keys=%s.columns.%s,\n",
|
||||
ctx.TableConstant, fk.ColumnName))
|
||||
}
|
||||
|
||||
b.WriteString(" ),\n")
|
||||
}
|
||||
|
||||
b.WriteString(" },\n")
|
||||
b.WriteString(" )\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
124
internal/generator/model.go
Normal file
124
internal/generator/model.go
Normal file
@ -0,0 +1,124 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateModel generates the dataclass model
|
||||
func GenerateModel(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from dataclasses import dataclass\n")
|
||||
|
||||
// Check what we need to import
|
||||
needsDatetime := NeedsDatetimeImport(ctx.TableInfo.Columns)
|
||||
needsDecimal := NeedsDecimalImport(ctx.TableInfo.Columns)
|
||||
|
||||
if needsDatetime {
|
||||
b.WriteString("from datetime import datetime\n")
|
||||
}
|
||||
if needsDecimal {
|
||||
b.WriteString("from decimal import Decimal\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString("from televend_core.databases.base_model import Base\n")
|
||||
|
||||
// Import related models for foreign keys
|
||||
fkImports := make(map[string]string) // module_name -> entity_name
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
fkImports[moduleName] = entityName
|
||||
}
|
||||
|
||||
// Import enum types
|
||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
||||
ctx.ModuleName))
|
||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||
}
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
|
||||
// Write foreign key imports
|
||||
for moduleName, entityName := range fkImports {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
moduleName, entityName))
|
||||
}
|
||||
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString("@dataclass\n")
|
||||
b.WriteString(fmt.Sprintf("class %s(Base):\n", ctx.EntityName))
|
||||
|
||||
// Get required and optional columns
|
||||
requiredCols := GetRequiredColumns(ctx.TableInfo.Columns)
|
||||
optionalCols := GetOptionalColumns(ctx.TableInfo.Columns)
|
||||
|
||||
// Required fields (non-nullable, no default, not auto-increment, not PK)
|
||||
for _, col := range requiredCols {
|
||||
fieldName := col.Name
|
||||
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||
|
||||
// Regular field
|
||||
b.WriteString(fmt.Sprintf(" %s: %s\n", fieldName, pythonType))
|
||||
|
||||
// Add relationship field if this is a foreign key and column ends with _id
|
||||
// (to avoid name clashes with FK columns that don't follow _id convention)
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
if strings.HasSuffix(col.Name, "_id") {
|
||||
relationName := GetRelationshipName(col.Name)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s\n", relationName, entityName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Empty line between required and optional
|
||||
if len(requiredCols) > 0 && len(optionalCols) > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Optional fields
|
||||
for _, col := range optionalCols {
|
||||
// Skip primary key, we'll add it at the end
|
||||
if col.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := col.Name
|
||||
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", fieldName, pythonType))
|
||||
|
||||
// Add relationship field if this is a foreign key and column ends with _id
|
||||
// (to avoid name clashes with FK columns that don't follow _id convention)
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
if strings.HasSuffix(col.Name, "_id") {
|
||||
relationName := GetRelationshipName(col.Name)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", relationName, entityName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add primary key at the end
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
b.WriteString("\n")
|
||||
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", col.Name, pythonType))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
28
internal/generator/repository.go
Normal file
28
internal/generator/repository.go
Normal file
@ -0,0 +1,28 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateRepository generates the repository class
|
||||
func GenerateRepository(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sRepository(CRUDRepository[%s, %sFilter]):\n",
|
||||
ctx.EntityName, ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" model_cls = %s\n", ctx.EntityName))
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
160
internal/generator/table.go
Normal file
160
internal/generator/table.go
Normal file
@ -0,0 +1,160 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateTable generates the SQLAlchemy table definition
|
||||
func GenerateTable(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from sqlalchemy import (\n")
|
||||
|
||||
// Collect unique imports
|
||||
imports := make(map[string]bool)
|
||||
imports["Column"] = true
|
||||
imports["Table"] = true
|
||||
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
sqlType := database.GetSQLAlchemyType(col)
|
||||
|
||||
// Extract base type name (before parentheses)
|
||||
typeName := strings.Split(sqlType, "(")[0]
|
||||
imports[typeName] = true
|
||||
|
||||
if col.IsPrimaryKey {
|
||||
imports["Integer"] = true
|
||||
}
|
||||
|
||||
// Check for foreign keys
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
imports["ForeignKey"] = true
|
||||
}
|
||||
|
||||
// Check for enums
|
||||
if col.DataType == "USER-DEFINED" {
|
||||
imports["Enum"] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Sort and write imports
|
||||
importList := []string{}
|
||||
for imp := range imports {
|
||||
importList = append(importList, imp)
|
||||
}
|
||||
|
||||
// Write imports in a reasonable order
|
||||
orderedImports := []string{
|
||||
"BigInteger", "Boolean", "Column", "Date", "DateTime", "Enum",
|
||||
"Float", "ForeignKey", "Integer", "JSON", "JSONB", "LargeBinary",
|
||||
"Numeric", "SmallInteger", "String", "Table", "Text", "Time", "UUID",
|
||||
}
|
||||
|
||||
first := true
|
||||
for _, imp := range orderedImports {
|
||||
if imports[imp] {
|
||||
if !first {
|
||||
b.WriteString(",\n")
|
||||
} else {
|
||||
first = false
|
||||
}
|
||||
b.WriteString(" " + imp)
|
||||
}
|
||||
}
|
||||
b.WriteString(",\n)\n\n")
|
||||
|
||||
// Import enum types if they exist
|
||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n")
|
||||
|
||||
// Table definition
|
||||
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
|
||||
b.WriteString(fmt.Sprintf(" \"%s\",\n", ctx.TableInfo.TableName))
|
||||
b.WriteString(" metadata_obj,\n")
|
||||
|
||||
// Columns
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
b.WriteString(generateColumnDefinition(col, ctx))
|
||||
}
|
||||
|
||||
b.WriteString(")\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func generateColumnDefinition(col database.Column, ctx *Context) string {
|
||||
var parts []string
|
||||
|
||||
// Column name
|
||||
parts = append(parts, fmt.Sprintf("\"%s\"", col.Name))
|
||||
|
||||
// Column type
|
||||
sqlType := database.GetSQLAlchemyType(col)
|
||||
|
||||
// Handle enum types
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
sqlType = fmt.Sprintf("Enum(\n *%s.to_value_list(),\n name=\"%s\",\n )",
|
||||
enumName, enumType.TypeName)
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, sqlType)
|
||||
|
||||
// Foreign key
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
fkDef := fmt.Sprintf("ForeignKey(\"%s.%s\", deferrable=True, initially=\"DEFERRED\")",
|
||||
fk.ForeignTableName, fk.ForeignColumnName)
|
||||
parts = append(parts, fkDef)
|
||||
}
|
||||
|
||||
// Primary key
|
||||
if col.IsPrimaryKey {
|
||||
parts = append(parts, "primary_key=True")
|
||||
if col.IsAutoIncrement {
|
||||
parts = append(parts, "autoincrement=True")
|
||||
}
|
||||
}
|
||||
|
||||
// Nullable
|
||||
if !col.IsNullable && !col.IsPrimaryKey {
|
||||
parts = append(parts, "nullable=False")
|
||||
}
|
||||
|
||||
// Unique
|
||||
// Note: We don't have unique constraint info in our introspection yet
|
||||
// This would need to be added if needed
|
||||
|
||||
// Format the column definition
|
||||
result := " Column("
|
||||
|
||||
// Check if we need multiline formatting (for complex types like Enum)
|
||||
if strings.Contains(sqlType, "\n") {
|
||||
// Multiline format
|
||||
result += parts[0] + ",\n " + parts[1]
|
||||
for _, part := range parts[2:] {
|
||||
result += ",\n " + part
|
||||
}
|
||||
result += ",\n ),\n"
|
||||
} else {
|
||||
// Single line format
|
||||
result += strings.Join(parts, ", ") + "),\n"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
227
internal/naming/naming.go
Normal file
227
internal/naming/naming.go
Normal file
@ -0,0 +1,227 @@
|
||||
package naming
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// ToSnakeCase converts a string to snake_case
|
||||
func ToSnakeCase(s string) string {
|
||||
var result []rune
|
||||
for i, r := range s {
|
||||
if unicode.IsUpper(r) {
|
||||
if i > 0 {
|
||||
result = append(result, '_')
|
||||
}
|
||||
result = append(result, unicode.ToLower(r))
|
||||
} else {
|
||||
result = append(result, r)
|
||||
}
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
|
||||
// ToPascalCase converts snake_case to PascalCase
|
||||
func ToPascalCase(s string) string {
|
||||
parts := strings.Split(s, "_")
|
||||
for i, part := range parts {
|
||||
if len(part) > 0 {
|
||||
parts[i] = strings.ToUpper(part[:1]) + part[1:]
|
||||
}
|
||||
}
|
||||
return strings.Join(parts, "")
|
||||
}
|
||||
|
||||
// Singularize converts a plural word to singular
|
||||
// Uses common English pluralization rules
|
||||
func Singularize(word string) string {
|
||||
// Handle empty string
|
||||
if word == "" {
|
||||
return word
|
||||
}
|
||||
|
||||
// Common irregular plurals
|
||||
irregulars := map[string]string{
|
||||
"people": "person",
|
||||
"men": "man",
|
||||
"women": "woman",
|
||||
"children": "child",
|
||||
"teeth": "tooth",
|
||||
"feet": "foot",
|
||||
"mice": "mouse",
|
||||
"geese": "goose",
|
||||
"oxen": "ox",
|
||||
"sheep": "sheep",
|
||||
"fish": "fish",
|
||||
"deer": "deer",
|
||||
"series": "series",
|
||||
"species": "species",
|
||||
"quizzes": "quiz",
|
||||
"analyses": "analysis",
|
||||
"diagnoses": "diagnosis",
|
||||
"oases": "oasis",
|
||||
"theses": "thesis",
|
||||
"crises": "crisis",
|
||||
"phenomena": "phenomenon",
|
||||
"criteria": "criterion",
|
||||
"data": "datum",
|
||||
}
|
||||
|
||||
lower := strings.ToLower(word)
|
||||
if singular, ok := irregulars[lower]; ok {
|
||||
return preserveCase(word, singular)
|
||||
}
|
||||
|
||||
// Handle words ending in 'ies' -> 'y'
|
||||
if strings.HasSuffix(lower, "ies") && len(word) > 3 {
|
||||
return word[:len(word)-3] + "y"
|
||||
}
|
||||
|
||||
// Handle words ending in 'ves' -> 'fe' or 'f'
|
||||
if strings.HasSuffix(lower, "ves") && len(word) > 3 {
|
||||
base := word[:len(word)-3]
|
||||
// Common words that end in 'fe'
|
||||
if strings.HasSuffix(strings.ToLower(base), "li") ||
|
||||
strings.HasSuffix(strings.ToLower(base), "wi") ||
|
||||
strings.HasSuffix(strings.ToLower(base), "kni") {
|
||||
return base + "fe"
|
||||
}
|
||||
return base + "f"
|
||||
}
|
||||
|
||||
// Handle words ending in 'xes', 'ses', 'shes', 'ches' -> remove 'es'
|
||||
if strings.HasSuffix(lower, "xes") ||
|
||||
strings.HasSuffix(lower, "ses") ||
|
||||
strings.HasSuffix(lower, "shes") ||
|
||||
strings.HasSuffix(lower, "ches") {
|
||||
if len(word) > 2 {
|
||||
return word[:len(word)-2]
|
||||
}
|
||||
}
|
||||
|
||||
// Handle words ending in 'oes' -> 'o'
|
||||
if strings.HasSuffix(lower, "oes") && len(word) > 3 {
|
||||
return word[:len(word)-2]
|
||||
}
|
||||
|
||||
// Handle simple 's' suffix
|
||||
if strings.HasSuffix(lower, "s") && len(word) > 1 {
|
||||
// Don't remove 's' from words that naturally end in 's'
|
||||
if !strings.HasSuffix(lower, "ss") &&
|
||||
!strings.HasSuffix(lower, "us") &&
|
||||
!strings.HasSuffix(lower, "is") {
|
||||
return word[:len(word)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return word
|
||||
}
|
||||
|
||||
// Pluralize converts a singular word to plural
|
||||
func Pluralize(word string) string {
|
||||
if word == "" {
|
||||
return word
|
||||
}
|
||||
|
||||
// Common irregular plurals
|
||||
irregulars := map[string]string{
|
||||
"person": "people",
|
||||
"man": "men",
|
||||
"woman": "women",
|
||||
"child": "children",
|
||||
"tooth": "teeth",
|
||||
"foot": "feet",
|
||||
"mouse": "mice",
|
||||
"goose": "geese",
|
||||
"ox": "oxen",
|
||||
"sheep": "sheep",
|
||||
"fish": "fish",
|
||||
"deer": "deer",
|
||||
"series": "series",
|
||||
"species": "species",
|
||||
"quiz": "quizzes",
|
||||
"analysis": "analyses",
|
||||
"diagnosis": "diagnoses",
|
||||
"oasis": "oases",
|
||||
"thesis": "theses",
|
||||
"crisis": "crises",
|
||||
"phenomenon": "phenomena",
|
||||
"criterion": "criteria",
|
||||
"datum": "data",
|
||||
}
|
||||
|
||||
lower := strings.ToLower(word)
|
||||
if plural, ok := irregulars[lower]; ok {
|
||||
return preserveCase(word, plural)
|
||||
}
|
||||
|
||||
// Already plural (ends in 's' and not special case)
|
||||
if strings.HasSuffix(lower, "s") && !strings.HasSuffix(lower, "us") {
|
||||
return word
|
||||
}
|
||||
|
||||
// Handle words ending in 'y' -> 'ies'
|
||||
if strings.HasSuffix(lower, "y") && len(word) > 1 {
|
||||
prevChar := rune(lower[len(lower)-2])
|
||||
if !isVowel(prevChar) {
|
||||
return word[:len(word)-1] + "ies"
|
||||
}
|
||||
}
|
||||
|
||||
// Handle words ending in 'f' or 'fe' -> 'ves'
|
||||
if strings.HasSuffix(lower, "f") {
|
||||
return word[:len(word)-1] + "ves"
|
||||
}
|
||||
if strings.HasSuffix(lower, "fe") {
|
||||
return word[:len(word)-2] + "ves"
|
||||
}
|
||||
|
||||
// Handle words ending in 'o' -> 'oes'
|
||||
if strings.HasSuffix(lower, "o") && len(word) > 1 {
|
||||
prevChar := rune(lower[len(lower)-2])
|
||||
if !isVowel(prevChar) {
|
||||
return word + "es"
|
||||
}
|
||||
}
|
||||
|
||||
// Handle words ending in 'x', 's', 'sh', 'ch' -> add 'es'
|
||||
if strings.HasSuffix(lower, "x") ||
|
||||
strings.HasSuffix(lower, "s") ||
|
||||
strings.HasSuffix(lower, "sh") ||
|
||||
strings.HasSuffix(lower, "ch") {
|
||||
return word + "es"
|
||||
}
|
||||
|
||||
// Default: just add 's'
|
||||
return word + "s"
|
||||
}
|
||||
|
||||
// preserveCase preserves the case pattern of the original word
|
||||
func preserveCase(original, replacement string) string {
|
||||
if len(original) == 0 {
|
||||
return replacement
|
||||
}
|
||||
|
||||
if unicode.IsUpper(rune(original[0])) {
|
||||
return strings.ToUpper(replacement[:1]) + replacement[1:]
|
||||
}
|
||||
return replacement
|
||||
}
|
||||
|
||||
// isVowel checks if a rune is a vowel
|
||||
func isVowel(r rune) bool {
|
||||
vowels := "aeiouAEIOU"
|
||||
return strings.ContainsRune(vowels, r)
|
||||
}
|
||||
|
||||
// SingularizeTableName converts a table name to its singular form
|
||||
// Handles snake_case table names
|
||||
func SingularizeTableName(tableName string) string {
|
||||
parts := strings.Split(tableName, "_")
|
||||
if len(parts) > 0 {
|
||||
// Only singularize the last part
|
||||
lastIdx := len(parts) - 1
|
||||
parts[lastIdx] = Singularize(parts[lastIdx])
|
||||
}
|
||||
return strings.Join(parts, "_")
|
||||
}
|
||||
186
internal/prompt/prompt.go
Normal file
186
internal/prompt/prompt.go
Normal file
@ -0,0 +1,186 @@
|
||||
package prompt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/fatih/color"
|
||||
)
|
||||
|
||||
var (
|
||||
cyan = color.New(color.FgCyan).SprintFunc()
|
||||
green = color.New(color.FgGreen).SprintFunc()
|
||||
yellow = color.New(color.FgYellow).SprintFunc()
|
||||
red = color.New(color.FgRed).SprintFunc()
|
||||
)
|
||||
|
||||
// Reader interface for testing
|
||||
type Reader interface {
|
||||
ReadString(delim byte) (string, error)
|
||||
}
|
||||
|
||||
// PromptString prompts for a string value
|
||||
func PromptString(label string, defaultValue string, required bool) (string, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
return promptStringWithReader(reader, label, defaultValue, required)
|
||||
}
|
||||
|
||||
func promptStringWithReader(reader Reader, label string, defaultValue string, required bool) (string, error) {
|
||||
for {
|
||||
prompt := fmt.Sprintf("%s", cyan(label))
|
||||
if defaultValue != "" {
|
||||
prompt += fmt.Sprintf(" [%s]", green(defaultValue))
|
||||
}
|
||||
prompt += ": "
|
||||
|
||||
fmt.Print(prompt)
|
||||
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// If input is empty, use default value
|
||||
if input == "" {
|
||||
if defaultValue != "" {
|
||||
// Echo the selected default value
|
||||
fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultValue), defaultValue)
|
||||
fmt.Printf("\033[1A\r%s\n", fullPrompt)
|
||||
return defaultValue, nil
|
||||
}
|
||||
if required {
|
||||
fmt.Println(red("✗ This field is required"))
|
||||
continue
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return input, nil
|
||||
}
|
||||
}
|
||||
|
||||
// PromptInt prompts for an integer value
|
||||
func PromptInt(label string, defaultValue int, required bool) (int, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
return promptIntWithReader(reader, label, defaultValue, required)
|
||||
}
|
||||
|
||||
func promptIntWithReader(reader Reader, label string, defaultValue int, required bool) (int, error) {
|
||||
for {
|
||||
defaultStr := ""
|
||||
if defaultValue != 0 {
|
||||
defaultStr = strconv.Itoa(defaultValue)
|
||||
}
|
||||
|
||||
prompt := fmt.Sprintf("%s", cyan(label))
|
||||
if defaultStr != "" {
|
||||
prompt += fmt.Sprintf(" [%s]", green(defaultStr))
|
||||
}
|
||||
prompt += ": "
|
||||
|
||||
fmt.Print(prompt)
|
||||
|
||||
input, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
input = strings.TrimSpace(input)
|
||||
|
||||
// If input is empty, use default value
|
||||
if input == "" {
|
||||
if defaultValue != 0 {
|
||||
// Echo the selected default value
|
||||
fullPrompt := fmt.Sprintf("%s [%s]: %s", cyan(label), green(defaultStr), defaultStr)
|
||||
fmt.Printf("\033[1A\r%s\n", fullPrompt)
|
||||
return defaultValue, nil
|
||||
}
|
||||
if required {
|
||||
fmt.Println(red("✗ This field is required"))
|
||||
continue
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Parse integer
|
||||
value, err := strconv.Atoi(input)
|
||||
if err != nil {
|
||||
fmt.Println(red("✗ Please enter a valid number"))
|
||||
continue
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
// ValidateDirectory checks if a directory exists or can be created
|
||||
func ValidateDirectory(path string) error {
|
||||
// Check if path exists
|
||||
info, err := os.Stat(path)
|
||||
if err == nil {
|
||||
// Path exists, check if it's a directory
|
||||
if !info.IsDir() {
|
||||
return fmt.Errorf("path exists but is not a directory")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Path doesn't exist, try to create it
|
||||
if os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(path, 0755); err != nil {
|
||||
return fmt.Errorf("cannot create directory: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// PromptDirectory prompts for a directory path and validates it
|
||||
func PromptDirectory(label string, defaultValue string, required bool) (string, error) {
|
||||
for {
|
||||
path, err := PromptString(label, defaultValue, required)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if path == "" && !required {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Validate directory
|
||||
if err := ValidateDirectory(path); err != nil {
|
||||
fmt.Printf("%s %s\n", red("✗"), err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
return path, nil
|
||||
}
|
||||
}
|
||||
|
||||
// PrintHeader prints a colored header
|
||||
func PrintHeader(text string) {
|
||||
header := color.New(color.FgCyan, color.Bold)
|
||||
header.Println("\n" + text)
|
||||
header.Println(strings.Repeat("=", len(text)))
|
||||
}
|
||||
|
||||
// PrintSuccess prints a success message
|
||||
func PrintSuccess(text string) {
|
||||
fmt.Printf("%s %s\n", green("✓"), text)
|
||||
}
|
||||
|
||||
// PrintError prints an error message
|
||||
func PrintError(text string) {
|
||||
fmt.Printf("%s %s\n", red("✗"), text)
|
||||
}
|
||||
|
||||
// PrintInfo prints an info message
|
||||
func PrintInfo(text string) {
|
||||
fmt.Printf("%s %s\n", yellow("ℹ"), text)
|
||||
}
|
||||
Reference in New Issue
Block a user