package generator import ( "fmt" "strings" "github.com/entity-maker/entity-maker/internal/database" "github.com/entity-maker/entity-maker/internal/naming" ) // Context contains all information needed for code generation type Context struct { TableInfo *database.TableInfo EntityName string // Singular, PascalCase (e.g., "CashbagConform") ModuleName string // Singular, snake_case (e.g., "cashbag_conform") TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE") } // NewContext creates a new generation context func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context { moduleName := naming.SingularizeTableName(tableInfo.TableName) entityName := entityNameOverride if entityName == "" { entityName = naming.ToPascalCase(moduleName) } tableConstant := strings.ToUpper(moduleName) + "_TABLE" return &Context{ TableInfo: tableInfo, EntityName: entityName, ModuleName: moduleName, TableConstant: tableConstant, } } // GetRelationshipName returns the relationship name for a foreign key // Strips _id suffix and converts to snake_case func GetRelationshipName(fkColumnName string) string { name := fkColumnName if strings.HasSuffix(name, "_id") { name = name[:len(name)-3] } return name } // GetRelationshipEntityName returns the entity name for a foreign key's target table func GetRelationshipEntityName(tableName string) string { singular := naming.SingularizeTableName(tableName) return naming.ToPascalCase(singular) } // GetRelationshipModuleName returns the module name for a foreign key's target table func GetRelationshipModuleName(tableName string) string { return naming.SingularizeTableName(tableName) } // GetFilterFieldName returns the filter field name for a column // For ID fields with IN operator, pluralizes the name func GetFilterFieldName(columnName string, useIN bool) string { if useIN && strings.HasSuffix(columnName, "_id") { // Remove _id, pluralize, add back _ids base := columnName[:len(columnName)-3] if base == "" { return "ids" } return naming.Pluralize(base) + "_ids" } if useIN && columnName == "id" { return "ids" } return columnName } // ShouldGenerateFilter determines if a column should have a filter func ShouldGenerateFilter(col database.Column) bool { // Generate filters for boolean, ID, and text fields if col.DataType == "boolean" { return true } if strings.HasSuffix(col.Name, "_id") || col.Name == "id" { return true } if col.DataType == "character varying" || col.DataType == "varchar" || col.DataType == "text" || col.DataType == "char" || col.DataType == "character" { return true } return false } // GetPythonTypeForColumn returns the Python type annotation for a column func GetPythonTypeForColumn(col database.Column, ctx *Context) string { baseType := database.GetPythonType(col) // Handle enum types if col.DataType == "USER-DEFINED" && col.UdtName != "" { if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists { baseType = naming.ToPascalCase(enumType.TypeName) } } // Handle Decimal type if baseType == "Decimal" { baseType = "Decimal" } return baseType } // NeedsDecimalImport checks if any column uses Decimal type func NeedsDecimalImport(columns []database.Column) bool { for _, col := range columns { if database.GetPythonType(col) == "Decimal" { return true } } return false } // NeedsDatetimeImport checks if any column uses datetime type func NeedsDatetimeImport(columns []database.Column) bool { for _, col := range columns { pyType := database.GetPythonType(col) if pyType == "datetime" || pyType == "date" || pyType == "time" { return true } } return false } // GetRequiredColumns returns columns that are not nullable and don't have defaults func GetRequiredColumns(columns []database.Column) []database.Column { var required []database.Column for _, col := range columns { if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey { required = append(required, col) } } return required } // GetOptionalColumns returns columns that are nullable or have defaults func GetOptionalColumns(columns []database.Column) []database.Column { var optional []database.Column for _, col := range columns { if col.IsNullable || col.ColumnDefault.Valid { optional = append(optional, col) } } return optional } // GetForeignKeyForColumn returns the foreign key info for a column, if it exists func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey { for _, fk := range foreignKeys { if fk.ColumnName == columnName { return &fk } } return nil } // GenerateFiles generates all Python files for the entity func GenerateFiles(ctx *Context, outputDir string) error { generators := map[string]func(*Context) (string, error){ "table.py": GenerateTable, "model.py": GenerateModel, "filter.py": GenerateFilter, "load_options.py": GenerateLoadOptions, "repository.py": GenerateRepository, "manager.py": GenerateManager, "factory.py": GenerateFactory, "mapper.py": GenerateMapper, "__init__.py": GenerateInit, } // Generate enum.py if there are enum types if len(ctx.TableInfo.EnumTypes) > 0 { generators["enum.py"] = GenerateEnum } for filename, generator := range generators { content, err := generator(ctx) if err != nil { return fmt.Errorf("failed to generate %s: %w", filename, err) } // Write to file (this will be handled by the main function) // For now, just return the content _ = content } return nil }