Tests
This commit is contained in:
341
internal/database/database_test.go
Normal file
341
internal/database/database_test.go
Normal file
@ -0,0 +1,341 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetPythonType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
col Column
|
||||
expected string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", Column{DataType: "integer"}, "int"},
|
||||
{"smallint", Column{DataType: "smallint"}, "int"},
|
||||
{"bigint", Column{DataType: "bigint"}, "int"},
|
||||
|
||||
// Numeric types
|
||||
{"numeric", Column{DataType: "numeric"}, "Decimal"},
|
||||
{"decimal", Column{DataType: "decimal"}, "Decimal"},
|
||||
{"real", Column{DataType: "real"}, "Decimal"},
|
||||
{"double precision", Column{DataType: "double precision"}, "Decimal"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", Column{DataType: "boolean"}, "bool"},
|
||||
|
||||
// String types
|
||||
{"varchar", Column{DataType: "character varying"}, "str"},
|
||||
{"varchar short", Column{DataType: "varchar"}, "str"},
|
||||
{"text", Column{DataType: "text"}, "str"},
|
||||
{"char", Column{DataType: "char"}, "str"},
|
||||
{"character", Column{DataType: "character"}, "str"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"},
|
||||
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"},
|
||||
{"timestamp", Column{DataType: "timestamp"}, "datetime"},
|
||||
{"date", Column{DataType: "date"}, "date"},
|
||||
{"time with tz", Column{DataType: "time with time zone"}, "time"},
|
||||
{"time without tz", Column{DataType: "time without time zone"}, "time"},
|
||||
{"time", Column{DataType: "time"}, "time"},
|
||||
|
||||
// JSON types
|
||||
{"json", Column{DataType: "json"}, "dict"},
|
||||
{"jsonb", Column{DataType: "jsonb"}, "dict"},
|
||||
|
||||
// Other types
|
||||
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
||||
{"bytea", Column{DataType: "bytea"}, "bytes"},
|
||||
|
||||
// User-defined (enum)
|
||||
{"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"},
|
||||
|
||||
// Unknown type
|
||||
{"unknown", Column{DataType: "unknown_type"}, "Any"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetPythonType(tt.col)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetSQLAlchemyType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
col Column
|
||||
expected string
|
||||
}{
|
||||
// Integer types
|
||||
{"integer", Column{DataType: "integer"}, "Integer"},
|
||||
{"smallint", Column{DataType: "smallint"}, "SmallInteger"},
|
||||
{"bigint", Column{DataType: "bigint"}, "BigInteger"},
|
||||
|
||||
// Numeric types with precision
|
||||
{
|
||||
"numeric with precision",
|
||||
Column{
|
||||
DataType: "numeric",
|
||||
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
||||
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
||||
},
|
||||
"Numeric(12, 4)",
|
||||
},
|
||||
{
|
||||
"numeric without precision",
|
||||
Column{DataType: "numeric"},
|
||||
"Numeric",
|
||||
},
|
||||
{"real", Column{DataType: "real"}, "Float"},
|
||||
{"double precision", Column{DataType: "double precision"}, "Float"},
|
||||
|
||||
// Boolean
|
||||
{"boolean", Column{DataType: "boolean"}, "Boolean"},
|
||||
|
||||
// String types
|
||||
{
|
||||
"varchar with length",
|
||||
Column{
|
||||
DataType: "character varying",
|
||||
CharMaxLength: sql.NullInt64{Valid: true, Int64: 255},
|
||||
},
|
||||
"String(255)",
|
||||
},
|
||||
{
|
||||
"varchar without length",
|
||||
Column{DataType: "varchar"},
|
||||
"String",
|
||||
},
|
||||
{
|
||||
"char with length",
|
||||
Column{
|
||||
DataType: "char",
|
||||
CharMaxLength: sql.NullInt64{Valid: true, Int64: 10},
|
||||
},
|
||||
"String(10)",
|
||||
},
|
||||
{
|
||||
"char without length",
|
||||
Column{DataType: "character"},
|
||||
"String(1)",
|
||||
},
|
||||
{"text", Column{DataType: "text"}, "Text"},
|
||||
|
||||
// Date/Time types
|
||||
{"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"},
|
||||
{"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"},
|
||||
{"timestamp", Column{DataType: "timestamp"}, "DateTime"},
|
||||
{"date", Column{DataType: "date"}, "Date"},
|
||||
{"time with tz", Column{DataType: "time with time zone"}, "Time"},
|
||||
{"time without tz", Column{DataType: "time without time zone"}, "Time"},
|
||||
{"time", Column{DataType: "time"}, "Time"},
|
||||
|
||||
// JSON types
|
||||
{"json", Column{DataType: "json"}, "JSON"},
|
||||
{"jsonb", Column{DataType: "jsonb"}, "JSONB"},
|
||||
|
||||
// Other types
|
||||
{"uuid", Column{DataType: "uuid"}, "UUID"},
|
||||
{"bytea", Column{DataType: "bytea"}, "LargeBinary"},
|
||||
|
||||
// User-defined (enum)
|
||||
{"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"},
|
||||
|
||||
// Unknown type
|
||||
{"unknown", Column{DataType: "unknown_type"}, "String"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := GetSQLAlchemyType(tt.col)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumn(t *testing.T) {
|
||||
col := Column{
|
||||
Name: "test_column",
|
||||
DataType: "varchar",
|
||||
IsNullable: true,
|
||||
ColumnDefault: sql.NullString{Valid: true, String: "default_value"},
|
||||
CharMaxLength: sql.NullInt64{Valid: true, Int64: 100},
|
||||
NumericPrecision: sql.NullInt64{Valid: false},
|
||||
NumericScale: sql.NullInt64{Valid: false},
|
||||
UdtName: "",
|
||||
IsPrimaryKey: false,
|
||||
IsAutoIncrement: false,
|
||||
}
|
||||
|
||||
if col.Name != "test_column" {
|
||||
t.Errorf("Expected Name 'test_column', got %q", col.Name)
|
||||
}
|
||||
|
||||
if !col.IsNullable {
|
||||
t.Error("Expected IsNullable to be true")
|
||||
}
|
||||
|
||||
if !col.ColumnDefault.Valid {
|
||||
t.Error("Expected ColumnDefault to be valid")
|
||||
}
|
||||
|
||||
if col.ColumnDefault.String != "default_value" {
|
||||
t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String)
|
||||
}
|
||||
}
|
||||
|
||||
func TestForeignKey(t *testing.T) {
|
||||
fk := ForeignKey{
|
||||
ColumnName: "user_id",
|
||||
ForeignTableSchema: "public",
|
||||
ForeignTableName: "users",
|
||||
ForeignColumnName: "id",
|
||||
ConstraintName: "fk_user_id",
|
||||
}
|
||||
|
||||
if fk.ColumnName != "user_id" {
|
||||
t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName)
|
||||
}
|
||||
|
||||
if fk.ForeignTableName != "users" {
|
||||
t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName)
|
||||
}
|
||||
|
||||
if fk.ForeignColumnName != "id" {
|
||||
t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnumType(t *testing.T) {
|
||||
enum := EnumType{
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"OPEN", "CLOSED", "PENDING"},
|
||||
}
|
||||
|
||||
if enum.TypeName != "status_enum" {
|
||||
t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName)
|
||||
}
|
||||
|
||||
if len(enum.Values) != 3 {
|
||||
t.Errorf("Expected 3 values, got %d", len(enum.Values))
|
||||
}
|
||||
|
||||
expectedValues := []string{"OPEN", "CLOSED", "PENDING"}
|
||||
for i, val := range enum.Values {
|
||||
if val != expectedValues[i] {
|
||||
t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTableInfo(t *testing.T) {
|
||||
tableInfo := &TableInfo{
|
||||
Schema: "public",
|
||||
TableName: "users",
|
||||
Columns: []Column{
|
||||
{Name: "id", DataType: "integer", IsPrimaryKey: true},
|
||||
{Name: "name", DataType: "varchar"},
|
||||
},
|
||||
ForeignKeys: []ForeignKey{
|
||||
{ColumnName: "company_id", ForeignTableName: "companies"},
|
||||
},
|
||||
EnumTypes: map[string]EnumType{
|
||||
"status_enum": {
|
||||
TypeName: "status_enum",
|
||||
Values: []string{"ACTIVE", "INACTIVE"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
if tableInfo.Schema != "public" {
|
||||
t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema)
|
||||
}
|
||||
|
||||
if tableInfo.TableName != "users" {
|
||||
t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName)
|
||||
}
|
||||
|
||||
if len(tableInfo.Columns) != 2 {
|
||||
t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns))
|
||||
}
|
||||
|
||||
if len(tableInfo.ForeignKeys) != 1 {
|
||||
t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys))
|
||||
}
|
||||
|
||||
if len(tableInfo.EnumTypes) != 1 {
|
||||
t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes))
|
||||
}
|
||||
|
||||
// Test primary key detection
|
||||
foundPK := false
|
||||
for _, col := range tableInfo.Columns {
|
||||
if col.IsPrimaryKey {
|
||||
foundPK = true
|
||||
if col.Name != "id" {
|
||||
t.Errorf("Expected primary key to be 'id', got %q", col.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !foundPK {
|
||||
t.Error("Expected to find primary key column")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
cfg := Config{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
Database: "testdb",
|
||||
Schema: "public",
|
||||
User: "testuser",
|
||||
Password: "testpass",
|
||||
}
|
||||
|
||||
if cfg.Host != "localhost" {
|
||||
t.Errorf("Expected Host 'localhost', got %q", cfg.Host)
|
||||
}
|
||||
|
||||
if cfg.Port != 5432 {
|
||||
t.Errorf("Expected Port 5432, got %d", cfg.Port)
|
||||
}
|
||||
|
||||
if cfg.Database != "testdb" {
|
||||
t.Errorf("Expected Database 'testdb', got %q", cfg.Database)
|
||||
}
|
||||
|
||||
if cfg.Schema != "public" {
|
||||
t.Errorf("Expected Schema 'public', got %q", cfg.Schema)
|
||||
}
|
||||
}
|
||||
|
||||
// Benchmark tests
|
||||
func BenchmarkGetPythonType(b *testing.B) {
|
||||
col := Column{DataType: "character varying"}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetPythonType(col)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGetSQLAlchemyType(b *testing.B) {
|
||||
col := Column{
|
||||
DataType: "numeric",
|
||||
NumericPrecision: sql.NullInt64{Valid: true, Int64: 12},
|
||||
NumericScale: sql.NullInt64{Valid: true, Int64: 4},
|
||||
}
|
||||
b.ResetTimer()
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
GetSQLAlchemyType(col)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user