342 lines
8.7 KiB
Go
342 lines
8.7 KiB
Go
package database
|
|
|
|
import (
|
|
"database/sql"
|
|
"testing"
|
|
)
|
|
|
|
func TestGetPythonType(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
col Column
|
|
expected string
|
|
}{
|
|
// Integer types
|
|
{"integer", Column{DataType: "integer"}, "int"},
|
|
{"smallint", Column{DataType: "smallint"}, "int"},
|
|
{"bigint", Column{DataType: "bigint"}, "int"},
|
|
|
|
// Numeric types
|
|
{"numeric", Column{DataType: "numeric"}, "Decimal"},
|
|
{"decimal", Column{DataType: "decimal"}, "Decimal"},
|
|
{"real", Column{DataType: "real"}, "Decimal"},
|
|
{"double precision", Column{DataType: "double precision"}, "Decimal"},
|
|
|
|
// Boolean
|
|
{"boolean", Column{DataType: "boolean"}, "bool"},
|
|
|
|
// String types
|
|
{"varchar", Column{DataType: "character varying"}, "str"},
|
|
{"varchar short", Column{DataType: "varchar"}, "str"},
|
|
{"text", Column{DataType: "text"}, "str"},
|
|
{"char", Column{DataType: "char"}, "str"},
|
|
{"character", Column{DataType: "character"}, "str"},
|
|
|
|
// Date/Time types
|
|
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"},
|
|
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"},
|
|
{"timestamp", Column{DataType: "timestamp"}, "datetime"},
|
|
{"date", Column{DataType: "date"}, "date"},
|
|
{"time with tz", Column{DataType: "time with time zone"}, "time"},
|
|
{"time without tz", Column{DataType: "time without time zone"}, "time"},
|
|
{"time", Column{DataType: "time"}, "time"},
|
|
|
|
// JSON types
|
|
{"json", Column{DataType: "json"}, "dict"},
|
|
{"jsonb", Column{DataType: "jsonb"}, "dict"},
|
|
|
|
// Other types
|
|
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
|
{"bytea", Column{DataType: "bytea"}, "bytes"},
|
|
|
|
// User-defined (enum)
|
|
{"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"},
|
|
|
|
// Unknown type
|
|
{"unknown", Column{DataType: "unknown_type"}, "Any"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetPythonType(tt.col)
|
|
if result != tt.expected {
|
|
t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGetSQLAlchemyType(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
col Column
|
|
expected string
|
|
}{
|
|
// Integer types
|
|
{"integer", Column{DataType: "integer"}, "Integer"},
|
|
{"smallint", Column{DataType: "smallint"}, "SmallInteger"},
|
|
{"bigint", Column{DataType: "bigint"}, "BigInteger"},
|
|
|
|
// Numeric types with precision
|
|
{
|
|
"numeric with precision",
|
|
Column{
|
|
DataType: "numeric",
|
|
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
|
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
|
},
|
|
"Numeric(12, 4)",
|
|
},
|
|
{
|
|
"numeric without precision",
|
|
Column{DataType: "numeric"},
|
|
"Numeric",
|
|
},
|
|
{"real", Column{DataType: "real"}, "Float"},
|
|
{"double precision", Column{DataType: "double precision"}, "Float"},
|
|
|
|
// Boolean
|
|
{"boolean", Column{DataType: "boolean"}, "Boolean"},
|
|
|
|
// String types
|
|
{
|
|
"varchar with length",
|
|
Column{
|
|
DataType: "character varying",
|
|
CharMaxLength: sql.NullInt64{Valid: true, Int64: 255},
|
|
},
|
|
"String(255)",
|
|
},
|
|
{
|
|
"varchar without length",
|
|
Column{DataType: "varchar"},
|
|
"String",
|
|
},
|
|
{
|
|
"char with length",
|
|
Column{
|
|
DataType: "char",
|
|
CharMaxLength: sql.NullInt64{Valid: true, Int64: 10},
|
|
},
|
|
"String(10)",
|
|
},
|
|
{
|
|
"char without length",
|
|
Column{DataType: "character"},
|
|
"String(1)",
|
|
},
|
|
{"text", Column{DataType: "text"}, "Text"},
|
|
|
|
// Date/Time types
|
|
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"},
|
|
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"},
|
|
{"timestamp", Column{DataType: "timestamp"}, "DateTime"},
|
|
{"date", Column{DataType: "date"}, "Date"},
|
|
{"time with tz", Column{DataType: "time with time zone"}, "Time"},
|
|
{"time without tz", Column{DataType: "time without time zone"}, "Time"},
|
|
{"time", Column{DataType: "time"}, "Time"},
|
|
|
|
// JSON types
|
|
{"json", Column{DataType: "json"}, "JSON"},
|
|
{"jsonb", Column{DataType: "jsonb"}, "JSONB"},
|
|
|
|
// Other types
|
|
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
|
{"bytea", Column{DataType: "bytea"}, "LargeBinary"},
|
|
|
|
// User-defined (enum)
|
|
{"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"},
|
|
|
|
// Unknown type
|
|
{"unknown", Column{DataType: "unknown_type"}, "String"},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
result := GetSQLAlchemyType(tt.col)
|
|
if result != tt.expected {
|
|
t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestColumn(t *testing.T) {
|
|
col := Column{
|
|
Name: "test_column",
|
|
DataType: "varchar",
|
|
IsNullable: true,
|
|
ColumnDefault: sql.NullString{Valid: true, String: "default_value"},
|
|
CharMaxLength: sql.NullInt64{Valid: true, Int64: 100},
|
|
NumericPrecision: sql.NullInt64{Valid: false},
|
|
NumericScale: sql.NullInt64{Valid: false},
|
|
UdtName: "",
|
|
IsPrimaryKey: false,
|
|
IsAutoIncrement: false,
|
|
}
|
|
|
|
if col.Name != "test_column" {
|
|
t.Errorf("Expected Name 'test_column', got %q", col.Name)
|
|
}
|
|
|
|
if !col.IsNullable {
|
|
t.Error("Expected IsNullable to be true")
|
|
}
|
|
|
|
if !col.ColumnDefault.Valid {
|
|
t.Error("Expected ColumnDefault to be valid")
|
|
}
|
|
|
|
if col.ColumnDefault.String != "default_value" {
|
|
t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String)
|
|
}
|
|
}
|
|
|
|
func TestForeignKey(t *testing.T) {
|
|
fk := ForeignKey{
|
|
ColumnName: "user_id",
|
|
ForeignTableSchema: "public",
|
|
ForeignTableName: "users",
|
|
ForeignColumnName: "id",
|
|
ConstraintName: "fk_user_id",
|
|
}
|
|
|
|
if fk.ColumnName != "user_id" {
|
|
t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName)
|
|
}
|
|
|
|
if fk.ForeignTableName != "users" {
|
|
t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName)
|
|
}
|
|
|
|
if fk.ForeignColumnName != "id" {
|
|
t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName)
|
|
}
|
|
}
|
|
|
|
func TestEnumType(t *testing.T) {
|
|
enum := EnumType{
|
|
TypeName: "status_enum",
|
|
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
|
}
|
|
|
|
if enum.TypeName != "status_enum" {
|
|
t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName)
|
|
}
|
|
|
|
if len(enum.Values) != 3 {
|
|
t.Errorf("Expected 3 values, got %d", len(enum.Values))
|
|
}
|
|
|
|
expectedValues := []string{"OPEN", "CLOSED", "PENDING"}
|
|
for i, val := range enum.Values {
|
|
if val != expectedValues[i] {
|
|
t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestTableInfo(t *testing.T) {
|
|
tableInfo := &TableInfo{
|
|
Schema: "public",
|
|
TableName: "users",
|
|
Columns: []Column{
|
|
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
|
{Name: "name", DataType: "varchar"},
|
|
},
|
|
ForeignKeys: []ForeignKey{
|
|
{ColumnName: "company_id", ForeignTableName: "companies"},
|
|
},
|
|
EnumTypes: map[string]EnumType{
|
|
"status_enum": {
|
|
TypeName: "status_enum",
|
|
Values: []string{"ACTIVE", "INACTIVE"},
|
|
},
|
|
},
|
|
}
|
|
|
|
if tableInfo.Schema != "public" {
|
|
t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema)
|
|
}
|
|
|
|
if tableInfo.TableName != "users" {
|
|
t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName)
|
|
}
|
|
|
|
if len(tableInfo.Columns) != 2 {
|
|
t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns))
|
|
}
|
|
|
|
if len(tableInfo.ForeignKeys) != 1 {
|
|
t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys))
|
|
}
|
|
|
|
if len(tableInfo.EnumTypes) != 1 {
|
|
t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes))
|
|
}
|
|
|
|
// Test primary key detection
|
|
foundPK := false
|
|
for _, col := range tableInfo.Columns {
|
|
if col.IsPrimaryKey {
|
|
foundPK = true
|
|
if col.Name != "id" {
|
|
t.Errorf("Expected primary key to be 'id', got %q", col.Name)
|
|
}
|
|
}
|
|
}
|
|
if !foundPK {
|
|
t.Error("Expected to find primary key column")
|
|
}
|
|
}
|
|
|
|
func TestConfig(t *testing.T) {
|
|
cfg := Config{
|
|
Host: "localhost",
|
|
Port: 5432,
|
|
Database: "testdb",
|
|
Schema: "public",
|
|
User: "testuser",
|
|
Password: "testpass",
|
|
}
|
|
|
|
if cfg.Host != "localhost" {
|
|
t.Errorf("Expected Host 'localhost', got %q", cfg.Host)
|
|
}
|
|
|
|
if cfg.Port != 5432 {
|
|
t.Errorf("Expected Port 5432, got %d", cfg.Port)
|
|
}
|
|
|
|
if cfg.Database != "testdb" {
|
|
t.Errorf("Expected Database 'testdb', got %q", cfg.Database)
|
|
}
|
|
|
|
if cfg.Schema != "public" {
|
|
t.Errorf("Expected Schema 'public', got %q", cfg.Schema)
|
|
}
|
|
}
|
|
|
|
// Benchmark tests
|
|
func BenchmarkGetPythonType(b *testing.B) {
|
|
col := Column{DataType: "character varying"}
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
GetPythonType(col)
|
|
}
|
|
}
|
|
|
|
func BenchmarkGetSQLAlchemyType(b *testing.B) {
|
|
col := Column{
|
|
DataType: "numeric",
|
|
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
|
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
|
}
|
|
b.ResetTimer()
|
|
|
|
for i := 0; i < b.N; i++ {
|
|
GetSQLAlchemyType(col)
|
|
}
|
|
}
|