First working version
This commit is contained in:
203
internal/generator/factory.go
Normal file
203
internal/generator/factory.go
Normal file
@ -0,0 +1,203 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateFactory generates the factory class
|
||||
func GenerateFactory(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from __future__ import annotations\n\n")
|
||||
b.WriteString("from typing import Type\n\n")
|
||||
b.WriteString("import factory\n\n")
|
||||
|
||||
// Import factories for related models
|
||||
fkImports := make(map[string]string) // module_name -> entity_name
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
fkImports[moduleName] = entityName
|
||||
}
|
||||
|
||||
for moduleName, entityName := range fkImports {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
|
||||
moduleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
|
||||
b.WriteString(" CustomSelfAttribute,\n")
|
||||
b.WriteString(" TelevendBaseFactory,\n")
|
||||
b.WriteString(")\n\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
|
||||
|
||||
// Add boolean fields with defaults
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.DataType == "boolean" {
|
||||
defaultValue := "True"
|
||||
if col.Name == "alive" {
|
||||
defaultValue = "True"
|
||||
} else {
|
||||
defaultValue = "False"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue))
|
||||
}
|
||||
}
|
||||
|
||||
// Add id field
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
b.WriteString(fmt.Sprintf(" %s = None\n", col.Name))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
|
||||
// Generate faker fields for each column
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey || col.DataType == "boolean" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip foreign keys, we'll handle them separately
|
||||
if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fakerDef := generateFakerField(col, ctx)
|
||||
if fakerDef != "" {
|
||||
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef))
|
||||
}
|
||||
}
|
||||
|
||||
// Generate foreign key relationships
|
||||
if len(ctx.TableInfo.ForeignKeys) > 0 {
|
||||
b.WriteString("\n")
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n",
|
||||
relationName, relationName, entityName))
|
||||
b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n",
|
||||
fk.ColumnName, relationName, relationName))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Meta class
|
||||
b.WriteString(" class Meta:\n")
|
||||
b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName))
|
||||
b.WriteString("\n")
|
||||
|
||||
// create_minimal method
|
||||
b.WriteString(" @classmethod\n")
|
||||
b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n",
|
||||
ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString(" minimal_params = {\n")
|
||||
|
||||
// Add foreign key params
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
// Check if this FK is required
|
||||
var col *database.Column
|
||||
for i := range ctx.TableInfo.Columns {
|
||||
if ctx.TableInfo.Columns[i].Name == fk.ColumnName {
|
||||
col = &ctx.TableInfo.Columns[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if col != nil && !col.IsNullable {
|
||||
// Required FK
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n",
|
||||
relationName, relationName, entityName))
|
||||
} else {
|
||||
// Optional FK
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString(" minimal_params.update(kwargs)\n")
|
||||
b.WriteString(" return cls.create(**minimal_params)\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func generateFakerField(col database.Column, ctx *Context) string {
|
||||
pythonType := database.GetPythonType(col)
|
||||
|
||||
switch pythonType {
|
||||
case "Decimal":
|
||||
precision := 7
|
||||
scale := 4
|
||||
if col.NumericPrecision.Valid {
|
||||
precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64)
|
||||
}
|
||||
if col.NumericScale.Valid {
|
||||
scale = int(col.NumericScale.Int64)
|
||||
}
|
||||
return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)",
|
||||
precision, scale)
|
||||
|
||||
case "int":
|
||||
if col.DataType == "bigint" {
|
||||
return "factory.Faker(\"pyint\")"
|
||||
}
|
||||
return "factory.Faker(\"pyint\")"
|
||||
|
||||
case "str":
|
||||
// Check if it's an enum
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
if len(enumType.Values) > 0 {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxLen := 255
|
||||
if col.CharMaxLength.Valid {
|
||||
maxLen = int(col.CharMaxLength.Int64)
|
||||
}
|
||||
if col.DataType == "text" {
|
||||
return "factory.Faker(\"text\", max_nb_chars=500)"
|
||||
}
|
||||
return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen)
|
||||
|
||||
case "datetime":
|
||||
return "factory.Faker(\"date_time\")"
|
||||
|
||||
case "date":
|
||||
return "factory.Faker(\"date\")"
|
||||
|
||||
case "time":
|
||||
return "factory.Faker(\"time\")"
|
||||
|
||||
case "bool":
|
||||
return "factory.Faker(\"boolean\")"
|
||||
|
||||
case "dict":
|
||||
return "factory.Faker(\"pydict\")"
|
||||
|
||||
default:
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user