319 lines
8.6 KiB
Go
319 lines
8.6 KiB
Go
package internal
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"database/sql"
|
|
"encoding/base64"
|
|
"io"
|
|
"log"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
const dbFile = "teachat.db"
|
|
|
|
var db *sql.DB
|
|
|
|
func InitDB() {
|
|
var err error
|
|
db, err = sql.Open("sqlite3", dbFile)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
query := `
|
|
CREATE TABLE IF NOT EXISTS config (key TEXT PRIMARY KEY, val TEXT);
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
pubkey TEXT PRIMARY KEY,
|
|
username TEXT UNIQUE,
|
|
identity_key BLOB,
|
|
prekey BLOB,
|
|
prekey_signature BLOB,
|
|
enc_priv_key BLOB,
|
|
priv_key_salt BLOB,
|
|
priv_key_nonce BLOB,
|
|
created_at DATETIME
|
|
);
|
|
CREATE TABLE IF NOT EXISTS rooms (
|
|
id TEXT PRIMARY KEY,
|
|
name TEXT,
|
|
creator TEXT,
|
|
is_dm BOOLEAN,
|
|
current_epoch INTEGER DEFAULT 0,
|
|
created_at DATETIME
|
|
);
|
|
CREATE TABLE IF NOT EXISTS room_members (
|
|
room_id TEXT,
|
|
username TEXT,
|
|
joined_at DATETIME,
|
|
PRIMARY KEY (room_id, username),
|
|
FOREIGN KEY (room_id) REFERENCES rooms(id),
|
|
FOREIGN KEY (username) REFERENCES users(username)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS user_room_keys (
|
|
username TEXT,
|
|
room_id TEXT,
|
|
epoch INTEGER,
|
|
encrypted_key BLOB,
|
|
PRIMARY KEY (username, room_id, epoch),
|
|
FOREIGN KEY (username) REFERENCES users(username),
|
|
FOREIGN KEY (room_id) REFERENCES rooms(id)
|
|
);
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id INTEGER PRIMARY KEY,
|
|
room_id TEXT,
|
|
epoch INTEGER,
|
|
timestamp DATETIME,
|
|
sender TEXT,
|
|
ciphertext BLOB,
|
|
nonce BLOB,
|
|
FOREIGN KEY (room_id) REFERENCES rooms(id)
|
|
);
|
|
`
|
|
if _, err := db.Exec(query); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}
|
|
|
|
func ensureSalt() []byte {
|
|
row := db.QueryRow("SELECT val FROM config WHERE key='global_salt'")
|
|
var b64 string
|
|
if err := row.Scan(&b64); err == sql.ErrNoRows {
|
|
salt := make([]byte, saltSize)
|
|
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
b64 = base64.StdEncoding.EncodeToString(salt)
|
|
db.Exec("INSERT INTO config (key, val) VALUES ('global_salt', ?)", b64)
|
|
return salt
|
|
} else if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
salt, _ := base64.StdEncoding.DecodeString(b64)
|
|
return salt
|
|
}
|
|
|
|
func authorizeUser(pubkey, username string, identityKey, prekey, prekeySignature []byte) error {
|
|
_, err := db.Exec("INSERT INTO users (pubkey, username, identity_key, prekey, prekey_signature, created_at) VALUES (?, ?, ?, ?, ?, ?)",
|
|
pubkey, username, identityKey, prekey, prekeySignature, time.Now())
|
|
return err
|
|
}
|
|
|
|
func storeUserEncryptedKey(username string, encKey, salt, nonce []byte) error {
|
|
_, err := db.Exec("UPDATE users SET enc_priv_key=?, priv_key_salt=?, priv_key_nonce=? WHERE username=?",
|
|
encKey, salt, nonce, username)
|
|
return err
|
|
}
|
|
|
|
func getUserEncryptedKey(username string) (encKey, salt, nonce []byte, err error) {
|
|
err = db.QueryRow("SELECT enc_priv_key, priv_key_salt, priv_key_nonce FROM users WHERE username = ?", username).Scan(&encKey, &salt, &nonce)
|
|
return
|
|
}
|
|
|
|
func getUsername(pubkey string) (string, error) {
|
|
var username string
|
|
err := db.QueryRow("SELECT username FROM users WHERE pubkey = ?", pubkey).Scan(&username)
|
|
return username, err
|
|
}
|
|
|
|
func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte, err error) {
|
|
err = db.QueryRow("SELECT identity_key, prekey, prekey_signature FROM users WHERE username = ?", username).Scan(&identityKey, &prekey, &prekeySignature)
|
|
return
|
|
}
|
|
|
|
func getMemberIdentityKey(username string) ([32]byte, error) {
|
|
var b []byte
|
|
var key [32]byte
|
|
err := db.QueryRow("SELECT identity_key FROM users WHERE username = ?", username).Scan(&b)
|
|
if err != nil {
|
|
return key, err
|
|
}
|
|
copy(key[:], b)
|
|
return key, nil
|
|
}
|
|
|
|
func createRoom(roomID, name, creator string, isDM bool) error {
|
|
_, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, current_epoch, created_at) VALUES (?, ?, ?, ?, 0, ?)",
|
|
roomID, name, creator, isDM, time.Now())
|
|
return err
|
|
}
|
|
|
|
func getRoomIDByName(name string) (string, error) {
|
|
var id string
|
|
err := db.QueryRow("SELECT id FROM rooms WHERE name = ? AND is_dm = 0 LIMIT 1", name).Scan(&id)
|
|
return id, err
|
|
}
|
|
|
|
func getRoomCreator(roomID string) (string, error) {
|
|
var creator string
|
|
err := db.QueryRow("SELECT creator FROM rooms WHERE id = ?", roomID).Scan(&creator)
|
|
return creator, err
|
|
}
|
|
|
|
func getRoomCurrentEpoch(roomID string) (int, error) {
|
|
var epoch int
|
|
err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch)
|
|
return epoch, err
|
|
}
|
|
|
|
func incrementRoomEpoch(roomID string) (int, error) {
|
|
var newEpoch int
|
|
err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch)
|
|
return newEpoch, err
|
|
}
|
|
|
|
func deleteRoom(roomID string) error {
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, err := tx.Exec("DELETE FROM messages WHERE room_id = ?", roomID); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
if _, err := tx.Exec("DELETE FROM user_room_keys WHERE room_id = ?", roomID); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
if _, err := tx.Exec("DELETE FROM rooms WHERE id = ?", roomID); err != nil {
|
|
tx.Rollback()
|
|
return err
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func leaveRoom(roomID, username string) error {
|
|
_, err := db.Exec("DELETE FROM room_members WHERE room_id = ? AND username = ?", roomID, username)
|
|
return err
|
|
}
|
|
|
|
func joinRoomMember(roomID, username string) error {
|
|
var count int
|
|
db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count)
|
|
if count > 0 {
|
|
return nil
|
|
}
|
|
_, err := db.Exec("INSERT INTO room_members (room_id, username, joined_at) VALUES (?, ?, ?)",
|
|
roomID, username, time.Now())
|
|
return err
|
|
}
|
|
|
|
func getRoomMembers(roomID string) ([]string, error) {
|
|
rows, err := db.Query("SELECT username FROM room_members WHERE room_id = ?", roomID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var users []string
|
|
for rows.Next() {
|
|
var u string
|
|
if err := rows.Scan(&u); err == nil {
|
|
users = append(users, u)
|
|
}
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func storeUserRoomKey(username, roomID string, epoch int, encryptedKey []byte) error {
|
|
_, err := db.Exec(`
|
|
INSERT INTO user_room_keys (username, room_id, epoch, encrypted_key)
|
|
VALUES (?, ?, ?, ?)
|
|
ON CONFLICT(username, room_id, epoch) DO UPDATE SET encrypted_key = excluded.encrypted_key`,
|
|
username, roomID, epoch, encryptedKey)
|
|
return err
|
|
}
|
|
|
|
func getUserRoomKey(username, roomID string, epoch int) ([]byte, error) {
|
|
var key []byte
|
|
err := db.QueryRow("SELECT encrypted_key FROM user_room_keys WHERE username = ? AND room_id = ? AND epoch = ?",
|
|
username, roomID, epoch).Scan(&key)
|
|
return key, err
|
|
}
|
|
|
|
func listUserRooms(username string) ([]struct {
|
|
ID string
|
|
Name string
|
|
Creator string
|
|
IsDM bool
|
|
}, error) {
|
|
rows, err := db.Query(`
|
|
SELECT r.id, r.name, COALESCE(r.creator, ''), r.is_dm
|
|
FROM rooms r
|
|
JOIN room_members rm ON r.id = rm.room_id
|
|
WHERE rm.username = ?
|
|
ORDER BY r.created_at DESC`, username)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var rooms []struct {
|
|
ID string
|
|
Name string
|
|
Creator string
|
|
IsDM bool
|
|
}
|
|
for rows.Next() {
|
|
var room struct {
|
|
ID string
|
|
Name string
|
|
Creator string
|
|
IsDM bool
|
|
}
|
|
if err := rows.Scan(&room.ID, &room.Name, &room.Creator, &room.IsDM); err != nil {
|
|
continue
|
|
}
|
|
rooms = append(rooms, room)
|
|
}
|
|
return rooms, nil
|
|
}
|
|
|
|
func saveMessage(roomID, sender string, epoch int, ciphertext, nonce []byte) error {
|
|
_, err := db.Exec("INSERT INTO messages (room_id, epoch, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?, ?)",
|
|
roomID, epoch, time.Now(), sender, ciphertext, nonce)
|
|
return err
|
|
}
|
|
|
|
func loadMessages(roomID string) ([]struct {
|
|
Epoch int
|
|
Timestamp time.Time
|
|
Sender string
|
|
Ciphertext []byte
|
|
Nonce []byte
|
|
}, error) {
|
|
rows, err := db.Query("SELECT epoch, timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var messages []struct {
|
|
Epoch int
|
|
Timestamp time.Time
|
|
Sender string
|
|
Ciphertext []byte
|
|
Nonce []byte
|
|
}
|
|
|
|
for rows.Next() {
|
|
var msg struct {
|
|
Epoch int
|
|
Timestamp time.Time
|
|
Sender string
|
|
Ciphertext []byte
|
|
Nonce []byte
|
|
}
|
|
if err := rows.Scan(&msg.Epoch, &msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil {
|
|
continue
|
|
}
|
|
messages = append(messages, msg)
|
|
}
|
|
|
|
return messages, nil
|
|
}
|