161 lines
4.1 KiB
Go
161 lines
4.1 KiB
Go
package generator
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/entity-maker/entity-maker/internal/database"
|
|
"github.com/entity-maker/entity-maker/internal/naming"
|
|
)
|
|
|
|
// GenerateTable generates the SQLAlchemy table definition
|
|
func GenerateTable(ctx *Context) (string, error) {
|
|
var b strings.Builder
|
|
|
|
// Imports
|
|
b.WriteString("from sqlalchemy import (\n")
|
|
|
|
// Collect unique imports
|
|
imports := make(map[string]bool)
|
|
imports["Column"] = true
|
|
imports["Table"] = true
|
|
|
|
for _, col := range ctx.TableInfo.Columns {
|
|
sqlType := database.GetSQLAlchemyType(col)
|
|
|
|
// Extract base type name (before parentheses)
|
|
typeName := strings.Split(sqlType, "(")[0]
|
|
imports[typeName] = true
|
|
|
|
if col.IsPrimaryKey {
|
|
imports["Integer"] = true
|
|
}
|
|
|
|
// Check for foreign keys
|
|
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
|
imports["ForeignKey"] = true
|
|
}
|
|
|
|
// Check for enums
|
|
if col.DataType == "USER-DEFINED" {
|
|
imports["Enum"] = true
|
|
}
|
|
}
|
|
|
|
// Sort and write imports
|
|
importList := []string{}
|
|
for imp := range imports {
|
|
importList = append(importList, imp)
|
|
}
|
|
|
|
// Write imports in a reasonable order
|
|
orderedImports := []string{
|
|
"BigInteger", "Boolean", "Column", "Date", "DateTime", "Enum",
|
|
"Float", "ForeignKey", "Integer", "JSON", "JSONB", "LargeBinary",
|
|
"Numeric", "SmallInteger", "String", "Table", "Text", "Time", "UUID",
|
|
}
|
|
|
|
first := true
|
|
for _, imp := range orderedImports {
|
|
if imports[imp] {
|
|
if !first {
|
|
b.WriteString(",\n")
|
|
} else {
|
|
first = false
|
|
}
|
|
b.WriteString(" " + imp)
|
|
}
|
|
}
|
|
b.WriteString(",\n)\n\n")
|
|
|
|
// Import enum types if they exist
|
|
if len(ctx.TableInfo.EnumTypes) > 0 {
|
|
for _, enumType := range ctx.TableInfo.EnumTypes {
|
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
|
b.WriteString(fmt.Sprintf("from televend_core.databases.cloud_repositories.%s.enum import (\n",
|
|
ctx.ModuleName))
|
|
b.WriteString(fmt.Sprintf(" %s,\n", enumName))
|
|
b.WriteString(")\n")
|
|
}
|
|
}
|
|
|
|
b.WriteString("from televend_core.databases.cloud_repositories.table_meta import metadata_obj\n\n")
|
|
|
|
// Table definition
|
|
b.WriteString(fmt.Sprintf("%s = Table(\n", ctx.TableConstant))
|
|
b.WriteString(fmt.Sprintf(" \"%s\",\n", ctx.TableInfo.TableName))
|
|
b.WriteString(" metadata_obj,\n")
|
|
|
|
// Columns
|
|
for _, col := range ctx.TableInfo.Columns {
|
|
b.WriteString(generateColumnDefinition(col, ctx))
|
|
}
|
|
|
|
b.WriteString(")\n")
|
|
|
|
return b.String(), nil
|
|
}
|
|
|
|
func generateColumnDefinition(col database.Column, ctx *Context) string {
|
|
var parts []string
|
|
|
|
// Column name
|
|
parts = append(parts, fmt.Sprintf("\"%s\"", col.Name))
|
|
|
|
// Column type
|
|
sqlType := database.GetSQLAlchemyType(col)
|
|
|
|
// Handle enum types
|
|
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
|
if enumType, exists := ctx.TableInfo.EnumTypes[col.UdtName]; exists {
|
|
enumName := naming.ToPascalCase(enumType.TypeName)
|
|
sqlType = fmt.Sprintf("Enum(\n *%s.to_value_list(),\n name=\"%s\",\n )",
|
|
enumName, enumType.TypeName)
|
|
}
|
|
}
|
|
|
|
parts = append(parts, sqlType)
|
|
|
|
// Foreign key
|
|
if fk := GetForeignKeyForColumn(col.Name, ctx.TableInfo.ForeignKeys); fk != nil {
|
|
fkDef := fmt.Sprintf("ForeignKey(\"%s.%s\", deferrable=True, initially=\"DEFERRED\")",
|
|
fk.ForeignTableName, fk.ForeignColumnName)
|
|
parts = append(parts, fkDef)
|
|
}
|
|
|
|
// Primary key
|
|
if col.IsPrimaryKey {
|
|
parts = append(parts, "primary_key=True")
|
|
if col.IsAutoIncrement {
|
|
parts = append(parts, "autoincrement=True")
|
|
}
|
|
}
|
|
|
|
// Nullable
|
|
if !col.IsNullable && !col.IsPrimaryKey {
|
|
parts = append(parts, "nullable=False")
|
|
}
|
|
|
|
// Unique
|
|
// Note: We don't have unique constraint info in our introspection yet
|
|
// This would need to be added if needed
|
|
|
|
// Format the column definition
|
|
result := " Column("
|
|
|
|
// Check if we need multiline formatting (for complex types like Enum)
|
|
if strings.Contains(sqlType, "\n") {
|
|
// Multiline format
|
|
result += parts[0] + ",\n " + parts[1]
|
|
for _, part := range parts[2:] {
|
|
result += ",\n " + part
|
|
}
|
|
result += ",\n ),\n"
|
|
} else {
|
|
// Single line format
|
|
result += strings.Join(parts, ", ") + "),\n"
|
|
}
|
|
|
|
return result
|
|
}
|