Files
teachat/internal/database.go
2025-12-28 00:30:12 -07:00

269 lines
7.0 KiB
Go

package internal
import (
"crypto/rand"
"database/sql"
"encoding/base64"
"io"
"log"
"time"
_ "github.com/mattn/go-sqlite3"
)
const dbFile = "teachat.db"
var db *sql.DB
func InitDB() {
var err error
db, err = sql.Open("sqlite3", dbFile)
if err != nil {
log.Fatal(err)
}
query := `
CREATE TABLE IF NOT EXISTS config (key TEXT PRIMARY KEY, val TEXT);
CREATE TABLE IF NOT EXISTS users (
pubkey TEXT PRIMARY KEY,
username TEXT UNIQUE,
identity_key BLOB,
prekey BLOB,
prekey_signature BLOB,
created_at DATETIME
);
CREATE TABLE IF NOT EXISTS rooms (
id TEXT PRIMARY KEY,
name TEXT,
creator TEXT,
is_dm BOOLEAN,
created_at DATETIME
);
CREATE TABLE IF NOT EXISTS room_members (
room_id TEXT,
username TEXT,
shared_secret BLOB,
joined_at DATETIME,
PRIMARY KEY (room_id, username),
FOREIGN KEY (room_id) REFERENCES rooms(id),
FOREIGN KEY (username) REFERENCES users(username)
);
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY,
room_id TEXT,
timestamp DATETIME,
sender TEXT,
ciphertext BLOB,
nonce BLOB,
FOREIGN KEY (room_id) REFERENCES rooms(id)
);
`
if _, err := db.Exec(query); err != nil {
log.Fatal(err)
}
// Migration for older DBs
db.Exec("ALTER TABLE rooms ADD COLUMN creator TEXT")
}
func ensureSalt() []byte {
row := db.QueryRow("SELECT val FROM config WHERE key='global_salt'")
var b64 string
if err := row.Scan(&b64); err == sql.ErrNoRows {
salt := make([]byte, saltSize)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
log.Fatal(err)
}
b64 = base64.StdEncoding.EncodeToString(salt)
db.Exec("INSERT INTO config (key, val) VALUES ('global_salt', ?)", b64)
return salt
} else if err != nil {
log.Fatal(err)
}
salt, _ := base64.StdEncoding.DecodeString(b64)
return salt
}
func authorizeUser(pubkey, username string, identityKey, prekey, prekeySignature []byte) error {
_, err := db.Exec("INSERT INTO users (pubkey, username, identity_key, prekey, prekey_signature, created_at) VALUES (?, ?, ?, ?, ?, ?)",
pubkey, username, identityKey, prekey, prekeySignature, time.Now())
return err
}
func getUsername(pubkey string) (string, error) {
var username string
err := db.QueryRow("SELECT username FROM users WHERE pubkey = ?", pubkey).Scan(&username)
return username, err
}
func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte, err error) {
err = db.QueryRow("SELECT identity_key, prekey, prekey_signature FROM users WHERE username = ?", username).Scan(&identityKey, &prekey, &prekeySignature)
return
}
func listUsers() ([]string, error) {
rows, err := db.Query("SELECT username FROM users ORDER BY username ASC")
if err != nil {
return nil, err
}
defer rows.Close()
var users []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
continue
}
users = append(users, username)
}
return users, nil
}
func createRoom(roomID, name, creator string, isDM bool) error {
_, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, created_at) VALUES (?, ?, ?, ?, ?)",
roomID, name, creator, isDM, time.Now())
return err
}
func getRoomIDByName(name string) (string, error) {
var id string
err := db.QueryRow("SELECT id FROM rooms WHERE name = ? AND is_dm = 0 LIMIT 1", name).Scan(&id)
return id, err
}
// getAnyRoomSecret gets the shared secret from ANY member of the room.
// In a real decentralized system this wouldn't be possible, but here the server acts as the key distributor.
func getAnyRoomSecret(roomID string) ([]byte, error) {
var secret []byte
err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? LIMIT 1", roomID).Scan(&secret)
return secret, err
}
func deleteRoom(roomID string) error {
tx, err := db.Begin()
if err != nil {
return err
}
if _, err := tx.Exec("DELETE FROM messages WHERE room_id = ?", roomID); err != nil {
tx.Rollback()
return err
}
if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil {
tx.Rollback()
return err
}
if _, err := tx.Exec("DELETE FROM rooms WHERE id = ?", roomID); err != nil {
tx.Rollback()
return err
}
return tx.Commit()
}
func leaveRoom(roomID, username string) error {
_, err := db.Exec("DELETE FROM room_members WHERE room_id = ? AND username = ?", roomID, username)
return err
}
func joinRoom(roomID, username string, sharedSecret []byte) error {
// Check if already joined
var count int
db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count)
if count > 0 {
return nil
}
_, err := db.Exec("INSERT INTO room_members (room_id, username, shared_secret, joined_at) VALUES (?, ?, ?, ?)",
roomID, username, sharedSecret, time.Now())
return err
}
func getMemberJoinedAt(roomID, username string) (time.Time, error) {
var joinedAt time.Time
err := db.QueryRow("SELECT joined_at FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&joinedAt)
return joinedAt, err
}
func listUserRooms(username string) ([]struct {
ID string
Name string
Creator string
IsDM bool
}, error) {
rows, err := db.Query(`
SELECT r.id, r.name, COALESCE(r.creator, ''), r.is_dm
FROM rooms r
JOIN room_members rm ON r.id = rm.room_id
WHERE rm.username = ?
ORDER BY r.created_at DESC`, username)
if err != nil {
return nil, err
}
defer rows.Close()
var rooms []struct {
ID string
Name string
Creator string
IsDM bool
}
for rows.Next() {
var room struct {
ID string
Name string
Creator string
IsDM bool
}
if err := rows.Scan(&room.ID, &room.Name, &room.Creator, &room.IsDM); err != nil {
continue
}
rooms = append(rooms, room)
}
return rooms, nil
}
func getRoomSharedSecret(roomID, username string) ([]byte, error) {
var secret []byte
err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&secret)
return secret, err
}
func saveMessage(roomID, sender string, ciphertext, nonce []byte) error {
_, err := db.Exec("INSERT INTO messages (room_id, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?)",
roomID, time.Now(), sender, ciphertext, nonce)
return err
}
func loadMessages(roomID string) ([]struct {
Timestamp time.Time
Sender string
Ciphertext []byte
Nonce []byte
}, error) {
rows, err := db.Query("SELECT timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID)
if err != nil {
return nil, err
}
defer rows.Close()
var messages []struct {
Timestamp time.Time
Sender string
Ciphertext []byte
Nonce []byte
}
for rows.Next() {
var msg struct {
Timestamp time.Time
Sender string
Ciphertext []byte
Nonce []byte
}
if err := rows.Scan(&msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil {
continue
}
messages = append(messages, msg)
}
return messages, nil
}