Update
This commit is contained in:
@@ -11,7 +11,7 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
const dbFile = "teachat.db"
|
||||
const dbFile = "encrypted_chat.db"
|
||||
|
||||
var db *sql.DB
|
||||
|
||||
@@ -37,20 +37,31 @@ func InitDB() {
|
||||
name TEXT,
|
||||
creator TEXT,
|
||||
is_dm BOOLEAN,
|
||||
current_epoch INTEGER DEFAULT 0,
|
||||
created_at DATETIME
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS room_members (
|
||||
room_id TEXT,
|
||||
username TEXT,
|
||||
shared_secret BLOB,
|
||||
joined_at DATETIME,
|
||||
PRIMARY KEY (room_id, username),
|
||||
FOREIGN KEY (room_id) REFERENCES rooms(id),
|
||||
FOREIGN KEY (username) REFERENCES users(username)
|
||||
);
|
||||
-- Stores the encrypted room key for a specific user for a specific epoch
|
||||
CREATE TABLE IF NOT EXISTS user_room_keys (
|
||||
username TEXT,
|
||||
room_id TEXT,
|
||||
epoch INTEGER,
|
||||
encrypted_key BLOB,
|
||||
PRIMARY KEY (username, room_id, epoch),
|
||||
FOREIGN KEY (username) REFERENCES users(username),
|
||||
FOREIGN KEY (room_id) REFERENCES rooms(id)
|
||||
);
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id INTEGER PRIMARY KEY,
|
||||
room_id TEXT,
|
||||
epoch INTEGER,
|
||||
timestamp DATETIME,
|
||||
sender TEXT,
|
||||
ciphertext BLOB,
|
||||
@@ -61,9 +72,6 @@ func InitDB() {
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Migration for older DBs
|
||||
db.Exec("ALTER TABLE rooms ADD COLUMN creator TEXT")
|
||||
}
|
||||
|
||||
func ensureSalt() []byte {
|
||||
@@ -101,26 +109,19 @@ func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte,
|
||||
return
|
||||
}
|
||||
|
||||
func listUsers() ([]string, error) {
|
||||
rows, err := db.Query("SELECT username FROM users ORDER BY username ASC")
|
||||
func getMemberIdentityKey(username string) ([32]byte, error) {
|
||||
var b []byte
|
||||
var key [32]byte
|
||||
err := db.QueryRow("SELECT identity_key FROM users WHERE username = ?", username).Scan(&b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return key, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var users []string
|
||||
for rows.Next() {
|
||||
var username string
|
||||
if err := rows.Scan(&username); err != nil {
|
||||
continue
|
||||
}
|
||||
users = append(users, username)
|
||||
}
|
||||
return users, nil
|
||||
copy(key[:], b)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func createRoom(roomID, name, creator string, isDM bool) error {
|
||||
_, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, created_at) VALUES (?, ?, ?, ?, ?)",
|
||||
_, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, current_epoch, created_at) VALUES (?, ?, ?, ?, 0, ?)",
|
||||
roomID, name, creator, isDM, time.Now())
|
||||
return err
|
||||
}
|
||||
@@ -131,12 +132,17 @@ func getRoomIDByName(name string) (string, error) {
|
||||
return id, err
|
||||
}
|
||||
|
||||
// getAnyRoomSecret gets the shared secret from ANY member of the room.
|
||||
// In a real decentralized system this wouldn't be possible, but here the server acts as the key distributor.
|
||||
func getAnyRoomSecret(roomID string) ([]byte, error) {
|
||||
var secret []byte
|
||||
err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? LIMIT 1", roomID).Scan(&secret)
|
||||
return secret, err
|
||||
func getRoomCurrentEpoch(roomID string) (int, error) {
|
||||
var epoch int
|
||||
err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch)
|
||||
return epoch, err
|
||||
}
|
||||
|
||||
func incrementRoomEpoch(roomID string) (int, error) {
|
||||
var newEpoch int
|
||||
// Atomic increment
|
||||
err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch)
|
||||
return newEpoch, err
|
||||
}
|
||||
|
||||
func deleteRoom(roomID string) error {
|
||||
@@ -148,6 +154,10 @@ func deleteRoom(roomID string) error {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM user_room_keys WHERE room_id = ?", roomID); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
@@ -164,22 +174,48 @@ func leaveRoom(roomID, username string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func joinRoom(roomID, username string, sharedSecret []byte) error {
|
||||
func joinRoomMember(roomID, username string) error {
|
||||
// Check if already joined
|
||||
var count int
|
||||
db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count)
|
||||
if count > 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := db.Exec("INSERT INTO room_members (room_id, username, shared_secret, joined_at) VALUES (?, ?, ?, ?)",
|
||||
roomID, username, sharedSecret, time.Now())
|
||||
_, err := db.Exec("INSERT INTO room_members (room_id, username, joined_at) VALUES (?, ?, ?)",
|
||||
roomID, username, time.Now())
|
||||
return err
|
||||
}
|
||||
|
||||
func getMemberJoinedAt(roomID, username string) (time.Time, error) {
|
||||
var joinedAt time.Time
|
||||
err := db.QueryRow("SELECT joined_at FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&joinedAt)
|
||||
return joinedAt, err
|
||||
func getRoomMembers(roomID string) ([]string, error) {
|
||||
rows, err := db.Query("SELECT username FROM room_members WHERE room_id = ?", roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var users []string
|
||||
for rows.Next() {
|
||||
var u string
|
||||
if err := rows.Scan(&u); err == nil {
|
||||
users = append(users, u)
|
||||
}
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func storeUserRoomKey(username, roomID string, epoch int, encryptedKey []byte) error {
|
||||
_, err := db.Exec(`
|
||||
INSERT INTO user_room_keys (username, room_id, epoch, encrypted_key)
|
||||
VALUES (?, ?, ?, ?)
|
||||
ON CONFLICT(username, room_id, epoch) DO UPDATE SET encrypted_key = excluded.encrypted_key`,
|
||||
username, roomID, epoch, encryptedKey)
|
||||
return err
|
||||
}
|
||||
|
||||
func getUserRoomKey(username, roomID string, epoch int) ([]byte, error) {
|
||||
var key []byte
|
||||
err := db.QueryRow("SELECT encrypted_key FROM user_room_keys WHERE username = ? AND room_id = ? AND epoch = ?",
|
||||
username, roomID, epoch).Scan(&key)
|
||||
return key, err
|
||||
}
|
||||
|
||||
func listUserRooms(username string) ([]struct {
|
||||
@@ -220,31 +256,27 @@ func listUserRooms(username string) ([]struct {
|
||||
return rooms, nil
|
||||
}
|
||||
|
||||
func getRoomSharedSecret(roomID, username string) ([]byte, error) {
|
||||
var secret []byte
|
||||
err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&secret)
|
||||
return secret, err
|
||||
}
|
||||
|
||||
func saveMessage(roomID, sender string, ciphertext, nonce []byte) error {
|
||||
_, err := db.Exec("INSERT INTO messages (room_id, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?)",
|
||||
roomID, time.Now(), sender, ciphertext, nonce)
|
||||
func saveMessage(roomID, sender string, epoch int, ciphertext, nonce []byte) error {
|
||||
_, err := db.Exec("INSERT INTO messages (room_id, epoch, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
roomID, epoch, time.Now(), sender, ciphertext, nonce)
|
||||
return err
|
||||
}
|
||||
|
||||
func loadMessages(roomID string) ([]struct {
|
||||
Epoch int
|
||||
Timestamp time.Time
|
||||
Sender string
|
||||
Ciphertext []byte
|
||||
Nonce []byte
|
||||
}, error) {
|
||||
rows, err := db.Query("SELECT timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID)
|
||||
rows, err := db.Query("SELECT epoch, timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var messages []struct {
|
||||
Epoch int
|
||||
Timestamp time.Time
|
||||
Sender string
|
||||
Ciphertext []byte
|
||||
@@ -253,12 +285,13 @@ func loadMessages(roomID string) ([]struct {
|
||||
|
||||
for rows.Next() {
|
||||
var msg struct {
|
||||
Epoch int
|
||||
Timestamp time.Time
|
||||
Sender string
|
||||
Ciphertext []byte
|
||||
Nonce []byte
|
||||
}
|
||||
if err := rows.Scan(&msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil {
|
||||
if err := rows.Scan(&msg.Epoch, &msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil {
|
||||
continue
|
||||
}
|
||||
messages = append(messages, msg)
|
||||
|
||||
Reference in New Issue
Block a user