Files
sprintpadawan/lib/db.go
T
2026-04-23 22:33:33 -06:00

122 lines
2.6 KiB
Go

package lib
import (
"database/sql"
"log"
"time"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/bcrypt"
)
var DB *sql.DB
type User struct {
ID int
Username string
PasswordHash string
}
// init sqlite db — always creates app.db at project root (run from root)
func InitDB() {
var err error
DB, err = sql.Open("sqlite3", "./app.db")
if err != nil {
log.Fatal(err)
}
// make users table
userTable := `
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL
);`
_, err = DB.Exec(userTable)
if err != nil {
log.Fatal(err)
}
// make sessions table
sessionTable := `
CREATE TABLE IF NOT EXISTS sessions (
id TEXT PRIMARY KEY,
user_id INTEGER NOT NULL,
expires_at DATETIME NOT NULL,
FOREIGN KEY(user_id) REFERENCES users(id)
);`
_, err = DB.Exec(sessionTable)
if err != nil {
log.Fatal(err)
}
}
// hash password
func HashPassword(password string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
return string(bytes), err
}
// check password
func CheckPasswordHash(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
}
// create new user
func CreateUser(username, password string) error {
hash, err := HashPassword(password)
if err != nil {
return err
}
_, err = DB.Exec("INSERT INTO users (username, password_hash) VALUES (?, ?)", username, hash)
return err
}
// get user by name
func GetUserByUsername(username string) (*User, error) {
row := DB.QueryRow("SELECT id, username, password_hash FROM users WHERE username = ?", username)
u := &User{}
err := row.Scan(&u.ID, &u.Username, &u.PasswordHash)
if err != nil {
return nil, err
}
return u, nil
}
// create session token
func CreateSession(userID int) (string, error) {
sessionID := uuid.New().String()
expiresAt := time.Now().Add(24 * time.Hour)
_, err := DB.Exec("INSERT INTO sessions (id, user_id, expires_at) VALUES (?, ?, ?)", sessionID, userID, expiresAt)
if err != nil {
return "", err
}
return sessionID, nil
}
// get user from session
func GetUserFromSession(sessionID string) (*User, error) {
row := DB.QueryRow(`
SELECT u.id, u.username, u.password_hash
FROM users u
JOIN sessions s ON u.id = s.user_id
WHERE s.id = ? AND s.expires_at > ?
`, sessionID, time.Now())
u := &User{}
err := row.Scan(&u.ID, &u.Username, &u.PasswordHash)
if err != nil {
return nil, err
}
return u, nil
}
// delete session
func DeleteSession(sessionID string) error {
_, err := DB.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
return err
}