This commit is contained in:
2025-12-28 00:55:33 -07:00
parent d1fe671704
commit a046e0eb0e
3 changed files with 256 additions and 162 deletions

View File

@@ -6,6 +6,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"errors"
"io" "io"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
@@ -18,20 +19,12 @@ const (
nonceSize = 12 nonceSize = 12
) )
// IdentityKeyPair represents a user's long-term identity key
type IdentityKeyPair struct { type IdentityKeyPair struct {
PublicKey [32]byte PublicKey [32]byte
PrivateKey [32]byte PrivateKey [32]byte
} }
// PrekeyBundle contains public keys for establishing a session // Creates a new Curve25519 key pair for identity
type PrekeyBundle struct {
IdentityKey [32]byte
Prekey [32]byte
PrekeySignature []byte
}
// generateIdentityKeyPair creates a new Curve25519 key pair for identity
func generateIdentityKeyPair() (*IdentityKeyPair, error) { func generateIdentityKeyPair() (*IdentityKeyPair, error) {
var privateKey [32]byte var privateKey [32]byte
if _, err := io.ReadFull(rand.Reader, privateKey[:]); err != nil { if _, err := io.ReadFull(rand.Reader, privateKey[:]); err != nil {
@@ -47,7 +40,7 @@ func generateIdentityKeyPair() (*IdentityKeyPair, error) {
}, nil }, nil
} }
// generateSignedPrekey creates a prekey signed with an Ed25519 signing key // Creates a prekey signed with an Ed25519 signing key
func generateSignedPrekey() ([32]byte, []byte, error) { func generateSignedPrekey() ([32]byte, []byte, error) {
var prekey [32]byte var prekey [32]byte
var prekeyPriv [32]byte var prekeyPriv [32]byte
@@ -64,23 +57,20 @@ func generateSignedPrekey() ([32]byte, []byte, error) {
// Sign the prekey // Sign the prekey
signature := ed25519.Sign(signingPriv, prekey[:]) signature := ed25519.Sign(signingPriv, prekey[:])
// In practice, we'd store signingPub to verify later. For now, we include it in signature
combinedSig := append(signingPub, signature...) combinedSig := append(signingPub, signature...)
return prekey, combinedSig, nil return prekey, combinedSig, nil
} }
// performDH does an ECDH key exchange // Does an ECDH key exchange
func performDH(privateKey, publicKey [32]byte) ([32]byte, error) { func performDH(privateKey, publicKey [32]byte) ([32]byte, error) {
var sharedSecret [32]byte var sharedSecret [32]byte
curve25519.ScalarMult(&sharedSecret, &privateKey, &publicKey) curve25519.ScalarMult(&sharedSecret, &privateKey, &publicKey)
return sharedSecret, nil return sharedSecret, nil
} }
// deriveSharedSecret combines multiple DH outputs to create a shared secret (simplified X3DH) // Combines multiple DH outputs to create a shared secret
func deriveSharedSecret(dh1, dh2, dh3 [32]byte) []byte { func deriveSharedSecret(dh1, dh2, dh3 [32]byte) []byte {
// KDF: hash all DH outputs together
h := sha256.New() h := sha256.New()
h.Write(dh1[:]) h.Write(dh1[:])
h.Write(dh2[:]) h.Write(dh2[:])
@@ -88,7 +78,7 @@ func deriveSharedSecret(dh1, dh2, dh3 [32]byte) []byte {
return h.Sum(nil) return h.Sum(nil)
} }
// encryptUserKey encrypts a user's private key with their passphrase // Encrypts a user's private key with their passphrase
func encryptUserKey(privateKey [32]byte, passphrase string) ([]byte, []byte, []byte, error) { func encryptUserKey(privateKey [32]byte, passphrase string) ([]byte, []byte, []byte, error) {
salt := make([]byte, saltSize) salt := make([]byte, saltSize)
if _, err := io.ReadFull(rand.Reader, salt); err != nil { if _, err := io.ReadFull(rand.Reader, salt); err != nil {
@@ -104,7 +94,7 @@ func encryptUserKey(privateKey [32]byte, passphrase string) ([]byte, []byte, []b
return ciphertext, nonce, salt, nil return ciphertext, nonce, salt, nil
} }
// decryptUserKey decrypts a user's private key with their passphrase // Decrypts a user's private key with their passphrase
func decryptUserKey(ciphertext, nonce, salt []byte, passphrase string) ([32]byte, error) { func decryptUserKey(ciphertext, nonce, salt []byte, passphrase string) ([32]byte, error) {
var privateKey [32]byte var privateKey [32]byte
key := deriveKey(passphrase, salt) key := deriveKey(passphrase, salt)
@@ -112,7 +102,7 @@ func decryptUserKey(ciphertext, nonce, salt []byte, passphrase string) ([32]byte
if err != nil { if err != nil {
return privateKey, err return privateKey, err
} }
copy(privateKey[:], plain) copy(privateKey[:], []byte(plain))
return privateKey, nil return privateKey, nil
} }
@@ -152,3 +142,68 @@ func decryptMsg(ciphertext, nonce, key []byte) (string, error) {
} }
return string(plain), nil return string(plain), nil
} }
// Format: EphemeralPub (32) || Nonce (12) || Ciphertext (...)
func EncryptKeyForUser(recipientPubKey [32]byte, payload []byte) ([]byte, error) {
var ephPriv, ephPub [32]byte
if _, err := io.ReadFull(rand.Reader, ephPriv[:]); err != nil {
return nil, err
}
curve25519.ScalarBaseMult(&ephPub, &ephPriv)
sharedSecret, _ := performDH(ephPriv, recipientPubKey)
kdf := sha256.Sum256(sharedSecret[:])
encryptionKey := kdf[:]
block, err := aes.NewCipher(encryptionKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nil, nonce, payload, nil)
out := make([]byte, 0, 32+len(nonce)+len(ciphertext))
out = append(out, ephPub[:]...)
out = append(out, nonce...)
out = append(out, ciphertext...)
return out, nil
}
// Decrypts a blob using the user's private identity key.
func DecryptKeyForUser(myPrivKey [32]byte, blob []byte) ([]byte, error) {
if len(blob) < 32+12 {
return nil, errors.New("invalid key blob size")
}
var ephPub [32]byte
copy(ephPub[:], blob[:32])
nonce := blob[32 : 32+12]
ciphertext := blob[32+12:]
sharedSecret, _ := performDH(myPrivKey, ephPub)
kdf := sha256.Sum256(sharedSecret[:])
decryptionKey := kdf[:]
block, err := aes.NewCipher(decryptionKey)
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
return gcm.Open(nil, nonce, ciphertext, nil)
}

View File

@@ -11,7 +11,7 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
const dbFile = "teachat.db" const dbFile = "encrypted_chat.db"
var db *sql.DB var db *sql.DB
@@ -37,20 +37,31 @@ func InitDB() {
name TEXT, name TEXT,
creator TEXT, creator TEXT,
is_dm BOOLEAN, is_dm BOOLEAN,
current_epoch INTEGER DEFAULT 0,
created_at DATETIME created_at DATETIME
); );
CREATE TABLE IF NOT EXISTS room_members ( CREATE TABLE IF NOT EXISTS room_members (
room_id TEXT, room_id TEXT,
username TEXT, username TEXT,
shared_secret BLOB,
joined_at DATETIME, joined_at DATETIME,
PRIMARY KEY (room_id, username), PRIMARY KEY (room_id, username),
FOREIGN KEY (room_id) REFERENCES rooms(id), FOREIGN KEY (room_id) REFERENCES rooms(id),
FOREIGN KEY (username) REFERENCES users(username) FOREIGN KEY (username) REFERENCES users(username)
); );
-- Stores the encrypted room key for a specific user for a specific epoch
CREATE TABLE IF NOT EXISTS user_room_keys (
username TEXT,
room_id TEXT,
epoch INTEGER,
encrypted_key BLOB,
PRIMARY KEY (username, room_id, epoch),
FOREIGN KEY (username) REFERENCES users(username),
FOREIGN KEY (room_id) REFERENCES rooms(id)
);
CREATE TABLE IF NOT EXISTS messages ( CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
room_id TEXT, room_id TEXT,
epoch INTEGER,
timestamp DATETIME, timestamp DATETIME,
sender TEXT, sender TEXT,
ciphertext BLOB, ciphertext BLOB,
@@ -61,9 +72,6 @@ func InitDB() {
if _, err := db.Exec(query); err != nil { if _, err := db.Exec(query); err != nil {
log.Fatal(err) log.Fatal(err)
} }
// Migration for older DBs
db.Exec("ALTER TABLE rooms ADD COLUMN creator TEXT")
} }
func ensureSalt() []byte { func ensureSalt() []byte {
@@ -101,26 +109,19 @@ func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte,
return return
} }
func listUsers() ([]string, error) { func getMemberIdentityKey(username string) ([32]byte, error) {
rows, err := db.Query("SELECT username FROM users ORDER BY username ASC") var b []byte
var key [32]byte
err := db.QueryRow("SELECT identity_key FROM users WHERE username = ?", username).Scan(&b)
if err != nil { if err != nil {
return nil, err return key, err
} }
defer rows.Close() copy(key[:], b)
return key, nil
var users []string
for rows.Next() {
var username string
if err := rows.Scan(&username); err != nil {
continue
}
users = append(users, username)
}
return users, nil
} }
func createRoom(roomID, name, creator string, isDM bool) error { func createRoom(roomID, name, creator string, isDM bool) error {
_, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, created_at) VALUES (?, ?, ?, ?, ?)", _, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, current_epoch, created_at) VALUES (?, ?, ?, ?, 0, ?)",
roomID, name, creator, isDM, time.Now()) roomID, name, creator, isDM, time.Now())
return err return err
} }
@@ -131,12 +132,17 @@ func getRoomIDByName(name string) (string, error) {
return id, err return id, err
} }
// getAnyRoomSecret gets the shared secret from ANY member of the room. func getRoomCurrentEpoch(roomID string) (int, error) {
// In a real decentralized system this wouldn't be possible, but here the server acts as the key distributor. var epoch int
func getAnyRoomSecret(roomID string) ([]byte, error) { err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch)
var secret []byte return epoch, err
err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? LIMIT 1", roomID).Scan(&secret) }
return secret, err
func incrementRoomEpoch(roomID string) (int, error) {
var newEpoch int
// Atomic increment
err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch)
return newEpoch, err
} }
func deleteRoom(roomID string) error { func deleteRoom(roomID string) error {
@@ -148,6 +154,10 @@ func deleteRoom(roomID string) error {
tx.Rollback() tx.Rollback()
return err return err
} }
if _, err := tx.Exec("DELETE FROM user_room_keys WHERE room_id = ?", roomID); err != nil {
tx.Rollback()
return err
}
if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil { if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil {
tx.Rollback() tx.Rollback()
return err return err
@@ -164,22 +174,48 @@ func leaveRoom(roomID, username string) error {
return err return err
} }
func joinRoom(roomID, username string, sharedSecret []byte) error { func joinRoomMember(roomID, username string) error {
// Check if already joined // Check if already joined
var count int var count int
db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count) db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count)
if count > 0 { if count > 0 {
return nil return nil
} }
_, err := db.Exec("INSERT INTO room_members (room_id, username, shared_secret, joined_at) VALUES (?, ?, ?, ?)", _, err := db.Exec("INSERT INTO room_members (room_id, username, joined_at) VALUES (?, ?, ?)",
roomID, username, sharedSecret, time.Now()) roomID, username, time.Now())
return err return err
} }
func getMemberJoinedAt(roomID, username string) (time.Time, error) { func getRoomMembers(roomID string) ([]string, error) {
var joinedAt time.Time rows, err := db.Query("SELECT username FROM room_members WHERE room_id = ?", roomID)
err := db.QueryRow("SELECT joined_at FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&joinedAt) if err != nil {
return joinedAt, err return nil, err
}
defer rows.Close()
var users []string
for rows.Next() {
var u string
if err := rows.Scan(&u); err == nil {
users = append(users, u)
}
}
return users, nil
}
func storeUserRoomKey(username, roomID string, epoch int, encryptedKey []byte) error {
_, err := db.Exec(`
INSERT INTO user_room_keys (username, room_id, epoch, encrypted_key)
VALUES (?, ?, ?, ?)
ON CONFLICT(username, room_id, epoch) DO UPDATE SET encrypted_key = excluded.encrypted_key`,
username, roomID, epoch, encryptedKey)
return err
}
func getUserRoomKey(username, roomID string, epoch int) ([]byte, error) {
var key []byte
err := db.QueryRow("SELECT encrypted_key FROM user_room_keys WHERE username = ? AND room_id = ? AND epoch = ?",
username, roomID, epoch).Scan(&key)
return key, err
} }
func listUserRooms(username string) ([]struct { func listUserRooms(username string) ([]struct {
@@ -220,31 +256,27 @@ func listUserRooms(username string) ([]struct {
return rooms, nil return rooms, nil
} }
func getRoomSharedSecret(roomID, username string) ([]byte, error) { func saveMessage(roomID, sender string, epoch int, ciphertext, nonce []byte) error {
var secret []byte _, err := db.Exec("INSERT INTO messages (room_id, epoch, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?, ?)",
err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&secret) roomID, epoch, time.Now(), sender, ciphertext, nonce)
return secret, err
}
func saveMessage(roomID, sender string, ciphertext, nonce []byte) error {
_, err := db.Exec("INSERT INTO messages (room_id, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?)",
roomID, time.Now(), sender, ciphertext, nonce)
return err return err
} }
func loadMessages(roomID string) ([]struct { func loadMessages(roomID string) ([]struct {
Epoch int
Timestamp time.Time Timestamp time.Time
Sender string Sender string
Ciphertext []byte Ciphertext []byte
Nonce []byte Nonce []byte
}, error) { }, error) {
rows, err := db.Query("SELECT timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID) rows, err := db.Query("SELECT epoch, timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
var messages []struct { var messages []struct {
Epoch int
Timestamp time.Time Timestamp time.Time
Sender string Sender string
Ciphertext []byte Ciphertext []byte
@@ -253,12 +285,13 @@ func loadMessages(roomID string) ([]struct {
for rows.Next() { for rows.Next() {
var msg struct { var msg struct {
Epoch int
Timestamp time.Time Timestamp time.Time
Sender string Sender string
Ciphertext []byte Ciphertext []byte
Nonce []byte Nonce []byte
} }
if err := rows.Scan(&msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil { if err := rows.Scan(&msg.Epoch, &msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil {
continue continue
} }
messages = append(messages, msg) messages = append(messages, msg)

View File

@@ -17,13 +17,13 @@ import (
) )
var ( var (
msgStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")).PaddingLeft(1) msgStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")).PaddingLeft(1)
senderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("205")).Bold(true) senderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("205")).Bold(true)
errStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Bold(true) errStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Bold(true)
sysStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Italic(true) sysStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Italic(true)
roomStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("86")).Bold(true) roomStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("86")).Bold(true)
dmStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("141")).Bold(true) dmStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("141")).Bold(true)
noteStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("228")).Bold(true) noteStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("228")).Bold(true)
commandStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Italic(true) commandStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Italic(true)
timeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("246")) timeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("246"))
lockedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Italic(true) lockedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Italic(true)
@@ -49,16 +49,17 @@ type model struct {
privKeySalt []byte privKeySalt []byte
privKeyNonce []byte privKeyNonce []byte
currentRoomID string currentRoomID string
currentRoomKey []byte currentRoomEpoch int
currentRoomJoinedAt time.Time // We cache keys: Epoch -> Plaintext Key
rooms []struct { roomKeyCache map[int][]byte
rooms []struct {
ID string ID string
Name string Name string
Creator string Creator string
IsDM bool IsDM bool
} }
availableUsers []string
err error err error
@@ -117,6 +118,7 @@ func initialModel(s ssh.Session) model {
input: ti, input: ti,
viewport: vp, viewport: vp,
updateChan: updates.subscribe(), updateChan: updates.subscribe(),
roomKeyCache: make(map[int][]byte),
} }
} }
@@ -225,6 +227,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
if val == "/leave" { if val == "/leave" {
leaveRoom(m.currentRoomID, m.username) leaveRoom(m.currentRoomID, m.username)
m.rotateRoomKey(m.currentRoomID) // Rotate key on leave
m.exitChat() m.exitChat()
return m, tea.ClearScreen return m, tea.ClearScreen
} }
@@ -269,8 +272,9 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
func (m *model) exitChat() { func (m *model) exitChat() {
m.state = 1 m.state = 1
m.currentRoomID = "" m.currentRoomID = ""
m.currentRoomKey = nil m.currentRoomEpoch = 0
m.messages = nil m.messages = nil
m.roomKeyCache = make(map[int][]byte)
m.viewport.SetContent("") m.viewport.SetContent("")
m.loadRooms() m.loadRooms()
m.input.Reset() m.input.Reset()
@@ -299,7 +303,6 @@ func (m *model) handleRoomListInput(text string) {
m.handleJoinRoom(roomName) m.handleJoinRoom(roomName)
return return
} }
// Keep /new as an alias for /join
if strings.HasPrefix(text, "/new ") { if strings.HasPrefix(text, "/new ") {
roomName := strings.TrimPrefix(text, "/new ") roomName := strings.TrimPrefix(text, "/new ")
m.handleJoinRoom(roomName) m.handleJoinRoom(roomName)
@@ -319,81 +322,74 @@ func (m *model) handleRoomListInput(text string) {
} }
func (m *model) handleJoinRoom(roomName string) { func (m *model) handleJoinRoom(roomName string) {
// Check if room exists
existingID, err := getRoomIDByName(roomName) existingID, err := getRoomIDByName(roomName)
// Room Exists: Join it
if err == nil && existingID != "" { if err == nil && existingID != "" {
// Need key // Join existing
secret, err := getAnyRoomSecret(existingID) if err := joinRoomMember(existingID, m.username); err != nil {
if err != nil {
m.err = fmt.Errorf("could not join: room key unreachable")
return
}
if err := joinRoom(existingID, m.username, secret); err != nil {
m.err = err m.err = err
return return
} }
// ROTATE KEY so new user gets a key, but doesn't get old keys
m.rotateRoomKey(existingID)
m.enterRoom(existingID, roomName) m.enterRoom(existingID, roomName)
return return
} }
// Room Does Not Exist: Create it // Create new
roomID := generateRoomID() roomID := generateRoomID()
if err := createRoom(roomID, roomName, m.username, false); err != nil { if err := createRoom(roomID, roomName, m.username, false); err != nil {
m.err = err m.err = err
return return
} }
if err := joinRoomMember(roomID, m.username); err != nil {
sharedSecret := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, sharedSecret); err != nil {
m.err = err m.err = err
return return
} }
// Initial Key
if err := joinRoom(roomID, m.username, sharedSecret); err != nil { m.rotateRoomKey(roomID)
m.err = err
return
}
m.enterRoom(roomID, roomName) m.enterRoom(roomID, roomName)
} }
func (m *model) rotateRoomKey(roomID string) {
newKey := make([]byte, 32)
io.ReadFull(rand.Reader, newKey)
newEpoch, err := incrementRoomEpoch(roomID)
if err != nil {
m.err = err
return
}
members, err := getRoomMembers(roomID)
if err != nil {
m.err = err
return
}
for _, user := range members {
idKey, err := getMemberIdentityKey(user)
if err != nil {
continue
}
encrypted, err := EncryptKeyForUser(idKey, newKey)
if err != nil {
continue
}
storeUserRoomKey(user, roomID, newEpoch, encrypted)
}
}
func (m *model) handleSelectUserForDM(username string) { func (m *model) handleSelectUserForDM(username string) {
if username == "/cancel" { if username == "/cancel" {
m.state = 1 m.state = 1
m.input.Placeholder = "Enter room # or: /join <name>, /dm <username>, /list" m.input.Placeholder = "Enter room # or: /join <name>, /dm <username>, /list"
return return
} }
if username == m.username { if username == m.username {
m.createNoteToSelf() m.createNoteToSelf()
return return
} }
identityKey, prekey, prekeySignature, err := getUserKeys(username)
if err != nil {
m.err = fmt.Errorf("user not found: %w", err)
return
}
var theirIdentityKey, theirPrekey [32]byte
copy(theirIdentityKey[:], identityKey)
copy(theirPrekey[:], prekey)
ephemeralKey, err := generateIdentityKeyPair()
if err != nil {
m.err = err
return
}
dh1, _ := performDH(m.identityKey.PrivateKey, theirPrekey)
dh2, _ := performDH(ephemeralKey.PrivateKey, theirIdentityKey)
dh3, _ := performDH(ephemeralKey.PrivateKey, theirPrekey)
sharedSecret := deriveSharedSecret(dh1, dh2, dh3)
roomID := generateRoomID() roomID := generateRoomID()
roomName := fmt.Sprintf("DM: %s <-> %s", m.username, username) roomName := fmt.Sprintf("DM: %s <-> %s", m.username, username)
@@ -401,56 +397,31 @@ func (m *model) handleSelectUserForDM(username string) {
m.err = err m.err = err
return return
} }
joinRoomMember(roomID, m.username)
joinRoomMember(roomID, username)
if err := joinRoom(roomID, m.username, sharedSecret); err != nil { m.rotateRoomKey(roomID)
m.err = err
return
}
m.enterRoom(roomID, roomName) m.enterRoom(roomID, roomName)
_ = prekeySignature
} }
func (m *model) createNoteToSelf() { func (m *model) createNoteToSelf() {
roomID := generateRoomID() roomID := generateRoomID()
roomName := "[Note to Self]" roomName := "[Note to Self]"
sharedSecret := make([]byte, 32)
if _, err := io.ReadFull(rand.Reader, sharedSecret); err != nil {
m.err = err
return
}
if err := createRoom(roomID, roomName, m.username, true); err != nil { if err := createRoom(roomID, roomName, m.username, true); err != nil {
m.err = err m.err = err
return return
} }
joinRoomMember(roomID, m.username)
if err := joinRoom(roomID, m.username, sharedSecret); err != nil { m.rotateRoomKey(roomID)
m.err = err
return
}
m.enterRoom(roomID, roomName) m.enterRoom(roomID, roomName)
} }
func (m *model) enterRoom(roomID, roomName string) { func (m *model) enterRoom(roomID, roomName string) {
secret, err := getRoomSharedSecret(roomID, m.username) epoch, _ := getRoomCurrentEpoch(roomID)
if err != nil {
m.err = err
return
}
joinedAt, err := getMemberJoinedAt(roomID, m.username)
if err != nil {
joinedAt = time.Now() // Fallback
}
m.currentRoomID = roomID m.currentRoomID = roomID
m.currentRoomKey = secret m.currentRoomEpoch = epoch
m.currentRoomJoinedAt = joinedAt
m.state = 2 m.state = 2
m.input.Placeholder = fmt.Sprintf("[%s] /back to menu, /leave to quit room", roomName) m.input.Placeholder = fmt.Sprintf("[%s] /back, /leave, /delete", roomName)
m.loadMessages() m.loadMessages()
} }
@@ -470,14 +441,42 @@ func generateRoomID() string {
return base64.URLEncoding.EncodeToString(b) return base64.URLEncoding.EncodeToString(b)
} }
func (m *model) getEpochKey(epoch int) []byte {
// Check cache
if key, ok := m.roomKeyCache[epoch]; ok {
return key
}
encKey, err := getUserRoomKey(m.username, m.currentRoomID, epoch)
if err != nil || encKey == nil {
return nil
}
key, err := DecryptKeyForUser(m.identityKey.PrivateKey, encKey)
if err != nil {
return nil
}
m.roomKeyCache[epoch] = key
return key
}
func (m *model) saveMessage(sender, text string) { func (m *model) saveMessage(sender, text string) {
ct, nonce, err := encryptMsg(text, m.currentRoomKey) m.currentRoomEpoch, _ = getRoomCurrentEpoch(m.currentRoomID)
key := m.getEpochKey(m.currentRoomEpoch)
if key == nil {
m.err = fmt.Errorf("no key for current epoch")
return
}
ct, nonce, err := encryptMsg(text, key)
if err != nil { if err != nil {
m.err = err m.err = err
return return
} }
err = saveMessage(m.currentRoomID, sender, ct, nonce) err = saveMessage(m.currentRoomID, sender, m.currentRoomEpoch, ct, nonce)
if err != nil { if err != nil {
m.err = err m.err = err
return return
@@ -499,10 +498,19 @@ func (m *model) loadMessages() {
var msgs []chatMsg var msgs []chatMsg
for _, dbMsg := range dbMessages { for _, dbMsg := range dbMessages {
plain, err := decryptMsg(dbMsg.Ciphertext, dbMsg.Nonce, m.currentRoomKey) key := m.getEpochKey(dbMsg.Epoch)
if err != nil { var plain string
plain = "[Decryption Failed]" if key == nil {
plain = "[Unreadable: Old History or Key Rotated]"
} else {
p, err := decryptMsg(dbMsg.Ciphertext, dbMsg.Nonce, key)
if err != nil {
plain = "[Decryption Failed]"
} else {
plain = p
}
} }
msgs = append(msgs, chatMsg{ msgs = append(msgs, chatMsg{
Timestamp: dbMsg.Timestamp, Timestamp: dbMsg.Timestamp,
Sender: dbMsg.Sender, Sender: dbMsg.Sender,
@@ -518,12 +526,10 @@ func (m *model) updateViewport() {
for _, msg := range m.messages { for _, msg := range m.messages {
dateStr := msg.Timestamp.Format("Jan 2 15:04") dateStr := msg.Timestamp.Format("Jan 2 15:04")
// Mask history before join time if strings.Contains(msg.Content, "[Unreadable") {
// Add 1 second buffer to avoid hiding immediate first messages
if msg.Timestamp.Before(m.currentRoomJoinedAt.Add(-1 * time.Second)) {
b.WriteString(fmt.Sprintf("%s %s\n", b.WriteString(fmt.Sprintf("%s %s\n",
timeStyle.Render(dateStr), timeStyle.Render(dateStr),
lockedStyle.Render("[Encrypted message - History hidden]"), lockedStyle.Render(msg.Content),
)) ))
} else { } else {
b.WriteString(fmt.Sprintf("%s %s: %s\n", b.WriteString(fmt.Sprintf("%s %s: %s\n",