package database import ( "database/sql" "testing" ) func TestGetPythonType(t *testing.T) { tests := []struct { name string col Column expected string }{ // Integer types {"integer", Column{DataType: "integer"}, "int"}, {"smallint", Column{DataType: "smallint"}, "int"}, {"bigint", Column{DataType: "bigint"}, "int"}, // Numeric types {"numeric", Column{DataType: "numeric"}, "Decimal"}, {"decimal", Column{DataType: "decimal"}, "Decimal"}, {"real", Column{DataType: "real"}, "Decimal"}, {"double precision", Column{DataType: "double precision"}, "Decimal"}, // Boolean {"boolean", Column{DataType: "boolean"}, "bool"}, // String types {"varchar", Column{DataType: "character varying"}, "str"}, {"varchar short", Column{DataType: "varchar"}, "str"}, {"text", Column{DataType: "text"}, "str"}, {"char", Column{DataType: "char"}, "str"}, {"character", Column{DataType: "character"}, "str"}, // Date/Time types {"timestamp with tz", Column{DataType: "timestamp with time zone"}, "datetime"}, {"timestamp without tz", Column{DataType: "timestamp without time zone"}, "datetime"}, {"timestamp", Column{DataType: "timestamp"}, "datetime"}, {"date", Column{DataType: "date"}, "date"}, {"time with tz", Column{DataType: "time with time zone"}, "time"}, {"time without tz", Column{DataType: "time without time zone"}, "time"}, {"time", Column{DataType: "time"}, "time"}, // JSON types {"json", Column{DataType: "json"}, "dict"}, {"jsonb", Column{DataType: "jsonb"}, "dict"}, // Other types {"uuid", Column{DataType: "uuid"}, "UUID"}, {"bytea", Column{DataType: "bytea"}, "bytes"}, // User-defined (enum) {"user-defined", Column{DataType: "USER-DEFINED", UdtName: "status_enum"}, "str"}, // Unknown type {"unknown", Column{DataType: "unknown_type"}, "Any"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetPythonType(tt.col) if result != tt.expected { t.Errorf("GetPythonType(%+v) = %q, want %q", tt.col, result, tt.expected) } }) } } func TestGetSQLAlchemyType(t *testing.T) { tests := []struct { name string col Column expected string }{ // Integer types {"integer", Column{DataType: "integer"}, "Integer"}, {"smallint", Column{DataType: "smallint"}, "SmallInteger"}, {"bigint", Column{DataType: "bigint"}, "BigInteger"}, // Numeric types with precision { "numeric with precision", Column{ DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 12}, NumericScale: sql.NullInt64{Valid: true, Int64: 4}, }, "Numeric(12, 4)", }, { "numeric without precision", Column{DataType: "numeric"}, "Numeric", }, {"real", Column{DataType: "real"}, "Float"}, {"double precision", Column{DataType: "double precision"}, "Float"}, // Boolean {"boolean", Column{DataType: "boolean"}, "Boolean"}, // String types { "varchar with length", Column{ DataType: "character varying", CharMaxLength: sql.NullInt64{Valid: true, Int64: 255}, }, "String(255)", }, { "varchar without length", Column{DataType: "varchar"}, "String", }, { "char with length", Column{ DataType: "char", CharMaxLength: sql.NullInt64{Valid: true, Int64: 10}, }, "String(10)", }, { "char without length", Column{DataType: "character"}, "String(1)", }, {"text", Column{DataType: "text"}, "Text"}, // Date/Time types {"timestamp with tz", Column{DataType: "timestamp with time zone"}, "DateTime(timezone=True)"}, {"timestamp without tz", Column{DataType: "timestamp without time zone"}, "DateTime"}, {"timestamp", Column{DataType: "timestamp"}, "DateTime"}, {"date", Column{DataType: "date"}, "Date"}, {"time with tz", Column{DataType: "time with time zone"}, "Time"}, {"time without tz", Column{DataType: "time without time zone"}, "Time"}, {"time", Column{DataType: "time"}, "Time"}, // JSON types {"json", Column{DataType: "json"}, "JSON"}, {"jsonb", Column{DataType: "jsonb"}, "JSONB"}, // Other types {"uuid", Column{DataType: "uuid"}, "UUID"}, {"bytea", Column{DataType: "bytea"}, "LargeBinary"}, // User-defined (enum) {"user-defined", Column{DataType: "USER-DEFINED"}, "Enum"}, // Unknown type {"unknown", Column{DataType: "unknown_type"}, "String"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { result := GetSQLAlchemyType(tt.col) if result != tt.expected { t.Errorf("GetSQLAlchemyType(%+v) = %q, want %q", tt.col, result, tt.expected) } }) } } func TestColumn(t *testing.T) { col := Column{ Name: "test_column", DataType: "varchar", IsNullable: true, ColumnDefault: sql.NullString{Valid: true, String: "default_value"}, CharMaxLength: sql.NullInt64{Valid: true, Int64: 100}, NumericPrecision: sql.NullInt64{Valid: false}, NumericScale: sql.NullInt64{Valid: false}, UdtName: "", IsPrimaryKey: false, IsAutoIncrement: false, } if col.Name != "test_column" { t.Errorf("Expected Name 'test_column', got %q", col.Name) } if !col.IsNullable { t.Error("Expected IsNullable to be true") } if !col.ColumnDefault.Valid { t.Error("Expected ColumnDefault to be valid") } if col.ColumnDefault.String != "default_value" { t.Errorf("Expected ColumnDefault 'default_value', got %q", col.ColumnDefault.String) } } func TestForeignKey(t *testing.T) { fk := ForeignKey{ ColumnName: "user_id", ForeignTableSchema: "public", ForeignTableName: "users", ForeignColumnName: "id", ConstraintName: "fk_user_id", } if fk.ColumnName != "user_id" { t.Errorf("Expected ColumnName 'user_id', got %q", fk.ColumnName) } if fk.ForeignTableName != "users" { t.Errorf("Expected ForeignTableName 'users', got %q", fk.ForeignTableName) } if fk.ForeignColumnName != "id" { t.Errorf("Expected ForeignColumnName 'id', got %q", fk.ForeignColumnName) } } func TestEnumType(t *testing.T) { enum := EnumType{ TypeName: "status_enum", Values: []string{"OPEN", "CLOSED", "PENDING"}, } if enum.TypeName != "status_enum" { t.Errorf("Expected TypeName 'status_enum', got %q", enum.TypeName) } if len(enum.Values) != 3 { t.Errorf("Expected 3 values, got %d", len(enum.Values)) } expectedValues := []string{"OPEN", "CLOSED", "PENDING"} for i, val := range enum.Values { if val != expectedValues[i] { t.Errorf("Expected value %q at index %d, got %q", expectedValues[i], i, val) } } } func TestTableInfo(t *testing.T) { tableInfo := &TableInfo{ Schema: "public", TableName: "users", Columns: []Column{ {Name: "id", DataType: "integer", IsPrimaryKey: true}, {Name: "name", DataType: "varchar"}, }, ForeignKeys: []ForeignKey{ {ColumnName: "company_id", ForeignTableName: "companies"}, }, EnumTypes: map[string]EnumType{ "status_enum": { TypeName: "status_enum", Values: []string{"ACTIVE", "INACTIVE"}, }, }, } if tableInfo.Schema != "public" { t.Errorf("Expected Schema 'public', got %q", tableInfo.Schema) } if tableInfo.TableName != "users" { t.Errorf("Expected TableName 'users', got %q", tableInfo.TableName) } if len(tableInfo.Columns) != 2 { t.Errorf("Expected 2 columns, got %d", len(tableInfo.Columns)) } if len(tableInfo.ForeignKeys) != 1 { t.Errorf("Expected 1 foreign key, got %d", len(tableInfo.ForeignKeys)) } if len(tableInfo.EnumTypes) != 1 { t.Errorf("Expected 1 enum type, got %d", len(tableInfo.EnumTypes)) } // Test primary key detection foundPK := false for _, col := range tableInfo.Columns { if col.IsPrimaryKey { foundPK = true if col.Name != "id" { t.Errorf("Expected primary key to be 'id', got %q", col.Name) } } } if !foundPK { t.Error("Expected to find primary key column") } } func TestConfig(t *testing.T) { cfg := Config{ Host: "localhost", Port: 5432, Database: "testdb", Schema: "public", User: "testuser", Password: "testpass", } if cfg.Host != "localhost" { t.Errorf("Expected Host 'localhost', got %q", cfg.Host) } if cfg.Port != 5432 { t.Errorf("Expected Port 5432, got %d", cfg.Port) } if cfg.Database != "testdb" { t.Errorf("Expected Database 'testdb', got %q", cfg.Database) } if cfg.Schema != "public" { t.Errorf("Expected Schema 'public', got %q", cfg.Schema) } } // Benchmark tests func BenchmarkGetPythonType(b *testing.B) { col := Column{DataType: "character varying"} b.ResetTimer() for i := 0; i < b.N; i++ { GetPythonType(col) } } func BenchmarkGetSQLAlchemyType(b *testing.B) { col := Column{ DataType: "numeric", NumericPrecision: sql.NullInt64{Valid: true, Int64: 12}, NumericScale: sql.NullInt64{Valid: true, Int64: 4}, } b.ResetTimer() for i := 0; i < b.N; i++ { GetSQLAlchemyType(col) } }