diff --git a/internal/crypto.go b/internal/crypto.go index 8412522..fe75718 100644 --- a/internal/crypto.go +++ b/internal/crypto.go @@ -6,6 +6,7 @@ import ( "crypto/ed25519" "crypto/rand" "crypto/sha256" + "errors" "io" "golang.org/x/crypto/argon2" @@ -18,20 +19,12 @@ const ( nonceSize = 12 ) -// IdentityKeyPair represents a user's long-term identity key type IdentityKeyPair struct { PublicKey [32]byte PrivateKey [32]byte } -// PrekeyBundle contains public keys for establishing a session -type PrekeyBundle struct { - IdentityKey [32]byte - Prekey [32]byte - PrekeySignature []byte -} - -// generateIdentityKeyPair creates a new Curve25519 key pair for identity +// Creates a new Curve25519 key pair for identity func generateIdentityKeyPair() (*IdentityKeyPair, error) { var privateKey [32]byte if _, err := io.ReadFull(rand.Reader, privateKey[:]); err != nil { @@ -47,7 +40,7 @@ func generateIdentityKeyPair() (*IdentityKeyPair, error) { }, nil } -// generateSignedPrekey creates a prekey signed with an Ed25519 signing key +// Creates a prekey signed with an Ed25519 signing key func generateSignedPrekey() ([32]byte, []byte, error) { var prekey [32]byte var prekeyPriv [32]byte @@ -64,23 +57,20 @@ func generateSignedPrekey() ([32]byte, []byte, error) { // Sign the prekey signature := ed25519.Sign(signingPriv, prekey[:]) - - // In practice, we'd store signingPub to verify later. For now, we include it in signature combinedSig := append(signingPub, signature...) return prekey, combinedSig, nil } -// performDH does an ECDH key exchange +// Does an ECDH key exchange func performDH(privateKey, publicKey [32]byte) ([32]byte, error) { var sharedSecret [32]byte curve25519.ScalarMult(&sharedSecret, &privateKey, &publicKey) return sharedSecret, nil } -// deriveSharedSecret combines multiple DH outputs to create a shared secret (simplified X3DH) +// Combines multiple DH outputs to create a shared secret func deriveSharedSecret(dh1, dh2, dh3 [32]byte) []byte { - // KDF: hash all DH outputs together h := sha256.New() h.Write(dh1[:]) h.Write(dh2[:]) @@ -88,7 +78,7 @@ func deriveSharedSecret(dh1, dh2, dh3 [32]byte) []byte { return h.Sum(nil) } -// encryptUserKey encrypts a user's private key with their passphrase +// Encrypts a user's private key with their passphrase func encryptUserKey(privateKey [32]byte, passphrase string) ([]byte, []byte, []byte, error) { salt := make([]byte, saltSize) if _, err := io.ReadFull(rand.Reader, salt); err != nil { @@ -104,7 +94,7 @@ func encryptUserKey(privateKey [32]byte, passphrase string) ([]byte, []byte, []b return ciphertext, nonce, salt, nil } -// decryptUserKey decrypts a user's private key with their passphrase +// Decrypts a user's private key with their passphrase func decryptUserKey(ciphertext, nonce, salt []byte, passphrase string) ([32]byte, error) { var privateKey [32]byte key := deriveKey(passphrase, salt) @@ -112,7 +102,7 @@ func decryptUserKey(ciphertext, nonce, salt []byte, passphrase string) ([32]byte if err != nil { return privateKey, err } - copy(privateKey[:], plain) + copy(privateKey[:], []byte(plain)) return privateKey, nil } @@ -152,3 +142,68 @@ func decryptMsg(ciphertext, nonce, key []byte) (string, error) { } return string(plain), nil } + +// Format: EphemeralPub (32) || Nonce (12) || Ciphertext (...) +func EncryptKeyForUser(recipientPubKey [32]byte, payload []byte) ([]byte, error) { + var ephPriv, ephPub [32]byte + if _, err := io.ReadFull(rand.Reader, ephPriv[:]); err != nil { + return nil, err + } + curve25519.ScalarBaseMult(&ephPub, &ephPriv) + + sharedSecret, _ := performDH(ephPriv, recipientPubKey) + + kdf := sha256.Sum256(sharedSecret[:]) + encryptionKey := kdf[:] + + block, err := aes.NewCipher(encryptionKey) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := io.ReadFull(rand.Reader, nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, nonce, payload, nil) + + out := make([]byte, 0, 32+len(nonce)+len(ciphertext)) + out = append(out, ephPub[:]...) + out = append(out, nonce...) + out = append(out, ciphertext...) + + return out, nil +} + +// Decrypts a blob using the user's private identity key. +func DecryptKeyForUser(myPrivKey [32]byte, blob []byte) ([]byte, error) { + if len(blob) < 32+12 { + return nil, errors.New("invalid key blob size") + } + + var ephPub [32]byte + copy(ephPub[:], blob[:32]) + nonce := blob[32 : 32+12] + ciphertext := blob[32+12:] + + sharedSecret, _ := performDH(myPrivKey, ephPub) + + kdf := sha256.Sum256(sharedSecret[:]) + decryptionKey := kdf[:] + + block, err := aes.NewCipher(decryptionKey) + if err != nil { + return nil, err + } + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + return gcm.Open(nil, nonce, ciphertext, nil) +} diff --git a/internal/database.go b/internal/database.go index 5756eef..da3e363 100644 --- a/internal/database.go +++ b/internal/database.go @@ -11,7 +11,7 @@ import ( _ "github.com/mattn/go-sqlite3" ) -const dbFile = "teachat.db" +const dbFile = "encrypted_chat.db" var db *sql.DB @@ -37,20 +37,31 @@ func InitDB() { name TEXT, creator TEXT, is_dm BOOLEAN, + current_epoch INTEGER DEFAULT 0, created_at DATETIME ); CREATE TABLE IF NOT EXISTS room_members ( room_id TEXT, username TEXT, - shared_secret BLOB, joined_at DATETIME, PRIMARY KEY (room_id, username), FOREIGN KEY (room_id) REFERENCES rooms(id), FOREIGN KEY (username) REFERENCES users(username) ); + -- Stores the encrypted room key for a specific user for a specific epoch + CREATE TABLE IF NOT EXISTS user_room_keys ( + username TEXT, + room_id TEXT, + epoch INTEGER, + encrypted_key BLOB, + PRIMARY KEY (username, room_id, epoch), + FOREIGN KEY (username) REFERENCES users(username), + FOREIGN KEY (room_id) REFERENCES rooms(id) + ); CREATE TABLE IF NOT EXISTS messages ( id INTEGER PRIMARY KEY, room_id TEXT, + epoch INTEGER, timestamp DATETIME, sender TEXT, ciphertext BLOB, @@ -61,9 +72,6 @@ func InitDB() { if _, err := db.Exec(query); err != nil { log.Fatal(err) } - - // Migration for older DBs - db.Exec("ALTER TABLE rooms ADD COLUMN creator TEXT") } func ensureSalt() []byte { @@ -101,26 +109,19 @@ func getUserKeys(username string) (identityKey, prekey, prekeySignature []byte, return } -func listUsers() ([]string, error) { - rows, err := db.Query("SELECT username FROM users ORDER BY username ASC") +func getMemberIdentityKey(username string) ([32]byte, error) { + var b []byte + var key [32]byte + err := db.QueryRow("SELECT identity_key FROM users WHERE username = ?", username).Scan(&b) if err != nil { - return nil, err + return key, err } - defer rows.Close() - - var users []string - for rows.Next() { - var username string - if err := rows.Scan(&username); err != nil { - continue - } - users = append(users, username) - } - return users, nil + copy(key[:], b) + return key, nil } func createRoom(roomID, name, creator string, isDM bool) error { - _, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, created_at) VALUES (?, ?, ?, ?, ?)", + _, err := db.Exec("INSERT INTO rooms (id, name, creator, is_dm, current_epoch, created_at) VALUES (?, ?, ?, ?, 0, ?)", roomID, name, creator, isDM, time.Now()) return err } @@ -131,12 +132,17 @@ func getRoomIDByName(name string) (string, error) { return id, err } -// getAnyRoomSecret gets the shared secret from ANY member of the room. -// In a real decentralized system this wouldn't be possible, but here the server acts as the key distributor. -func getAnyRoomSecret(roomID string) ([]byte, error) { - var secret []byte - err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? LIMIT 1", roomID).Scan(&secret) - return secret, err +func getRoomCurrentEpoch(roomID string) (int, error) { + var epoch int + err := db.QueryRow("SELECT current_epoch FROM rooms WHERE id = ?", roomID).Scan(&epoch) + return epoch, err +} + +func incrementRoomEpoch(roomID string) (int, error) { + var newEpoch int + // Atomic increment + err := db.QueryRow("UPDATE rooms SET current_epoch = current_epoch + 1 WHERE id = ? RETURNING current_epoch", roomID).Scan(&newEpoch) + return newEpoch, err } func deleteRoom(roomID string) error { @@ -148,6 +154,10 @@ func deleteRoom(roomID string) error { tx.Rollback() return err } + if _, err := tx.Exec("DELETE FROM user_room_keys WHERE room_id = ?", roomID); err != nil { + tx.Rollback() + return err + } if _, err := tx.Exec("DELETE FROM room_members WHERE room_id = ?", roomID); err != nil { tx.Rollback() return err @@ -164,22 +174,48 @@ func leaveRoom(roomID, username string) error { return err } -func joinRoom(roomID, username string, sharedSecret []byte) error { +func joinRoomMember(roomID, username string) error { // Check if already joined var count int db.QueryRow("SELECT COUNT(*) FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&count) if count > 0 { return nil } - _, err := db.Exec("INSERT INTO room_members (room_id, username, shared_secret, joined_at) VALUES (?, ?, ?, ?)", - roomID, username, sharedSecret, time.Now()) + _, err := db.Exec("INSERT INTO room_members (room_id, username, joined_at) VALUES (?, ?, ?)", + roomID, username, time.Now()) return err } -func getMemberJoinedAt(roomID, username string) (time.Time, error) { - var joinedAt time.Time - err := db.QueryRow("SELECT joined_at FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&joinedAt) - return joinedAt, err +func getRoomMembers(roomID string) ([]string, error) { + rows, err := db.Query("SELECT username FROM room_members WHERE room_id = ?", roomID) + if err != nil { + return nil, err + } + defer rows.Close() + var users []string + for rows.Next() { + var u string + if err := rows.Scan(&u); err == nil { + users = append(users, u) + } + } + return users, nil +} + +func storeUserRoomKey(username, roomID string, epoch int, encryptedKey []byte) error { + _, err := db.Exec(` + INSERT INTO user_room_keys (username, room_id, epoch, encrypted_key) + VALUES (?, ?, ?, ?) + ON CONFLICT(username, room_id, epoch) DO UPDATE SET encrypted_key = excluded.encrypted_key`, + username, roomID, epoch, encryptedKey) + return err +} + +func getUserRoomKey(username, roomID string, epoch int) ([]byte, error) { + var key []byte + err := db.QueryRow("SELECT encrypted_key FROM user_room_keys WHERE username = ? AND room_id = ? AND epoch = ?", + username, roomID, epoch).Scan(&key) + return key, err } func listUserRooms(username string) ([]struct { @@ -220,31 +256,27 @@ func listUserRooms(username string) ([]struct { return rooms, nil } -func getRoomSharedSecret(roomID, username string) ([]byte, error) { - var secret []byte - err := db.QueryRow("SELECT shared_secret FROM room_members WHERE room_id = ? AND username = ?", roomID, username).Scan(&secret) - return secret, err -} - -func saveMessage(roomID, sender string, ciphertext, nonce []byte) error { - _, err := db.Exec("INSERT INTO messages (room_id, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?)", - roomID, time.Now(), sender, ciphertext, nonce) +func saveMessage(roomID, sender string, epoch int, ciphertext, nonce []byte) error { + _, err := db.Exec("INSERT INTO messages (room_id, epoch, timestamp, sender, ciphertext, nonce) VALUES (?, ?, ?, ?, ?, ?)", + roomID, epoch, time.Now(), sender, ciphertext, nonce) return err } func loadMessages(roomID string) ([]struct { + Epoch int Timestamp time.Time Sender string Ciphertext []byte Nonce []byte }, error) { - rows, err := db.Query("SELECT timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID) + rows, err := db.Query("SELECT epoch, timestamp, sender, ciphertext, nonce FROM messages WHERE room_id = ? ORDER BY timestamp ASC", roomID) if err != nil { return nil, err } defer rows.Close() var messages []struct { + Epoch int Timestamp time.Time Sender string Ciphertext []byte @@ -253,12 +285,13 @@ func loadMessages(roomID string) ([]struct { for rows.Next() { var msg struct { + Epoch int Timestamp time.Time Sender string Ciphertext []byte Nonce []byte } - if err := rows.Scan(&msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil { + if err := rows.Scan(&msg.Epoch, &msg.Timestamp, &msg.Sender, &msg.Ciphertext, &msg.Nonce); err != nil { continue } messages = append(messages, msg) diff --git a/internal/model.go b/internal/model.go index 87ebb85..af1d85f 100644 --- a/internal/model.go +++ b/internal/model.go @@ -17,13 +17,13 @@ import ( ) var ( - msgStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")).PaddingLeft(1) - senderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("205")).Bold(true) - errStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Bold(true) - sysStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Italic(true) - roomStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("86")).Bold(true) - dmStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("141")).Bold(true) - noteStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("228")).Bold(true) + msgStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("252")).PaddingLeft(1) + senderStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("205")).Bold(true) + errStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("196")).Bold(true) + sysStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("117")).Italic(true) + roomStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("86")).Bold(true) + dmStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("141")).Bold(true) + noteStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("228")).Bold(true) commandStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("39")).Italic(true) timeStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("246")) lockedStyle = lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Italic(true) @@ -49,16 +49,17 @@ type model struct { privKeySalt []byte privKeyNonce []byte - currentRoomID string - currentRoomKey []byte - currentRoomJoinedAt time.Time - rooms []struct { + currentRoomID string + currentRoomEpoch int + // We cache keys: Epoch -> Plaintext Key + roomKeyCache map[int][]byte + + rooms []struct { ID string Name string Creator string IsDM bool } - availableUsers []string err error @@ -117,6 +118,7 @@ func initialModel(s ssh.Session) model { input: ti, viewport: vp, updateChan: updates.subscribe(), + roomKeyCache: make(map[int][]byte), } } @@ -225,6 +227,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if val == "/leave" { leaveRoom(m.currentRoomID, m.username) + m.rotateRoomKey(m.currentRoomID) // Rotate key on leave m.exitChat() return m, tea.ClearScreen } @@ -269,8 +272,9 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { func (m *model) exitChat() { m.state = 1 m.currentRoomID = "" - m.currentRoomKey = nil + m.currentRoomEpoch = 0 m.messages = nil + m.roomKeyCache = make(map[int][]byte) m.viewport.SetContent("") m.loadRooms() m.input.Reset() @@ -299,7 +303,6 @@ func (m *model) handleRoomListInput(text string) { m.handleJoinRoom(roomName) return } - // Keep /new as an alias for /join if strings.HasPrefix(text, "/new ") { roomName := strings.TrimPrefix(text, "/new ") m.handleJoinRoom(roomName) @@ -319,81 +322,74 @@ func (m *model) handleRoomListInput(text string) { } func (m *model) handleJoinRoom(roomName string) { - // Check if room exists existingID, err := getRoomIDByName(roomName) - - // Room Exists: Join it if err == nil && existingID != "" { - // Need key - secret, err := getAnyRoomSecret(existingID) - if err != nil { - m.err = fmt.Errorf("could not join: room key unreachable") - return - } - - if err := joinRoom(existingID, m.username, secret); err != nil { + // Join existing + if err := joinRoomMember(existingID, m.username); err != nil { m.err = err return } + // ROTATE KEY so new user gets a key, but doesn't get old keys + m.rotateRoomKey(existingID) m.enterRoom(existingID, roomName) return } - // Room Does Not Exist: Create it + // Create new roomID := generateRoomID() if err := createRoom(roomID, roomName, m.username, false); err != nil { m.err = err return } - - sharedSecret := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, sharedSecret); err != nil { + if err := joinRoomMember(roomID, m.username); err != nil { m.err = err return } - - if err := joinRoom(roomID, m.username, sharedSecret); err != nil { - m.err = err - return - } - + // Initial Key + m.rotateRoomKey(roomID) m.enterRoom(roomID, roomName) } +func (m *model) rotateRoomKey(roomID string) { + newKey := make([]byte, 32) + io.ReadFull(rand.Reader, newKey) + + newEpoch, err := incrementRoomEpoch(roomID) + if err != nil { + m.err = err + return + } + + members, err := getRoomMembers(roomID) + if err != nil { + m.err = err + return + } + + for _, user := range members { + idKey, err := getMemberIdentityKey(user) + if err != nil { + continue + } + encrypted, err := EncryptKeyForUser(idKey, newKey) + if err != nil { + continue + } + storeUserRoomKey(user, roomID, newEpoch, encrypted) + } +} + func (m *model) handleSelectUserForDM(username string) { if username == "/cancel" { m.state = 1 m.input.Placeholder = "Enter room # or: /join , /dm , /list" return } - if username == m.username { m.createNoteToSelf() return } - identityKey, prekey, prekeySignature, err := getUserKeys(username) - if err != nil { - m.err = fmt.Errorf("user not found: %w", err) - return - } - - var theirIdentityKey, theirPrekey [32]byte - copy(theirIdentityKey[:], identityKey) - copy(theirPrekey[:], prekey) - - ephemeralKey, err := generateIdentityKeyPair() - if err != nil { - m.err = err - return - } - - dh1, _ := performDH(m.identityKey.PrivateKey, theirPrekey) - dh2, _ := performDH(ephemeralKey.PrivateKey, theirIdentityKey) - dh3, _ := performDH(ephemeralKey.PrivateKey, theirPrekey) - - sharedSecret := deriveSharedSecret(dh1, dh2, dh3) - roomID := generateRoomID() roomName := fmt.Sprintf("DM: %s <-> %s", m.username, username) @@ -401,56 +397,31 @@ func (m *model) handleSelectUserForDM(username string) { m.err = err return } + joinRoomMember(roomID, m.username) + joinRoomMember(roomID, username) - if err := joinRoom(roomID, m.username, sharedSecret); err != nil { - m.err = err - return - } - + m.rotateRoomKey(roomID) m.enterRoom(roomID, roomName) - _ = prekeySignature } func (m *model) createNoteToSelf() { roomID := generateRoomID() roomName := "[Note to Self]" - - sharedSecret := make([]byte, 32) - if _, err := io.ReadFull(rand.Reader, sharedSecret); err != nil { - m.err = err - return - } - if err := createRoom(roomID, roomName, m.username, true); err != nil { m.err = err return } - - if err := joinRoom(roomID, m.username, sharedSecret); err != nil { - m.err = err - return - } - + joinRoomMember(roomID, m.username) + m.rotateRoomKey(roomID) m.enterRoom(roomID, roomName) } func (m *model) enterRoom(roomID, roomName string) { - secret, err := getRoomSharedSecret(roomID, m.username) - if err != nil { - m.err = err - return - } - - joinedAt, err := getMemberJoinedAt(roomID, m.username) - if err != nil { - joinedAt = time.Now() // Fallback - } - + epoch, _ := getRoomCurrentEpoch(roomID) m.currentRoomID = roomID - m.currentRoomKey = secret - m.currentRoomJoinedAt = joinedAt + m.currentRoomEpoch = epoch m.state = 2 - m.input.Placeholder = fmt.Sprintf("[%s] /back to menu, /leave to quit room", roomName) + m.input.Placeholder = fmt.Sprintf("[%s] /back, /leave, /delete", roomName) m.loadMessages() } @@ -470,14 +441,42 @@ func generateRoomID() string { return base64.URLEncoding.EncodeToString(b) } +func (m *model) getEpochKey(epoch int) []byte { + // Check cache + if key, ok := m.roomKeyCache[epoch]; ok { + return key + } + + encKey, err := getUserRoomKey(m.username, m.currentRoomID, epoch) + if err != nil || encKey == nil { + return nil + } + + key, err := DecryptKeyForUser(m.identityKey.PrivateKey, encKey) + if err != nil { + return nil + } + + m.roomKeyCache[epoch] = key + return key +} + func (m *model) saveMessage(sender, text string) { - ct, nonce, err := encryptMsg(text, m.currentRoomKey) + m.currentRoomEpoch, _ = getRoomCurrentEpoch(m.currentRoomID) + + key := m.getEpochKey(m.currentRoomEpoch) + if key == nil { + m.err = fmt.Errorf("no key for current epoch") + return + } + + ct, nonce, err := encryptMsg(text, key) if err != nil { m.err = err return } - err = saveMessage(m.currentRoomID, sender, ct, nonce) + err = saveMessage(m.currentRoomID, sender, m.currentRoomEpoch, ct, nonce) if err != nil { m.err = err return @@ -499,10 +498,19 @@ func (m *model) loadMessages() { var msgs []chatMsg for _, dbMsg := range dbMessages { - plain, err := decryptMsg(dbMsg.Ciphertext, dbMsg.Nonce, m.currentRoomKey) - if err != nil { - plain = "[Decryption Failed]" + key := m.getEpochKey(dbMsg.Epoch) + var plain string + if key == nil { + plain = "[Unreadable: Old History or Key Rotated]" + } else { + p, err := decryptMsg(dbMsg.Ciphertext, dbMsg.Nonce, key) + if err != nil { + plain = "[Decryption Failed]" + } else { + plain = p + } } + msgs = append(msgs, chatMsg{ Timestamp: dbMsg.Timestamp, Sender: dbMsg.Sender, @@ -518,12 +526,10 @@ func (m *model) updateViewport() { for _, msg := range m.messages { dateStr := msg.Timestamp.Format("Jan 2 15:04") - // Mask history before join time - // Add 1 second buffer to avoid hiding immediate first messages - if msg.Timestamp.Before(m.currentRoomJoinedAt.Add(-1 * time.Second)) { + if strings.Contains(msg.Content, "[Unreadable") { b.WriteString(fmt.Sprintf("%s %s\n", timeStyle.Render(dateStr), - lockedStyle.Render("[Encrypted message - History hidden]"), + lockedStyle.Render(msg.Content), )) } else { b.WriteString(fmt.Sprintf("%s %s: %s\n",