Tests
This commit is contained in:
199
internal/generator/enum_test.go
Normal file
199
internal/generator/enum_test.go
Normal file
@ -0,0 +1,199 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
)
|
||||
|
||||
func TestSanitizePythonIdentifier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"starts with digit", "12_HOUR", "_12_HOUR"},
|
||||
{"starts with letter", "HOUR_12", "HOUR_12"},
|
||||
{"all digits", "24", "_24"},
|
||||
{"underscore first", "_12_HOUR", "_12_HOUR"},
|
||||
{"empty", "", ""},
|
||||
{"single digit", "1", "_1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := sanitizePythonIdentifier(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("sanitizePythonIdentifier(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateEnum(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
enumTypes map[string]database.EnumType
|
||||
expectError bool
|
||||
checkFunc func(t *testing.T, result string)
|
||||
}{
|
||||
{
|
||||
name: "simple enum",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class StatusEnum") {
|
||||
t.Error("Expected class StatusEnum")
|
||||
}
|
||||
if !strings.Contains(result, "OPEN = \"OPEN\"") {
|
||||
t.Error("Expected OPEN = \"OPEN\"")
|
||||
}
|
||||
if !strings.Contains(result, "CLOSED = \"CLOSED\"") {
|
||||
t.Error("Expected CLOSED = \"CLOSED\"")
|
||||
}
|
||||
if !strings.Contains(result, "PENDING = \"PENDING\"") {
|
||||
t.Error("Expected PENDING = \"PENDING\"")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with spaces and hyphens",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"time_format_enum": {
|
||||
TypeName: "time_format_enum",
|
||||
Values: []string{"12-hour", "24-hour"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class TimeFormatEnum") {
|
||||
t.Error("Expected class TimeFormatEnum")
|
||||
}
|
||||
if !strings.Contains(result, "_12_HOUR = \"12-hour\"") {
|
||||
t.Error("Expected _12_HOUR = \"12-hour\" (sanitized)")
|
||||
}
|
||||
if !strings.Contains(result, "_24_HOUR = \"24-hour\"") {
|
||||
t.Error("Expected _24_HOUR = \"24-hour\" (sanitized)")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with duplicates after normalization",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"measurement_enum": {
|
||||
TypeName: "measurement_enum",
|
||||
Values: []string{"international", "INTERNATIONAL", "imperial", "IMPERIAL"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class MeasurementEnum") {
|
||||
t.Error("Expected class MeasurementEnum")
|
||||
}
|
||||
// Should only have one INTERNATIONAL and one IMPERIAL
|
||||
internationalCount := strings.Count(result, "INTERNATIONAL = ")
|
||||
if internationalCount != 1 {
|
||||
t.Errorf("Expected 1 INTERNATIONAL, got %d", internationalCount)
|
||||
}
|
||||
imperialCount := strings.Count(result, "IMPERIAL = ")
|
||||
if imperialCount != 1 {
|
||||
t.Errorf("Expected 1 IMPERIAL, got %d", imperialCount)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty enum",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"empty_enum": {
|
||||
TypeName: "empty_enum",
|
||||
Values: []string{},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class EmptyEnum") {
|
||||
t.Error("Expected class EmptyEnum")
|
||||
}
|
||||
if !strings.Contains(result, "pass") {
|
||||
t.Error("Expected 'pass' for empty enum")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple enums",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"OPEN", "CLOSED"},
|
||||
},
|
||||
"priority_enum": {
|
||||
TypeName: "priority_enum",
|
||||
Values: []string{"HIGH", "LOW"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "class StatusEnum") {
|
||||
t.Error("Expected class StatusEnum")
|
||||
}
|
||||
if !strings.Contains(result, "class PriorityEnum") {
|
||||
t.Error("Expected class PriorityEnum")
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "enum with special characters",
|
||||
enumTypes: map[string]database.EnumType{
|
||||
"special_enum": {
|
||||
TypeName: "special_enum",
|
||||
Values: []string{"IN-PROGRESS", "ON HOLD", "DONE"},
|
||||
},
|
||||
},
|
||||
expectError: false,
|
||||
checkFunc: func(t *testing.T, result string) {
|
||||
if !strings.Contains(result, "IN_PROGRESS = \"IN-PROGRESS\"") {
|
||||
t.Error("Expected IN_PROGRESS = \"IN-PROGRESS\"")
|
||||
}
|
||||
if !strings.Contains(result, "ON_HOLD = \"ON HOLD\"") {
|
||||
t.Error("Expected ON_HOLD = \"ON HOLD\"")
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := &Context{
|
||||
TableInfo: &database.TableInfo{
|
||||
EnumTypes: tt.enumTypes,
|
||||
},
|
||||
EntityName: "TestEntity",
|
||||
ModuleName: "test_entity",
|
||||
}
|
||||
|
||||
result, err := GenerateEnum(ctx)
|
||||
if (err != nil) != tt.expectError {
|
||||
t.Errorf("GenerateEnum() error = %v, expectError %v", err, tt.expectError)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.checkFunc != nil {
|
||||
tt.checkFunc(t, result)
|
||||
}
|
||||
|
||||
// Check common requirements
|
||||
if !strings.Contains(result, "from enum import StrEnum") {
|
||||
t.Error("Expected import of StrEnum")
|
||||
}
|
||||
if !strings.Contains(result, "from televend_core.databases.enum import EnumMixin") {
|
||||
t.Error("Expected import of EnumMixin")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -57,15 +57,11 @@ func GetRelationshipModuleName(tableName string) string {
|
||||
}
|
||||
|
||||
// GetFilterFieldName returns the filter field name for a column
|
||||
// For ID fields with IN operator, pluralizes the name
|
||||
// For ID fields with IN operator, changes _id to _ids (e.g., machine_id -> machine_ids)
|
||||
func GetFilterFieldName(columnName string, useIN bool) string {
|
||||
if useIN && strings.HasSuffix(columnName, "_id") {
|
||||
// Remove _id, pluralize, add back _ids
|
||||
base := columnName[:len(columnName)-3]
|
||||
if base == "" {
|
||||
return "ids"
|
||||
}
|
||||
return naming.Pluralize(base) + "_ids"
|
||||
// Replace _id with _ids
|
||||
return columnName[:len(columnName)-2] + "ids"
|
||||
}
|
||||
if useIN && columnName == "id" {
|
||||
return "ids"
|
||||
|
||||
979
internal/generator/generator_test.go
Normal file
979
internal/generator/generator_test.go
Normal file
@ -0,0 +1,979 @@
|
||||
package generator
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/entity-maker/entity-maker/internal/database"
|
||||
)
|
||||
|
||||
func TestGetRelationshipName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"with _id suffix", "author_info_id", "author_info"},
|
||||
{"with user_id", "user_id", "user"},
|
||||
{"without _id", "status", "status"},
|
||||
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
|
||||
{"empty", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetRelationshipName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRelationshipEntityName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"plural table", "users", "User"},
|
||||
{"compound plural", "user_accounts", "UserAccount"},
|
||||
{"ies ending", "companies", "Company"},
|
||||
{"already singular", "user", "User"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetRelationshipEntityName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRelationshipModuleName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{"plural table", "users", "user"},
|
||||
{"compound plural", "user_accounts", "user_account"},
|
||||
{"ies ending", "companies", "company"},
|
||||
{"already singular", "user", "user"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetRelationshipModuleName(tt.input)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFilterFieldName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
useIN bool
|
||||
expected string
|
||||
}{
|
||||
{"id with IN", "id", true, "ids"},
|
||||
{"id without IN", "id", false, "id"},
|
||||
{"user_id with IN", "user_id", true, "user_ids"},
|
||||
{"user_id without IN", "user_id", false, "user_id"},
|
||||
{"machine_id with IN", "machine_id", true, "machine_ids"},
|
||||
{"status without IN", "status", false, "status"},
|
||||
{"status with IN", "status", true, "status"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetFilterFieldName(tt.columnName, tt.useIN)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
|
||||
tt.columnName, tt.useIN, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldGenerateFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
col database.Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "boolean field",
|
||||
col: database.Column{
|
||||
Name: "alive",
|
||||
DataType: "boolean",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "id field",
|
||||
col: database.Column{
|
||||
Name: "id",
|
||||
DataType: "integer",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "foreign key field",
|
||||
col: database.Column{
|
||||
Name: "user_id",
|
||||
DataType: "integer",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "text field",
|
||||
col: database.Column{
|
||||
Name: "name",
|
||||
DataType: "text",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "varchar field",
|
||||
col: database.Column{
|
||||
Name: "email",
|
||||
DataType: "character varying",
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "numeric field",
|
||||
col: database.Column{
|
||||
Name: "amount",
|
||||
DataType: "numeric",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "timestamp field",
|
||||
col: database.Column{
|
||||
Name: "created_at",
|
||||
DataType: "timestamp with time zone",
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ShouldGenerateFilter(tt.col)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDecimalImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columns []database.Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has numeric column",
|
||||
columns: []database.Column{
|
||||
{Name: "amount", DataType: "numeric"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has decimal column",
|
||||
columns: []database.Column{
|
||||
{Name: "price", DataType: "decimal"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no numeric columns",
|
||||
columns: []database.Column{
|
||||
{Name: "id", DataType: "integer"},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty columns",
|
||||
columns: []database.Column{},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := NeedsDecimalImport(tt.columns)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNeedsDatetimeImport(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
columns []database.Column
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "has timestamp column",
|
||||
columns: []database.Column{
|
||||
{Name: "created_at", DataType: "timestamp with time zone"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has date column",
|
||||
columns: []database.Column{
|
||||
{Name: "birth_date", DataType: "date"},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "no datetime columns",
|
||||
columns: []database.Column{
|
||||
{Name: "id", DataType: "integer"},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := NeedsDatetimeImport(tt.columns)
|
||||
if result != tt.expected {
|
||||
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRequiredColumns(t *testing.T) {
|
||||
columns := []database.Column{
|
||||
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
|
||||
}
|
||||
|
||||
result := GetRequiredColumns(columns)
|
||||
|
||||
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
|
||||
if len(result) != 1 {
|
||||
t.Errorf("Expected 1 required column, got %d", len(result))
|
||||
}
|
||||
|
||||
if len(result) > 0 && result[0].Name != "name" {
|
||||
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOptionalColumns(t *testing.T) {
|
||||
columns := []database.Column{
|
||||
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||
}
|
||||
|
||||
result := GetOptionalColumns(columns)
|
||||
|
||||
// Should include 'email' (nullable) and 'created_at' (has default)
|
||||
if len(result) != 2 {
|
||||
t.Errorf("Expected 2 optional columns, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetForeignKeyForColumn(t *testing.T) {
|
||||
fks := []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users"},
|
||||
{ColumnName: "company_id", ForeignTableName: "companies"},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
columnName string
|
||||
expectNil bool
|
||||
expectFK string
|
||||
}{
|
||||
{"found user_id", "user_id", false, "users"},
|
||||
{"found company_id", "company_id", false, "companies"},
|
||||
{"not found", "status_id", true, ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetForeignKeyForColumn(tt.columnName, fks)
|
||||
if tt.expectNil {
|
||||
if result != nil {
|
||||
t.Errorf("Expected nil, got %+v", result)
|
||||
}
|
||||
} else {
|
||||
if result == nil {
|
||||
t.Errorf("Expected FK, got nil")
|
||||
} else if result.ForeignTableName != tt.expectFK {
|
||||
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateInit(t *testing.T) {
|
||||
ctx := &Context{
|
||||
EntityName: "User",
|
||||
ModuleName: "user",
|
||||
}
|
||||
|
||||
result, err := GenerateInit(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateInit failed: %v", err)
|
||||
}
|
||||
|
||||
// __init__.py should be empty
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty __init__.py, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateRepository(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{},
|
||||
}
|
||||
|
||||
ctx := NewContext(tableInfo, "")
|
||||
|
||||
result, err := GenerateRepository(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateRepository failed: %v", err)
|
||||
}
|
||||
|
||||
// Check for expected content
|
||||
expectedStrings := []string{
|
||||
"class UserRepository",
|
||||
"CRUDRepository",
|
||||
"model_cls = User",
|
||||
"from televend_core.databases.base_repository import CRUDRepository",
|
||||
"from televend_core.databases.televend_repositories.user.filter import",
|
||||
"from televend_core.databases.televend_repositories.user.model import User",
|
||||
}
|
||||
|
||||
for _, expected := range expectedStrings {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateManager(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{},
|
||||
}
|
||||
|
||||
ctx := NewContext(tableInfo, "")
|
||||
|
||||
result, err := GenerateManager(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateManager failed: %v", err)
|
||||
}
|
||||
|
||||
// Check for expected content
|
||||
expectedStrings := []string{
|
||||
"class UserManager",
|
||||
"CRUDManager",
|
||||
"repository_cls = UserRepository",
|
||||
}
|
||||
|
||||
for _, expected := range expectedStrings {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewContext(t *testing.T) {
|
||||
tableInfo := &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "user_accounts",
|
||||
}
|
||||
|
||||
// Test without override
|
||||
ctx := NewContext(tableInfo, "")
|
||||
if ctx.EntityName != "UserAccount" {
|
||||
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
|
||||
}
|
||||
if ctx.ModuleName != "user_account" {
|
||||
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
|
||||
}
|
||||
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
|
||||
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
|
||||
}
|
||||
|
||||
// Test with override
|
||||
ctx = NewContext(tableInfo, "CustomUser")
|
||||
if ctx.EntityName != "CustomUser" {
|
||||
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTable(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "simple table with basic types",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
|
||||
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from sqlalchemy import",
|
||||
"Column",
|
||||
"Table",
|
||||
"Integer",
|
||||
"String",
|
||||
"USER_TABLE = Table(",
|
||||
`"users"`,
|
||||
"metadata_obj",
|
||||
`Column("id", Integer, primary_key=True, autoincrement=True)`,
|
||||
`Column("name", String(255), nullable=False)`,
|
||||
`Column("email", String)`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "table with foreign keys",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||
{Name: "title", DataType: "text", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"ForeignKey",
|
||||
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "table with enum",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "orders",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{
|
||||
"order_status": {
|
||||
TypeName: "order_status",
|
||||
Values: []string{"pending", "completed", "cancelled"},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"Enum",
|
||||
"from televend_core.databases.televend_repositories.order.enum import",
|
||||
"OrderStatus",
|
||||
`Enum(`,
|
||||
`*OrderStatus.to_value_list()`,
|
||||
`name="order_status"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateTable(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateTable failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "simple model",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "name", DataType: "varchar", IsNullable: false},
|
||||
{Name: "email", DataType: "varchar", IsNullable: true},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from dataclasses import dataclass",
|
||||
"from televend_core.databases.base_model import Base",
|
||||
"@dataclass",
|
||||
"class User(Base):",
|
||||
"name: str",
|
||||
"email: str | None = None",
|
||||
"id: int | None = None",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with foreign key",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||
{Name: "title", DataType: "text", IsNullable: false},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.televend_repositories.user.model import User",
|
||||
"user_id: int",
|
||||
"user: User",
|
||||
"title: str",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "model with datetime and decimal",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "orders",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "amount", DataType: "numeric", IsNullable: false},
|
||||
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from datetime import datetime",
|
||||
"from decimal import Decimal",
|
||||
"amount: Decimal",
|
||||
"created_at: datetime | None = None",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateModel(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateModel failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFilter(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
notExpect []string
|
||||
}{
|
||||
{
|
||||
name: "filter with boolean and id fields",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
{Name: "alive", DataType: "boolean"},
|
||||
{Name: "user_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.base_filter import BaseFilter",
|
||||
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
|
||||
"class UserFilter(BaseFilter):",
|
||||
"model_cls = User",
|
||||
"id: int | None = filterfield(operator=EQ)",
|
||||
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
|
||||
"name: str | None = filterfield(operator=EQ)",
|
||||
"alive: bool | None = filterfield(operator=EQ, default=True)",
|
||||
"user_id: int | None = filterfield(operator=EQ)",
|
||||
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
|
||||
},
|
||||
notExpect: []string{
|
||||
"default=None",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter with no filterable fields",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "logs",
|
||||
Columns: []database.Column{
|
||||
{Name: "timestamp", DataType: "timestamp with time zone"},
|
||||
{Name: "amount", DataType: "numeric"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
},
|
||||
expected: []string{
|
||||
"class LogFilter(BaseFilter):",
|
||||
"pass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateFilter(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateFilter failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
|
||||
for _, notExpected := range tt.notExpect {
|
||||
if strings.Contains(result, notExpected) {
|
||||
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
|
||||
notExpected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateLoadOptions(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "load options with relationships",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "user_id", DataType: "integer"},
|
||||
{Name: "category_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.base_load_options import LoadOptions",
|
||||
"from televend_core.databases.common.load_options import joinload",
|
||||
"class PostLoadOptions(LoadOptions):",
|
||||
"model_cls = Post",
|
||||
`load_user: bool = joinload(relations=["user"])`,
|
||||
`load_category: bool = joinload(relations=["category"])`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "load options with no relationships",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "settings",
|
||||
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
},
|
||||
expected: []string{
|
||||
"class SettingLoadOptions(LoadOptions):",
|
||||
"pass",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateLoadOptions(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateLoadOptions failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateFactory(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "factory with basic fields",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
|
||||
{Name: "alive", DataType: "boolean"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from __future__ import annotations",
|
||||
"from typing import Type",
|
||||
"import factory",
|
||||
"class UserFactory(TelevendBaseFactory):",
|
||||
"alive = True",
|
||||
"id = None",
|
||||
`name = factory.Faker("pystr", max_chars=100)`,
|
||||
"class Meta:",
|
||||
"model = User",
|
||||
"def create_minimal",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "factory with foreign keys",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "user_id", DataType: "integer", IsNullable: false},
|
||||
{Name: "title", DataType: "text"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
"from televend_core.databases.televend_repositories.user.factory import",
|
||||
"UserFactory",
|
||||
`user = CustomSelfAttribute("..user", UserFactory)`,
|
||||
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
|
||||
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "factory with decimal field",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "orders",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
||||
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{},
|
||||
EnumTypes: map[string]database.EnumType{},
|
||||
},
|
||||
expected: []string{
|
||||
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateFactory(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateFactory failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateMapper(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
tableInfo *database.TableInfo
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "mapper with single foreign key",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "posts",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "user_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
"mapper_registry.map_imperatively(",
|
||||
"class_=Post,",
|
||||
"local_table=POST_TABLE,",
|
||||
"properties={",
|
||||
`"user": relationship(`,
|
||||
"User, lazy=relationship_loading_strategy.value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mapper with multiple foreign keys to same table",
|
||||
tableInfo: &database.TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "messages",
|
||||
Columns: []database.Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "sender_id", DataType: "integer"},
|
||||
{Name: "receiver_id", DataType: "integer"},
|
||||
},
|
||||
ForeignKeys: []database.ForeignKey{
|
||||
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
||||
},
|
||||
},
|
||||
expected: []string{
|
||||
`"sender": relationship(`,
|
||||
"User, lazy=relationship_loading_strategy.value",
|
||||
`"receiver": relationship(`,
|
||||
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := NewContext(tt.tableInfo, "")
|
||||
result, err := GenerateMapper(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateMapper failed: %v", err)
|
||||
}
|
||||
|
||||
for _, expected := range tt.expected {
|
||||
if !strings.Contains(result, expected) {
|
||||
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
|
||||
expected, result)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPythonTypeForColumn(t *testing.T) {
|
||||
ctx := &Context{
|
||||
TableInfo: &database.TableInfo{
|
||||
EnumTypes: map[string]database.EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"active", "inactive"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
col database.Column
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "integer type",
|
||||
col: database.Column{DataType: "integer"},
|
||||
expected: "int",
|
||||
},
|
||||
{
|
||||
name: "varchar type",
|
||||
col: database.Column{DataType: "varchar"},
|
||||
expected: "str",
|
||||
},
|
||||
{
|
||||
name: "enum type",
|
||||
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
|
||||
expected: "StatusEnum",
|
||||
},
|
||||
{
|
||||
name: "unknown enum type",
|
||||
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
|
||||
expected: "str",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPythonTypeForColumn(tt.col, ctx)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user