package generator import ( "fmt" "strings" "github.com/entity-maker/entity-maker/internal/database" "github.com/entity-maker/entity-maker/internal/naming" ) // GenerateFactory generates the factory class func GenerateFactory(ctx *Context) (string, error) { var b strings.Builder // Imports b.WriteString("from __future__ import annotations\n\n") b.WriteString("from typing import Type\n\n") b.WriteString("import factory\n\n") // Import factories for related models fkImports := make(map[string]string) // module_name -> entity_name for _, fk := range ctx.TableInfo.ForeignKeys { moduleName := GetRelationshipModuleName(fk.ForeignTableName) entityName := GetRelationshipEntityName(fk.ForeignTableName) fkImports[moduleName] = entityName } for moduleName, entityName := range fkImports { b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n", moduleName)) b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName)) b.WriteString(")\n") } b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n", ctx.ModuleName)) b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName)) b.WriteString(")\n") b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n") b.WriteString(" CustomSelfAttribute,\n") b.WriteString(" TelevendBaseFactory,\n") b.WriteString(")\n\n\n") // Class definition b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName)) // Add boolean fields with defaults for _, col := range ctx.TableInfo.Columns { if col.DataType == "boolean" { defaultValue := "True" if col.Name == "alive" { defaultValue = "True" } else { defaultValue = "False" } b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue)) } } // Add id field for _, col := range ctx.TableInfo.Columns { if col.IsPrimaryKey { b.WriteString(fmt.Sprintf(" %s = None\n", col.Name)) } } b.WriteString("\n") // Generate faker fields for each column for _, col := range ctx.TableInfo.Columns { if col.IsPrimaryKey || col.DataType == "boolean" { continue } // Skip foreign keys, we'll handle them separately if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil { continue } fakerDef := generateFakerField(col, ctx) if fakerDef != "" { b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef)) } } // Generate foreign key relationships if len(ctx.TableInfo.ForeignKeys) > 0 { b.WriteString("\n") for _, fk := range ctx.TableInfo.ForeignKeys { relationName := GetRelationshipName(fk.ColumnName) entityName := GetRelationshipEntityName(fk.ForeignTableName) b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n", relationName, relationName, entityName)) b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n", fk.ColumnName, relationName, relationName)) b.WriteString("\n") } } // Meta class b.WriteString(" class Meta:\n") b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName)) b.WriteString("\n") // create_minimal method b.WriteString(" @classmethod\n") b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n", ctx.EntityName, ctx.EntityName)) b.WriteString(" minimal_params = {\n") // Add foreign key params for _, fk := range ctx.TableInfo.ForeignKeys { relationName := GetRelationshipName(fk.ColumnName) entityName := GetRelationshipEntityName(fk.ForeignTableName) // Check if this FK is required var col *database.Column for i := range ctx.TableInfo.Columns { if ctx.TableInfo.Columns[i].Name == fk.ColumnName { col = &ctx.TableInfo.Columns[i] break } } if col != nil && !col.IsNullable { // Required FK b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n", relationName, relationName, entityName)) } else { // Optional FK b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName)) } } b.WriteString(" }\n") b.WriteString(" minimal_params.update(kwargs)\n") b.WriteString(" return cls.create(**minimal_params)\n") return b.String(), nil } func generateFakerField(col database.Column, ctx *Context) string { pythonType := database.GetPythonType(col) switch pythonType { case "Decimal": precision := 7 scale := 4 if col.NumericPrecision.Valid { precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64) } if col.NumericScale.Valid { scale = int(col.NumericScale.Int64) } return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)", precision, scale) case "int": if col.DataType == "bigint" { return "factory.Faker(\"pyint\")" } return "factory.Faker(\"pyint\")" case "str": // Check if it's an enum if col.DataType == "USER-DEFINED" && col.UdtName != "" { if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists { if len(enumType.Values) > 0 { enumName := naming.ToPascalCase(enumType.TypeName) return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName) } } } maxLen := 255 if col.CharMaxLength.Valid { maxLen = int(col.CharMaxLength.Int64) } if col.DataType == "text" { return "factory.Faker(\"text\", max_nb_chars=500)" } return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen) case "datetime": return "factory.Faker(\"date_time\")" case "date": return "factory.Faker(\"date\")" case "time": return "factory.Faker(\"time\")" case "bool": return "factory.Faker(\"boolean\")" case "dict": return "factory.Faker(\"pydict\")" default: return "None" } }