This commit is contained in:
Eden Kirin
2025-10-31 19:04:40 +01:00
parent ca10b01fb0
commit 4e4827d640
12 changed files with 2612 additions and 10 deletions

View File

@ -0,0 +1,979 @@
package generator
import (
"database/sql"
"strings"
"testing"
"github.com/entity-maker/entity-maker/internal/database"
)
func TestGetRelationshipName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"with _id suffix", "author_info_id", "author_info"},
{"with user_id", "user_id", "user"},
{"without _id", "status", "status"},
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
{"empty", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetRelationshipEntityName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"plural table", "users", "User"},
{"compound plural", "user_accounts", "UserAccount"},
{"ies ending", "companies", "Company"},
{"already singular", "user", "User"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipEntityName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetRelationshipModuleName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"plural table", "users", "user"},
{"compound plural", "user_accounts", "user_account"},
{"ies ending", "companies", "company"},
{"already singular", "user", "user"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetRelationshipModuleName(tt.input)
if result != tt.expected {
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
}
})
}
}
func TestGetFilterFieldName(t *testing.T) {
tests := []struct {
name string
columnName string
useIN bool
expected string
}{
{"id with IN", "id", true, "ids"},
{"id without IN", "id", false, "id"},
{"user_id with IN", "user_id", true, "user_ids"},
{"user_id without IN", "user_id", false, "user_id"},
{"machine_id with IN", "machine_id", true, "machine_ids"},
{"status without IN", "status", false, "status"},
{"status with IN", "status", true, "status"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetFilterFieldName(tt.columnName, tt.useIN)
if result != tt.expected {
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
tt.columnName, tt.useIN, result, tt.expected)
}
})
}
}
func TestShouldGenerateFilter(t *testing.T) {
tests := []struct {
name string
col database.Column
expected bool
}{
{
name: "boolean field",
col: database.Column{
Name: "alive",
DataType: "boolean",
},
expected: true,
},
{
name: "id field",
col: database.Column{
Name: "id",
DataType: "integer",
},
expected: true,
},
{
name: "foreign key field",
col: database.Column{
Name: "user_id",
DataType: "integer",
},
expected: true,
},
{
name: "text field",
col: database.Column{
Name: "name",
DataType: "text",
},
expected: true,
},
{
name: "varchar field",
col: database.Column{
Name: "email",
DataType: "character varying",
},
expected: true,
},
{
name: "numeric field",
col: database.Column{
Name: "amount",
DataType: "numeric",
},
expected: false,
},
{
name: "timestamp field",
col: database.Column{
Name: "created_at",
DataType: "timestamp with time zone",
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ShouldGenerateFilter(tt.col)
if result != tt.expected {
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
}
})
}
}
func TestNeedsDecimalImport(t *testing.T) {
tests := []struct {
name string
columns []database.Column
expected bool
}{
{
name: "has numeric column",
columns: []database.Column{
{Name: "amount", DataType: "numeric"},
},
expected: true,
},
{
name: "has decimal column",
columns: []database.Column{
{Name: "price", DataType: "decimal"},
},
expected: true,
},
{
name: "no numeric columns",
columns: []database.Column{
{Name: "id", DataType: "integer"},
{Name: "name", DataType: "varchar"},
},
expected: false,
},
{
name: "empty columns",
columns: []database.Column{},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NeedsDecimalImport(tt.columns)
if result != tt.expected {
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
}
})
}
}
func TestNeedsDatetimeImport(t *testing.T) {
tests := []struct {
name string
columns []database.Column
expected bool
}{
{
name: "has timestamp column",
columns: []database.Column{
{Name: "created_at", DataType: "timestamp with time zone"},
},
expected: true,
},
{
name: "has date column",
columns: []database.Column{
{Name: "birth_date", DataType: "date"},
},
expected: true,
},
{
name: "no datetime columns",
columns: []database.Column{
{Name: "id", DataType: "integer"},
{Name: "name", DataType: "varchar"},
},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := NeedsDatetimeImport(tt.columns)
if result != tt.expected {
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
}
})
}
}
func TestGetRequiredColumns(t *testing.T) {
columns := []database.Column{
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
{Name: "email", DataType: "varchar", IsNullable: true},
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
}
result := GetRequiredColumns(columns)
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
if len(result) != 1 {
t.Errorf("Expected 1 required column, got %d", len(result))
}
if len(result) > 0 && result[0].Name != "name" {
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
}
}
func TestGetOptionalColumns(t *testing.T) {
columns := []database.Column{
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
}
result := GetOptionalColumns(columns)
// Should include 'email' (nullable) and 'created_at' (has default)
if len(result) != 2 {
t.Errorf("Expected 2 optional columns, got %d", len(result))
}
}
func TestGetForeignKeyForColumn(t *testing.T) {
fks := []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users"},
{ColumnName: "company_id", ForeignTableName: "companies"},
}
tests := []struct {
name string
columnName string
expectNil bool
expectFK string
}{
{"found user_id", "user_id", false, "users"},
{"found company_id", "company_id", false, "companies"},
{"not found", "status_id", true, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetForeignKeyForColumn(tt.columnName, fks)
if tt.expectNil {
if result != nil {
t.Errorf("Expected nil, got %+v", result)
}
} else {
if result == nil {
t.Errorf("Expected FK, got nil")
} else if result.ForeignTableName != tt.expectFK {
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
}
}
})
}
}
func TestGenerateInit(t *testing.T) {
ctx := &Context{
EntityName: "User",
ModuleName: "user",
}
result, err := GenerateInit(ctx)
if err != nil {
t.Fatalf("GenerateInit failed: %v", err)
}
// __init__.py should be empty
if result != "" {
t.Errorf("Expected empty __init__.py, got %q", result)
}
}
func TestGenerateRepository(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{},
}
ctx := NewContext(tableInfo, "")
result, err := GenerateRepository(ctx)
if err != nil {
t.Fatalf("GenerateRepository failed: %v", err)
}
// Check for expected content
expectedStrings := []string{
"class UserRepository",
"CRUDRepository",
"model_cls = User",
"from televend_core.databases.base_repository import CRUDRepository",
"from televend_core.databases.televend_repositories.user.filter import",
"from televend_core.databases.televend_repositories.user.model import User",
}
for _, expected := range expectedStrings {
if !strings.Contains(result, expected) {
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
}
func TestGenerateManager(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{},
}
ctx := NewContext(tableInfo, "")
result, err := GenerateManager(ctx)
if err != nil {
t.Fatalf("GenerateManager failed: %v", err)
}
// Check for expected content
expectedStrings := []string{
"class UserManager",
"CRUDManager",
"repository_cls = UserRepository",
}
for _, expected := range expectedStrings {
if !strings.Contains(result, expected) {
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
}
func TestNewContext(t *testing.T) {
tableInfo := &database.TableInfo{
Schema: "public",
TableName: "user_accounts",
}
// Test without override
ctx := NewContext(tableInfo, "")
if ctx.EntityName != "UserAccount" {
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
}
if ctx.ModuleName != "user_account" {
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
}
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
}
// Test with override
ctx = NewContext(tableInfo, "CustomUser")
if ctx.EntityName != "CustomUser" {
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
}
}
func TestGenerateTable(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "simple table with basic types",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from sqlalchemy import",
"Column",
"Table",
"Integer",
"String",
"USER_TABLE = Table(",
`"users"`,
"metadata_obj",
`Column("id", Integer, primary_key=True, autoincrement=True)`,
`Column("name", String(255), nullable=False)`,
`Column("email", String)`,
},
},
{
name: "table with foreign keys",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"ForeignKey",
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
},
},
{
name: "table with enum",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{
"order_status": {
TypeName: "order_status",
Values: []string{"pending", "completed", "cancelled"},
},
},
},
expected: []string{
"Enum",
"from televend_core.databases.televend_repositories.order.enum import",
"OrderStatus",
`Enum(`,
`*OrderStatus.to_value_list()`,
`name="order_status"`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateTable(ctx)
if err != nil {
t.Fatalf("GenerateTable failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateModel(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "simple model",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", IsNullable: false},
{Name: "email", DataType: "varchar", IsNullable: true},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from dataclasses import dataclass",
"from televend_core.databases.base_model import Base",
"@dataclass",
"class User(Base):",
"name: str",
"email: str | None = None",
"id: int | None = None",
},
},
{
name: "model with foreign key",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text", IsNullable: false},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from televend_core.databases.televend_repositories.user.model import User",
"user_id: int",
"user: User",
"title: str",
},
},
{
name: "model with datetime and decimal",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "amount", DataType: "numeric", IsNullable: false},
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from datetime import datetime",
"from decimal import Decimal",
"amount: Decimal",
"created_at: datetime | None = None",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateModel(ctx)
if err != nil {
t.Fatalf("GenerateModel failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateFilter(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
notExpect []string
}{
{
name: "filter with boolean and id fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "name", DataType: "varchar"},
{Name: "alive", DataType: "boolean"},
{Name: "user_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"from televend_core.databases.base_filter import BaseFilter",
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
"class UserFilter(BaseFilter):",
"model_cls = User",
"id: int | None = filterfield(operator=EQ)",
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
"name: str | None = filterfield(operator=EQ)",
"alive: bool | None = filterfield(operator=EQ, default=True)",
"user_id: int | None = filterfield(operator=EQ)",
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
},
notExpect: []string{
"default=None",
},
},
{
name: "filter with no filterable fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "logs",
Columns: []database.Column{
{Name: "timestamp", DataType: "timestamp with time zone"},
{Name: "amount", DataType: "numeric"},
},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"class LogFilter(BaseFilter):",
"pass",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateFilter(ctx)
if err != nil {
t.Fatalf("GenerateFilter failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
for _, notExpected := range tt.notExpect {
if strings.Contains(result, notExpected) {
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
notExpected, result)
}
}
})
}
}
func TestGenerateLoadOptions(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "load options with relationships",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "user_id", DataType: "integer"},
{Name: "category_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
},
},
expected: []string{
"from televend_core.databases.base_load_options import LoadOptions",
"from televend_core.databases.common.load_options import joinload",
"class PostLoadOptions(LoadOptions):",
"model_cls = Post",
`load_user: bool = joinload(relations=["user"])`,
`load_category: bool = joinload(relations=["category"])`,
},
},
{
name: "load options with no relationships",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "settings",
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
ForeignKeys: []database.ForeignKey{},
},
expected: []string{
"class SettingLoadOptions(LoadOptions):",
"pass",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateLoadOptions(ctx)
if err != nil {
t.Fatalf("GenerateLoadOptions failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateFactory(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "factory with basic fields",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "users",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
{Name: "alive", DataType: "boolean"},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from __future__ import annotations",
"from typing import Type",
"import factory",
"class UserFactory(TelevendBaseFactory):",
"alive = True",
"id = None",
`name = factory.Faker("pystr", max_chars=100)`,
"class Meta:",
"model = User",
"def create_minimal",
},
},
{
name: "factory with foreign keys",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "user_id", DataType: "integer", IsNullable: false},
{Name: "title", DataType: "text"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
"from televend_core.databases.televend_repositories.user.factory import",
"UserFactory",
`user = CustomSelfAttribute("..user", UserFactory)`,
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
},
},
{
name: "factory with decimal field",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "orders",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
},
ForeignKeys: []database.ForeignKey{},
EnumTypes: map[string]database.EnumType{},
},
expected: []string{
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateFactory(ctx)
if err != nil {
t.Fatalf("GenerateFactory failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGenerateMapper(t *testing.T) {
tests := []struct {
name string
tableInfo *database.TableInfo
expected []string
}{
{
name: "mapper with single foreign key",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "posts",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "user_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
},
expected: []string{
"mapper_registry.map_imperatively(",
"class_=Post,",
"local_table=POST_TABLE,",
"properties={",
`"user": relationship(`,
"User, lazy=relationship_loading_strategy.value",
},
},
{
name: "mapper with multiple foreign keys to same table",
tableInfo: &database.TableInfo{
Schema: "public",
TableName: "messages",
Columns: []database.Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "sender_id", DataType: "integer"},
{Name: "receiver_id", DataType: "integer"},
},
ForeignKeys: []database.ForeignKey{
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
},
},
expected: []string{
`"sender": relationship(`,
"User, lazy=relationship_loading_strategy.value",
`"receiver": relationship(`,
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := NewContext(tt.tableInfo, "")
result, err := GenerateMapper(ctx)
if err != nil {
t.Fatalf("GenerateMapper failed: %v", err)
}
for _, expected := range tt.expected {
if !strings.Contains(result, expected) {
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
expected, result)
}
}
})
}
}
func TestGetPythonTypeForColumn(t *testing.T) {
ctx := &Context{
TableInfo: &database.TableInfo{
EnumTypes: map[string]database.EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"active", "inactive"},
},
},
},
}
tests := []struct {
name string
col database.Column
expected string
}{
{
name: "integer type",
col: database.Column{DataType: "integer"},
expected: "int",
},
{
name: "varchar type",
col: database.Column{DataType: "varchar"},
expected: "str",
},
{
name: "enum type",
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
expected: "StatusEnum",
},
{
name: "unknown enum type",
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
expected: "str",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPythonTypeForColumn(tt.col, ctx)
if result != tt.expected {
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}