First working version
This commit is contained in:
341
internal/database/introspector.go
Normal file
341
internal/database/introspector.go
Normal file
@ -0,0 +1,341 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Column represents a database column
|
||||
type Column struct {
|
||||
Name string
|
||||
DataType string
|
||||
IsNullable bool
|
||||
ColumnDefault sql.NullString
|
||||
CharMaxLength sql.NullInt64
|
||||
NumericPrecision sql.NullInt64
|
||||
NumericScale sql.NullInt64
|
||||
UdtName string // User-defined type name (for enums)
|
||||
IsPrimaryKey bool
|
||||
IsAutoIncrement bool
|
||||
}
|
||||
|
||||
// ForeignKey represents a foreign key constraint
|
||||
type ForeignKey struct {
|
||||
ColumnName string
|
||||
ForeignTableSchema string
|
||||
ForeignTableName string
|
||||
ForeignColumnName string
|
||||
ConstraintName string
|
||||
}
|
||||
|
||||
// EnumType represents a PostgreSQL enum type
|
||||
type EnumType struct {
|
||||
TypeName string
|
||||
Values []string
|
||||
}
|
||||
|
||||
// TableInfo contains all information about a table
|
||||
type TableInfo struct {
|
||||
Schema string
|
||||
TableName string
|
||||
Columns []Column
|
||||
ForeignKeys []ForeignKey
|
||||
EnumTypes map[string]EnumType
|
||||
}
|
||||
|
||||
// IntrospectTable retrieves comprehensive information about a table
|
||||
func (c *Client) IntrospectTable(tableName string) (*TableInfo, error) {
|
||||
info := &TableInfo{
|
||||
Schema: c.schema,
|
||||
TableName: tableName,
|
||||
EnumTypes: make(map[string]EnumType),
|
||||
}
|
||||
|
||||
// Get columns
|
||||
columns, err := c.getColumns(tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.Columns = columns
|
||||
|
||||
// Get primary keys and mark columns
|
||||
primaryKeys, err := c.getPrimaryKeys(tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range info.Columns {
|
||||
for _, pk := range primaryKeys {
|
||||
if info.Columns[i].Name == pk {
|
||||
info.Columns[i].IsPrimaryKey = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Get foreign keys
|
||||
foreignKeys, err := c.getForeignKeys(tableName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.ForeignKeys = foreignKeys
|
||||
|
||||
// Get enum types for columns that use them
|
||||
for _, col := range info.Columns {
|
||||
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
|
||||
if _, exists := info.EnumTypes[col.UdtName]; !exists {
|
||||
enumType, err := c.getEnumType(col.UdtName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info.EnumTypes[col.UdtName] = enumType
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// getColumns retrieves column information for a table
|
||||
func (c *Client) getColumns(tableName string) ([]Column, error) {
|
||||
query := `
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable = 'YES' as is_nullable,
|
||||
column_default,
|
||||
character_maximum_length,
|
||||
numeric_precision,
|
||||
numeric_scale,
|
||||
udt_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = $1 AND table_name = $2
|
||||
ORDER BY ordinal_position
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, c.schema, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query columns: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var columns []Column
|
||||
for rows.Next() {
|
||||
var col Column
|
||||
err := rows.Scan(
|
||||
&col.Name,
|
||||
&col.DataType,
|
||||
&col.IsNullable,
|
||||
&col.ColumnDefault,
|
||||
&col.CharMaxLength,
|
||||
&col.NumericPrecision,
|
||||
&col.NumericScale,
|
||||
&col.UdtName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if column is auto-increment
|
||||
if col.ColumnDefault.Valid {
|
||||
defaultVal := strings.ToLower(col.ColumnDefault.String)
|
||||
col.IsAutoIncrement = strings.Contains(defaultVal, "nextval")
|
||||
}
|
||||
|
||||
columns = append(columns, col)
|
||||
}
|
||||
|
||||
return columns, rows.Err()
|
||||
}
|
||||
|
||||
// getPrimaryKeys retrieves primary key column names for a table
|
||||
func (c *Client) getPrimaryKeys(tableName string) ([]string, error) {
|
||||
query := `
|
||||
SELECT a.attname
|
||||
FROM pg_index i
|
||||
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
|
||||
WHERE i.indrelid = ($1 || '.' || $2)::regclass
|
||||
AND i.indisprimary
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, c.schema, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query primary keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var keys []string
|
||||
for rows.Next() {
|
||||
var key string
|
||||
if err := rows.Scan(&key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
|
||||
return keys, rows.Err()
|
||||
}
|
||||
|
||||
// getForeignKeys retrieves foreign key information for a table
|
||||
func (c *Client) getForeignKeys(tableName string) ([]ForeignKey, error) {
|
||||
query := `
|
||||
SELECT
|
||||
kcu.column_name,
|
||||
ccu.table_schema AS foreign_table_schema,
|
||||
ccu.table_name AS foreign_table_name,
|
||||
ccu.column_name AS foreign_column_name,
|
||||
tc.constraint_name
|
||||
FROM information_schema.table_constraints AS tc
|
||||
JOIN information_schema.key_column_usage AS kcu
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
JOIN information_schema.constraint_column_usage AS ccu
|
||||
ON ccu.constraint_name = tc.constraint_name
|
||||
AND ccu.table_schema = tc.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND tc.table_schema = $1
|
||||
AND tc.table_name = $2
|
||||
ORDER BY kcu.ordinal_position
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, c.schema, tableName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to query foreign keys: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var fks []ForeignKey
|
||||
for rows.Next() {
|
||||
var fk ForeignKey
|
||||
err := rows.Scan(
|
||||
&fk.ColumnName,
|
||||
&fk.ForeignTableSchema,
|
||||
&fk.ForeignTableName,
|
||||
&fk.ForeignColumnName,
|
||||
&fk.ConstraintName,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fks = append(fks, fk)
|
||||
}
|
||||
|
||||
return fks, rows.Err()
|
||||
}
|
||||
|
||||
// getEnumType retrieves enum values for a PostgreSQL enum type
|
||||
func (c *Client) getEnumType(typeName string) (EnumType, error) {
|
||||
query := `
|
||||
SELECT e.enumlabel
|
||||
FROM pg_type t
|
||||
JOIN pg_enum e ON t.oid = e.enumtypid
|
||||
WHERE t.typname = $1
|
||||
ORDER BY e.enumsortorder
|
||||
`
|
||||
|
||||
rows, err := c.db.Query(query, typeName)
|
||||
if err != nil {
|
||||
return EnumType{}, fmt.Errorf("failed to query enum type: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var values []string
|
||||
for rows.Next() {
|
||||
var value string
|
||||
if err := rows.Scan(&value); err != nil {
|
||||
return EnumType{}, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
return EnumType{
|
||||
TypeName: typeName,
|
||||
Values: values,
|
||||
}, rows.Err()
|
||||
}
|
||||
|
||||
// GetPythonType converts PostgreSQL data type to Python type
|
||||
func GetPythonType(col Column) string {
|
||||
switch col.DataType {
|
||||
case "integer", "smallint", "bigint":
|
||||
return "int"
|
||||
case "numeric", "decimal", "real", "double precision":
|
||||
return "Decimal"
|
||||
case "boolean":
|
||||
return "bool"
|
||||
case "character varying", "varchar", "text", "char", "character":
|
||||
return "str"
|
||||
case "timestamp with time zone", "timestamp without time zone", "timestamp":
|
||||
return "datetime"
|
||||
case "date":
|
||||
return "date"
|
||||
case "time with time zone", "time without time zone", "time":
|
||||
return "time"
|
||||
case "json", "jsonb":
|
||||
return "dict"
|
||||
case "uuid":
|
||||
return "UUID"
|
||||
case "bytea":
|
||||
return "bytes"
|
||||
case "USER-DEFINED":
|
||||
// This is likely an enum
|
||||
return "str" // Will be replaced with specific enum type in generator
|
||||
default:
|
||||
return "Any"
|
||||
}
|
||||
}
|
||||
|
||||
// GetSQLAlchemyType converts PostgreSQL data type to SQLAlchemy type
|
||||
func GetSQLAlchemyType(col Column) string {
|
||||
switch col.DataType {
|
||||
case "integer":
|
||||
return "Integer"
|
||||
case "smallint":
|
||||
return "SmallInteger"
|
||||
case "bigint":
|
||||
return "BigInteger"
|
||||
case "numeric", "decimal":
|
||||
if col.NumericPrecision.Valid && col.NumericScale.Valid {
|
||||
return fmt.Sprintf("Numeric(%d, %d)", col.NumericPrecision.Int64, col.NumericScale.Int64)
|
||||
}
|
||||
return "Numeric"
|
||||
case "real":
|
||||
return "Float"
|
||||
case "double precision":
|
||||
return "Float"
|
||||
case "boolean":
|
||||
return "Boolean"
|
||||
case "character varying", "varchar":
|
||||
if col.CharMaxLength.Valid {
|
||||
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
|
||||
}
|
||||
return "String"
|
||||
case "char", "character":
|
||||
if col.CharMaxLength.Valid {
|
||||
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
|
||||
}
|
||||
return "String(1)"
|
||||
case "text":
|
||||
return "Text"
|
||||
case "timestamp with time zone":
|
||||
return "DateTime(timezone=True)"
|
||||
case "timestamp without time zone", "timestamp":
|
||||
return "DateTime"
|
||||
case "date":
|
||||
return "Date"
|
||||
case "time with time zone", "time without time zone", "time":
|
||||
return "Time"
|
||||
case "json":
|
||||
return "JSON"
|
||||
case "jsonb":
|
||||
return "JSONB"
|
||||
case "uuid":
|
||||
return "UUID"
|
||||
case "bytea":
|
||||
return "LargeBinary"
|
||||
case "USER-DEFINED":
|
||||
// Will be handled separately for enums
|
||||
return "Enum"
|
||||
default:
|
||||
return "String"
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user