This commit is contained in:
Eden Kirin
2025-10-31 19:04:40 +01:00
parent ca10b01fb0
commit 4e4827d640
12 changed files with 2612 additions and 10 deletions

View File

@ -0,0 +1,199 @@
package generator
import (
"strings"
"testing"
"github.com/entity-maker/entity-maker/internal/database"
)
func TestSanitizePythonIdentifier(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"starts with digit", "12_HOUR", "_12_HOUR"},
{"starts with letter", "HOUR_12", "HOUR_12"},
{"all digits", "24", "_24"},
{"underscore first", "_12_HOUR", "_12_HOUR"},
{"empty", "", ""},
{"single digit", "1", "_1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := sanitizePythonIdentifier(tt.input)
if result != tt.expected {
t.Errorf("sanitizePythonIdentifier(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGenerateEnum(t *testing.T) {
tests := []struct {
name string
enumTypes map[string]database.EnumType
expectError bool
checkFunc func(t *testing.T, result string)
}{
{
name: "simple enum",
enumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"OPEN", "CLOSED", "PENDING"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class StatusEnum") {
t.Error("Expected class StatusEnum")
}
if !strings.Contains(result, "OPEN = \"OPEN\"") {
t.Error("Expected OPEN = \"OPEN\"")
}
if !strings.Contains(result, "CLOSED = \"CLOSED\"") {
t.Error("Expected CLOSED = \"CLOSED\"")
}
if !strings.Contains(result, "PENDING = \"PENDING\"") {
t.Error("Expected PENDING = \"PENDING\"")
}
},
},
{
name: "enum with spaces and hyphens",
enumTypes: map[string]database.EnumType{
"time_format_enum": {
TypeName: "time_format_enum",
Values: []string{"12-hour", "24-hour"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class TimeFormatEnum") {
t.Error("Expected class TimeFormatEnum")
}
if !strings.Contains(result, "_12_HOUR = \"12-hour\"") {
t.Error("Expected _12_HOUR = \"12-hour\" (sanitized)")
}
if !strings.Contains(result, "_24_HOUR = \"24-hour\"") {
t.Error("Expected _24_HOUR = \"24-hour\" (sanitized)")
}
},
},
{
name: "enum with duplicates after normalization",
enumTypes: map[string]database.EnumType{
"measurement_enum": {
TypeName: "measurement_enum",
Values: []string{"international", "INTERNATIONAL", "imperial", "IMPERIAL"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class MeasurementEnum") {
t.Error("Expected class MeasurementEnum")
}
// Should only have one INTERNATIONAL and one IMPERIAL
internationalCount := strings.Count(result, "INTERNATIONAL = ")
if internationalCount != 1 {
t.Errorf("Expected 1 INTERNATIONAL, got %d", internationalCount)
}
imperialCount := strings.Count(result, "IMPERIAL = ")
if imperialCount != 1 {
t.Errorf("Expected 1 IMPERIAL, got %d", imperialCount)
}
},
},
{
name: "empty enum",
enumTypes: map[string]database.EnumType{
"empty_enum": {
TypeName: "empty_enum",
Values: []string{},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class EmptyEnum") {
t.Error("Expected class EmptyEnum")
}
if !strings.Contains(result, "pass") {
t.Error("Expected 'pass' for empty enum")
}
},
},
{
name: "multiple enums",
enumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"OPEN", "CLOSED"},
},
"priority_enum": {
TypeName: "priority_enum",
Values: []string{"HIGH", "LOW"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "class StatusEnum") {
t.Error("Expected class StatusEnum")
}
if !strings.Contains(result, "class PriorityEnum") {
t.Error("Expected class PriorityEnum")
}
},
},
{
name: "enum with special characters",
enumTypes: map[string]database.EnumType{
"special_enum": {
TypeName: "special_enum",
Values: []string{"IN-PROGRESS", "ON HOLD", "DONE"},
},
},
expectError: false,
checkFunc: func(t *testing.T, result string) {
if !strings.Contains(result, "IN_PROGRESS = \"IN-PROGRESS\"") {
t.Error("Expected IN_PROGRESS = \"IN-PROGRESS\"")
}
if !strings.Contains(result, "ON_HOLD = \"ON HOLD\"") {
t.Error("Expected ON_HOLD = \"ON HOLD\"")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := &Context{
TableInfo: &database.TableInfo{
EnumTypes: tt.enumTypes,
},
EntityName: "TestEntity",
ModuleName: "test_entity",
}
result, err := GenerateEnum(ctx)
if (err != nil) != tt.expectError {
t.Errorf("GenerateEnum() error = %v, expectError %v", err, tt.expectError)
return
}
if tt.checkFunc != nil {
tt.checkFunc(t, result)
}
// Check common requirements
if !strings.Contains(result, "from enum import StrEnum") {
t.Error("Expected import of StrEnum")
}
if !strings.Contains(result, "from televend_core.databases.enum import EnumMixin") {
t.Error("Expected import of EnumMixin")
}
})
}
}

View File

@ -57,15 +57,11 @@ func GetRelationshipModuleName(tableName string) string {
}
// GetFilterFieldName returns the filter field name for a column
// For ID fields with IN operator, pluralizes the name
// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids)
func GetFilterFieldName(columnName string, useIN bool) string {
if useIN && strings.HasSuffix(columnName, "_id") {
// Remove _id, pluralize, add back _ids
base := columnName[:len(columnName)-3]
if base == "" {
return "ids"
}
return naming.Pluralize(base) + "_ids"
// Replace _id with _ids
return columnName[:len(columnName)-2] + "ids"
}
if useIN && columnName == "id" {
return "ids"

View File

@ -0,0 +1,979 @@
package generator
import (
"database/sql"
"strings"
"testing"
"github.com/entity-maker/entity-maker/internal/database"
)
func TestGetRelationshipName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"with _id suffix", "author_info_id", "author_info"},
{"with user_id", "user_id", "user"},
{"without _id", "status", "status"},
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetRelationshipEntityName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"plural table", "users", "User"},
{"compound plural", "user_accounts", "UserAccount"},
{"ies ending", "companies", "Company"},
{"already singular", "user", "User"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipEntityName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetRelationshipModuleName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"plural table", "users", "user"},
{"compound plural", "user_accounts", "user_account"},
{"ies ending", "companies", "company"},
{"already singular", "user", "user"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipModuleName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetFilterFieldName(t *testing.T) {
tests := []struct {
name string
columnName string
useIN bool
expected string
}{
{"id with IN", "id", true, "ids"},
{"id without IN", "id", false, "id"},
{"user_id with IN", "user_id", true, "user_ids"},
{"user_id without IN", "user_id", false, "user_id"},
{"machine_id with IN", "machine_id", true, "machine_ids"},
{"status without IN", "status", false, "status"},
{"status with IN", "status", true, "status"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetFilterFieldName(tt.columnName, tt.useIN)
if result != tt.expected {
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
tt.columnName, tt.useIN, result, tt.expected)
}
})
}
}
func TestShouldGenerateFilter(t *testing.T) {
tests := []struct {
name string
col database.Column
expected bool
}{
{
name: "boolean field",
col: database.Column{
Name: "alive",
DataType: "boolean",
},
expected: true,
},
{
name: "id field",
col: database.Column{
Name: "id",
DataType: "integer",
},
expected: true,
},
{
name: "foreign key field",
col: database.Column{
Name: "user_id",
DataType: "integer",
},
expected: true,
},
{
name: "text field",
col: database.Column{
Name: "name",
DataType: "text",
},
expected: true,
},
{
name: "varchar field",
col: database.Column{
Name: "email",
DataType: "character varying",
},
expected: true,
},
{
name: "numeric field",
col: database.Column{
Name: "amount",
DataType: "numeric",
},
expected: false,
},
{
name: "timestamp field",
col: database.Column{
Name: "created_at",
DataType: "timestamp with time zone",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ShouldGenerateFilter(tt.col)
if result != tt.expected {
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
}
})
}
}
func TestNeedsDecimalImport(t *testing.T) {
tests := []struct {
name string
columns []database.Column
expected bool
}{
{
name: "has numeric column",
columns: []database.Column{
{Name: "amount", DataType: "numeric"},
},
expected: true,
},
{
name: "has decimal column",
columns: []database.Column{
{Name: "price", DataType: "decimal"},
},
expected: true,
},
{
name: "no numeric columns",
columns: []database.Column{
{Name: "id", DataType: "integer"},
{Name: "name", DataType: "varchar"},
},
expected: false,
},
{
name: "empty columns",
columns: []database.Column{},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NeedsDecimalImport(tt.columns)
if result != tt.expected {
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
}
})
}
}
func TestNeedsDatetimeImport(t *testing.T) {
tests := []struct {
name string
columns []database.Column
expected bool
}{
{
name: "has timestamp column",
columns: []database.Column{
{Name: "created_at", DataType: "timestamp with time zone"},
},
expected: true,
},
{
name: "has date column",
columns: []database.Column{
{Name: "birth_date", DataType: "date"},
},
expected: true,
},
{
name: "no datetime columns",
columns: []database.Column{
{Name: "id", DataType: "integer"},
{Name: "name", DataType: "varchar"},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NeedsDatetimeImport(tt.columns)
if result != tt.expected {
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetRequiredColumns(t *testing.T) {
columns := []database.Column{
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
{Name: "email", DataType: "varchar", IsNullable: true},
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
}
result := GetRequiredColumns(columns)
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
if len(result) != 1 {
t.Errorf("Expected 1 required column, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "name" {
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
}
}
func TestGetOptionalColumns(t *testing.T) {
columns := []database.Column{
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
}
result := GetOptionalColumns(columns)
// Should include 'email' (nullable) and 'created_at' (has default)
if len(result) != 2 {
t.Errorf("Expected 2 optional columns, got %d", len(result))
}
}
func TestGetForeignKeyForColumn(t *testing.T) {
fks := []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users"},
{ColumnName: "company_id", ForeignTableName: "companies"},
}
tests := []struct {
name string
columnName string
expectNil bool
expectFK string
}{
{"found user_id", "user_id", false, "users"},
{"found company_id", "company_id", false, "companies"},
{"not found", "status_id", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetForeignKeyForColumn(tt.columnName, fks)
if tt.expectNil {
if result != nil {
t.Errorf("Expected nil, got %+v", result)
}
} else {
if result == nil {
t.Errorf("Expected FK, got nil")
} else if result.ForeignTableName != tt.expectFK {
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
}
}
})
}
}
func TestGenerateInit(t *testing.T) {
ctx := &Context{
EntityName: "User",
ModuleName: "user",
}
result, err := GenerateInit(ctx)
if err != nil {
t.Fatalf("GenerateInit failed: %v", err)
}
// __init__.py should be empty
if result != "" {
t.Errorf("Expected empty __init__.py, got %q", result)
}
}
func TestGenerateRepository(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{},
}
ctx := NewContext(tableInfo, "")
result, err := GenerateRepository(ctx)
if err != nil {
t.Fatalf("GenerateRepository failed: %v", err)
}
// Check for expected content
expectedStrings := []string{
"class UserRepository",
"CRUDRepository",
"model_cls = User",
"from televend_core.databases.base_repository import CRUDRepository",
"from televend_core.databases.televend_repositories.user.filter import",
"from televend_core.databases.televend_repositories.user.model import User",
}
for _, expected := range expectedStrings {
if !strings.Contains(result, expected) {
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
}
func TestGenerateManager(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{},
}
ctx := NewContext(tableInfo, "")
result, err := GenerateManager(ctx)
if err != nil {
t.Fatalf("GenerateManager failed: %v", err)
}
// Check for expected content
expectedStrings := []string{
"class UserManager",
"CRUDManager",
"repository_cls = UserRepository",
}
for _, expected := range expectedStrings {
if !strings.Contains(result, expected) {
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
}
func TestNewContext(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "user_accounts",
}
// Test without override
ctx := NewContext(tableInfo, "")
if ctx.EntityName != "UserAccount" {
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
}
if ctx.ModuleName != "user_account" {
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
}
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
}
// Test with override
ctx = NewContext(tableInfo, "CustomUser")
if ctx.EntityName != "CustomUser" {
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
}
}
func TestGenerateTable(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "simple table with basic types",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from sqlalchemy import",
"Column",
"Table",
"Integer",
"String",
"USER_TABLE = Table(",
`"users"`,
"metadata_obj",
`Column("id", Integer, primary_key=True, autoincrement=True)`,
`Column("name", String(255), nullable=False)`,
`Column("email", String)`,
},
},
{
name: "table with foreign keys",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"ForeignKey",
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
},
},
{
name: "table with enum",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{
"order_status": {
TypeName: "order_status",
Values: []string{"pending", "completed", "cancelled"},
},
},
},
expected: []string{
"Enum",
"from televend_core.databases.televend_repositories.order.enum import",
"OrderStatus",
`Enum(`,
`*OrderStatus.to_value_list()`,
`name="order_status"`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateTable(ctx)
if err != nil {
t.Fatalf("GenerateTable failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateModel(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "simple model",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from dataclasses import dataclass",
"from televend_core.databases.base_model import Base",
"@dataclass",
"class User(Base):",
"name: str",
"email: str | None = None",
"id: int | None = None",
},
},
{
name: "model with foreign key",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from televend_core.databases.televend_repositories.user.model import User",
"user_id: int",
"user: User",
"title: str",
},
},
{
name: "model with datetime and decimal",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "amount", DataType: "numeric", IsNullable: false},
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from datetime import datetime",
"from decimal import Decimal",
"amount: Decimal",
"created_at: datetime | None = None",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateModel(ctx)
if err != nil {
t.Fatalf("GenerateModel failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateFilter(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
notExpect []string
}{
{
name: "filter with boolean and id fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "name", DataType: "varchar"},
{Name: "alive", DataType: "boolean"},
{Name: "user_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"from televend_core.databases.base_filter import BaseFilter",
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
"class UserFilter(BaseFilter):",
"model_cls = User",
"id: int | None = filterfield(operator=EQ)",
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
"name: str | None = filterfield(operator=EQ)",
"alive: bool | None = filterfield(operator=EQ, default=True)",
"user_id: int | None = filterfield(operator=EQ)",
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
},
notExpect: []string{
"default=None",
},
},
{
name: "filter with no filterable fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "logs",
Columns: []database.Column{
{Name: "timestamp", DataType: "timestamp with time zone"},
{Name: "amount", DataType: "numeric"},
},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"class LogFilter(BaseFilter):",
"pass",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateFilter(ctx)
if err != nil {
t.Fatalf("GenerateFilter failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
for _, notExpected := range tt.notExpect {
if strings.Contains(result, notExpected) {
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
notExpected, result)
}
}
})
}
}
func TestGenerateLoadOptions(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "load options with relationships",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "user_id", DataType: "integer"},
{Name: "category_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
},
},
expected: []string{
"from televend_core.databases.base_load_options import LoadOptions",
"from televend_core.databases.common.load_options import joinload",
"class PostLoadOptions(LoadOptions):",
"model_cls = Post",
`load_user: bool = joinload(relations=["user"])`,
`load_category: bool = joinload(relations=["category"])`,
},
},
{
name: "load options with no relationships",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "settings",
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"class SettingLoadOptions(LoadOptions):",
"pass",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateLoadOptions(ctx)
if err != nil {
t.Fatalf("GenerateLoadOptions failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateFactory(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "factory with basic fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
{Name: "alive", DataType: "boolean"},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from __future__ import annotations",
"from typing import Type",
"import factory",
"class UserFactory(TelevendBaseFactory):",
"alive = True",
"id = None",
`name = factory.Faker("pystr", max_chars=100)`,
"class Meta:",
"model = User",
"def create_minimal",
},
},
{
name: "factory with foreign keys",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from televend_core.databases.televend_repositories.user.factory import",
"UserFactory",
`user = CustomSelfAttribute("..user", UserFactory)`,
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
},
},
{
name: "factory with decimal field",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateFactory(ctx)
if err != nil {
t.Fatalf("GenerateFactory failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateMapper(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "mapper with single foreign key",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "user_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
},
expected: []string{
"mapper_registry.map_imperatively(",
"class_=Post,",
"local_table=POST_TABLE,",
"properties={",
`"user": relationship(`,
"User, lazy=relationship_loading_strategy.value",
},
},
{
name: "mapper with multiple foreign keys to same table",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "messages",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "sender_id", DataType: "integer"},
{Name: "receiver_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
},
expected: []string{
`"sender": relationship(`,
"User, lazy=relationship_loading_strategy.value",
`"receiver": relationship(`,
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateMapper(ctx)
if err != nil {
t.Fatalf("GenerateMapper failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGetPythonTypeForColumn(t *testing.T) {
ctx := &Context{
TableInfo: &database.TableInfo{
EnumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"active", "inactive"},
},
},
},
}
tests := []struct {
name string
col database.Column
expected string
}{
{
name: "integer type",
col: database.Column{DataType: "integer"},
expected: "int",
},
{
name: "varchar type",
col: database.Column{DataType: "varchar"},
expected: "str",
},
{
name: "enum type",
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
expected: "StatusEnum",
},
{
name: "unknown enum type",
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
expected: "str",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPythonTypeForColumn(tt.col, ctx)
if result != tt.expected {
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}