180 lines
4.2 KiB
Go
180 lines
4.2 KiB
Go
package lib
|
|
|
|
import (
|
|
"database/sql"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/tursodatabase/go-libsql"
|
|
)
|
|
|
|
var DBClient *sql.DB
|
|
var DBConnector *libsql.Connector
|
|
|
|
func InitDB() error {
|
|
dbUrl := os.Getenv("DATABASE_URL")
|
|
dbToken := os.Getenv("DATABASE_AUTH_TOKEN")
|
|
|
|
if dbUrl == "" || dbToken == "" {
|
|
return fmt.Errorf("database configuration missing")
|
|
}
|
|
|
|
// Determine DB path based on /data directory existence
|
|
dbPath := "himbot.db" // default to local
|
|
if _, err := os.Stat("/data"); !os.IsNotExist(err) {
|
|
dbPath = "/data/himbot.db"
|
|
}
|
|
|
|
connector, connectorError := libsql.NewEmbeddedReplicaConnector(
|
|
dbPath,
|
|
dbUrl,
|
|
libsql.WithAuthToken(dbToken),
|
|
)
|
|
|
|
if connectorError != nil {
|
|
fmt.Fprintf(os.Stderr, "failed to open db %s: %s", dbUrl, connectorError)
|
|
os.Exit(1)
|
|
}
|
|
// finalDBUrl := fmt.Sprintf("%s?authToken=%s", dbUrl, dbToken)
|
|
|
|
client := sql.OpenDB(connector)
|
|
|
|
DBClient = client
|
|
DBConnector = connector
|
|
|
|
return runMigrations()
|
|
}
|
|
|
|
type Migration struct {
|
|
Version int
|
|
Up string
|
|
Down string
|
|
}
|
|
|
|
func loadMigrations() ([]Migration, error) {
|
|
var migrations []Migration
|
|
migrationFiles, err := filepath.Glob("migrations/*.up.sql")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read migration files: %w", err)
|
|
}
|
|
|
|
for _, upFile := range migrationFiles {
|
|
// Extract version from filename (000001_create_users_table.up.sql -> 1)
|
|
baseName := filepath.Base(upFile)
|
|
version := 0
|
|
fmt.Sscanf(baseName, "%d_", &version)
|
|
|
|
downFile := strings.Replace(upFile, ".up.sql", ".down.sql", 1)
|
|
|
|
upSQL, err := ioutil.ReadFile(upFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read migration file %s: %w", upFile, err)
|
|
}
|
|
|
|
downSQL, err := ioutil.ReadFile(downFile)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read migration file %s: %w", downFile, err)
|
|
}
|
|
|
|
migrations = append(migrations, Migration{
|
|
Version: version,
|
|
Up: string(upSQL),
|
|
Down: string(downSQL),
|
|
})
|
|
}
|
|
|
|
sort.Slice(migrations, func(i, j int) bool {
|
|
return migrations[i].Version < migrations[j].Version
|
|
})
|
|
|
|
return migrations, nil
|
|
}
|
|
|
|
func runMigrations() error {
|
|
// Create migrations table if it doesn't exist
|
|
_, err := DBClient.Exec(`
|
|
CREATE TABLE IF NOT EXISTS schema_migrations (
|
|
version INTEGER PRIMARY KEY,
|
|
applied_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)`)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create schema_migrations table: %w", err)
|
|
}
|
|
|
|
migrations, err := loadMigrations()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, migration := range migrations {
|
|
var exists bool
|
|
err := DBClient.QueryRow(
|
|
"SELECT EXISTS(SELECT 1 FROM schema_migrations WHERE version = ?)",
|
|
migration.Version).Scan(&exists)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check migration status: %w", err)
|
|
}
|
|
|
|
if !exists {
|
|
tx, err := DBClient.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to start transaction: %w", err)
|
|
}
|
|
|
|
// Run migration
|
|
_, err = tx.Exec(migration.Up)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to apply migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
// Record migration
|
|
_, err = tx.Exec(
|
|
"INSERT INTO schema_migrations (version) VALUES (?)",
|
|
migration.Version)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
return fmt.Errorf("failed to record migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("failed to commit migration %d: %w", migration.Version, err)
|
|
}
|
|
|
|
log.Printf("Applied migration %d", migration.Version)
|
|
}
|
|
}
|
|
|
|
log.Println("Database migrations completed successfully")
|
|
return nil
|
|
}
|
|
|
|
func StoreUser(discordID, username string) error {
|
|
// Check if user exists
|
|
var exists bool
|
|
err := DBClient.QueryRow(
|
|
"SELECT EXISTS(SELECT 1 FROM users WHERE discord_id = ?)",
|
|
discordID).Scan(&exists)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to check user existence: %w", err)
|
|
}
|
|
|
|
// If user doesn't exist, insert them
|
|
if !exists {
|
|
_, err = DBClient.Exec(
|
|
"INSERT INTO users (discord_id, username) VALUES (?, ?)",
|
|
discordID, username)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store user: %w", err)
|
|
}
|
|
log.Printf("New user stored: %s (%s)", username, discordID)
|
|
}
|
|
|
|
return nil
|
|
}
|