himbot/lib/db.go
Atridad Lahiji 3af77b6bea
All checks were successful
Docker Deploy / build-and-push (push) Successful in 1m26s
no
2024-11-22 17:34:41 -06:00

142 lines
3.2 KiB
Go

package lib
import (
"database/sql"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"sort"
"strings"
"github.com/tursodatabase/go-libsql"
)
var DBClient *sql.DB
var DBConnector *libsql.Connector
func InitDB() error {
// Determine DB path based on /data directory existence
var dbPath string
if _, err := os.Stat("/data"); os.IsNotExist(err) {
dbPath = "file:./himbot.db"
} else {
dbPath = "file:/data/himbot.db"
}
db, err := sql.Open("libsql", dbPath)
if err != nil {
fmt.Fprintf(os.Stderr, "failed to open db %s: %v", dbPath, err)
os.Exit(1)
}
DBClient = db
return runMigrations()
}
type Migration struct {
Version int
Up string
Down string
}
func loadMigrations() ([]Migration, error) {
var migrations []Migration
migrationFiles, err := filepath.Glob("migrations/*.up.sql")
if err != nil {
return nil, fmt.Errorf("failed to read migration files: %w", err)
}
for _, upFile := range migrationFiles {
// Extract version from filename (000001_create_users_table.up.sql -> 1)
baseName := filepath.Base(upFile)
version := 0
fmt.Sscanf(baseName, "%d_", &version)
downFile := strings.Replace(upFile, ".up.sql", ".down.sql", 1)
upSQL, err := ioutil.ReadFile(upFile)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", upFile, err)
}
downSQL, err := ioutil.ReadFile(downFile)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", downFile, err)
}
migrations = append(migrations, Migration{
Version: version,
Up: string(upSQL),
Down: string(downSQL),
})
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
func runMigrations() error {
// Create migrations table if it doesn't exist
_, err := DBClient.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
return fmt.Errorf("failed to create schema_migrations table: %w", err)
}
migrations, err := loadMigrations()
if err != nil {
return err
}
for _, migration := range migrations {
var exists bool
err := DBClient.QueryRow(
"SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = ?)",
migration.Version).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if !exists {
tx, err := DBClient.Begin()
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
// Run migration
_, err = tx.Exec(migration.Up)
if err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d: %w", migration.Version, err)
}
// Record migration
_, err = tx.Exec(
"INSERT INTO schema_migrations (version) VALUES (?)",
migration.Version)
if err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
}
log.Printf("Applied migration %d", migration.Version)
}
}
log.Println("Database migrations completed successfully")
return nil
}