980 lines
26 KiB
Go
980 lines
26 KiB
Go
package generator
|
|
|
|
import (
|
|
"database/sql"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/entity-maker/entity-maker/internal/database"
|
|
)
|
|
|
|
func TestGetRelationshipName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{"with _id suffix", "author_info_id", "author_info"},
|
|
{"with user_id", "user_id", "user"},
|
|
{"without _id", "status", "status"},
|
|
{"just id", "id", "id"}, // "id" doesn't have "_id" suffix, so returns as-is
|
|
{"empty", "", ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetRelationshipName(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("GetRelationshipName(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetRelationshipEntityName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{"plural table", "users", "User"},
|
|
{"compound plural", "user_accounts", "UserAccount"},
|
|
{"ies ending", "companies", "Company"},
|
|
{"already singular", "user", "User"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetRelationshipEntityName(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("GetRelationshipEntityName(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetRelationshipModuleName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
expected string
|
|
}{
|
|
{"plural table", "users", "user"},
|
|
{"compound plural", "user_accounts", "user_account"},
|
|
{"ies ending", "companies", "company"},
|
|
{"already singular", "user", "user"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetRelationshipModuleName(tt.input)
|
|
if result != tt.expected {
|
|
t.Errorf("GetRelationshipModuleName(%q) = %q, want %q", tt.input, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetFilterFieldName(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
columnName string
|
|
useIN bool
|
|
expected string
|
|
}{
|
|
{"id with IN", "id", true, "ids"},
|
|
{"id without IN", "id", false, "id"},
|
|
{"user_id with IN", "user_id", true, "user_ids"},
|
|
{"user_id without IN", "user_id", false, "user_id"},
|
|
{"machine_id with IN", "machine_id", true, "machine_ids"},
|
|
{"status without IN", "status", false, "status"},
|
|
{"status with IN", "status", true, "status"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetFilterFieldName(tt.columnName, tt.useIN)
|
|
if result != tt.expected {
|
|
t.Errorf("GetFilterFieldName(%q, %v) = %q, want %q",
|
|
tt.columnName, tt.useIN, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestShouldGenerateFilter(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
col database.Column
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "boolean field",
|
|
col: database.Column{
|
|
Name: "alive",
|
|
DataType: "boolean",
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "id field",
|
|
col: database.Column{
|
|
Name: "id",
|
|
DataType: "integer",
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "foreign key field",
|
|
col: database.Column{
|
|
Name: "user_id",
|
|
DataType: "integer",
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "text field",
|
|
col: database.Column{
|
|
Name: "name",
|
|
DataType: "text",
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "varchar field",
|
|
col: database.Column{
|
|
Name: "email",
|
|
DataType: "character varying",
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "numeric field",
|
|
col: database.Column{
|
|
Name: "amount",
|
|
DataType: "numeric",
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "timestamp field",
|
|
col: database.Column{
|
|
Name: "created_at",
|
|
DataType: "timestamp with time zone",
|
|
},
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := ShouldGenerateFilter(tt.col)
|
|
if result != tt.expected {
|
|
t.Errorf("ShouldGenerateFilter(%+v) = %v, want %v", tt.col, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNeedsDecimalImport(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
columns []database.Column
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "has numeric column",
|
|
columns: []database.Column{
|
|
{Name: "amount", DataType: "numeric"},
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "has decimal column",
|
|
columns: []database.Column{
|
|
{Name: "price", DataType: "decimal"},
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "no numeric columns",
|
|
columns: []database.Column{
|
|
{Name: "id", DataType: "integer"},
|
|
{Name: "name", DataType: "varchar"},
|
|
},
|
|
expected: false,
|
|
},
|
|
{
|
|
name: "empty columns",
|
|
columns: []database.Column{},
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := NeedsDecimalImport(tt.columns)
|
|
if result != tt.expected {
|
|
t.Errorf("NeedsDecimalImport() = %v, want %v", result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestNeedsDatetimeImport(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
columns []database.Column
|
|
expected bool
|
|
}{
|
|
{
|
|
name: "has timestamp column",
|
|
columns: []database.Column{
|
|
{Name: "created_at", DataType: "timestamp with time zone"},
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "has date column",
|
|
columns: []database.Column{
|
|
{Name: "birth_date", DataType: "date"},
|
|
},
|
|
expected: true,
|
|
},
|
|
{
|
|
name: "no datetime columns",
|
|
columns: []database.Column{
|
|
{Name: "id", DataType: "integer"},
|
|
{Name: "name", DataType: "varchar"},
|
|
},
|
|
expected: false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := NeedsDatetimeImport(tt.columns)
|
|
if result != tt.expected {
|
|
t.Errorf("NeedsDatetimeImport() = %v, want %v", result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetRequiredColumns(t *testing.T) {
|
|
columns := []database.Column{
|
|
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
|
{Name: "name", DataType: "varchar", IsNullable: false, ColumnDefault: sql.NullString{Valid: false}},
|
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
|
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
|
{Name: "count", DataType: "integer", IsNullable: false, IsAutoIncrement: true},
|
|
}
|
|
|
|
result := GetRequiredColumns(columns)
|
|
|
|
// Should only include 'name' (not nullable, no default, not PK, not auto-increment)
|
|
if len(result) != 1 {
|
|
t.Errorf("Expected 1 required column, got %d", len(result))
|
|
}
|
|
|
|
if len(result) > 0 && result[0].Name != "name" {
|
|
t.Errorf("Expected required column to be 'name', got %q", result[0].Name)
|
|
}
|
|
}
|
|
|
|
func TestGetOptionalColumns(t *testing.T) {
|
|
columns := []database.Column{
|
|
{Name: "id", DataType: "integer", IsNullable: false, IsPrimaryKey: true},
|
|
{Name: "name", DataType: "varchar", IsNullable: false},
|
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
|
{Name: "created_at", DataType: "timestamp", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
|
}
|
|
|
|
result := GetOptionalColumns(columns)
|
|
|
|
// Should include 'email' (nullable) and 'created_at' (has default)
|
|
if len(result) != 2 {
|
|
t.Errorf("Expected 2 optional columns, got %d", len(result))
|
|
}
|
|
}
|
|
|
|
func TestGetForeignKeyForColumn(t *testing.T) {
|
|
fks := []database.ForeignKey{
|
|
{ColumnName: "user_id", ForeignTableName: "users"},
|
|
{ColumnName: "company_id", ForeignTableName: "companies"},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
columnName string
|
|
expectNil bool
|
|
expectFK string
|
|
}{
|
|
{"found user_id", "user_id", false, "users"},
|
|
{"found company_id", "company_id", false, "companies"},
|
|
{"not found", "status_id", true, ""},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetForeignKeyForColumn(tt.columnName, fks)
|
|
if tt.expectNil {
|
|
if result != nil {
|
|
t.Errorf("Expected nil, got %+v", result)
|
|
}
|
|
} else {
|
|
if result == nil {
|
|
t.Errorf("Expected FK, got nil")
|
|
} else if result.ForeignTableName != tt.expectFK {
|
|
t.Errorf("Expected FK table %q, got %q", tt.expectFK, result.ForeignTableName)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateInit(t *testing.T) {
|
|
ctx := &Context{
|
|
EntityName: "User",
|
|
ModuleName: "user",
|
|
}
|
|
|
|
result, err := GenerateInit(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateInit failed: %v", err)
|
|
}
|
|
|
|
// __init__.py should be empty
|
|
if result != "" {
|
|
t.Errorf("Expected empty __init__.py, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestGenerateRepository(t *testing.T) {
|
|
tableInfo := &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []database.Column{},
|
|
}
|
|
|
|
ctx := NewContext(tableInfo, "")
|
|
|
|
result, err := GenerateRepository(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateRepository failed: %v", err)
|
|
}
|
|
|
|
// Check for expected content
|
|
expectedStrings := []string{
|
|
"class UserRepository",
|
|
"CRUDRepository",
|
|
"model_cls = User",
|
|
"from televend_core.databases.base_repository import CRUDRepository",
|
|
"from televend_core.databases.cloud_repositories.user.filter import",
|
|
"from televend_core.databases.cloud_repositories.user.model import User",
|
|
}
|
|
|
|
for _, expected := range expectedStrings {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected repository to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestGenerateManager(t *testing.T) {
|
|
tableInfo := &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []database.Column{},
|
|
}
|
|
|
|
ctx := NewContext(tableInfo, "")
|
|
|
|
result, err := GenerateManager(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateManager failed: %v", err)
|
|
}
|
|
|
|
// Check for expected content
|
|
expectedStrings := []string{
|
|
"class UserManager",
|
|
"CRUDManager",
|
|
"repository_cls = UserRepository",
|
|
}
|
|
|
|
for _, expected := range expectedStrings {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected manager to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestNewContext(t *testing.T) {
|
|
tableInfo := &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "user_accounts",
|
|
}
|
|
|
|
// Test without override
|
|
ctx := NewContext(tableInfo, "")
|
|
if ctx.EntityName != "UserAccount" {
|
|
t.Errorf("Expected EntityName 'UserAccount', got %q", ctx.EntityName)
|
|
}
|
|
if ctx.ModuleName != "user_account" {
|
|
t.Errorf("Expected ModuleName 'user_account', got %q", ctx.ModuleName)
|
|
}
|
|
if ctx.TableConstant != "USER_ACCOUNT_TABLE" {
|
|
t.Errorf("Expected TableConstant 'USER_ACCOUNT_TABLE', got %q", ctx.TableConstant)
|
|
}
|
|
|
|
// Test with override
|
|
ctx = NewContext(tableInfo, "CustomUser")
|
|
if ctx.EntityName != "CustomUser" {
|
|
t.Errorf("Expected EntityName 'CustomUser', got %q", ctx.EntityName)
|
|
}
|
|
}
|
|
|
|
func TestGenerateTable(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tableInfo *database.TableInfo
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "simple table with basic types",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsAutoIncrement: true, IsNullable: false},
|
|
{Name: "name", DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, IsNullable: false},
|
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"from sqlalchemy import",
|
|
"Column",
|
|
"Table",
|
|
"Integer",
|
|
"String",
|
|
"USER_TABLE = Table(",
|
|
`"users"`,
|
|
"metadata_obj",
|
|
`Column("id", Integer, primary_key=True, autoincrement=True)`,
|
|
`Column("name", String(255), nullable=False)`,
|
|
`Column("email", String)`,
|
|
},
|
|
},
|
|
{
|
|
name: "table with foreign keys",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "posts",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "user_id", DataType: "integer", IsNullable: false},
|
|
{Name: "title", DataType: "text", IsNullable: false},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{
|
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"ForeignKey",
|
|
`ForeignKey("users.id", deferrable=True, initially="DEFERRED")`,
|
|
},
|
|
},
|
|
{
|
|
name: "table with enum",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "orders",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "status", DataType: "USER-DEFINED", UdtName: "order_status", IsNullable: false},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
EnumTypes: map[string]database.EnumType{
|
|
"order_status": {
|
|
TypeName: "order_status",
|
|
Values: []string{"pending", "completed", "cancelled"},
|
|
},
|
|
},
|
|
},
|
|
expected: []string{
|
|
"Enum",
|
|
"from televend_core.databases.cloud_repositories.order.enum import",
|
|
"OrderStatus",
|
|
`Enum(`,
|
|
`*OrderStatus.to_value_list()`,
|
|
`name="order_status"`,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContext(tt.tableInfo, "")
|
|
result, err := GenerateTable(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateTable failed: %v", err)
|
|
}
|
|
|
|
for _, expected := range tt.expected {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected table to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateModel(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tableInfo *database.TableInfo
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "simple model",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "name", DataType: "varchar", IsNullable: false},
|
|
{Name: "email", DataType: "varchar", IsNullable: true},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"from dataclasses import dataclass",
|
|
"from televend_core.databases.base_model import Base",
|
|
"@dataclass",
|
|
"class User(Base):",
|
|
"name: str",
|
|
"email: str | None = None",
|
|
"id: int | None = None",
|
|
},
|
|
},
|
|
{
|
|
name: "model with foreign key",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "posts",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "user_id", DataType: "integer", IsNullable: false},
|
|
{Name: "title", DataType: "text", IsNullable: false},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{
|
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"from televend_core.databases.cloud_repositories.user.model import User",
|
|
"user_id: int",
|
|
"user: User",
|
|
"title: str",
|
|
},
|
|
},
|
|
{
|
|
name: "model with datetime and decimal",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "orders",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "amount", DataType: "numeric", IsNullable: false},
|
|
{Name: "created_at", DataType: "timestamp with time zone", IsNullable: false, ColumnDefault: sql.NullString{Valid: true, String: "now()"}},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"from datetime import datetime",
|
|
"from decimal import Decimal",
|
|
"amount: Decimal",
|
|
"created_at: datetime | None = None",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContext(tt.tableInfo, "")
|
|
result, err := GenerateModel(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateModel failed: %v", err)
|
|
}
|
|
|
|
for _, expected := range tt.expected {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected model to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateFilter(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tableInfo *database.TableInfo
|
|
expected []string
|
|
notExpect []string
|
|
}{
|
|
{
|
|
name: "filter with boolean and id fields",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
|
{Name: "name", DataType: "varchar"},
|
|
{Name: "alive", DataType: "boolean"},
|
|
{Name: "user_id", DataType: "integer"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
},
|
|
expected: []string{
|
|
"from televend_core.databases.base_filter import BaseFilter",
|
|
"from televend_core.databases.common.filters.filters import EQ, IN, filterfield",
|
|
"class UserFilter(BaseFilter):",
|
|
"model_cls = User",
|
|
"id: int | None = filterfield(operator=EQ)",
|
|
"ids: list[int] | None = filterfield(field=\"id\", operator=IN)",
|
|
"name: str | None = filterfield(operator=EQ)",
|
|
"alive: bool | None = filterfield(operator=EQ, default=True)",
|
|
"user_id: int | None = filterfield(operator=EQ)",
|
|
"user_ids: list[int] | None = filterfield(field=\"user_id\", operator=IN)",
|
|
},
|
|
notExpect: []string{
|
|
"default=None",
|
|
},
|
|
},
|
|
{
|
|
name: "filter with no filterable fields",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "logs",
|
|
Columns: []database.Column{
|
|
{Name: "timestamp", DataType: "timestamp with time zone"},
|
|
{Name: "amount", DataType: "numeric"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
},
|
|
expected: []string{
|
|
"class LogFilter(BaseFilter):",
|
|
"pass",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContext(tt.tableInfo, "")
|
|
result, err := GenerateFilter(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateFilter failed: %v", err)
|
|
}
|
|
|
|
for _, expected := range tt.expected {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected filter to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
|
|
for _, notExpected := range tt.notExpect {
|
|
if strings.Contains(result, notExpected) {
|
|
t.Errorf("Did not expect filter to contain %q, but it does.\nGenerated:\n%s",
|
|
notExpected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateLoadOptions(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tableInfo *database.TableInfo
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "load options with relationships",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "posts",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
|
{Name: "user_id", DataType: "integer"},
|
|
{Name: "category_id", DataType: "integer"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{
|
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
{ColumnName: "category_id", ForeignTableName: "categories", ForeignColumnName: "id"},
|
|
},
|
|
},
|
|
expected: []string{
|
|
"from televend_core.databases.base_load_options import LoadOptions",
|
|
"from televend_core.databases.common.load_options import joinload",
|
|
"class PostLoadOptions(LoadOptions):",
|
|
"model_cls = Post",
|
|
`load_user: bool = joinload(relations=["user"])`,
|
|
`load_category: bool = joinload(relations=["category"])`,
|
|
},
|
|
},
|
|
{
|
|
name: "load options with no relationships",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "settings",
|
|
Columns: []database.Column{{Name: "id", DataType: "integer", IsPrimaryKey: true}},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
},
|
|
expected: []string{
|
|
"class SettingLoadOptions(LoadOptions):",
|
|
"pass",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContext(tt.tableInfo, "")
|
|
result, err := GenerateLoadOptions(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateLoadOptions failed: %v", err)
|
|
}
|
|
|
|
for _, expected := range tt.expected {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected load_options to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateFactory(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tableInfo *database.TableInfo
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "factory with basic fields",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "name", DataType: "varchar", CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}},
|
|
{Name: "alive", DataType: "boolean"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"from __future__ import annotations",
|
|
"from typing import Type",
|
|
"import factory",
|
|
"class UserFactory(CloudBaseFactory):",
|
|
"alive = True",
|
|
"id = None",
|
|
`name = factory.Faker("pystr", max_chars=100)`,
|
|
"class Meta:",
|
|
"model = User",
|
|
"def create_minimal",
|
|
},
|
|
},
|
|
{
|
|
name: "factory with foreign keys",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "posts",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "user_id", DataType: "integer", IsNullable: false},
|
|
{Name: "title", DataType: "text"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{
|
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
"from televend_core.databases.cloud_repositories.user.factory import",
|
|
"UserFactory",
|
|
`user = CustomSelfAttribute("..user", UserFactory)`,
|
|
"user_id = factory.LazyAttribute(lambda a: a.user.id if a.user else None)",
|
|
`"user": kwargs.pop("user", None) or UserFactory.create_minimal()`,
|
|
},
|
|
},
|
|
{
|
|
name: "factory with decimal field",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "orders",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true, IsNullable: false},
|
|
{Name: "amount", DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 10}, NumericScale: sql.NullInt64{Valid: true, Int64: 2}},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{},
|
|
EnumTypes: map[string]database.EnumType{},
|
|
},
|
|
expected: []string{
|
|
`amount = factory.Faker("pydecimal", left_digits=8, right_digits=2, positive=True)`,
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContext(tt.tableInfo, "")
|
|
result, err := GenerateFactory(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateFactory failed: %v", err)
|
|
}
|
|
|
|
for _, expected := range tt.expected {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected factory to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateMapper(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
tableInfo *database.TableInfo
|
|
expected []string
|
|
}{
|
|
{
|
|
name: "mapper with single foreign key",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "posts",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
|
{Name: "user_id", DataType: "integer"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{
|
|
{ColumnName: "user_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
},
|
|
},
|
|
expected: []string{
|
|
"mapper_registry.map_imperatively(",
|
|
"class_=Post,",
|
|
"local_table=POST_TABLE,",
|
|
"properties={",
|
|
`"user": relationship(`,
|
|
"User, lazy=relationship_loading_strategy.value",
|
|
},
|
|
},
|
|
{
|
|
name: "mapper with multiple foreign keys to same table",
|
|
tableInfo: &database.TableInfo{
|
|
Schema: "public",
|
|
TableName: "messages",
|
|
Columns: []database.Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
|
{Name: "sender_id", DataType: "integer"},
|
|
{Name: "receiver_id", DataType: "integer"},
|
|
},
|
|
ForeignKeys: []database.ForeignKey{
|
|
{ColumnName: "sender_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
{ColumnName: "receiver_id", ForeignTableName: "users", ForeignColumnName: "id"},
|
|
},
|
|
},
|
|
expected: []string{
|
|
`"sender": relationship(`,
|
|
"User, lazy=relationship_loading_strategy.value",
|
|
`"receiver": relationship(`,
|
|
"foreign_keys=MESSAGE_TABLE.columns.receiver_id,",
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
ctx := NewContext(tt.tableInfo, "")
|
|
result, err := GenerateMapper(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GenerateMapper failed: %v", err)
|
|
}
|
|
|
|
for _, expected := range tt.expected {
|
|
if !strings.Contains(result, expected) {
|
|
t.Errorf("Expected mapper to contain %q, but it doesn't.\nGenerated:\n%s",
|
|
expected, result)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetPythonTypeForColumn(t *testing.T) {
|
|
ctx := &Context{
|
|
TableInfo: &database.TableInfo{
|
|
EnumTypes: map[string]database.EnumType{
|
|
"status_enum": {
|
|
TypeName: "status_enum",
|
|
Values: []string{"active", "inactive"},
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
col database.Column
|
|
expected string
|
|
}{
|
|
{
|
|
name: "integer type",
|
|
col: database.Column{DataType: "integer"},
|
|
expected: "int",
|
|
},
|
|
{
|
|
name: "varchar type",
|
|
col: database.Column{DataType: "varchar"},
|
|
expected: "str",
|
|
},
|
|
{
|
|
name: "enum type",
|
|
col: database.Column{DataType: "USER-DEFINED", UdtName: "status_enum"},
|
|
expected: "StatusEnum",
|
|
},
|
|
{
|
|
name: "unknown enum type",
|
|
col: database.Column{DataType: "USER-DEFINED", UdtName: "unknown_enum"},
|
|
expected: "str",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetPythonTypeForColumn(tt.col, ctx)
|
|
if result != tt.expected {
|
|
t.Errorf("GetPythonTypeForColumn(%+v) = %q, want %q", tt.col, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|