package internal import ( "crypto/rand" "database/sql" "encoding/base64" "io" "log" "time" _ "github.com/mattn/go-sqlite3" ) const dbFile = "teachat.db" var db *sql.DB func InitDB() { var err error db, err = sql.Open("sqlite3", dbFile) if err != nil { log.Fatal(err) } query := ` CREATE TABLE IF NOT EXISTS config (key TEXT PRIMARY KEY, val TEXT); CREATE TABLE IF NOT EXISTS users ( pubkey TEXT PRIMARY KEY, username TEXT UNIQUE, identity_key BLOB, prekey BLOB, prekey_signature BLOB, created_at DATETIME ); CREATE TABLE IF NOT EXISTS rooms ( id TEXT PRIMARY KEY, name TEXT, creator TEXT, is_dm BOOLEAN, created_at DATETIME ); CREATE TABLE IF NOT EXISTS room_members ( room_id TEXT, username TEXT, shared_secret BLOB, joined_at DATETIME, PRIMARY KEY (room_id, username), FOREIGN KEY (room_id) REFERENCES rooms(id), FOREIGN KEY (username) REFERENCES users(username) ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY, room_id TEXT, timestamp DATETIME, sender TEXT, ciphertext BLOB, nonce BLOB, FOREIGN KEY (room_id) REFERENCES rooms(id) ); ` if _, err := db.Exec(query); err != nil { log.Fatal(err) } // Migration for older DBs db.Exec("ALTER TABLE rooms ADD COLUMN creator TEXT") } func ensureSalt() []byte { row := db.QueryRow("SELECT val FROM config WHERE key='global_salt'") var b64 string if err := row.Scan(&b64); err == sql.ErrNoRows { salt := make([]byte, saltSize) if _, err := io.ReadFull(rand.Reader, salt); err != nil { log.Fatal(err) } b64 = base64.StdEncoding.EncodeToString(salt) db.Exec("INSERT INTO config (key, val) VALUES ('global_salt', ?)", b64) return salt } else if err != nil { log.Fatal(err) } salt, _ := base64.StdEncoding.DecodeString(b64) return salt } func authorizeUser(pubkey, username string, identityKey, prekey, prekeySignature []byte) error { _, err := db.Exec("INSERT INTO users (pubkey, username, identity_key, prekey, prekey_signature, created_at) VALUES (?, ?, ?, ?, ?, ?)", pubkey, username, identityKey, prekey, prekeySignature, time.Now()) return err } func getUsername(pubkey string) (string, error) { var username string err := db.QueryRow("SELECT username FROM users WHERE pubkey = ?", pubkey).Scan(&username) return username, err } func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte, err error) { err = db.QueryRow("SELECT identity_key, prekey, prekey_signature FROM users WHERE username = ?", username).Scan(&identityKey, &prekey, &prekeySignature) return } func listUsers() ([]string, error) { rows, err := db.Query("SELECT username FROM users ORDER BY username ASC") if err != nil { return nil, err } defer rows.Close() var users []string for rows.Next() { var username string if err := rows.Scan(&username); err != nil { continue } users = append(users, username) } return users, nil } func createRoom(roomID, name, creator string, isDM bool) error { _, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, created_at) VALUES (?, ?, ?, ?, ?)", roomID, name, creator, isDM, time.Now()) return err } func getRoomIDByName(name string) (string, error) { var id string err := db.QueryRow("SELECT id FROM rooms WHERE name = ? AND is_dm = 0 LIMIT 1", name).Scan(&id) return id, err } // getAnyRoomSecret gets the shared secret from ANY member of the room. // In a real decentralized system this wouldn't be possible, but here the server acts as the key distributor. func getAnyRoomSecret(roomID string) ([]byte, error) { var secret []byte err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? LIMIT 1", roomID).Scan(&secret) return secret, err } func deleteRoom(roomID string) error { tx, err := db.Begin() if err != nil { return err } if _, err := tx.Exec("DELETE FROM messages WHERE room_id = ?", roomID); err != nil { tx.Rollback() return err } if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil { tx.Rollback() return err } if _, err := tx.Exec("DELETE FROM rooms WHERE id = ?", roomID); err != nil { tx.Rollback() return err } return tx.Commit() } func leaveRoom(roomID, username string) error { _, err := db.Exec("DELETE FROM room_members WHERE room_id = ? AND username = ?", roomID, username) return err } func joinRoom(roomID, username string, sharedSecret []byte) error { // Check if already joined var count int db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count) if count > 0 { return nil } _, err := db.Exec("INSERT INTO room_members (room_id, username, shared_secret, joined_at) VALUES (?, ?, ?, ?)", roomID, username, sharedSecret, time.Now()) return err } func getMemberJoinedAt(roomID, username string) (time.Time, error) { var joinedAt time.Time err := db.QueryRow("SELECT joined_at FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&joinedAt) return joinedAt, err } func listUserRooms(username string) ([]struct { ID string Name string Creator string IsDM bool }, error) { rows, err := db.Query(` SELECT r.id, r.name, COALESCE(r.creator, ''), r.is_dm FROM rooms r JOIN room_members rm ON r.id = rm.room_id WHERE rm.username = ? ORDER BY r.created_at DESC`, username) if err != nil { return nil, err } defer rows.Close() var rooms []struct { ID string Name string Creator string IsDM bool } for rows.Next() { var room struct { ID string Name string Creator string IsDM bool } if err := rows.Scan(&room.ID, &room.Name, &room.Creator, &room.IsDM); err != nil { continue } rooms = append(rooms, room) } return rooms, nil } func getRoomSharedSecret(roomID, username string) ([]byte, error) { var secret []byte err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&secret) return secret, err } func saveMessage(roomID, sender string, ciphertext, nonce []byte) error { _, err := db.Exec("INSERT INTO messages (room_id, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?)", roomID, time.Now(), sender, ciphertext, nonce) return err } func loadMessages(roomID string) ([]struct { Timestamp time.Time Sender string Ciphertext []byte Nonce []byte }, error) { rows, err := db.Query("SELECT timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID) if err != nil { return nil, err } defer rows.Close() var messages []struct { Timestamp time.Time Sender string Ciphertext []byte Nonce []byte } for rows.Next() { var msg struct { Timestamp time.Time Sender string Ciphertext []byte Nonce []byte } if err := rows.Scan(&msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil { continue } messages = append(messages, msg) } return messages, nil }