package internal import ( "crypto/rand" "database/sql" "encoding/base64" "io" "log" "time" _ "github.com/mattn/go-sqlite3" ) const dbFile = "teachat.db" var db *sql.DB func InitDB() { var err error db, err = sql.Open("sqlite3", dbFile) if err != nil { log.Fatal(err) } query := ` CREATE TABLE IF NOT EXISTS config (key TEXT PRIMARY KEY, val TEXT); CREATE TABLE IF NOT EXISTS users ( pubkey TEXT PRIMARY KEY, username TEXT UNIQUE, identity_key BLOB, prekey BLOB, prekey_signature BLOB, created_at DATETIME ); CREATE TABLE IF NOT EXISTS rooms ( id TEXT PRIMARY KEY, name TEXT, creator TEXT, is_dm BOOLEAN, current_epoch INTEGER DEFAULT 0, created_at DATETIME ); CREATE TABLE IF NOT EXISTS room_members ( room_id TEXT, username TEXT, joined_at DATETIME, PRIMARY KEY (room_id, username), FOREIGN KEY (room_id) REFERENCES rooms(id), FOREIGN KEY (username) REFERENCES users(username) ); -- Stores the encrypted room key for a specific user for a specific epoch CREATE TABLE IF NOT EXISTS user_room_keys ( username TEXT, room_id TEXT, epoch INTEGER, encrypted_key BLOB, PRIMARY KEY (username, room_id, epoch), FOREIGN KEY (username) REFERENCES users(username), FOREIGN KEY (room_id) REFERENCES rooms(id) ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY, room_id TEXT, epoch INTEGER, timestamp DATETIME, sender TEXT, ciphertext BLOB, nonce BLOB, FOREIGN KEY (room_id) REFERENCES rooms(id) ); ` if _, err := db.Exec(query); err != nil { log.Fatal(err) } } func ensureSalt() []byte { row := db.QueryRow("SELECT val FROM config WHERE key='global_salt'") var b64 string if err := row.Scan(&b64); err == sql.ErrNoRows { salt := make([]byte, saltSize) if _, err := io.ReadFull(rand.Reader, salt); err != nil { log.Fatal(err) } b64 = base64.StdEncoding.EncodeToString(salt) db.Exec("INSERT INTO config (key, val) VALUES ('global_salt', ?)", b64) return salt } else if err != nil { log.Fatal(err) } salt, _ := base64.StdEncoding.DecodeString(b64) return salt } func authorizeUser(pubkey, username string, identityKey, prekey, prekeySignature []byte) error { _, err := db.Exec("INSERT INTO users (pubkey, username, identity_key, prekey, prekey_signature, created_at) VALUES (?, ?, ?, ?, ?, ?)", pubkey, username, identityKey, prekey, prekeySignature, time.Now()) return err } func getUsername(pubkey string) (string, error) { var username string err := db.QueryRow("SELECT username FROM users WHERE pubkey = ?", pubkey).Scan(&username) return username, err } func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte, err error) { err = db.QueryRow("SELECT identity_key, prekey, prekey_signature FROM users WHERE username = ?", username).Scan(&identityKey, &prekey, &prekeySignature) return } func getMemberIdentityKey(username string) ([32]byte, error) { var b []byte var key [32]byte err := db.QueryRow("SELECT identity_key FROM users WHERE username = ?", username).Scan(&b) if err != nil { return key, err } copy(key[:], b) return key, nil } func createRoom(roomID, name, creator string, isDM bool) error { _, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, current_epoch, created_at) VALUES (?, ?, ?, ?, 0, ?)", roomID, name, creator, isDM, time.Now()) return err } func getRoomIDByName(name string) (string, error) { var id string err := db.QueryRow("SELECT id FROM rooms WHERE name = ? AND is_dm = 0 LIMIT 1", name).Scan(&id) return id, err } func getRoomCurrentEpoch(roomID string) (int, error) { var epoch int err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch) return epoch, err } func incrementRoomEpoch(roomID string) (int, error) { var newEpoch int // Atomic increment err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch) return newEpoch, err } func deleteRoom(roomID string) error { tx, err := db.Begin() if err != nil { return err } if _, err := tx.Exec("DELETE FROM messages WHERE room_id = ?", roomID); err != nil { tx.Rollback() return err } if _, err := tx.Exec("DELETE FROM user_room_keys WHERE room_id = ?", roomID); err != nil { tx.Rollback() return err } if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil { tx.Rollback() return err } if _, err := tx.Exec("DELETE FROM rooms WHERE id = ?", roomID); err != nil { tx.Rollback() return err } return tx.Commit() } func leaveRoom(roomID, username string) error { _, err := db.Exec("DELETE FROM room_members WHERE room_id = ? AND username = ?", roomID, username) return err } func joinRoomMember(roomID, username string) error { // Check if already joined var count int db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count) if count > 0 { return nil } _, err := db.Exec("INSERT INTO room_members (room_id, username, joined_at) VALUES (?, ?, ?)", roomID, username, time.Now()) return err } func getRoomMembers(roomID string) ([]string, error) { rows, err := db.Query("SELECT username FROM room_members WHERE room_id = ?", roomID) if err != nil { return nil, err } defer rows.Close() var users []string for rows.Next() { var u string if err := rows.Scan(&u); err == nil { users = append(users, u) } } return users, nil } func storeUserRoomKey(username, roomID string, epoch int, encryptedKey []byte) error { _, err := db.Exec(` INSERT INTO user_room_keys (username, room_id, epoch, encrypted_key) VALUES (?, ?, ?, ?) ON CONFLICT(username, room_id, epoch) DO UPDATE SET encrypted_key = excluded.encrypted_key`, username, roomID, epoch, encryptedKey) return err } func getUserRoomKey(username, roomID string, epoch int) ([]byte, error) { var key []byte err := db.QueryRow("SELECT encrypted_key FROM user_room_keys WHERE username = ? AND room_id = ? AND epoch = ?", username, roomID, epoch).Scan(&key) return key, err } func listUserRooms(username string) ([]struct { ID string Name string Creator string IsDM bool }, error) { rows, err := db.Query(` SELECT r.id, r.name, COALESCE(r.creator, ''), r.is_dm FROM rooms r JOIN room_members rm ON r.id = rm.room_id WHERE rm.username = ? ORDER BY r.created_at DESC`, username) if err != nil { return nil, err } defer rows.Close() var rooms []struct { ID string Name string Creator string IsDM bool } for rows.Next() { var room struct { ID string Name string Creator string IsDM bool } if err := rows.Scan(&room.ID, &room.Name, &room.Creator, &room.IsDM); err != nil { continue } rooms = append(rooms, room) } return rooms, nil } func saveMessage(roomID, sender string, epoch int, ciphertext, nonce []byte) error { _, err := db.Exec("INSERT INTO messages (room_id, epoch, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?, ?)", roomID, epoch, time.Now(), sender, ciphertext, nonce) return err } func loadMessages(roomID string) ([]struct { Epoch int Timestamp time.Time Sender string Ciphertext []byte Nonce []byte }, error) { rows, err := db.Query("SELECT epoch, timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID) if err != nil { return nil, err } defer rows.Close() var messages []struct { Epoch int Timestamp time.Time Sender string Ciphertext []byte Nonce []byte } for rows.Next() { var msg struct { Epoch int Timestamp time.Time Sender string Ciphertext []byte Nonce []byte } if err := rows.Scan(&msg.Epoch, &msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil { continue } messages = append(messages, msg) } return messages, nil }