Files
entity-maker/internal/generator/generator.go
Eden Kirin 4e4827d640 Tests
2025-10-31 19:04:40 +01:00

192 lines
5.5 KiB
Go

package generator
import (
"fmt"
"strings"
"github.com/entity-maker/entity-maker/internal/database"
"github.com/entity-maker/entity-maker/internal/naming"
)
// Context contains all information needed for code generation
type Context struct {
TableInfo *database.TableInfo
EntityName string // Singular, PascalCase (e.g., "CashbagConform")
ModuleName string // Singular, snake_case (e.g., "cashbag_conform")
TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE")
}
// NewContext creates a new generation context
func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context {
moduleName := naming.SingularizeTableName(tableInfo.TableName)
entityName := entityNameOverride
if entityName == "" {
entityName = naming.ToPascalCase(moduleName)
}
tableConstant := strings.ToUpper(moduleName) + "_TABLE"
return &Context{
TableInfo: tableInfo,
EntityName: entityName,
ModuleName: moduleName,
TableConstant: tableConstant,
}
}
// GetRelationshipName returns the relationship name for a foreign key
// Strips _id suffix and converts to snake_case
func GetRelationshipName(fkColumnName string) string {
name := fkColumnName
if strings.HasSuffix(name, "_id") {
name = name[:len(name)-3]
}
return name
}
// GetRelationshipEntityName returns the entity name for a foreign key's target table
func GetRelationshipEntityName(tableName string) string {
singular := naming.SingularizeTableName(tableName)
return naming.ToPascalCase(singular)
}
// GetRelationshipModuleName returns the module name for a foreign key's target table
func GetRelationshipModuleName(tableName string) string {
return naming.SingularizeTableName(tableName)
}
// GetFilterFieldName returns the filter field name for a column
// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids)
func GetFilterFieldName(columnName string, useIN bool) string {
if useIN && strings.HasSuffix(columnName, "_id") {
// Replace _id with _ids
return columnName[:len(columnName)-2] + "ids"
}
if useIN && columnName == "id" {
return "ids"
}
return columnName
}
// ShouldGenerateFilter determines if a column should have a filter
func ShouldGenerateFilter(col database.Column) bool {
// Generate filters for boolean, ID, and text fields
if col.DataType == "boolean" {
return true
}
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
return true
}
if col.DataType == "character varying" || col.DataType == "varchar" ||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character" {
return true
}
return false
}
// GetPythonTypeForColumn returns the Python type annotation for a column
func GetPythonTypeForColumn(col database.Column, ctx *Context) string {
baseType := database.GetPythonType(col)
// Handle enum types
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
baseType = naming.ToPascalCase(enumType.TypeName)
}
}
// Handle Decimal type
if baseType == "Decimal" {
baseType = "Decimal"
}
return baseType
}
// NeedsDecimalImport checks if any column uses Decimal type
func NeedsDecimalImport(columns []database.Column) bool {
for _, col := range columns {
if database.GetPythonType(col) == "Decimal" {
return true
}
}
return false
}
// NeedsDatetimeImport checks if any column uses datetime type
func NeedsDatetimeImport(columns []database.Column) bool {
for _, col := range columns {
pyType := database.GetPythonType(col)
if pyType == "datetime" || pyType == "date" || pyType == "time" {
return true
}
}
return false
}
// GetRequiredColumns returns columns that are not nullable and don't have defaults
func GetRequiredColumns(columns []database.Column) []database.Column {
var required []database.Column
for _, col := range columns {
if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey {
required = append(required, col)
}
}
return required
}
// GetOptionalColumns returns columns that are nullable or have defaults
func GetOptionalColumns(columns []database.Column) []database.Column {
var optional []database.Column
for _, col := range columns {
if col.IsNullable || col.ColumnDefault.Valid {
optional = append(optional, col)
}
}
return optional
}
// GetForeignKeyForColumn returns the foreign key info for a column, if it exists
func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey {
for _, fk := range foreignKeys {
if fk.ColumnName == columnName {
return &fk
}
}
return nil
}
// GenerateFiles generates all Python files for the entity
func GenerateFiles(ctx *Context, outputDir string) error {
generators := map[string]func(*Context) (string, error){
"table.py": GenerateTable,
"model.py": GenerateModel,
"filter.py": GenerateFilter,
"load_options.py": GenerateLoadOptions,
"repository.py": GenerateRepository,
"manager.py": GenerateManager,
"factory.py": GenerateFactory,
"mapper.py": GenerateMapper,
"__init__.py": GenerateInit,
}
// Generate enum.py if there are enum types
if len(ctx.TableInfo.EnumTypes) > 0 {
generators["enum.py"] = GenerateEnum
}
for filename, generator := range generators {
content, err := generator(ctx)
if err != nil {
return fmt.Errorf("failed to generate %s: %w", filename, err)
}
// Write to file (this will be handled by the main function)
// For now, just return the content
_ = content
}
return nil
}