himbot/lib/db.go
2024-11-04 11:14:43 -06:00

167 lines
3.8 KiB
Go

package lib
import (
"database/sql"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"sort"
"strings"
_ "github.com/tursodatabase/libsql-client-go/libsql"
)
var DBClient *sql.DB
func InitDB() error {
dbUrl := os.Getenv("DATABASE_URL")
dbToken := os.Getenv("DATABASE_AUTH_TOKEN")
if dbUrl == "" || dbToken == "" {
return fmt.Errorf("database configuration missing")
}
finalDBUrl := fmt.Sprintf("%s?authToken=%s", dbUrl, dbToken)
client, clientError := sql.Open("libsql", finalDBUrl)
if clientError != nil {
fmt.Fprintf(os.Stderr, "failed to open db %s: %s", dbUrl, clientError)
os.Exit(1)
}
DBClient = client
return runMigrations()
}
type Migration struct {
Version int
Up string
Down string
}
func loadMigrations() ([]Migration, error) {
var migrations []Migration
migrationFiles, err := filepath.Glob("migrations/*.up.sql")
if err != nil {
return nil, fmt.Errorf("failed to read migration files: %w", err)
}
for _, upFile := range migrationFiles {
// Extract version from filename (000001_create_users_table.up.sql -> 1)
baseName := filepath.Base(upFile)
version := 0
fmt.Sscanf(baseName, "%d_", &version)
downFile := strings.Replace(upFile, ".up.sql", ".down.sql", 1)
upSQL, err := ioutil.ReadFile(upFile)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", upFile, err)
}
downSQL, err := ioutil.ReadFile(downFile)
if err != nil {
return nil, fmt.Errorf("failed to read migration file %s: %w", downFile, err)
}
migrations = append(migrations, Migration{
Version: version,
Up: string(upSQL),
Down: string(downSQL),
})
}
sort.Slice(migrations, func(i, j int) bool {
return migrations[i].Version < migrations[j].Version
})
return migrations, nil
}
func runMigrations() error {
// Create migrations table if it doesn't exist
_, err := DBClient.Exec(`
CREATE TABLE IF NOT EXISTS schema_migrations (
version INTEGER PRIMARY KEY,
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
)`)
if err != nil {
return fmt.Errorf("failed to create schema_migrations table: %w", err)
}
migrations, err := loadMigrations()
if err != nil {
return err
}
for _, migration := range migrations {
var exists bool
err := DBClient.QueryRow(
"SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = ?)",
migration.Version).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check migration status: %w", err)
}
if !exists {
tx, err := DBClient.Begin()
if err != nil {
return fmt.Errorf("failed to start transaction: %w", err)
}
// Run migration
_, err = tx.Exec(migration.Up)
if err != nil {
tx.Rollback()
return fmt.Errorf("failed to apply migration %d: %w", migration.Version, err)
}
// Record migration
_, err = tx.Exec(
"INSERT INTO schema_migrations (version) VALUES (?)",
migration.Version)
if err != nil {
tx.Rollback()
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
}
log.Printf("Applied migration %d", migration.Version)
}
}
log.Println("Database migrations completed successfully")
return nil
}
func StoreUser(discordID, username string) error {
// Check if user exists
var exists bool
err := DBClient.QueryRow(
"SELECT EXISTS(SELECT 1 FROM users WHERE discord_id = ?)",
discordID).Scan(&exists)
if err != nil {
return fmt.Errorf("failed to check user existence: %w", err)
}
// If user doesn't exist, insert them
if !exists {
_, err = DBClient.Exec(
"INSERT INTO users (discord_id, username) VALUES (?, ?)",
discordID, username)
if err != nil {
return fmt.Errorf("failed to store user: %w", err)
}
log.Printf("New user stored: %s (%s)", username, discordID)
}
return nil
}