This commit is contained in:
Eden Kirin
2025-10-31 19:04:40 +01:00
parent ca10b01fb0
commit 4e4827d640
12 changed files with 2612 additions and 10 deletions

View File

@ -0,0 +1,341 @@
package database
import (
"database/sql"
"testing"
)
func TestGetPythonType(t *testing.T) {
tests := []struct {
name string
col Column
expected string
}{
// Integer types
{"integer", Column{DataType: "integer"}, "int"},
{"smallint", Column{DataType: "smallint"}, "int"},
{"bigint", Column{DataType: "bigint"}, "int"},
// Numeric types
{"numeric", Column{DataType: "numeric"}, "Decimal"},
{"decimal", Column{DataType: "decimal"}, "Decimal"},
{"real", Column{DataType: "real"}, "Decimal"},
{"double precision", Column{DataType: "double precision"}, "Decimal"},
// Boolean
{"boolean", Column{DataType: "boolean"}, "bool"},
// String types
{"varchar", Column{DataType: "character varying"}, "str"},
{"varchar short", Column{DataType: "varchar"}, "str"},
{"text", Column{DataType: "text"}, "str"},
{"char", Column{DataType: "char"}, "str"},
{"character", Column{DataType: "character"}, "str"},
// Date/Time types
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"},
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"},
{"timestamp", Column{DataType: "timestamp"}, "datetime"},
{"date", Column{DataType: "date"}, "date"},
{"time with tz", Column{DataType: "time with time zone"}, "time"},
{"time without tz", Column{DataType: "time without time zone"}, "time"},
{"time", Column{DataType: "time"}, "time"},
// JSON types
{"json", Column{DataType: "json"}, "dict"},
{"jsonb", Column{DataType: "jsonb"}, "dict"},
// Other types
{"uuid", Column{DataType: "uuid"}, "UUID"},
{"bytea", Column{DataType: "bytea"}, "bytes"},
// User-defined (enum)
{"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"},
// Unknown type
{"unknown", Column{DataType: "unknown_type"}, "Any"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetPythonType(tt.col)
if result != tt.expected {
t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}
func TestGetSQLAlchemyType(t *testing.T) {
tests := []struct {
name string
col Column
expected string
}{
// Integer types
{"integer", Column{DataType: "integer"}, "Integer"},
{"smallint", Column{DataType: "smallint"}, "SmallInteger"},
{"bigint", Column{DataType: "bigint"}, "BigInteger"},
// Numeric types with precision
{
"numeric with precision",
Column{
DataType: "numeric",
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
},
"Numeric(12, 4)",
},
{
"numeric without precision",
Column{DataType: "numeric"},
"Numeric",
},
{"real", Column{DataType: "real"}, "Float"},
{"double precision", Column{DataType: "double precision"}, "Float"},
// Boolean
{"boolean", Column{DataType: "boolean"}, "Boolean"},
// String types
{
"varchar with length",
Column{
DataType: "character varying",
CharMaxLength: sql.NullInt64{Valid: true, Int64: 255},
},
"String(255)",
},
{
"varchar without length",
Column{DataType: "varchar"},
"String",
},
{
"char with length",
Column{
DataType: "char",
CharMaxLength: sql.NullInt64{Valid: true, Int64: 10},
},
"String(10)",
},
{
"char without length",
Column{DataType: "character"},
"String(1)",
},
{"text", Column{DataType: "text"}, "Text"},
// Date/Time types
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"},
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"},
{"timestamp", Column{DataType: "timestamp"}, "DateTime"},
{"date", Column{DataType: "date"}, "Date"},
{"time with tz", Column{DataType: "time with time zone"}, "Time"},
{"time without tz", Column{DataType: "time without time zone"}, "Time"},
{"time", Column{DataType: "time"}, "Time"},
// JSON types
{"json", Column{DataType: "json"}, "JSON"},
{"jsonb", Column{DataType: "jsonb"}, "JSONB"},
// Other types
{"uuid", Column{DataType: "uuid"}, "UUID"},
{"bytea", Column{DataType: "bytea"}, "LargeBinary"},
// User-defined (enum)
{"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"},
// Unknown type
{"unknown", Column{DataType: "unknown_type"}, "String"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := GetSQLAlchemyType(tt.col)
if result != tt.expected {
t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected)
}
})
}
}
func TestColumn(t *testing.T) {
col := Column{
Name: "test_column",
DataType: "varchar",
IsNullable: true,
ColumnDefault: sql.NullString{Valid: true, String: "default_value"},
CharMaxLength: sql.NullInt64{Valid: true, Int64: 100},
NumericPrecision: sql.NullInt64{Valid: false},
NumericScale: sql.NullInt64{Valid: false},
UdtName: "",
IsPrimaryKey: false,
IsAutoIncrement: false,
}
if col.Name != "test_column" {
t.Errorf("Expected Name 'test_column', got %q", col.Name)
}
if !col.IsNullable {
t.Error("Expected IsNullable to be true")
}
if !col.ColumnDefault.Valid {
t.Error("Expected ColumnDefault to be valid")
}
if col.ColumnDefault.String != "default_value" {
t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String)
}
}
func TestForeignKey(t *testing.T) {
fk := ForeignKey{
ColumnName: "user_id",
ForeignTableSchema: "public",
ForeignTableName: "users",
ForeignColumnName: "id",
ConstraintName: "fk_user_id",
}
if fk.ColumnName != "user_id" {
t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName)
}
if fk.ForeignTableName != "users" {
t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName)
}
if fk.ForeignColumnName != "id" {
t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName)
}
}
func TestEnumType(t *testing.T) {
enum := EnumType{
TypeName: "status_enum",
Values: []string{"OPEN", "CLOSED", "PENDING"},
}
if enum.TypeName != "status_enum" {
t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName)
}
if len(enum.Values) != 3 {
t.Errorf("Expected 3 values, got %d", len(enum.Values))
}
expectedValues := []string{"OPEN", "CLOSED", "PENDING"}
for i, val := range enum.Values {
if val != expectedValues[i] {
t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val)
}
}
}
func TestTableInfo(t *testing.T) {
tableInfo := &TableInfo{
Schema: "public",
TableName: "users",
Columns: []Column{
{Name: "id", DataType: "integer", IsPrimaryKey: true},
{Name: "name", DataType: "varchar"},
},
ForeignKeys: []ForeignKey{
{ColumnName: "company_id", ForeignTableName: "companies"},
},
EnumTypes: map[string]EnumType{
"status_enum": {
TypeName: "status_enum",
Values: []string{"ACTIVE", "INACTIVE"},
},
},
}
if tableInfo.Schema != "public" {
t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema)
}
if tableInfo.TableName != "users" {
t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName)
}
if len(tableInfo.Columns) != 2 {
t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns))
}
if len(tableInfo.ForeignKeys) != 1 {
t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys))
}
if len(tableInfo.EnumTypes) != 1 {
t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes))
}
// Test primary key detection
foundPK := false
for _, col := range tableInfo.Columns {
if col.IsPrimaryKey {
foundPK = true
if col.Name != "id" {
t.Errorf("Expected primary key to be 'id', got %q", col.Name)
}
}
}
if !foundPK {
t.Error("Expected to find primary key column")
}
}
func TestConfig(t *testing.T) {
cfg := Config{
Host: "localhost",
Port: 5432,
Database: "testdb",
Schema: "public",
User: "testuser",
Password: "testpass",
}
if cfg.Host != "localhost" {
t.Errorf("Expected Host 'localhost', got %q", cfg.Host)
}
if cfg.Port != 5432 {
t.Errorf("Expected Port 5432, got %d", cfg.Port)
}
if cfg.Database != "testdb" {
t.Errorf("Expected Database 'testdb', got %q", cfg.Database)
}
if cfg.Schema != "public" {
t.Errorf("Expected Schema 'public', got %q", cfg.Schema)
}
}
// Benchmark tests
func BenchmarkGetPythonType(b *testing.B) {
col := Column{DataType: "character varying"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
GetPythonType(col)
}
}
func BenchmarkGetSQLAlchemyType(b *testing.B) {
col := Column{
DataType: "numeric",
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
GetSQLAlchemyType(col)
}
}