First working version
This commit is contained in:
40
internal/generator/enum.go
Normal file
40
internal/generator/enum.go
Normal file
@ -0,0 +1,40 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateEnum generates the enum types file
|
||||
func GenerateEnum(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from enum import StrEnum\n\n")
|
||||
b.WriteString("from televend_core.databases.enum import EnumMixin\n\n\n")
|
||||
|
||||
// Generate each enum type
|
||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
|
||||
b.WriteString(fmt.Sprintf("class %s(EnumMixin, StrEnum):\n", enumName))
|
||||
|
||||
if len(enumType.Values) == 0 {
|
||||
b.WriteString(" pass\n")
|
||||
} else {
|
||||
for _, value := range enumType.Values {
|
||||
// Convert value to valid Python identifier
|
||||
// Usually enum values are already uppercase like "OPEN", "IN_PROGRESS"
|
||||
identifier := strings.ToUpper(strings.ReplaceAll(value, " ", "_"))
|
||||
identifier = strings.ReplaceAll(identifier, "-", "_")
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s = \"%s\"\n", identifier, value))
|
||||
}
|
||||
}
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
203
internal/generator/factory.go
Normal file
203
internal/generator/factory.go
Normal file
@ -0,0 +1,203 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateFactory generates the factory class
|
||||
func GenerateFactory(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from __future__ import annotations\n\n")
|
||||
b.WriteString("from typing import Type\n\n")
|
||||
b.WriteString("import factory\n\n")
|
||||
|
||||
// Import factories for related models
|
||||
fkImports := make(map[string]string) // module_name -> entity_name
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
fkImports[moduleName] = entityName
|
||||
}
|
||||
|
||||
for moduleName, entityName := range fkImports {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.factory import (\n",
|
||||
moduleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFactory,\n", entityName))
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString("from televend_core.test_extras.factory_boy_utils import (\n")
|
||||
b.WriteString(" CustomSelfAttribute,\n")
|
||||
b.WriteString(" TelevendBaseFactory,\n")
|
||||
b.WriteString(")\n\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sFactory(TelevendBaseFactory):\n", ctx.EntityName))
|
||||
|
||||
// Add boolean fields with defaults
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.DataType == "boolean" {
|
||||
defaultValue := "True"
|
||||
if col.Name == "alive" {
|
||||
defaultValue = "True"
|
||||
} else {
|
||||
defaultValue = "False"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, defaultValue))
|
||||
}
|
||||
}
|
||||
|
||||
// Add id field
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
b.WriteString(fmt.Sprintf(" %s = None\n", col.Name))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
|
||||
// Generate faker fields for each column
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey || col.DataType == "boolean" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Skip foreign keys, we'll handle them separately
|
||||
if GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
fakerDef := generateFakerField(col, ctx)
|
||||
if fakerDef != "" {
|
||||
b.WriteString(fmt.Sprintf(" %s = %s\n", col.Name, fakerDef))
|
||||
}
|
||||
}
|
||||
|
||||
// Generate foreign key relationships
|
||||
if len(ctx.TableInfo.ForeignKeys) > 0 {
|
||||
b.WriteString("\n")
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s = CustomSelfAttribute(\"..%s\", %sFactory)\n",
|
||||
relationName, relationName, entityName))
|
||||
b.WriteString(fmt.Sprintf(" %s = factory.LazyAttribute(lambda a: a.%s.id if a.%s else None)\n",
|
||||
fk.ColumnName, relationName, relationName))
|
||||
b.WriteString("\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Meta class
|
||||
b.WriteString(" class Meta:\n")
|
||||
b.WriteString(fmt.Sprintf(" model = %s\n", ctx.EntityName))
|
||||
b.WriteString("\n")
|
||||
|
||||
// create_minimal method
|
||||
b.WriteString(" @classmethod\n")
|
||||
b.WriteString(fmt.Sprintf(" def create_minimal(cls: Type[%sFactory], **kwargs) -> %s:\n",
|
||||
ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString(" minimal_params = {\n")
|
||||
|
||||
// Add foreign key params
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
// Check if this FK is required
|
||||
var col *database.Column
|
||||
for i := range ctx.TableInfo.Columns {
|
||||
if ctx.TableInfo.Columns[i].Name == fk.ColumnName {
|
||||
col = &ctx.TableInfo.Columns[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if col != nil && !col.IsNullable {
|
||||
// Required FK
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": kwargs.pop(\"%s\", None) or %sFactory.create_minimal(),\n",
|
||||
relationName, relationName, entityName))
|
||||
} else {
|
||||
// Optional FK
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": None,\n", relationName))
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString(" }\n")
|
||||
b.WriteString(" minimal_params.update(kwargs)\n")
|
||||
b.WriteString(" return cls.create(**minimal_params)\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func generateFakerField(col database.Column, ctx *Context) string {
|
||||
pythonType := database.GetPythonType(col)
|
||||
|
||||
switch pythonType {
|
||||
case "Decimal":
|
||||
precision := 7
|
||||
scale := 4
|
||||
if col.NumericPrecision.Valid {
|
||||
precision = int(col.NumericPrecision.Int64) - int(col.NumericScale.Int64)
|
||||
}
|
||||
if col.NumericScale.Valid {
|
||||
scale = int(col.NumericScale.Int64)
|
||||
}
|
||||
return fmt.Sprintf("factory.Faker(\"pydecimal\", left_digits=%d, right_digits=%d, positive=True)",
|
||||
precision, scale)
|
||||
|
||||
case "int":
|
||||
if col.DataType == "bigint" {
|
||||
return "factory.Faker(\"pyint\")"
|
||||
}
|
||||
return "factory.Faker(\"pyint\")"
|
||||
|
||||
case "str":
|
||||
// Check if it's an enum
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
if len(enumType.Values) > 0 {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
return fmt.Sprintf("factory.Iterator(%s.to_value_list())", enumName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
maxLen := 255
|
||||
if col.CharMaxLength.Valid {
|
||||
maxLen = int(col.CharMaxLength.Int64)
|
||||
}
|
||||
if col.DataType == "text" {
|
||||
return "factory.Faker(\"text\", max_nb_chars=500)"
|
||||
}
|
||||
return fmt.Sprintf("factory.Faker(\"pystr\", max_chars=%d)", maxLen)
|
||||
|
||||
case "datetime":
|
||||
return "factory.Faker(\"date_time\")"
|
||||
|
||||
case "date":
|
||||
return "factory.Faker(\"date\")"
|
||||
|
||||
case "time":
|
||||
return "factory.Faker(\"time\")"
|
||||
|
||||
case "bool":
|
||||
return "factory.Faker(\"boolean\")"
|
||||
|
||||
case "dict":
|
||||
return "factory.Faker(\"pydict\")"
|
||||
|
||||
default:
|
||||
return "None"
|
||||
}
|
||||
}
|
||||
77
internal/generator/filter.go
Normal file
77
internal/generator/filter.go
Normal file
@ -0,0 +1,77 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
)
|
||||
|
||||
// GenerateFilter generates the filter class
|
||||
func GenerateFilter(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_filter import BaseFilter\n")
|
||||
b.WriteString("from televend_core.databases.common.filters.filters import EQ, IN, filterfield\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sFilter(BaseFilter):\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
|
||||
|
||||
// Generate filters based on rules:
|
||||
// - Boolean fields: EQ operator
|
||||
// - ID fields: both EQ and IN operators
|
||||
// - Text fields: EQ operator
|
||||
|
||||
hasFilters := false
|
||||
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if !ShouldGenerateFilter(col) {
|
||||
continue
|
||||
}
|
||||
|
||||
hasFilters = true
|
||||
|
||||
// Boolean fields
|
||||
if col.DataType == "boolean" {
|
||||
defaultVal := "None"
|
||||
if col.Name == "alive" {
|
||||
defaultVal = "True"
|
||||
}
|
||||
b.WriteString(fmt.Sprintf(" %s: bool | None = filterfield(operator=EQ, default=%s)\n",
|
||||
col.Name, defaultVal))
|
||||
}
|
||||
|
||||
// ID fields (both EQ and IN)
|
||||
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
|
||||
// Single ID with EQ
|
||||
pythonType := database.GetPythonType(col)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = filterfield(operator=EQ)\n",
|
||||
col.Name, pythonType))
|
||||
|
||||
// Multiple IDs with IN (plural field name)
|
||||
pluralFieldName := GetFilterFieldName(col.Name, true)
|
||||
b.WriteString(fmt.Sprintf(" %s: list[%s] | None = filterfield(field=\"%s\", operator=IN)\n",
|
||||
pluralFieldName, pythonType, col.Name))
|
||||
}
|
||||
|
||||
// Text fields
|
||||
if (col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character") &&
|
||||
!strings.HasSuffix(col.Name, "_id") && col.Name != "id" {
|
||||
b.WriteString(fmt.Sprintf(" %s: str | None = filterfield(operator=EQ)\n",
|
||||
col.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// If no filters were generated, add a pass statement
|
||||
if !hasFilters {
|
||||
b.WriteString(" pass\n")
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
195
internal/generator/generator.go
Normal file
195
internal/generator/generator.go
Normal file
@ -0,0 +1,195 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// Context contains all information needed for code generation
|
||||
type Context struct {
|
||||
TableInfo *database.TableInfo
|
||||
EntityName string // Singular, PascalCase (e.g., "CashbagConform")
|
||||
ModuleName string // Singular, snake_case (e.g., "cashbag_conform")
|
||||
TableConstant string // Uppercase with TABLE suffix (e.g., "CASHBAG_CONFORM_TABLE")
|
||||
}
|
||||
|
||||
// NewContext creates a new generation context
|
||||
func NewContext(tableInfo *database.TableInfo, entityNameOverride string) *Context {
|
||||
moduleName := naming.SingularizeTableName(tableInfo.TableName)
|
||||
|
||||
entityName := entityNameOverride
|
||||
if entityName == "" {
|
||||
entityName = naming.ToPascalCase(moduleName)
|
||||
}
|
||||
|
||||
tableConstant := strings.ToUpper(moduleName) + "_TABLE"
|
||||
|
||||
return &Context{
|
||||
TableInfo: tableInfo,
|
||||
EntityName: entityName,
|
||||
ModuleName: moduleName,
|
||||
TableConstant: tableConstant,
|
||||
}
|
||||
}
|
||||
|
||||
// GetRelationshipName returns the relationship name for a foreign key
|
||||
// Strips _id suffix and converts to snake_case
|
||||
func GetRelationshipName(fkColumnName string) string {
|
||||
name := fkColumnName
|
||||
if strings.HasSuffix(name, "_id") {
|
||||
name = name[:len(name)-3]
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
// GetRelationshipEntityName returns the entity name for a foreign key's target table
|
||||
func GetRelationshipEntityName(tableName string) string {
|
||||
singular := naming.SingularizeTableName(tableName)
|
||||
return naming.ToPascalCase(singular)
|
||||
}
|
||||
|
||||
// GetRelationshipModuleName returns the module name for a foreign key's target table
|
||||
func GetRelationshipModuleName(tableName string) string {
|
||||
return naming.SingularizeTableName(tableName)
|
||||
}
|
||||
|
||||
// GetFilterFieldName returns the filter field name for a column
|
||||
// For ID fields with IN operator, pluralizes the name
|
||||
func GetFilterFieldName(columnName string, useIN bool) string {
|
||||
if useIN && strings.HasSuffix(columnName, "_id") {
|
||||
// Remove _id, pluralize, add back _ids
|
||||
base := columnName[:len(columnName)-3]
|
||||
if base == "" {
|
||||
return "ids"
|
||||
}
|
||||
return naming.Pluralize(base) + "_ids"
|
||||
}
|
||||
if useIN && columnName == "id" {
|
||||
return "ids"
|
||||
}
|
||||
return columnName
|
||||
}
|
||||
|
||||
// ShouldGenerateFilter determines if a column should have a filter
|
||||
func ShouldGenerateFilter(col database.Column) bool {
|
||||
// Generate filters for boolean, ID, and text fields
|
||||
if col.DataType == "boolean" {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(col.Name, "_id") || col.Name == "id" {
|
||||
return true
|
||||
}
|
||||
if col.DataType == "character varying" || col.DataType == "varchar" ||
|
||||
col.DataType == "text" || col.DataType == "char" || col.DataType == "character" {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetPythonTypeForColumn returns the Python type annotation for a column
|
||||
func GetPythonTypeForColumn(col database.Column, ctx *Context) string {
|
||||
baseType := database.GetPythonType(col)
|
||||
|
||||
// Handle enum types
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
baseType = naming.ToPascalCase(enumType.TypeName)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle Decimal type
|
||||
if baseType == "Decimal" {
|
||||
baseType = "Decimal"
|
||||
}
|
||||
|
||||
return baseType
|
||||
}
|
||||
|
||||
// NeedsDecimalImport checks if any column uses Decimal type
|
||||
func NeedsDecimalImport(columns []database.Column) bool {
|
||||
for _, col := range columns {
|
||||
if database.GetPythonType(col) == "Decimal" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// NeedsDatetimeImport checks if any column uses datetime type
|
||||
func NeedsDatetimeImport(columns []database.Column) bool {
|
||||
for _, col := range columns {
|
||||
pyType := database.GetPythonType(col)
|
||||
if pyType == "datetime" || pyType == "date" || pyType == "time" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetRequiredColumns returns columns that are not nullable and don't have defaults
|
||||
func GetRequiredColumns(columns []database.Column) []database.Column {
|
||||
var required []database.Column
|
||||
for _, col := range columns {
|
||||
if !col.IsNullable && !col.ColumnDefault.Valid && !col.IsAutoIncrement && !col.IsPrimaryKey {
|
||||
required = append(required, col)
|
||||
}
|
||||
}
|
||||
return required
|
||||
}
|
||||
|
||||
// GetOptionalColumns returns columns that are nullable or have defaults
|
||||
func GetOptionalColumns(columns []database.Column) []database.Column {
|
||||
var optional []database.Column
|
||||
for _, col := range columns {
|
||||
if col.IsNullable || col.ColumnDefault.Valid {
|
||||
optional = append(optional, col)
|
||||
}
|
||||
}
|
||||
return optional
|
||||
}
|
||||
|
||||
// GetForeignKeyForColumn returns the foreign key info for a column, if it exists
|
||||
func GetForeignKeyForColumn(columnName string, foreignKeys []database.ForeignKey) *database.ForeignKey {
|
||||
for _, fk := range foreignKeys {
|
||||
if fk.ColumnName == columnName {
|
||||
return &fk
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateFiles generates all Python files for the entity
|
||||
func GenerateFiles(ctx *Context, outputDir string) error {
|
||||
generators := map[string]func(*Context) (string, error){
|
||||
"table.py": GenerateTable,
|
||||
"model.py": GenerateModel,
|
||||
"filter.py": GenerateFilter,
|
||||
"load_options.py": GenerateLoadOptions,
|
||||
"repository.py": GenerateRepository,
|
||||
"manager.py": GenerateManager,
|
||||
"factory.py": GenerateFactory,
|
||||
"mapper.py": GenerateMapper,
|
||||
"__init__.py": GenerateInit,
|
||||
}
|
||||
|
||||
// Generate enum.py if there are enum types
|
||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||
generators["enum.py"] = GenerateEnum
|
||||
}
|
||||
|
||||
for filename, generator := range generators {
|
||||
content, err := generator(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate %s: %w", filename, err)
|
||||
}
|
||||
|
||||
// Write to file (this will be handled by the main function)
|
||||
// For now, just return the content
|
||||
_ = content
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
7
internal/generator/init.go
Normal file
7
internal/generator/init.go
Normal file
@ -0,0 +1,7 @@
|
||||
package generator
|
||||
|
||||
// GenerateInit generates an empty __init__.py file
|
||||
func GenerateInit(ctx *Context) (string, error) {
|
||||
// Empty __init__.py file
|
||||
return "", nil
|
||||
}
|
||||
38
internal/generator/load_options.go
Normal file
38
internal/generator/load_options.go
Normal file
@ -0,0 +1,38 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateLoadOptions generates the load options class
|
||||
func GenerateLoadOptions(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_load_options import LoadOptions\n")
|
||||
b.WriteString("from televend_core.databases.common.load_options import joinload\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sLoadOptions(LoadOptions):\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" model_cls = %s\n\n", ctx.EntityName))
|
||||
|
||||
// Generate load options for all foreign key relationships
|
||||
hasRelationships := false
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
hasRelationships = true
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
b.WriteString(fmt.Sprintf(" load_%s: bool = joinload(relations=[\"%s\"])\n",
|
||||
relationName, relationName))
|
||||
}
|
||||
|
||||
// If no relationships, add pass
|
||||
if !hasRelationships {
|
||||
b.WriteString(" pass\n")
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
34
internal/generator/manager.go
Normal file
34
internal/generator/manager.go
Normal file
@ -0,0 +1,34 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateManager generates the manager class
|
||||
func GenerateManager(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_manager import CRUDManager\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.repository import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %sRepository,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sManager(\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" CRUDManager[%s, %sFilter, %sRepository]\n",
|
||||
ctx.EntityName, ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString("):\n")
|
||||
b.WriteString(fmt.Sprintf(" repository_cls = %sRepository\n", ctx.EntityName))
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
48
internal/generator/mapper.go
Normal file
48
internal/generator/mapper.go
Normal file
@ -0,0 +1,48 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateMapper generates the mapper snippet (without imports)
|
||||
func GenerateMapper(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Mapper registration (snippet only, no imports)
|
||||
b.WriteString(" mapper_registry.map_imperatively(\n")
|
||||
b.WriteString(fmt.Sprintf(" class_=%s,\n", ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" local_table=%s,\n", ctx.TableConstant))
|
||||
b.WriteString(" properties={\n")
|
||||
|
||||
// Generate relationships for all foreign keys
|
||||
fkRelationships := make(map[string][]string) // entity -> []column_names
|
||||
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
relationName := GetRelationshipName(fk.ColumnName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
|
||||
// Group by entity name to handle multiple FKs to same table
|
||||
fkRelationships[entityName] = append(fkRelationships[entityName], fk.ColumnName)
|
||||
|
||||
if len(fkRelationships[entityName]) == 1 {
|
||||
// First FK to this table
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
|
||||
b.WriteString(fmt.Sprintf(" %s, lazy=relationship_loading_strategy.value\n", entityName))
|
||||
} else {
|
||||
// Multiple FKs to same table, need to specify foreign_keys
|
||||
b.WriteString(fmt.Sprintf(" \"%s\": relationship(\n", relationName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", entityName))
|
||||
b.WriteString(" lazy=relationship_loading_strategy.value,\n")
|
||||
b.WriteString(fmt.Sprintf(" foreign_keys=%s.columns.%s,\n",
|
||||
ctx.TableConstant, fk.ColumnName))
|
||||
}
|
||||
|
||||
b.WriteString(" ),\n")
|
||||
}
|
||||
|
||||
b.WriteString(" },\n")
|
||||
b.WriteString(" )\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
124
internal/generator/model.go
Normal file
124
internal/generator/model.go
Normal file
@ -0,0 +1,124 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateModel generates the dataclass model
|
||||
func GenerateModel(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from dataclasses import dataclass\n")
|
||||
|
||||
// Check what we need to import
|
||||
needsDatetime := NeedsDatetimeImport(ctx.TableInfo.Columns)
|
||||
needsDecimal := NeedsDecimalImport(ctx.TableInfo.Columns)
|
||||
|
||||
if needsDatetime {
|
||||
b.WriteString("from datetime import datetime\n")
|
||||
}
|
||||
if needsDecimal {
|
||||
b.WriteString("from decimal import Decimal\n")
|
||||
}
|
||||
|
||||
b.WriteString("\n")
|
||||
b.WriteString("from televend_core.databases.base_model import Base\n")
|
||||
|
||||
// Import related models for foreign keys
|
||||
fkImports := make(map[string]string) // module_name -> entity_name
|
||||
for _, fk := range ctx.TableInfo.ForeignKeys {
|
||||
moduleName := GetRelationshipModuleName(fk.ForeignTableName)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
fkImports[moduleName] = entityName
|
||||
}
|
||||
|
||||
// Import enum types
|
||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
||||
ctx.ModuleName))
|
||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||
}
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
|
||||
// Write foreign key imports
|
||||
for moduleName, entityName := range fkImports {
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
moduleName, entityName))
|
||||
}
|
||||
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString("@dataclass\n")
|
||||
b.WriteString(fmt.Sprintf("class %s(Base):\n", ctx.EntityName))
|
||||
|
||||
// Get required and optional columns
|
||||
requiredCols := GetRequiredColumns(ctx.TableInfo.Columns)
|
||||
optionalCols := GetOptionalColumns(ctx.TableInfo.Columns)
|
||||
|
||||
// Required fields (non-nullable, no default, not auto-increment, not PK)
|
||||
for _, col := range requiredCols {
|
||||
fieldName := col.Name
|
||||
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||
|
||||
// Regular field
|
||||
b.WriteString(fmt.Sprintf(" %s: %s\n", fieldName, pythonType))
|
||||
|
||||
// Add relationship field if this is a foreign key and column ends with _id
|
||||
// (to avoid name clashes with FK columns that don't follow _id convention)
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
if strings.HasSuffix(col.Name, "_id") {
|
||||
relationName := GetRelationshipName(col.Name)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s\n", relationName, entityName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Empty line between required and optional
|
||||
if len(requiredCols) > 0 && len(optionalCols) > 0 {
|
||||
b.WriteString("\n")
|
||||
}
|
||||
|
||||
// Optional fields
|
||||
for _, col := range optionalCols {
|
||||
// Skip primary key, we'll add it at the end
|
||||
if col.IsPrimaryKey {
|
||||
continue
|
||||
}
|
||||
|
||||
fieldName := col.Name
|
||||
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", fieldName, pythonType))
|
||||
|
||||
// Add relationship field if this is a foreign key and column ends with _id
|
||||
// (to avoid name clashes with FK columns that don't follow _id convention)
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
if strings.HasSuffix(col.Name, "_id") {
|
||||
relationName := GetRelationshipName(col.Name)
|
||||
entityName := GetRelationshipEntityName(fk.ForeignTableName)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", relationName, entityName))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Add primary key at the end
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
b.WriteString("\n")
|
||||
pythonType := GetPythonTypeForColumn(col, ctx)
|
||||
b.WriteString(fmt.Sprintf(" %s: %s | None = None\n", col.Name, pythonType))
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
28
internal/generator/repository.go
Normal file
28
internal/generator/repository.go
Normal file
@ -0,0 +1,28 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateRepository generates the repository class
|
||||
func GenerateRepository(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from televend_core.databases.base_repository import CRUDRepository\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.filter import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %sFilter,\n", ctx.EntityName))
|
||||
b.WriteString(")\n")
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.model import %s\n",
|
||||
ctx.ModuleName, ctx.EntityName))
|
||||
b.WriteString("\n\n")
|
||||
|
||||
// Class definition
|
||||
b.WriteString(fmt.Sprintf("class %sRepository(CRUDRepository[%s, %sFilter]):\n",
|
||||
ctx.EntityName, ctx.EntityName, ctx.EntityName))
|
||||
b.WriteString(fmt.Sprintf(" model_cls = %s\n", ctx.EntityName))
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
160
internal/generator/table.go
Normal file
160
internal/generator/table.go
Normal file
@ -0,0 +1,160 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
"github.com/entity-maker/entity-maker/internal/naming"
|
||||
)
|
||||
|
||||
// GenerateTable generates the SQLAlchemy table definition
|
||||
func GenerateTable(ctx *Context) (string, error) {
|
||||
var b strings.Builder
|
||||
|
||||
// Imports
|
||||
b.WriteString("from sqlalchemy import (\n")
|
||||
|
||||
// Collect unique imports
|
||||
imports := make(map[string]bool)
|
||||
imports["Column"] = true
|
||||
imports["Table"] = true
|
||||
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
sqlType := database.GetSQLAlchemyType(col)
|
||||
|
||||
// Extract base type name (before parentheses)
|
||||
typeName := strings.Split(sqlType, "(")[0]
|
||||
imports[typeName] = true
|
||||
|
||||
if col.IsPrimaryKey {
|
||||
imports["Integer"] = true
|
||||
}
|
||||
|
||||
// Check for foreign keys
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
imports["ForeignKey"] = true
|
||||
}
|
||||
|
||||
// Check for enums
|
||||
if col.DataType == "USER-DEFINED" {
|
||||
imports["Enum"] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Sort and write imports
|
||||
importList := []string{}
|
||||
for imp := range imports {
|
||||
importList = append(importList, imp)
|
||||
}
|
||||
|
||||
// Write imports in a reasonable order
|
||||
orderedImports := []string{
|
||||
"BigInteger", "Boolean", "Column", "Date", "DateTime", "Enum",
|
||||
"Float", "ForeignKey", "Integer", "JSON", "JSONB", "LargeBinary",
|
||||
"Numeric", "SmallInteger", "String", "Table", "Text", "Time", "UUID",
|
||||
}
|
||||
|
||||
first := true
|
||||
for _, imp := range orderedImports {
|
||||
if imports[imp] {
|
||||
if !first {
|
||||
b.WriteString(",\n")
|
||||
} else {
|
||||
first = false
|
||||
}
|
||||
b.WriteString(" " + imp)
|
||||
}
|
||||
}
|
||||
b.WriteString(",\n)\n\n")
|
||||
|
||||
// Import enum types if they exist
|
||||
if len(ctx.TableInfo.EnumTypes) > 0 {
|
||||
for _, enumType := range ctx.TableInfo.EnumTypes {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
b.WriteString(fmt.Sprintf("from televend_core.databases.televend_repositories.%s.enum import (\n",
|
||||
ctx.ModuleName))
|
||||
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
||||
b.WriteString(")\n")
|
||||
}
|
||||
}
|
||||
|
||||
b.WriteString("from televend_core.databases.televend_repositories.table_meta import metadata_obj\n\n")
|
||||
|
||||
// Table definition
|
||||
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
|
||||
b.WriteString(fmt.Sprintf(" \"%s\",\n", ctx.TableInfo.TableName))
|
||||
b.WriteString(" metadata_obj,\n")
|
||||
|
||||
// Columns
|
||||
for _, col := range ctx.TableInfo.Columns {
|
||||
b.WriteString(generateColumnDefinition(col, ctx))
|
||||
}
|
||||
|
||||
b.WriteString(")\n")
|
||||
|
||||
return b.String(), nil
|
||||
}
|
||||
|
||||
func generateColumnDefinition(col database.Column, ctx *Context) string {
|
||||
var parts []string
|
||||
|
||||
// Column name
|
||||
parts = append(parts, fmt.Sprintf("\"%s\"", col.Name))
|
||||
|
||||
// Column type
|
||||
sqlType := database.GetSQLAlchemyType(col)
|
||||
|
||||
// Handle enum types
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
||||
enumName := naming.ToPascalCase(enumType.TypeName)
|
||||
sqlType = fmt.Sprintf("Enum(\n *%s.to_value_list(),\n name=\"%s\",\n )",
|
||||
enumName, enumType.TypeName)
|
||||
}
|
||||
}
|
||||
|
||||
parts = append(parts, sqlType)
|
||||
|
||||
// Foreign key
|
||||
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
||||
fkDef := fmt.Sprintf("ForeignKey(\"%s.%s\", deferrable=True, initially=\"DEFERRED\")",
|
||||
fk.ForeignTableName, fk.ForeignColumnName)
|
||||
parts = append(parts, fkDef)
|
||||
}
|
||||
|
||||
// Primary key
|
||||
if col.IsPrimaryKey {
|
||||
parts = append(parts, "primary_key=True")
|
||||
if col.IsAutoIncrement {
|
||||
parts = append(parts, "autoincrement=True")
|
||||
}
|
||||
}
|
||||
|
||||
// Nullable
|
||||
if !col.IsNullable && !col.IsPrimaryKey {
|
||||
parts = append(parts, "nullable=False")
|
||||
}
|
||||
|
||||
// Unique
|
||||
// Note: We don't have unique constraint info in our introspection yet
|
||||
// This would need to be added if needed
|
||||
|
||||
// Format the column definition
|
||||
result := " Column("
|
||||
|
||||
// Check if we need multiline formatting (for complex types like Enum)
|
||||
if strings.Contains(sqlType, "\n") {
|
||||
// Multiline format
|
||||
result += parts[0] + ",\n " + parts[1]
|
||||
for _, part := range parts[2:] {
|
||||
result += ",\n " + part
|
||||
}
|
||||
result += ",\n ),\n"
|
||||
} else {
|
||||
// Single line format
|
||||
result += strings.Join(parts, ", ") + "),\n"
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user