First working version

This commit is contained in:
Eden Kirin
2025-10-31 14:36:41 +01:00
parent f8edfc0fc1
commit f9f67b6c93
22 changed files with 2277 additions and 0 deletions

View File

@ -0,0 +1,67 @@
package database
import (
"database/sql"
"fmt"
_ "github.com/lib/pq"
)
// Client represents a database connection
type Client struct {
db *sql.DB
schema string
}
// Config holds database connection parameters
type Config struct {
Host string
Port int
Database string
Schema string
User string
Password string
}
// NewClient creates a new database client
func NewClient(cfg Config) (*Client, error) {
connStr := fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
cfg.Host, cfg.Port, cfg.User, cfg.Password, cfg.Database,
)
db, err := sql.Open("postgres", connStr)
if err != nil {
return nil, fmt.Errorf("failed to open database: %w", err)
}
if err := db.Ping(); err != nil {
db.Close()
return nil, fmt.Errorf("failed to connect to database: %w", err)
}
return &Client{
db: db,
schema: cfg.Schema,
}, nil
}
// Close closes the database connection
func (c *Client) Close() error {
return c.db.Close()
}
// TableExists checks if a table exists in the schema
func (c *Client) TableExists(tableName string) (bool, error) {
query := `
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = $1 AND table_name = $2
)
`
var exists bool
err := c.db.QueryRow(query, c.schema, tableName).Scan(&exists)
return exists, err
}

View File

@ -0,0 +1,341 @@
package database
import (
"database/sql"
"fmt"
"strings"
)
// Column represents a database column
type Column struct {
Name string
DataType string
IsNullable bool
ColumnDefault sql.NullString
CharMaxLength sql.NullInt64
NumericPrecision sql.NullInt64
NumericScale sql.NullInt64
UdtName string // User-defined type name (for enums)
IsPrimaryKey bool
IsAutoIncrement bool
}
// ForeignKey represents a foreign key constraint
type ForeignKey struct {
ColumnName string
ForeignTableSchema string
ForeignTableName string
ForeignColumnName string
ConstraintName string
}
// EnumType represents a PostgreSQL enum type
type EnumType struct {
TypeName string
Values []string
}
// TableInfo contains all information about a table
type TableInfo struct {
Schema string
TableName string
Columns []Column
ForeignKeys []ForeignKey
EnumTypes map[string]EnumType
}
// IntrospectTable retrieves comprehensive information about a table
func (c *Client) IntrospectTable(tableName string) (*TableInfo, error) {
info := &TableInfo{
Schema: c.schema,
TableName: tableName,
EnumTypes: make(map[string]EnumType),
}
// Get columns
columns, err := c.getColumns(tableName)
if err != nil {
return nil, err
}
info.Columns = columns
// Get primary keys and mark columns
primaryKeys, err := c.getPrimaryKeys(tableName)
if err != nil {
return nil, err
}
for i := range info.Columns {
for _, pk := range primaryKeys {
if info.Columns[i].Name == pk {
info.Columns[i].IsPrimaryKey = true
}
}
}
// Get foreign keys
foreignKeys, err := c.getForeignKeys(tableName)
if err != nil {
return nil, err
}
info.ForeignKeys = foreignKeys
// Get enum types for columns that use them
for _, col := range info.Columns {
if col.DataType == "USER-DEFINED" && col.UdtName != "" {
if _, exists := info.EnumTypes[col.UdtName]; !exists {
enumType, err := c.getEnumType(col.UdtName)
if err != nil {
return nil, err
}
info.EnumTypes[col.UdtName] = enumType
}
}
}
return info, nil
}
// getColumns retrieves column information for a table
func (c *Client) getColumns(tableName string) ([]Column, error) {
query := `
SELECT
column_name,
data_type,
is_nullable = 'YES' as is_nullable,
column_default,
character_maximum_length,
numeric_precision,
numeric_scale,
udt_name
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
`
rows, err := c.db.Query(query, c.schema, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query columns: %w", err)
}
defer rows.Close()
var columns []Column
for rows.Next() {
var col Column
err := rows.Scan(
&col.Name,
&col.DataType,
&col.IsNullable,
&col.ColumnDefault,
&col.CharMaxLength,
&col.NumericPrecision,
&col.NumericScale,
&col.UdtName,
)
if err != nil {
return nil, err
}
// Check if column is auto-increment
if col.ColumnDefault.Valid {
defaultVal := strings.ToLower(col.ColumnDefault.String)
col.IsAutoIncrement = strings.Contains(defaultVal, "nextval")
}
columns = append(columns, col)
}
return columns, rows.Err()
}
// getPrimaryKeys retrieves primary key column names for a table
func (c *Client) getPrimaryKeys(tableName string) ([]string, error) {
query := `
SELECT a.attname
FROM pg_index i
JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey)
WHERE i.indrelid = ($1 || '.' || $2)::regclass
AND i.indisprimary
`
rows, err := c.db.Query(query, c.schema, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query primary keys: %w", err)
}
defer rows.Close()
var keys []string
for rows.Next() {
var key string
if err := rows.Scan(&key); err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, rows.Err()
}
// getForeignKeys retrieves foreign key information for a table
func (c *Client) getForeignKeys(tableName string) ([]ForeignKey, error) {
query := `
SELECT
kcu.column_name,
ccu.table_schema AS foreign_table_schema,
ccu.table_name AS foreign_table_name,
ccu.column_name AS foreign_column_name,
tc.constraint_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
JOIN information_schema.constraint_column_usage AS ccu
ON ccu.constraint_name = tc.constraint_name
AND ccu.table_schema = tc.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_schema = $1
AND tc.table_name = $2
ORDER BY kcu.ordinal_position
`
rows, err := c.db.Query(query, c.schema, tableName)
if err != nil {
return nil, fmt.Errorf("failed to query foreign keys: %w", err)
}
defer rows.Close()
var fks []ForeignKey
for rows.Next() {
var fk ForeignKey
err := rows.Scan(
&fk.ColumnName,
&fk.ForeignTableSchema,
&fk.ForeignTableName,
&fk.ForeignColumnName,
&fk.ConstraintName,
)
if err != nil {
return nil, err
}
fks = append(fks, fk)
}
return fks, rows.Err()
}
// getEnumType retrieves enum values for a PostgreSQL enum type
func (c *Client) getEnumType(typeName string) (EnumType, error) {
query := `
SELECT e.enumlabel
FROM pg_type t
JOIN pg_enum e ON t.oid = e.enumtypid
WHERE t.typname = $1
ORDER BY e.enumsortorder
`
rows, err := c.db.Query(query, typeName)
if err != nil {
return EnumType{}, fmt.Errorf("failed to query enum type: %w", err)
}
defer rows.Close()
var values []string
for rows.Next() {
var value string
if err := rows.Scan(&value); err != nil {
return EnumType{}, err
}
values = append(values, value)
}
return EnumType{
TypeName: typeName,
Values: values,
}, rows.Err()
}
// GetPythonType converts PostgreSQL data type to Python type
func GetPythonType(col Column) string {
switch col.DataType {
case "integer", "smallint", "bigint":
return "int"
case "numeric", "decimal", "real", "double precision":
return "Decimal"
case "boolean":
return "bool"
case "character varying", "varchar", "text", "char", "character":
return "str"
case "timestamp with time zone", "timestamp without time zone", "timestamp":
return "datetime"
case "date":
return "date"
case "time with time zone", "time without time zone", "time":
return "time"
case "json", "jsonb":
return "dict"
case "uuid":
return "UUID"
case "bytea":
return "bytes"
case "USER-DEFINED":
// This is likely an enum
return "str" // Will be replaced with specific enum type in generator
default:
return "Any"
}
}
// GetSQLAlchemyType converts PostgreSQL data type to SQLAlchemy type
func GetSQLAlchemyType(col Column) string {
switch col.DataType {
case "integer":
return "Integer"
case "smallint":
return "SmallInteger"
case "bigint":
return "BigInteger"
case "numeric", "decimal":
if col.NumericPrecision.Valid && col.NumericScale.Valid {
return fmt.Sprintf("Numeric(%d, %d)", col.NumericPrecision.Int64, col.NumericScale.Int64)
}
return "Numeric"
case "real":
return "Float"
case "double precision":
return "Float"
case "boolean":
return "Boolean"
case "character varying", "varchar":
if col.CharMaxLength.Valid {
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
}
return "String"
case "char", "character":
if col.CharMaxLength.Valid {
return fmt.Sprintf("String(%d)", col.CharMaxLength.Int64)
}
return "String(1)"
case "text":
return "Text"
case "timestamp with time zone":
return "DateTime(timezone=True)"
case "timestamp without time zone", "timestamp":
return "DateTime"
case "date":
return "Date"
case "time with time zone", "time without time zone", "time":
return "Time"
case "json":
return "JSON"
case "jsonb":
return "JSONB"
case "uuid":
return "UUID"
case "bytea":
return "LargeBinary"
case "USER-DEFINED":
// Will be handled separately for enums
return "Enum"
default:
return "String"
}
}