122 lines
2.6 KiB
Go
122 lines
2.6 KiB
Go
package lib
|
|
|
|
import (
|
|
"database/sql"
|
|
"log"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"golang.org/x/crypto/bcrypt"
|
|
)
|
|
|
|
var DB *sql.DB
|
|
|
|
type User struct {
|
|
ID int
|
|
Username string
|
|
PasswordHash string
|
|
}
|
|
|
|
// init sqlite db — always creates app.db at project root (run from root)
|
|
func InitDB() {
|
|
var err error
|
|
DB, err = sql.Open("sqlite3", "./app.db")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// make users table
|
|
userTable := `
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
username TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT NOT NULL
|
|
);`
|
|
_, err = DB.Exec(userTable)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// make sessions table
|
|
sessionTable := `
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
id TEXT PRIMARY KEY,
|
|
user_id INTEGER NOT NULL,
|
|
expires_at DATETIME NOT NULL,
|
|
FOREIGN KEY(user_id) REFERENCES users(id)
|
|
);`
|
|
_, err = DB.Exec(sessionTable)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
// hash password
|
|
func HashPassword(password string) (string, error) {
|
|
bytes, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
|
return string(bytes), err
|
|
}
|
|
|
|
// check password
|
|
func CheckPasswordHash(password, hash string) bool {
|
|
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
|
return err == nil
|
|
}
|
|
|
|
// create new user
|
|
func CreateUser(username, password string) error {
|
|
hash, err := HashPassword(password)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = DB.Exec("INSERT INTO users (username, password_hash) VALUES (?, ?)", username, hash)
|
|
return err
|
|
}
|
|
|
|
// get user by name
|
|
func GetUserByUsername(username string) (*User, error) {
|
|
row := DB.QueryRow("SELECT id, username, password_hash FROM users WHERE username = ?", username)
|
|
u := &User{}
|
|
err := row.Scan(&u.ID, &u.Username, &u.PasswordHash)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// create session token
|
|
func CreateSession(userID int) (string, error) {
|
|
sessionID := uuid.New().String()
|
|
expiresAt := time.Now().Add(24 * time.Hour)
|
|
|
|
_, err := DB.Exec("INSERT INTO sessions (id, user_id, expires_at) VALUES (?, ?, ?)", sessionID, userID, expiresAt)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return sessionID, nil
|
|
}
|
|
|
|
// get user from session
|
|
func GetUserFromSession(sessionID string) (*User, error) {
|
|
row := DB.QueryRow(`
|
|
SELECT u.id, u.username, u.password_hash
|
|
FROM users u
|
|
JOIN sessions s ON u.id = s.user_id
|
|
WHERE s.id = ? AND s.expires_at > ?
|
|
`, sessionID, time.Now())
|
|
|
|
u := &User{}
|
|
err := row.Scan(&u.ID, &u.Username, &u.PasswordHash)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return u, nil
|
|
}
|
|
|
|
// delete session
|
|
func DeleteSession(sessionID string) error {
|
|
_, err := DB.Exec("DELETE FROM sessions WHERE id = ?", sessionID)
|
|
return err
|
|
}
|